mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
fix matrix product bug with OpenMP
This commit is contained in:
parent
8d27f55eb3
commit
1eea88bff7
@ -91,9 +91,10 @@ static void run(Index rows, Index cols, Index depth,
|
||||
// this is the parallel version!
|
||||
Index tid = omp_get_thread_num();
|
||||
Index threads = omp_get_num_threads();
|
||||
|
||||
LhsScalar* blockA = ei_aligned_stack_new(LhsScalar, kc*mc);
|
||||
|
||||
std::size_t sizeA = kc*mc;
|
||||
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
||||
LhsScalar* blockA = ei_aligned_stack_new(LhsScalar, sizeA);
|
||||
RhsScalar* w = ei_aligned_stack_new(RhsScalar, sizeW);
|
||||
RhsScalar* blockB = blocking.blockB();
|
||||
eigen_internal_assert(blockB!=0);
|
||||
@ -116,7 +117,7 @@ static void run(Index rows, Index cols, Index depth,
|
||||
while(info[tid].users!=0) {}
|
||||
info[tid].users += threads;
|
||||
|
||||
pack_rhs(blockB+info[tid].rhs_start*kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
|
||||
pack_rhs(blockB+info[tid].rhs_start*actual_kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
|
||||
|
||||
// Notify the other threads that the part B'_j is ready to go.
|
||||
info[tid].sync = k;
|
||||
@ -132,7 +133,7 @@ static void run(Index rows, Index cols, Index depth,
|
||||
if(shift>0)
|
||||
while(info[j].sync!=k) {}
|
||||
|
||||
gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
|
||||
gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*actual_kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
|
||||
}
|
||||
|
||||
// Then keep going as usual with the remaining A'
|
||||
@ -200,7 +201,7 @@ static void run(Index rows, Index cols, Index depth,
|
||||
}
|
||||
}
|
||||
|
||||
if(blocking.blockA()==0) ei_aligned_stack_delete(LhsScalar, blockA, kc*mc);
|
||||
if(blocking.blockA()==0) ei_aligned_stack_delete(LhsScalar, blockA, sizeA);
|
||||
if(blocking.blockB()==0) ei_aligned_stack_delete(RhsScalar, blockB, sizeB);
|
||||
if(blocking.blockW()==0) ei_aligned_stack_delete(RhsScalar, blockW, sizeW);
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ void parallelize_gemm(const Functor& func, Index rows, Index cols, bool transpos
|
||||
|
||||
Index blockCols = (cols / threads) & ~Index(0x3);
|
||||
Index blockRows = (rows / threads) & ~Index(0x7);
|
||||
|
||||
|
||||
GemmParallelInfo<Index>* info = new GemmParallelInfo<Index>[threads];
|
||||
|
||||
#pragma omp parallel for schedule(static,1) num_threads(threads)
|
||||
|
Loading…
x
Reference in New Issue
Block a user