mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 00:29:38 +08:00
trmm is now fully working and available via TriangularView::operator*
This commit is contained in:
parent
6aba84719d
commit
f95b77be62
@ -142,8 +142,10 @@ struct ei_traits<TriangularView<MatrixType, _Mode> > : ei_traits<MatrixType>
|
||||
};
|
||||
};
|
||||
|
||||
template<typename Lhs,typename Rhs>
|
||||
struct ei_triangular_vector_product_returntype;
|
||||
template<int Mode, bool LhsIsTriangular,
|
||||
typename Lhs, bool LhsIsVector,
|
||||
typename Rhs, bool RhsIsVector>
|
||||
struct ei_triangular_product_returntype;
|
||||
|
||||
template<typename _MatrixType, unsigned int _Mode> class TriangularView
|
||||
: public TriangularBase<TriangularView<_MatrixType, _Mode> >
|
||||
@ -247,11 +249,24 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
|
||||
return res;
|
||||
}
|
||||
|
||||
/** Efficient triangular matrix times vector/matrix product */
|
||||
template<typename OtherDerived>
|
||||
ei_triangular_vector_product_returntype<TriangularView,OtherDerived>
|
||||
ei_triangular_product_returntype<Mode,true,MatrixType,false,OtherDerived,OtherDerived::IsVectorAtCompileTime>
|
||||
operator*(const MatrixBase<OtherDerived>& rhs) const
|
||||
{
|
||||
return ei_triangular_vector_product_returntype<TriangularView,OtherDerived>(*this, rhs.derived(), 1);
|
||||
return ei_triangular_product_returntype
|
||||
<Mode,true,MatrixType,false,OtherDerived,OtherDerived::IsVectorAtCompileTime>
|
||||
(m_matrix, rhs.derived());
|
||||
}
|
||||
|
||||
/** Efficient vector/matrix times triangular matrix product */
|
||||
template<typename OtherDerived> friend
|
||||
ei_triangular_product_returntype<Mode,false,OtherDerived,OtherDerived::IsVectorAtCompileTime,MatrixType,false>
|
||||
operator*(const MatrixBase<OtherDerived>& lhs, const TriangularView& rhs)
|
||||
{
|
||||
return ei_triangular_product_returntype
|
||||
<Mode,false,OtherDerived,OtherDerived::IsVectorAtCompileTime,MatrixType,false>
|
||||
(lhs.derived(),rhs.m_matrix);
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
|
392
Eigen/src/Core/products/TriangularMatrixMatrix.h
Normal file
392
Eigen/src/Core/products/TriangularMatrixMatrix.h
Normal file
@ -0,0 +1,392 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2009 Gael Guennebaud <g.gael@free.fr>
|
||||
//
|
||||
// Eigen is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU Lesser General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 3 of the License, or (at your option) any later version.
|
||||
//
|
||||
// Alternatively, you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public License as
|
||||
// published by the Free Software Foundation; either version 2 of
|
||||
// the License, or (at your option) any later version.
|
||||
//
|
||||
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public
|
||||
// License and a copy of the GNU General Public License along with
|
||||
// Eigen. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
||||
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
||||
|
||||
// template<typename Scalar, int mr, int StorageOrder, bool Conjugate, int Mode>
|
||||
// struct ei_gemm_pack_lhs_triangular
|
||||
// {
|
||||
// Matrix<Scalar,mr,mr,
|
||||
// void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows)
|
||||
// {
|
||||
// ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
// ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
|
||||
// int count = 0;
|
||||
// const int peeled_mc = (rows/mr)*mr;
|
||||
// for(int i=0; i<peeled_mc; i+=mr)
|
||||
// {
|
||||
// for(int k=0; k<depth; k++)
|
||||
// for(int w=0; w<mr; w++)
|
||||
// blockA[count++] = cj(lhs(i+w, k));
|
||||
// }
|
||||
// for(int i=peeled_mc; i<rows; i++)
|
||||
// {
|
||||
// for(int k=0; k<depth; k++)
|
||||
// blockA[count++] = cj(lhs(i, k));
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
|
||||
/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
|
||||
* the general matrix matrix product.
|
||||
*/
|
||||
template <typename Scalar,
|
||||
int Mode, bool LhsIsTriangular,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResStorageOrder>
|
||||
struct ei_product_triangular_matrix_matrix;
|
||||
|
||||
template <typename Scalar,
|
||||
int Mode, bool LhsIsTriangular,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct ei_product_triangular_matrix_matrix<Scalar,Mode,LhsIsTriangular,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,RowMajor>
|
||||
{
|
||||
static EIGEN_STRONG_INLINE void run(
|
||||
int size, int otherSize,
|
||||
const Scalar* lhs, int lhsStride,
|
||||
const Scalar* rhs, int rhsStride,
|
||||
Scalar* res, int resStride,
|
||||
Scalar alpha)
|
||||
{
|
||||
ei_product_triangular_matrix_matrix<Scalar,
|
||||
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
|
||||
(!LhsIsTriangular),
|
||||
RhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
||||
ConjugateRhs,
|
||||
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
||||
ConjugateLhs,
|
||||
ColMajor>
|
||||
::run(size, otherSize, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
// implements col-major += alpha * op(triangular) * op(general)
|
||||
template <typename Scalar, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct ei_product_triangular_matrix_matrix<Scalar,Mode,true,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor>
|
||||
{
|
||||
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int cols,
|
||||
const Scalar* _lhs, int lhsStride,
|
||||
const Scalar* _rhs, int rhsStride,
|
||||
Scalar* res, int resStride,
|
||||
Scalar alpha)
|
||||
{
|
||||
int rows = size;
|
||||
|
||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||
|
||||
if (ConjugateRhs)
|
||||
alpha = ei_conj(alpha);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
|
||||
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
||||
};
|
||||
|
||||
int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
||||
int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
|
||||
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
|
||||
|
||||
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer;
|
||||
triangularBuffer.setZero();
|
||||
triangularBuffer.diagonal().setOnes();
|
||||
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||
|
||||
for(int k2=IsLowerTriangular ? size : 0;
|
||||
IsLowerTriangular ? k2>0 : k2<size;
|
||||
IsLowerTriangular ? k2-=kc : k2+=kc)
|
||||
{
|
||||
const int actual_kc = std::min(IsLowerTriangular ? k2 : size-k2, kc);
|
||||
int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2;
|
||||
|
||||
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols);
|
||||
|
||||
// the selected lhs's panel has to be split in three different parts:
|
||||
// 1 - the part which is above the diagonal block => skip it
|
||||
// 2 - the diagonal block => special kernel
|
||||
// 3 - the panel below the diagonal block => GEPP
|
||||
// the block diagonal
|
||||
{
|
||||
// for each small vertical panels of lhs
|
||||
for (int k1=0; k1<actual_kc; k1+=SmallPanelWidth)
|
||||
{
|
||||
int actualPanelWidth = std::min<int>(actual_kc-k1, SmallPanelWidth);
|
||||
int lengthTarget = IsLowerTriangular ? actual_kc-k1-actualPanelWidth : k1;
|
||||
int startBlock = actual_k2+k1;
|
||||
int blockBOffset = k1;
|
||||
|
||||
// => GEBP with the micro triangular block
|
||||
// The trick is to pack this micro block while filling the opposite triangular part with zeros.
|
||||
// To this end we do an extra triangular copy to small temporary buffer
|
||||
for (int k=0;k<actualPanelWidth;++k)
|
||||
{
|
||||
if (!(Mode&UnitDiagBit))
|
||||
triangularBuffer.coeffRef(k,k) = lhs(startBlock+k,startBlock+k);
|
||||
for (int i=IsLowerTriangular ? k+1 : 0; IsLowerTriangular ? i<actualPanelWidth : i<k; ++i)
|
||||
triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k);
|
||||
}
|
||||
pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.stride(), actualPanelWidth, actualPanelWidth);
|
||||
|
||||
gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols,
|
||||
actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
|
||||
|
||||
// GEBP with remaining micro panel
|
||||
if (lengthTarget>0)
|
||||
{
|
||||
int startTarget = IsLowerTriangular ? actual_k2+k1+actualPanelWidth : actual_k2;
|
||||
|
||||
pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
|
||||
|
||||
gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
|
||||
actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
// the part below the diagonal => GEPP
|
||||
{
|
||||
int start = IsLowerTriangular ? k2 : 0;
|
||||
int end = IsLowerTriangular ? size : actual_k2;
|
||||
for(int i2=start; i2<end; i2+=mc)
|
||||
{
|
||||
const int actual_mc = std::min(i2+mc,end)-i2;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder,false>()
|
||||
(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
|
||||
|
||||
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
||||
ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
|
||||
}
|
||||
};
|
||||
|
||||
// implements col-major += alpha * op(general) * op(triangular)
|
||||
template <typename Scalar, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct ei_product_triangular_matrix_matrix<Scalar,Mode,false,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor>
|
||||
{
|
||||
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int rows,
|
||||
const Scalar* _lhs, int lhsStride,
|
||||
const Scalar* _rhs, int rhsStride,
|
||||
Scalar* res, int resStride,
|
||||
Scalar alpha)
|
||||
{
|
||||
int cols = size;
|
||||
|
||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||
|
||||
if (ConjugateRhs)
|
||||
alpha = ei_conj(alpha);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
|
||||
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
||||
};
|
||||
|
||||
int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
||||
int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
|
||||
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
|
||||
|
||||
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer;
|
||||
triangularBuffer.setZero();
|
||||
triangularBuffer.diagonal().setOnes();
|
||||
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
||||
|
||||
for(int k2=IsLowerTriangular ? 0 : size;
|
||||
IsLowerTriangular ? k2<size : k2>0;
|
||||
IsLowerTriangular ? k2+=kc : k2-=kc)
|
||||
{
|
||||
const int actual_kc = std::min(IsLowerTriangular ? size-k2 : k2, kc);
|
||||
int actual_k2 = IsLowerTriangular ? k2 : k2-actual_kc;
|
||||
int rs = IsLowerTriangular ? actual_k2 : size - k2;
|
||||
Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
|
||||
|
||||
pack_rhs(geb, &rhs(actual_k2,IsLowerTriangular ? 0 : k2), rhsStride, alpha, actual_kc, rs);
|
||||
|
||||
// pack the triangular part of the rhs padding the unrolled blocks with zeros
|
||||
{
|
||||
for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
|
||||
{
|
||||
int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
|
||||
int actual_j2 = actual_k2 + j2;
|
||||
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
||||
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
|
||||
// general part
|
||||
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
&rhs(actual_k2+panelOffset, actual_j2), rhsStride, alpha,
|
||||
panelLength, actualPanelWidth,
|
||||
actual_kc, panelOffset);
|
||||
|
||||
// append the triangular part via a temporary buffer
|
||||
for (int j=0;j<actualPanelWidth;++j)
|
||||
{
|
||||
if (!(Mode&UnitDiagBit))
|
||||
triangularBuffer.coeffRef(j,j) = rhs(actual_j2+j,actual_j2+j);
|
||||
for (int k=IsLowerTriangular ? j+1 : 0; IsLowerTriangular ? k<actualPanelWidth : k<j; ++k)
|
||||
triangularBuffer.coeffRef(k,j) = rhs(actual_j2+k,actual_j2+j);
|
||||
}
|
||||
|
||||
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
triangularBuffer.data(), triangularBuffer.stride(), alpha,
|
||||
actualPanelWidth, actualPanelWidth,
|
||||
actual_kc, j2);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i2=0; i2<rows; i2+=mc)
|
||||
{
|
||||
const int actual_mc = std::min(mc,rows-i2);
|
||||
pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
|
||||
|
||||
// triangular kernel
|
||||
{
|
||||
for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
|
||||
{
|
||||
int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
|
||||
int panelLength = IsLowerTriangular ? actual_kc-j2 : j2+actualPanelWidth;
|
||||
int blockOffset = IsLowerTriangular ? j2 : 0;
|
||||
|
||||
gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride,
|
||||
blockA, blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
actual_mc, panelLength, actualPanelWidth,
|
||||
actual_kc, actual_kc, // strides
|
||||
blockOffset, blockOffset*Blocking::PacketSize);// offsets
|
||||
}
|
||||
}
|
||||
gebp_kernel(res+i2+(IsLowerTriangular ? 0 : k2)*resStride, resStride,
|
||||
blockA, geb, actual_mc, actual_kc, rs);
|
||||
}
|
||||
}
|
||||
|
||||
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
||||
ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* Wrapper to ei_product_triangular_matrix_matrix
|
||||
***************************************************************************/
|
||||
|
||||
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
||||
struct ei_triangular_product_returntype<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
: public ReturnByValue<ei_triangular_product_returntype<Mode,LhsIsTriangular,Lhs,false,Rhs,false>,
|
||||
Matrix<typename ei_traits<Rhs>::Scalar,
|
||||
Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
{
|
||||
ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{}
|
||||
|
||||
typedef typename Lhs::Scalar Scalar;
|
||||
|
||||
typedef typename Lhs::Nested LhsNested;
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
|
||||
typedef typename Rhs::Nested RhsNested;
|
||||
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
|
||||
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
|
||||
// enum {
|
||||
// LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
|
||||
// LhsIsTriangular = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
|
||||
// RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
|
||||
// RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
|
||||
// };
|
||||
|
||||
template<typename Dest> inline void _addTo(Dest& dst) const
|
||||
{ evalTo(dst,1); }
|
||||
template<typename Dest> inline void _subTo(Dest& dst) const
|
||||
{ evalTo(dst,-1); }
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst) const
|
||||
{
|
||||
dst.resize(m_lhs.rows(), m_rhs.cols());
|
||||
dst.setZero();
|
||||
evalTo(dst,1);
|
||||
}
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
ei_product_triangular_matrix_matrix<Scalar,
|
||||
Mode, LhsIsTriangular,
|
||||
(ei_traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
(ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
(ei_traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(
|
||||
lhs.rows(), LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes
|
||||
&lhs.coeff(0,0), lhs.stride(), // lhs info
|
||||
&rhs.coeff(0,0), rhs.stride(), // rhs info
|
||||
&dst.coeffRef(0,0), dst.stride(), // result info
|
||||
actualAlpha // alpha
|
||||
);
|
||||
}
|
||||
|
||||
const LhsNested m_lhs;
|
||||
const RhsNested m_rhs;
|
||||
};
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
@ -113,49 +113,113 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs,typename Rhs>
|
||||
struct ei_triangular_vector_product_returntype
|
||||
: public ReturnByValue<ei_triangular_vector_product_returntype<Lhs,Rhs>,
|
||||
// template<typename Lhs,typename Rhs>
|
||||
// struct ei_triangular_vector_product_returntype
|
||||
// : public ReturnByValue<ei_triangular_vector_product_returntype<Lhs,Rhs>,
|
||||
// Matrix<typename ei_traits<Rhs>::Scalar,
|
||||
// Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
// {
|
||||
// typedef typename Lhs::Scalar Scalar;
|
||||
// typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
|
||||
// ei_triangular_vector_product_returntype(const Lhs& lhs, const Rhs& rhs, Scalar alpha)
|
||||
// : m_lhs(lhs), m_rhs(rhs), m_alpha(alpha)
|
||||
// {}
|
||||
//
|
||||
// template<typename Dest> void evalTo(Dest& dst) const
|
||||
// {
|
||||
// typedef typename Lhs::MatrixType MatrixType;
|
||||
//
|
||||
// typedef ei_blas_traits<MatrixType> LhsBlasTraits;
|
||||
// typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
// typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
// const ActualLhsType actualLhs = LhsBlasTraits::extract(m_lhs._expression());
|
||||
//
|
||||
// typedef ei_blas_traits<Rhs> RhsBlasTraits;
|
||||
// typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
// typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
// const ActualRhsType actualRhs = RhsBlasTraits::extract(m_rhs);
|
||||
//
|
||||
// Scalar actualAlpha = m_alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
|
||||
// * RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
//
|
||||
// dst.resize(m_rhs.rows(), m_rhs.cols());
|
||||
// dst.setZero();
|
||||
// ei_product_triangular_vector_selector
|
||||
// <_ActualLhsType,_ActualRhsType,Dest,
|
||||
// ei_traits<Lhs>::Mode,
|
||||
// LhsBlasTraits::NeedToConjugate,
|
||||
// RhsBlasTraits::NeedToConjugate,
|
||||
// ei_traits<Lhs>::Flags&RowMajorBit>
|
||||
// ::run(actualLhs,actualRhs,dst,actualAlpha);
|
||||
// }
|
||||
//
|
||||
// const Lhs m_lhs;
|
||||
// const typename Rhs::Nested m_rhs;
|
||||
// const Scalar m_alpha;
|
||||
// };
|
||||
|
||||
|
||||
/***************************************************************************
|
||||
* Wrapper to ei_product_triangular_vector
|
||||
***************************************************************************/
|
||||
|
||||
template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
|
||||
struct ei_triangular_product_returntype<Mode,true,Lhs,false,Rhs,true>
|
||||
: public ReturnByValue<ei_triangular_product_returntype<Mode,true,Lhs,false,Rhs,true>,
|
||||
Matrix<typename ei_traits<Rhs>::Scalar,
|
||||
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
|
||||
{
|
||||
typedef typename Lhs::Scalar Scalar;
|
||||
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
|
||||
ei_triangular_vector_product_returntype(const Lhs& lhs, const Rhs& rhs, Scalar alpha)
|
||||
: m_lhs(lhs), m_rhs(rhs), m_alpha(alpha)
|
||||
|
||||
typedef typename Lhs::Nested LhsNested;
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
|
||||
typedef typename Rhs::Nested RhsNested;
|
||||
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
|
||||
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
|
||||
ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{}
|
||||
|
||||
template<typename Dest> inline void _addTo(Dest& dst) const
|
||||
{ evalTo(dst,1); }
|
||||
template<typename Dest> inline void _subTo(Dest& dst) const
|
||||
{ evalTo(dst,-1); }
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst) const
|
||||
{
|
||||
typedef typename Lhs::MatrixType MatrixType;
|
||||
|
||||
typedef ei_blas_traits<MatrixType> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
|
||||
const ActualLhsType actualLhs = LhsBlasTraits::extract(m_lhs._expression());
|
||||
dst.resize(m_lhs.rows(), m_rhs.cols());
|
||||
dst.setZero();
|
||||
evalTo(dst,1);
|
||||
}
|
||||
|
||||
typedef ei_blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
|
||||
const ActualRhsType actualRhs = RhsBlasTraits::extract(m_rhs);
|
||||
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
Scalar actualAlpha = m_alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
dst.resize(m_rhs.rows(), m_rhs.cols());
|
||||
dst.setZero();
|
||||
ei_product_triangular_vector_selector
|
||||
<_ActualLhsType,_ActualRhsType,Dest,
|
||||
ei_traits<Lhs>::Mode,
|
||||
Mode,
|
||||
LhsBlasTraits::NeedToConjugate,
|
||||
RhsBlasTraits::NeedToConjugate,
|
||||
ei_traits<Lhs>::Flags&RowMajorBit>
|
||||
::run(actualLhs,actualRhs,dst,actualAlpha);
|
||||
::run(lhs,rhs,dst,actualAlpha);
|
||||
}
|
||||
|
||||
const Lhs m_lhs;
|
||||
const typename Rhs::Nested m_rhs;
|
||||
const Scalar m_alpha;
|
||||
const LhsNested m_lhs;
|
||||
const RhsNested m_rhs;
|
||||
};
|
||||
|
||||
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
|
||||
|
@ -98,10 +98,6 @@ ei_add_test(redux)
|
||||
ei_add_test(product_small)
|
||||
ei_add_test(product_large ${EI_OFLAG})
|
||||
ei_add_test(product_extra ${EI_OFLAG})
|
||||
ei_add_test(product_selfadjoint ${EI_OFLAG})
|
||||
ei_add_test(product_symm ${EI_OFLAG})
|
||||
ei_add_test(product_syrk ${EI_OFLAG})
|
||||
ei_add_test(product_trsm ${EI_OFLAG})
|
||||
ei_add_test(diagonalmatrices)
|
||||
ei_add_test(adjoint)
|
||||
ei_add_test(submatrices)
|
||||
@ -113,7 +109,12 @@ ei_add_test(array)
|
||||
ei_add_test(array_replicate)
|
||||
ei_add_test(array_reverse)
|
||||
ei_add_test(triangular)
|
||||
ei_add_test(product_triangular)
|
||||
ei_add_test(product_selfadjoint ${EI_OFLAG})
|
||||
ei_add_test(product_symm ${EI_OFLAG})
|
||||
ei_add_test(product_syrk ${EI_OFLAG})
|
||||
ei_add_test(product_trmv ${EI_OFLAG})
|
||||
ei_add_test(product_trmm ${EI_OFLAG})
|
||||
ei_add_test(product_trsm ${EI_OFLAG})
|
||||
ei_add_test(bandmatrix)
|
||||
ei_add_test(cholesky " " "${GSL_LIBRARIES}")
|
||||
ei_add_test(lu ${EI_OFLAG})
|
||||
|
69
test/product_trmm.cpp
Normal file
69
test/product_trmm.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@gmail.com>
|
||||
//
|
||||
// Eigen is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU Lesser General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 3 of the License, or (at your option) any later version.
|
||||
//
|
||||
// Alternatively, you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public License as
|
||||
// published by the Free Software Foundation; either version 2 of
|
||||
// the License, or (at your option) any later version.
|
||||
//
|
||||
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public
|
||||
// License and a copy of the GNU General Public License along with
|
||||
// Eigen. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template<typename Scalar> void trmm(int size,int othersize)
|
||||
{
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> tri(size,size), upTri(size,size), loTri(size,size);
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> ge1(size,othersize), ge2(10,size), ge3;
|
||||
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rge3;
|
||||
|
||||
Scalar s1 = ei_random<Scalar>(),
|
||||
s2 = ei_random<Scalar>();
|
||||
|
||||
tri.setRandom();
|
||||
loTri = tri.template triangularView<LowerTriangular>();
|
||||
upTri = tri.template triangularView<UpperTriangular>();
|
||||
ge1.setRandom();
|
||||
ge2.setRandom();
|
||||
|
||||
VERIFY_IS_APPROX( ge3 = tri.template triangularView<LowerTriangular>() * ge1, loTri * ge1);
|
||||
VERIFY_IS_APPROX(rge3 = tri.template triangularView<LowerTriangular>() * ge1, loTri * ge1);
|
||||
VERIFY_IS_APPROX( ge3 = tri.template triangularView<UpperTriangular>() * ge1, upTri * ge1);
|
||||
VERIFY_IS_APPROX(rge3 = tri.template triangularView<UpperTriangular>() * ge1, upTri * ge1);
|
||||
VERIFY_IS_APPROX( ge3 = (s1*tri.adjoint()).template triangularView<UpperTriangular>() * (s2*ge1), s1*loTri.adjoint() * (s2*ge1));
|
||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge1, loTri.adjoint() * ge1);
|
||||
VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge1, upTri.adjoint() * ge1);
|
||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge1, upTri.adjoint() * ge1);
|
||||
VERIFY_IS_APPROX( ge3 = tri.template triangularView<LowerTriangular>() * ge2.adjoint(), loTri * ge2.adjoint());
|
||||
VERIFY_IS_APPROX(rge3 = tri.template triangularView<LowerTriangular>() * ge2.adjoint(), loTri * ge2.adjoint());
|
||||
VERIFY_IS_APPROX( ge3 = tri.template triangularView<UpperTriangular>() * ge2.adjoint(), upTri * ge2.adjoint());
|
||||
VERIFY_IS_APPROX(rge3 = tri.template triangularView<UpperTriangular>() * ge2.adjoint(), upTri * ge2.adjoint());
|
||||
VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
|
||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
|
||||
VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
||||
}
|
||||
|
||||
void test_product_trmm()
|
||||
{
|
||||
for(int i = 0; i < g_repeat ; i++)
|
||||
{
|
||||
trmm<float>(ei_random<int>(1,320),ei_random<int>(1,320));
|
||||
trmm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320));
|
||||
}
|
||||
}
|
@ -85,6 +85,7 @@ template<typename Scalar> void trsm(int size,int cols)
|
||||
solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
|
||||
VERIFY_IS_APPROX(rmRhs, rmRef);
|
||||
}
|
||||
|
||||
void test_product_trsm()
|
||||
{
|
||||
for(int i = 0; i < g_repeat ; i++)
|
||||
|
Loading…
x
Reference in New Issue
Block a user