From 7852a48a2f7ccba4a526640d0aac03fe12580da3 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 25 Jan 2010 21:56:01 +0100 Subject: [PATCH] fix matrix product with EIGEN_DEFAULT_TO_ROW_MAJOR --- Eigen/src/Core/Product.h | 74 ++++++++++++++++++++++++--------- Eigen/src/Core/util/XprHelper.h | 13 ++++++ test/product.h | 1 + 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 1151b2164..ca769c622 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -66,11 +66,8 @@ struct ProductReturnType template struct ProductReturnType { - typedef typename ei_nested::type LhsNested; - - typedef typename ei_nested::type - >::type RhsNested; + typedef const Lhs& LhsNested; + typedef const Rhs& RhsNested; typedef Product Type; }; @@ -144,7 +141,7 @@ struct ei_traits > EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)), - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit)|DirectAccessBit), Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) | EvalBeforeAssigningBit @@ -571,7 +568,7 @@ struct ei_cache_friendly_product_selector >(_res, res.size()) = res; + Map >(_res, res.size()) = res; } ei_cache_friendly_product_colmajor_times_vector(res.size(), &product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), @@ -579,7 +576,7 @@ struct ei_cache_friendly_product_selector >(_res, res.size()); + res = Map >(_res, res.size()); ei_aligned_stack_delete(Scalar, _res, res.size()); } } @@ -617,7 +614,7 @@ struct ei_cache_friendly_product_selector >(_res, res.size()) = res; + Map >(_res, res.size()) = res; } ei_cache_friendly_product_colmajor_times_vector(res.size(), &product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), @@ -625,7 +622,7 @@ struct ei_cache_friendly_product_selector >(_res, res.size()); + res = Map >(_res, res.size()); ei_aligned_stack_delete(Scalar, _res, res.size()); } } @@ -650,7 +647,7 @@ struct ei_cache_friendly_product_selector >(_rhs, product.rhs().size()) = product.rhs(); + Map >(_rhs, product.rhs().size()) = product.rhs(); } ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), _rhs, product.rhs().size(), res); @@ -678,7 +675,7 @@ struct ei_cache_friendly_product_selector >(_lhs, product.lhs().size()) = product.lhs(); + Map >(_lhs, product.lhs().size()) = product.lhs(); } ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), _lhs, product.lhs().size(), res); @@ -709,7 +706,17 @@ MatrixBase::operator+=(const Flagged >::run(const_cast_derived(), other._expression()); else - lazyAssign(derived() + other._expression()); + { + typedef typename ei_cleantype::type _Lhs; + typedef typename ei_cleantype::type _Rhs; + + typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested; + typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested; + + Product prod(other._expression().lhs(),other._expression().rhs()); + + lazyAssign(derived() + prod); + } return derived(); } @@ -724,12 +731,21 @@ inline Derived& MatrixBase::lazyAssign(const Product >(product); + typedef typename ei_cleantype::type _Lhs; + typedef typename ei_cleantype::type _Rhs; + + typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested; + typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested; + + typedef Product NormalProduct; + NormalProduct normal_prod(product.lhs(),product.rhs()); + + lazyAssign(normal_prod); } return derived(); } -template struct ei_product_copy_rhs +template struct ei_product_copy_rhs { typedef typename ei_meta_if< (ei_traits::Flags & RowMajorBit) @@ -739,11 +755,30 @@ template struct ei_product_copy_rhs >::ret type; }; -template struct ei_product_copy_lhs +template struct ei_product_copy_rhs +{ + typedef typename ei_meta_if< + (!(ei_traits::Flags & DirectAccessBit)), + typename ei_plain_matrix_type::type, + const T& + >::ret type; +}; + +template struct ei_product_copy_lhs { typedef typename ei_meta_if< (!(int(ei_traits::Flags) & DirectAccessBit)), - typename ei_plain_matrix_type::type, + typename ei_plain_matrix_type_row_major::type, + const T& + >::ret type; +}; + +template struct ei_product_copy_lhs +{ + typedef typename ei_meta_if< + ((ei_traits::Flags & RowMajorBit)==0) + || (!(int(ei_traits::Flags) & DirectAccessBit)), + typename ei_plain_matrix_type_row_major::type, const T& >::ret type; }; @@ -752,9 +787,9 @@ template template inline void Product::_cacheFriendlyEvalAndAdd(DestDerived& res) const { - typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy; + typedef typename ei_product_copy_lhs<_LhsNested,DestDerived::Flags&RowMajorBit>::type LhsCopy; typedef typename ei_unref::type _LhsCopy; - typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy; + typedef typename ei_product_copy_rhs<_RhsNested,DestDerived::Flags&RowMajorBit>::type RhsCopy; typedef typename ei_unref::type _RhsCopy; LhsCopy lhs(m_lhs); RhsCopy rhs(m_rhs); @@ -764,6 +799,7 @@ inline void Product::_cacheFriendlyEvalAndAdd(DestDerived& _RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride() ); + } #endif // EIGEN_PRODUCT_H diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index 7cadd343d..c22d46a52 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -161,6 +161,19 @@ template struct ei_plain_matrix_type_column_major > type; }; +/* ei_plain_matrix_type_row_major : same as ei_plain_matrix_type but guaranteed to be row-major + */ +template struct ei_plain_matrix_type_row_major +{ + typedef Matrix::Scalar, + ei_traits::RowsAtCompileTime, + ei_traits::ColsAtCompileTime, + AutoAlign | RowMajor, + ei_traits::MaxRowsAtCompileTime, + ei_traits::MaxColsAtCompileTime + > type; +}; + template struct ei_must_nest_by_value { enum { ret = false }; }; template struct ei_must_nest_by_value > { enum { ret = true }; }; diff --git a/test/product.h b/test/product.h index 4f4aeb965..ec84ff92b 100644 --- a/test/product.h +++ b/test/product.h @@ -77,6 +77,7 @@ template void product(const MatrixType& m) // begin testing Product.h: only associativity for now // (we use Transpose.h but this doesn't count as a test for it) + VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2)); m3 = m1; m3 *= m1.transpose() * m2;