Fix long to int conversion in BLAS API.

This commit is contained in:
Gael Guennebaud 2016-04-11 15:52:01 +02:00
parent 8191f373be
commit ddabc992fa
9 changed files with 123 additions and 120 deletions

View File

@ -72,7 +72,7 @@ EIGEN_MKL_RANKUPDATE_SPECIALIZE(float)
// EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex) // EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
// SYRK for float/double // SYRK for float/double
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \ #define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, BLASTYPE, MKLFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \ template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \ enum { \
@ -85,19 +85,19 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
{ \ { \
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \ /* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
\ \
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \
MKLTYPE alpha_, beta_; \ BLASTYPE alpha_, beta_; \
\ \
/* Set alpha_ & beta_ */ \ /* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \ assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \ MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \
} \ } \
}; };
// HERK for complex data // HERK for complex data
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \ #define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, MKLFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \ template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \ enum { \
@ -110,14 +110,14 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
{ \ { \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
\ \
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \
RTYPE alpha_, beta_; \ RTYPE alpha_, beta_; \
const EIGTYPE* a_ptr; \ const EIGTYPE* a_ptr; \
\ \
/* Set alpha_ & beta_ */ \ /* Set alpha_ & beta_ */ \
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); */\ /* assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(alpha_, alpha); */\
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \ /* assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
alpha_ = alpha.real(); \ alpha_ = alpha.real(); \
beta_ = 1.0; \ beta_ = 1.0; \
/* Copy with conjugation in some cases*/ \ /* Copy with conjugation in some cases*/ \
@ -128,7 +128,7 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
lda = a.outerStride(); \ lda = a.outerStride(); \
a_ptr = a.data(); \ a_ptr = a.data(); \
} else a_ptr=lhs; \ } else a_ptr=lhs; \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (MKLTYPE*)a_ptr, &lda, &beta_, (MKLTYPE*)res, &ldc); \ MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
} \ } \
}; };

View File

@ -46,7 +46,7 @@ namespace internal {
// gemm specialization // gemm specialization
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, MKLTYPE, MKLPREFIX) \ #define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, MKLPREFIX) \
template< \ template< \
typename Index, \ typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
@ -66,7 +66,7 @@ static void run(Index rows, Index cols, Index depth, \
using std::conj; \ using std::conj; \
\ \
char transa, transb; \ char transa, transb; \
MKL_INT m, n, k, lda, ldb, ldc; \ BlasIndex m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX a_tmp, b_tmp; \ MatrixX##EIGPREFIX a_tmp, b_tmp; \
@ -76,31 +76,31 @@ static void run(Index rows, Index cols, Index depth, \
transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
k = (MKL_INT)depth; \ k = convert_index<BlasIndex>(depth); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \ if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \
a_tmp = lhs.conjugate(); \ a_tmp = lhs.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _lhs; \ } else a = _lhs; \
\ \
if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \ if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \
b_tmp = rhs.conjugate(); \ b_tmp = rhs.conjugate(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \ } else b = _rhs; \
\ \
MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
}}; }};
GEMM_SPECIALIZATION(double, d, double, d) GEMM_SPECIALIZATION(double, d, double, d)

View File

@ -85,7 +85,7 @@ EIGEN_MKL_GEMV_SPECIALIZE(float)
EIGEN_MKL_GEMV_SPECIALIZE(dcomplex) EIGEN_MKL_GEMV_SPECIALIZE(dcomplex)
EIGEN_MKL_GEMV_SPECIALIZE(scomplex) EIGEN_MKL_GEMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \ #define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLPREFIX) \
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \ struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
{ \ { \
@ -97,13 +97,14 @@ static void run( \
const EIGTYPE* rhs, Index rhsIncr, \ const EIGTYPE* rhs, Index rhsIncr, \
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \ EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
{ \ { \
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \ BlasIndex m=convert_index<BlasIndex>(rows), n=convert_index<BlasIndex>(cols), \
lda=convert_index<BlasIndex>(lhsStride), incx=convert_index<BlasIndex>(rhsIncr), incy=convert_index<BlasIndex>(resIncr); \
const EIGTYPE beta(1); \ const EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \ const EIGTYPE *x_ptr; \
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \ char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
if (LhsStorageOrder==RowMajor) { \ if (LhsStorageOrder==RowMajor) { \
m = cols; \ m = convert_index<BlasIndex>(cols); \
n = rows; \ n = convert_index<BlasIndex>(rows); \
}\ }\
GEMVVector x_tmp; \ GEMVVector x_tmp; \
if (ConjugateRhs) { \ if (ConjugateRhs) { \
@ -112,7 +113,7 @@ static void run( \
x_ptr=x_tmp.data(); \ x_ptr=x_tmp.data(); \
incx=1; \ incx=1; \
} else x_ptr=rhs; \ } else x_ptr=rhs; \
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\ }\
}; };

View File

@ -40,7 +40,7 @@ namespace internal {
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */ /* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -55,20 +55,20 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='L', uplo='L'; \ char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
\ \
/* Set transpose options */ \ /* Set transpose options */ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (LhsStorageOrder==RowMajor) uplo='U'; \ if (LhsStorageOrder==RowMajor) uplo='U'; \
@ -78,16 +78,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.adjoint(); \ b_tmp = rhs.adjoint(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \ } else b = _rhs; \
\ \
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -101,7 +101,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='L', uplo='L'; \ char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
@ -109,13 +109,13 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
\ \
/* Set transpose options */ \ /* Set transpose options */ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \ if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
@ -141,10 +141,10 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
b_tmp = rhs.transpose(); \ b_tmp = rhs.transpose(); \
} \ } \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \ } \
\ \
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
@ -157,7 +157,7 @@ EIGEN_MKL_HEMM_L(scomplex, float, cf, c)
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */ /* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -172,19 +172,19 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='R', uplo='L'; \ char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
ldb = (MKL_INT)lhsStride; \ ldb = convert_index<BlasIndex>(lhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (RhsStorageOrder==RowMajor) uplo='U'; \ if (RhsStorageOrder==RowMajor) uplo='U'; \
@ -194,16 +194,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = lhs.adjoint(); \ b_tmp = lhs.adjoint(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _lhs; \ } else b = _lhs; \
\ \
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -217,27 +217,27 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='R', uplo='L'; \ char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \ Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
ldb = (MKL_INT)lhsStride; \ ldb = convert_index<BlasIndex>(lhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \ if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \ Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
a_tmp = rhs.conjugate(); \ a_tmp = rhs.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _rhs; \ } else a = _rhs; \
if (RhsStorageOrder==RowMajor) uplo='U'; \ if (RhsStorageOrder==RowMajor) uplo='U'; \
\ \
@ -259,7 +259,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = b_tmp.outerStride(); \ ldb = b_tmp.outerStride(); \
} \ } \
\ \
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
} \ } \
}; };

View File

@ -71,7 +71,7 @@ EIGEN_MKL_SYMV_SPECIALIZE(float)
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex) EIGEN_MKL_SYMV_SPECIALIZE(dcomplex)
EIGEN_MKL_SYMV_SPECIALIZE(scomplex) EIGEN_MKL_SYMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \ #define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLFUNC) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \ struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
{ \ { \
@ -85,7 +85,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \ IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
IsLower = UpLo == Lower ? 1 : 0 \ IsLower = UpLo == Lower ? 1 : 0 \
}; \ }; \
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \ BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \ const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \ char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
@ -95,7 +95,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
x_tmp=map_x.conjugate(); \ x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \ x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \ } else x_ptr=_rhs; \
MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\ }\
}; };

View File

@ -75,7 +75,7 @@ EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false) EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
// implements col-major += alpha * op(triangular) * op(general) // implements col-major += alpha * op(triangular) * op(general)
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, int Mode, \ template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -122,7 +122,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
/* Make sense to call GEMM */ \ /* Make sense to call GEMM */ \
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
@ -134,11 +134,11 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
char side = 'L', transa, uplo, diag = 'N'; \ char side = 'L', transa, uplo, diag = 'N'; \
EIGTYPE *b; \ EIGTYPE *b; \
const EIGTYPE *a; \ const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \ BlasIndex m, n, lda, ldb; \
\ \
/* Set m, n */ \ /* Set m, n */ \
m = (MKL_INT)diagSize; \ m = convert_index<BlasIndex>(diagSize); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set trans */ \ /* Set trans */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
@ -149,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
\ \
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\ \
/* Set uplo */ \ /* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \ uplo = IsLower ? 'L' : 'U'; \
@ -165,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
else if (IsUnitDiag) \ else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\ a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _lhs; \ a = _lhs; \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
} \ } \
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
/* call ?trmm*/ \ /* call ?trmm*/ \
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\ \
/* Add op(a_triangular)*b into res*/ \ /* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -186,7 +186,7 @@ EIGEN_MKL_TRMM_L(float, float, f, s)
EIGEN_MKL_TRMM_L(scomplex, float, cf, c) EIGEN_MKL_TRMM_L(scomplex, float, cf, c)
// implements col-major += alpha * op(general) * op(triangular) // implements col-major += alpha * op(general) * op(triangular)
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template <typename Index, int Mode, \ template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -232,7 +232,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
/* Make sense to call GEMM */ \ /* Make sense to call GEMM */ \
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
@ -244,11 +244,11 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
char side = 'R', transa, uplo, diag = 'N'; \ char side = 'R', transa, uplo, diag = 'N'; \
EIGTYPE *b; \ EIGTYPE *b; \
const EIGTYPE *a; \ const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \ BlasIndex m, n, lda, ldb; \
\ \
/* Set m, n */ \ /* Set m, n */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)diagSize; \ n = convert_index<BlasIndex>(diagSize); \
\ \
/* Set trans */ \ /* Set trans */ \
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
@ -259,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
\ \
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\ \
/* Set uplo */ \ /* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \ uplo = IsLower ? 'L' : 'U'; \
@ -275,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
else if (IsUnitDiag) \ else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\ a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _rhs; \ a = _rhs; \
lda = rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
} \ } \
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
/* call ?trmm*/ \ /* call ?trmm*/ \
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\ \
/* Add op(a_triangular)*b into res*/ \ /* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \

View File

@ -71,7 +71,7 @@ EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
EIGEN_MKL_TRMV_SPECIALIZE(scomplex) EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
// implements col-major: res += alpha * op(triangular) * vector // implements col-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
enum { \ enum { \
@ -105,15 +105,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\ /* Square part handling */\
\ \
char trans, uplo, diag; \ char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \ BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \ EIGTYPE const *a; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
\ \
/* Set m, n */ \ /* Set m, n */ \
n = (MKL_INT)size; \ n = convert_index<BlasIndex>(size); \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \ incx = 1; \
incy = resIncr; \ incy = convert_index<BlasIndex>(resIncr); \
\ \
/* Set uplo, trans and diag*/ \ /* Set uplo, trans and diag*/ \
trans = 'N'; \ trans = 'N'; \
@ -121,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \ MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \ MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@ -132,17 +132,17 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = _res + size*resIncr; \
a = _lhs + size; \ a = _lhs + size; \
m = rows-size; \ m = convert_index<BlasIndex>(rows-size); \
n = size; \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = _res; \
a = _lhs + size*lda; \ a = _lhs + size*lda; \
m = size; \ m = convert_index<BlasIndex>(size); \
n = cols-size; \ n = convert_index<BlasIndex>(cols-size); \
} \ } \
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \ MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \ } \
} \ } \
}; };
@ -153,7 +153,7 @@ EIGEN_MKL_TRMV_CM(float, float, f, s)
EIGEN_MKL_TRMV_CM(scomplex, float, cf, c) EIGEN_MKL_TRMV_CM(scomplex, float, cf, c)
// implements row-major: res += alpha * op(triangular) * vector // implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_MKL_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
enum { \ enum { \
@ -187,15 +187,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\ /* Square part handling */\
\ \
char trans, uplo, diag; \ char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \ BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \ EIGTYPE const *a; \
EIGTYPE beta(1); \ EIGTYPE beta(1); \
\ \
/* Set m, n */ \ /* Set m, n */ \
n = (MKL_INT)size; \ n = convert_index<BlasIndex>(size); \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \ incx = 1; \
incy = resIncr; \ incy = convert_index<BlasIndex>(resIncr); \
\ \
/* Set uplo, trans and diag*/ \ /* Set uplo, trans and diag*/ \
trans = ConjLhs ? 'C' : 'T'; \ trans = ConjLhs ? 'C' : 'T'; \
@ -203,10 +203,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \ MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \ MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@ -214,17 +214,17 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = _res + size*resIncr; \
a = _lhs + size*lda; \ a = _lhs + size*lda; \
m = rows-size; \ m = convert_index<BlasIndex>(rows-size); \
n = size; \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = _res; \
a = _lhs + size; \ a = _lhs + size; \
m = size; \ m = convert_index<BlasIndex>(size); \
n = cols-size; \ n = convert_index<BlasIndex>(cols-size); \
} \ } \
MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \ MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \ } \
} \ } \
}; };

View File

@ -38,7 +38,7 @@ namespace Eigen {
namespace internal { namespace internal {
// implements LeftSide op(triangular)^-1 * general // implements LeftSide op(triangular)^-1 * general
#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \ #define EIGEN_MKL_TRSM_L(EIGTYPE, BLASTYPE, MKLPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \ { \
@ -53,11 +53,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
const EIGTYPE* _tri, Index triStride, \ const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \ { \
MKL_INT m = size, n = otherSize, lda, ldb; \ BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
char side = 'L', uplo, diag='N', transa; \ char side = 'L', uplo, diag='N', transa; \
/* Set alpha_ */ \ /* Set alpha_ */ \
EIGTYPE alpha(1); \ EIGTYPE alpha(1); \
ldb = otherStride;\ ldb = convert_index<BlasIndex>(otherStride);\
\ \
const EIGTYPE *a; \ const EIGTYPE *a; \
/* Set trans */ \ /* Set trans */ \
@ -73,14 +73,14 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
if (conjA) { \ if (conjA) { \
a_tmp = tri.conjugate(); \ a_tmp = tri.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _tri; \ a = _tri; \
lda = triStride; \ lda = convert_index<BlasIndex>(triStride); \
} \ } \
if (IsUnitDiag) diag='U'; \ if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \ /* call ?trsm*/ \
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
} \ } \
}; };
@ -91,7 +91,7 @@ EIGEN_MKL_TRSM_L(scomplex, float, c)
// implements RightSide general * op(triangular)^-1 // implements RightSide general * op(triangular)^-1
#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \ #define EIGEN_MKL_TRSM_R(EIGTYPE, BLASTYPE, MKLPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \ { \
@ -106,11 +106,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
const EIGTYPE* _tri, Index triStride, \ const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \ { \
MKL_INT m = otherSize, n = size, lda, ldb; \ BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
char side = 'R', uplo, diag='N', transa; \ char side = 'R', uplo, diag='N', transa; \
/* Set alpha_ */ \ /* Set alpha_ */ \
EIGTYPE alpha(1); \ EIGTYPE alpha(1); \
ldb = otherStride;\ ldb = convert_index<BlasIndex>(otherStride);\
\ \
const EIGTYPE *a; \ const EIGTYPE *a; \
/* Set trans */ \ /* Set trans */ \
@ -126,14 +126,14 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
if (conjA) { \ if (conjA) { \
a_tmp = tri.conjugate(); \ a_tmp = tri.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _tri; \ a = _tri; \
lda = triStride; \ lda = convert_index<BlasIndex>(triStride); \
} \ } \
if (IsUnitDiag) diag='U'; \ if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \ /* call ?trsm*/ \
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \ /*std::cout << "TRMS_L specialization!\n";*/ \
} \ } \
}; };

View File

@ -114,7 +114,9 @@ typedef std::complex<double> dcomplex;
typedef std::complex<float> scomplex; typedef std::complex<float> scomplex;
#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) #if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL)
typedef int MKL_INT; typedef int BlasIndex;
#else
typedef MKL_INT BlasIndex;
#endif #endif
namespace internal { namespace internal {