diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index fc8496a4e..ef09226a4 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -100,12 +100,22 @@ struct triangular_solver_selector typedef typename Rhs::Index Index; typedef blas_traits LhsProductTraits; typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; + static void run(const Lhs& lhs, Rhs& rhs) { typename internal::add_const_on_value_type::type actualLhs = LhsProductTraits::extract(lhs); + + const Index size = lhs.rows(); + const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows(); + + typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, + Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType; + + BlockingType blocking(rhs.rows(), rhs.cols(), size); + triangular_solve_matrix - ::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride()); + ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking); } }; diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 545beebc2..76fc32032 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -79,7 +79,7 @@ static void run(Index rows, Index cols, Index depth, typedef gebp_traits Traits; - Index kc = blocking.kc(); // cache block size along the K direction + Index kc = blocking.kc(); // cache block size along the K direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction //Index nc = blocking.nc(); // cache block size along the N direction @@ -249,7 +249,7 @@ struct gemm_functor BlockingType& m_blocking; }; -template class gemm_blocking_space; template @@ -282,8 +282,8 @@ class level3_blocking inline RhsScalar* blockW() { return m_blockW; } }; -template -class gemm_blocking_space +template +class gemm_blocking_space : public level3_blocking< typename conditional::type, typename conditional::type> @@ -324,8 +324,8 @@ class gemm_blocking_space -class gemm_blocking_space +template +class gemm_blocking_space : public level3_blocking< typename conditional::type, typename conditional::type> @@ -349,7 +349,7 @@ class gemm_blocking_spacem_nc = Transpose ? rows : cols; this->m_kc = depth; - computeProductBlockingSizes(this->m_kc, this->m_mc, this->m_nc); + computeProductBlockingSizes(this->m_kc, this->m_mc, this->m_nc); m_sizeA = this->m_mc * this->m_kc; m_sizeB = this->m_kc * this->m_nc; m_sizeW = this->m_kc*Traits::WorkSpaceFactor; diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 4bba12cfe..0dd94638b 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -36,14 +36,15 @@ struct triangular_solve_matrix& blocking) { triangular_solve_matrix< Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft, (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper), NumTraits::IsComplex && Conjugate, TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor> - ::run(size, cols, tri, triStride, _other, otherStride); + ::run(size, cols, tri, triStride, _other, otherStride, blocking); } }; @@ -55,7 +56,8 @@ struct triangular_solve_matrix& blocking) { Index cols = otherSize; const_blas_data_mapper tri(_tri,triStride); @@ -67,17 +69,16 @@ struct triangular_solve_matrix(kc, mc, nc); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*cols; std::size_t sizeW = kc*Traits::WorkSpaceFactor; - std::size_t sizeB = sizeW + kc*cols; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB + sizeW; - Scalar* blockW = allocatedBlockB; + + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); conj_if conj; gebp_kernel gebp_kernel; @@ -181,7 +182,7 @@ struct triangular_solve_matrix& blocking) { Index rows = otherSize; const_blas_data_mapper rhs(_tri,triStride); @@ -210,19 +212,16 @@ struct triangular_solve_matrix(Traits::Max_kc/4,size); // cache block size along the K direction -// Index mc = std::min(Traits::Max_mc,size); // cache block size along the M direction - // check that !!!! - Index kc = size; // cache block size along the K direction - Index mc = size; // cache block size along the M direction - Index nc = rows; // cache block size along the N direction - computeProductBlockingSizes(kc, mc, nc); + Index kc = blocking.kc(); // cache block size along the K direction + Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction + std::size_t sizeA = kc*mc; + std::size_t sizeB = kc*size; std::size_t sizeW = kc*Traits::WorkSpaceFactor; - std::size_t sizeB = sizeW + kc*size; - ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); - ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); - Scalar* blockB = allocatedBlockB + sizeW; + + ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); + ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW()); conj_if conj; gebp_kernel gebp_kernel; @@ -289,7 +288,7 @@ struct triangular_solve_matrix0) gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, actual_mc, actual_kc, rs, Scalar(-1), - -1, -1, 0, 0, allocatedBlockB); + -1, -1, 0, 0, blockW); } } }