fix trmv for Strictly* triangular matrices and trapezoidal matrices

This commit is contained in:
Gael Guennebaud 2011-03-28 17:42:26 +02:00
parent 568478ffe5
commit 6feb1d3c0b

View File

@ -36,12 +36,16 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
enum {
IsLower = ((Mode&Lower)==Lower),
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
};
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
{
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index size = std::min(_rows,_cols);
Index rows = IsLower ? _rows : std::min(_rows,_cols);
Index cols = IsLower ? std::min(_rows,_cols) : _cols;
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
@ -54,20 +58,20 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
ResMap res(_res,rows);
for (Index pi=0; pi<cols; pi+=PanelWidth)
for (Index pi=0; pi<size; pi+=PanelWidth)
{
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
Index actualPanelWidth = std::min(PanelWidth, size-pi);
for (Index k=0; k<actualPanelWidth; ++k)
{
Index i = pi + k;
Index s = IsLower ? (HasUnitDiag ? i+1 : i ) : pi;
Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
Index r = IsLower ? actualPanelWidth-k : k+1;
if ((!HasUnitDiag) || (--r)>0)
if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
if (HasUnitDiag)
res.coeffRef(i) += alpha * cjRhs.coeff(i);
}
Index r = IsLower ? cols - pi - actualPanelWidth : pi;
Index r = IsLower ? rows - pi - actualPanelWidth : pi;
if (r>0)
{
Index s = IsLower ? pi+actualPanelWidth : 0;
@ -78,6 +82,14 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
&res.coeffRef(s), resIncr, alpha);
}
}
if((!IsLower) && cols>size)
{
general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
rows, cols-size,
&lhs.coeffRef(0,size), lhsStride,
&rhs.coeffRef(size), rhsIncr,
_res, resIncr, alpha);
}
}
};
@ -87,12 +99,16 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
enum {
IsLower = ((Mode&Lower)==Lower),
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
};
static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
{
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index diagSize = std::min(_rows,_cols);
Index rows = IsLower ? _rows : diagSize;
Index cols = IsLower ? diagSize : _cols;
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
@ -105,15 +121,15 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
ResMap res(_res,rows,InnerStride<>(resIncr));
for (Index pi=0; pi<cols; pi+=PanelWidth)
for (Index pi=0; pi<diagSize; pi+=PanelWidth)
{
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
Index actualPanelWidth = std::min(PanelWidth, diagSize-pi);
for (Index k=0; k<actualPanelWidth; ++k)
{
Index i = pi + k;
Index s = IsLower ? pi : (HasUnitDiag ? i+1 : i);
Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
Index r = IsLower ? k+1 : actualPanelWidth-k;
if ((!HasUnitDiag) || (--r)>0)
if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
if (HasUnitDiag)
res.coeffRef(i) += alpha * cjRhs.coeff(i);
@ -129,6 +145,14 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
&res.coeffRef(pi), resIncr, alpha);
}
}
if(IsLower && rows>diagSize)
{
general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
rows-diagSize, cols,
&lhs.coeffRef(diagSize,0), lhsStride,
&rhs.coeffRef(0), rhsIncr,
&res.coeffRef(diagSize), resIncr, alpha);
}
}
};
@ -180,7 +204,7 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
{
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
Transpose<Dest> dstT(dst);
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);