mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-10 23:09:06 +08:00
fix openmp for row major destination
This commit is contained in:
parent
0d9dc578dd
commit
be1fdbf3af
@ -373,8 +373,8 @@ class ei_gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols
|
|||||||
|
|
||||||
void allocateW()
|
void allocateW()
|
||||||
{
|
{
|
||||||
if(this->m_blockB==0)
|
if(this->m_blockW==0)
|
||||||
this->m_blockB = ei_aligned_new<RhsScalar>(m_sizeB);
|
this->m_blockW = ei_aligned_new<RhsScalar>(m_sizeW);
|
||||||
}
|
}
|
||||||
|
|
||||||
void allocateAll()
|
void allocateAll()
|
||||||
@ -432,7 +432,7 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
|||||||
|
|
||||||
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
|
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
|
||||||
|
|
||||||
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols());
|
ei_parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ template<typename Index> struct GemmParallelInfo
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<bool Condition, typename Functor, typename Index>
|
template<bool Condition, typename Functor, typename Index>
|
||||||
void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
void ei_parallelize_gemm(const Functor& func, Index rows, Index cols, bool transpose)
|
||||||
{
|
{
|
||||||
#ifndef EIGEN_HAS_OPENMP
|
#ifndef EIGEN_HAS_OPENMP
|
||||||
func(0,rows, 0,cols);
|
func(0,rows, 0,cols);
|
||||||
@ -98,9 +98,11 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
if((!Condition) || (omp_get_num_threads()>1))
|
if((!Condition) || (omp_get_num_threads()>1))
|
||||||
return func(0,rows, 0,cols);
|
return func(0,rows, 0,cols);
|
||||||
|
|
||||||
|
Index size = transpose ? cols : rows;
|
||||||
|
|
||||||
// 2- compute the maximal number of threads from the size of the product:
|
// 2- compute the maximal number of threads from the size of the product:
|
||||||
// FIXME this has to be fine tuned
|
// FIXME this has to be fine tuned
|
||||||
Index max_threads = std::max<Index>(1,rows / 32);
|
Index max_threads = std::max<Index>(1,size / 32);
|
||||||
|
|
||||||
// 3 - compute the number of threads we are going to use
|
// 3 - compute the number of threads we are going to use
|
||||||
Index threads = std::min<Index>(nbThreads(), max_threads);
|
Index threads = std::min<Index>(nbThreads(), max_threads);
|
||||||
@ -110,6 +112,9 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
|
|
||||||
func.initParallelSession();
|
func.initParallelSession();
|
||||||
|
|
||||||
|
if(transpose)
|
||||||
|
std::swap(rows,cols);
|
||||||
|
|
||||||
Index blockCols = (cols / threads) & ~Index(0x3);
|
Index blockCols = (cols / threads) & ~Index(0x3);
|
||||||
Index blockRows = (rows / threads) & ~Index(0x7);
|
Index blockRows = (rows / threads) & ~Index(0x7);
|
||||||
|
|
||||||
@ -127,7 +132,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
info[i].rhs_start = c0;
|
info[i].rhs_start = c0;
|
||||||
info[i].rhs_length = actualBlockCols;
|
info[i].rhs_length = actualBlockCols;
|
||||||
|
|
||||||
func(r0, actualBlockRows, 0,cols, info);
|
if(transpose)
|
||||||
|
func(0, cols, r0, actualBlockRows, info);
|
||||||
|
else
|
||||||
|
func(r0, actualBlockRows, 0,cols, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete[] info;
|
delete[] info;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user