mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-22 21:04:28 +08:00
GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product.
Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495).
This commit is contained in:
parent
ec032ac03b
commit
512b74aaa1
@ -411,35 +411,56 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
|
||||
call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op<typename Dst::Scalar,Scalar>());
|
||||
}
|
||||
|
||||
// Catch "dst {,+,-}= (s*A)*B" and evaluate it lazily by moving out the scalar factor:
|
||||
// dst {,+,-}= s * (A.lazyProduct(B))
|
||||
// This is a huge benefit for heap-allocated matrix types as it save one costly allocation.
|
||||
// For them, this strategy is also faster than simply by-passing the heap allocation through
|
||||
// stack allocation.
|
||||
// For fixed sizes matrices, this is less obvious, it is sometimes x2 faster, but sometimes x3 slower,
|
||||
// and the behavior depends also a lot on the compiler... so let's be conservative and enable them for dynamic-size only,
|
||||
// that is when coming from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h
|
||||
template<typename Dst, typename Scalar1, typename Scalar2, typename Plain1, typename Xpr2, typename Func>
|
||||
// This is a special evaluation path called from generic_product_impl<...,GemmProduct> in file GeneralMatrixMatrix.h
|
||||
// This variant tries to extract scalar multiples from both the LHS and RHS and factor them out. For instance:
|
||||
// dst {,+,-}= (s1*A)*(B*s2)
|
||||
// will be rewritten as:
|
||||
// dst {,+,-}= (s1*s2) * (A.lazyProduct(B))
|
||||
// There are at least four benefits of doing so:
|
||||
// 1 - huge performance gain for heap-allocated matrix types as it save costly allocations.
|
||||
// 2 - it is faster than simply by-passing the heap allocation through stack allocation.
|
||||
// 3 - it makes this fallback consistent with the heavy GEMM routine.
|
||||
// 4 - it fully by-passes huge stack allocation attempts when multiplying huge fixed-size matrices.
|
||||
// (see https://stackoverflow.com/questions/54738495)
|
||||
// For small fixed sizes matrices, howver, the gains are less obvious, it is sometimes x2 faster, but sometimes x3 slower,
|
||||
// and the behavior depends also a lot on the compiler... This is why this re-writting strategy is currently
|
||||
// enabled only when falling back from the main GEMM.
|
||||
template<typename Dst, typename Func>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
void eval_dynamic(Dst& dst, const CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
|
||||
const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>, Xpr2>& lhs, const Rhs& rhs, const Func &func)
|
||||
void eval_dynamic(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Func &func)
|
||||
{
|
||||
call_restricted_packet_assignment_no_alias(dst, lhs.lhs().functor().m_other * lhs.rhs().lazyProduct(rhs), func);
|
||||
enum {
|
||||
HasScalarFactor = blas_traits<Lhs>::HasScalarFactor || blas_traits<Rhs>::HasScalarFactor
|
||||
};
|
||||
// FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto
|
||||
// this is important for real*complex_mat
|
||||
Scalar actualAlpha = blas_traits<Lhs>::extractScalarFactor(lhs)
|
||||
* blas_traits<Rhs>::extractScalarFactor(rhs);
|
||||
eval_dynamic_impl(dst,
|
||||
blas_traits<Lhs>::extract(lhs),
|
||||
blas_traits<Rhs>::extract(rhs),
|
||||
func,
|
||||
actualAlpha,
|
||||
typename conditional<HasScalarFactor,true_type,false_type>::type());
|
||||
|
||||
|
||||
}
|
||||
|
||||
// Here, we we always have LhsT==Lhs, but we need to make it a template type to make the above
|
||||
// overload more specialized.
|
||||
template<typename Dst, typename LhsT, typename Func>
|
||||
protected:
|
||||
|
||||
template<typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
void eval_dynamic(Dst& dst, const LhsT& lhs, const Rhs& rhs, const Func &func)
|
||||
void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& /* s == 1 */, false_type)
|
||||
{
|
||||
call_restricted_packet_assignment_no_alias(dst, lhs.lazyProduct(rhs), func);
|
||||
}
|
||||
|
||||
|
||||
// template<typename Dst>
|
||||
// static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||
// { dst.noalias() += alpha * lhs.lazyProduct(rhs); }
|
||||
|
||||
template<typename Dst, typename LhsT, typename RhsT, typename Func, typename Scalar>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
void eval_dynamic_impl(Dst& dst, const LhsT& lhs, const RhsT& rhs, const Func &func, const Scalar& s, true_type)
|
||||
{
|
||||
call_restricted_packet_assignment_no_alias(dst, s * lhs.lazyProduct(rhs), func);
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization enforces the use of a coefficient-based evaluation strategy
|
||||
|
@ -274,7 +274,8 @@ template<typename XprType> struct blas_traits
|
||||
HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
|
||||
&& ( bool(XprType::IsVectorAtCompileTime)
|
||||
|| int(inner_stride_at_compile_time<XprType>::ret) == 1)
|
||||
) ? 1 : 0
|
||||
) ? 1 : 0,
|
||||
HasScalarFactor = false
|
||||
};
|
||||
typedef typename conditional<bool(HasUsableDirectAccess),
|
||||
ExtractType,
|
||||
@ -306,6 +307,9 @@ template<typename Scalar, typename NestedXpr, typename Plain>
|
||||
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
|
||||
: blas_traits<NestedXpr>
|
||||
{
|
||||
enum {
|
||||
HasScalarFactor = true
|
||||
};
|
||||
typedef blas_traits<NestedXpr> Base;
|
||||
typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
|
||||
typedef typename Base::ExtractType ExtractType;
|
||||
@ -317,6 +321,9 @@ template<typename Scalar, typename NestedXpr, typename Plain>
|
||||
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
|
||||
: blas_traits<NestedXpr>
|
||||
{
|
||||
enum {
|
||||
HasScalarFactor = true
|
||||
};
|
||||
typedef blas_traits<NestedXpr> Base;
|
||||
typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
|
||||
typedef typename Base::ExtractType ExtractType;
|
||||
@ -335,6 +342,9 @@ template<typename Scalar, typename NestedXpr>
|
||||
struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
|
||||
: blas_traits<NestedXpr>
|
||||
{
|
||||
enum {
|
||||
HasScalarFactor = true
|
||||
};
|
||||
typedef blas_traits<NestedXpr> Base;
|
||||
typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
|
||||
typedef typename Base::ExtractType ExtractType;
|
||||
@ -358,7 +368,7 @@ struct blas_traits<Transpose<NestedXpr> >
|
||||
typename ExtractType::PlainObject
|
||||
>::type DirectLinearAccessType;
|
||||
enum {
|
||||
IsTransposed = Base::IsTransposed ? 0 : 1
|
||||
IsTransposed = Base::IsTransposed ? 0 : 1,
|
||||
};
|
||||
static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
|
||||
static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
|
||||
|
@ -11,6 +11,35 @@
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs>
|
||||
void check_scalar_multiple3(Dst &dst, const Lhs& A, const Rhs& B)
|
||||
{
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() = A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() += A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, 2*(A.eval() * B.eval()).eval() );
|
||||
VERIFY_EVALUATION_COUNT( (dst.noalias() -= A * B), 0);
|
||||
VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
|
||||
}
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename S2>
|
||||
void check_scalar_multiple2(Dst &dst, const Lhs& A, const Rhs& B, S2 s2)
|
||||
{
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, -B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, s2*B) );
|
||||
CALL_SUBTEST( check_scalar_multiple3(dst, A, B*s2) );
|
||||
}
|
||||
|
||||
template<typename Dst, typename Lhs, typename Rhs, typename S1, typename S2>
|
||||
void check_scalar_multiple1(Dst &dst, const Lhs& A, const Rhs& B, S1 s1, S2 s2)
|
||||
{
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, -A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, s1*A, B, s2) );
|
||||
CALL_SUBTEST( check_scalar_multiple2(dst, A*s1, B, s2) );
|
||||
}
|
||||
|
||||
template<typename MatrixType> void product_notemporary(const MatrixType& m)
|
||||
{
|
||||
/* This test checks the number of temporaries created
|
||||
@ -148,6 +177,15 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
|
||||
// Check nested products
|
||||
VERIFY_EVALUATION_COUNT( cvres.noalias() = m1.adjoint() * m1 * cv1, 1 );
|
||||
VERIFY_EVALUATION_COUNT( rvres.noalias() = rv1 * (m1 * m2.adjoint()), 1 );
|
||||
|
||||
// exhaustively check all scalar multiple combinations:
|
||||
{
|
||||
// Generic path:
|
||||
check_scalar_multiple1(m3, m1, m2, s1, s2);
|
||||
// Force fall back to coeff-based:
|
||||
typename ColMajorMatrixType::BlockXpr m3_blck = m3.block(r0,r0,1,1);
|
||||
check_scalar_multiple1(m3_blck, m1.block(r0,c0,1,1), m2.block(c0,r0,1,1), s1, s2);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_notemporary)
|
||||
|
Loading…
x
Reference in New Issue
Block a user