mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
- add a low level mechanism to provide preallocated memory to gemm
- ensure static allocation for the product of "large" fixed size matrix
This commit is contained in:
parent
e039edcb42
commit
566867428c
@ -25,6 +25,8 @@
|
|||||||
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
|
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
|
||||||
#define EIGEN_GENERAL_MATRIX_MATRIX_H
|
#define EIGEN_GENERAL_MATRIX_MATRIX_H
|
||||||
|
|
||||||
|
template<typename _LhsScalar, typename _RhsScalar> class ei_level3_blocking;
|
||||||
|
|
||||||
/* Specialization for a row-major destination matrix => simple transposition of the product */
|
/* Specialization for a row-major destination matrix => simple transposition of the product */
|
||||||
template<
|
template<
|
||||||
typename Scalar, typename Index,
|
typename Scalar, typename Index,
|
||||||
@ -38,7 +40,8 @@ struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLh
|
|||||||
const Scalar* rhs, Index rhsStride,
|
const Scalar* rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha,
|
Scalar alpha,
|
||||||
GemmParallelInfo<Scalar, Index>* info = 0)
|
ei_level3_blocking<Scalar,Scalar>& blocking,
|
||||||
|
GemmParallelInfo<Index>* info = 0)
|
||||||
{
|
{
|
||||||
// transpose the product such that the result is column major
|
// transpose the product such that the result is column major
|
||||||
ei_general_matrix_matrix_product<Scalar, Index,
|
ei_general_matrix_matrix_product<Scalar, Index,
|
||||||
@ -47,7 +50,7 @@ struct ei_general_matrix_matrix_product<Scalar,Index,LhsStorageOrder,ConjugateLh
|
|||||||
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
||||||
ConjugateLhs,
|
ConjugateLhs,
|
||||||
ColMajor>
|
ColMajor>
|
||||||
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,info);
|
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -64,7 +67,8 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
const Scalar* _rhs, Index rhsStride,
|
const Scalar* _rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha,
|
Scalar alpha,
|
||||||
GemmParallelInfo<Scalar,Index>* info = 0)
|
ei_level3_blocking<Scalar,Scalar>& blocking,
|
||||||
|
GemmParallelInfo<Index>* info = 0)
|
||||||
{
|
{
|
||||||
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
@ -75,10 +79,9 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
typedef typename ei_packet_traits<Scalar>::type PacketType;
|
typedef typename ei_packet_traits<Scalar>::type PacketType;
|
||||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||||
|
|
||||||
Index kc = depth; // cache block size along the K direction
|
Index kc = blocking.kc(); // cache block size along the K direction
|
||||||
Index mc = rows; // cache block size along the M direction
|
Index mc = std::min(rows,blocking.mc()); // cache block size along the M direction
|
||||||
Index nc = cols; // cache block size along the N direction
|
//Index nc = blocking.nc(); // cache block size along the N direction
|
||||||
computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc);
|
|
||||||
|
|
||||||
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr, RhsStorageOrder> pack_rhs;
|
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr, RhsStorageOrder> pack_rhs;
|
||||||
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr, LhsStorageOrder> pack_lhs;
|
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr, LhsStorageOrder> pack_lhs;
|
||||||
@ -94,10 +97,10 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||||
std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr*8;
|
std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr*8;
|
||||||
Scalar* w = ei_aligned_stack_new(Scalar, sizeW);
|
Scalar* w = ei_aligned_stack_new(Scalar, sizeW);
|
||||||
Scalar* blockB = (Scalar*)info[tid].blockB;
|
Scalar* blockB = blocking.blockB();
|
||||||
|
ei_internal_assert(blockB!=0);
|
||||||
|
|
||||||
// For each horizontal panel of the rhs, and corresponding panel of the lhs...
|
// For each horizontal panel of the rhs, and corresponding vertical panel of the lhs...
|
||||||
// (==GEMM_VAR1)
|
|
||||||
for(Index k=0; k<depth; k+=kc)
|
for(Index k=0; k<depth; k+=kc)
|
||||||
{
|
{
|
||||||
const Index actual_kc = std::min(k+kc,depth)-k; // => rows of B', and cols of the A'
|
const Index actual_kc = std::min(k+kc,depth)-k; // => rows of B', and cols of the A'
|
||||||
@ -106,7 +109,7 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
// let's start by packing A'.
|
// let's start by packing A'.
|
||||||
pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc);
|
pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc);
|
||||||
|
|
||||||
// Pack B_k to B' in parallel fashion:
|
// Pack B_k to B' in a parallel fashion:
|
||||||
// each thread packs the sub block B_k,j to B'_j where j is the thread id.
|
// each thread packs the sub block B_k,j to B'_j where j is the thread id.
|
||||||
|
|
||||||
// However, before copying to B'_j, we have to make sure that no other thread is still using it,
|
// However, before copying to B'_j, we have to make sure that no other thread is still using it,
|
||||||
@ -162,10 +165,12 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
EIGEN_UNUSED_VARIABLE(info);
|
EIGEN_UNUSED_VARIABLE(info);
|
||||||
|
|
||||||
// this is the sequential version!
|
// this is the sequential version!
|
||||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
std::size_t sizeA = kc*mc;
|
||||||
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols;
|
std::size_t sizeB = kc*cols;
|
||||||
Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB);
|
std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr;
|
||||||
Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr;
|
Scalar *blockA = blocking.blockA()==0 ? ei_aligned_stack_new(Scalar, sizeA) : blocking.blockA();
|
||||||
|
Scalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(Scalar, sizeB) : blocking.blockB();
|
||||||
|
Scalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(Scalar, sizeW) : blocking.blockW();
|
||||||
|
|
||||||
// For each horizontal panel of the rhs, and corresponding panel of the lhs...
|
// For each horizontal panel of the rhs, and corresponding panel of the lhs...
|
||||||
// (==GEMM_VAR1)
|
// (==GEMM_VAR1)
|
||||||
@ -192,13 +197,14 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
|
pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
|
||||||
|
|
||||||
// Everything is packed, we can now call the block * panel kernel:
|
// Everything is packed, we can now call the block * panel kernel:
|
||||||
gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols);
|
gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, -1, -1, 0, 0, blockW);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
if(blocking.blockA()==0) ei_aligned_stack_delete(Scalar, blockA, kc*mc);
|
||||||
ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB);
|
if(blocking.blockB()==0) ei_aligned_stack_delete(Scalar, blockB, sizeB);
|
||||||
|
if(blocking.blockW()==0) ei_aligned_stack_delete(Scalar, blockW, sizeW);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,33 +220,25 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
|
|||||||
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
|
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
|
||||||
{};
|
{};
|
||||||
|
|
||||||
template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest>
|
template<typename Scalar, typename Index, typename Gemm, typename Lhs, typename Rhs, typename Dest, typename BlockingType>
|
||||||
struct ei_gemm_functor
|
struct ei_gemm_functor
|
||||||
{
|
{
|
||||||
typedef typename Rhs::Scalar BlockBScalar;
|
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha,
|
||||||
|
BlockingType& blocking)
|
||||||
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
|
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking)
|
||||||
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
|
|
||||||
{}
|
{}
|
||||||
|
|
||||||
void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<BlockBScalar,Index>* info=0) const
|
void operator() (Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo<Index>* info=0) const
|
||||||
{
|
{
|
||||||
if(cols==-1)
|
if(cols==-1)
|
||||||
cols = m_rhs.cols();
|
cols = m_rhs.cols();
|
||||||
|
if(info)
|
||||||
|
m_blocking.allocateB();
|
||||||
Gemm::run(rows, cols, m_lhs.cols(),
|
Gemm::run(rows, cols, m_lhs.cols(),
|
||||||
(const Scalar*)&(m_lhs.const_cast_derived().coeffRef(row,0)), m_lhs.outerStride(),
|
(const Scalar*)&(m_lhs.const_cast_derived().coeffRef(row,0)), m_lhs.outerStride(),
|
||||||
(const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,col)), m_rhs.outerStride(),
|
(const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,col)), m_rhs.outerStride(),
|
||||||
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
|
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
|
||||||
m_actualAlpha,
|
m_actualAlpha, m_blocking, info);
|
||||||
info);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Index sharedBlockBSize() const
|
|
||||||
{
|
|
||||||
Index kc = m_rhs.rows(), mc = m_lhs.rows(), nc = m_rhs.cols();
|
|
||||||
computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc);
|
|
||||||
return kc * nc;;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -248,12 +246,155 @@ struct ei_gemm_functor
|
|||||||
const Rhs& m_rhs;
|
const Rhs& m_rhs;
|
||||||
Dest& m_dest;
|
Dest& m_dest;
|
||||||
Scalar m_actualAlpha;
|
Scalar m_actualAlpha;
|
||||||
|
BlockingType& m_blocking;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth,
|
||||||
|
bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> struct ei_gemm_blocking_space;
|
||||||
|
|
||||||
|
template<typename _LhsScalar, typename _RhsScalar>
|
||||||
|
class ei_level3_blocking
|
||||||
|
{
|
||||||
|
typedef _LhsScalar LhsScalar;
|
||||||
|
typedef _RhsScalar RhsScalar;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
LhsScalar* m_blockA;
|
||||||
|
RhsScalar* m_blockB;
|
||||||
|
RhsScalar* m_blockW;
|
||||||
|
|
||||||
|
DenseIndex m_mc;
|
||||||
|
DenseIndex m_nc;
|
||||||
|
DenseIndex m_kc;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
ei_level3_blocking()
|
||||||
|
: m_blockA(0), m_blockB(0), m_blockW(0), m_mc(0), m_nc(0), m_kc(0)
|
||||||
|
{}
|
||||||
|
|
||||||
|
inline DenseIndex mc() const { return m_mc; }
|
||||||
|
inline DenseIndex nc() const { return m_nc; }
|
||||||
|
inline DenseIndex kc() const { return m_kc; }
|
||||||
|
|
||||||
|
inline LhsScalar* blockA() { return m_blockA; }
|
||||||
|
inline RhsScalar* blockB() { return m_blockB; }
|
||||||
|
inline RhsScalar* blockW() { return m_blockW; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
||||||
|
class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, true>
|
||||||
|
: public ei_level3_blocking<
|
||||||
|
typename ei_meta_if<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::ret,
|
||||||
|
typename ei_meta_if<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::ret>
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
Transpose = StorageOrder==RowMajor,
|
||||||
|
ActualRows = Transpose ? MaxCols : MaxRows,
|
||||||
|
ActualCols = Transpose ? MaxRows : MaxCols
|
||||||
|
};
|
||||||
|
typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar;
|
||||||
|
typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar;
|
||||||
|
typedef ei_product_blocking_traits<RhsScalar> Blocking;
|
||||||
|
enum {
|
||||||
|
SizeA = ActualCols * MaxDepth,
|
||||||
|
SizeB = ActualRows * MaxDepth,
|
||||||
|
SizeW = MaxDepth * Blocking::nr * ei_packet_traits<RhsScalar>::size
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_ALIGN16 LhsScalar m_staticA[SizeA];
|
||||||
|
EIGEN_ALIGN16 RhsScalar m_staticB[SizeB];
|
||||||
|
EIGEN_ALIGN16 RhsScalar m_staticW[SizeW];
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
ei_gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth)
|
||||||
|
{
|
||||||
|
this->m_mc = ActualRows;
|
||||||
|
this->m_nc = ActualCols;
|
||||||
|
this->m_kc = MaxDepth;
|
||||||
|
this->m_blockA = m_staticA;
|
||||||
|
this->m_blockB = m_staticB;
|
||||||
|
this->m_blockW = m_staticW;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void allocateA() {}
|
||||||
|
inline void allocateB() {}
|
||||||
|
inline void allocateW() {}
|
||||||
|
inline void allocateAll() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
||||||
|
struct ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, false>
|
||||||
|
: public ei_level3_blocking<
|
||||||
|
typename ei_meta_if<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::ret,
|
||||||
|
typename ei_meta_if<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::ret>
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
Transpose = StorageOrder==RowMajor
|
||||||
|
};
|
||||||
|
typedef typename ei_meta_if<Transpose,_RhsScalar,_LhsScalar>::ret LhsScalar;
|
||||||
|
typedef typename ei_meta_if<Transpose,_LhsScalar,_RhsScalar>::ret RhsScalar;
|
||||||
|
typedef ei_product_blocking_traits<RhsScalar> Blocking;
|
||||||
|
|
||||||
|
DenseIndex m_sizeA;
|
||||||
|
DenseIndex m_sizeB;
|
||||||
|
DenseIndex m_sizeW;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
ei_gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth)
|
||||||
|
{
|
||||||
|
this->m_mc = Transpose ? cols : rows;
|
||||||
|
this->m_nc = Transpose ? rows : cols;
|
||||||
|
this->m_kc = depth;
|
||||||
|
|
||||||
|
computeProductBlockingSizes<LhsScalar,RhsScalar>(this->m_kc, this->m_mc, this->m_nc);
|
||||||
|
m_sizeA = this->m_mc * this->m_kc;
|
||||||
|
m_sizeB = this->m_kc * this->m_nc;
|
||||||
|
m_sizeW = this->m_kc*ei_packet_traits<RhsScalar>::size*Blocking::nr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocateA()
|
||||||
|
{
|
||||||
|
if(this->m_blockA==0)
|
||||||
|
this->m_blockA = ei_aligned_new<LhsScalar>(m_sizeA);
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocateB()
|
||||||
|
{
|
||||||
|
if(this->m_blockB==0)
|
||||||
|
this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB);
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocateW()
|
||||||
|
{
|
||||||
|
if(this->m_blockB==0)
|
||||||
|
this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB);
|
||||||
|
}
|
||||||
|
|
||||||
|
void allocateAll()
|
||||||
|
{
|
||||||
|
allocateA();
|
||||||
|
allocateB();
|
||||||
|
allocateW();
|
||||||
|
}
|
||||||
|
|
||||||
|
~ei_gemm_blocking_space()
|
||||||
|
{
|
||||||
|
ei_aligned_delete(this->m_blockA, m_sizeA);
|
||||||
|
ei_aligned_delete(this->m_blockB, m_sizeB);
|
||||||
|
ei_aligned_delete(this->m_blockW, m_sizeW);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
class GeneralProduct<Lhs, Rhs, GemmProduct>
|
class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||||
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
|
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
|
||||||
{
|
{
|
||||||
|
enum {
|
||||||
|
MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
|
||||||
|
};
|
||||||
public:
|
public:
|
||||||
EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
|
EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
|
||||||
|
|
||||||
@ -273,6 +414,9 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
|||||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||||
|
|
||||||
|
typedef ei_gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
||||||
|
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
|
||||||
|
|
||||||
typedef ei_gemm_functor<
|
typedef ei_gemm_functor<
|
||||||
Scalar, Index,
|
Scalar, Index,
|
||||||
ei_general_matrix_matrix_product<
|
ei_general_matrix_matrix_product<
|
||||||
@ -280,11 +424,11 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
|||||||
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||||
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
||||||
_ActualLhsType,
|
_ActualLhsType, _ActualRhsType, Dest, BlockingType> GemmFunctor;
|
||||||
_ActualRhsType,
|
|
||||||
Dest> GemmFunctor;
|
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
|
||||||
|
|
||||||
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
|
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -69,16 +69,15 @@ inline void setNbThreads(int v)
|
|||||||
ei_manage_multi_threading(SetAction, &v);
|
ei_manage_multi_threading(SetAction, &v);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename BlockBScalar, typename Index> struct GemmParallelInfo
|
template<typename Index> struct GemmParallelInfo
|
||||||
{
|
{
|
||||||
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
|
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0) {}
|
||||||
|
|
||||||
int volatile sync;
|
int volatile sync;
|
||||||
int volatile users;
|
int volatile users;
|
||||||
|
|
||||||
Index rhs_start;
|
Index rhs_start;
|
||||||
Index rhs_length;
|
Index rhs_length;
|
||||||
BlockBScalar* blockB;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<bool Condition, typename Functor, typename Index>
|
template<bool Condition, typename Functor, typename Index>
|
||||||
@ -112,11 +111,7 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
Index blockCols = (cols / threads) & ~Index(0x3);
|
Index blockCols = (cols / threads) & ~Index(0x3);
|
||||||
Index blockRows = (rows / threads) & ~Index(0x7);
|
Index blockRows = (rows / threads) & ~Index(0x7);
|
||||||
|
|
||||||
typedef typename Functor::BlockBScalar BlockBScalar;
|
GemmParallelInfo<Index>* info = new GemmParallelInfo<Index>[threads];
|
||||||
BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()];
|
|
||||||
|
|
||||||
GemmParallelInfo<BlockBScalar,Index>* info = new
|
|
||||||
GemmParallelInfo<BlockBScalar,Index>[threads];
|
|
||||||
|
|
||||||
#pragma omp parallel for schedule(static,1) num_threads(threads)
|
#pragma omp parallel for schedule(static,1) num_threads(threads)
|
||||||
for(Index i=0; i<threads; ++i)
|
for(Index i=0; i<threads; ++i)
|
||||||
@ -129,12 +124,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
|
|
||||||
info[i].rhs_start = c0;
|
info[i].rhs_start = c0;
|
||||||
info[i].rhs_length = actualBlockCols;
|
info[i].rhs_length = actualBlockCols;
|
||||||
info[i].blockB = sharedBlockB;
|
|
||||||
|
|
||||||
func(r0, actualBlockRows, 0,cols, info);
|
func(r0, actualBlockRows, 0,cols, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete[] sharedBlockB;
|
|
||||||
delete[] info;
|
delete[] info;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -100,9 +100,7 @@ void ctms_decompositions()
|
|||||||
|
|
||||||
const Matrix A(Matrix::Random(size, size));
|
const Matrix A(Matrix::Random(size, size));
|
||||||
const ComplexMatrix complexA(ComplexMatrix::Random(size, size));
|
const ComplexMatrix complexA(ComplexMatrix::Random(size, size));
|
||||||
// const Matrix saA = A.adjoint() * A; // NOTE: This product allocates on the stack. The two following lines are a kludgy workaround
|
const Matrix saA = A.adjoint() * A;
|
||||||
Matrix saA(Matrix::Constant(size, size, 1.0));
|
|
||||||
saA.diagonal().setConstant(2.0);
|
|
||||||
|
|
||||||
// Cholesky module
|
// Cholesky module
|
||||||
Eigen::LLT<Matrix> LLT; LLT.compute(A);
|
Eigen::LLT<Matrix> LLT; LLT.compute(A);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user