ok now all the complex mat-mat and mat-vec products involving conjugate,

adjoint, -, and scalar multiple seems to be well handled. It only remains
the simpler case: C = alpha*(A*B) ... for the next commit
This commit is contained in:
Gael Guennebaud 2009-07-08 18:24:37 +02:00
parent 13b2dafb50
commit 96e7d9f896
5 changed files with 403 additions and 174 deletions

View File

@ -73,24 +73,9 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
}; };
/* Helper class to determine the type of the product, can be either: /* Helper class to analyze the factors of a Product expression.
* - NormalProduct * In particular it allows to pop out operator-, scalar multiples,
* - CacheFriendlyProduct * and conjugate */
*/
template<typename Lhs, typename Rhs> struct ei_product_mode
{
enum{
value = Lhs::MaxColsAtCompileTime == Dynamic
&& ( Lhs::MaxRowsAtCompileTime == Dynamic
|| Rhs::MaxColsAtCompileTime == Dynamic )
&& (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(Lhs::Flags&DirectAccessBit))))
&& (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(Rhs::Flags&DirectAccessBit))))
&& (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret)
? CacheFriendlyProduct
: NormalProduct };
};
template<typename XprType> struct ei_product_factor_traits template<typename XprType> struct ei_product_factor_traits
{ {
typedef typename ei_traits<XprType>::Scalar Scalar; typedef typename ei_traits<XprType>::Scalar Scalar;
@ -98,11 +83,10 @@ template<typename XprType> struct ei_product_factor_traits
enum { enum {
IsComplex = NumTraits<Scalar>::IsComplex, IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false, NeedToConjugate = false,
HasScalarMultiple = false, ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
}; };
static inline const ActualXprType& extract(const XprType& x) { return x; } static inline const ActualXprType& extract(const XprType& x) { return x; }
static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); } static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
}; };
// pop conjugate // pop conjugate
@ -117,8 +101,8 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
IsComplex = NumTraits<Scalar>::IsComplex, IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex NeedToConjugate = IsComplex
}; };
static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); } static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); }
}; };
// pop scalar multiple // pop scalar multiple
@ -128,11 +112,41 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
typedef ei_product_factor_traits<NestedXpr> Base; typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType; typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType; typedef typename Base::ActualXprType ActualXprType;
enum { static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
HasScalarMultiple = true static inline Scalar extractScalarFactor(const XprType& x)
}; { return x._functor().m_other * Base::extractScalarFactor(x._expression()); }
static inline const ActualXprType& extract(const XprType& x) { return x._expression(); } };
static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; }
// pop opposite
template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> >
: ei_product_factor_traits<NestedXpr>
{
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return - Base::extractScalarFactor(x._expression()); }
};
/* Helper class to determine the type of the product, can be either:
* - NormalProduct
* - CacheFriendlyProduct
*/
template<typename Lhs, typename Rhs> struct ei_product_mode
{
typedef typename ei_product_factor_traits<Lhs>::ActualXprType ActualLhs;
typedef typename ei_product_factor_traits<Rhs>::ActualXprType ActualRhs;
enum{
value = Lhs::MaxColsAtCompileTime == Dynamic
&& ( Lhs::MaxRowsAtCompileTime == Dynamic
|| Rhs::MaxColsAtCompileTime == Dynamic )
&& (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(ActualLhs::Flags&DirectAccessBit))))
&& (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(ActualRhs::Flags&DirectAccessBit))))
&& (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret)
? CacheFriendlyProduct
: NormalProduct };
}; };
/** \class Product /** \class Product
@ -552,11 +566,11 @@ void ei_cache_friendly_product(
bool resRowMajor, Scalar* res, int resStride, bool resRowMajor, Scalar* res, int resStride,
Scalar alpha); Scalar alpha);
template<typename Scalar, typename RhsType> template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType>
static void ei_cache_friendly_product_colmajor_times_vector( static void ei_cache_friendly_product_colmajor_times_vector(
int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, Scalar* res, Scalar alpha); int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, Scalar* res, Scalar alpha);
template<typename Scalar, typename ResType> template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType>
static void ei_cache_friendly_product_rowmajor_times_vector( static void ei_cache_friendly_product_rowmajor_times_vector(
const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha); const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha);
@ -572,10 +586,10 @@ static void ei_cache_friendly_product_rowmajor_times_vector(
template<typename ProductType, template<typename ProductType,
int LhsRows = ei_traits<ProductType>::RowsAtCompileTime, int LhsRows = ei_traits<ProductType>::RowsAtCompileTime,
int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor, int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor,
int LhsHasDirectAccess = int(ei_traits<ProductType>::LhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess, int LhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested>::ActualAccess,
int RhsCols = ei_traits<ProductType>::ColsAtCompileTime, int RhsCols = ei_traits<ProductType>::ColsAtCompileTime,
int RhsOrder = int(ei_traits<ProductType>::RhsFlags)&RowMajorBit ? RowMajor : ColMajor, int RhsOrder = int(ei_traits<ProductType>::RhsFlags)&RowMajorBit ? RowMajor : ColMajor,
int RhsHasDirectAccess = int(ei_traits<ProductType>::RhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess> int RhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested>::ActualAccess>
struct ei_cache_friendly_product_selector struct ei_cache_friendly_product_selector
{ {
template<typename DestDerived> template<typename DestDerived>
@ -592,7 +606,6 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,NoDirectA
template<typename DestDerived> template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{ {
// FIXME is it really used ?
ei_assert(alpha==typename ProductType::Scalar(1)); ei_assert(alpha==typename ProductType::Scalar(1));
const int size = product.rhs().rows(); const int size = product.rhs().rows();
for (int k=0; k<size; ++k) for (int k=0; k<size; ++k)
@ -606,10 +619,21 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess>
struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirectAccess,1,RhsOrder,RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirectAccess,1,RhsOrder,RhsAccess>
{ {
typedef typename ProductType::Scalar Scalar; typedef typename ProductType::Scalar Scalar;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
typedef typename RhsProductTraits::ActualXprType ActualRhsType;
template<typename DestDerived> template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{ {
const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
* RhsProductTraits::extractScalarFactor(product.rhs());
enum { enum {
EvalToRes = (ei_packet_traits<Scalar>::size==1) EvalToRes = (ei_packet_traits<Scalar>::size==1)
||((DestDerived::Flags&ActualPacketAccessBit) && (!(DestDerived::Flags & RowMajorBit))) }; ||((DestDerived::Flags&ActualPacketAccessBit) && (!(DestDerived::Flags & RowMajorBit))) };
@ -621,9 +645,12 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
_res = ei_aligned_stack_new(Scalar,res.size()); _res = ei_aligned_stack_new(Scalar,res.size());
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res; Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res;
} }
ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), ei_cache_friendly_product_colmajor_times_vector
product.rhs(), _res, alpha); <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>(
res.size(),
&actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(),
actualRhs, _res, actualAlpha);
if (!EvalToRes) if (!EvalToRes)
{ {
@ -653,10 +680,21 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols>
struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,RowMajor,HasDirectAccess> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,RowMajor,HasDirectAccess>
{ {
typedef typename ProductType::Scalar Scalar; typedef typename ProductType::Scalar Scalar;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
typedef typename RhsProductTraits::ActualXprType ActualRhsType;
template<typename DestDerived> template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{ {
const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
* RhsProductTraits::extractScalarFactor(product.rhs());
enum { enum {
EvalToRes = (ei_packet_traits<Scalar>::size==1) EvalToRes = (ei_packet_traits<Scalar>::size==1)
||((DestDerived::Flags & ActualPacketAccessBit) && (DestDerived::Flags & RowMajorBit)) }; ||((DestDerived::Flags & ActualPacketAccessBit) && (DestDerived::Flags & RowMajorBit)) };
@ -668,9 +706,11 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
_res = ei_aligned_stack_new(Scalar, res.size()); _res = ei_aligned_stack_new(Scalar, res.size());
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res; Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res;
} }
ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), ei_cache_friendly_product_colmajor_times_vector
product.lhs().transpose(), _res, alpha); <RhsProductTraits::NeedToConjugate,LhsProductTraits::NeedToConjugate>(res.size(),
&actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(),
actualLhs.transpose(), _res, actualAlpha);
if (!EvalToRes) if (!EvalToRes)
{ {
@ -685,24 +725,39 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess>
struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirectAccess,1,RhsOrder,RhsAccess> struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirectAccess,1,RhsOrder,RhsAccess>
{ {
typedef typename ProductType::Scalar Scalar; typedef typename ProductType::Scalar Scalar;
typedef typename ei_traits<ProductType>::_RhsNested Rhs;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
typedef typename RhsProductTraits::ActualXprType ActualRhsType;
enum { enum {
UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Rhs::Flags&ActualPacketAccessBit)) UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualRhsType::Flags&ActualPacketAccessBit))
&& (!(Rhs::Flags & RowMajorBit)) }; && (!(ActualRhsType::Flags & RowMajorBit)) };
template<typename DestDerived> template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{ {
const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
* RhsProductTraits::extractScalarFactor(product.rhs());
Scalar* EIGEN_RESTRICT _rhs; Scalar* EIGEN_RESTRICT _rhs;
if (UseRhsDirectly) if (UseRhsDirectly)
_rhs = &product.rhs().const_cast_derived().coeffRef(0); _rhs = &actualRhs.const_cast_derived().coeffRef(0);
else else
{ {
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size()); _rhs = ei_aligned_stack_new(Scalar, actualRhs.size());
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs(); Map<Matrix<Scalar,ActualRhsType::SizeAtCompileTime,1> >(_rhs, actualRhs.size()) = actualRhs;
} }
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
_rhs, product.rhs().size(), res, alpha); ei_cache_friendly_product_rowmajor_times_vector
<LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>(
&actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(),
_rhs, product.rhs().size(), res, actualAlpha);
if (!UseRhsDirectly) ei_aligned_stack_delete(Scalar, _rhs, product.rhs().size()); if (!UseRhsDirectly) ei_aligned_stack_delete(Scalar, _rhs, product.rhs().size());
} }
@ -713,24 +768,39 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols>
struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,ColMajor,HasDirectAccess> struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,ColMajor,HasDirectAccess>
{ {
typedef typename ProductType::Scalar Scalar; typedef typename ProductType::Scalar Scalar;
typedef typename ei_traits<ProductType>::_LhsNested Lhs;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
typedef typename RhsProductTraits::ActualXprType ActualRhsType;
enum { enum {
UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Lhs::Flags&ActualPacketAccessBit)) UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualLhsType::Flags&ActualPacketAccessBit))
&& (Lhs::Flags & RowMajorBit) }; && (ActualLhsType::Flags & RowMajorBit) };
template<typename DestDerived> template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha) inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{ {
const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
* RhsProductTraits::extractScalarFactor(product.rhs());
Scalar* EIGEN_RESTRICT _lhs; Scalar* EIGEN_RESTRICT _lhs;
if (UseLhsDirectly) if (UseLhsDirectly)
_lhs = &product.lhs().const_cast_derived().coeffRef(0); _lhs = &actualLhs.const_cast_derived().coeffRef(0);
else else
{ {
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size()); _lhs = ei_aligned_stack_new(Scalar, actualLhs.size());
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs(); Map<Matrix<Scalar,ActualLhsType::SizeAtCompileTime,1> >(_lhs, actualLhs.size()) = actualLhs;
} }
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
_lhs, product.lhs().size(), res, alpha); ei_cache_friendly_product_rowmajor_times_vector
<RhsProductTraits::NeedToConjugate, LhsProductTraits::NeedToConjugate>(
&actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(),
_lhs, product.lhs().size(), res, actualAlpha);
if(!UseLhsDirectly) ei_aligned_stack_delete(Scalar, _lhs, product.lhs().size()); if(!UseLhsDirectly) ei_aligned_stack_delete(Scalar, _lhs, product.lhs().size());
} }
@ -827,8 +897,8 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs);
const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs) Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs)
* RhsProductTraits::extractSalarFactor(m_rhs); * RhsProductTraits::extractScalarFactor(m_rhs);
typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy; typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy;
@ -837,7 +907,6 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
LhsCopy lhs(actualLhs); LhsCopy lhs(actualLhs);
RhsCopy rhs(actualRhs); RhsCopy rhs(actualRhs);
ei_cache_friendly_product<Scalar, ei_cache_friendly_product<Scalar,
// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>
((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)), ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)),
((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))> ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))>
( (

View File

@ -30,30 +30,46 @@ struct ei_L2_block_traits {
enum {width = 8 * ei_meta_sqrt<L2MemorySize/(64*sizeof(Scalar))>::ret }; enum {width = 8 * ei_meta_sqrt<L2MemorySize/(64*sizeof(Scalar))>::ret };
}; };
template<bool ConjLhs, bool ConjRhs> struct ei_conj_pmadd; template<bool ConjLhs, bool ConjRhs> struct ei_conj_helper;
template<> struct ei_conj_pmadd<false,false> template<> struct ei_conj_helper<false,false>
{ {
template<typename T> template<typename T>
EIGEN_STRONG_INLINE T operator()(const T& x, const T& y, T& c) const { return ei_pmadd(x,y,c); } EIGEN_STRONG_INLINE T pmadd(const T& x, const T& y, const T& c) const { return ei_pmadd(x,y,c); }
template<typename T>
EIGEN_STRONG_INLINE T pmul(const T& x, const T& y) const { return ei_pmul(x,y); }
}; };
template<> struct ei_conj_pmadd<false,true> template<> struct ei_conj_helper<false,true>
{ {
template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const template<typename T> std::complex<T>
{ return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); } pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const
{ return c + pmul(x,y); }
template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const
//{ return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_imag(x)*ei_real(y) - ei_real(x)*ei_imag(y)); }
{ return x * ei_conj(y); }
}; };
template<> struct ei_conj_pmadd<true,false> template<> struct ei_conj_helper<true,false>
{ {
template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const template<typename T> std::complex<T>
{ return c + std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const
{ return c + pmul(x,y); }
template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const
{ return std::complex<T>(ei_real(x)*ei_real(y) + ei_imag(x)*ei_imag(y), ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); }
}; };
template<> struct ei_conj_pmadd<true,true> template<> struct ei_conj_helper<true,true>
{ {
template<typename T> std::complex<T> operator()(const std::complex<T>& x, const std::complex<T>& y, std::complex<T>& c) const template<typename T> std::complex<T>
{ return c + std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); } pmadd(const std::complex<T>& x, const std::complex<T>& y, const std::complex<T>& c) const
{ return c + pmul(x,y); }
template<typename T> std::complex<T> pmul(const std::complex<T>& x, const std::complex<T>& y) const
// { return std::complex<T>(ei_real(x)*ei_real(y) - ei_imag(x)*ei_imag(y), - ei_real(x)*ei_imag(y) - ei_imag(x)*ei_real(y)); }
{ return ei_conj(x) * ei_conj(y); }
}; };
#ifndef EIGEN_EXTERN_INSTANTIATIONS #ifndef EIGEN_EXTERN_INSTANTIATIONS
@ -74,7 +90,9 @@ static void ei_cache_friendly_product(
int lhsStride, rhsStride, rows, cols; int lhsStride, rhsStride, rows, cols;
bool lhsRowMajor; bool lhsRowMajor;
ei_conj_pmadd<ConjugateLhs,ConjugateRhs> cj_pmadd; ei_conj_helper<ConjugateLhs,ConjugateRhs> cj;
if (ConjugateRhs)
alpha = ei_conj(alpha);
bool hasAlpha = alpha != Scalar(1); bool hasAlpha = alpha != Scalar(1);
if (resRowMajor) if (resRowMajor)
@ -261,59 +279,59 @@ static void ei_cache_friendly_product(
A1 = ei_pload(&blA[1*PacketSize]); A1 = ei_pload(&blA[1*PacketSize]);
B0 = ei_pload(&blB[0*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]);
B1 = ei_pload(&blB[1*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]);
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
C4 = cj_pmadd(A1, B0, C4); C4 = cj.pmadd(A1, B0, C4);
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]);
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
C5 = cj_pmadd(A1, B1, C5); C5 = cj.pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]); B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) C6 = cj.pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[6*PacketSize]); if(nr==4) B2 = ei_pload(&blB[6*PacketSize]);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
A0 = ei_pload(&blA[2*PacketSize]); A0 = ei_pload(&blA[2*PacketSize]);
if(nr==4) C7 = cj_pmadd(A1, B3, C7); if(nr==4) C7 = cj.pmadd(A1, B3, C7);
A1 = ei_pload(&blA[3*PacketSize]); A1 = ei_pload(&blA[3*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[7*PacketSize]); if(nr==4) B3 = ei_pload(&blB[7*PacketSize]);
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
C4 = cj_pmadd(A1, B0, C4); C4 = cj.pmadd(A1, B0, C4);
B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]);
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
C5 = cj_pmadd(A1, B1, C5); C5 = cj.pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]); B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) C6 = cj.pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[10*PacketSize]); if(nr==4) B2 = ei_pload(&blB[10*PacketSize]);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
A0 = ei_pload(&blA[4*PacketSize]); A0 = ei_pload(&blA[4*PacketSize]);
if(nr==4) C7 = cj_pmadd(A1, B3, C7); if(nr==4) C7 = cj.pmadd(A1, B3, C7);
A1 = ei_pload(&blA[5*PacketSize]); A1 = ei_pload(&blA[5*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[11*PacketSize]); if(nr==4) B3 = ei_pload(&blB[11*PacketSize]);
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
C4 = cj_pmadd(A1, B0, C4); C4 = cj.pmadd(A1, B0, C4);
B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]); B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]);
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
C5 = cj_pmadd(A1, B1, C5); C5 = cj.pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]); B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) C6 = cj.pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[14*PacketSize]); if(nr==4) B2 = ei_pload(&blB[14*PacketSize]);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
A0 = ei_pload(&blA[6*PacketSize]); A0 = ei_pload(&blA[6*PacketSize]);
if(nr==4) C7 = cj_pmadd(A1, B3, C7); if(nr==4) C7 = cj.pmadd(A1, B3, C7);
A1 = ei_pload(&blA[7*PacketSize]); A1 = ei_pload(&blA[7*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[15*PacketSize]); if(nr==4) B3 = ei_pload(&blB[15*PacketSize]);
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
C4 = cj_pmadd(A1, B0, C4); C4 = cj.pmadd(A1, B0, C4);
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
C5 = cj_pmadd(A1, B1, C5); C5 = cj.pmadd(A1, B1, C5);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) C6 = cj.pmadd(A1, B2, C6);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
if(nr==4) C7 = cj_pmadd(A1, B3, C7); if(nr==4) C7 = cj.pmadd(A1, B3, C7);
blB += 4*nr*PacketSize; blB += 4*nr*PacketSize;
blA += 4*mr; blA += 4*mr;
@ -327,16 +345,16 @@ static void ei_cache_friendly_product(
A1 = ei_pload(&blA[1*PacketSize]); A1 = ei_pload(&blA[1*PacketSize]);
B0 = ei_pload(&blB[0*PacketSize]); B0 = ei_pload(&blB[0*PacketSize]);
B1 = ei_pload(&blB[1*PacketSize]); B1 = ei_pload(&blB[1*PacketSize]);
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]); if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
C4 = cj_pmadd(A1, B0, C4); C4 = cj.pmadd(A1, B0, C4);
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]); if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
C5 = cj_pmadd(A1, B1, C5); C5 = cj.pmadd(A1, B1, C5);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C6 = cj_pmadd(A1, B2, C6); if(nr==4) C6 = cj.pmadd(A1, B2, C6);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
if(nr==4) C7 = cj_pmadd(A1, B3, C7); if(nr==4) C7 = cj.pmadd(A1, B3, C7);
blB += nr*PacketSize; blB += nr*PacketSize;
blA += mr; blA += mr;
@ -368,12 +386,12 @@ static void ei_cache_friendly_product(
A0 = blA[k]; A0 = blA[k];
B0 = blB[0*PacketSize]; B0 = blB[0*PacketSize];
B1 = blB[1*PacketSize]; B1 = blB[1*PacketSize];
C0 = cj_pmadd(A0, B0, C0); C0 = cj.pmadd(A0, B0, C0);
if(nr==4) B2 = blB[2*PacketSize]; if(nr==4) B2 = blB[2*PacketSize];
if(nr==4) B3 = blB[3*PacketSize]; if(nr==4) B3 = blB[3*PacketSize];
C1 = cj_pmadd(A0, B1, C1); C1 = cj.pmadd(A0, B1, C1);
if(nr==4) C2 = cj_pmadd(A0, B2, C2); if(nr==4) C2 = cj.pmadd(A0, B2, C2);
if(nr==4) C3 = cj_pmadd(A0, B3, C3); if(nr==4) C3 = cj.pmadd(A0, B3, C3);
blB += nr*PacketSize; blB += nr*PacketSize;
} }
@ -391,11 +409,11 @@ static void ei_cache_friendly_product(
Scalar c0 = Scalar(0); Scalar c0 = Scalar(0);
if (lhsRowMajor) if (lhsRowMajor)
for(int k=0; k<actual_kc; k++) for(int k=0; k<actual_kc; k++)
c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0); c0 += cj.pmul(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k]);
else else
for(int k=0; k<actual_kc; k++) for(int k=0; k<actual_kc; k++)
c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0); c0 += cj.pmul(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k]);
res[(j2)*resStride + i2+i] += alpha * c0; res[(j2)*resStride + i2+i] += (ConjugateRhs ? ei_conj(alpha) : alpha) * c0;
} }
} }
} }
@ -493,39 +511,39 @@ static void ei_cache_friendly_product(
L0 = ei_pload(&lb[1*PacketSize]); L0 = ei_pload(&lb[1*PacketSize]);
R1 = ei_pload(&lb[2*PacketSize]); R1 = ei_pload(&lb[2*PacketSize]);
L1 = ei_pload(&lb[3*PacketSize]); L1 = ei_pload(&lb[3*PacketSize]);
T0 = cj_pmadd(A0, R0, T0); T0 = cj.pmadd(A0, R0, T0);
T1 = cj_pmadd(A0, L0, T1); T1 = cj.pmadd(A0, L0, T1);
R0 = ei_pload(&lb[4*PacketSize]); R0 = ei_pload(&lb[4*PacketSize]);
L0 = ei_pload(&lb[5*PacketSize]); L0 = ei_pload(&lb[5*PacketSize]);
T0 = cj_pmadd(A1, R1, T0); T0 = cj.pmadd(A1, R1, T0);
T1 = cj_pmadd(A1, L1, T1); T1 = cj.pmadd(A1, L1, T1);
R1 = ei_pload(&lb[6*PacketSize]); R1 = ei_pload(&lb[6*PacketSize]);
L1 = ei_pload(&lb[7*PacketSize]); L1 = ei_pload(&lb[7*PacketSize]);
T0 = cj_pmadd(A2, R0, T0); T0 = cj.pmadd(A2, R0, T0);
T1 = cj_pmadd(A2, L0, T1); T1 = cj.pmadd(A2, L0, T1);
if(MaxBlockRows==8) if(MaxBlockRows==8)
{ {
R0 = ei_pload(&lb[8*PacketSize]); R0 = ei_pload(&lb[8*PacketSize]);
L0 = ei_pload(&lb[9*PacketSize]); L0 = ei_pload(&lb[9*PacketSize]);
} }
T0 = cj_pmadd(A3, R1, T0); T0 = cj.pmadd(A3, R1, T0);
T1 = cj_pmadd(A3, L1, T1); T1 = cj.pmadd(A3, L1, T1);
if(MaxBlockRows==8) if(MaxBlockRows==8)
{ {
R1 = ei_pload(&lb[10*PacketSize]); R1 = ei_pload(&lb[10*PacketSize]);
L1 = ei_pload(&lb[11*PacketSize]); L1 = ei_pload(&lb[11*PacketSize]);
T0 = cj_pmadd(A4, R0, T0); T0 = cj.pmadd(A4, R0, T0);
T1 = cj_pmadd(A4, L0, T1); T1 = cj.pmadd(A4, L0, T1);
R0 = ei_pload(&lb[12*PacketSize]); R0 = ei_pload(&lb[12*PacketSize]);
L0 = ei_pload(&lb[13*PacketSize]); L0 = ei_pload(&lb[13*PacketSize]);
T0 = cj_pmadd(A5, R1, T0); T0 = cj.pmadd(A5, R1, T0);
T1 = cj_pmadd(A5, L1, T1); T1 = cj.pmadd(A5, L1, T1);
R1 = ei_pload(&lb[14*PacketSize]); R1 = ei_pload(&lb[14*PacketSize]);
L1 = ei_pload(&lb[15*PacketSize]); L1 = ei_pload(&lb[15*PacketSize]);
T0 = cj_pmadd(A6, R0, T0); T0 = cj.pmadd(A6, R0, T0);
T1 = cj_pmadd(A6, L0, T1); T1 = cj.pmadd(A6, L0, T1);
T0 = cj_pmadd(A7, R1, T0); T0 = cj.pmadd(A7, R1, T0);
T1 = cj_pmadd(A7, L1, T1); T1 = cj.pmadd(A7, L1, T1);
} }
lb += MaxBlockRows*2*PacketSize; lb += MaxBlockRows*2*PacketSize;

View File

@ -32,8 +32,9 @@
* same alignment pattern. * same alignment pattern.
* TODO: since rhs gets evaluated only once, no need to evaluate it * TODO: since rhs gets evaluated only once, no need to evaluate it
*/ */
template<typename Scalar, typename RhsType> template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType>
static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector( static EIGEN_DONT_INLINE
void ei_cache_friendly_product_colmajor_times_vector(
int size, int size,
const Scalar* lhs, int lhsStride, const Scalar* lhs, int lhsStride,
const RhsType& rhs, const RhsType& rhs,
@ -47,10 +48,14 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
ei_pstore(&res[j], \ ei_pstore(&res[j], \
ei_padd(ei_pload(&res[j]), \ ei_padd(ei_pload(&res[j]), \
ei_padd( \ ei_padd( \
ei_padd(ei_pmul(ptmp0,EIGEN_CAT(ei_ploa , A0)(&lhs0[j])), \ ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A0)(&lhs0[j]), ptmp0), \
ei_pmul(ptmp1,EIGEN_CAT(ei_ploa , A13)(&lhs1[j]))), \ cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs1[j]), ptmp1)), \
ei_padd(ei_pmul(ptmp2,EIGEN_CAT(ei_ploa , A2)(&lhs2[j])), \ ei_padd(cj.pmul(EIGEN_CAT(ei_ploa , A2)(&lhs2[j]), ptmp2), \
ei_pmul(ptmp3,EIGEN_CAT(ei_ploa , A13)(&lhs3[j]))) ))) cj.pmul(EIGEN_CAT(ei_ploa , A13)(&lhs3[j]), ptmp3)) )))
ei_conj_helper<ConjugateLhs,ConjugateRhs> cj;
if(ConjugateRhs)
alpha = ei_conj(alpha);
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
const int PacketSize = sizeof(Packet)/sizeof(Scalar); const int PacketSize = sizeof(Packet)/sizeof(Scalar);
@ -109,7 +114,7 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
ptmp2 = ei_pset1(alpha*rhs[i+2]), ptmp3 = ei_pset1(alpha*rhs[i+offset3]); ptmp2 = ei_pset1(alpha*rhs[i+2]), ptmp3 = ei_pset1(alpha*rhs[i+offset3]);
// this helps a lot generating better binary code // this helps a lot generating better binary code
const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride,
*lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride;
if (PacketSize>1) if (PacketSize>1)
@ -117,7 +122,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
/* explicit vectorization */ /* explicit vectorization */
// process initial unaligned coeffs // process initial unaligned coeffs
for (int j=0; j<alignedStart; ++j) for (int j=0; j<alignedStart; ++j)
res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; {
res[j] = cj.pmadd(lhs0[j], ei_pfirst(ptmp0), res[j]);
res[j] = cj.pmadd(lhs1[j], ei_pfirst(ptmp1), res[j]);
res[j] = cj.pmadd(lhs2[j], ei_pfirst(ptmp2), res[j]);
res[j] = cj.pmadd(lhs3[j], ei_pfirst(ptmp3), res[j]);
// res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j];
}
if (alignedSize>alignedStart) if (alignedSize>alignedStart)
{ {
@ -148,19 +159,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
A00 = ei_pload (&lhs0[j]); A00 = ei_pload (&lhs0[j]);
A10 = ei_pload (&lhs0[j+PacketSize]); A10 = ei_pload (&lhs0[j+PacketSize]);
A00 = ei_pmadd(ptmp0, A00, ei_pload(&res[j])); A00 = cj.pmadd(A00, ptmp0, ei_pload(&res[j]));
A10 = ei_pmadd(ptmp0, A10, ei_pload(&res[j+PacketSize])); A10 = cj.pmadd(A10, ptmp0, ei_pload(&res[j+PacketSize]));
A00 = ei_pmadd(ptmp1, A01, A00); A00 = cj.pmadd(A01, ptmp1, A00);
A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01);
A00 = ei_pmadd(ptmp2, A02, A00); A00 = cj.pmadd(A02, ptmp2, A00);
A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02);
A00 = ei_pmadd(ptmp3, A03, A00); A00 = cj.pmadd(A03, ptmp3, A00);
ei_pstore(&res[j],A00); ei_pstore(&res[j],A00);
A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03);
A10 = ei_pmadd(ptmp1, A11, A10); A10 = cj.pmadd(A11, ptmp1, A10);
A10 = ei_pmadd(ptmp2, A12, A10); A10 = cj.pmadd(A12, ptmp2, A10);
A10 = ei_pmadd(ptmp3, A13, A10); A10 = cj.pmadd(A13, ptmp3, A10);
ei_pstore(&res[j+PacketSize],A10); ei_pstore(&res[j+PacketSize],A10);
} }
} }
@ -177,7 +188,13 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
/* process remaining coeffs (or all if there is no explicit vectorization) */ /* process remaining coeffs (or all if there is no explicit vectorization) */
for (int j=alignedSize; j<size; ++j) for (int j=alignedSize; j<size; ++j)
res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j]; {
res[j] = cj.pmadd(lhs0[j], ei_pfirst(ptmp0), res[j]);
res[j] = cj.pmadd(lhs1[j], ei_pfirst(ptmp1), res[j]);
res[j] = cj.pmadd(lhs2[j], ei_pfirst(ptmp2), res[j]);
res[j] = cj.pmadd(lhs3[j], ei_pfirst(ptmp3), res[j]);
// res[j] += ei_pfirst(ptmp0)*lhs0[j] + ei_pfirst(ptmp1)*lhs1[j] + ei_pfirst(ptmp2)*lhs2[j] + ei_pfirst(ptmp3)*lhs3[j];
}
} }
// process remaining first and last columns (at most columnsAtOnce-1) // process remaining first and last columns (at most columnsAtOnce-1)
@ -195,20 +212,20 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
/* explicit vectorization */ /* explicit vectorization */
// process first unaligned result's coeffs // process first unaligned result's coeffs
for (int j=0; j<alignedStart; ++j) for (int j=0; j<alignedStart; ++j)
res[j] += ei_pfirst(ptmp0) * lhs0[j]; res[j] = cj.pmul(lhs0[j], ei_pfirst(ptmp0));
// process aligned result's coeffs // process aligned result's coeffs
if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0) if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0)
for (int j = alignedStart;j<alignedSize;j+=PacketSize) for (int j = alignedStart;j<alignedSize;j+=PacketSize)
ei_pstore(&res[j], ei_pmadd(ptmp0,ei_pload(&lhs0[j]),ei_pload(&res[j]))); ei_pstore(&res[j], cj.pmadd(ei_pload(&lhs0[j]), ptmp0, ei_pload(&res[j])));
else else
for (int j = alignedStart;j<alignedSize;j+=PacketSize) for (int j = alignedStart;j<alignedSize;j+=PacketSize)
ei_pstore(&res[j], ei_pmadd(ptmp0,ei_ploadu(&lhs0[j]),ei_pload(&res[j]))); ei_pstore(&res[j], cj.pmadd(ei_ploadu(&lhs0[j]), ptmp0, ei_pload(&res[j])));
} }
// process remaining scalars (or all if no explicit vectorization) // process remaining scalars (or all if no explicit vectorization)
for (int j=alignedSize; j<size; ++j) for (int j=alignedSize; j<size; ++j)
res[j] += ei_pfirst(ptmp0) * lhs0[j]; res[j] += cj.pmul(lhs0[j], ei_pfirst(ptmp0));
} }
if (skipColumns) if (skipColumns)
{ {
@ -223,7 +240,7 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_colmajor_times_vector(
} }
// TODO add peeling to mask unaligned load/stores // TODO add peeling to mask unaligned load/stores
template<typename Scalar, typename ResType> template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType>
static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
const Scalar* lhs, int lhsStride, const Scalar* lhs, int lhsStride,
const Scalar* rhs, int rhsSize, const Scalar* rhs, int rhsSize,
@ -236,10 +253,12 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
#define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\ #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\
Packet b = ei_pload(&rhs[j]); \ Packet b = ei_pload(&rhs[j]); \
ptmp0 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), ptmp0); \ ptmp0 = cj.pmadd(EIGEN_CAT(ei_ploa,A0) (&lhs0[j]), b, ptmp0); \
ptmp1 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), ptmp1); \ ptmp1 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs1[j]), b, ptmp1); \
ptmp2 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), ptmp2); \ ptmp2 = cj.pmadd(EIGEN_CAT(ei_ploa,A2) (&lhs2[j]), b, ptmp2); \
ptmp3 = ei_pmadd(b, EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), ptmp3); } ptmp3 = cj.pmadd(EIGEN_CAT(ei_ploa,A13)(&lhs3[j]), b, ptmp3); }
ei_conj_helper<ConjugateLhs,ConjugateRhs> cj;
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
const int PacketSize = sizeof(Packet)/sizeof(Scalar); const int PacketSize = sizeof(Packet)/sizeof(Scalar);
@ -311,7 +330,8 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
for (int j=0; j<alignedStart; ++j) for (int j=0; j<alignedStart; ++j)
{ {
Scalar b = rhs[j]; Scalar b = rhs[j];
tmp0 += b*lhs0[j]; tmp1 += b*lhs1[j]; tmp2 += b*lhs2[j]; tmp3 += b*lhs3[j]; tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
} }
if (alignedSize>alignedStart) if (alignedSize>alignedStart)
@ -347,19 +367,19 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
A12 = ei_pload(&lhs2[j-2+PacketSize]); ei_palign<2>(A02,A12); A12 = ei_pload(&lhs2[j-2+PacketSize]); ei_palign<2>(A02,A12);
A13 = ei_pload(&lhs3[j-3+PacketSize]); ei_palign<3>(A03,A13); A13 = ei_pload(&lhs3[j-3+PacketSize]); ei_palign<3>(A03,A13);
ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j]), ptmp0); ptmp0 = cj.pmadd(ei_pload (&lhs0[j]), b, ptmp0);
ptmp1 = ei_pmadd(b, A01, ptmp1); ptmp1 = cj.pmadd(A01, b, ptmp1);
A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01); A01 = ei_pload(&lhs1[j-1+2*PacketSize]); ei_palign<1>(A11,A01);
ptmp2 = ei_pmadd(b, A02, ptmp2); ptmp2 = cj.pmadd(A02, b, ptmp2);
A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02); A02 = ei_pload(&lhs2[j-2+2*PacketSize]); ei_palign<2>(A12,A02);
ptmp3 = ei_pmadd(b, A03, ptmp3); ptmp3 = cj.pmadd(A03, b, ptmp3);
A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03); A03 = ei_pload(&lhs3[j-3+2*PacketSize]); ei_palign<3>(A13,A03);
b = ei_pload(&rhs[j+PacketSize]); b = ei_pload(&rhs[j+PacketSize]);
ptmp0 = ei_pmadd(b, ei_pload (&lhs0[j+PacketSize]), ptmp0); ptmp0 = cj.pmadd(ei_pload (&lhs0[j+PacketSize]), b, ptmp0);
ptmp1 = ei_pmadd(b, A11, ptmp1); ptmp1 = cj.pmadd(A11, b, ptmp1);
ptmp2 = ei_pmadd(b, A12, ptmp2); ptmp2 = cj.pmadd(A12, b, ptmp2);
ptmp3 = ei_pmadd(b, A13, ptmp3); ptmp3 = cj.pmadd(A13, b, ptmp3);
} }
} }
for (int j = peeledSize; j<alignedSize; j+=PacketSize) for (int j = peeledSize; j<alignedSize; j+=PacketSize)
@ -382,7 +402,8 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
for (int j=alignedSize; j<size; ++j) for (int j=alignedSize; j<size; ++j)
{ {
Scalar b = rhs[j]; Scalar b = rhs[j];
tmp0 += b*lhs0[j]; tmp1 += b*lhs1[j]; tmp2 += b*lhs2[j]; tmp3 += b*lhs3[j]; tmp0 += cj.pmul(lhs0[j],b); tmp1 += cj.pmul(lhs1[j],b);
tmp2 += cj.pmul(lhs2[j],b); tmp3 += cj.pmul(lhs3[j],b);
} }
res[i] += alpha*tmp0; res[i+offset1] += alpha*tmp1; res[i+2] += alpha*tmp2; res[i+offset3] += alpha*tmp3; res[i] += alpha*tmp0; res[i+offset1] += alpha*tmp1; res[i+2] += alpha*tmp2; res[i+offset3] += alpha*tmp3;
} }
@ -400,24 +421,24 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector(
// process first unaligned result's coeffs // process first unaligned result's coeffs
// FIXME this loop get vectorized by the compiler ! // FIXME this loop get vectorized by the compiler !
for (int j=0; j<alignedStart; ++j) for (int j=0; j<alignedStart; ++j)
tmp0 += rhs[j] * lhs0[j]; tmp0 += cj.pmul(lhs0[j], rhs[j]);
if (alignedSize>alignedStart) if (alignedSize>alignedStart)
{ {
// process aligned rhs coeffs // process aligned rhs coeffs
if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0) if ((size_t(lhs0+alignedStart)%sizeof(Packet))==0)
for (int j = alignedStart;j<alignedSize;j+=PacketSize) for (int j = alignedStart;j<alignedSize;j+=PacketSize)
ptmp0 = ei_pmadd(ei_pload(&rhs[j]), ei_pload(&lhs0[j]), ptmp0); ptmp0 = cj.pmadd(ei_pload(&lhs0[j]), ei_pload(&rhs[j]), ptmp0);
else else
for (int j = alignedStart;j<alignedSize;j+=PacketSize) for (int j = alignedStart;j<alignedSize;j+=PacketSize)
ptmp0 = ei_pmadd(ei_pload(&rhs[j]), ei_ploadu(&lhs0[j]), ptmp0); ptmp0 = cj.pmadd(ei_ploadu(&lhs0[j]), ei_pload(&rhs[j]), ptmp0);
tmp0 += ei_predux(ptmp0); tmp0 += ei_predux(ptmp0);
} }
// process remaining scalars // process remaining scalars
// FIXME this loop get vectorized by the compiler ! // FIXME this loop get vectorized by the compiler !
for (int j=alignedSize; j<size; ++j) for (int j=alignedSize; j<size; ++j)
tmp0 += rhs[j] * lhs0[j]; tmp0 += cj.pmul(lhs0[j], rhs[j]);
res[i] += alpha*tmp0; res[i] += alpha*tmp0;
} }
if (skipRows) if (skipRows)

View File

@ -98,6 +98,7 @@ ei_add_test(redux)
ei_add_test(product_small) ei_add_test(product_small)
ei_add_test(product_large ${EI_OFLAG}) ei_add_test(product_large ${EI_OFLAG})
ei_add_test(product_selfadjoint) ei_add_test(product_selfadjoint)
ei_add_test(product_extra)
ei_add_test(diagonalmatrices) ei_add_test(diagonalmatrices)
ei_add_test(adjoint) ei_add_test(adjoint)
ei_add_test(submatrices) ei_add_test(submatrices)

120
test/product_extra.cpp Normal file
View File

@ -0,0 +1,120 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#include "main.h"
#include <Eigen/Array>
template<typename MatrixType> void product_extra(const MatrixType& m)
{
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<Scalar>::FloatingPoint FloatingPoint;
typedef Matrix<Scalar, 1, Dynamic> RowVectorType;
typedef Matrix<Scalar, Dynamic, 1> ColVectorType;
typedef Matrix<Scalar, Dynamic, Dynamic,
MatrixType::Flags&RowMajorBit> OtherMajorMatrixType;
int rows = m.rows();
int cols = m.cols();
MatrixType m1 = MatrixType::Random(rows, cols),
m2 = MatrixType::Random(rows, cols),
m3(rows, cols),
mzero = MatrixType::Zero(rows, cols),
identity = MatrixType::Identity(rows, rows),
square = MatrixType::Random(rows, rows),
res = MatrixType::Random(rows, rows),
square2 = MatrixType::Random(cols, cols),
res2 = MatrixType::Random(cols, cols);
RowVectorType v1 = RowVectorType::Random(rows),
v2 = RowVectorType::Random(rows),
vzero = RowVectorType::Zero(rows);
ColVectorType vc2 = ColVectorType::Random(cols), vcres(cols);
OtherMajorMatrixType tm1 = m1;
Scalar s1 = ei_random<Scalar>(),
s2 = ei_random<Scalar>(),
s3 = ei_random<Scalar>();
// all the expressions in this test should be compiled as a single matrix product
// TODO: add internal checks to verify that
VERIFY_IS_APPROX(m1 * m2.adjoint(), m1 * m2.adjoint().eval());
VERIFY_IS_APPROX(m1.adjoint() * square.adjoint(), m1.adjoint().eval() * square.adjoint().eval());
VERIFY_IS_APPROX(m1.adjoint() * m2, m1.adjoint().eval() * m2);
VERIFY_IS_APPROX( (s1 * m1.adjoint()) * m2, (s1 * m1.adjoint()).eval() * m2);
VERIFY_IS_APPROX( (- m1.adjoint() * s1) * (s3 * m2), (- m1.adjoint() * s1).eval() * (s3 * m2).eval());
VERIFY_IS_APPROX( (s2 * m1.adjoint() * s1) * m2, (s2 * m1.adjoint() * s1).eval() * m2);
VERIFY_IS_APPROX( (-m1*s2) * s1*m2.adjoint(), (-m1*s2).eval() * (s1*m2.adjoint()).eval());
// a very tricky case where a scale factor has to be automatically conjugated:
VERIFY_IS_APPROX( m1.adjoint() * (s1*m2).conjugate(), (m1.adjoint()).eval() * ((s1*m2).conjugate()).eval());
// test all possible conjugate combinations for the four matrix-vector product cases:
// std::cerr << "a\n";
VERIFY_IS_APPROX((-m1.conjugate() * s2) * (s1 * vc2),
(-m1.conjugate()*s2).eval() * (s1 * vc2).eval());
VERIFY_IS_APPROX((-m1 * s2) * (s1 * vc2.conjugate()),
(-m1*s2).eval() * (s1 * vc2.conjugate()).eval());
VERIFY_IS_APPROX((-m1.conjugate() * s2) * (s1 * vc2.conjugate()),
(-m1.conjugate()*s2).eval() * (s1 * vc2.conjugate()).eval());
// std::cerr << "b\n";
VERIFY_IS_APPROX((s1 * vc2.transpose()) * (-m1.adjoint() * s2),
(s1 * vc2.transpose()).eval() * (-m1.adjoint()*s2).eval());
VERIFY_IS_APPROX((s1 * vc2.adjoint()) * (-m1.transpose() * s2),
(s1 * vc2.adjoint()).eval() * (-m1.transpose()*s2).eval());
VERIFY_IS_APPROX((s1 * vc2.adjoint()) * (-m1.adjoint() * s2),
(s1 * vc2.adjoint()).eval() * (-m1.adjoint()*s2).eval());
// std::cerr << "c\n";
VERIFY_IS_APPROX((-m1.adjoint() * s2) * (s1 * v1.transpose()),
(-m1.adjoint()*s2).eval() * (s1 * v1.transpose()).eval());
VERIFY_IS_APPROX((-m1.transpose() * s2) * (s1 * v1.adjoint()),
(-m1.transpose()*s2).eval() * (s1 * v1.adjoint()).eval());
VERIFY_IS_APPROX((-m1.adjoint() * s2) * (s1 * v1.adjoint()),
(-m1.adjoint()*s2).eval() * (s1 * v1.adjoint()).eval());
// std::cerr << "d\n";
VERIFY_IS_APPROX((s1 * v1) * (-m1.conjugate() * s2),
(s1 * v1).eval() * (-m1.conjugate()*s2).eval());
VERIFY_IS_APPROX((s1 * v1.conjugate()) * (-m1 * s2),
(s1 * v1.conjugate()).eval() * (-m1*s2).eval());
VERIFY_IS_APPROX((s1 * v1.conjugate()) * (-m1.conjugate() * s2),
(s1 * v1.conjugate()).eval() * (-m1.conjugate()*s2).eval());
VERIFY_IS_APPROX((-m1.adjoint() * s2) * (s1 * v1.adjoint()),
(-m1.adjoint()*s2).eval() * (s1 * v1.adjoint()).eval());
}
void test_product_extra()
{
// for(int i = 0; i < g_repeat; i++) {
// CALL_SUBTEST( product_extra(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
// CALL_SUBTEST( product(MatrixXd(ei_random<int>(1,320), ei_random<int>(1,320))) );
// CALL_SUBTEST( product(MatrixXi(ei_random<int>(1,320), ei_random<int>(1,320))) );
CALL_SUBTEST( product_extra(MatrixXcf(ei_random<int>(50,50), ei_random<int>(50,50))) );
// CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
// }
}