Pulled latest update from trunk

This commit is contained in:
Benoit Steiner 2016-04-11 11:03:02 -07:00
commit e939b087fe
26 changed files with 716 additions and 794 deletions

View File

@ -450,14 +450,14 @@ using std::ptrdiff_t;
#include "src/Core/ArrayWrapper.h" #include "src/Core/ArrayWrapper.h"
#ifdef EIGEN_USE_BLAS #ifdef EIGEN_USE_BLAS
#include "src/Core/products/GeneralMatrixMatrix_MKL.h" #include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
#include "src/Core/products/GeneralMatrixVector_MKL.h" #include "src/Core/products/GeneralMatrixVector_BLAS.h"
#include "src/Core/products/GeneralMatrixMatrixTriangular_MKL.h" #include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h"
#include "src/Core/products/SelfadjointMatrixMatrix_MKL.h" #include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h"
#include "src/Core/products/SelfadjointMatrixVector_MKL.h" #include "src/Core/products/SelfadjointMatrixVector_BLAS.h"
#include "src/Core/products/TriangularMatrixMatrix_MKL.h" #include "src/Core/products/TriangularMatrixMatrix_BLAS.h"
#include "src/Core/products/TriangularMatrixVector_MKL.h" #include "src/Core/products/TriangularMatrixVector_BLAS.h"
#include "src/Core/products/TriangularSolverMatrix_MKL.h" #include "src/Core/products/TriangularSolverMatrix_BLAS.h"
#endif // EIGEN_USE_BLAS #endif // EIGEN_USE_BLAS
#ifdef EIGEN_USE_MKL_VML #ifdef EIGEN_USE_MKL_VML

View File

@ -788,8 +788,8 @@ template<typename Dst, typename Src> void check_for_aliasing(const Dst &dst, con
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar> struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
{ {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func) static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{ {
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
@ -806,8 +806,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
struct Assignment<DstXprType, SrcXprType, Functor, EigenBase2EigenBase, Scalar> struct Assignment<DstXprType, SrcXprType, Functor, EigenBase2EigenBase, Scalar>
{ {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/) static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
{ {
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
src.evalTo(dst); src.evalTo(dst);

View File

@ -79,8 +79,8 @@ namespace cephes {
*/ */
template <typename Scalar, int N> template <typename Scalar, int N>
struct polevl { struct polevl {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static Scalar run(const Scalar x, const Scalar coef[]) { static EIGEN_STRONG_INLINE Scalar run(const Scalar x, const Scalar coef[]) {
EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
return polevl<Scalar, N - 1>::run(x, coef) * x + coef[N]; return polevl<Scalar, N - 1>::run(x, coef) * x + coef[N];
@ -89,8 +89,8 @@ struct polevl {
template <typename Scalar> template <typename Scalar>
struct polevl<Scalar, 0> { struct polevl<Scalar, 0> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static Scalar run(const Scalar, const Scalar coef[]) { static EIGEN_STRONG_INLINE Scalar run(const Scalar, const Scalar coef[]) {
return coef[0]; return coef[0];
} }
}; };
@ -144,7 +144,7 @@ struct digamma_retval {
template <typename Scalar> template <typename Scalar>
struct digamma_impl { struct digamma_impl {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static Scalar run(Scalar x) { static EIGEN_STRONG_INLINE Scalar run(Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED); THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0); return Scalar(0);
@ -428,20 +428,20 @@ template <typename Scalar> struct igamma_impl; // predeclare igamma_impl
template <typename Scalar> template <typename Scalar>
struct igamma_helper { struct igamma_helper {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; } static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static Scalar big() { assert(false && "big not supported for this type"); return 0.0; } static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
}; };
template <> template <>
struct igamma_helper<float> { struct igamma_helper<float> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static float machep() { static EIGEN_STRONG_INLINE float machep() {
return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0 return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static float big() { static EIGEN_STRONG_INLINE float big() {
// use epsneg (1.0 - epsneg == 1.0) // use epsneg (1.0 - epsneg == 1.0)
return 1.0 / (NumTraits<float>::epsilon() / 2); return 1.0 / (NumTraits<float>::epsilon() / 2);
} }
@ -449,12 +449,12 @@ struct igamma_helper<float> {
template <> template <>
struct igamma_helper<double> { struct igamma_helper<double> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static double machep() { static EIGEN_STRONG_INLINE double machep() {
return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0 return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static double big() { static EIGEN_STRONG_INLINE double big() {
return 1.0 / NumTraits<double>::epsilon(); return 1.0 / NumTraits<double>::epsilon();
} }
}; };
@ -605,7 +605,7 @@ struct igamma_retval {
template <typename Scalar> template <typename Scalar>
struct igamma_impl { struct igamma_impl {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) { static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED); THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0); return Scalar(0);
@ -736,7 +736,7 @@ struct zeta_retval {
template <typename Scalar> template <typename Scalar>
struct zeta_impl { struct zeta_impl {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static Scalar run(Scalar x, Scalar q) { static EIGEN_STRONG_INLINE Scalar run(Scalar x, Scalar q) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED); THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0); return Scalar(0);
@ -757,8 +757,8 @@ struct zeta_impl_series {
template <> template <>
struct zeta_impl_series<float> { struct zeta_impl_series<float> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static bool run(float& a, float& b, float& s, const float x, const float machep) { static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) {
int i = 0; int i = 0;
while(i < 9) while(i < 9)
{ {
@ -777,8 +777,8 @@ struct zeta_impl_series<float> {
template <> template <>
struct zeta_impl_series<double> { struct zeta_impl_series<double> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
static bool run(double& a, double& b, double& s, const double x, const double machep) { static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) {
int i = 0; int i = 0;
while( (i < 9) || (a <= 9.0) ) while( (i < 9) || (a <= 9.0) )
{ {
@ -881,13 +881,14 @@ struct zeta_impl {
const Scalar maxnum = NumTraits<Scalar>::infinity(); const Scalar maxnum = NumTraits<Scalar>::infinity();
const Scalar zero = 0.0, half = 0.5, one = 1.0; const Scalar zero = 0.0, half = 0.5, one = 1.0;
const Scalar machep = igamma_helper<Scalar>::machep(); const Scalar machep = igamma_helper<Scalar>::machep();
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
if( x == one ) if( x == one )
return maxnum; return maxnum;
if( x < one ) if( x < one )
{ {
return zero; return nan;
} }
if( q <= zero ) if( q <= zero )
@ -899,7 +900,7 @@ struct zeta_impl {
p = x; p = x;
r = numext::floor(p); r = numext::floor(p);
if (p != r) if (p != r)
return zero; return nan;
} }
/* Permit negative q but continue sum until n+q > +9 . /* Permit negative q but continue sum until n+q > +9 .
@ -954,7 +955,7 @@ struct polygamma_retval {
template <typename Scalar> template <typename Scalar>
struct polygamma_impl { struct polygamma_impl {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static Scalar run(Scalar n, Scalar x) { static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false), EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED); THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0); return Scalar(0);
@ -969,9 +970,14 @@ struct polygamma_impl {
static Scalar run(Scalar n, Scalar x) { static Scalar run(Scalar n, Scalar x) {
Scalar zero = 0.0, one = 1.0; Scalar zero = 0.0, one = 1.0;
Scalar nplus = n + one; Scalar nplus = n + one;
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
// Check that n is an integer
if (numext::floor(n) != n) {
return nan;
}
// Just return the digamma function for n = 1 // Just return the digamma function for n = 1
if (n == zero) { else if (n == zero) {
return digamma_impl<Scalar>::run(x); return digamma_impl<Scalar>::run(x);
} }
// Use the same implementation as scipy // Use the same implementation as scipy

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Level 3 BLAS SYRK/HERK implementation. * Level 3 BLAS SYRK/HERK implementation.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H #ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H #define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
namespace Eigen { namespace Eigen {
@ -44,34 +44,35 @@ struct general_matrix_matrix_rankupdate :
// try to go to BLAS specialization // try to go to BLAS specialization
#define EIGEN_MKL_RANKUPDATE_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar) \
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \ template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs, int UpLo> \ int RhsStorageOrder, bool ConjugateRhs, int UpLo> \
struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \ struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \ Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \ static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) \ const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
{ \ { \
if (lhs==rhs) { \ if (lhs==rhs) { \
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \ general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \ ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} else { \ } else { \
general_matrix_matrix_triangular_product<Index, \ general_matrix_matrix_triangular_product<Index, \
Scalar, LhsStorageOrder, ConjugateLhs, \ Scalar, LhsStorageOrder, ConjugateLhs, \
Scalar, RhsStorageOrder, ConjugateRhs, \ Scalar, RhsStorageOrder, ConjugateRhs, \
ColMajor, UpLo, BuiltIn> \ ColMajor, UpLo, BuiltIn> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \ ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} \ } \
} \ } \
}; };
EIGEN_MKL_RANKUPDATE_SPECIALIZE(double) EIGEN_BLAS_RANKUPDATE_SPECIALIZE(double)
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex) EIGEN_BLAS_RANKUPDATE_SPECIALIZE(float)
EIGEN_MKL_RANKUPDATE_SPECIALIZE(float) // TODO handle complex cases
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex) // EIGEN_BLAS_RANKUPDATE_SPECIALIZE(dcomplex)
// EIGEN_BLAS_RANKUPDATE_SPECIALIZE(scomplex)
// SYRK for float/double // SYRK for float/double
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \ #define EIGEN_BLAS_RANKUPDATE_R(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \ template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \ enum { \
@ -80,23 +81,19 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \ conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \
}; \ }; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \ static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \ const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \ /* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
\ \
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta; \
\ BLASFUNC(&uplo, &trans, &n, &k, &numext::real_ref(alpha), lhs, &lda, &numext::real_ref(beta), res, &ldc); \
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \
} \ } \
}; };
// HERK for complex data // HERK for complex data
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \ #define EIGEN_BLAS_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, BLASFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \ template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \ enum { \
@ -105,18 +102,15 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \ conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \
}; \ }; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \ static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \ const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
\ \
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \
RTYPE alpha_, beta_; \ RTYPE alpha_, beta_; \
const EIGTYPE* a_ptr; \ const EIGTYPE* a_ptr; \
\ \
/* Set alpha_ & beta_ */ \
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); */\
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
alpha_ = alpha.real(); \ alpha_ = alpha.real(); \
beta_ = 1.0; \ beta_ = 1.0; \
/* Copy with conjugation in some cases*/ \ /* Copy with conjugation in some cases*/ \
@ -127,20 +121,21 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
lda = a.outerStride(); \ lda = a.outerStride(); \
a_ptr = a.data(); \ a_ptr = a.data(); \
} else a_ptr=lhs; \ } else a_ptr=lhs; \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (MKLTYPE*)a_ptr, &lda, &beta_, (MKLTYPE*)res, &ldc); \ BLASFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
} \ } \
}; };
EIGEN_MKL_RANKUPDATE_R(double, double, dsyrk) EIGEN_BLAS_RANKUPDATE_R(double, double, dsyrk_)
EIGEN_MKL_RANKUPDATE_R(float, float, ssyrk) EIGEN_BLAS_RANKUPDATE_R(float, float, ssyrk_)
//EIGEN_MKL_RANKUPDATE_C(dcomplex, MKL_Complex16, double, zherk) // TODO hanlde complex cases
//EIGEN_MKL_RANKUPDATE_C(scomplex, MKL_Complex8, double, cherk) // EIGEN_BLAS_RANKUPDATE_C(dcomplex, double, double, zherk_)
// EIGEN_BLAS_RANKUPDATE_C(scomplex, float, float, cherk_)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H #endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* General matrix-matrix product functionality based on ?GEMM. * General matrix-matrix product functionality based on ?GEMM.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_MKL_H #ifndef EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
#define EIGEN_GENERAL_MATRIX_MATRIX_MKL_H #define EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
namespace Eigen { namespace Eigen {
@ -46,7 +46,7 @@ namespace internal {
// gemm specialization // gemm specialization
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, MKLTYPE, MKLPREFIX) \ #define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASPREFIX) \
template< \ template< \
typename Index, \ typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
@ -66,55 +66,50 @@ static void run(Index rows, Index cols, Index depth, \
using std::conj; \ using std::conj; \
\ \
char transa, transb; \ char transa, transb; \
MKL_INT m, n, k, lda, ldb, ldc; \ BlasIndex m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX a_tmp, b_tmp; \ MatrixX##EIGPREFIX a_tmp, b_tmp; \
EIGTYPE myone(1);\
\ \
/* Set transpose options */ \ /* Set transpose options */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
k = (MKL_INT)depth; \ k = convert_index<BlasIndex>(depth); \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \ if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \
a_tmp = lhs.conjugate(); \ a_tmp = lhs.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _lhs; \ } else a = _lhs; \
\ \
if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \ if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \
b_tmp = rhs.conjugate(); \ b_tmp = rhs.conjugate(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \ } else b = _rhs; \
\ \
MKLPREFIX##gemm(&transa, &transb, &m, &n, &k, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \ BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
}}; }};
GEMM_SPECIALIZATION(double, d, double, d) GEMM_SPECIALIZATION(double, d, double, d)
GEMM_SPECIALIZATION(float, f, float, s) GEMM_SPECIALIZATION(float, f, float, s)
GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, z) GEMM_SPECIALIZATION(dcomplex, cd, double, z)
GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, c) GEMM_SPECIALIZATION(scomplex, cf, float, c)
} // end namespase internal } // end namespase internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_MATRIX_MKL_H #endif // EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* General matrix-vector product functionality based on ?GEMV. * General matrix-vector product functionality based on ?GEMV.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_GENERAL_MATRIX_VECTOR_MKL_H #ifndef EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
#define EIGEN_GENERAL_MATRIX_VECTOR_MKL_H #define EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
namespace Eigen { namespace Eigen {
@ -49,7 +49,7 @@ namespace internal {
template<typename Index, typename LhsScalar, int StorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs> template<typename Index, typename LhsScalar, int StorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
struct general_matrix_vector_product_gemv; struct general_matrix_vector_product_gemv;
#define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_GEMV_SPECIALIZE(Scalar) \
template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,Specialized> { \ struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,Specialized> { \
static void run( \ static void run( \
@ -80,12 +80,12 @@ static void run( \
} \ } \
}; \ }; \
EIGEN_MKL_GEMV_SPECIALIZE(double) EIGEN_BLAS_GEMV_SPECIALIZE(double)
EIGEN_MKL_GEMV_SPECIALIZE(float) EIGEN_BLAS_GEMV_SPECIALIZE(float)
EIGEN_MKL_GEMV_SPECIALIZE(dcomplex) EIGEN_BLAS_GEMV_SPECIALIZE(dcomplex)
EIGEN_MKL_GEMV_SPECIALIZE(scomplex) EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \ #define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASPREFIX) \
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \ struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
{ \ { \
@ -97,16 +97,15 @@ static void run( \
const EIGTYPE* rhs, Index rhsIncr, \ const EIGTYPE* rhs, Index rhsIncr, \
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \ EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
{ \ { \
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \ BlasIndex m=convert_index<BlasIndex>(rows), n=convert_index<BlasIndex>(cols), \
MKLTYPE alpha_, beta_; \ lda=convert_index<BlasIndex>(lhsStride), incx=convert_index<BlasIndex>(rhsIncr), incy=convert_index<BlasIndex>(resIncr); \
const EIGTYPE *x_ptr, myone(1); \ const EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \ char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
if (LhsStorageOrder==RowMajor) { \ if (LhsStorageOrder==RowMajor) { \
m=cols; \ m = convert_index<BlasIndex>(cols); \
n=rows; \ n = convert_index<BlasIndex>(rows); \
}\ }\
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
GEMVVector x_tmp; \ GEMVVector x_tmp; \
if (ConjugateRhs) { \ if (ConjugateRhs) { \
Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \ Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \
@ -114,17 +113,17 @@ static void run( \
x_ptr=x_tmp.data(); \ x_ptr=x_tmp.data(); \
incx=1; \ incx=1; \
} else x_ptr=rhs; \ } else x_ptr=rhs; \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \ BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\ }\
}; };
EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d) EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, d)
EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s) EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, s)
EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z) EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, z)
EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c) EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, c)
} // end namespase internal } // end namespase internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_VECTOR_MKL_H #endif // EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// //
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM. * Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H #ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H #define EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
namespace Eigen { namespace Eigen {
@ -40,7 +40,7 @@ namespace internal {
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */ /* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -52,28 +52,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \ EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='L', uplo='L'; \ char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\ \
/* Set transpose options */ \ /* Set transpose options */ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (LhsStorageOrder==RowMajor) uplo='U'; \ if (LhsStorageOrder==RowMajor) uplo='U'; \
@ -83,16 +78,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.adjoint(); \ b_tmp = rhs.adjoint(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \ } else b = _rhs; \
\ \
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \ BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -103,29 +98,24 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \ EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='L', uplo='L'; \ char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \ Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\ \
/* Set transpose options */ \ /* Set transpose options */ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
ldb = (MKL_INT)rhsStride; \ ldb = convert_index<BlasIndex>(rhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \ if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
@ -151,23 +141,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
b_tmp = rhs.transpose(); \ b_tmp = rhs.transpose(); \
} \ } \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \ } \
\ \
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \ BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
EIGEN_MKL_SYMM_L(double, double, d, d) EIGEN_BLAS_SYMM_L(double, double, d, d)
EIGEN_MKL_SYMM_L(float, float, f, s) EIGEN_BLAS_SYMM_L(float, float, f, s)
EIGEN_MKL_HEMM_L(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z)
EIGEN_MKL_HEMM_L(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_HEMM_L(scomplex, float, cf, c)
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */ /* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -179,27 +169,22 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \ EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='R', uplo='L'; \ char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
ldb = (MKL_INT)lhsStride; \ ldb = convert_index<BlasIndex>(lhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (RhsStorageOrder==RowMajor) uplo='U'; \ if (RhsStorageOrder==RowMajor) uplo='U'; \
@ -209,16 +194,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = lhs.adjoint(); \ b_tmp = lhs.adjoint(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _lhs; \ } else b = _lhs; \
\ \
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \ BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\ \
} \ } \
}; };
#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \ template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -229,35 +214,30 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \ EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \ EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \ { \
char side='R', uplo='L'; \ char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \ BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \ const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \ MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \ Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\ \
/* Set m, n, k */ \ /* Set m, n, k */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
\ \
/* Set lda, ldb, ldc */ \ /* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
ldb = (MKL_INT)lhsStride; \ ldb = convert_index<BlasIndex>(lhsStride); \
ldc = (MKL_INT)resStride; \ ldc = convert_index<BlasIndex>(resStride); \
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \ if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \ Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
a_tmp = rhs.conjugate(); \ a_tmp = rhs.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _rhs; \ } else a = _rhs; \
if (RhsStorageOrder==RowMajor) uplo='U'; \ if (RhsStorageOrder==RowMajor) uplo='U'; \
\ \
@ -279,17 +259,17 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = b_tmp.outerStride(); \ ldb = b_tmp.outerStride(); \
} \ } \
\ \
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \ BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
} \ } \
}; };
EIGEN_MKL_SYMM_R(double, double, d, d) EIGEN_BLAS_SYMM_R(double, double, d, d)
EIGEN_MKL_SYMM_R(float, float, f, s) EIGEN_BLAS_SYMM_R(float, float, f, s)
EIGEN_MKL_HEMM_R(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_HEMM_R(dcomplex, double, cd, z)
EIGEN_MKL_HEMM_R(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_HEMM_R(scomplex, float, cf, c)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H #endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV. * Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #define EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
namespace Eigen { namespace Eigen {
@ -47,7 +47,7 @@ template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool Conju
struct selfadjoint_matrix_vector_product_symv : struct selfadjoint_matrix_vector_product_symv :
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {}; selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {};
#define EIGEN_MKL_SYMV_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \ struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \
static void run( \ static void run( \
@ -66,12 +66,12 @@ static void run( \
} \ } \
}; \ }; \
EIGEN_MKL_SYMV_SPECIALIZE(double) EIGEN_BLAS_SYMV_SPECIALIZE(double)
EIGEN_MKL_SYMV_SPECIALIZE(float) EIGEN_BLAS_SYMV_SPECIALIZE(float)
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex) EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex)
EIGEN_MKL_SYMV_SPECIALIZE(scomplex) EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \ #define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \ struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
{ \ { \
@ -85,29 +85,27 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \ IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
IsLower = UpLo == Lower ? 1 : 0 \ IsLower = UpLo == Lower ? 1 : 0 \
}; \ }; \
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \ BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
const EIGTYPE *x_ptr, myone(1); \ const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \ char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
SYMVVector x_tmp; \ SYMVVector x_tmp; \
if (ConjugateRhs) { \ if (ConjugateRhs) { \
Map<const SYMVVector, 0 > map_x(_rhs,size,1); \ Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
x_tmp=map_x.conjugate(); \ x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \ x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \ } else x_ptr=_rhs; \
MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \ BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\ }\
}; };
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv) EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv) EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv) EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv) EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM. * Triangular matrix * matrix product functionality based on ?TRMM.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
namespace Eigen { namespace Eigen {
@ -50,7 +50,7 @@ struct product_triangular_matrix_matrix_trmm :
// try to go to BLAS specialization // try to go to BLAS specialization
#define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
template <typename Index, int Mode, \ template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -65,17 +65,17 @@ struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
} \ } \
}; };
EIGEN_MKL_TRMM_SPECIALIZE(double, true) EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
EIGEN_MKL_TRMM_SPECIALIZE(double, false) EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true) EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false) EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
EIGEN_MKL_TRMM_SPECIALIZE(float, true) EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
EIGEN_MKL_TRMM_SPECIALIZE(float, false) EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true) EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false) EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
// implements col-major += alpha * op(triangular) * op(general) // implements col-major += alpha * op(triangular) * op(general)
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, int Mode, \ template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -106,13 +106,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\ \
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (rows != depth) { \ if (rows != depth) { \
\ \
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \ /* FIXME handle mkl_domain_get_max_threads */ \
/*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
\ \
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \ /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
@ -121,27 +122,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
/* Make sense to call GEMM */ \ /* Make sense to call GEMM */ \
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
\ \
/*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \ } \
return; \ return; \
} \ } \
char side = 'L', transa, uplo, diag = 'N'; \ char side = 'L', transa, uplo, diag = 'N'; \
EIGTYPE *b; \ EIGTYPE *b; \
const EIGTYPE *a; \ const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \ BlasIndex m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
\ \
/* Set m, n */ \ /* Set m, n */ \
m = (MKL_INT)diagSize; \ m = convert_index<BlasIndex>(diagSize); \
n = (MKL_INT)cols; \ n = convert_index<BlasIndex>(cols); \
\ \
/* Set trans */ \ /* Set trans */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
@ -152,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
\ \
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\ \
/* Set uplo */ \ /* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \ uplo = IsLower ? 'L' : 'U'; \
@ -168,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
else if (IsUnitDiag) \ else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\ a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _lhs; \ a = _lhs; \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
} \ } \
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \ /* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\ \
/* Add op(a_triangular)*b into res*/ \ /* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -183,13 +180,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
} \ } \
}; };
EIGEN_MKL_TRMM_L(double, double, d, d) EIGEN_BLAS_TRMM_L(double, double, d, d)
EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z)
EIGEN_MKL_TRMM_L(float, float, f, s) EIGEN_BLAS_TRMM_L(float, float, f, s)
EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_TRMM_L(scomplex, float, cf, c)
// implements col-major += alpha * op(general) * op(triangular) // implements col-major += alpha * op(general) * op(triangular)
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, int Mode, \ template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \ int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
@ -220,13 +217,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\ \
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (cols != depth) { \ if (cols != depth) { \
\ \
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \ int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
\ \
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \ /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
@ -235,27 +232,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
/* Make sense to call GEMM */ \ /* Make sense to call GEMM */ \
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \ BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
\ \
/*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \ } \
return; \ return; \
} \ } \
char side = 'R', transa, uplo, diag = 'N'; \ char side = 'R', transa, uplo, diag = 'N'; \
EIGTYPE *b; \ EIGTYPE *b; \
const EIGTYPE *a; \ const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \ BlasIndex m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
\ \
/* Set m, n */ \ /* Set m, n */ \
m = (MKL_INT)rows; \ m = convert_index<BlasIndex>(rows); \
n = (MKL_INT)diagSize; \ n = convert_index<BlasIndex>(diagSize); \
\ \
/* Set trans */ \ /* Set trans */ \
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
@ -266,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
\ \
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\ \
/* Set uplo */ \ /* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \ uplo = IsLower ? 'L' : 'U'; \
@ -282,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
else if (IsUnitDiag) \ else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\ a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _rhs; \ a = _rhs; \
lda = rhsStride; \ lda = convert_index<BlasIndex>(rhsStride); \
} \ } \
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \ /* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\ \
/* Add op(a_triangular)*b into res*/ \ /* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -297,13 +290,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
} \ } \
}; };
EIGEN_MKL_TRMM_R(double, double, d, d) EIGEN_BLAS_TRMM_R(double, double, d, d)
EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z)
EIGEN_MKL_TRMM_R(float, float, f, s) EIGEN_BLAS_TRMM_R(float, float, f, s)
EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_TRMM_R(scomplex, float, cf, c)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Triangular matrix-vector product functionality based on ?TRMV. * Triangular matrix-vector product functionality based on ?TRMV.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
namespace Eigen { namespace Eigen {
@ -47,7 +47,7 @@ template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename Rh
struct triangular_matrix_vector_product_trmv : struct triangular_matrix_vector_product_trmv :
triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {}; triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
#define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
@ -65,13 +65,13 @@ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs
} \ } \
}; };
EIGEN_MKL_TRMV_SPECIALIZE(double) EIGEN_BLAS_TRMV_SPECIALIZE(double)
EIGEN_MKL_TRMV_SPECIALIZE(float) EIGEN_BLAS_TRMV_SPECIALIZE(float)
EIGEN_MKL_TRMV_SPECIALIZE(dcomplex) EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
EIGEN_MKL_TRMV_SPECIALIZE(scomplex) EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
// implements col-major: res += alpha * op(triangular) * vector // implements col-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
enum { \ enum { \
@ -105,17 +105,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\ /* Square part handling */\
\ \
char trans, uplo, diag; \ char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \ BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \ EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
\ \
/* Set m, n */ \ /* Set m, n */ \
n = (MKL_INT)size; \ n = convert_index<BlasIndex>(size); \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \ incx = 1; \
incy = resIncr; \ incy = convert_index<BlasIndex>(resIncr); \
\ \
/* Set uplo, trans and diag*/ \ /* Set uplo, trans and diag*/ \
trans = 'N'; \ trans = 'N'; \
@ -123,39 +121,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \ BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \ BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = _res + size*resIncr; \
a = _lhs + size; \ a = _lhs + size; \
m = rows-size; \ m = convert_index<BlasIndex>(rows-size); \
n = size; \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = _res; \
a = _lhs + size*lda; \ a = _lhs + size*lda; \
m = size; \ m = convert_index<BlasIndex>(size); \
n = cols-size; \ n = convert_index<BlasIndex>(cols-size); \
} \ } \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \ BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \ } \
} \ } \
}; };
EIGEN_MKL_TRMV_CM(double, double, d, d) EIGEN_BLAS_TRMV_CM(double, double, d, d)
EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z)
EIGEN_MKL_TRMV_CM(float, float, f, s) EIGEN_BLAS_TRMV_CM(float, float, f, s)
EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c)
// implements row-major: res += alpha * op(triangular) * vector // implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
enum { \ enum { \
@ -189,17 +187,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\ /* Square part handling */\
\ \
char trans, uplo, diag; \ char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \ BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \ EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
\ \
/* Set m, n */ \ /* Set m, n */ \
n = (MKL_INT)size; \ n = convert_index<BlasIndex>(size); \
lda = lhsStride; \ lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \ incx = 1; \
incy = resIncr; \ incy = convert_index<BlasIndex>(resIncr); \
\ \
/* Set uplo, trans and diag*/ \ /* Set uplo, trans and diag*/ \
trans = ConjLhs ? 'C' : 'T'; \ trans = ConjLhs ? 'C' : 'T'; \
@ -207,39 +203,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \ BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \ BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = _res + size*resIncr; \
a = _lhs + size*lda; \ a = _lhs + size*lda; \
m = rows-size; \ m = convert_index<BlasIndex>(rows-size); \
n = size; \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = _res; \
a = _lhs + size; \ a = _lhs + size; \
m = size; \ m = convert_index<BlasIndex>(size); \
n = cols-size; \ n = convert_index<BlasIndex>(cols-size); \
} \ } \
MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \ BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \ } \
} \ } \
}; };
EIGEN_MKL_TRMV_RM(double, double, d, d) EIGEN_BLAS_TRMV_RM(double, double, d, d)
EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z) EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z)
EIGEN_MKL_TRMV_RM(float, float, f, s) EIGEN_BLAS_TRMV_RM(float, float, f, s)
EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c) EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c)
} // end namespase internal } // end namespase internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H

View File

@ -25,20 +25,20 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM. * Triangular matrix * matrix product functionality based on ?TRMM.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H #define EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
// implements LeftSide op(triangular)^-1 * general // implements LeftSide op(triangular)^-1 * general
#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \ #define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \ { \
@ -53,13 +53,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
const EIGTYPE* _tri, Index triStride, \ const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \ { \
MKL_INT m = size, n = otherSize, lda, ldb; \ BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
char side = 'L', uplo, diag='N', transa; \ char side = 'L', uplo, diag='N', transa; \
/* Set alpha_ */ \ /* Set alpha_ */ \
MKLTYPE alpha; \ EIGTYPE alpha(1); \
EIGTYPE myone(1); \ ldb = convert_index<BlasIndex>(otherStride);\
assign_scalar_eig2mkl(alpha, myone); \
ldb = otherStride;\
\ \
const EIGTYPE *a; \ const EIGTYPE *a; \
/* Set trans */ \ /* Set trans */ \
@ -75,25 +73,25 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
if (conjA) { \ if (conjA) { \
a_tmp = tri.conjugate(); \ a_tmp = tri.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _tri; \ a = _tri; \
lda = triStride; \ lda = convert_index<BlasIndex>(triStride); \
} \ } \
if (IsUnitDiag) diag='U'; \ if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \ /* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
} \ } \
}; };
EIGEN_MKL_TRSM_L(double, double, d) EIGEN_BLAS_TRSM_L(double, double, d)
EIGEN_MKL_TRSM_L(dcomplex, MKL_Complex16, z) EIGEN_BLAS_TRSM_L(dcomplex, double, z)
EIGEN_MKL_TRSM_L(float, float, s) EIGEN_BLAS_TRSM_L(float, float, s)
EIGEN_MKL_TRSM_L(scomplex, MKL_Complex8, c) EIGEN_BLAS_TRSM_L(scomplex, float, c)
// implements RightSide general * op(triangular)^-1 // implements RightSide general * op(triangular)^-1
#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \ #define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \ template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \ { \
@ -108,13 +106,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
const EIGTYPE* _tri, Index triStride, \ const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \ EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \ { \
MKL_INT m = otherSize, n = size, lda, ldb; \ BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
char side = 'R', uplo, diag='N', transa; \ char side = 'R', uplo, diag='N', transa; \
/* Set alpha_ */ \ /* Set alpha_ */ \
MKLTYPE alpha; \ EIGTYPE alpha(1); \
EIGTYPE myone(1); \ ldb = convert_index<BlasIndex>(otherStride);\
assign_scalar_eig2mkl(alpha, myone); \
ldb = otherStride;\
\ \
const EIGTYPE *a; \ const EIGTYPE *a; \
/* Set trans */ \ /* Set trans */ \
@ -130,26 +126,26 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
if (conjA) { \ if (conjA) { \
a_tmp = tri.conjugate(); \ a_tmp = tri.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = a_tmp.outerStride(); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \ } else { \
a = _tri; \ a = _tri; \
lda = triStride; \ lda = convert_index<BlasIndex>(triStride); \
} \ } \
if (IsUnitDiag) diag='U'; \ if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \ /* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \ /*std::cout << "TRMS_L specialization!\n";*/ \
} \ } \
}; };
EIGEN_MKL_TRSM_R(double, double, d) EIGEN_BLAS_TRSM_R(double, double, d)
EIGEN_MKL_TRSM_R(dcomplex, MKL_Complex16, z) EIGEN_BLAS_TRSM_R(dcomplex, double, z)
EIGEN_MKL_TRSM_R(float, float, s) EIGEN_BLAS_TRSM_R(float, float, s)
EIGEN_MKL_TRSM_R(scomplex, MKL_Complex8, c) EIGEN_BLAS_TRSM_R(scomplex, float, c)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H

View File

@ -49,7 +49,7 @@
#define EIGEN_USE_LAPACKE #define EIGEN_USE_LAPACKE
#endif #endif
#if defined(EIGEN_USE_BLAS) || defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML) #if defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
#define EIGEN_USE_MKL #define EIGEN_USE_MKL
#endif #endif
@ -64,7 +64,6 @@
# ifndef EIGEN_USE_MKL # ifndef EIGEN_USE_MKL
/*If the MKL version is too old, undef everything*/ /*If the MKL version is too old, undef everything*/
# undef EIGEN_USE_MKL_ALL # undef EIGEN_USE_MKL_ALL
# undef EIGEN_USE_BLAS
# undef EIGEN_USE_LAPACKE # undef EIGEN_USE_LAPACKE
# undef EIGEN_USE_MKL_VML # undef EIGEN_USE_MKL_VML
# undef EIGEN_USE_LAPACKE_STRICT # undef EIGEN_USE_LAPACKE_STRICT
@ -107,52 +106,23 @@
#else #else
#define EIGEN_MKL_DOMAIN_PARDISO MKL_PARDISO #define EIGEN_MKL_DOMAIN_PARDISO MKL_PARDISO
#endif #endif
#endif
namespace Eigen { namespace Eigen {
typedef std::complex<double> dcomplex; typedef std::complex<double> dcomplex;
typedef std::complex<float> scomplex; typedef std::complex<float> scomplex;
namespace internal { #if defined(EIGEN_USE_MKL)
typedef MKL_INT BlasIndex;
template<typename MKLType, typename EigenType> #else
static inline void assign_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) { typedef int BlasIndex;
mklScalar=eigenScalar; #endif
}
template<typename MKLType, typename EigenType>
static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) {
mklScalar=eigenScalar;
}
template <>
inline void assign_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=eigenScalar.imag();
}
template <>
inline void assign_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=eigenScalar.imag();
}
template <>
inline void assign_conj_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=-eigenScalar.imag();
}
template <>
inline void assign_conj_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=-eigenScalar.imag();
}
} // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#if defined(EIGEN_USE_BLAS)
#include "../../misc/blas.h"
#endif #endif
#endif // EIGEN_MKL_SUPPORT_H #endif // EIGEN_MKL_SUPPORT_H

View File

@ -30,15 +30,15 @@ int BLASFUNC(cdotcw) (int *, float *, int *, float *, int *, float*);
int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*); int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*);
int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*); int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*);
int BLASFUNC(saxpy) (int *, float *, float *, int *, float *, int *); int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(daxpy) (int *, double *, double *, int *, double *, int *); int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(qaxpy) (int *, double *, double *, int *, double *, int *); int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(caxpy) (int *, float *, float *, int *, float *, int *); int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(zaxpy) (int *, double *, double *, int *, double *, int *); int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xaxpy) (int *, double *, double *, int *, double *, int *); int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(caxpyc)(int *, float *, float *, int *, float *, int *); int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(zaxpyc)(int *, double *, double *, int *, double *, int *); int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xaxpyc)(int *, double *, double *, int *, double *, int *); int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(scopy) (int *, float *, int *, float *, int *); int BLASFUNC(scopy) (int *, float *, int *, float *, int *);
int BLASFUNC(dcopy) (int *, double *, int *, double *, int *); int BLASFUNC(dcopy) (int *, double *, int *, double *, int *);
@ -177,31 +177,19 @@ int BLASFUNC(xgeru)(int *, int *, double *, double *, int *,
int BLASFUNC(xgerc)(int *, int *, double *, double *, int *, int BLASFUNC(xgerc)(int *, int *, double *, double *, int *,
double *, int *, double *, int *); double *, int *, double *, int *);
int BLASFUNC(sgemv)(char *, int *, int *, float *, float *, int *, int BLASFUNC(sgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(dgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(dgemv)(char *, int *, int *, double *, double *, int *, int BLASFUNC(qgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *); int BLASFUNC(cgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(qgemv)(char *, int *, int *, double *, double *, int *, int BLASFUNC(zgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *); int BLASFUNC(xgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cgemv)(char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(strsv) (char *, char *, char *, int *, float *, int *, int BLASFUNC(strsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
float *, int *); int BLASFUNC(dtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(dtrsv) (char *, char *, char *, int *, double *, int *, int BLASFUNC(qtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
double *, int *); int BLASFUNC(ctrsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(qtrsv) (char *, char *, char *, int *, double *, int *, int BLASFUNC(ztrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
double *, int *); int BLASFUNC(xtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrsv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(ztrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(xtrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(stpsv) (char *, char *, char *, int *, float *, float *, int *); int BLASFUNC(stpsv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(dtpsv) (char *, char *, char *, int *, double *, double *, int *); int BLASFUNC(dtpsv) (char *, char *, char *, int *, double *, double *, int *);
@ -210,18 +198,12 @@ int BLASFUNC(ctpsv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(ztpsv) (char *, char *, char *, int *, double *, double *, int *); int BLASFUNC(ztpsv) (char *, char *, char *, int *, double *, double *, int *);
int BLASFUNC(xtpsv) (char *, char *, char *, int *, double *, double *, int *); int BLASFUNC(xtpsv) (char *, char *, char *, int *, double *, double *, int *);
int BLASFUNC(strmv) (char *, char *, char *, int *, float *, int *, int BLASFUNC(strmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
float *, int *); int BLASFUNC(dtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(dtrmv) (char *, char *, char *, int *, double *, int *, int BLASFUNC(qtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
double *, int *); int BLASFUNC(ctrmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(qtrmv) (char *, char *, char *, int *, double *, int *, int BLASFUNC(ztrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
double *, int *); int BLASFUNC(xtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrmv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(ztrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(xtrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(stpmv) (char *, char *, char *, int *, float *, float *, int *); int BLASFUNC(stpmv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(dtpmv) (char *, char *, char *, int *, double *, double *, int *); int BLASFUNC(dtpmv) (char *, char *, char *, int *, double *, double *, int *);
@ -244,18 +226,9 @@ int BLASFUNC(ctbsv) (char *, char *, char *, int *, int *, float *, int *, floa
int BLASFUNC(ztbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *); int BLASFUNC(ztbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *); int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
int BLASFUNC(ssymv) (char *, int *, float *, float *, int *, int BLASFUNC(ssymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(dsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(dsymv) (char *, int *, double *, double *, int *, int BLASFUNC(qsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *);
int BLASFUNC(qsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(csymv) (char *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(sspmv) (char *, int *, float *, float *, int BLASFUNC(sspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *); float *, int *, float *, float *, int *);
@ -263,38 +236,17 @@ int BLASFUNC(dspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *); double *, int *, double *, double *, int *);
int BLASFUNC(qspmv) (char *, int *, double *, double *, int BLASFUNC(qspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *); double *, int *, double *, double *, int *);
int BLASFUNC(cspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
int BLASFUNC(zspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(xspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(ssyr) (char *, int *, float *, float *, int *, int BLASFUNC(ssyr) (const char *, const int *, const float *, const float *, const int *, float *, const int *);
float *, int *); int BLASFUNC(dsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(dsyr) (char *, int *, double *, double *, int *, int BLASFUNC(qsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
double *, int *);
int BLASFUNC(qsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(ssyr2) (char *, int *, float *, int BLASFUNC(ssyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
float *, int *, float *, int *, float *, int *); int BLASFUNC(dsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(dsyr2) (char *, int *, double *, int BLASFUNC(qsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
double *, int *, double *, int *, double *, int *); int BLASFUNC(csyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(qsyr2) (char *, int *, double *, int BLASFUNC(zsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
double *, int *, double *, int *, double *, int *); int BLASFUNC(xsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(csyr2) (char *, int *, float *,
float *, int *, float *, int *, float *, int *);
int BLASFUNC(zsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(xsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(sspr) (char *, int *, float *, float *, int *, int BLASFUNC(sspr) (char *, int *, float *, float *, int *,
float *); float *);
@ -302,12 +254,6 @@ int BLASFUNC(dspr) (char *, int *, double *, double *, int *,
double *); double *);
int BLASFUNC(qspr) (char *, int *, double *, double *, int *, int BLASFUNC(qspr) (char *, int *, double *, double *, int *,
double *); double *);
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
float *);
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(sspr2) (char *, int *, float *, int BLASFUNC(sspr2) (char *, int *, float *,
float *, int *, float *, int *, float *); float *, int *, float *, int *, float *);
@ -347,12 +293,9 @@ int BLASFUNC(zhpr2) (char *, int *, double *,
int BLASFUNC(xhpr2) (char *, int *, double *, int BLASFUNC(xhpr2) (char *, int *, double *,
double *, int *, double *, int *, double *); double *, int *, double *, int *, double *);
int BLASFUNC(chemv) (char *, int *, float *, float *, int *, int BLASFUNC(chemv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(zhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(zhemv) (char *, int *, double *, double *, int *, int BLASFUNC(xhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *);
int BLASFUNC(xhemv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(chpmv) (char *, int *, float *, float *, int BLASFUNC(chpmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *); float *, int *, float *, float *, int *);
@ -401,18 +344,12 @@ int BLASFUNC(xhbmv)(char *, int *, int *, double *, double *, int *,
/* Level 3 routines */ /* Level 3 routines */
int BLASFUNC(sgemm)(char *, char *, int *, int *, int *, float *, int BLASFUNC(sgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, int *, float *, float *, int *); int BLASFUNC(dgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(dgemm)(char *, char *, int *, int *, int *, double *, int BLASFUNC(qgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, int *, double *, double *, int *); int BLASFUNC(cgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(qgemm)(char *, char *, int *, int *, int *, double *, int BLASFUNC(zgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, int *, double *, double *, int *); int BLASFUNC(xgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cgemm)(char *, char *, int *, int *, int *, float *,
float *, int *, float *, int *, float *, float *, int *);
int BLASFUNC(zgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(xgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(cgemm3m)(char *, char *, int *, int *, int *, float *, int BLASFUNC(cgemm3m)(char *, char *, int *, int *, int *, float *,
float *, int *, float *, int *, float *, float *, int *); float *, int *, float *, int *, float *, float *, int *);
@ -434,84 +371,48 @@ int BLASFUNC(zge2mm)(char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *, double *, double *, int *, double *, int *,
double *, double *, int *); double *, double *, int *);
int BLASFUNC(strsm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(strsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
float *, float *, int *, float *, int *); int BLASFUNC(dtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(dtrsm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(qtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
double *, double *, int *, double *, int *); int BLASFUNC(ctrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(qtrsm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(ztrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
double *, double *, int *, double *, int *); int BLASFUNC(xtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrsm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(ztrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(xtrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(strmm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(strmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
float *, float *, int *, float *, int *); int BLASFUNC(dtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(dtrmm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(qtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
double *, double *, int *, double *, int *); int BLASFUNC(ctrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(qtrmm)(char *, char *, char *, char *, int *, int *, int BLASFUNC(ztrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
double *, double *, int *, double *, int *); int BLASFUNC(xtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrmm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(ztrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(xtrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(ssymm)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(ssymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(dsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(dsymm)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(qsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *); int BLASFUNC(csymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(qsymm)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(zsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *); int BLASFUNC(xsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(csymm)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *);
float *, int *, float *, float *, int *); int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
double *, int *, double *, double *, int *);
int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(ssyrk)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(ssyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
float *, float *, int *); int BLASFUNC(dsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(dsyrk)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(qsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
double *, double *, int *); int BLASFUNC(csyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(qsyrk)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(zsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
double *, double *, int *); int BLASFUNC(xsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(csyrk)(char *, char *, int *, int *, float *, float *, int *,
float *, float *, int *);
int BLASFUNC(zsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(xsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(ssyr2k)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(ssyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(dsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(dsyr2k)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(qsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
double*, int *, double *, double *, int *); int BLASFUNC(csyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(qsyr2k)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(zsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
double*, int *, double *, double *, int *); int BLASFUNC(xsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(csyr2k)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(xsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(chemm)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(chemm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(zhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(zhemm)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double *, int *, double *, double *, int *);
int BLASFUNC(xhemm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(chemm3m)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(chemm3m)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *); float *, int *, float *, float *, int *);
@ -520,136 +421,17 @@ int BLASFUNC(zhemm3m)(char *, char *, int *, int *, double *, double *, int *,
int BLASFUNC(xhemm3m)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xhemm3m)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *); double *, int *, double *, double *, int *);
int BLASFUNC(cherk)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(cherk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
float *, float *, int *); int BLASFUNC(zherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(zherk)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
double *, double *, int *);
int BLASFUNC(xherk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(cher2k)(char *, char *, int *, int *, float *, float *, int *, int BLASFUNC(cher2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
float *, int *, float *, float *, int *); int BLASFUNC(zher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(zher2k)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
double*, int *, double *, double *, int *); int BLASFUNC(cher2m)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(xher2k)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
double*, int *, double *, double *, int *); int BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(cher2m)(char *, char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(xher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
#ifdef __cplusplus #ifdef __cplusplus
} }

152
Eigen/src/misc/lapack.h Normal file
View File

@ -0,0 +1,152 @@
#ifndef LAPACK_H
#define LAPACK_H
#include "blas.h"
#ifdef __cplusplus
extern "C"
{
#endif
int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
int BLASFUNC(zspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(xspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
float *);
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -10,8 +10,8 @@
#ifndef EIGEN_BLAS_COMMON_H #ifndef EIGEN_BLAS_COMMON_H
#define EIGEN_BLAS_COMMON_H #define EIGEN_BLAS_COMMON_H
#include <Eigen/Core> #include "../Eigen/Core"
#include <Eigen/Jacobi> #include "../Eigen/Jacobi"
#include <complex> #include <complex>
@ -19,8 +19,7 @@
#error the token SCALAR must be defined to compile this file #error the token SCALAR must be defined to compile this file
#endif #endif
#include <Eigen/src/misc/blas.h> #include "../Eigen/src/misc/blas.h"
#define NOTR 0 #define NOTR 0
#define TR 1 #define TR 1
@ -94,6 +93,7 @@ enum
typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> PlainMatrixType; typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> PlainMatrixType;
typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType; typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType;
typedef Map<const Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > ConstMatrixType;
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType; typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType; typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
@ -104,25 +104,44 @@ matrix(T* data, int rows, int cols, int stride)
return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride)); return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
} }
template<typename T>
Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >
matrix(const T* data, int rows, int cols, int stride)
{
return Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
}
template<typename T> template<typename T>
Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(T* data, int size, int incr) Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(T* data, int size, int incr)
{ {
return Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr)); return Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
} }
template<typename T>
Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(const T* data, int size, int incr)
{
return Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
}
template<typename T> template<typename T>
Map<Matrix<T,Dynamic,1> > make_vector(T* data, int size) Map<Matrix<T,Dynamic,1> > make_vector(T* data, int size)
{ {
return Map<Matrix<T,Dynamic,1> >(data, size); return Map<Matrix<T,Dynamic,1> >(data, size);
} }
template<typename T>
Map<const Matrix<T,Dynamic,1> > make_vector(const T* data, int size)
{
return Map<const Matrix<T,Dynamic,1> >(data, size);
}
template<typename T> template<typename T>
T* get_compact_vector(T* x, int n, int incx) T* get_compact_vector(T* x, int n, int incx)
{ {
if(incx==1) if(incx==1)
return x; return x;
T* ret = new Scalar[n]; typename Eigen::internal::remove_const<T>::type* ret = new Scalar[n];
if(incx<0) make_vector(ret,n) = make_vector(x,n,-incx).reverse(); if(incx<0) make_vector(ret,n) = make_vector(x,n,-incx).reverse();
else make_vector(ret,n) = make_vector(x,n, incx); else make_vector(ret,n) = make_vector(x,n, incx);
return ret; return ret;

View File

@ -9,11 +9,11 @@
#include "common.h" #include "common.h"
int EIGEN_BLAS_FUNC(axpy)(int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy) int EIGEN_BLAS_FUNC(axpy)(const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *py, const int *incy)
{ {
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py); Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
if(*n<=0) return 0; if(*n<=0) return 0;

View File

@ -16,7 +16,8 @@
* where alpha and beta are scalars, x and y are n element vectors and * where alpha and beta are scalars, x and y are n element vectors and
* A is an n by n hermitian matrix. * A is an n by n hermitian matrix.
*/ */
int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy) int EIGEN_BLAS_FUNC(hemv)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
{ {
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar); typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
static const functype func[2] = { static const functype func[2] = {
@ -26,11 +27,11 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run), (internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py); Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments // check arguments
int info = 0; int info = 0;
@ -45,7 +46,7 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
if(*n==0) if(*n==0)
return 1; return 1;
Scalar* actual_x = get_compact_vector(x,*n,*incx); const Scalar* actual_x = get_compact_vector(x,*n,*incx);
Scalar* actual_y = get_compact_vector(y,*n,*incy); Scalar* actual_y = get_compact_vector(y,*n,*incy);
if(beta!=Scalar(1)) if(beta!=Scalar(1))

View File

@ -23,7 +23,8 @@ struct general_matrix_vector_product_wrapper
} }
}; };
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc) int EIGEN_BLAS_FUNC(gemv)(const char *opa, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *incb, const RealScalar *pbeta, RealScalar *pc, const int *incc)
{ {
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar); typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
static const functype func[4] = { static const functype func[4] = {
@ -36,11 +37,11 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments // check arguments
int info = 0; int info = 0;
@ -62,7 +63,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
if(code!=NOTR) if(code!=NOTR)
std::swap(actual_m,actual_n); std::swap(actual_m,actual_n);
Scalar* actual_b = get_compact_vector(b,actual_n,*incb); const Scalar* actual_b = get_compact_vector(b,actual_n,*incb);
Scalar* actual_c = get_compact_vector(c,actual_m,*incc); Scalar* actual_c = get_compact_vector(c,actual_m,*incc);
if(beta!=Scalar(1)) if(beta!=Scalar(1))
@ -82,7 +83,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
return 1; return 1;
} }
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) int EIGEN_BLAS_FUNC(trsv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
{ {
typedef void (*functype)(int, const Scalar *, int, Scalar *); typedef void (*functype)(int, const Scalar *, int, Scalar *);
static const functype func[16] = { static const functype func[16] = {
@ -116,7 +117,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* b = reinterpret_cast<Scalar*>(pb);
int info = 0; int info = 0;
@ -141,7 +142,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) int EIGEN_BLAS_FUNC(trmv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
{ {
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, const Scalar&); typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, const Scalar&);
static const functype func[16] = { static const functype func[16] = {
@ -175,7 +176,7 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* b = reinterpret_cast<Scalar*>(pb);
int info = 0; int info = 0;
@ -217,11 +218,11 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealScalar *palpha, RealScalar *pa, int *lda, int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealScalar *palpha, RealScalar *pa, int *lda,
RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy) RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
{ {
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py); Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int coeff_rows = *kl+*ku+1; int coeff_rows = *kl+*ku+1;
int info = 0; int info = 0;
@ -244,7 +245,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
if(OP(*trans)!=NOTR) if(OP(*trans)!=NOTR)
std::swap(actual_m,actual_n); std::swap(actual_m,actual_n);
Scalar* actual_x = get_compact_vector(x,actual_n,*incx); const Scalar* actual_x = get_compact_vector(x,actual_n,*incx);
Scalar* actual_y = get_compact_vector(y,actual_m,*incy); Scalar* actual_y = get_compact_vector(y,actual_m,*incy);
if(beta!=Scalar(1)) if(beta!=Scalar(1))
@ -253,7 +254,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
else make_vector(actual_y, actual_m) *= beta; else make_vector(actual_y, actual_m) *= beta;
} }
MatrixType mat_coeffs(a,coeff_rows,*n,*lda); ConstMatrixType mat_coeffs(a,coeff_rows,*n,*lda);
int nb = std::min(*n,(*m)+(*ku)); int nb = std::min(*n,(*m)+(*ku));
for(int j=0; j<nb; ++j) for(int j=0; j<nb; ++j)

View File

@ -10,7 +10,8 @@
#include "common.h" #include "common.h"
// y = alpha*A*x + beta*y // y = alpha*A*x + beta*y
int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy) int EIGEN_BLAS_FUNC(symv) (const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
{ {
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar); typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
static const functype func[2] = { static const functype func[2] = {
@ -20,11 +21,11 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run), (internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py); Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments // check arguments
int info = 0; int info = 0;
@ -39,7 +40,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
if(*n==0) if(*n==0)
return 0; return 0;
Scalar* actual_x = get_compact_vector(x,*n,*incx); const Scalar* actual_x = get_compact_vector(x,*n,*incx);
Scalar* actual_y = get_compact_vector(y,*n,*incy); Scalar* actual_y = get_compact_vector(y,*n,*incy);
if(beta!=Scalar(1)) if(beta!=Scalar(1))
@ -61,7 +62,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
} }
// C := alpha*x*x' + C // C := alpha*x*x' + C
int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(syr)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *pc, const int *ldc)
{ {
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&); typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&);
@ -72,9 +73,9 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
(selfadjoint_rank1_update<Scalar,int,ColMajor,Lower,false,Conj>::run), (selfadjoint_rank1_update<Scalar,int,ColMajor,Lower,false,Conj>::run),
}; };
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0; int info = 0;
if(UPLO(*uplo)==INVALID) info = 1; if(UPLO(*uplo)==INVALID) info = 1;
@ -87,7 +88,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
if(*n==0 || alpha==Scalar(0)) return 1; if(*n==0 || alpha==Scalar(0)) return 1;
// if the increment is not 1, let's copy it to a temporary vector to enable vectorization // if the increment is not 1, let's copy it to a temporary vector to enable vectorization
Scalar* x_cpy = get_compact_vector(x,*n,*incx); const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
int code = UPLO(*uplo); int code = UPLO(*uplo);
if(code>=2 || func[code]==0) if(code>=2 || func[code]==0)
@ -101,7 +102,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
} }
// C := alpha*x*y' + alpha*y*x' + C // C := alpha*x*y' + alpha*y*x' + C
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(syr2)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, const RealScalar *py, const int *incy, RealScalar *pc, const int *ldc)
{ {
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar); typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
static const functype func[2] = { static const functype func[2] = {
@ -111,10 +112,10 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
(internal::rank2_update_selector<Scalar,int,Lower>::run), (internal::rank2_update_selector<Scalar,int,Lower>::run),
}; };
Scalar* x = reinterpret_cast<Scalar*>(px); const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py); const Scalar* y = reinterpret_cast<const Scalar*>(py);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0; int info = 0;
if(UPLO(*uplo)==INVALID) info = 1; if(UPLO(*uplo)==INVALID) info = 1;
@ -128,8 +129,8 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
if(alpha==Scalar(0)) if(alpha==Scalar(0))
return 1; return 1;
Scalar* x_cpy = get_compact_vector(x,*n,*incx); const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
Scalar* y_cpy = get_compact_vector(y,*n,*incy); const Scalar* y_cpy = get_compact_vector(y,*n,*incy);
int code = UPLO(*uplo); int code = UPLO(*uplo);
if(code>=2 || func[code]==0) if(code>=2 || func[code]==0)

View File

@ -9,7 +9,8 @@
#include <iostream> #include <iostream>
#include "common.h" #include "common.h"
int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*); typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
@ -37,11 +38,11 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0; int info = 0;
if(OP(*opa)==INVALID) info = 1; if(OP(*opa)==INVALID) info = 1;
@ -74,7 +75,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
return 0; return 0;
} }
int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
{ {
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n"; // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&); typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
@ -137,9 +139,9 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0; int info = 0;
if(SIDE(*side)==INVALID) info = 1; if(SIDE(*side)==INVALID) info = 1;
@ -178,7 +180,8 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
// b = alpha*op(a)*b for side = 'L'or'l' // b = alpha*op(a)*b for side = 'L'or'l'
// b = alpha*b*op(a) for side = 'R'or'r' // b = alpha*b*op(a) for side = 'R'or'r'
int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
{ {
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n"; // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&); typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
@ -241,9 +244,9 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0; int info = 0;
if(SIDE(*side)==INVALID) info = 1; if(SIDE(*side)==INVALID) info = 1;
@ -281,14 +284,15 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
// c = alpha*a*b + beta*c for side = 'L'or'l' // c = alpha*a*b + beta*c for side = 'L'or'l'
// c = alpha*b*a + beta*c for side = 'R'or'r // c = alpha*b*a + beta*c for side = 'R'or'r
int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0; int info = 0;
if(SIDE(*side)==INVALID) info = 1; if(SIDE(*side)==INVALID) info = 1;
@ -350,7 +354,8 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// c = alpha*a*a' + beta*c for op = 'N'or'n' // c = alpha*a*a' + beta*c for op = 'N'or'n'
// c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c' // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
#if !ISCOMPLEX #if !ISCOMPLEX
@ -373,10 +378,10 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
}; };
#endif #endif
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0; int info = 0;
if(UPLO(*uplo)==INVALID) info = 1; if(UPLO(*uplo)==INVALID) info = 1;
@ -429,13 +434,14 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
// c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n' // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
// c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't' // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
@ -496,13 +502,14 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
// c = alpha*a*b + beta*c for side = 'L'or'l' // c = alpha*a*b + beta*c for side = 'L'or'l'
// c = alpha*b*a + beta*c for side = 'R'or'r // c = alpha*b*a + beta*c for side = 'R'or'r
int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta); Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
@ -554,7 +561,8 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// c = alpha*a*conj(a') + beta*c for op = 'N'or'n' // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
// c = alpha*conj(a')*a + beta*c for op = 'C'or'c' // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
@ -574,7 +582,7 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
0 0
}; };
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
RealScalar alpha = *palpha; RealScalar alpha = *palpha;
RealScalar beta = *pbeta; RealScalar beta = *pbeta;
@ -620,12 +628,13 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
// c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
// c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c' // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{ {
Scalar* a = reinterpret_cast<Scalar*>(pa); const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb); const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc); Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
RealScalar beta = *pbeta; RealScalar beta = *pbeta;
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";

View File

@ -37,10 +37,10 @@ Here is another example reshaping a 2x6 matrix to a 6x2 one:
\section TutorialSlicing Slicing \section TutorialSlicing Slicing
Slicing consists in taking a set of rows, or columns, or elements, uniformly spaced within a matrix. Slicing consists in taking a set of rows, columns, or elements, uniformly spaced within a matrix.
Again, the class Map allows to easily mimic this feature. Again, the class Map allows to easily mimic this feature.
For instance, one can take skip every P elements in a vector: For instance, one can skip every P elements in a vector:
<table class="example"> <table class="example">
<tr><th>Example:</th><th>Output:</th></tr> <tr><th>Example:</th><th>Output:</th></tr>
<tr><td> <tr><td>

View File

@ -55,7 +55,7 @@ Operations on other scalar types or mixing reals and complexes will continue to
In addition you can choose which parts will be substituted by defining one or multiple of the following macros: In addition you can choose which parts will be substituted by defining one or multiple of the following macros:
<table class="manual"> <table class="manual">
<tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (currently works with Intel MKL only)</td></tr> <tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (compatible with any F77 BLAS interface, not only Intel MKL)</td></tr>
<tr class="alt"><td>\c EIGEN_USE_LAPACKE </td><td>Enables the use of external Lapack routines via the <a href="http://www.netlib.org/lapack/lapacke.html">Intel Lapacke</a> C interface to Lapack (currently works with Intel MKL only)</td></tr> <tr class="alt"><td>\c EIGEN_USE_LAPACKE </td><td>Enables the use of external Lapack routines via the <a href="http://www.netlib.org/lapack/lapacke.html">Intel Lapacke</a> C interface to Lapack (currently works with Intel MKL only)</td></tr>
<tr><td>\c EIGEN_USE_LAPACKE_STRICT </td><td>Same as \c EIGEN_USE_LAPACKE but algorithm of lower robustness are disabled. This currently concerns only JacobiSVD which otherwise would be replaced by \c gesvd that is less robust than Jacobi rotations.</td></tr> <tr><td>\c EIGEN_USE_LAPACKE_STRICT </td><td>Same as \c EIGEN_USE_LAPACKE but algorithm of lower robustness are disabled. This currently concerns only JacobiSVD which otherwise would be replaced by \c gesvd that is less robust than Jacobi rotations.</td></tr>
<tr class="alt"><td>\c EIGEN_USE_MKL_VML </td><td>Enables the use of Intel VML (vector operations)</td></tr> <tr class="alt"><td>\c EIGEN_USE_MKL_VML </td><td>Enables the use of Intel VML (vector operations)</td></tr>

View File

@ -11,6 +11,7 @@
#define EIGEN_LAPACK_COMMON_H #define EIGEN_LAPACK_COMMON_H
#include "../blas/common.h" #include "../blas/common.h"
#include "../Eigen/src/misc/lapack.h"
#define EIGEN_LAPACK_FUNC(FUNC,ARGLIST) \ #define EIGEN_LAPACK_FUNC(FUNC,ARGLIST) \
extern "C" { int EIGEN_BLAS_FUNC(FUNC) ARGLIST; } \ extern "C" { int EIGEN_BLAS_FUNC(FUNC) ARGLIST; } \

View File

@ -331,11 +331,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097)); VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
std::numeric_limits<RealScalar>::infinity()); std::numeric_limits<RealScalar>::infinity());
VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter
// Check the polygamma against scipy.special.polygamma examples // Check the polygamma against scipy.special.polygamma examples
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848)); VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848)); VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496)); VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter
// Check the polygamma function over a larger range of values // Check the polygamma function over a larger range of values
VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435)); VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435));

View File

@ -14,6 +14,9 @@
using std::sqrt; using std::sqrt;
// tolerance for chekcing number of iterations
#define LM_EVAL_COUNT_TOL 4/3
int fcn_chkder(const VectorXd &x, VectorXd &fvec, MatrixXd &fjac, int iflag) int fcn_chkder(const VectorXd &x, VectorXd &fvec, MatrixXd &fjac, int iflag)
{ {
/* subroutine fcn for chkder example. */ /* subroutine fcn for chkder example. */
@ -1023,7 +1026,8 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.njev, 72); VERIFY_IS_EQUAL(lm.njev, 72);
// check norm^2 // check norm^2
std::cout.precision(30); std::cout.precision(30);
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.4290986055242372e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats std::cout << lm.fvec.squaredNorm() << "\n";
VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
// check x // check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02); VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00); VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -1044,7 +1048,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev, 9); VERIFY_IS_EQUAL(lm.nfev, 9);
VERIFY_IS_EQUAL(lm.njev, 8); VERIFY_IS_EQUAL(lm.njev, 8);
// check norm^2 // check norm^2
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.430571737783119393e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
// check x // check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02); VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00); VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -1354,8 +1358,12 @@ void testNistMGH17(void)
// check return value // check return value
VERIFY_IS_EQUAL(info, 2); VERIFY_IS_EQUAL(info, 2);
VERIFY(lm.nfev < 650); // 602 ++g_test_level;
VERIFY(lm.njev < 600); // 545 VERIFY_IS_EQUAL(lm.nfev, 602); // 602
VERIFY_IS_EQUAL(lm.njev, 545); // 545
--g_test_level;
VERIFY(lm.nfev < 602 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev < 545 * LM_EVAL_COUNT_TOL);
/* /*
* Second try * Second try

View File

@ -23,6 +23,9 @@
using std::sqrt; using std::sqrt;
// tolerance for chekcing number of iterations
#define LM_EVAL_COUNT_TOL 4/3
struct lmder_functor : DenseFunctor<double> struct lmder_functor : DenseFunctor<double>
{ {
lmder_functor(void): DenseFunctor<double>(3,15) {} lmder_functor(void): DenseFunctor<double>(3,15) {}
@ -631,7 +634,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev(), 79); VERIFY_IS_EQUAL(lm.nfev(), 79);
VERIFY_IS_EQUAL(lm.njev(), 72); VERIFY_IS_EQUAL(lm.njev(), 72);
// check norm^2 // check norm^2
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.430899764097e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
// check x // check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02); VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00); VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -652,7 +655,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev(), 9); VERIFY_IS_EQUAL(lm.nfev(), 9);
VERIFY_IS_EQUAL(lm.njev(), 8); VERIFY_IS_EQUAL(lm.njev(), 8);
// check norm^2 // check norm^2
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.428595533845e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
// check x // check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02); VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00); VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -789,7 +792,8 @@ void testNistMGH10(void)
MGH10_functor functor; MGH10_functor functor;
LevenbergMarquardt<MGH10_functor> lm(functor); LevenbergMarquardt<MGH10_functor> lm(functor);
info = lm.minimize(x); info = lm.minimize(x);
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeErrorTooSmall); VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
// was: VERIFY_IS_EQUAL(info, 1);
// check norm^2 // check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01); VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
@ -799,9 +803,13 @@ void testNistMGH10(void)
VERIFY_IS_APPROX(x[2], 3.4522363462E+02); VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
// check return value // check return value
//VERIFY_IS_EQUAL(info, 1);
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 284 ); VERIFY_IS_EQUAL(lm.nfev(), 284 );
VERIFY_IS_EQUAL(lm.njev(), 249 ); VERIFY_IS_EQUAL(lm.njev(), 249 );
--g_test_level;
VERIFY(lm.nfev() < 284 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 249 * LM_EVAL_COUNT_TOL);
/* /*
* Second try * Second try
@ -809,7 +817,10 @@ void testNistMGH10(void)
x<< 0.02, 4000., 250.; x<< 0.02, 4000., 250.;
// do the computation // do the computation
info = lm.minimize(x); info = lm.minimize(x);
++g_test_level;
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall); VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
// was: VERIFY_IS_EQUAL(info, 1);
--g_test_level;
// check norm^2 // check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01); VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
@ -819,9 +830,12 @@ void testNistMGH10(void)
VERIFY_IS_APPROX(x[2], 3.4522363462E+02); VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
// check return value // check return value
//VERIFY_IS_EQUAL(info, 1); ++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 126); VERIFY_IS_EQUAL(lm.nfev(), 126);
VERIFY_IS_EQUAL(lm.njev(), 116); VERIFY_IS_EQUAL(lm.njev(), 116);
--g_test_level;
VERIFY(lm.nfev() < 126 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 116 * LM_EVAL_COUNT_TOL);
} }
@ -896,8 +910,12 @@ void testNistBoxBOD(void)
// check return value // check return value
VERIFY_IS_EQUAL(info, 1); VERIFY_IS_EQUAL(info, 1);
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 16 ); VERIFY_IS_EQUAL(lm.nfev(), 16 );
VERIFY_IS_EQUAL(lm.njev(), 15 ); VERIFY_IS_EQUAL(lm.njev(), 15 );
--g_test_level;
VERIFY(lm.nfev() < 16 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 15 * LM_EVAL_COUNT_TOL);
// check norm^2 // check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.1680088766E+03); VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.1680088766E+03);
// check x // check x