bug #923: fix EIGEN_USE_BLAS mode

This commit is contained in:
Gael Guennebaud 2015-06-23 11:13:24 +02:00
parent 307c4fc292
commit 793e4c6d77
3 changed files with 15 additions and 15 deletions

25
Eigen/src/Core/products/GeneralMatrixVector_MKL.h Normal file → Executable file
View File

@ -46,38 +46,37 @@ namespace internal {
// gemv specialization // gemv specialization
template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs> template<typename Index, typename LhsScalar, int StorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
struct general_matrix_vector_product_gemv : struct general_matrix_vector_product_gemv;
general_matrix_vector_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,ConjugateRhs,BuiltIn> {};
#define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \ #define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \
template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs,Specialized> { \ struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,Specialized> { \
static void run( \ static void run( \
Index rows, Index cols, \ Index rows, Index cols, \
const Scalar* lhs, Index lhsStride, \ const const_blas_data_mapper<Scalar,Index,ColMajor> &lhs, \
const Scalar* rhs, Index rhsIncr, \ const const_blas_data_mapper<Scalar,Index,RowMajor> &rhs, \
Scalar* res, Index resIncr, Scalar alpha) \ Scalar* res, Index resIncr, Scalar alpha) \
{ \ { \
if (ConjugateLhs) { \ if (ConjugateLhs) { \
general_matrix_vector_product<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs,BuiltIn>::run( \ general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,BuiltIn>::run( \
rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ rows, cols, lhs, rhs, res, resIncr, alpha); \
} else { \ } else { \
general_matrix_vector_product_gemv<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \ general_matrix_vector_product_gemv<Index,Scalar,ColMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \
rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ rows, cols, lhs.data(), lhs.stride(), rhs.data(), rhs.stride(), res, resIncr, alpha); \
} \ } \
} \ } \
}; \ }; \
template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product<Index,Scalar,RowMajor,ConjugateLhs,Scalar,ConjugateRhs,Specialized> { \ struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,RowMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ConjugateRhs,Specialized> { \
static void run( \ static void run( \
Index rows, Index cols, \ Index rows, Index cols, \
const Scalar* lhs, Index lhsStride, \ const const_blas_data_mapper<Scalar,Index,RowMajor> &lhs, \
const Scalar* rhs, Index rhsIncr, \ const const_blas_data_mapper<Scalar,Index,ColMajor> &rhs, \
Scalar* res, Index resIncr, Scalar alpha) \ Scalar* res, Index resIncr, Scalar alpha) \
{ \ { \
general_matrix_vector_product_gemv<Index,Scalar,RowMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \ general_matrix_vector_product_gemv<Index,Scalar,RowMajor,ConjugateLhs,Scalar,ConjugateRhs>::run( \
rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha); \ rows, cols, lhs.data(), lhs.stride(), rhs.data(), rhs.stride(), res, resIncr, alpha); \
} \ } \
}; \ }; \

4
Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h Normal file → Executable file
View File

@ -122,7 +122,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ MKL_INT aStride = aa_tmp.outerStride(); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
\ \
@ -236,7 +236,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ MKL_INT aStride = aa_tmp.outerStride(); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
\ \

1
Eigen/src/Core/util/BlasUtil.h Normal file → Executable file
View File

@ -224,6 +224,7 @@ class blas_data_mapper {
} }
const Index stride() const { return m_stride; } const Index stride() const { return m_stride; }
const Scalar* data() const { return m_data; }
Index firstAligned(Index size) const { Index firstAligned(Index size) const {
if (size_t(m_data)%sizeof(Scalar)) { if (size_t(m_data)%sizeof(Scalar)) {