fix matrix product bug with OpenMP

This commit is contained in:
Gael Guennebaud 2010-11-03 16:12:37 +01:00
parent 8d27f55eb3
commit 1eea88bff7
2 changed files with 7 additions and 6 deletions

View File

@ -91,9 +91,10 @@ static void run(Index rows, Index cols, Index depth,
// this is the parallel version!
Index tid = omp_get_thread_num();
Index threads = omp_get_num_threads();
LhsScalar* blockA = ei_aligned_stack_new(LhsScalar, kc*mc);
std::size_t sizeA = kc*mc;
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
LhsScalar* blockA = ei_aligned_stack_new(LhsScalar, sizeA);
RhsScalar* w = ei_aligned_stack_new(RhsScalar, sizeW);
RhsScalar* blockB = blocking.blockB();
eigen_internal_assert(blockB!=0);
@ -116,7 +117,7 @@ static void run(Index rows, Index cols, Index depth,
while(info[tid].users!=0) {}
info[tid].users += threads;
pack_rhs(blockB+info[tid].rhs_start*kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
pack_rhs(blockB+info[tid].rhs_start*actual_kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
// Notify the other threads that the part B'_j is ready to go.
info[tid].sync = k;
@ -132,7 +133,7 @@ static void run(Index rows, Index cols, Index depth,
if(shift>0)
while(info[j].sync!=k) {}
gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*actual_kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
}
// Then keep going as usual with the remaining A'
@ -200,7 +201,7 @@ static void run(Index rows, Index cols, Index depth,
}
}
if(blocking.blockA()==0) ei_aligned_stack_delete(LhsScalar, blockA, kc*mc);
if(blocking.blockA()==0) ei_aligned_stack_delete(LhsScalar, blockA, sizeA);
if(blocking.blockB()==0) ei_aligned_stack_delete(RhsScalar, blockB, sizeB);
if(blocking.blockW()==0) ei_aligned_stack_delete(RhsScalar, blockW, sizeW);
}

View File

@ -124,7 +124,7 @@ void parallelize_gemm(const Functor& func, Index rows, Index cols, bool transpos
Index blockCols = (cols / threads) & ~Index(0x3);
Index blockRows = (rows / threads) & ~Index(0x7);
GemmParallelInfo<Index>* info = new GemmParallelInfo<Index>[threads];
#pragma omp parallel for schedule(static,1) num_threads(threads)