From 793e4c6d770e30bbc335f8cee6a8f388fc2c9330 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 23 Jun 2015 11:13:24 +0200 Subject: [PATCH] bug #923: fix EIGEN_USE_BLAS mode --- .../Core/products/GeneralMatrixVector_MKL.h | 25 +++++++++---------- .../products/TriangularMatrixMatrix_MKL.h | 4 +-- Eigen/src/Core/util/BlasUtil.h | 1 + 3 files changed, 15 insertions(+), 15 deletions(-) mode change 100644 => 100755 Eigen/src/Core/products/GeneralMatrixVector_MKL.h mode change 100644 => 100755 Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h mode change 100644 => 100755 Eigen/src/Core/util/BlasUtil.h diff --git a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h b/Eigen/src/Core/products/GeneralMatrixVector_MKL.h old mode 100644 new mode 100755 index 1cb9fe6b5..12c3d13bd --- a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixVector_MKL.h @@ -46,38 +46,37 @@ namespace internal { // gemv specialization -template -struct general_matrix_vector_product_gemv : - general_matrix_vector_product {}; +template +struct general_matrix_vector_product_gemv; #define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \ template \ -struct general_matrix_vector_product { \ +struct general_matrix_vector_product,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper,ConjugateRhs,Specialized> { \ static void run( \ Index rows, Index cols, \ - const Scalar* lhs, Index lhsStride, \ - const Scalar* rhs, Index rhsIncr, \ + const const_blas_data_mapper &lhs, \ + const const_blas_data_mapper &rhs, \ Scalar* res, Index resIncr, Scalar alpha) \ { \ if (ConjugateLhs) { \ - general_matrix_vector_product::run( \ - rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ + general_matrix_vector_product,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper,ConjugateRhs,BuiltIn>::run( \ + rows, cols, lhs, rhs, res, resIncr, alpha); \ } else { \ general_matrix_vector_product_gemv::run( \ - rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ + rows, cols, lhs.data(), lhs.stride(), rhs.data(), rhs.stride(), res, resIncr, alpha); \ } \ } \ }; \ template \ -struct general_matrix_vector_product { \ +struct general_matrix_vector_product,RowMajor,ConjugateLhs,Scalar,const_blas_data_mapper,ConjugateRhs,Specialized> { \ static void run( \ Index rows, Index cols, \ - const Scalar* lhs, Index lhsStride, \ - const Scalar* rhs, Index rhsIncr, \ + const const_blas_data_mapper &lhs, \ + const const_blas_data_mapper &rhs, \ Scalar* res, Index resIncr, Scalar alpha) \ { \ general_matrix_vector_product_gemv::run( \ - rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ + rows, cols, lhs.data(), lhs.stride(), rhs.data(), rhs.stride(), res, resIncr, alpha); \ } \ }; \ diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h old mode 100644 new mode 100755 index 4cc56a42f..d9e7cf852 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h @@ -122,7 +122,7 @@ struct product_triangular_matrix_matrix_trmm > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ MatrixLhs aa_tmp=lhsMap.template triangularView(); \ MKL_INT aStride = aa_tmp.outerStride(); \ - gemm_blocking_space gemm_blocking(_rows,_cols,_depth); \ + gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ general_matrix_matrix_product::run( \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ \ @@ -236,7 +236,7 @@ struct product_triangular_matrix_matrix_trmm > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ MatrixRhs aa_tmp=rhsMap.template triangularView(); \ MKL_INT aStride = aa_tmp.outerStride(); \ - gemm_blocking_space gemm_blocking(_rows,_cols,_depth); \ + gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ general_matrix_matrix_product::run( \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ \ diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h old mode 100644 new mode 100755 index ffeb5ac5f..934948ebd --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -224,6 +224,7 @@ class blas_data_mapper { } const Index stride() const { return m_stride; } + const Scalar* data() const { return m_data; } Index firstAligned(Index size) const { if (size_t(m_data)%sizeof(Scalar)) {