mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-30 10:15:13 +08:00
add initial openmp support for matrix-matrix products
=> x1.9 speedup on my core2 duo
This commit is contained in:
parent
1a70f3b48d
commit
b20935be9b
@ -128,6 +128,49 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
|
||||
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
|
||||
{};
|
||||
|
||||
template<bool Prallelize,typename Functor>
|
||||
void ei_multithreaded_product(const Functor& func, int size)
|
||||
{
|
||||
if(!Prallelize)
|
||||
return func(0,size);
|
||||
#ifdef OMP
|
||||
int threads = omp_get_num_procs();
|
||||
#else
|
||||
int threads = 1;
|
||||
#endif
|
||||
int blockSize = size / threads;
|
||||
#pragma omp parallel for schedule(static,1)
|
||||
for(int i=0; i<threads; ++i)
|
||||
{
|
||||
int blockStart = i*blockSize;
|
||||
int actualBlockSize = std::min(blockSize, size - blockStart);
|
||||
|
||||
func(blockStart, actualBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback
|
||||
{
|
||||
ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
|
||||
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
|
||||
{}
|
||||
|
||||
void operator() (int start, int size) const
|
||||
{
|
||||
Gemm::run(m_lhs.rows(), size, m_lhs.cols(),
|
||||
(const Scalar*)&(m_lhs.const_cast_derived().coeffRef(0,0)), m_lhs.stride(),
|
||||
(const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,start)), m_rhs.stride(),
|
||||
(Scalar*)&(m_dest.coeffRef(0,start)), m_dest.stride(),
|
||||
m_actualAlpha);
|
||||
}
|
||||
|
||||
protected:
|
||||
const Lhs& m_lhs;
|
||||
const Rhs& m_rhs;
|
||||
mutable Dest& m_dest;
|
||||
Scalar m_actualAlpha;
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
|
||||
@ -151,17 +194,28 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
ei_general_matrix_matrix_product<
|
||||
Scalar,
|
||||
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(
|
||||
this->rows(), this->cols(), lhs.cols(),
|
||||
(const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
|
||||
(const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
||||
(Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
|
||||
actualAlpha);
|
||||
typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product<
|
||||
Scalar,
|
||||
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
||||
_ActualLhsType, _ActualRhsType, Dest> Functor;
|
||||
|
||||
#ifdef OMP
|
||||
ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols());
|
||||
#else
|
||||
ei_general_matrix_matrix_product<
|
||||
Scalar,
|
||||
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(
|
||||
this->rows(), this->cols(), lhs.cols(),
|
||||
(const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
|
||||
(const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
||||
(Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
|
||||
actualAlpha);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user