diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 955668bef..a64bda394 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -194,7 +194,6 @@ struct Assignment" expression to save one temporary // FIXME we could probably enable these rules for any product, i.e., not only Dense and DefaultProduct -// TODO enable it for "Dense ?= xpr - Product<>" as well. template struct evaluator_assume_aliasing::Scalar>, const OtherXpr, @@ -203,10 +202,9 @@ struct evaluator_assume_aliasing -struct assignment_from_xpr_plus_product +struct assignment_from_xpr_op_product { - typedef CwiseBinaryOp, const OtherXpr, const ProductType> SrcXprType; - template + template static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const InitialFunc& /*func*/) { @@ -215,21 +213,21 @@ struct assignment_from_xpr_plus_product } }; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment, const OtherXpr, - const Product >, internal::assign_op, Dense2Dense> - : assignment_from_xpr_plus_product, internal::assign_op, internal::add_assign_op > -{}; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment, const OtherXpr, - const Product >, internal::add_assign_op, Dense2Dense> - : assignment_from_xpr_plus_product, internal::add_assign_op, internal::add_assign_op > -{}; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment, const OtherXpr, - const Product >, internal::sub_assign_op, Dense2Dense> - : assignment_from_xpr_plus_product, internal::sub_assign_op, internal::sub_assign_op > -{}; +#define EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(ASSIGN_OP,BINOP,ASSIGN_OP2) \ + template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> \ + struct Assignment, const OtherXpr, \ + const Product >, internal::ASSIGN_OP, Dense2Dense> \ + : assignment_from_xpr_op_product, internal::ASSIGN_OP, internal::ASSIGN_OP2 > \ + {} + +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_sum_op,add_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_sum_op,add_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_sum_op,sub_assign_op); + +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_difference_op,sub_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_difference_op,sub_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_difference_op,add_assign_op); + //---------------------------------------- template diff --git a/test/product.h b/test/product.h index 27976a4ae..cabfc0b03 100644 --- a/test/product.h +++ b/test/product.h @@ -119,6 +119,14 @@ template void product(const MatrixType& m) res.noalias() -= square + m1 * m2.transpose(); VERIFY_IS_APPROX(res, square + m1 * m2.transpose()); + // test d ?= a-b*c rules + res.noalias() = square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, square - m1 * m2.transpose()); + res.noalias() += square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, 2*(square - m1 * m2.transpose())); + res.noalias() -= square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, square - m1 * m2.transpose()); + tm1 = m1; VERIFY_IS_APPROX(tm1.transpose() * v1, m1.transpose() * v1); diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index 5a3f3a01a..2bb19a681 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -56,6 +56,9 @@ template void product_notemporary(const MatrixType& m) VERIFY_EVALUATION_COUNT( m3.noalias() = m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() += m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 + m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() = m3 - m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() += m3 - m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 - m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * m2.adjoint(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * (m1*s3+m2*s2).adjoint(), 1);