diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 6f7dee743..b6123ca8b 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -77,21 +77,19 @@ static void run(int rows, int cols, int depth, typedef typename ei_packet_traits::type PacketType; typedef ei_product_blocking_traits Blocking; -// int kc = std::min(Blocking::Max_kc,depth); // cache block size along the K direction -// int mc = std::min(Blocking::Max_mc,rows); // cache block size along the M direction - - int kc = std::min(256,depth); // cache block size along the K direction - int mc = std::min(512,rows); // cache block size along the M direction + int kc = std::min(Blocking::Max_kc,depth); // cache block size along the K direction + int mc = std::min(Blocking::Max_mc,rows); // cache block size along the M direction ei_gemm_pack_rhs pack_rhs; ei_gemm_pack_lhs pack_lhs; ei_gebp_kernel > gebp; - #ifdef EIGEN_HAS_OPENMP +#ifdef EIGEN_HAS_OPENMP if(info) { // this is the parallel version! int tid = omp_get_thread_num(); + int threads = omp_get_num_threads(); Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); std::size_t sizeW = kc*Blocking::PacketSize*Blocking::nr*8; @@ -109,20 +107,48 @@ static void run(int rows, int cols, int depth, // (==GEMM_VAR1) for(int k=0; k rows of B', and cols of the A' + #ifndef USEGOTOROUTINES pack_rhs(blockB+info[tid].rhs_start*kc, &rhs(k,info[tid].rhs_start), rhsStride, alpha, actual_kc, info[tid].rhs_length); #else sgemm_oncopy(actual_kc, info[tid].rhs_length, &rhs(k,info[tid].rhs_start), rhsStride, blockB+info[tid].rhs_start*kc); #endif - #pragma omp barrier +#if 0 + // this is an attempt to implement a smarter strategy as suggested by Aron + // the layout is good, but there is no synchronization yet + { + const int actual_mc = mc; + // pack to A' + pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, actual_mc); + + // use our current thread's B' part right away, no need to wait for the other threads + sgemm_kernel(actual_mc, info[tid].rhs_length, actual_kc, alpha, blockA, blockB+info[tid].rhs_start*kc, res+info[tid].rhs_start*resStride, resStride); + + for(int shift=1; shift