fix openmp for row major destination

This commit is contained in:
Gael Guennebaud 2010-07-03 12:52:39 +02:00
parent 0d9dc578dd
commit be1fdbf3af
2 changed files with 14 additions and 6 deletions

View File

@ -373,8 +373,8 @@ class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols
void allocateW()
{
if(this->m_blockB==0)
this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB);
if(this->m_blockW==0)
this->m_blockW = ei_aligned_new<RhsScalar>(m_sizeW);
}
void allocateAll()
@ -432,7 +432,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols());
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
}
};

View File

@ -81,7 +81,7 @@ template<typename Index> struct GemmParallelInfo
};
template<bool Condition, typename Functor, typename Index>
void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
void ei_parallelize_gemm(const Functor& func, Index rows, Index cols, bool transpose)
{
#ifndef EIGEN_HAS_OPENMP
func(0,rows, 0,cols);
@ -98,9 +98,11 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
if((!Condition) || (omp_get_num_threads()>1))
return func(0,rows, 0,cols);
Index size = transpose ? cols : rows;
// 2- compute the maximal number of threads from the size of the product:
// FIXME this has to be fine tuned
Index max_threads = std::max<Index>(1,rows / 32);
Index max_threads = std::max<Index>(1,size / 32);
// 3 - compute the number of threads we are going to use
Index threads = std::min<Index>(nbThreads(), max_threads);
@ -110,6 +112,9 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
func.initParallelSession();
if(transpose)
std::swap(rows,cols);
Index blockCols = (cols / threads) & ~Index(0x3);
Index blockRows = (rows / threads) & ~Index(0x7);
@ -127,7 +132,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
info[i].rhs_start = c0;
info[i].rhs_length = actualBlockCols;
func(r0, actualBlockRows, 0,cols, info);
if(transpose)
func(0, cols, r0, actualBlockRows, info);
else
func(r0, actualBlockRows, 0,cols, info);
}
delete[] info;