bug #1405: enable StrictlyLower/StrictlyUpper triangularView as the destination of matrix*matrix products.

(grafted from ba5cab576a7615af12389ff159c6ed57b5312d5e
)
This commit is contained in:
Gael Guennebaud 2017-06-09 14:38:04 +02:00
parent 8880be60fa
commit 7a0a9581b5
3 changed files with 19 additions and 7 deletions

View File

@ -269,10 +269,13 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
enum {
IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
LhsIsRowMajor = _ActualLhs::Flags&RowMajorBit ? 1 : 0,
RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0
RhsIsRowMajor = _ActualRhs::Flags&RowMajorBit ? 1 : 0,
SkipDiag = (UpLo&(UnitDiag|ZeroDiag))!=0
};
Index size = mat.cols();
if(SkipDiag)
size--;
Index depth = actualLhs.cols();
typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,typename Lhs::Scalar,typename Rhs::Scalar,
@ -283,10 +286,11 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
IsRowMajor ? RowMajor : ColMajor, UpLo>
IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)>
::run(size, depth,
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
mat.data(), mat.outerStride(), actualAlpha, blocking);
&actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(),
&actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(),
mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking);
}
};
@ -294,6 +298,7 @@ template<typename MatrixType, unsigned int UpLo>
template<typename ProductType>
TriangularView<MatrixType,UpLo>& TriangularViewImpl<MatrixType,UpLo,Dense>::_assignProduct(const ProductType& prod, const Scalar& alpha, bool beta)
{
EIGEN_STATIC_ASSERT((UpLo&UnitDiag)==0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
general_product_to_triangular_selector<MatrixType, ProductType, UpLo, internal::traits<ProductType>::InnerSize==1>::run(derived().nestedExpression().const_cast_derived(), prod, alpha, beta);

View File

@ -52,7 +52,7 @@ struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,Con
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
{ \
if (lhs==rhs) { \
if ( lhs==rhs && ((UpLo&(Lower|Upper)==UpLo)) ) { \
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} else { \

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2010-2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -10,12 +10,19 @@
#include "main.h"
#define CHECK_MMTR(DEST, TRI, OP) { \
ref3 = DEST; \
ref2 = ref1 = DEST; \
DEST.template triangularView<TRI>() OP; \
ref1 OP; \
ref2.template triangularView<TRI>() \
= ref1.template triangularView<TRI>(); \
VERIFY_IS_APPROX(DEST,ref2); \
\
DEST = ref3; \
ref3 = ref2; \
ref3.diagonal() = DEST.diagonal(); \
DEST.template triangularView<TRI|ZeroDiag>() OP; \
VERIFY_IS_APPROX(DEST,ref3); \
}
template<typename Scalar> void mmtr(int size)
@ -27,7 +34,7 @@ template<typename Scalar> void mmtr(int size)
MatrixColMaj matc = MatrixColMaj::Zero(size, size);
MatrixRowMaj matr = MatrixRowMaj::Zero(size, size);
MatrixColMaj ref1(size, size), ref2(size, size);
MatrixColMaj ref1(size, size), ref2(size, size), ref3(size,size);
MatrixColMaj soc(size,othersize); soc.setRandom();
MatrixColMaj osc(othersize,size); osc.setRandom();