mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-30 15:54:13 +08:00
Fix long to int conversion in BLAS API.
This commit is contained in:
parent
8191f373be
commit
ddabc992fa
@ -72,7 +72,7 @@ EIGEN_MKL_RANKUPDATE_SPECIALIZE(float)
|
||||
// EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
|
||||
|
||||
// SYRK for float/double
|
||||
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \
|
||||
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, BLASTYPE, MKLFUNC) \
|
||||
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
|
||||
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
|
||||
enum { \
|
||||
@ -85,19 +85,19 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
{ \
|
||||
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
|
||||
\
|
||||
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
|
||||
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
|
||||
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
BLASTYPE alpha_, beta_; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
||||
assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
||||
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
// HERK for complex data
|
||||
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \
|
||||
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, MKLFUNC) \
|
||||
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
|
||||
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
|
||||
enum { \
|
||||
@ -110,14 +110,14 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
{ \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
|
||||
\
|
||||
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
|
||||
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
|
||||
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \
|
||||
RTYPE alpha_, beta_; \
|
||||
const EIGTYPE* a_ptr; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); */\
|
||||
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
|
||||
/* assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(alpha_, alpha); */\
|
||||
/* assign_scalar_eig2mkl<BLASTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
|
||||
alpha_ = alpha.real(); \
|
||||
beta_ = 1.0; \
|
||||
/* Copy with conjugation in some cases*/ \
|
||||
@ -128,7 +128,7 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
lda = a.outerStride(); \
|
||||
a_ptr = a.data(); \
|
||||
} else a_ptr=lhs; \
|
||||
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (MKLTYPE*)a_ptr, &lda, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
|
@ -46,7 +46,7 @@ namespace internal {
|
||||
|
||||
// gemm specialization
|
||||
|
||||
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, MKLTYPE, MKLPREFIX) \
|
||||
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, MKLPREFIX) \
|
||||
template< \
|
||||
typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
@ -66,7 +66,7 @@ static void run(Index rows, Index cols, Index depth, \
|
||||
using std::conj; \
|
||||
\
|
||||
char transa, transb; \
|
||||
MKL_INT m, n, k, lda, ldb, ldc; \
|
||||
BlasIndex m, n, k, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX a_tmp, b_tmp; \
|
||||
@ -76,31 +76,31 @@ static void run(Index rows, Index cols, Index depth, \
|
||||
transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
k = (MKL_INT)depth; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
k = convert_index<BlasIndex>(depth); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \
|
||||
a_tmp = lhs.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else a = _lhs; \
|
||||
\
|
||||
if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \
|
||||
b_tmp = rhs.conjugate(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _rhs; \
|
||||
\
|
||||
MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
|
||||
MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
}};
|
||||
|
||||
GEMM_SPECIALIZATION(double, d, double, d)
|
||||
|
@ -85,7 +85,7 @@ EIGEN_MKL_GEMV_SPECIALIZE(float)
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(scomplex)
|
||||
|
||||
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \
|
||||
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLPREFIX) \
|
||||
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
|
||||
{ \
|
||||
@ -97,13 +97,14 @@ static void run( \
|
||||
const EIGTYPE* rhs, Index rhsIncr, \
|
||||
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
|
||||
{ \
|
||||
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \
|
||||
BlasIndex m=convert_index<BlasIndex>(rows), n=convert_index<BlasIndex>(cols), \
|
||||
lda=convert_index<BlasIndex>(lhsStride), incx=convert_index<BlasIndex>(rhsIncr), incy=convert_index<BlasIndex>(resIncr); \
|
||||
const EIGTYPE beta(1); \
|
||||
const EIGTYPE *x_ptr; \
|
||||
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
|
||||
if (LhsStorageOrder==RowMajor) { \
|
||||
m = cols; \
|
||||
n = rows; \
|
||||
m = convert_index<BlasIndex>(cols); \
|
||||
n = convert_index<BlasIndex>(rows); \
|
||||
}\
|
||||
GEMVVector x_tmp; \
|
||||
if (ConjugateRhs) { \
|
||||
@ -112,7 +113,7 @@ static void run( \
|
||||
x_ptr=x_tmp.data(); \
|
||||
incx=1; \
|
||||
} else x_ptr=rhs; \
|
||||
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \
|
||||
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
||||
}\
|
||||
};
|
||||
|
||||
|
@ -40,7 +40,7 @@ namespace internal {
|
||||
|
||||
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
|
||||
|
||||
#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -55,20 +55,20 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='L', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
\
|
||||
/* Set transpose options */ \
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (LhsStorageOrder==RowMajor) uplo='U'; \
|
||||
@ -78,16 +78,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
|
||||
b_tmp = rhs.adjoint(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _rhs; \
|
||||
\
|
||||
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
|
||||
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -101,7 +101,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='L', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
@ -109,13 +109,13 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
\
|
||||
/* Set transpose options */ \
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
|
||||
@ -141,10 +141,10 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
b_tmp = rhs.transpose(); \
|
||||
} \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} \
|
||||
\
|
||||
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
|
||||
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
@ -157,7 +157,7 @@ EIGEN_MKL_HEMM_L(scomplex, float, cf, c)
|
||||
|
||||
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
|
||||
|
||||
#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -172,19 +172,19 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='R', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)rhsStride; \
|
||||
ldb = (MKL_INT)lhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
ldb = convert_index<BlasIndex>(lhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (RhsStorageOrder==RowMajor) uplo='U'; \
|
||||
@ -194,16 +194,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
|
||||
b_tmp = lhs.adjoint(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _lhs; \
|
||||
\
|
||||
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
|
||||
MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -217,27 +217,27 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='R', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)rhsStride; \
|
||||
ldb = (MKL_INT)lhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
ldb = convert_index<BlasIndex>(lhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
|
||||
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
|
||||
a_tmp = rhs.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else a = _rhs; \
|
||||
if (RhsStorageOrder==RowMajor) uplo='U'; \
|
||||
\
|
||||
@ -259,7 +259,7 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
ldb = b_tmp.outerStride(); \
|
||||
} \
|
||||
\
|
||||
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \
|
||||
MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
|
@ -71,7 +71,7 @@ EIGEN_MKL_SYMV_SPECIALIZE(float)
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(scomplex)
|
||||
|
||||
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \
|
||||
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLFUNC) \
|
||||
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
|
||||
{ \
|
||||
@ -85,7 +85,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
|
||||
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
|
||||
IsLower = UpLo == Lower ? 1 : 0 \
|
||||
}; \
|
||||
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \
|
||||
BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
|
||||
EIGTYPE beta(1); \
|
||||
const EIGTYPE *x_ptr; \
|
||||
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
|
||||
@ -95,7 +95,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
|
||||
x_tmp=map_x.conjugate(); \
|
||||
x_ptr=x_tmp.data(); \
|
||||
} else x_ptr=_rhs; \
|
||||
MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \
|
||||
MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
||||
}\
|
||||
};
|
||||
|
||||
|
@ -75,7 +75,7 @@ EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
|
||||
|
||||
// implements col-major += alpha * op(triangular) * op(general)
|
||||
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -122,7 +122,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
/* Make sense to call GEMM */ \
|
||||
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
|
||||
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
||||
MKL_INT aStride = aa_tmp.outerStride(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
@ -134,11 +134,11 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
char side = 'L', transa, uplo, diag = 'N'; \
|
||||
EIGTYPE *b; \
|
||||
const EIGTYPE *a; \
|
||||
MKL_INT m, n, lda, ldb; \
|
||||
BlasIndex m, n, lda, ldb; \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
m = (MKL_INT)diagSize; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(diagSize); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set trans */ \
|
||||
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
|
||||
@ -149,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
\
|
||||
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
\
|
||||
/* Set uplo */ \
|
||||
uplo = IsLower ? 'L' : 'U'; \
|
||||
@ -165,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
else if (IsUnitDiag) \
|
||||
a_tmp.diagonal().setOnes();\
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _lhs; \
|
||||
lda = lhsStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
} \
|
||||
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
|
||||
/* call ?trmm*/ \
|
||||
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
||||
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
|
||||
\
|
||||
/* Add op(a_triangular)*b into res*/ \
|
||||
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
||||
@ -186,7 +186,7 @@ EIGEN_MKL_TRMM_L(float, float, f, s)
|
||||
EIGEN_MKL_TRMM_L(scomplex, float, cf, c)
|
||||
|
||||
// implements col-major += alpha * op(general) * op(triangular)
|
||||
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -232,7 +232,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
/* Make sense to call GEMM */ \
|
||||
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
|
||||
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
||||
MKL_INT aStride = aa_tmp.outerStride(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
@ -244,11 +244,11 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
char side = 'R', transa, uplo, diag = 'N'; \
|
||||
EIGTYPE *b; \
|
||||
const EIGTYPE *a; \
|
||||
MKL_INT m, n, lda, ldb; \
|
||||
BlasIndex m, n, lda, ldb; \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)diagSize; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(diagSize); \
|
||||
\
|
||||
/* Set trans */ \
|
||||
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
|
||||
@ -259,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
\
|
||||
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
\
|
||||
/* Set uplo */ \
|
||||
uplo = IsLower ? 'L' : 'U'; \
|
||||
@ -275,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
else if (IsUnitDiag) \
|
||||
a_tmp.diagonal().setOnes();\
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _rhs; \
|
||||
lda = rhsStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
} \
|
||||
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
|
||||
/* call ?trmm*/ \
|
||||
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
||||
MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
|
||||
\
|
||||
/* Add op(a_triangular)*b into res*/ \
|
||||
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
||||
|
@ -71,7 +71,7 @@ EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
|
||||
|
||||
// implements col-major: res += alpha * op(triangular) * vector
|
||||
#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
||||
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
|
||||
enum { \
|
||||
@ -105,15 +105,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
/* Square part handling */\
|
||||
\
|
||||
char trans, uplo, diag; \
|
||||
MKL_INT m, n, lda, incx, incy; \
|
||||
BlasIndex m, n, lda, incx, incy; \
|
||||
EIGTYPE const *a; \
|
||||
EIGTYPE beta(1); \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
n = (MKL_INT)size; \
|
||||
lda = lhsStride; \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
incx = 1; \
|
||||
incy = resIncr; \
|
||||
incy = convert_index<BlasIndex>(resIncr); \
|
||||
\
|
||||
/* Set uplo, trans and diag*/ \
|
||||
trans = 'N'; \
|
||||
@ -121,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
diag = IsUnitDiag ? 'U' : 'N'; \
|
||||
\
|
||||
/* call ?TRMV*/ \
|
||||
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
||||
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
|
||||
\
|
||||
/* Add op(a_tr)rhs into res*/ \
|
||||
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
||||
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
||||
if (size<(std::max)(rows,cols)) { \
|
||||
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
||||
@ -132,17 +132,17 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
if (size<rows) { \
|
||||
y = _res + size*resIncr; \
|
||||
a = _lhs + size; \
|
||||
m = rows-size; \
|
||||
n = size; \
|
||||
m = convert_index<BlasIndex>(rows-size); \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
} \
|
||||
else { \
|
||||
x += size; \
|
||||
y = _res; \
|
||||
a = _lhs + size*lda; \
|
||||
m = size; \
|
||||
n = cols-size; \
|
||||
m = convert_index<BlasIndex>(size); \
|
||||
n = convert_index<BlasIndex>(cols-size); \
|
||||
} \
|
||||
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
|
||||
MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
@ -153,7 +153,7 @@ EIGEN_MKL_TRMV_CM(float, float, f, s)
|
||||
EIGEN_MKL_TRMV_CM(scomplex, float, cf, c)
|
||||
|
||||
// implements row-major: res += alpha * op(triangular) * vector
|
||||
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
||||
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
|
||||
enum { \
|
||||
@ -187,15 +187,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
/* Square part handling */\
|
||||
\
|
||||
char trans, uplo, diag; \
|
||||
MKL_INT m, n, lda, incx, incy; \
|
||||
BlasIndex m, n, lda, incx, incy; \
|
||||
EIGTYPE const *a; \
|
||||
EIGTYPE beta(1); \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
n = (MKL_INT)size; \
|
||||
lda = lhsStride; \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
incx = 1; \
|
||||
incy = resIncr; \
|
||||
incy = convert_index<BlasIndex>(resIncr); \
|
||||
\
|
||||
/* Set uplo, trans and diag*/ \
|
||||
trans = ConjLhs ? 'C' : 'T'; \
|
||||
@ -203,10 +203,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
diag = IsUnitDiag ? 'U' : 'N'; \
|
||||
\
|
||||
/* call ?TRMV*/ \
|
||||
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
||||
MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
|
||||
\
|
||||
/* Add op(a_tr)rhs into res*/ \
|
||||
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
||||
MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
||||
if (size<(std::max)(rows,cols)) { \
|
||||
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
||||
@ -214,17 +214,17 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
if (size<rows) { \
|
||||
y = _res + size*resIncr; \
|
||||
a = _lhs + size*lda; \
|
||||
m = rows-size; \
|
||||
n = size; \
|
||||
m = convert_index<BlasIndex>(rows-size); \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
} \
|
||||
else { \
|
||||
x += size; \
|
||||
y = _res; \
|
||||
a = _lhs + size; \
|
||||
m = size; \
|
||||
n = cols-size; \
|
||||
m = convert_index<BlasIndex>(size); \
|
||||
n = convert_index<BlasIndex>(cols-size); \
|
||||
} \
|
||||
MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
|
||||
MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
@ -38,7 +38,7 @@ namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// implements LeftSide op(triangular)^-1 * general
|
||||
#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRSM_L(EIGTYPE, BLASTYPE, MKLPREFIX) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
{ \
|
||||
@ -53,11 +53,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
MKL_INT m = size, n = otherSize, lda, ldb; \
|
||||
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
|
||||
char side = 'L', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
EIGTYPE alpha(1); \
|
||||
ldb = otherStride;\
|
||||
ldb = convert_index<BlasIndex>(otherStride);\
|
||||
\
|
||||
const EIGTYPE *a; \
|
||||
/* Set trans */ \
|
||||
@ -73,14 +73,14 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
|
||||
if (conjA) { \
|
||||
a_tmp = tri.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _tri; \
|
||||
lda = triStride; \
|
||||
lda = convert_index<BlasIndex>(triStride); \
|
||||
} \
|
||||
if (IsUnitDiag) diag='U'; \
|
||||
/* call ?trsm*/ \
|
||||
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
|
||||
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
|
||||
} \
|
||||
};
|
||||
|
||||
@ -91,7 +91,7 @@ EIGEN_MKL_TRSM_L(scomplex, float, c)
|
||||
|
||||
|
||||
// implements RightSide general * op(triangular)^-1
|
||||
#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
||||
#define EIGEN_MKL_TRSM_R(EIGTYPE, BLASTYPE, MKLPREFIX) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
{ \
|
||||
@ -106,11 +106,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
MKL_INT m = otherSize, n = size, lda, ldb; \
|
||||
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
|
||||
char side = 'R', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
EIGTYPE alpha(1); \
|
||||
ldb = otherStride;\
|
||||
ldb = convert_index<BlasIndex>(otherStride);\
|
||||
\
|
||||
const EIGTYPE *a; \
|
||||
/* Set trans */ \
|
||||
@ -126,14 +126,14 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
|
||||
if (conjA) { \
|
||||
a_tmp = tri.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _tri; \
|
||||
lda = triStride; \
|
||||
lda = convert_index<BlasIndex>(triStride); \
|
||||
} \
|
||||
if (IsUnitDiag) diag='U'; \
|
||||
/* call ?trsm*/ \
|
||||
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
|
||||
MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
|
||||
/*std::cout << "TRMS_L specialization!\n";*/ \
|
||||
} \
|
||||
};
|
||||
|
@ -114,7 +114,9 @@ typedef std::complex<double> dcomplex;
|
||||
typedef std::complex<float> scomplex;
|
||||
|
||||
#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL)
|
||||
typedef int MKL_INT;
|
||||
typedef int BlasIndex;
|
||||
#else
|
||||
typedef MKL_INT BlasIndex;
|
||||
#endif
|
||||
|
||||
namespace internal {
|
||||
|
Loading…
x
Reference in New Issue
Block a user