mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-11 23:51:49 +08:00
bug #1741: fix C.noalias() = A*C; with C.innerStride()!=1
(grafted from ea0d5dc956c1268dd91ce636d8fd5e07225acb06 )
This commit is contained in:
parent
32cb4853c6
commit
f483c7ea8a
@ -20,8 +20,9 @@ template<typename _LhsScalar, typename _RhsScalar> class level3_blocking;
|
|||||||
template<
|
template<
|
||||||
typename Index,
|
typename Index,
|
||||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
|
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
|
int ResInnerStride>
|
||||||
|
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
|
||||||
{
|
{
|
||||||
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
|
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
|||||||
Index rows, Index cols, Index depth,
|
Index rows, Index cols, Index depth,
|
||||||
const LhsScalar* lhs, Index lhsStride,
|
const LhsScalar* lhs, Index lhsStride,
|
||||||
const RhsScalar* rhs, Index rhsStride,
|
const RhsScalar* rhs, Index rhsStride,
|
||||||
ResScalar* res, Index resStride,
|
ResScalar* res, Index resIncr, Index resStride,
|
||||||
ResScalar alpha,
|
ResScalar alpha,
|
||||||
level3_blocking<RhsScalar,LhsScalar>& blocking,
|
level3_blocking<RhsScalar,LhsScalar>& blocking,
|
||||||
GemmParallelInfo<Index>* info = 0)
|
GemmParallelInfo<Index>* info = 0)
|
||||||
@ -39,8 +40,8 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
|||||||
general_matrix_matrix_product<Index,
|
general_matrix_matrix_product<Index,
|
||||||
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
|
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
|
||||||
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
|
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
|
||||||
ColMajor>
|
ColMajor,ResInnerStride>
|
||||||
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
|
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -49,8 +50,9 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
|||||||
template<
|
template<
|
||||||
typename Index,
|
typename Index,
|
||||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
|
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
|
int ResInnerStride>
|
||||||
|
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
|
||||||
{
|
{
|
||||||
|
|
||||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||||
@ -59,17 +61,17 @@ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScala
|
|||||||
static void run(Index rows, Index cols, Index depth,
|
static void run(Index rows, Index cols, Index depth,
|
||||||
const LhsScalar* _lhs, Index lhsStride,
|
const LhsScalar* _lhs, Index lhsStride,
|
||||||
const RhsScalar* _rhs, Index rhsStride,
|
const RhsScalar* _rhs, Index rhsStride,
|
||||||
ResScalar* _res, Index resStride,
|
ResScalar* _res, Index resIncr, Index resStride,
|
||||||
ResScalar alpha,
|
ResScalar alpha,
|
||||||
level3_blocking<LhsScalar,RhsScalar>& blocking,
|
level3_blocking<LhsScalar,RhsScalar>& blocking,
|
||||||
GemmParallelInfo<Index>* info = 0)
|
GemmParallelInfo<Index>* info = 0)
|
||||||
{
|
{
|
||||||
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
|
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
|
||||||
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
|
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
|
||||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
|
||||||
LhsMapper lhs(_lhs,lhsStride);
|
LhsMapper lhs(_lhs, lhsStride);
|
||||||
RhsMapper rhs(_rhs,rhsStride);
|
RhsMapper rhs(_rhs, rhsStride);
|
||||||
ResMapper res(_res, resStride);
|
ResMapper res(_res, resStride, resIncr);
|
||||||
|
|
||||||
Index kc = blocking.kc(); // cache block size along the K direction
|
Index kc = blocking.kc(); // cache block size along the K direction
|
||||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||||
@ -226,7 +228,7 @@ struct gemm_functor
|
|||||||
Gemm::run(rows, cols, m_lhs.cols(),
|
Gemm::run(rows, cols, m_lhs.cols(),
|
||||||
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
|
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
|
||||||
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
|
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
|
||||||
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
|
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(),
|
||||||
m_actualAlpha, m_blocking, info);
|
m_actualAlpha, m_blocking, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -476,7 +478,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
|||||||
Index,
|
Index,
|
||||||
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||||
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||||
|
Dest::InnerStrideAtCompileTime>,
|
||||||
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
|
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
|
||||||
|
|
||||||
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
|
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
|
||||||
|
@ -51,20 +51,22 @@ template< \
|
|||||||
typename Index, \
|
typename Index, \
|
||||||
int LhsStorageOrder, bool ConjugateLhs, \
|
int LhsStorageOrder, bool ConjugateLhs, \
|
||||||
int RhsStorageOrder, bool ConjugateRhs> \
|
int RhsStorageOrder, bool ConjugateRhs> \
|
||||||
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor> \
|
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
|
||||||
{ \
|
{ \
|
||||||
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
|
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
|
||||||
\
|
\
|
||||||
static void run(Index rows, Index cols, Index depth, \
|
static void run(Index rows, Index cols, Index depth, \
|
||||||
const EIGTYPE* _lhs, Index lhsStride, \
|
const EIGTYPE* _lhs, Index lhsStride, \
|
||||||
const EIGTYPE* _rhs, Index rhsStride, \
|
const EIGTYPE* _rhs, Index rhsStride, \
|
||||||
EIGTYPE* res, Index resStride, \
|
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||||
EIGTYPE alpha, \
|
EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, \
|
||||||
GemmParallelInfo<Index>* /*info = 0*/) \
|
GemmParallelInfo<Index>* /*info = 0*/) \
|
||||||
{ \
|
{ \
|
||||||
using std::conj; \
|
using std::conj; \
|
||||||
\
|
\
|
||||||
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||||
|
eigen_assert(resIncr == 1); \
|
||||||
char transa, transb; \
|
char transa, transb; \
|
||||||
BlasIndex m, n, k, lda, ldb, ldc; \
|
BlasIndex m, n, k, lda, ldb, ldc; \
|
||||||
const EIGTYPE *a, *b; \
|
const EIGTYPE *a, *b; \
|
||||||
|
@ -124,8 +124,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
|||||||
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
||||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
|
||||||
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
rows, cols, depth, aa_tmp.data(), aStride, _rhs, 1, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||||
\
|
\
|
||||||
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||||
} \
|
} \
|
||||||
@ -241,8 +241,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
|||||||
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
||||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
|
||||||
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
|
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
|
||||||
\
|
\
|
||||||
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||||
} \
|
} \
|
||||||
|
@ -31,7 +31,7 @@ template<
|
|||||||
typename Index,
|
typename Index,
|
||||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||||
int ResStorageOrder>
|
int ResStorageOrder, int ResInnerStride>
|
||||||
struct general_matrix_matrix_product;
|
struct general_matrix_matrix_product;
|
||||||
|
|
||||||
template<typename Index,
|
template<typename Index,
|
||||||
@ -155,13 +155,21 @@ class BlasVectorMapper {
|
|||||||
Scalar* m_data;
|
Scalar* m_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Scalar, typename Index, int AlignmentType, int Incr=1>
|
||||||
|
class BlasLinearMapper;
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int AlignmentType>
|
template<typename Scalar, typename Index, int AlignmentType>
|
||||||
class BlasLinearMapper {
|
class BlasLinearMapper<Scalar,Index,AlignmentType,1> {
|
||||||
public:
|
public:
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1)
|
||||||
|
: m_data(data)
|
||||||
|
{
|
||||||
|
EIGEN_ONLY_USED_FOR_DEBUG(incr);
|
||||||
|
eigen_assert(incr==1);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
|
||||||
internal::prefetch(&operator()(i));
|
internal::prefetch(&operator()(i));
|
||||||
@ -188,16 +196,25 @@ class BlasLinearMapper {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Lightweight helper class to access matrix coefficients.
|
// Lightweight helper class to access matrix coefficients.
|
||||||
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
|
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
|
||||||
class blas_data_mapper {
|
class blas_data_mapper;
|
||||||
public:
|
|
||||||
|
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType>
|
||||||
|
class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
|
||||||
|
{
|
||||||
|
public:
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
|
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
|
||||||
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
|
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
|
||||||
|
: m_data(data), m_stride(stride)
|
||||||
|
{
|
||||||
|
EIGEN_ONLY_USED_FOR_DEBUG(incr);
|
||||||
|
eigen_assert(incr==1);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
|
||||||
getSubMapper(Index i, Index j) const {
|
getSubMapper(Index i, Index j) const {
|
||||||
@ -251,6 +268,90 @@ class blas_data_mapper {
|
|||||||
const Index m_stride;
|
const Index m_stride;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Implementation of non-natural increment (i.e. inner-stride != 1)
|
||||||
|
// The exposed API is not complete yet compared to the Incr==1 case
|
||||||
|
// because some features makes less sense in this case.
|
||||||
|
template<typename Scalar, typename Index, int AlignmentType, int Incr>
|
||||||
|
class BlasLinearMapper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
|
||||||
|
internal::prefetch(&operator()(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
|
||||||
|
return m_data[i*m_incr.value()];
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||||
|
return pgather<Scalar,Packet>(m_data + i*m_incr.value(), m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename PacketType>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const {
|
||||||
|
pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Scalar *m_data;
|
||||||
|
const internal::variable_if_dynamic<Index,Incr> m_incr;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType,int Incr>
|
||||||
|
class blas_data_mapper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
|
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
|
||||||
|
getSubMapper(Index i, Index j) const {
|
||||||
|
return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||||
|
return LinearMapper(&operator()(i, j), m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
|
||||||
|
return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride];
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
|
return pgather<Scalar,Packet>(&operator()(i, j),m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PacketT, int AlignmentT>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
|
||||||
|
return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SubPacket>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
|
||||||
|
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SubPacket>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
|
||||||
|
return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Scalar* EIGEN_RESTRICT m_data;
|
||||||
|
const Index m_stride;
|
||||||
|
const internal::variable_if_dynamic<Index,Incr> m_incr;
|
||||||
|
};
|
||||||
|
|
||||||
// lightweight helper class to access matrix coefficients (const version)
|
// lightweight helper class to access matrix coefficients (const version)
|
||||||
template<typename Scalar, typename Index, int StorageOrder>
|
template<typename Scalar, typename Index, int StorageOrder>
|
||||||
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
|
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
|
||||||
|
@ -13,28 +13,28 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const
|
|||||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||||
{
|
{
|
||||||
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
|
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
|
||||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
|
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
|
||||||
static const functype func[12] = {
|
static const functype func[12] = {
|
||||||
// array index: NOTR | (NOTR << 2)
|
// array index: NOTR | (NOTR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor,1>::run),
|
||||||
// array index: TR | (NOTR << 2)
|
// array index: TR | (NOTR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor,1>::run),
|
||||||
// array index: ADJ | (NOTR << 2)
|
// array index: ADJ | (NOTR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,1>::run),
|
||||||
0,
|
0,
|
||||||
// array index: NOTR | (TR << 2)
|
// array index: NOTR | (TR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor,1>::run),
|
||||||
// array index: TR | (TR << 2)
|
// array index: TR | (TR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor,1>::run),
|
||||||
// array index: ADJ | (TR << 2)
|
// array index: ADJ | (TR << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor,1>::run),
|
||||||
0,
|
0,
|
||||||
// array index: NOTR | (ADJ << 2)
|
// array index: NOTR | (ADJ << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||||
// array index: TR | (ADJ << 2)
|
// array index: TR | (ADJ << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||||
// array index: ADJ | (ADJ << 2)
|
// array index: ADJ | (ADJ << 2)
|
||||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run),
|
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const
|
|||||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
|
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
|
||||||
|
|
||||||
int code = OP(*opa) | (OP(*opb) << 2);
|
int code = OP(*opa) | (OP(*opb) << 2);
|
||||||
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
|
func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,4 +239,19 @@ template<typename MatrixType> void product(const MatrixType& m)
|
|||||||
VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate());
|
VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// destination with a non-default inner-stride
|
||||||
|
// see bug 1741
|
||||||
|
if(!MatrixType::IsRowMajor)
|
||||||
|
{
|
||||||
|
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||||
|
MatrixX buffer(2*rows,2*rows);
|
||||||
|
Map<RowSquareMatrixType,0,Stride<Dynamic,2> > map1(buffer.data(),rows,rows,Stride<Dynamic,2>(2*rows,2));
|
||||||
|
buffer.setZero();
|
||||||
|
VERIFY_IS_APPROX(map1 = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||||
|
buffer.setZero();
|
||||||
|
VERIFY_IS_APPROX(map1.noalias() = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||||
|
buffer.setZero();
|
||||||
|
VERIFY_IS_APPROX(map1.noalias() += m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user