addd matrix * self adjoint high level API

This commit is contained in:
Gael Guennebaud 2009-07-23 10:05:38 +02:00
parent f696efc00e
commit ddb3ac98a2
2 changed files with 69 additions and 37 deletions

View File

@ -53,8 +53,9 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri
}; };
}; };
template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime> template <typename Lhs, int LhsMode, bool LhsIsVector,
struct ei_selfadjoint_matrix_product_returntype; typename Rhs, int RhsMode, bool RhsIsVector>
struct ei_selfadjoint_product_returntype;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
@ -99,10 +100,22 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
/** Efficient self-adjoint matrix times vector/matrix product */ /** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived> template<typename OtherDerived>
ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived> ei_selfadjoint_product_returntype<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
operator*(const MatrixBase<OtherDerived>& rhs) const operator*(const MatrixBase<OtherDerived>& rhs) const
{ {
return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived()); return ei_selfadjoint_product_returntype
<MatrixType,Mode,false,OtherDerived,0,OtherDerived::IsVectorAtCompileTime>
(m_matrix, rhs.derived());
}
/** Efficient vector/matrix times self-adjoint matrix product */
template<typename OtherDerived> friend
ei_selfadjoint_product_returntype<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView& rhs)
{
return ei_selfadjoint_product_returntype
<OtherDerived,0,OtherDerived::IsVectorAtCompileTime,MatrixType,Mode,false>
(lhs.derived(),rhs.m_matrix);
} }
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this: /** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
@ -125,6 +138,14 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
const typename MatrixType::Nested m_matrix; const typename MatrixType::Nested m_matrix;
}; };
// template<typename OtherDerived, typename MatrixType, unsigned int UpLo>
// ei_selfadjoint_matrix_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >
// operator*(const MatrixBase<OtherDerived>& lhs, const SelfAdjointView<MatrixType,UpLo>& rhs)
// {
// return ei_matrix_selfadjoint_product_returntype<OtherDerived,SelfAdjointView<MatrixType,UpLo> >(lhs.derived(),rhs);
// }
template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite> template<typename Derived1, typename Derived2, int UnrollCount, bool ClearOpposite>
struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite> struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, UnrollCount, ClearOpposite>
{ {
@ -163,14 +184,14 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
* Wrapper to ei_product_selfadjoint_vector * Wrapper to ei_product_selfadjoint_vector
***************************************************************************/ ***************************************************************************/
template<typename Lhs,typename Rhs> template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true> struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>, : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>,
Matrix<typename ei_traits<Rhs>::Scalar, Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{ {
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested; typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs) : m_lhs(lhs), m_rhs(rhs)
{} {}
@ -178,10 +199,10 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
{ {
dst.resize(m_rhs.rows(), m_rhs.cols()); dst.resize(m_rhs.rows(), m_rhs.cols());
ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit, ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit,
Lhs::Mode&(UpperTriangularBit|LowerTriangularBit)> LhsMode&(UpperTriangularBit|LowerTriangularBit)>
( (
m_lhs.rows(), // size m_lhs.rows(), // size
m_lhs._expression().data(), // lhs m_lhs.data(), // lhs
m_lhs.stride(), // lhsStride, m_lhs.stride(), // lhsStride,
m_rhs.data(), // rhs m_rhs.data(), // rhs
// int rhsIncr, // int rhsIncr,
@ -189,7 +210,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
); );
} }
const Lhs m_lhs; const typename Lhs::Nested m_lhs;
const typename Rhs::Nested m_rhs; const typename Rhs::Nested m_rhs;
}; };
@ -197,25 +218,36 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
* Wrapper to ei_product_selfadjoint_matrix * Wrapper to ei_product_selfadjoint_matrix
***************************************************************************/ ***************************************************************************/
template<typename Lhs,typename Rhs> template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false> struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>, : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,false>,
Matrix<typename ei_traits<Rhs>::Scalar, Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{ {
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs) ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs) : m_lhs(lhs), m_rhs(rhs)
{} {}
typedef typename Lhs::Scalar Scalar; typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested; typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested; typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
typedef typename ei_traits<Lhs>::ExpressionType LhsExpr; enum {
typedef typename LhsExpr::Nested LhsNested; LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
typedef typename ei_cleantype<LhsNested>::type _LhsNested; LhsIsSelfAdjoint = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) }; RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
template<typename Dest> inline void _addTo(Dest& dst) const template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); } { evalTo(dst,1); }
@ -231,26 +263,19 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
{ {
typedef ei_blas_traits<_LhsNested> LhsBlasTraits; const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression());
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression()) Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs); * RhsBlasTraits::extractScalarFactor(m_rhs);
ei_product_selfadjoint_matrix<Scalar, ei_product_selfadjoint_matrix<Scalar,
EIGEN_LOGICAL_XOR(UpLo==UpperTriangular, EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true, ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)), NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsUpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate), EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,
ei_traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsUpLo==UpperTriangular,bool(RhsBlasTraits::NeedToConjugate)),
ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor> ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
::run( ::run(
lhs.rows(), rhs.cols(), // sizes lhs.rows(), rhs.cols(), // sizes
@ -261,7 +286,7 @@ struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
); );
} }
const Lhs m_lhs; const LhsNested m_lhs;
const RhsNested m_rhs; const RhsNested m_rhs;
}; };

View File

@ -138,6 +138,13 @@ template<typename MatrixType> void symm(const MatrixType& m)
m2 = m1.template triangularView<UpperTriangular>(); m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(), VERIFY_IS_APPROX(rhs32 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*rhs3).conjugate(),
rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate()); rhs33 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
// test matrix * selfadjoint
m2 = m1.template triangularView<LowerTriangular>();
VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<LowerTriangular>(),
rhs23 = (rhs2) * (m1));
VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<LowerTriangular>(),
rhs23 = (s2*rhs2) * (s1*m1));
} }
void test_product_selfadjoint() void test_product_selfadjoint()
{ {