Optimize expression matching "d?=a-b*c" as "d?=a; d?=b*c;"

This commit is contained in:
Gael Guennebaud 2016-08-23 16:52:22 +02:00
parent e47a8928ec
commit 504a4404f1
3 changed files with 28 additions and 19 deletions

View File

@ -194,7 +194,6 @@ struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_product_op<ScalarBi
//---------------------------------------- //----------------------------------------
// Catch "Dense ?= xpr + Product<>" expression to save one temporary // Catch "Dense ?= xpr + Product<>" expression to save one temporary
// FIXME we could probably enable these rules for any product, i.e., not only Dense and DefaultProduct // FIXME we could probably enable these rules for any product, i.e., not only Dense and DefaultProduct
// TODO enable it for "Dense ?= xpr - Product<>" as well.
template<typename OtherXpr, typename Lhs, typename Rhs> template<typename OtherXpr, typename Lhs, typename Rhs>
struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_sum_op<typename OtherXpr::Scalar,typename Product<Lhs,Rhs,DefaultProduct>::Scalar>, const OtherXpr, struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_sum_op<typename OtherXpr::Scalar,typename Product<Lhs,Rhs,DefaultProduct>::Scalar>, const OtherXpr,
@ -203,10 +202,9 @@ struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_sum_op<typename
}; };
template<typename DstXprType, typename OtherXpr, typename ProductType, typename Func1, typename Func2> template<typename DstXprType, typename OtherXpr, typename ProductType, typename Func1, typename Func2>
struct assignment_from_xpr_plus_product struct assignment_from_xpr_op_product
{ {
typedef CwiseBinaryOp<internal::scalar_sum_op<typename OtherXpr::Scalar,typename ProductType::Scalar>, const OtherXpr, const ProductType> SrcXprType; template<typename SrcXprType, typename InitialFunc>
template<typename InitialFunc>
static EIGEN_STRONG_INLINE static EIGEN_STRONG_INLINE
void run(DstXprType &dst, const SrcXprType &src, const InitialFunc& /*func*/) void run(DstXprType &dst, const SrcXprType &src, const InitialFunc& /*func*/)
{ {
@ -215,21 +213,21 @@ struct assignment_from_xpr_plus_product
} }
}; };
template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> #define EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(ASSIGN_OP,BINOP,ASSIGN_OP2) \
struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr, template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> \
const Product<Lhs,Rhs,DefaultProduct> >, internal::assign_op<DstScalar,SrcScalar>, Dense2Dense> struct Assignment<DstXprType, CwiseBinaryOp<internal::BINOP<OtherScalar,ProdScalar>, const OtherXpr, \
: assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::assign_op<DstScalar,OtherScalar>, internal::add_assign_op<DstScalar,ProdScalar> > const Product<Lhs,Rhs,DefaultProduct> >, internal::ASSIGN_OP<DstScalar,SrcScalar>, Dense2Dense> \
{}; : assignment_from_xpr_op_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::ASSIGN_OP<DstScalar,OtherScalar>, internal::ASSIGN_OP2<DstScalar,ProdScalar> > \
template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> {}
struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr,
const Product<Lhs,Rhs,DefaultProduct> >, internal::add_assign_op<DstScalar,SrcScalar>, Dense2Dense> EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_sum_op,add_assign_op);
: assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::add_assign_op<DstScalar,OtherScalar>, internal::add_assign_op<DstScalar,ProdScalar> > EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_sum_op,add_assign_op);
{}; EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_sum_op,sub_assign_op);
template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar>
struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr, EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_difference_op,sub_assign_op);
const Product<Lhs,Rhs,DefaultProduct> >, internal::sub_assign_op<DstScalar,SrcScalar>, Dense2Dense> EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_difference_op,sub_assign_op);
: assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::sub_assign_op<DstScalar,OtherScalar>, internal::sub_assign_op<DstScalar,ProdScalar> > EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_difference_op,add_assign_op);
{};
//---------------------------------------- //----------------------------------------
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>

View File

@ -119,6 +119,14 @@ template<typename MatrixType> void product(const MatrixType& m)
res.noalias() -= square + m1 * m2.transpose(); res.noalias() -= square + m1 * m2.transpose();
VERIFY_IS_APPROX(res, square + m1 * m2.transpose()); VERIFY_IS_APPROX(res, square + m1 * m2.transpose());
// test d ?= a-b*c rules
res.noalias() = square - m1 * m2.transpose();
VERIFY_IS_APPROX(res, square - m1 * m2.transpose());
res.noalias() += square - m1 * m2.transpose();
VERIFY_IS_APPROX(res, 2*(square - m1 * m2.transpose()));
res.noalias() -= square - m1 * m2.transpose();
VERIFY_IS_APPROX(res, square - m1 * m2.transpose());
tm1 = m1; tm1 = m1;
VERIFY_IS_APPROX(tm1.transpose() * v1, m1.transpose() * v1); VERIFY_IS_APPROX(tm1.transpose() * v1, m1.transpose() * v1);

View File

@ -56,6 +56,9 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT( m3.noalias() = m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = m3 + m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() += m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() += m3 + m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 + m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() = m3 - m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() += m3 - m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 - m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * m2.adjoint(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * m2.adjoint(), 0);
VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * (m1*s3+m2*s2).adjoint(), 1); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * (m1*s3+m2*s2).adjoint(), 1);