add initial openmp support for matrix-matrix products

=> x1.9 speedup on my core2 duo
This commit is contained in:
Gael Guennebaud 2010-02-22 09:40:34 +01:00
parent 1a70f3b48d
commit b20935be9b

View File

@ -128,6 +128,49 @@ struct ei_traits<GeneralProduct<Lhs,Rhs,GemmProduct> >
: ei_traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
{};
template<bool Prallelize,typename Functor>
void ei_multithreaded_product(const Functor& func, int size)
{
if(!Prallelize)
return func(0,size);
#ifdef OMP
int threads = omp_get_num_procs();
#else
int threads = 1;
#endif
int blockSize = size / threads;
#pragma omp parallel for schedule(static,1)
for(int i=0; i<threads; ++i)
{
int blockStart = i*blockSize;
int actualBlockSize = std::min(blockSize, size - blockStart);
func(blockStart, actualBlockSize);
}
}
template<typename Scalar, typename Gemm, typename Lhs, typename Rhs, typename Dest> struct ei_gemm_callback
{
ei_gemm_callback(const Lhs& lhs, const Rhs& rhs, Dest& dest, Scalar actualAlpha)
: m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha)
{}
void operator() (int start, int size) const
{
Gemm::run(m_lhs.rows(), size, m_lhs.cols(),
(const Scalar*)&(m_lhs.const_cast_derived().coeffRef(0,0)), m_lhs.stride(),
(const Scalar*)&(m_rhs.const_cast_derived().coeffRef(0,start)), m_rhs.stride(),
(Scalar*)&(m_dest.coeffRef(0,start)), m_dest.stride(),
m_actualAlpha);
}
protected:
const Lhs& m_lhs;
const Rhs& m_rhs;
mutable Dest& m_dest;
Scalar m_actualAlpha;
};
template<typename Lhs, typename Rhs>
class GeneralProduct<Lhs, Rhs, GemmProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
@ -151,17 +194,28 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_general_matrix_matrix_product<
Scalar,
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
::run(
this->rows(), this->cols(), lhs.cols(),
(const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
(const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
(Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
actualAlpha);
typedef ei_gemm_callback<Scalar,ei_general_matrix_matrix_product<
Scalar,
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
_ActualLhsType, _ActualRhsType, Dest> Functor;
#ifdef OMP
ei_multithreaded_product<true>(Functor(lhs, rhs, dst, actualAlpha), this->cols());
#else
ei_general_matrix_matrix_product<
Scalar,
(_ActualLhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
(_ActualRhsType::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>
::run(
this->rows(), this->cols(), lhs.cols(),
(const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
(const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
(Scalar*)&(dst.coeffRef(0,0)), dst.stride(),
actualAlpha);
#endif
}
};