bug #1741: fix SelfAdjointView::rankUpdate and product to triangular part for destination with non-trivial inner stride

This commit is contained in:
Gael Guennebaud 2019-09-10 23:29:52 +02:00
parent ea0d5dc956
commit c06e6fd115
4 changed files with 58 additions and 31 deletions

View File

@ -25,51 +25,54 @@ namespace internal {
**********************************************************************/ **********************************************************************/
// forward declarations (defined at the end of this file) // forward declarations (defined at the end of this file)
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo> template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
struct tribb_kernel; struct tribb_kernel;
/* Optimized matrix-matrix product evaluating only one triangular half */ /* Optimized matrix-matrix product evaluating only one triangular half */
template <typename Index, template <typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
int ResStorageOrder, int UpLo, int Version = Specialized> int ResStorageOrder, int ResInnerStride, int UpLo, int Version = Specialized>
struct general_matrix_matrix_triangular_product; struct general_matrix_matrix_triangular_product;
// as usual if the result is row major => we transpose the product // as usual if the result is row major => we transpose the product
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version> typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo,Version> int ResInnerStride, int UpLo, int Version>
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,UpLo,Version>
{ {
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride, static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking) const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking)
{ {
general_matrix_matrix_triangular_product<Index, general_matrix_matrix_triangular_product<Index,
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs, LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
ColMajor, UpLo==Lower?Upper:Lower> ColMajor, ResInnerStride, UpLo==Lower?Upper:Lower>
::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking); ::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking);
} }
}; };
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version> typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Version> int ResInnerStride, int UpLo, int Version>
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version>
{ {
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride, const RhsScalar* _rhs, Index rhsStride,
ResScalar* _res, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{ {
typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef gebp_traits<LhsScalar,RhsScalar> Traits;
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride); LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride); RhsMapper rhs(_rhs,rhsStride);
ResMapper res(_res, resStride); ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); Index kc = blocking.kc();
Index mc = (std::min)(size,blocking.mc()); Index mc = (std::min)(size,blocking.mc());
@ -87,7 +90,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs; gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs; gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp; gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb; tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo> sybb;
for(Index k2=0; k2<depth; k2+=kc) for(Index k2=0; k2<depth; k2+=kc)
{ {
@ -110,7 +113,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
(std::min)(size,i2), alpha, -1, -1, 0, 0); (std::min)(size,i2), alpha, -1, -1, 0, 0);
sybb(_res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha); sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
if (UpLo==Upper) if (UpLo==Upper)
{ {
@ -132,7 +135,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
// while the triangular block overlapping the diagonal is evaluated into a // while the triangular block overlapping the diagonal is evaluated into a
// small temporary buffer which is then accumulated into the result using a // small temporary buffer which is then accumulated into the result using a
// triangular traversal. // triangular traversal.
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo> template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
struct tribb_kernel struct tribb_kernel
{ {
typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits; typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
@ -141,11 +144,13 @@ struct tribb_kernel
enum { enum {
BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
}; };
void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha) void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
{ {
typedef blas_data_mapper<ResScalar, Index, ColMajor> ResMapper; typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
ResMapper res(_res, resStride); typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel; ResMapper res(_res, resStride, resIncr);
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert())); Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert()));
@ -157,7 +162,7 @@ struct tribb_kernel
const RhsScalar* actual_b = blockB+j*depth; const RhsScalar* actual_b = blockB+j*depth;
if(UpLo==Upper) if(UpLo==Upper)
gebp_kernel(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
-1, -1, 0, 0); -1, -1, 0, 0);
// selfadjoint micro block // selfadjoint micro block
@ -165,23 +170,23 @@ struct tribb_kernel
Index i = j; Index i = j;
buffer.setZero(); buffer.setZero();
// 1 - apply the kernel on the temporary buffer // 1 - apply the kernel on the temporary buffer
gebp_kernel(ResMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha, gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
-1, -1, 0, 0); -1, -1, 0, 0);
// 2 - triangular accumulation // 2 - triangular accumulation
for(Index j1=0; j1<actualBlockSize; ++j1) for(Index j1=0; j1<actualBlockSize; ++j1)
{ {
ResScalar* r = &res(i, j + j1); typename ResMapper::LinearMapper r = res.getLinearMapper(i,j+j1);
for(Index i1=UpLo==Lower ? j1 : 0; for(Index i1=UpLo==Lower ? j1 : 0;
UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1) UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
r[i1] += buffer(i1,j1); r(i1) += buffer(i1,j1);
} }
} }
if(UpLo==Lower) if(UpLo==Lower)
{ {
Index i = j+actualBlockSize; Index i = j+actualBlockSize;
gebp_kernel(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i, gebp_kernel1(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
depth, actualBlockSize, alpha, -1, -1, 0, 0); depth, actualBlockSize, alpha, -1, -1, 0, 0);
} }
} }
@ -286,11 +291,12 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
internal::general_matrix_matrix_triangular_product<Index, internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)> IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo&(Lower|Upper)>
::run(size, depth, ::run(size, depth,
&actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(), &actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(),
&actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(), &actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(),
mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking); mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? mat.innerStride() : mat.outerStride() ) : 0),
mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
} }
}; };

View File

@ -109,10 +109,10 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
internal::general_matrix_matrix_triangular_product<Index, internal::general_matrix_matrix_triangular_product<Index,
Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
IsRowMajor ? RowMajor : ColMajor, UpLo> IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo>
::run(size, depth, ::run(size, depth,
&actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
mat.data(), mat.outerStride(), actualAlpha, blocking); mat.data(), mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
} }
}; };

View File

@ -82,6 +82,16 @@ template<typename Scalar> void mmtr(int size)
ref2.template triangularView<Lower>() = ref1.template triangularView<Lower>(); ref2.template triangularView<Lower>() = ref1.template triangularView<Lower>();
matc.template triangularView<Lower>() = sqc * matc * sqc.adjoint(); matc.template triangularView<Lower>() = sqc * matc * sqc.adjoint();
VERIFY_IS_APPROX(matc, ref2); VERIFY_IS_APPROX(matc, ref2);
// destination with a non-default inner-stride
// see bug 1741
{
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
MatrixX buffer(2*size,2*size);
Map<MatrixColMaj,0,Stride<Dynamic,Dynamic> > map1(buffer.data(),size,size,Stride<Dynamic,Dynamic>(2*size,2));
buffer.setZero();
CHECK_MMTR(map1, Lower, = s*soc*sor.adjoint());
}
} }
EIGEN_DECLARE_TEST(product_mmtr) EIGEN_DECLARE_TEST(product_mmtr)

View File

@ -115,6 +115,17 @@ template<typename MatrixType> void syrk(const MatrixType& m)
m2.setZero(); m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()), VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()),
((s1 * m1.row(c).adjoint() * m1.row(c).adjoint().adjoint()).eval().template triangularView<Upper>().toDenseMatrix())); ((s1 * m1.row(c).adjoint() * m1.row(c).adjoint().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
// destination with a non-default inner-stride
// see bug 1741
{
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
MatrixX buffer(2*rows,2*cols);
Map<MatrixType,0,Stride<Dynamic,2> > map1(buffer.data(),rows,cols,Stride<Dynamic,2>(2*rows,2));
buffer.setZero();
VERIFY_IS_APPROX((map1.template selfadjointView<Lower>().rankUpdate(rhs2,s1)._expression()),
((s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
}
} }
EIGEN_DECLARE_TEST(product_syrk) EIGEN_DECLARE_TEST(product_syrk)