implement high level API for SYMM and fix a couple of bugs related to complex

This commit is contained in:
Gael Guennebaud 2009-07-22 23:12:22 +02:00
parent e7f8e939e2
commit 0cb4f32e12
8 changed files with 200 additions and 56 deletions

View File

@ -300,6 +300,10 @@ template<typename Derived> class MatrixBase
template<typename OtherDerived,typename OtherEvalType>
Derived& operator=(const ReturnByValue<OtherDerived,OtherEvalType>& func);
template<typename OtherDerived,typename OtherEvalType>
Derived& operator+=(const ReturnByValue<OtherDerived,OtherEvalType>& func);
template<typename OtherDerived,typename OtherEvalType>
Derived& operator-=(const ReturnByValue<OtherDerived,OtherEvalType>& func);
#ifndef EIGEN_PARSED_BY_DOXYGEN
/** Copies \a other into *this without evaluating other. \returns a reference to *this. */

View File

@ -46,20 +46,36 @@ struct ei_nested<ReturnByValue<Functor,MatrixBase<EvalTypeDerived> >, n, EvalTyp
template<typename Functor, typename EvalType> class ReturnByValue
{
public:
template<typename Dest>
inline void evalTo(Dest& dst) const
template<typename Dest> inline void evalTo(Dest& dst) const
{ static_cast<const Functor*>(this)->evalTo(dst); }
template<typename Dest> inline void addTo(Dest& dst) const
{ static_cast<const Functor*>(this)->_addTo(dst); }
template<typename Dest> inline void subTo(Dest& dst) const
{ static_cast<const Functor*>(this)->_subTo(dst); }
template<typename Dest> inline void _addTo(Dest& dst) const
{ EvalType res; evalTo(res); dst += res; }
template<typename Dest> inline void _subTo(Dest& dst) const
{ EvalType res; evalTo(res); dst -= res; }
};
template<typename Functor, typename _Scalar,int _Rows,int _Cols,int _Options,int _MaxRows,int _MaxCols>
class ReturnByValue<Functor,Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols> >
: public MatrixBase<ReturnByValue<Functor,Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols> > >
{
typedef Matrix<_Scalar,_Rows,_Cols,_Options,_MaxRows,_MaxCols> EvalType;
public:
EIGEN_GENERIC_PUBLIC_INTERFACE(ReturnByValue)
template<typename Dest>
inline void evalTo(Dest& dst) const
{ static_cast<const Functor* const>(this)->evalTo(dst); }
template<typename Dest> inline void addTo(Dest& dst) const
{ static_cast<const Functor*>(this)->_addTo(dst); }
template<typename Dest> inline void subTo(Dest& dst) const
{ static_cast<const Functor*>(this)->_subTo(dst); }
template<typename Dest> inline void _addTo(Dest& dst) const
{ EvalType res; evalTo(res); dst += res; }
template<typename Dest> inline void _subTo(Dest& dst) const
{ EvalType res; evalTo(res); dst -= res; }
};
template<typename Derived>
@ -70,4 +86,20 @@ Derived& MatrixBase<Derived>::operator=(const ReturnByValue<OtherDerived,OtherEv
return derived();
}
template<typename Derived>
template<typename OtherDerived,typename OtherEvalType>
Derived& MatrixBase<Derived>::operator+=(const ReturnByValue<OtherDerived,OtherEvalType>& other)
{
other.addTo(derived());
return derived();
}
template<typename Derived>
template<typename OtherDerived,typename OtherEvalType>
Derived& MatrixBase<Derived>::operator-=(const ReturnByValue<OtherDerived,OtherEvalType>& other)
{
other.subTo(derived());
return derived();
}
#endif // EIGEN_RETURNBYVALUE_H

View File

@ -53,8 +53,8 @@ struct ei_traits<SelfAdjointView<MatrixType, TriangularPart> > : ei_traits<Matri
};
};
template<typename Lhs,typename Rhs>
struct ei_selfadjoint_vector_product_returntype;
template<typename Lhs,typename Rhs,bool RhsIsVector=Rhs::IsVectorAtCompileTime>
struct ei_selfadjoint_matrix_product_returntype;
// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ??
template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
@ -97,13 +97,12 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView
/** \internal */
const MatrixType& _expression() const { return m_matrix; }
/** Efficient self-adjoint matrix times vector product */
// TODO this product is far to be ready
/** Efficient self-adjoint matrix times vector/matrix product */
template<typename OtherDerived>
ei_selfadjoint_vector_product_returntype<SelfAdjointView,OtherDerived>
ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
return ei_selfadjoint_vector_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived());
return ei_selfadjoint_matrix_product_returntype<SelfAdjointView,OtherDerived>(*this, rhs.derived());
}
/** Perform a symmetric rank 2 update of the selfadjoint matrix \c *this:
@ -165,13 +164,13 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
***************************************************************************/
template<typename Lhs,typename Rhs>
struct ei_selfadjoint_vector_product_returntype
: public ReturnByValue<ei_selfadjoint_vector_product_returntype<Lhs,Rhs>,
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,true>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
ei_selfadjoint_vector_product_returntype(const Lhs& lhs, const Rhs& rhs)
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
@ -194,6 +193,78 @@ struct ei_selfadjoint_vector_product_returntype
const typename Rhs::Nested m_rhs;
};
/***************************************************************************
* Wrapper to ei_product_selfadjoint_matrix
***************************************************************************/
template<typename Lhs,typename Rhs>
struct ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>
: public ReturnByValue<ei_selfadjoint_matrix_product_returntype<Lhs,Rhs,false>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
ei_selfadjoint_matrix_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
typedef typename Lhs::Scalar Scalar;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef typename ei_traits<Lhs>::ExpressionType LhsExpr;
typedef typename LhsExpr::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
enum { UpLo = ei_traits<Lhs>::Mode&(UpperTriangularBit|LowerTriangularBit) };
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void _subTo(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalTo(Dest& dst) const
{
dst.resize(m_lhs.rows(), m_rhs.cols());
dst.setZero();
evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
{
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs._expression());
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_product_selfadjoint_matrix<Scalar,
EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,
ei_traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, true,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(UpLo==UpperTriangular,bool(LhsBlasTraits::NeedToConjugate)),
ei_traits<Rhs>::Flags &RowMajorBit ? RowMajor : ColMajor, false, bool(RhsBlasTraits::NeedToConjugate),
ei_traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
::run(
lhs.rows(), rhs.cols(), // sizes
&lhs.coeff(0,0), lhs.stride(), // lhs info
&rhs.coeff(0,0), rhs.stride(), // rhs info
&dst.coeffRef(0,0), dst.stride(), // result info
actualAlpha // alpha
);
}
const Lhs m_lhs;
const RhsNested m_rhs;
};
/***************************************************************************
* Implementation of MatrixBase methods
***************************************************************************/

View File

@ -334,22 +334,23 @@ struct ei_gebp_kernel
};
// pack a block of the lhs
template<typename Scalar, int mr, int StorageOrder>
template<typename Scalar, int mr, int StorageOrder, bool Conjugate>
struct ei_gemm_pack_lhs
{
void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int actual_kc, int actual_mc)
{
ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
int count = 0;
const int peeled_mc = (actual_mc/mr)*mr;
for(int i=0; i<peeled_mc; i+=mr)
for(int k=0; k<actual_kc; k++)
for(int w=0; w<mr; w++)
blockA[count++] = lhs(i+w, k);
blockA[count++] = cj(lhs(i+w, k));
for(int i=peeled_mc; i<actual_mc; i++)
{
for(int k=0; k<actual_kc; k++)
blockA[count++] = lhs(i, k);
blockA[count++] = cj(lhs(i, k));
}
}
};

View File

@ -39,32 +39,30 @@ struct ei_symm_pack_lhs
// normal copy
for(int k=0; k<i; k++)
for(int w=0; w<mr; w++)
blockA[count++] = lhs(i+w,k);
blockA[count++] = lhs(i+w,k); // normal
// symmetric copy
int h = 0;
for(int k=i; k<i+mr; k++)
{
// transposed copy
for(int w=0; w<h; w++)
blockA[count++] = lhs(k, i+w);
blockA[count++] = ei_conj(lhs(k, i+w)); // transposed
for(int w=h; w<mr; w++)
blockA[count++] = lhs(i+w, k);
blockA[count++] = lhs(i+w, k); // normal
++h;
}
// transposed copy
for(int k=i+mr; k<actual_kc; k++)
for(int w=0; w<mr; w++)
blockA[count++] = lhs(k, i+w);
blockA[count++] = ei_conj(lhs(k, i+w)); // transposed
}
// do the same with mr==1
for(int i=peeled_mc; i<actual_mc; i++)
{
for(int k=0; k<=i; k++)
blockA[count++] = lhs(i, k);
// transposed copy
blockA[count++] = lhs(i, k); // normal
for(int k=i+1; k<actual_kc; k++)
blockA[count++] = lhs(k, i);
blockA[count++] = ei_conj(lhs(k, i)); // transposed
}
}
};
@ -79,7 +77,7 @@ struct ei_symm_pack_rhs
int count = 0;
ei_const_blas_data_mapper<Scalar,StorageOrder> rhs(_rhs,rhsStride);
// first part: standard case
// first part: normal case
for(int j2=0; j2<k2; j2+=nr)
{
for(int k=k2; k<end_k; k++)
@ -102,12 +100,12 @@ struct ei_symm_pack_rhs
// transpose
for(int k=k2; k<j2; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+0,k))));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+1,k))));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+2,k))));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+3,k))));
}
count += nr*PacketSize;
}
@ -120,7 +118,7 @@ struct ei_symm_pack_rhs
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(k,j2+w)));
// transpose
for (int w=h ; w<nr; ++w)
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(j2+w,k)));
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+w,k))));
count += nr*PacketSize;
++h;
}
@ -138,17 +136,17 @@ struct ei_symm_pack_rhs
}
}
// third part: transpose
// third part: transposed
for(int j2=k2+actual_kc; j2<packet_cols; j2+=nr)
{
for(int k=k2; k<end_k; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+0,k))));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+1,k))));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+2,k))));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*ei_conj(rhs(j2+3,k))));
}
count += nr*PacketSize;
}
@ -161,7 +159,7 @@ struct ei_symm_pack_rhs
int half = std::min(end_k,j2);
for(int k=k2; k<half; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*rhs(j2,k)));
ei_pstore(&blockB[count], ei_pset1(alpha*ei_conj(rhs(j2,k))));
count += PacketSize;
}
// normal
@ -198,8 +196,9 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,LhsSelfAdjoint,Conju
{
ei_product_selfadjoint_matrix<Scalar,
RhsStorageOrder==RowMajor ? ColMajor : RowMajor, RhsSelfAdjoint, ConjugateRhs,
LhsStorageOrder==RowMajor ? ColMajor : RowMajor, LhsSelfAdjoint, ConjugateLhs, ColMajor>
::run(rows, cols, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && !ConjugateLhs, ColMajor>
::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
}
};
@ -254,8 +253,8 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, R
for(int i2=0; i2<k2; i2+=mc)
{
const int actual_mc = std::min(i2+mc,k2)-i2;
// transposed packed copy if Lower part
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor>()
// transposed packed copy
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor, true>()
(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
@ -273,7 +272,7 @@ struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, R
for(int i2=k2+kc; i2<size; i2+=mc)
{
const int actual_mc = std::min(i2+mc,size)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>()
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder,false>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);

View File

@ -35,7 +35,7 @@ struct ei_gebp_kernel;
template<typename Scalar, int nr, int StorageOrder>
struct ei_gemm_pack_rhs;
template<typename Scalar, int mr, int StorageOrder>
template<typename Scalar, int mr, int StorageOrder, bool Conjugate = false>
struct ei_gemm_pack_lhs;
template<

View File

@ -277,5 +277,6 @@ _EIGEN_GENERIC_PUBLIC_INTERFACE(Derived, Eigen::MatrixBase<Derived>)
#define EIGEN_ENUM_MIN(a,b) (((int)a <= (int)b) ? (int)a : (int)b)
#define EIGEN_ENUM_MAX(a,b) (((int)a >= (int)b) ? (int)a : (int)b)
#define EIGEN_LOGICAL_XOR(a,b) (((a) || (b)) && !((a) && (b)))
#endif // EIGEN_MACROS_H

View File

@ -87,6 +87,53 @@ template<typename MatrixType> void product_selfadjoint(const MatrixType& m)
}
}
template<typename MatrixType> void symm(const MatrixType& m)
{
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic> Rhs1;
typedef Matrix<Scalar, Dynamic, MatrixType::RowsAtCompileTime> Rhs2;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic,RowMajor> Rhs3;
int rows = m.rows();
int cols = m.cols();
MatrixType m1 = MatrixType::Random(rows, cols),
m2 = MatrixType::Random(rows, cols);
m1 = (m1+m1.adjoint()).eval();
Rhs1 rhs1 = Rhs1::Random(cols, ei_random<int>(1,320)), rhs12, rhs13;
Rhs2 rhs2 = Rhs2::Random(ei_random<int>(1,320), rows), rhs22, rhs23;
Rhs3 rhs3 = Rhs3::Random(cols, ei_random<int>(1,320)), rhs32, rhs33;
// Scalar s1 = ei_random<Scalar>(),
// s2 = ei_random<Scalar>();
m2 = m1.template triangularView<LowerTriangular>();
VERIFY_IS_APPROX(rhs12 = m2.template selfadjointView<LowerTriangular>() * rhs1, rhs13 = m1 * rhs1);
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs12 = m2.template selfadjointView<UpperTriangular>() * rhs1, rhs13 = m1 * rhs1);
m2 = m1.template triangularView<LowerTriangular>();
VERIFY_IS_APPROX(rhs22 = m2.template selfadjointView<LowerTriangular>() * rhs2.adjoint(), rhs23 = m1 * rhs2.adjoint());
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs22 = m2.template selfadjointView<UpperTriangular>() * rhs2.adjoint(), rhs23 = m1 * rhs2.adjoint());
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs22 = m2.adjoint().template selfadjointView<LowerTriangular>() * rhs2.adjoint(),
rhs23 = m1.adjoint() * rhs2.adjoint());
// test row major = <...>
m2 = m1.template triangularView<LowerTriangular>();
VERIFY_IS_APPROX(rhs32 = m2.template selfadjointView<LowerTriangular>() * rhs3, rhs33 = m1 * rhs3);
m2 = m1.template triangularView<UpperTriangular>();
VERIFY_IS_APPROX(rhs32 = m2.adjoint().template selfadjointView<LowerTriangular>() * rhs3.conjugate(),
rhs33 = m1.adjoint() * rhs3.conjugate());
}
void test_product_selfadjoint()
{
for(int i = 0; i < g_repeat ; i++) {
@ -102,21 +149,10 @@ void test_product_selfadjoint()
for(int i = 0; i < g_repeat ; i++)
{
int size = ei_random<int>(10,1024);
int cols = ei_random<int>(10,320);
MatrixXf A = MatrixXf::Random(size,size);
MatrixXf B = MatrixXf::Random(size,cols);
MatrixXf C = MatrixXf::Random(size,cols);
MatrixXf R = MatrixXf::Random(size,cols);
A = (A+A.transpose()).eval();
R = C + (A * B).eval();
A.corner(TopRight,size-1,size-1).triangularView<UpperTriangular>().setZero();
ei_product_selfadjoint_matrix<float,ColMajor,LowerTriangular,false,false>
(size, A.data(), A.stride(), B.data(), B.stride(), false, B.cols(), C.data(), C.stride(), 1);
// std::cerr << A << "\n\n" << C << "\n\n" << R << "\n\n";
VERIFY_IS_APPROX(C,R);
int s;
s = ei_random<int>(10,320);
CALL_SUBTEST( symm(MatrixXf(s, s)) );
s = ei_random<int>(10,320);
CALL_SUBTEST( symm(MatrixXcd(s, s)) );
}
}