mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
conjugate expressions are now properly caught by Product
=> significant speedup in expr. like a.adjoint() * b, for complex scalar type (~ x3)
This commit is contained in:
parent
5ed6ce90d3
commit
13b2dafb50
@ -96,7 +96,8 @@ class CwiseUnaryOp : ei_no_assignment_operator,
|
|||||||
const UnaryOp& _functor() const { return m_functor; }
|
const UnaryOp& _functor() const { return m_functor; }
|
||||||
|
|
||||||
/** \internal used for introspection */
|
/** \internal used for introspection */
|
||||||
const typename MatrixType::Nested& _expression() const { return m_matrix; }
|
const typename ei_cleantype<typename MatrixType::Nested>::type&
|
||||||
|
_expression() const { return m_matrix; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const typename MatrixType::Nested m_matrix;
|
const typename MatrixType::Nested m_matrix;
|
||||||
|
@ -65,9 +65,8 @@ struct ProductReturnType
|
|||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
||||||
{
|
{
|
||||||
typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
|
typedef typename ei_nested<Lhs,1>::type LhsNested;
|
||||||
|
typedef typename ei_nested<Rhs,1,
|
||||||
typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
|
|
||||||
typename ei_plain_matrix_type_column_major<Rhs>::type
|
typename ei_plain_matrix_type_column_major<Rhs>::type
|
||||||
>::type RhsNested;
|
>::type RhsNested;
|
||||||
|
|
||||||
@ -95,14 +94,14 @@ template<typename Lhs, typename Rhs> struct ei_product_mode
|
|||||||
template<typename XprType> struct ei_product_factor_traits
|
template<typename XprType> struct ei_product_factor_traits
|
||||||
{
|
{
|
||||||
typedef typename ei_traits<XprType>::Scalar Scalar;
|
typedef typename ei_traits<XprType>::Scalar Scalar;
|
||||||
typedef XprType RealXprType;
|
typedef XprType ActualXprType;
|
||||||
enum {
|
enum {
|
||||||
IsComplex = NumTraits<Scalar>::IsComplex,
|
IsComplex = NumTraits<Scalar>::IsComplex,
|
||||||
NeedToConjugate = false,
|
NeedToConjugate = false,
|
||||||
HasScalarMultiple = false,
|
HasScalarMultiple = false,
|
||||||
Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
|
Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
|
||||||
};
|
};
|
||||||
static inline const RealXprType& extract(const XprType& x) { return x; }
|
static inline const ActualXprType& extract(const XprType& x) { return x; }
|
||||||
static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); }
|
static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -112,13 +111,13 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
|
|||||||
{
|
{
|
||||||
typedef ei_product_factor_traits<NestedXpr> Base;
|
typedef ei_product_factor_traits<NestedXpr> Base;
|
||||||
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
|
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
|
||||||
typedef typename Base::RealXprType RealXprType;
|
typedef typename Base::ActualXprType ActualXprType;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
IsComplex = NumTraits<Scalar>::IsComplex,
|
IsComplex = NumTraits<Scalar>::IsComplex,
|
||||||
NeedToConjugate = IsComplex
|
NeedToConjugate = IsComplex
|
||||||
};
|
};
|
||||||
static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
|
static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
|
||||||
static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); }
|
static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -128,12 +127,12 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
|
|||||||
{
|
{
|
||||||
typedef ei_product_factor_traits<NestedXpr> Base;
|
typedef ei_product_factor_traits<NestedXpr> Base;
|
||||||
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
|
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
|
||||||
typedef typename Base::RealXprType RealXprType;
|
typedef typename Base::ActualXprType ActualXprType;
|
||||||
enum {
|
enum {
|
||||||
HasScalarMultiple = true
|
HasScalarMultiple = true
|
||||||
};
|
};
|
||||||
static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
|
static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
|
||||||
static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; }
|
static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \class Product
|
/** \class Product
|
||||||
@ -819,18 +818,34 @@ template<typename Lhs, typename Rhs, int ProductMode>
|
|||||||
template<typename DestDerived>
|
template<typename DestDerived>
|
||||||
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const
|
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const
|
||||||
{
|
{
|
||||||
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
|
typedef ei_product_factor_traits<_LhsNested> LhsProductTraits;
|
||||||
|
typedef ei_product_factor_traits<_RhsNested> RhsProductTraits;
|
||||||
|
|
||||||
|
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||||
|
typedef typename RhsProductTraits::ActualXprType ActualRhsType;
|
||||||
|
|
||||||
|
const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs);
|
||||||
|
const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs);
|
||||||
|
|
||||||
|
Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs)
|
||||||
|
* RhsProductTraits::extractSalarFactor(m_rhs);
|
||||||
|
|
||||||
|
typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
|
||||||
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
|
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
|
||||||
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
|
typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy;
|
||||||
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
|
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
|
||||||
LhsCopy lhs(m_lhs);
|
LhsCopy lhs(actualLhs);
|
||||||
RhsCopy rhs(m_rhs);
|
RhsCopy rhs(actualRhs);
|
||||||
ei_cache_friendly_product<Scalar,false,false>(
|
ei_cache_friendly_product<Scalar,
|
||||||
|
// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>
|
||||||
|
((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)),
|
||||||
|
((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))>
|
||||||
|
(
|
||||||
rows(), cols(), lhs.cols(),
|
rows(), cols(), lhs.cols(),
|
||||||
_LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
|
_LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
|
||||||
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
||||||
Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(),
|
Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(),
|
||||||
alpha
|
actualAlpha
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,6 +58,9 @@ template<> struct ei_conj_pmadd<true,true>
|
|||||||
|
|
||||||
#ifndef EIGEN_EXTERN_INSTANTIATIONS
|
#ifndef EIGEN_EXTERN_INSTANTIATIONS
|
||||||
|
|
||||||
|
/** \warning you should never call this function directly,
|
||||||
|
* this is because the ConjugateLhs/ConjugateRhs have to
|
||||||
|
* be flipped is resRowMajor==true */
|
||||||
template<typename Scalar, bool ConjugateLhs, bool ConjugateRhs>
|
template<typename Scalar, bool ConjugateLhs, bool ConjugateRhs>
|
||||||
static void ei_cache_friendly_product(
|
static void ei_cache_friendly_product(
|
||||||
int _rows, int _cols, int depth,
|
int _rows, int _cols, int depth,
|
||||||
@ -76,6 +79,12 @@ static void ei_cache_friendly_product(
|
|||||||
|
|
||||||
if (resRowMajor)
|
if (resRowMajor)
|
||||||
{
|
{
|
||||||
|
// return ei_cache_friendly_product<Scalar,ConjugateRhs,ConjugateLhs>(_cols,_rows,depth,
|
||||||
|
// !_rhsRowMajor, _rhs, _rhsStride,
|
||||||
|
// !_lhsRowMajor, _lhs, _lhsStride,
|
||||||
|
// false, res, resStride,
|
||||||
|
// alpha);
|
||||||
|
|
||||||
lhs = _rhs;
|
lhs = _rhs;
|
||||||
rhs = _lhs;
|
rhs = _lhs;
|
||||||
lhsStride = _rhsStride;
|
lhsStride = _rhsStride;
|
||||||
@ -252,59 +261,59 @@ static void ei_cache_friendly_product(
|
|||||||
A1 = ei_pload(&blA[1*PacketSize]);
|
A1 = ei_pload(&blA[1*PacketSize]);
|
||||||
B0 = ei_pload(&blB[0*PacketSize]);
|
B0 = ei_pload(&blB[0*PacketSize]);
|
||||||
B1 = ei_pload(&blB[1*PacketSize]);
|
B1 = ei_pload(&blB[1*PacketSize]);
|
||||||
C0 = cj_pmadd(B0, A0, C0);
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
|
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
|
||||||
C4 = cj_pmadd(B0, A1, C4);
|
C4 = cj_pmadd(A1, B0, C4);
|
||||||
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
|
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
|
||||||
B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]);
|
B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]);
|
||||||
C1 = cj_pmadd(B1, A0, C1);
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
C5 = cj_pmadd(B1, A1, C5);
|
C5 = cj_pmadd(A1, B1, C5);
|
||||||
B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]);
|
B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]);
|
||||||
if(nr==4) C2 = cj_pmadd(B2, A0, C2);
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C6 = cj_pmadd(B2, A1, C6);
|
if(nr==4) C6 = cj_pmadd(A1, B2, C6);
|
||||||
if(nr==4) B2 = ei_pload(&blB[6*PacketSize]);
|
if(nr==4) B2 = ei_pload(&blB[6*PacketSize]);
|
||||||
if(nr==4) C3 = cj_pmadd(B3, A0, C3);
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
A0 = ei_pload(&blA[2*PacketSize]);
|
A0 = ei_pload(&blA[2*PacketSize]);
|
||||||
if(nr==4) C7 = cj_pmadd(B3, A1, C7);
|
if(nr==4) C7 = cj_pmadd(A1, B3, C7);
|
||||||
A1 = ei_pload(&blA[3*PacketSize]);
|
A1 = ei_pload(&blA[3*PacketSize]);
|
||||||
if(nr==4) B3 = ei_pload(&blB[7*PacketSize]);
|
if(nr==4) B3 = ei_pload(&blB[7*PacketSize]);
|
||||||
C0 = cj_pmadd(B0, A0, C0);
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
C4 = cj_pmadd(B0, A1, C4);
|
C4 = cj_pmadd(A1, B0, C4);
|
||||||
B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]);
|
B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]);
|
||||||
C1 = cj_pmadd(B1, A0, C1);
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
C5 = cj_pmadd(B1, A1, C5);
|
C5 = cj_pmadd(A1, B1, C5);
|
||||||
B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]);
|
B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]);
|
||||||
if(nr==4) C2 = cj_pmadd(B2, A0, C2);
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C6 = cj_pmadd(B2, A1, C6);
|
if(nr==4) C6 = cj_pmadd(A1, B2, C6);
|
||||||
if(nr==4) B2 = ei_pload(&blB[10*PacketSize]);
|
if(nr==4) B2 = ei_pload(&blB[10*PacketSize]);
|
||||||
if(nr==4) C3 = cj_pmadd(B3, A0, C3);
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
A0 = ei_pload(&blA[4*PacketSize]);
|
A0 = ei_pload(&blA[4*PacketSize]);
|
||||||
if(nr==4) C7 = cj_pmadd(B3, A1, C7);
|
if(nr==4) C7 = cj_pmadd(A1, B3, C7);
|
||||||
A1 = ei_pload(&blA[5*PacketSize]);
|
A1 = ei_pload(&blA[5*PacketSize]);
|
||||||
if(nr==4) B3 = ei_pload(&blB[11*PacketSize]);
|
if(nr==4) B3 = ei_pload(&blB[11*PacketSize]);
|
||||||
|
|
||||||
C0 = cj_pmadd(B0, A0, C0);
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
C4 = cj_pmadd(B0, A1, C4);
|
C4 = cj_pmadd(A1, B0, C4);
|
||||||
B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]);
|
B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]);
|
||||||
C1 = cj_pmadd(B1, A0, C1);
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
C5 = cj_pmadd(B1, A1, C5);
|
C5 = cj_pmadd(A1, B1, C5);
|
||||||
B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]);
|
B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]);
|
||||||
if(nr==4) C2 = cj_pmadd(B2, A0, C2);
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C6 = cj_pmadd(B2, A1, C6);
|
if(nr==4) C6 = cj_pmadd(A1, B2, C6);
|
||||||
if(nr==4) B2 = ei_pload(&blB[14*PacketSize]);
|
if(nr==4) B2 = ei_pload(&blB[14*PacketSize]);
|
||||||
if(nr==4) C3 = cj_pmadd(B3, A0, C3);
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
A0 = ei_pload(&blA[6*PacketSize]);
|
A0 = ei_pload(&blA[6*PacketSize]);
|
||||||
if(nr==4) C7 = cj_pmadd(B3, A1, C7);
|
if(nr==4) C7 = cj_pmadd(A1, B3, C7);
|
||||||
A1 = ei_pload(&blA[7*PacketSize]);
|
A1 = ei_pload(&blA[7*PacketSize]);
|
||||||
if(nr==4) B3 = ei_pload(&blB[15*PacketSize]);
|
if(nr==4) B3 = ei_pload(&blB[15*PacketSize]);
|
||||||
C0 = cj_pmadd(B0, A0, C0);
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
C4 = cj_pmadd(B0, A1, C4);
|
C4 = cj_pmadd(A1, B0, C4);
|
||||||
C1 = cj_pmadd(B1, A0, C1);
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
C5 = cj_pmadd(B1, A1, C5);
|
C5 = cj_pmadd(A1, B1, C5);
|
||||||
if(nr==4) C2 = cj_pmadd(B2, A0, C2);
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C6 = cj_pmadd(B2, A1, C6);
|
if(nr==4) C6 = cj_pmadd(A1, B2, C6);
|
||||||
if(nr==4) C3 = cj_pmadd(B3, A0, C3);
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
if(nr==4) C7 = cj_pmadd(B3, A1, C7);
|
if(nr==4) C7 = cj_pmadd(A1, B3, C7);
|
||||||
|
|
||||||
blB += 4*nr*PacketSize;
|
blB += 4*nr*PacketSize;
|
||||||
blA += 4*mr;
|
blA += 4*mr;
|
||||||
@ -318,16 +327,16 @@ static void ei_cache_friendly_product(
|
|||||||
A1 = ei_pload(&blA[1*PacketSize]);
|
A1 = ei_pload(&blA[1*PacketSize]);
|
||||||
B0 = ei_pload(&blB[0*PacketSize]);
|
B0 = ei_pload(&blB[0*PacketSize]);
|
||||||
B1 = ei_pload(&blB[1*PacketSize]);
|
B1 = ei_pload(&blB[1*PacketSize]);
|
||||||
C0 = cj_pmadd(B0, A0, C0);
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
|
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
|
||||||
C4 = cj_pmadd(B0, A1, C4);
|
C4 = cj_pmadd(A1, B0, C4);
|
||||||
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
|
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
|
||||||
C1 = cj_pmadd(B1, A0, C1);
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
C5 = cj_pmadd(B1, A1, C5);
|
C5 = cj_pmadd(A1, B1, C5);
|
||||||
if(nr==4) C2 = cj_pmadd(B2, A0, C2);
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C6 = cj_pmadd(B2, A1, C6);
|
if(nr==4) C6 = cj_pmadd(A1, B2, C6);
|
||||||
if(nr==4) C3 = cj_pmadd(B3, A0, C3);
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
if(nr==4) C7 = cj_pmadd(B3, A1, C7);
|
if(nr==4) C7 = cj_pmadd(A1, B3, C7);
|
||||||
|
|
||||||
blB += nr*PacketSize;
|
blB += nr*PacketSize;
|
||||||
blA += mr;
|
blA += mr;
|
||||||
@ -359,12 +368,12 @@ static void ei_cache_friendly_product(
|
|||||||
A0 = blA[k];
|
A0 = blA[k];
|
||||||
B0 = blB[0*PacketSize];
|
B0 = blB[0*PacketSize];
|
||||||
B1 = blB[1*PacketSize];
|
B1 = blB[1*PacketSize];
|
||||||
C0 += B0 * A0;
|
C0 = cj_pmadd(A0, B0, C0);
|
||||||
if(nr==4) B2 = blB[2*PacketSize];
|
if(nr==4) B2 = blB[2*PacketSize];
|
||||||
if(nr==4) B3 = blB[3*PacketSize];
|
if(nr==4) B3 = blB[3*PacketSize];
|
||||||
C1 += B1 * A0;
|
C1 = cj_pmadd(A0, B1, C1);
|
||||||
if(nr==4) C2 += B2 * A0;
|
if(nr==4) C2 = cj_pmadd(A0, B2, C2);
|
||||||
if(nr==4) C3 += B3 * A0;
|
if(nr==4) C3 = cj_pmadd(A0, B3, C3);
|
||||||
|
|
||||||
blB += nr*PacketSize;
|
blB += nr*PacketSize;
|
||||||
}
|
}
|
||||||
@ -382,10 +391,10 @@ static void ei_cache_friendly_product(
|
|||||||
Scalar c0 = Scalar(0);
|
Scalar c0 = Scalar(0);
|
||||||
if (lhsRowMajor)
|
if (lhsRowMajor)
|
||||||
for(int k=0; k<actual_kc; k++)
|
for(int k=0; k<actual_kc; k++)
|
||||||
c0 += lhs[(k2+k)+(i2+i)*lhsStride] * rhs[j2*rhsStride + k2 + k];
|
c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0);
|
||||||
else
|
else
|
||||||
for(int k=0; k<actual_kc; k++)
|
for(int k=0; k<actual_kc; k++)
|
||||||
c0 += lhs[(k2+k)*lhsStride + i2+i] * rhs[j2*rhsStride + k2 + k];
|
c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0);
|
||||||
res[(j2)*resStride + i2+i] += alpha * c0;
|
res[(j2)*resStride + i2+i] += alpha * c0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -395,6 +404,8 @@ static void ei_cache_friendly_product(
|
|||||||
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
||||||
ei_aligned_stack_delete(Scalar, blockB, kc*cols*PacketSize);
|
ei_aligned_stack_delete(Scalar, blockB, kc*cols*PacketSize);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#else // alternate product from cylmor
|
#else // alternate product from cylmor
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
@ -482,39 +493,39 @@ static void ei_cache_friendly_product(
|
|||||||
L0 = ei_pload(&lb[1*PacketSize]);
|
L0 = ei_pload(&lb[1*PacketSize]);
|
||||||
R1 = ei_pload(&lb[2*PacketSize]);
|
R1 = ei_pload(&lb[2*PacketSize]);
|
||||||
L1 = ei_pload(&lb[3*PacketSize]);
|
L1 = ei_pload(&lb[3*PacketSize]);
|
||||||
T0 = cj_pmadd(R0, A0, T0);
|
T0 = cj_pmadd(A0, R0, T0);
|
||||||
T1 = cj_pmadd(L0, A0, T1);
|
T1 = cj_pmadd(A0, L0, T1);
|
||||||
R0 = ei_pload(&lb[4*PacketSize]);
|
R0 = ei_pload(&lb[4*PacketSize]);
|
||||||
L0 = ei_pload(&lb[5*PacketSize]);
|
L0 = ei_pload(&lb[5*PacketSize]);
|
||||||
T0 = cj_pmadd(R1, A1, T0);
|
T0 = cj_pmadd(A1, R1, T0);
|
||||||
T1 = cj_pmadd(L1, A1, T1);
|
T1 = cj_pmadd(A1, L1, T1);
|
||||||
R1 = ei_pload(&lb[6*PacketSize]);
|
R1 = ei_pload(&lb[6*PacketSize]);
|
||||||
L1 = ei_pload(&lb[7*PacketSize]);
|
L1 = ei_pload(&lb[7*PacketSize]);
|
||||||
T0 = cj_pmadd(R0, A2, T0);
|
T0 = cj_pmadd(A2, R0, T0);
|
||||||
T1 = cj_pmadd(L0, A2, T1);
|
T1 = cj_pmadd(A2, L0, T1);
|
||||||
if(MaxBlockRows==8)
|
if(MaxBlockRows==8)
|
||||||
{
|
{
|
||||||
R0 = ei_pload(&lb[8*PacketSize]);
|
R0 = ei_pload(&lb[8*PacketSize]);
|
||||||
L0 = ei_pload(&lb[9*PacketSize]);
|
L0 = ei_pload(&lb[9*PacketSize]);
|
||||||
}
|
}
|
||||||
T0 = cj_pmadd(R1, A3, T0);
|
T0 = cj_pmadd(A3, R1, T0);
|
||||||
T1 = cj_pmadd(L1, A3, T1);
|
T1 = cj_pmadd(A3, L1, T1);
|
||||||
if(MaxBlockRows==8)
|
if(MaxBlockRows==8)
|
||||||
{
|
{
|
||||||
R1 = ei_pload(&lb[10*PacketSize]);
|
R1 = ei_pload(&lb[10*PacketSize]);
|
||||||
L1 = ei_pload(&lb[11*PacketSize]);
|
L1 = ei_pload(&lb[11*PacketSize]);
|
||||||
T0 = cj_pmadd(R0, A4, T0);
|
T0 = cj_pmadd(A4, R0, T0);
|
||||||
T1 = cj_pmadd(L0, A4, T1);
|
T1 = cj_pmadd(A4, L0, T1);
|
||||||
R0 = ei_pload(&lb[12*PacketSize]);
|
R0 = ei_pload(&lb[12*PacketSize]);
|
||||||
L0 = ei_pload(&lb[13*PacketSize]);
|
L0 = ei_pload(&lb[13*PacketSize]);
|
||||||
T0 = cj_pmadd(R1, A5, T0);
|
T0 = cj_pmadd(A5, R1, T0);
|
||||||
T1 = cj_pmadd(L1, A5, T1);
|
T1 = cj_pmadd(A5, L1, T1);
|
||||||
R1 = ei_pload(&lb[14*PacketSize]);
|
R1 = ei_pload(&lb[14*PacketSize]);
|
||||||
L1 = ei_pload(&lb[15*PacketSize]);
|
L1 = ei_pload(&lb[15*PacketSize]);
|
||||||
T0 = cj_pmadd(R0, A6, T0);
|
T0 = cj_pmadd(A6, R0, T0);
|
||||||
T1 = cj_pmadd(L0, A6, T1);
|
T1 = cj_pmadd(A6, L0, T1);
|
||||||
T0 = cj_pmadd(R1, A7, T0);
|
T0 = cj_pmadd(A7, R1, T0);
|
||||||
T1 = cj_pmadd(L1, A7, T1);
|
T1 = cj_pmadd(A7, L1, T1);
|
||||||
}
|
}
|
||||||
lb += MaxBlockRows*2*PacketSize;
|
lb += MaxBlockRows*2*PacketSize;
|
||||||
|
|
||||||
|
@ -28,19 +28,18 @@ void test_product_large()
|
|||||||
{
|
{
|
||||||
for(int i = 0; i < g_repeat; i++) {
|
for(int i = 0; i < g_repeat; i++) {
|
||||||
CALL_SUBTEST( product(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
CALL_SUBTEST( product(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
||||||
//CALL_SUBTEST( product(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
CALL_SUBTEST( product(MatrixXd(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
||||||
// CALL_SUBTEST( product(MatrixXd(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
CALL_SUBTEST( product(MatrixXi(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
||||||
// CALL_SUBTEST( product(MatrixXi(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
|
||||||
// CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
|
CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
||||||
// CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// test a specific issue in DiagonalProduct
|
// test a specific issue in DiagonalProduct
|
||||||
// int N = 1000000;
|
int N = 1000000;
|
||||||
// VectorXf v = VectorXf::Ones(N);
|
VectorXf v = VectorXf::Ones(N);
|
||||||
// MatrixXf m = MatrixXf::Ones(N,3);
|
MatrixXf m = MatrixXf::Ones(N,3);
|
||||||
// m = (v+v).asDiagonal() * m;
|
m = (v+v).asDiagonal() * m;
|
||||||
// VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
|
VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user