From 655ba783f8b2c9a8c3f4edb45e6db468aca22188 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 17 Jan 2017 18:03:35 +0100 Subject: [PATCH] Defer set-to-zero in triangular = product so that no aliasing issue occur in the common: A.triangularView() = B*A.sefladjointView()*B.adjoint() case that used to work in 3.2. --- Eigen/src/Core/TriangularMatrix.h | 9 ++++----- .../Core/products/GeneralMatrixMatrixTriangular.h | 14 ++++++++++---- test/product_mmtr.cpp | 13 +++++++++++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index c35fb8083..667ef09dc 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -543,7 +543,7 @@ template class TriangularViewImpl<_Mat template EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha); + EIGEN_STRONG_INLINE TriangularViewType& _assignProduct(const ProductType& prod, const Scalar& alpha, bool beta); }; /*************************************************************************** @@ -950,8 +950,7 @@ struct Assignment, internal::assign_ if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) dst.resize(dstRows, dstCols); - dst.setZero(); - dst._assignProduct(src, 1); + dst._assignProduct(src, 1, 0); } }; @@ -962,7 +961,7 @@ struct Assignment, internal::add_ass typedef Product SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op &) { - dst._assignProduct(src, 1); + dst._assignProduct(src, 1, 1); } }; @@ -973,7 +972,7 @@ struct Assignment, internal::sub_ass typedef Product SrcXprType; static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op &) { - dst._assignProduct(src, -1); + dst._assignProduct(src, -1, 1); } }; diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index 29d6dc721..5cd2794a4 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -199,7 +199,7 @@ struct general_product_to_triangular_selector; template struct general_product_to_triangular_selector { - static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) { typedef typename MatrixType::Scalar Scalar; @@ -217,6 +217,9 @@ struct general_product_to_triangular_selector Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + if(!beta) + mat.template triangularView().setZero(); + enum { StorageOrder = (internal::traits::Flags&RowMajorBit) ? RowMajor : ColMajor, UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, @@ -244,7 +247,7 @@ struct general_product_to_triangular_selector template struct general_product_to_triangular_selector { - static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) { typedef typename internal::remove_all::type Lhs; typedef internal::blas_traits LhsBlasTraits; @@ -260,6 +263,9 @@ struct general_product_to_triangular_selector typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + if(!beta) + mat.template triangularView().setZero(); + enum { IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0, LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0, @@ -286,11 +292,11 @@ struct general_product_to_triangular_selector template template -TriangularView& TriangularViewImpl::_assignProduct(const ProductType& prod, const Scalar& alpha) +TriangularView& TriangularViewImpl::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta) { eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols()); - general_product_to_triangular_selector::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha); + general_product_to_triangular_selector::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta); return derived(); } diff --git a/test/product_mmtr.cpp b/test/product_mmtr.cpp index b66529acd..f6e4bb1ae 100644 --- a/test/product_mmtr.cpp +++ b/test/product_mmtr.cpp @@ -62,6 +62,19 @@ template void mmtr(int size) CHECK_MMTR(matc, Upper, -= (s*sqc).template triangularView()*sqc); CHECK_MMTR(matc, Lower, = (s*sqr).template triangularView()*sqc); CHECK_MMTR(matc, Upper, += (s*sqc).template triangularView()*sqc); + + // check aliasing + ref2 = ref1 = matc; + ref1 = sqc.adjoint() * matc * sqc; + ref2.template triangularView() = ref1.template triangularView(); + matc.template triangularView() = sqc.adjoint() * matc * sqc; + VERIFY_IS_APPROX(matc, ref2); + + ref2 = ref1 = matc; + ref1 = sqc * matc * sqc.adjoint(); + ref2.template triangularView() = ref1.template triangularView(); + matc.template triangularView() = sqc * matc * sqc.adjoint(); + VERIFY_IS_APPROX(matc, ref2); } void test_product_mmtr()