mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-01 16:24:28 +08:00
fix openmp version for scalar types different than float
This commit is contained in:
parent
d13b877014
commit
62ac021606
@ -40,7 +40,7 @@ struct ei_general_matrix_matrix_product<Scalar,LhsStorageOrder,ConjugateLhs,RhsS
|
|||||||
const Scalar* rhs, int rhsStride,
|
const Scalar* rhs, int rhsStride,
|
||||||
Scalar* res, int resStride,
|
Scalar* res, int resStride,
|
||||||
Scalar alpha,
|
Scalar alpha,
|
||||||
GemmParallelInfo* info = 0)
|
GemmParallelInfo<Scalar>* info = 0)
|
||||||
{
|
{
|
||||||
// transpose the product such that the result is column major
|
// transpose the product such that the result is column major
|
||||||
ei_general_matrix_matrix_product<Scalar,
|
ei_general_matrix_matrix_product<Scalar,
|
||||||
@ -66,7 +66,7 @@ static void run(int rows, int cols, int depth,
|
|||||||
const Scalar* _rhs, int rhsStride,
|
const Scalar* _rhs, int rhsStride,
|
||||||
Scalar* res, int resStride,
|
Scalar* res, int resStride,
|
||||||
Scalar alpha,
|
Scalar alpha,
|
||||||
GemmParallelInfo* info = 0)
|
GemmParallelInfo<Scalar>* info = 0)
|
||||||
{
|
{
|
||||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
@ -218,11 +218,13 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
|
|||||||
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest>
|
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest>
|
||||||
struct ei_gemm_functor
|
struct ei_gemm_functor
|
||||||
{
|
{
|
||||||
|
typedef typename Rhs::Scalar BlockBScalar;
|
||||||
|
|
||||||
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
|
ei_gemm_functor(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
|
||||||
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
|
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo* info=0) const
|
void operator() (int row, int rows, int col=0, int cols=-1, GemmParallelInfo<BlockBScalar>* info=0) const
|
||||||
{
|
{
|
||||||
if(cols==-1)
|
if(cols==-1)
|
||||||
cols = m_rhs.cols();
|
cols = m_rhs.cols();
|
||||||
@ -234,6 +236,12 @@ struct ei_gemm_functor
|
|||||||
info);
|
info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int sharedBlockBSize() const
|
||||||
|
{
|
||||||
|
return std::min<int>(ei_product_blocking_traits<Scalar>::Max_kc,m_rhs.rows()) * m_rhs.cols();
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const Lhs& m_lhs;
|
const Lhs& m_lhs;
|
||||||
const Rhs& m_rhs;
|
const Rhs& m_rhs;
|
||||||
@ -275,7 +283,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
|||||||
_ActualRhsType,
|
_ActualRhsType,
|
||||||
Dest> GemmFunctor;
|
Dest> GemmFunctor;
|
||||||
|
|
||||||
ei_parallelize_gemm<Dest::MaxRowsAtCompileTime>32>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
|
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32)>(GemmFunctor(lhs, rhs, dst, actualAlpha), this->rows(), this->cols());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -25,16 +25,16 @@
|
|||||||
#ifndef EIGEN_PARALLELIZER_H
|
#ifndef EIGEN_PARALLELIZER_H
|
||||||
#define EIGEN_PARALLELIZER_H
|
#define EIGEN_PARALLELIZER_H
|
||||||
|
|
||||||
struct GemmParallelInfo
|
template<typename BlockBScalar> struct GemmParallelInfo
|
||||||
{
|
{
|
||||||
GemmParallelInfo() : sync(-1), users(0) {}
|
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
|
||||||
|
|
||||||
int volatile sync;
|
int volatile sync;
|
||||||
int volatile users;
|
int volatile users;
|
||||||
|
|
||||||
int rhs_start;
|
int rhs_start;
|
||||||
int rhs_length;
|
int rhs_length;
|
||||||
float* blockB;
|
BlockBScalar* blockB;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<bool Condition,typename Functor>
|
template<bool Condition,typename Functor>
|
||||||
@ -51,9 +51,10 @@ void ei_parallelize_gemm(const Functor& func, int rows, int cols)
|
|||||||
int blockCols = (cols / threads) & ~0x3;
|
int blockCols = (cols / threads) & ~0x3;
|
||||||
int blockRows = (rows / threads) & ~0x7;
|
int blockRows = (rows / threads) & ~0x7;
|
||||||
|
|
||||||
float* sharedBlockB = new float[2048*2048*4];
|
typedef typename Functor::BlockBScalar BlockBScalar;
|
||||||
|
BlockBScalar* sharedBlockB = new BlockBScalar[func.sharedBlockBSize()];
|
||||||
|
|
||||||
GemmParallelInfo* info = new GemmParallelInfo[threads];
|
GemmParallelInfo<BlockBScalar>* info = new GemmParallelInfo<BlockBScalar>[threads];
|
||||||
|
|
||||||
#pragma omp parallel for schedule(static,1)
|
#pragma omp parallel for schedule(static,1)
|
||||||
for(int i=0; i<threads; ++i)
|
for(int i=0; i<threads; ++i)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user