From 512b74aaa19fa12a05774dd30205d2c97e8bdef9 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 18 Feb 2019 11:47:54 +0100 Subject: [PATCH] GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product. Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495). --- Eigen/src/Core/ProductEvaluators.h | 63 ++++++++++++++++++++---------- Eigen/src/Core/util/BlasUtil.h | 14 ++++++- test/product_notemporary.cpp | 38 ++++++++++++++++++ 3 files changed, 92 insertions(+), 23 deletions(-) diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 27796315d..60b79b855 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -411,35 +411,56 @@ struct generic_product_impl call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op()); } - // Catch "dst {,+,-}= (s*A)*B" and evaluate it lazily by moving out the scalar factor: - // dst {,+,-}= s * (A.lazyProduct(B)) - // This is a huge benefit for heap-allocated matrix types as it save one costly allocation. - // For them, this strategy is also faster than simply by-passing the heap allocation through - // stack allocation. - // For fixed sizes matrices, this is less obvious, it is sometimes x2 faster, but sometimes x3 slower, - // and the behavior depends also a lot on the compiler... so let's be conservative and enable them for dynamic-size only, - // that is when coming from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h - template + // This is a special evaluation path called from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h + // This variant tries to extract scalar multiples from both the LHS and RHS and factor them out. For instance: + // dst {,+,-}= (s1*A)*(B*s2) + // will be rewritten as: + // dst {,+,-}= (s1*s2) * (A.lazyProduct(B)) + // There are at least four benefits of doing so: + // 1 - huge performance gain for heap-allocated matrix types as it save costly allocations. + // 2 - it is faster than simply by-passing the heap allocation through stack allocation. + // 3 - it makes this fallback consistent with the heavy GEMM routine. + // 4 - it fully by-passes huge stack allocation attempts when multiplying huge fixed-size matrices. + // (see https://stackoverflow.com/questions/54738495) + // For small fixed sizes matrices, howver, the gains are less obvious, it is sometimes x2 faster, but sometimes x3 slower, + // and the behavior depends also a lot on the compiler... This is why this re-writting strategy is currently + // enabled only when falling back from the main GEMM. + template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - void eval_dynamic(Dst& dst, const CwiseBinaryOp, - const CwiseNullaryOp, Plain1>, Xpr2>& lhs, const Rhs& rhs, const Func &func) + void eval_dynamic(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Func &func) { - call_restricted_packet_assignment_no_alias(dst, lhs.lhs().functor().m_other * lhs.rhs().lazyProduct(rhs), func); + enum { + HasScalarFactor = blas_traits::HasScalarFactor || blas_traits::HasScalarFactor + }; + // FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto + // this is important for real*complex_mat + Scalar actualAlpha = blas_traits::extractScalarFactor(lhs) + * blas_traits::extractScalarFactor(rhs); + eval_dynamic_impl(dst, + blas_traits::extract(lhs), + blas_traits::extract(rhs), + func, + actualAlpha, + typename conditional::type()); + + } - // Here, we we always have LhsT==Lhs, but we need to make it a template type to make the above - // overload more specialized. - template +protected: + + template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - void eval_dynamic(Dst& dst, const LhsT& lhs, const Rhs& rhs, const Func &func) + void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& /* s == 1 */, false_type) { call_restricted_packet_assignment_no_alias(dst, lhs.lazyProduct(rhs), func); } - - -// template -// static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) -// { dst.noalias() += alpha * lhs.lazyProduct(rhs); } + + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& s, true_type) + { + call_restricted_packet_assignment_no_alias(dst, s * lhs.lazyProduct(rhs), func); + } }; // This specialization enforces the use of a coefficient-based evaluation strategy diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index a32630ed7..bc0a01540 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -274,7 +274,8 @@ template struct blas_traits HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit) && ( bool(XprType::IsVectorAtCompileTime) || int(inner_stride_at_compile_time::ret) == 1) - ) ? 1 : 0 + ) ? 1 : 0, + HasScalarFactor = false }; typedef typename conditional struct blas_traits, const CwiseNullaryOp,Plain>, NestedXpr> > : blas_traits { + enum { + HasScalarFactor = true + }; typedef blas_traits Base; typedef CwiseBinaryOp, const CwiseNullaryOp,Plain>, NestedXpr> XprType; typedef typename Base::ExtractType ExtractType; @@ -317,6 +321,9 @@ template struct blas_traits, NestedXpr, const CwiseNullaryOp,Plain> > > : blas_traits { + enum { + HasScalarFactor = true + }; typedef blas_traits Base; typedef CwiseBinaryOp, NestedXpr, const CwiseNullaryOp,Plain> > XprType; typedef typename Base::ExtractType ExtractType; @@ -335,6 +342,9 @@ template struct blas_traits, NestedXpr> > : blas_traits { + enum { + HasScalarFactor = true + }; typedef blas_traits Base; typedef CwiseUnaryOp, NestedXpr> XprType; typedef typename Base::ExtractType ExtractType; @@ -358,7 +368,7 @@ struct blas_traits > typename ExtractType::PlainObject >::type DirectLinearAccessType; enum { - IsTransposed = Base::IsTransposed ? 0 : 1 + IsTransposed = Base::IsTransposed ? 0 : 1, }; static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); } static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index dffb07608..7f169e6ae 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -11,6 +11,35 @@ #include "main.h" +template +void check_scalar_multiple3(Dst &dst, const Lhs& A, const Rhs& B) +{ + VERIFY_EVALUATION_COUNT( (dst.noalias() = A * B), 0); + VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() ); + VERIFY_EVALUATION_COUNT( (dst.noalias() += A * B), 0); + VERIFY_IS_APPROX( dst, 2*(A.eval() * B.eval()).eval() ); + VERIFY_EVALUATION_COUNT( (dst.noalias() -= A * B), 0); + VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() ); +} + +template +void check_scalar_multiple2(Dst &dst, const Lhs& A, const Rhs& B, S2 s2) +{ + CALL_SUBTEST( check_scalar_multiple3(dst, A, B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, -B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, s2*B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, B*s2) ); +} + +template +void check_scalar_multiple1(Dst &dst, const Lhs& A, const Rhs& B, S1 s1, S2 s2) +{ + CALL_SUBTEST( check_scalar_multiple2(dst, A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, -A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, s1*A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, A*s1, B, s2) ); +} + template void product_notemporary(const MatrixType& m) { /* This test checks the number of temporaries created @@ -148,6 +177,15 @@ template void product_notemporary(const MatrixType& m) // Check nested products VERIFY_EVALUATION_COUNT( cvres.noalias() = m1.adjoint() * m1 * cv1, 1 ); VERIFY_EVALUATION_COUNT( rvres.noalias() = rv1 * (m1 * m2.adjoint()), 1 ); + + // exhaustively check all scalar multiple combinations: + { + // Generic path: + check_scalar_multiple1(m3, m1, m2, s1, s2); + // Force fall back to coeff-based: + typename ColMajorMatrixType::BlockXpr m3_blck = m3.block(r0,r0,1,1); + check_scalar_multiple1(m3_blck, m1.block(r0,c0,1,1), m2.block(c0,r0,1,1), s1, s2); + } } EIGEN_DECLARE_TEST(product_notemporary)