mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-08 22:21:49 +08:00
fix trmm for some unusual trapezoidal cases (a dense set of columns or rows is zero)
(transplanted from 568478ffe5a82608ac0ce3b6a1f5eac551dc9543 )
This commit is contained in:
parent
55574053d0
commit
ffefe1bd2e
@ -96,24 +96,30 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
LhsStorageOrder,ConjugateLhs,
|
LhsStorageOrder,ConjugateLhs,
|
||||||
RhsStorageOrder,ConjugateRhs,ColMajor>
|
RhsStorageOrder,ConjugateRhs,ColMajor>
|
||||||
{
|
{
|
||||||
|
|
||||||
|
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||||
|
enum {
|
||||||
|
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
||||||
|
IsLower = (Mode&Lower) == Lower,
|
||||||
|
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
||||||
|
};
|
||||||
|
|
||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index rows, Index cols, Index depth,
|
Index _rows, Index _cols, Index _depth,
|
||||||
const Scalar* _lhs, Index lhsStride,
|
const Scalar* _lhs, Index lhsStride,
|
||||||
const Scalar* _rhs, Index rhsStride,
|
const Scalar* _rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha)
|
Scalar alpha)
|
||||||
{
|
{
|
||||||
|
// strip zeros
|
||||||
|
Index diagSize = std::min(_rows,_depth);
|
||||||
|
Index rows = IsLower ? _rows : diagSize;
|
||||||
|
Index depth = IsLower ? diagSize : _depth;
|
||||||
|
Index cols = _cols;
|
||||||
|
|
||||||
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
|
|
||||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
|
||||||
enum {
|
|
||||||
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
|
||||||
IsLower = (Mode&Lower) == Lower,
|
|
||||||
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
|
||||||
};
|
|
||||||
|
|
||||||
Index kc = depth; // cache block size along the K direction
|
Index kc = depth; // cache block size along the K direction
|
||||||
Index mc = rows; // cache block size along the M direction
|
Index mc = rows; // cache block size along the M direction
|
||||||
Index nc = cols; // cache block size along the N direction
|
Index nc = cols; // cache block size along the N direction
|
||||||
@ -152,10 +158,11 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols);
|
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols);
|
||||||
|
|
||||||
// the selected lhs's panel has to be split in three different parts:
|
// the selected lhs's panel has to be split in three different parts:
|
||||||
// 1 - the part which is above the diagonal block => skip it
|
// 1 - the part which is zero => skip it
|
||||||
// 2 - the diagonal block => special kernel
|
// 2 - the diagonal block => special kernel
|
||||||
// 3 - the panel below the diagonal block => GEPP
|
// 3 - the dense panel below (lower case) or above (upper case) the diagonal block => GEPP
|
||||||
// the block diagonal, if any
|
|
||||||
|
// the block diagonal, if any:
|
||||||
if(IsLower || actual_k2<rows)
|
if(IsLower || actual_k2<rows)
|
||||||
{
|
{
|
||||||
// for each small vertical panels of lhs
|
// for each small vertical panels of lhs
|
||||||
@ -193,7 +200,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// the part below the diagonal => GEPP
|
// the part below (lower case) or above (upper case) the diagonal => GEPP
|
||||||
{
|
{
|
||||||
Index start = IsLower ? k2 : 0;
|
Index start = IsLower ? k2 : 0;
|
||||||
Index end = IsLower ? rows : std::min(actual_k2,rows);
|
Index end = IsLower ? rows : std::min(actual_k2,rows);
|
||||||
@ -218,24 +225,29 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
|||||||
LhsStorageOrder,ConjugateLhs,
|
LhsStorageOrder,ConjugateLhs,
|
||||||
RhsStorageOrder,ConjugateRhs,ColMajor>
|
RhsStorageOrder,ConjugateRhs,ColMajor>
|
||||||
{
|
{
|
||||||
|
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||||
|
enum {
|
||||||
|
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
||||||
|
IsLower = (Mode&Lower) == Lower,
|
||||||
|
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
||||||
|
};
|
||||||
|
|
||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index rows, Index cols, Index depth,
|
Index _rows, Index _cols, Index _depth,
|
||||||
const Scalar* _lhs, Index lhsStride,
|
const Scalar* _lhs, Index lhsStride,
|
||||||
const Scalar* _rhs, Index rhsStride,
|
const Scalar* _rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha)
|
Scalar alpha)
|
||||||
{
|
{
|
||||||
|
// strip zeros
|
||||||
|
Index diagSize = std::min(_cols,_depth);
|
||||||
|
Index rows = _rows;
|
||||||
|
Index depth = IsLower ? _depth : diagSize;
|
||||||
|
Index cols = IsLower ? diagSize : _cols;
|
||||||
|
|
||||||
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
|
|
||||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
|
||||||
enum {
|
|
||||||
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
|
||||||
IsLower = (Mode&Lower) == Lower,
|
|
||||||
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
|
||||||
};
|
|
||||||
|
|
||||||
Index kc = depth; // cache block size along the K direction
|
Index kc = depth; // cache block size along the K direction
|
||||||
Index mc = rows; // cache block size along the M direction
|
Index mc = rows; // cache block size along the M direction
|
||||||
Index nc = cols; // cache block size along the N direction
|
Index nc = cols; // cache block size along the N direction
|
||||||
|
Loading…
x
Reference in New Issue
Block a user