mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
fix trmv for Strictly* triangular matrices and trapezoidal matrices
This commit is contained in:
parent
568478ffe5
commit
6feb1d3c0b
@ -36,12 +36,16 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower),
|
||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
|
||||
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
|
||||
};
|
||||
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
|
||||
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||
{
|
||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||
Index size = std::min(_rows,_cols);
|
||||
Index rows = IsLower ? _rows : std::min(_rows,_cols);
|
||||
Index cols = IsLower ? std::min(_rows,_cols) : _cols;
|
||||
|
||||
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
|
||||
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
|
||||
@ -54,20 +58,20 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
|
||||
ResMap res(_res,rows);
|
||||
|
||||
for (Index pi=0; pi<cols; pi+=PanelWidth)
|
||||
for (Index pi=0; pi<size; pi+=PanelWidth)
|
||||
{
|
||||
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
|
||||
Index actualPanelWidth = std::min(PanelWidth, size-pi);
|
||||
for (Index k=0; k<actualPanelWidth; ++k)
|
||||
{
|
||||
Index i = pi + k;
|
||||
Index s = IsLower ? (HasUnitDiag ? i+1 : i ) : pi;
|
||||
Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
|
||||
Index r = IsLower ? actualPanelWidth-k : k+1;
|
||||
if ((!HasUnitDiag) || (--r)>0)
|
||||
if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
|
||||
res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
|
||||
if (HasUnitDiag)
|
||||
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
||||
}
|
||||
Index r = IsLower ? cols - pi - actualPanelWidth : pi;
|
||||
Index r = IsLower ? rows - pi - actualPanelWidth : pi;
|
||||
if (r>0)
|
||||
{
|
||||
Index s = IsLower ? pi+actualPanelWidth : 0;
|
||||
@ -78,6 +82,14 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
&res.coeffRef(s), resIncr, alpha);
|
||||
}
|
||||
}
|
||||
if((!IsLower) && cols>size)
|
||||
{
|
||||
general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
|
||||
rows, cols-size,
|
||||
&lhs.coeffRef(0,size), lhsStride,
|
||||
&rhs.coeffRef(size), rhsIncr,
|
||||
_res, resIncr, alpha);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -87,12 +99,16 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower),
|
||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
|
||||
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
|
||||
};
|
||||
static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||
static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
|
||||
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||
{
|
||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||
Index diagSize = std::min(_rows,_cols);
|
||||
Index rows = IsLower ? _rows : diagSize;
|
||||
Index cols = IsLower ? diagSize : _cols;
|
||||
|
||||
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
|
||||
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
|
||||
@ -105,15 +121,15 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
|
||||
ResMap res(_res,rows,InnerStride<>(resIncr));
|
||||
|
||||
for (Index pi=0; pi<cols; pi+=PanelWidth)
|
||||
for (Index pi=0; pi<diagSize; pi+=PanelWidth)
|
||||
{
|
||||
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
|
||||
Index actualPanelWidth = std::min(PanelWidth, diagSize-pi);
|
||||
for (Index k=0; k<actualPanelWidth; ++k)
|
||||
{
|
||||
Index i = pi + k;
|
||||
Index s = IsLower ? pi : (HasUnitDiag ? i+1 : i);
|
||||
Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
|
||||
Index r = IsLower ? k+1 : actualPanelWidth-k;
|
||||
if ((!HasUnitDiag) || (--r)>0)
|
||||
if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
|
||||
res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
|
||||
if (HasUnitDiag)
|
||||
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
||||
@ -129,6 +145,14 @@ struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
|
||||
&res.coeffRef(pi), resIncr, alpha);
|
||||
}
|
||||
}
|
||||
if(IsLower && rows>diagSize)
|
||||
{
|
||||
general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
|
||||
rows-diagSize, cols,
|
||||
&lhs.coeffRef(diagSize,0), lhsStride,
|
||||
&rhs.coeffRef(0), rhsIncr,
|
||||
&res.coeffRef(diagSize), resIncr, alpha);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -180,7 +204,7 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
||||
{
|
||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
|
||||
typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
|
||||
Transpose<Dest> dstT(dst);
|
||||
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
|
||||
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
|
||||
|
Loading…
x
Reference in New Issue
Block a user