* GEMM enhencement: no need to pre-transpose the rhs

=> faster a * b.transpose() product
  => this also fix a bug in a so far untested situation
* SYMM is now ready for use => still have to write the high level
  stuff to convert natural expressions into a call to SYMM
This commit is contained in:
Gael Guennebaud 2009-07-22 18:04:16 +02:00
parent d6475ea390
commit e7f8e939e2
4 changed files with 418 additions and 188 deletions

View File

@ -73,79 +73,6 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
}; };
/* Helper class to analyze the factors of a Product expression.
* In particular it allows to pop out operator-, scalar multiples,
* and conjugate */
template<typename XprType> struct ei_blas_traits
{
typedef typename ei_traits<XprType>::Scalar Scalar;
typedef XprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false,
ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
};
typedef typename ei_meta_if<int(ActualAccess)==HasDirectAccess,
const ActualXprType&,
typename ActualXprType::PlainMatrixType
>::ret DirectLinearAccessType;
static inline const ActualXprType& extract(const XprType& x) { return x; }
static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};
// pop conjugate
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex
};
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); }
};
// pop scalar multiple
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return x._functor().m_other * Base::extractScalarFactor(x._expression()); }
};
// pop opposite
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return - Base::extractScalarFactor(x._expression()); }
};
// pop opposite
template<typename NestedXpr> struct ei_blas_traits<NestByValue<NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef typename NestedXpr::Scalar Scalar;
typedef ei_blas_traits<NestedXpr> Base;
typedef NestByValue<NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast<const NestedXpr&>(x)); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return Base::extractScalarFactor(static_cast<const NestedXpr&>(x)); }
};
/* Helper class to determine the type of the product, can be either: /* Helper class to determine the type of the product, can be either:
* - NormalProduct * - NormalProduct
* - CacheFriendlyProduct * - CacheFriendlyProduct
@ -869,25 +796,6 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien
return derived(); return derived();
} }
template<typename T> struct ei_product_copy_rhs
{
typedef typename ei_meta_if<
(ei_traits<T>::Flags & RowMajorBit)
|| (!(ei_traits<T>::Flags & DirectAccessBit)),
typename ei_plain_matrix_type_column_major<T>::type,
const T&
>::ret type;
};
template<typename T> struct ei_product_copy_lhs
{
typedef typename ei_meta_if<
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_plain_matrix_type<T>::type,
const T&
>::ret type;
};
template<typename Lhs, typename Rhs, int ProductMode> template<typename Lhs, typename Rhs, int ProductMode>
template<typename DestDerived> template<typename DestDerived>
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const
@ -895,26 +803,22 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
typedef ei_blas_traits<_LhsNested> LhsProductTraits; typedef ei_blas_traits<_LhsNested> LhsProductTraits;
typedef ei_blas_traits<_RhsNested> RhsProductTraits; typedef ei_blas_traits<_RhsNested> RhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType; typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
typedef typename RhsProductTraits::ActualXprType ActualRhsType; typedef typename RhsProductTraits::DirectLinearAccessType ActualRhsType;
const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
const ActualLhsType lhs = LhsProductTraits::extract(m_lhs);
const ActualRhsType rhs = RhsProductTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs) Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs)
* RhsProductTraits::extractScalarFactor(m_rhs); * RhsProductTraits::extractScalarFactor(m_rhs);
typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
LhsCopy lhs(actualLhs);
RhsCopy rhs(actualRhs);
ei_general_matrix_matrix_product< ei_general_matrix_matrix_product<
Scalar, Scalar,
(_LhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate), (_ActualLhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate),
(_RhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate), (_ActualRhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate),
(DestDerived::Flags&RowMajorBit)?RowMajor:ColMajor> (DestDerived::Flags&RowMajorBit)?RowMajor:ColMajor>
::run( ::run(
rows(), cols(), lhs.cols(), rows(), cols(), lhs.cols(),

View File

@ -89,16 +89,16 @@ static void run(int rows, int cols, int depth,
// we have selected one row panel of rhs and one column panel of lhs // we have selected one row panel of rhs and one column panel of lhs
// pack rhs's panel into a sequential chunk of memory // pack rhs's panel into a sequential chunk of memory
// and expand each coeff to a constant packet for further reuse // and expand each coeff to a constant packet for further reuse
ei_gemm_pack_rhs<Scalar, Blocking::PacketSize, Blocking::nr>()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); ei_gemm_pack_rhs<Scalar, Blocking::nr, RhsStorageOrder>()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols);
// => GEPP_VAR1 // => GEPP_VAR1
for(int i2=0; i2<rows; i2+=mc) for(int i2=0; i2<rows; i2+=mc)
{ {
const int actual_mc = std::min(i2+mc,rows)-i2; const int actual_mc = std::min(i2+mc,rows)-i2;
ei_gemm_pack_lhs<Scalar, Blocking::mr, LhsStorageOrder>()(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); ei_gemm_pack_lhs<Scalar, Blocking::mr, LhsStorageOrder>()(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
ei_gebp_kernel<Scalar, PacketType, Blocking::PacketSize, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> >() ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> >()
(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); (res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
} }
} }
@ -110,11 +110,13 @@ static void run(int rows, int cols, int depth,
}; };
// optimized GEneral packed Block * packed Panel product kernel // optimized GEneral packed Block * packed Panel product kernel
template<typename Scalar, typename PacketType, int PacketSize, int mr, int nr, typename Conj> template<typename Scalar, int mr, int nr, typename Conj>
struct ei_gebp_kernel struct ei_gebp_kernel
{ {
void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int actual_mc, int actual_kc, int packet_cols, int i2, int cols) void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int actual_mc, int actual_kc, int packet_cols, int i2, int cols)
{ {
typedef typename ei_packet_traits<Scalar>::type PacketType;
enum { PacketSize = ei_packet_traits<Scalar>::size };
Conj cj; Conj cj;
const int peeled_mc = (actual_mc/mr)*mr; const int peeled_mc = (actual_mc/mr)*mr;
// loops on each cache friendly block of the result/rhs // loops on each cache friendly block of the result/rhs
@ -276,7 +278,7 @@ struct ei_gebp_kernel
if(nr==4) res[(j2+3)*resStride + i2 + i] += C3; if(nr==4) res[(j2+3)*resStride + i2 + i] += C3;
} }
} }
// process remaining rhs/res columns one at a time // process remaining rhs/res columns one at a time
// => do the same but with nr==1 // => do the same but with nr==1
for(int j2=packet_cols; j2<cols; j2++) for(int j2=packet_cols; j2<cols; j2++)
@ -353,9 +355,11 @@ struct ei_gemm_pack_lhs
}; };
// copy a complete panel of the rhs while expending each coefficient into a packet form // copy a complete panel of the rhs while expending each coefficient into a packet form
template<typename Scalar, int PacketSize, int nr> // this version is optimized for column major matrices
struct ei_gemm_pack_rhs template<typename Scalar, int nr>
struct ei_gemm_pack_rhs<Scalar, nr, ColMajor>
{ {
enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols) void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols)
{ {
bool hasAlpha = alpha != Scalar(1); bool hasAlpha = alpha != Scalar(1);
@ -419,6 +423,61 @@ struct ei_gemm_pack_rhs
} }
}; };
// this version is optimized for row major matrices
template<typename Scalar, int nr>
struct ei_gemm_pack_rhs<Scalar, nr, RowMajor>
{
enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols)
{
bool hasAlpha = alpha != Scalar(1);
int count = 0;
for(int j2=0; j2<packet_cols; j2+=nr)
{
if (hasAlpha)
{
for(int k=0; k<actual_kc; k++)
{
const Scalar* b0 = &rhs[k*rhsStride + j2];
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[0]));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b0[1]));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b0[2]));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b0[3]));
}
count += nr*PacketSize;
}
}
else
{
for(int k=0; k<actual_kc; k++)
{
const Scalar* b0 = &rhs[k*rhsStride + j2];
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(b0[0]));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(b0[1]));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(b0[2]));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(b0[3]));
}
count += nr*PacketSize;
}
}
}
// copy the remaining columns one at a time (nr==1)
for(int j2=packet_cols; j2<cols; ++j2)
{
const Scalar* b0 = &rhs[j2];
for(int k=0; k<actual_kc; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*b0[k*rhsStride]));
count += PacketSize;
}
}
}
};
#endif // EIGEN_EXTERN_INSTANTIATIONS #endif // EIGEN_EXTERN_INSTANTIATIONS
#endif // EIGEN_GENERAL_MATRIX_MATRIX_H #endif // EIGEN_GENERAL_MATRIX_MATRIX_H

View File

@ -31,11 +31,12 @@ struct ei_symm_pack_lhs
{ {
void operator()(Scalar* blockA, const Scalar* _lhs, int lhsStride, int actual_kc, int actual_mc) void operator()(Scalar* blockA, const Scalar* _lhs, int lhsStride, int actual_kc, int actual_mc)
{ {
ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar,StorageOrder> lhs(_lhs,lhsStride);
int count = 0; int count = 0;
const int peeled_mc = (actual_mc/mr)*mr; const int peeled_mc = (actual_mc/mr)*mr;
for(int i=0; i<peeled_mc; i+=mr) for(int i=0; i<peeled_mc; i+=mr)
{ {
// normal copy
for(int k=0; k<i; k++) for(int k=0; k<i; k++)
for(int w=0; w<mr; w++) for(int w=0; w<mr; w++)
blockA[count++] = lhs(i+w,k); blockA[count++] = lhs(i+w,k);
@ -55,6 +56,7 @@ struct ei_symm_pack_lhs
for(int w=0; w<mr; w++) for(int w=0; w<mr; w++)
blockA[count++] = lhs(k, i+w); blockA[count++] = lhs(k, i+w);
} }
// do the same with mr==1 // do the same with mr==1
for(int i=peeled_mc; i<actual_mc; i++) for(int i=peeled_mc; i<actual_mc; i++)
{ {
@ -67,86 +69,278 @@ struct ei_symm_pack_lhs
} }
}; };
template<typename Scalar, int nr, int StorageOrder>
struct ei_symm_pack_rhs
{
enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* _rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols, int k2)
{
int end_k = k2 + actual_kc;
int count = 0;
ei_const_blas_data_mapper<Scalar,StorageOrder> rhs(_rhs,rhsStride);
// first part: standard case
for(int j2=0; j2<k2; j2+=nr)
{
for(int k=k2; k<end_k; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(k,j2+0)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(k,j2+1)));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(k,j2+2)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(k,j2+3)));
}
count += nr*PacketSize;
}
}
// second part: diagonal block
for(int j2=k2; j2<std::min(k2+actual_kc,packet_cols); j2+=nr)
{
// again we can split vertically in three different parts (transpose, symmetric, normal)
// transpose
for(int k=k2; k<j2; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
}
count += nr*PacketSize;
}
// symmetric
int h = 0;
for(int k=j2; k<j2+nr; k++)
{
// normal
for (int w=0 ; w<h; ++w)
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(k,j2+w)));
// transpose
for (int w=h ; w<nr; ++w)
ei_pstore(&blockB[count+w*PacketSize], ei_pset1(alpha*rhs(j2+w,k)));
count += nr*PacketSize;
++h;
}
// normal
for(int k=j2+nr; k<end_k; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(k,j2+0)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(k,j2+1)));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(k,j2+2)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(k,j2+3)));
}
count += nr*PacketSize;
}
}
// third part: transpose
for(int j2=k2+actual_kc; j2<packet_cols; j2+=nr)
{
for(int k=k2; k<end_k; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*rhs(j2+0,k)));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*rhs(j2+1,k)));
if (nr==4)
{
ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*rhs(j2+2,k)));
ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*rhs(j2+3,k)));
}
count += nr*PacketSize;
}
}
// copy the remaining columns one at a time (=> the same with nr==1)
for(int j2=packet_cols; j2<cols; ++j2)
{
// transpose
int half = std::min(end_k,j2);
for(int k=k2; k<half; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*rhs(j2,k)));
count += PacketSize;
}
// normal
for(int k=half; k<k2+actual_kc; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*rhs(k,j2)));
count += PacketSize;
}
}
}
};
/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of /* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
* the general matrix matrix product. * the general matrix matrix product.
*/ */
template<typename Scalar, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> template <typename Scalar,
static EIGEN_DONT_INLINE void ei_product_selfadjoint_matrix( int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
int size, int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
const Scalar* _lhs, int lhsStride, int ResStorageOrder>
const Scalar* _rhs, int rhsStride, bool rhsRowMajor, int cols, struct ei_product_selfadjoint_matrix;
Scalar* res, int resStride,
Scalar alpha) template <typename Scalar,
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs>
struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor>
{ {
typedef typename ei_packet_traits<Scalar>::type Packet;
ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride); static EIGEN_STRONG_INLINE void run(
ei_const_blas_data_mapper<Scalar, ColMajor> rhs(_rhs,rhsStride); int rows, int cols,
const Scalar* lhs, int lhsStride,
if (ConjugateRhs) const Scalar* rhs, int rhsStride,
alpha = ei_conj(alpha); Scalar* res, int resStride,
Scalar alpha)
typedef typename ei_packet_traits<Scalar>::type PacketType;
const bool lhsRowMajor = (StorageOrder==RowMajor);
typedef ei_product_blocking_traits<Scalar> Blocking;
int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
// number of columns which can be processed by packet of nr columns
int packet_cols = (cols/Blocking::nr)*Blocking::nr;
ei_gebp_kernel<Scalar, PacketType, Blocking::PacketSize,
Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
for(int k2=0; k2<size; k2+=kc)
{ {
const int actual_kc = std::min(k2+kc,size)-k2; ei_product_selfadjoint_matrix<Scalar,
RhsStorageOrder==RowMajor ? ColMajor : RowMajor, RhsSelfAdjoint, ConjugateRhs,
// we have selected one row panel of rhs and one column panel of lhs LhsStorageOrder==RowMajor ? ColMajor : RowMajor, LhsSelfAdjoint, ConjugateLhs, ColMajor>
// pack rhs's panel into a sequential chunk of memory ::run(rows, cols, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
// and expand each coeff to a constant packet for further reuse
ei_gemm_pack_rhs<Scalar,Blocking::PacketSize,Blocking::nr>()
(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols);
// the select lhs's panel has to be split in three different parts:
// 1 - the transposed panel above the diagonal block => transposed packed copy
// 2 - the diagonal block => special packed copy
// 3 - the panel below the diagonal block => generic packed copy
for(int i2=0; i2<k2; i2+=mc)
{
const int actual_mc = std::min(i2+mc,k2)-i2;
// transposed packed copy
ei_gemm_pack_lhs<Scalar,Blocking::mr,StorageOrder==RowMajor?ColMajor:RowMajor>()
(blockA, &lhs(k2,i2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
}
// the block diagonal
{
const int actual_mc = std::min(k2+kc,size)-k2;
// symmetric packed copy
ei_symm_pack_lhs<Scalar,Blocking::mr,StorageOrder>()
(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols);
}
for(int i2=k2+kc; i2<size; i2+=mc)
{
const int actual_mc = std::min(i2+mc,size)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,StorageOrder>()
(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
}
} }
};
ei_aligned_stack_delete(Scalar, blockA, kc*mc); template <typename Scalar,
ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); int LhsStorageOrder, bool ConjugateLhs,
} int RhsStorageOrder, bool ConjugateRhs>
struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>
{
static EIGEN_DONT_INLINE void run(
int rows, int cols,
const Scalar* _lhs, int lhsStride,
const Scalar* _rhs, int rhsStride,
Scalar* res, int resStride,
Scalar alpha)
{
int size = rows;
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar> Blocking;
int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction
int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
// number of columns which can be processed by packet of nr columns
int packet_cols = (cols/Blocking::nr)*Blocking::nr;
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
for(int k2=0; k2<size; k2+=kc)
{
const int actual_kc = std::min(k2+kc,size)-k2;
// we have selected one row panel of rhs and one column panel of lhs
// pack rhs's panel into a sequential chunk of memory
// and expand each coeff to a constant packet for further reuse
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder>()
(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols);
// the select lhs's panel has to be split in three different parts:
// 1 - the transposed panel above the diagonal block => transposed packed copy
// 2 - the diagonal block => special packed copy
// 3 - the panel below the diagonal block => generic packed copy
for(int i2=0; i2<k2; i2+=mc)
{
const int actual_mc = std::min(i2+mc,k2)-i2;
// transposed packed copy if Lower part
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder==RowMajor?ColMajor:RowMajor>()
(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
}
// the block diagonal
{
const int actual_mc = std::min(k2+kc,size)-k2;
// symmetric packed copy
ei_symm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>()
(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols);
}
for(int i2=k2+kc; i2<size; i2+=mc)
{
const int actual_mc = std::min(i2+mc,size)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
}
}
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
}
};
// matrix * selfadjoint product
template <typename Scalar,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs>
struct ei_product_selfadjoint_matrix<Scalar,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>
{
static EIGEN_DONT_INLINE void run(
int rows, int cols,
const Scalar* _lhs, int lhsStride,
const Scalar* _rhs, int rhsStride,
Scalar* res, int resStride,
Scalar alpha)
{
int size = cols;
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar> Blocking;
int kc = std::min<int>(Blocking::Max_kc,size); // cache block size along the K direction
int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
// number of columns which can be processed by packet of nr columns
int packet_cols = (cols/Blocking::nr)*Blocking::nr;
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
for(int k2=0; k2<size; k2+=kc)
{
const int actual_kc = std::min(k2+kc,size)-k2;
ei_symm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder>()
(blockB, _rhs, rhsStride, alpha, actual_kc, packet_cols, cols, k2);
// => GEPP
for(int i2=0; i2<rows; i2+=mc)
{
const int actual_mc = std::min(i2+mc,rows)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols);
}
}
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
}
};
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H #endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H

View File

@ -29,10 +29,10 @@
// implement and control fast level 2 and level 3 BLAS-like routines. // implement and control fast level 2 and level 3 BLAS-like routines.
// forward declarations // forward declarations
template<typename Scalar, typename Packet, int PacketSize, int mr, int nr, typename Conj> template<typename Scalar, int mr, int nr, typename Conj>
struct ei_gebp_kernel; struct ei_gebp_kernel;
template<typename Scalar, int PacketSize, int nr> template<typename Scalar, int nr, int StorageOrder>
struct ei_gemm_pack_rhs; struct ei_gemm_pack_rhs;
template<typename Scalar, int mr, int StorageOrder> template<typename Scalar, int mr, int StorageOrder>
@ -154,4 +154,77 @@ struct ei_product_blocking_traits
}; };
}; };
/* Helper class to analyze the factors of a Product expression.
* In particular it allows to pop out operator-, scalar multiples,
* and conjugate */
template<typename XprType> struct ei_blas_traits
{
typedef typename ei_traits<XprType>::Scalar Scalar;
typedef XprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false,
ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
};
typedef typename ei_meta_if<int(ActualAccess)==HasDirectAccess,
const ActualXprType&,
typename ActualXprType::PlainMatrixType
>::ret DirectLinearAccessType;
static inline const ActualXprType& extract(const XprType& x) { return x; }
static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};
// pop conjugate
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex
};
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); }
};
// pop scalar multiple
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return x._functor().m_other * Base::extractScalarFactor(x._expression()); }
};
// pop opposite
template<typename Scalar, typename NestedXpr> struct ei_blas_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef ei_blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return - Base::extractScalarFactor(x._expression()); }
};
// pop opposite
template<typename NestedXpr> struct ei_blas_traits<NestByValue<NestedXpr> >
: ei_blas_traits<NestedXpr>
{
typedef typename NestedXpr::Scalar Scalar;
typedef ei_blas_traits<NestedXpr> Base;
typedef NestByValue<NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast<const NestedXpr&>(x)); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return Base::extractScalarFactor(static_cast<const NestedXpr&>(x)); }
};
#endif // EIGEN_BLASUTIL_H #endif // EIGEN_BLASUTIL_H