From 5da90fc8dd1570ebfbc0a9b6c058207b3bec15b6 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 8 Apr 2016 19:40:48 -0700 Subject: [PATCH 01/15] Use numext::abs instead of std::abs in scalar_fuzzy_default_impl to make it usable inside GPU kernels. --- Eigen/src/Core/MathFunctions.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index dd19f080b..8e7dd2b73 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -1128,14 +1128,12 @@ struct scalar_fuzzy_default_impl template EIGEN_DEVICE_FUNC static inline bool isMuchSmallerThan(const Scalar& x, const OtherScalar& y, const RealScalar& prec) { - EIGEN_USING_STD_MATH(abs); - return abs(x) <= abs(y) * prec; + return numext::abs(x) <= numext::abs(y) * prec; } EIGEN_DEVICE_FUNC static inline bool isApprox(const Scalar& x, const Scalar& y, const RealScalar& prec) { - EIGEN_USING_STD_MATH(abs); - return abs(x - y) <= numext::mini(abs(x), abs(y)) * prec; + return numext::abs(x - y) <= numext::mini(numext::abs(x), numext::abs(y)) * prec; } EIGEN_DEVICE_FUNC static inline bool isApproxOrLessThan(const Scalar& x, const Scalar& y, const RealScalar& prec) From de057ebe541d5a6c1297ea94a89dcaf35582d44e Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 9 Apr 2016 20:07:36 +0100 Subject: [PATCH 02/15] Added nans to zeta function. --- Eigen/src/Core/SpecialFunctions.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index 2a0a6ff15..954972cdd 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -881,13 +881,14 @@ struct zeta_impl { const Scalar maxnum = NumTraits::infinity(); const Scalar zero = 0.0, half = 0.5, one = 1.0; const Scalar machep = igamma_helper::machep(); + const Scalar nan = NumTraits::quiet_NaN(); if( x == one ) return maxnum; if( x < one ) { - return zero; + return nan; } if( q <= zero ) @@ -899,7 +900,7 @@ struct zeta_impl { p = x; r = numext::floor(p); if (p != r) - return zero; + return nan; } /* Permit negative q but continue sum until n+q > +9 . From 643b6976493c122ffb7205cc3ab893f28f9e1634 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sun, 10 Apr 2016 00:37:53 +0100 Subject: [PATCH 03/15] Proper handling of domain errors. --- Eigen/src/Core/SpecialFunctions.h | 7 ++++++- test/array.cpp | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index 954972cdd..2dc7b22fc 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -970,9 +970,14 @@ struct polygamma_impl { static Scalar run(Scalar n, Scalar x) { Scalar zero = 0.0, one = 1.0; Scalar nplus = n + one; + const Scalar nan = NumTraits::quiet_NaN(); + // Check that n is an integer + if (numext::floor(n) != n) { + return nan; + } // Just return the digamma function for n = 1 - if (n == zero) { + else if (n == zero) { return digamma_impl::run(x); } // Use the same implementation as scipy diff --git a/test/array.cpp b/test/array.cpp index 8b0a34722..beaa62221 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -331,11 +331,13 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097)); VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter std::numeric_limits::infinity()); + VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter // Check the polygamma against scipy.special.polygamma examples VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848)); VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848)); VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496)); + VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter // Check the polygamma function over a larger range of values VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435)); From fc6a0ebb1c98ab51c575bcd2688c1d9d11200267 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 10:54:58 +0200 Subject: [PATCH 04/15] Typos in doc. --- doc/TutorialReshapeSlicing.dox | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/TutorialReshapeSlicing.dox b/doc/TutorialReshapeSlicing.dox index eb0fb0df0..3730a5de6 100644 --- a/doc/TutorialReshapeSlicing.dox +++ b/doc/TutorialReshapeSlicing.dox @@ -37,10 +37,10 @@ Here is another example reshaping a 2x6 matrix to a 6x2 one: \section TutorialSlicing Slicing -Slicing consists in taking a set of rows, or columns, or elements, uniformly spaced within a matrix. +Slicing consists in taking a set of rows, columns, or elements, uniformly spaced within a matrix. Again, the class Map allows to easily mimic this feature. -For instance, one can take skip every P elements in a vector: +For instance, one can skip every P elements in a vector:
Example:Output:
From 675e0a222442b1d7446a843f15128c467502160a Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 15:06:20 +0200 Subject: [PATCH 05/15] Fix static/inline keywords order. --- Eigen/src/Core/AssignEvaluator.h | 8 +++--- Eigen/src/Core/SpecialFunctions.h | 48 +++++++++++++++---------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index a9a524130..3de8aa9a2 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -788,8 +788,8 @@ template void check_for_aliasing(const Dst &dst, con template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> struct Assignment { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static void run(DstXprType &dst, const SrcXprType &src, const Functor &func) + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const Functor &func) { eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); @@ -806,8 +806,8 @@ struct Assignment template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar> struct Assignment { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &/*func*/) + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &/*func*/) { eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); src.evalTo(dst); diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index 2dc7b22fc..adb055b15 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -79,8 +79,8 @@ namespace cephes { */ template struct polevl { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static Scalar run(const Scalar x, const Scalar coef[]) { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE Scalar run(const Scalar x, const Scalar coef[]) { EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); return polevl::run(x, coef) * x + coef[N]; @@ -89,8 +89,8 @@ struct polevl { template struct polevl { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static Scalar run(const Scalar, const Scalar coef[]) { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE Scalar run(const Scalar, const Scalar coef[]) { return coef[0]; } }; @@ -144,7 +144,7 @@ struct digamma_retval { template struct digamma_impl { EIGEN_DEVICE_FUNC - static Scalar run(Scalar x) { + static EIGEN_STRONG_INLINE Scalar run(Scalar x) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), THIS_TYPE_IS_NOT_SUPPORTED); return Scalar(0); @@ -428,20 +428,20 @@ template struct igamma_impl; // predeclare igamma_impl template struct igamma_helper { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static Scalar big() { assert(false && "big not supported for this type"); return 0.0; } + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; } + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; } }; template <> struct igamma_helper { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static float machep() { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE float machep() { return NumTraits::epsilon() / 2; // 1.0 - machep == 1.0 } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static float big() { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE float big() { // use epsneg (1.0 - epsneg == 1.0) return 1.0 / (NumTraits::epsilon() / 2); } @@ -449,12 +449,12 @@ struct igamma_helper { template <> struct igamma_helper { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static double machep() { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE double machep() { return NumTraits::epsilon() / 2; // 1.0 - machep == 1.0 } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static double big() { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE double big() { return 1.0 / NumTraits::epsilon(); } }; @@ -605,7 +605,7 @@ struct igamma_retval { template struct igamma_impl { EIGEN_DEVICE_FUNC - static Scalar run(Scalar a, Scalar x) { + static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), THIS_TYPE_IS_NOT_SUPPORTED); return Scalar(0); @@ -736,7 +736,7 @@ struct zeta_retval { template struct zeta_impl { EIGEN_DEVICE_FUNC - static Scalar run(Scalar x, Scalar q) { + static EIGEN_STRONG_INLINE Scalar run(Scalar x, Scalar q) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), THIS_TYPE_IS_NOT_SUPPORTED); return Scalar(0); @@ -757,8 +757,8 @@ struct zeta_impl_series { template <> struct zeta_impl_series { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static bool run(float& a, float& b, float& s, const float x, const float machep) { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) { int i = 0; while(i < 9) { @@ -777,8 +777,8 @@ struct zeta_impl_series { template <> struct zeta_impl_series { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - static bool run(double& a, double& b, double& s, const double x, const double machep) { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) { int i = 0; while( (i < 9) || (a <= 9.0) ) { @@ -955,7 +955,7 @@ struct polygamma_retval { template struct polygamma_impl { EIGEN_DEVICE_FUNC - static Scalar run(Scalar n, Scalar x) { + static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), THIS_TYPE_IS_NOT_SUPPORTED); return Scalar(0); From 4e8e5888d7a78d514e54a518f6692f2838314328 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 15:12:44 +0200 Subject: [PATCH 06/15] Improve constness of blas level-3 interface. --- Eigen/src/misc/blas.h | 252 +++++++++++++++--------------------------- blas/common.h | 19 ++++ blas/level3_impl.h | 81 ++++++++------ 3 files changed, 151 insertions(+), 201 deletions(-) diff --git a/Eigen/src/misc/blas.h b/Eigen/src/misc/blas.h index 6fce99ed5..ae0c393f1 100644 --- a/Eigen/src/misc/blas.h +++ b/Eigen/src/misc/blas.h @@ -30,15 +30,15 @@ int BLASFUNC(cdotcw) (int *, float *, int *, float *, int *, float*); int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*); int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*); -int BLASFUNC(saxpy) (int *, float *, float *, int *, float *, int *); -int BLASFUNC(daxpy) (int *, double *, double *, int *, double *, int *); -int BLASFUNC(qaxpy) (int *, double *, double *, int *, double *, int *); -int BLASFUNC(caxpy) (int *, float *, float *, int *, float *, int *); -int BLASFUNC(zaxpy) (int *, double *, double *, int *, double *, int *); -int BLASFUNC(xaxpy) (int *, double *, double *, int *, double *, int *); -int BLASFUNC(caxpyc)(int *, float *, float *, int *, float *, int *); -int BLASFUNC(zaxpyc)(int *, double *, double *, int *, double *, int *); -int BLASFUNC(xaxpyc)(int *, double *, double *, int *, double *, int *); +int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, int *); +int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, int *); +int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, int *); +int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, int *); int BLASFUNC(scopy) (int *, float *, int *, float *, int *); int BLASFUNC(dcopy) (int *, double *, int *, double *, int *); @@ -177,31 +177,19 @@ int BLASFUNC(xgeru)(int *, int *, double *, double *, int *, int BLASFUNC(xgerc)(int *, int *, double *, double *, int *, double *, int *, double *, int *); -int BLASFUNC(sgemv)(char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(dgemv)(char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(qgemv)(char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(cgemv)(char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zgemv)(char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xgemv)(char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(sgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(qgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(cgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(strsv) (char *, char *, char *, int *, float *, int *, - float *, int *); -int BLASFUNC(dtrsv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(qtrsv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(ctrsv) (char *, char *, char *, int *, float *, int *, - float *, int *); -int BLASFUNC(ztrsv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(xtrsv) (char *, char *, char *, int *, double *, int *, - double *, int *); +int BLASFUNC(strsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(dtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(qtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(ctrsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(ztrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(xtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); int BLASFUNC(stpsv) (char *, char *, char *, int *, float *, float *, int *); int BLASFUNC(dtpsv) (char *, char *, char *, int *, double *, double *, int *); @@ -210,18 +198,12 @@ int BLASFUNC(ctpsv) (char *, char *, char *, int *, float *, float *, int *); int BLASFUNC(ztpsv) (char *, char *, char *, int *, double *, double *, int *); int BLASFUNC(xtpsv) (char *, char *, char *, int *, double *, double *, int *); -int BLASFUNC(strmv) (char *, char *, char *, int *, float *, int *, - float *, int *); -int BLASFUNC(dtrmv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(qtrmv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(ctrmv) (char *, char *, char *, int *, float *, int *, - float *, int *); -int BLASFUNC(ztrmv) (char *, char *, char *, int *, double *, int *, - double *, int *); -int BLASFUNC(xtrmv) (char *, char *, char *, int *, double *, int *, - double *, int *); +int BLASFUNC(strmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(dtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(qtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(ctrmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(ztrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(xtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *); int BLASFUNC(stpmv) (char *, char *, char *, int *, float *, float *, int *); int BLASFUNC(dtpmv) (char *, char *, char *, int *, double *, double *, int *); @@ -244,18 +226,12 @@ int BLASFUNC(ctbsv) (char *, char *, char *, int *, int *, float *, int *, floa int BLASFUNC(ztbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *); int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *); -int BLASFUNC(ssymv) (char *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(dsymv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(qsymv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(csymv) (char *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zsymv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xsymv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(ssymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(qsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(sspmv) (char *, int *, float *, float *, float *, int *, float *, float *, int *); @@ -347,12 +323,9 @@ int BLASFUNC(zhpr2) (char *, int *, double *, int BLASFUNC(xhpr2) (char *, int *, double *, double *, int *, double *, int *, double *); -int BLASFUNC(chemv) (char *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zhemv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xhemv) (char *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(chemv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(chpmv) (char *, int *, float *, float *, float *, int *, float *, float *, int *); @@ -401,18 +374,12 @@ int BLASFUNC(xhbmv)(char *, int *, int *, double *, double *, int *, /* Level 3 routines */ -int BLASFUNC(sgemm)(char *, char *, int *, int *, int *, float *, - float *, int *, float *, int *, float *, float *, int *); -int BLASFUNC(dgemm)(char *, char *, int *, int *, int *, double *, - double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(qgemm)(char *, char *, int *, int *, int *, double *, - double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(cgemm)(char *, char *, int *, int *, int *, float *, - float *, int *, float *, int *, float *, float *, int *); -int BLASFUNC(zgemm)(char *, char *, int *, int *, int *, double *, - double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(xgemm)(char *, char *, int *, int *, int *, double *, - double *, int *, double *, int *, double *, double *, int *); +int BLASFUNC(sgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(qgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(cgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(cgemm3m)(char *, char *, int *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *); @@ -434,84 +401,48 @@ int BLASFUNC(zge2mm)(char *, char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(strsm)(char *, char *, char *, char *, int *, int *, - float *, float *, int *, float *, int *); -int BLASFUNC(dtrsm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(qtrsm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(ctrsm)(char *, char *, char *, char *, int *, int *, - float *, float *, int *, float *, int *); -int BLASFUNC(ztrsm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(xtrsm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); +int BLASFUNC(strsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(dtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(qtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(ctrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(ztrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(xtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); -int BLASFUNC(strmm)(char *, char *, char *, char *, int *, int *, - float *, float *, int *, float *, int *); -int BLASFUNC(dtrmm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(qtrmm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(ctrmm)(char *, char *, char *, char *, int *, int *, - float *, float *, int *, float *, int *); -int BLASFUNC(ztrmm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); -int BLASFUNC(xtrmm)(char *, char *, char *, char *, int *, int *, - double *, double *, int *, double *, int *); +int BLASFUNC(strmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(dtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(qtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(ctrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(ztrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(xtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *); -int BLASFUNC(ssymm)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(dsymm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(qsymm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(csymm)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zsymm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xsymm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(ssymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(qsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(csymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *); +int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *); +int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(ssyrk)(char *, char *, int *, int *, float *, float *, int *, - float *, float *, int *); -int BLASFUNC(dsyrk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); -int BLASFUNC(qsyrk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); -int BLASFUNC(csyrk)(char *, char *, int *, int *, float *, float *, int *, - float *, float *, int *); -int BLASFUNC(zsyrk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); -int BLASFUNC(xsyrk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); +int BLASFUNC(ssyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(qsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(csyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(ssyr2k)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(dsyr2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(qsyr2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(csyr2k)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zsyr2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(xsyr2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); +int BLASFUNC(ssyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(dsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); +int BLASFUNC(qsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); +int BLASFUNC(csyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); +int BLASFUNC(xsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); -int BLASFUNC(chemm)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zhemm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); -int BLASFUNC(xhemm)(char *, char *, int *, int *, double *, double *, int *, - double *, int *, double *, double *, int *); +int BLASFUNC(chemm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(chemm3m)(char *, char *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *); @@ -520,25 +451,16 @@ int BLASFUNC(zhemm3m)(char *, char *, int *, int *, double *, double *, int *, int BLASFUNC(xhemm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *); -int BLASFUNC(cherk)(char *, char *, int *, int *, float *, float *, int *, - float *, float *, int *); -int BLASFUNC(zherk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); -int BLASFUNC(xherk)(char *, char *, int *, int *, double *, double *, int *, - double *, double *, int *); +int BLASFUNC(cherk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(cher2k)(char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zher2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(xher2k)(char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(cher2m)(char *, char *, char *, int *, int *, float *, float *, int *, - float *, int *, float *, float *, int *); -int BLASFUNC(zher2m)(char *, char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); -int BLASFUNC(xher2m)(char *, char *, char *, int *, int *, double *, double *, int *, - double*, int *, double *, double *, int *); +int BLASFUNC(cher2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(cher2m)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); +int BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *, float *, int *); diff --git a/blas/common.h b/blas/common.h index 5ecb153e2..acb50af1b 100644 --- a/blas/common.h +++ b/blas/common.h @@ -104,18 +104,37 @@ matrix(T* data, int rows, int cols, int stride) return Map, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride)); } +template +Map, 0, OuterStride<> > +matrix(const T* data, int rows, int cols, int stride) +{ + return Map, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride)); +} + template Map, 0, InnerStride > make_vector(T* data, int size, int incr) { return Map, 0, InnerStride >(data, size, InnerStride(incr)); } +template +Map, 0, InnerStride > make_vector(const T* data, int size, int incr) +{ + return Map, 0, InnerStride >(data, size, InnerStride(incr)); +} + template Map > make_vector(T* data, int size) { return Map >(data, size); } +template +Map > make_vector(const T* data, int size) +{ + return Map >(data, size); +} + template T* get_compact_vector(T* x, int n, int incx) { diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 267a727ef..beb36c47d 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -9,7 +9,8 @@ #include #include "common.h" -int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking&, Eigen::internal::GemmParallelInfo*); @@ -37,11 +38,11 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal 0 }; - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); int info = 0; if(OP(*opa)==INVALID) info = 1; @@ -74,7 +75,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal return 0; } -int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) +int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, + const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n"; typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking&); @@ -137,9 +139,9 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, 0 }; - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); int info = 0; if(SIDE(*side)==INVALID) info = 1; @@ -178,7 +180,8 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, // b = alpha*op(a)*b for side = 'L'or'l' // b = alpha*b*op(a) for side = 'R'or'r' -int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) +int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, + const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n"; typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking&); @@ -241,9 +244,9 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, 0 }; - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); int info = 0; if(SIDE(*side)==INVALID) info = 1; @@ -281,14 +284,15 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, // c = alpha*a*b + beta*c for side = 'L'or'l' // c = alpha*b*a + beta*c for side = 'R'or'r -int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); int info = 0; if(SIDE(*side)==INVALID) info = 1; @@ -350,7 +354,8 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa // c = alpha*a*a' + beta*c for op = 'N'or'n' // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c' -int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k, + const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; #if !ISCOMPLEX @@ -373,10 +378,10 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp }; #endif - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); int info = 0; if(UPLO(*uplo)==INVALID) info = 1; @@ -429,13 +434,14 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n' // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't' -int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; @@ -496,13 +502,14 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal // c = alpha*a*b + beta*c for side = 'L'or'l' // c = alpha*b*a + beta*c for side = 'R'or'r -int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; @@ -554,7 +561,8 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa // c = alpha*a*conj(a') + beta*c for op = 'N'or'n' // c = alpha*conj(a')*a + beta*c for op = 'C'or'c' -int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k, + const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; @@ -574,7 +582,7 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp 0 }; - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* c = reinterpret_cast(pc); RealScalar alpha = *palpha; RealScalar beta = *pbeta; @@ -620,12 +628,13 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c' -int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k, + const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); RealScalar beta = *pbeta; // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; From 6a9ca88e7e1bb72de621806b51c5a4fd17310943 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 15:17:14 +0200 Subject: [PATCH 07/15] Relax dependency on MKL for EIGEN_USE_BLAS --- .../GeneralMatrixMatrixTriangular_MKL.h | 24 +++++---- .../Core/products/GeneralMatrixMatrix_MKL.h | 17 +++--- .../Core/products/GeneralMatrixVector_MKL.h | 20 ++++--- .../products/SelfadjointMatrixMatrix_MKL.h | 52 ++++++------------- .../products/SelfadjointMatrixVector_MKL.h | 16 +++--- .../products/TriangularMatrixMatrix_MKL.h | 25 ++++----- .../products/TriangularMatrixVector_MKL.h | 36 ++++++------- .../products/TriangularSolverMatrix_MKL.h | 28 +++++----- Eigen/src/Core/util/MKL_support.h | 12 ++++- 9 files changed, 98 insertions(+), 132 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h index 3deed068e..1cdf48fbf 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h @@ -50,25 +50,26 @@ template { \ static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \ - const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) \ + const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking& blocking) \ { \ if (lhs==rhs) { \ general_matrix_matrix_rankupdate \ - ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \ + ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \ } else { \ general_matrix_matrix_triangular_product \ - ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \ + ::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \ } \ } \ }; EIGEN_MKL_RANKUPDATE_SPECIALIZE(double) -//EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex) EIGEN_MKL_RANKUPDATE_SPECIALIZE(float) -//EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex) +// TODO handle complex cases +// EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex) +// EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex) // SYRK for float/double #define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \ @@ -80,7 +81,7 @@ struct general_matrix_matrix_rankupdate& /*blocking*/) \ { \ /* typedef Matrix MatrixRhs;*/ \ \ @@ -105,7 +106,7 @@ struct general_matrix_matrix_rankupdate& /*blocking*/) \ { \ typedef Matrix MatrixType; \ \ @@ -132,11 +133,12 @@ struct general_matrix_matrix_rankupdate > map_x(rhs,cols,1,InnerStride<>(incx)); \ @@ -114,14 +112,14 @@ static void run( \ x_ptr=x_tmp.data(); \ incx=1; \ } else x_ptr=rhs; \ - MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \ + MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ }\ }; -EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d) -EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s) -EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z) -EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c) +EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d) +EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s) +EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, double, z) +EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, float, c) } // end namespase internal diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h index dfa687fef..9c2e811dd 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h @@ -52,24 +52,19 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='L', uplo='L'; \ MKL_INT m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ - MKLTYPE alpha_, beta_; \ + EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ - EIGTYPE myone(1);\ \ /* Set transpose options */ \ /* Set m, n, k */ \ m = (MKL_INT)rows; \ n = (MKL_INT)cols; \ \ -/* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, myone); \ -\ /* Set lda, ldb, ldc */ \ lda = (MKL_INT)lhsStride; \ ldb = (MKL_INT)rhsStride; \ @@ -86,7 +81,7 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='L', uplo='L'; \ MKL_INT m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ - MKLTYPE alpha_, beta_; \ + EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ Matrix a_tmp; \ - EIGTYPE myone(1); \ \ /* Set transpose options */ \ /* Set m, n, k */ \ m = (MKL_INT)rows; \ n = (MKL_INT)cols; \ \ -/* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, myone); \ -\ /* Set lda, ldb, ldc */ \ lda = (MKL_INT)lhsStride; \ ldb = (MKL_INT)rhsStride; \ @@ -154,15 +144,15 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='R', uplo='L'; \ MKL_INT m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ - MKLTYPE alpha_, beta_; \ + EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ - EIGTYPE myone(1);\ \ /* Set m, n, k */ \ m = (MKL_INT)rows; \ n = (MKL_INT)cols; \ \ -/* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, myone); \ -\ /* Set lda, ldb, ldc */ \ lda = (MKL_INT)rhsStride; \ ldb = (MKL_INT)lhsStride; \ @@ -212,7 +197,7 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='R', uplo='L'; \ MKL_INT m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ - MKLTYPE alpha_, beta_; \ + EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ Matrix a_tmp; \ - EIGTYPE myone(1); \ \ /* Set m, n, k */ \ m = (MKL_INT)rows; \ n = (MKL_INT)cols; \ \ -/* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, myone); \ -\ /* Set lda, ldb, ldc */ \ lda = (MKL_INT)rhsStride; \ ldb = (MKL_INT)lhsStride; \ @@ -279,14 +259,14 @@ struct product_selfadjoint_matrix map_x(_rhs,size,1); \ x_tmp=map_x.conjugate(); \ x_ptr=x_tmp.data(); \ } else x_ptr=_rhs; \ - MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \ + MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ }\ }; -EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv) -EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv) -EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv) -EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv) +EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv_) +EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv_) +EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, double, zhemv_) +EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, float, chemv_) } // end namespace internal diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h index d9e7cf852..31f6d2007 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h @@ -109,7 +109,8 @@ struct product_triangular_matrix_matrix_trmm(alpha_, alpha); \ \ /* Set m, n */ \ m = (MKL_INT)diagSize; \ @@ -175,7 +172,7 @@ struct product_triangular_matrix_matrix_trmm > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -184,9 +181,9 @@ struct product_triangular_matrix_matrix_trmm(alpha_, alpha); \ \ /* Set m, n */ \ m = (MKL_INT)rows; \ @@ -289,7 +282,7 @@ struct product_triangular_matrix_matrix_trmm > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -298,9 +291,9 @@ struct product_triangular_matrix_matrix_trmm(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ + EIGTYPE beta(1); \ \ /* Set m, n */ \ n = (MKL_INT)size; \ @@ -123,10 +121,10 @@ struct triangular_matrix_vector_product_trmv(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ + EIGTYPE beta(1); \ \ /* Set m, n */ \ n = (MKL_INT)size; \ @@ -207,10 +203,10 @@ struct triangular_matrix_vector_product_trmv dcomplex; typedef std::complex scomplex; +#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) +typedef int MKL_INT; +#endif + namespace internal { template @@ -125,6 +129,7 @@ static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenTyp mklScalar=eigenScalar; } +#ifdef EIGEN_USE_MKL template <> inline void assign_scalar_eig2mkl(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) { mklScalar.real=eigenScalar.real(); @@ -148,11 +153,14 @@ inline void assign_conj_scalar_eig2mkl(MKL_Complex8& mklS mklScalar.real=eigenScalar.real(); mklScalar.imag=-eigenScalar.imag(); } +#endif } // end namespace internal } // end namespace Eigen +#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) +#include "../../misc/blas.h" #endif #endif // EIGEN_MKL_SUPPORT_H From 8191f373befc6d02e473d99ce0d86e92ee3a8736 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 15:37:16 +0200 Subject: [PATCH 08/15] Silent unused warning. --- Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h index 1cdf48fbf..91b949137 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h @@ -81,7 +81,7 @@ struct general_matrix_matrix_rankupdate& /*blocking*/) \ + const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking& /*blocking*/) \ { \ /* typedef Matrix MatrixRhs;*/ \ \ @@ -106,7 +106,7 @@ struct general_matrix_matrix_rankupdate& /*blocking*/) \ + const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking& /*blocking*/) \ { \ typedef Matrix MatrixType; \ \ From ddabc992faad25b8c1fca0d0c5ae35ea34e778a4 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 15:52:01 +0200 Subject: [PATCH 09/15] Fix long to int conversion in BLAS API. --- .../GeneralMatrixMatrixTriangular_MKL.h | 20 +++--- .../Core/products/GeneralMatrixMatrix_MKL.h | 22 +++--- .../Core/products/GeneralMatrixVector_MKL.h | 11 +-- .../products/SelfadjointMatrixMatrix_MKL.h | 72 +++++++++---------- .../products/SelfadjointMatrixVector_MKL.h | 6 +- .../products/TriangularMatrixMatrix_MKL.h | 36 +++++----- .../products/TriangularMatrixVector_MKL.h | 48 ++++++------- .../products/TriangularSolverMatrix_MKL.h | 24 +++---- Eigen/src/Core/util/MKL_support.h | 4 +- 9 files changed, 123 insertions(+), 120 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h index 91b949137..6c835372c 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h @@ -72,7 +72,7 @@ EIGEN_MKL_RANKUPDATE_SPECIALIZE(float) // EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex) // SYRK for float/double -#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \ +#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, BLASTYPE, MKLFUNC) \ template \ struct general_matrix_matrix_rankupdate { \ enum { \ @@ -85,19 +85,19 @@ struct general_matrix_matrix_rankupdate MatrixRhs;*/ \ \ - MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ + BlasIndex lda=convert_index(lhsStride), ldc=convert_index(resStride), n=convert_index(size), k=convert_index(depth); \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \ - MKLTYPE alpha_, beta_; \ + BLASTYPE alpha_, beta_; \ \ /* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ + assign_scalar_eig2mkl(alpha_, alpha); \ + assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \ } \ }; // HERK for complex data -#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \ +#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, MKLFUNC) \ template \ struct general_matrix_matrix_rankupdate { \ enum { \ @@ -110,14 +110,14 @@ struct general_matrix_matrix_rankupdate MatrixType; \ \ - MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \ + BlasIndex lda=convert_index(lhsStride), ldc=convert_index(resStride), n=convert_index(size), k=convert_index(depth); \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \ RTYPE alpha_, beta_; \ const EIGTYPE* a_ptr; \ \ /* Set alpha_ & beta_ */ \ -/* assign_scalar_eig2mkl(alpha_, alpha); */\ -/* assign_scalar_eig2mkl(beta_, EIGTYPE(1));*/ \ +/* assign_scalar_eig2mkl(alpha_, alpha); */\ +/* assign_scalar_eig2mkl(beta_, EIGTYPE(1));*/ \ alpha_ = alpha.real(); \ beta_ = 1.0; \ /* Copy with conjugation in some cases*/ \ @@ -128,7 +128,7 @@ struct general_matrix_matrix_rankupdate(rows); \ + n = convert_index(cols); \ + k = convert_index(depth); \ \ /* Set lda, ldb, ldc */ \ - lda = (MKL_INT)lhsStride; \ - ldb = (MKL_INT)rhsStride; \ - ldc = (MKL_INT)resStride; \ + lda = convert_index(lhsStride); \ + ldb = convert_index(rhsStride); \ + ldc = convert_index(resStride); \ \ /* Set a, b, c */ \ if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \ Map > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \ a_tmp = lhs.conjugate(); \ a = a_tmp.data(); \ - lda = a_tmp.outerStride(); \ + lda = convert_index(a_tmp.outerStride()); \ } else a = _lhs; \ \ if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \ Map > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \ b_tmp = rhs.conjugate(); \ b = b_tmp.data(); \ - ldb = b_tmp.outerStride(); \ + ldb = convert_index(b_tmp.outerStride()); \ } else b = _rhs; \ \ - MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ + MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ }}; GEMM_SPECIALIZATION(double, d, double, d) diff --git a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h b/Eigen/src/Core/products/GeneralMatrixVector_MKL.h index fa5c9b6a0..c447c4aed 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixVector_MKL.h @@ -85,7 +85,7 @@ EIGEN_MKL_GEMV_SPECIALIZE(float) EIGEN_MKL_GEMV_SPECIALIZE(dcomplex) EIGEN_MKL_GEMV_SPECIALIZE(scomplex) -#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \ +#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLPREFIX) \ template \ struct general_matrix_vector_product_gemv \ { \ @@ -97,13 +97,14 @@ static void run( \ const EIGTYPE* rhs, Index rhsIncr, \ EIGTYPE* res, Index resIncr, EIGTYPE alpha) \ { \ - MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \ + BlasIndex m=convert_index(rows), n=convert_index(cols), \ + lda=convert_index(lhsStride), incx=convert_index(rhsIncr), incy=convert_index(resIncr); \ const EIGTYPE beta(1); \ const EIGTYPE *x_ptr; \ char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \ if (LhsStorageOrder==RowMajor) { \ - m = cols; \ - n = rows; \ + m = convert_index(cols); \ + n = convert_index(rows); \ }\ GEMVVector x_tmp; \ if (ConjugateRhs) { \ @@ -112,7 +113,7 @@ static void run( \ x_ptr=x_tmp.data(); \ incx=1; \ } else x_ptr=rhs; \ - MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ + MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ }\ }; diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h index 9c2e811dd..b1176962b 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h @@ -40,7 +40,7 @@ namespace internal { /* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */ -#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -55,20 +55,20 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='L', uplo='L'; \ - MKL_INT m, n, lda, ldb, ldc; \ + BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ \ /* Set transpose options */ \ /* Set m, n, k */ \ - m = (MKL_INT)rows; \ - n = (MKL_INT)cols; \ + m = convert_index(rows); \ + n = convert_index(cols); \ \ /* Set lda, ldb, ldc */ \ - lda = (MKL_INT)lhsStride; \ - ldb = (MKL_INT)rhsStride; \ - ldc = (MKL_INT)resStride; \ + lda = convert_index(lhsStride); \ + ldb = convert_index(rhsStride); \ + ldc = convert_index(resStride); \ \ /* Set a, b, c */ \ if (LhsStorageOrder==RowMajor) uplo='U'; \ @@ -78,16 +78,16 @@ struct product_selfadjoint_matrix > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \ b_tmp = rhs.adjoint(); \ b = b_tmp.data(); \ - ldb = b_tmp.outerStride(); \ + ldb = convert_index(b_tmp.outerStride()); \ } else b = _rhs; \ \ - MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ + MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -101,7 +101,7 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='L', uplo='L'; \ - MKL_INT m, n, lda, ldb, ldc; \ + BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ @@ -109,13 +109,13 @@ struct product_selfadjoint_matrix(rows); \ + n = convert_index(cols); \ \ /* Set lda, ldb, ldc */ \ - lda = (MKL_INT)lhsStride; \ - ldb = (MKL_INT)rhsStride; \ - ldc = (MKL_INT)resStride; \ + lda = convert_index(lhsStride); \ + ldb = convert_index(rhsStride); \ + ldc = convert_index(resStride); \ \ /* Set a, b, c */ \ if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \ @@ -141,10 +141,10 @@ struct product_selfadjoint_matrix(b_tmp.outerStride()); \ } \ \ - MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ + MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; @@ -157,7 +157,7 @@ EIGEN_MKL_HEMM_L(scomplex, float, cf, c) /* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */ -#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -172,19 +172,19 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='R', uplo='L'; \ - MKL_INT m, n, lda, ldb, ldc; \ + BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ \ /* Set m, n, k */ \ - m = (MKL_INT)rows; \ - n = (MKL_INT)cols; \ + m = convert_index(rows); \ + n = convert_index(cols); \ \ /* Set lda, ldb, ldc */ \ - lda = (MKL_INT)rhsStride; \ - ldb = (MKL_INT)lhsStride; \ - ldc = (MKL_INT)resStride; \ + lda = convert_index(rhsStride); \ + ldb = convert_index(lhsStride); \ + ldc = convert_index(resStride); \ \ /* Set a, b, c */ \ if (RhsStorageOrder==RowMajor) uplo='U'; \ @@ -194,16 +194,16 @@ struct product_selfadjoint_matrix > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \ b_tmp = lhs.adjoint(); \ b = b_tmp.data(); \ - ldb = b_tmp.outerStride(); \ + ldb = convert_index(b_tmp.outerStride()); \ } else b = _lhs; \ \ - MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &numext::real_ref(beta), (MKLTYPE*)res, &ldc); \ + MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -217,27 +217,27 @@ struct product_selfadjoint_matrix& /*blocking*/) \ { \ char side='R', uplo='L'; \ - MKL_INT m, n, lda, ldb, ldc; \ + BlasIndex m, n, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ EIGTYPE beta(1); \ MatrixX##EIGPREFIX b_tmp; \ Matrix a_tmp; \ \ /* Set m, n, k */ \ - m = (MKL_INT)rows; \ - n = (MKL_INT)cols; \ + m = convert_index(rows); \ + n = convert_index(cols); \ \ /* Set lda, ldb, ldc */ \ - lda = (MKL_INT)rhsStride; \ - ldb = (MKL_INT)lhsStride; \ - ldc = (MKL_INT)resStride; \ + lda = convert_index(rhsStride); \ + ldb = convert_index(lhsStride); \ + ldc = convert_index(resStride); \ \ /* Set a, b, c */ \ if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \ Map, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \ a_tmp = rhs.conjugate(); \ a = a_tmp.data(); \ - lda = a_tmp.outerStride(); \ + lda = convert_index(a_tmp.outerStride()); \ } else a = _rhs; \ if (RhsStorageOrder==RowMajor) uplo='U'; \ \ @@ -259,7 +259,7 @@ struct product_selfadjoint_matrix \ struct selfadjoint_matrix_vector_product_symv \ { \ @@ -85,7 +85,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \ IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \ IsLower = UpLo == Lower ? 1 : 0 \ }; \ - MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \ + BlasIndex n=convert_index(size), lda=convert_index(lhsStride), incx=1, incy=1; \ EIGTYPE beta(1); \ const EIGTYPE *x_ptr; \ char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \ @@ -95,7 +95,7 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \ x_tmp=map_x.conjugate(); \ x_ptr=x_tmp.data(); \ } else x_ptr=_rhs; \ - MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &numext::real_ref(beta), (MKLTYPE*)res, &incy); \ + MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ }\ }; diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h index 31f6d2007..47a8698a7 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h @@ -75,7 +75,7 @@ EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true) EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false) // implements col-major += alpha * op(triangular) * op(general) -#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -122,7 +122,7 @@ struct product_triangular_matrix_matrix_trmm > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ MatrixLhs aa_tmp=lhsMap.template triangularView(); \ - MKL_INT aStride = aa_tmp.outerStride(); \ + BlasIndex aStride = convert_index(aa_tmp.outerStride()); \ gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ general_matrix_matrix_product::run( \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ @@ -134,11 +134,11 @@ struct product_triangular_matrix_matrix_trmm(diagSize); \ + n = convert_index(cols); \ \ /* Set trans */ \ transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ @@ -149,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm(b_tmp.outerStride()); \ \ /* Set uplo */ \ uplo = IsLower ? 'L' : 'U'; \ @@ -165,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm(a_tmp.outerStride()); \ } else { \ a = _lhs; \ - lda = lhsStride; \ + lda = convert_index(lhsStride); \ } \ /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ + MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -186,7 +186,7 @@ EIGEN_MKL_TRMM_L(float, float, f, s) EIGEN_MKL_TRMM_L(scomplex, float, cf, c) // implements col-major += alpha * op(general) * op(triangular) -#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ @@ -232,7 +232,7 @@ struct product_triangular_matrix_matrix_trmm > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ MatrixRhs aa_tmp=rhsMap.template triangularView(); \ - MKL_INT aStride = aa_tmp.outerStride(); \ + BlasIndex aStride = convert_index(aa_tmp.outerStride()); \ gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ general_matrix_matrix_product::run( \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ @@ -244,11 +244,11 @@ struct product_triangular_matrix_matrix_trmm(rows); \ + n = convert_index(diagSize); \ \ /* Set trans */ \ transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ @@ -259,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm(b_tmp.outerStride()); \ \ /* Set uplo */ \ uplo = IsLower ? 'L' : 'U'; \ @@ -275,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm(a_tmp.outerStride()); \ } else { \ a = _rhs; \ - lda = rhsStride; \ + lda = convert_index(rhsStride); \ } \ /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ + MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ diff --git a/Eigen/src/Core/products/TriangularMatrixVector_MKL.h b/Eigen/src/Core/products/TriangularMatrixVector_MKL.h index 3aaea3457..17c9eeb44 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector_MKL.h +++ b/Eigen/src/Core/products/TriangularMatrixVector_MKL.h @@ -71,7 +71,7 @@ EIGEN_MKL_TRMV_SPECIALIZE(dcomplex) EIGEN_MKL_TRMV_SPECIALIZE(scomplex) // implements col-major: res += alpha * op(triangular) * vector -#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ struct triangular_matrix_vector_product_trmv { \ enum { \ @@ -105,15 +105,15 @@ struct triangular_matrix_vector_product_trmv(size); \ + lda = convert_index(lhsStride); \ incx = 1; \ - incy = resIncr; \ + incy = convert_index(resIncr); \ \ /* Set uplo, trans and diag*/ \ trans = 'N'; \ @@ -121,10 +121,10 @@ struct triangular_matrix_vector_product_trmv(rows-size); \ + n = convert_index(size); \ } \ else { \ x += size; \ y = _res; \ a = _lhs + size*lda; \ - m = size; \ - n = cols-size; \ + m = convert_index(size); \ + n = convert_index(cols-size); \ } \ - MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \ + MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; @@ -153,7 +153,7 @@ EIGEN_MKL_TRMV_CM(float, float, f, s) EIGEN_MKL_TRMV_CM(scomplex, float, cf, c) // implements row-major: res += alpha * op(triangular) * vector -#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_MKL_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ template \ struct triangular_matrix_vector_product_trmv { \ enum { \ @@ -187,15 +187,15 @@ struct triangular_matrix_vector_product_trmv(size); \ + lda = convert_index(lhsStride); \ incx = 1; \ - incy = resIncr; \ + incy = convert_index(resIncr); \ \ /* Set uplo, trans and diag*/ \ trans = ConjLhs ? 'C' : 'T'; \ @@ -203,10 +203,10 @@ struct triangular_matrix_vector_product_trmv(rows-size); \ + n = convert_index(size); \ } \ else { \ x += size; \ y = _res; \ a = _lhs + size; \ - m = size; \ - n = cols-size; \ + m = convert_index(size); \ + n = convert_index(cols-size); \ } \ - MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \ + MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; diff --git a/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h b/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h index 3677364e3..1f68a1cec 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h @@ -38,7 +38,7 @@ namespace Eigen { namespace internal { // implements LeftSide op(triangular)^-1 * general -#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \ +#define EIGEN_MKL_TRSM_L(EIGTYPE, BLASTYPE, MKLPREFIX) \ template \ struct triangular_solve_matrix \ { \ @@ -53,11 +53,11 @@ struct triangular_solve_matrix& /*blocking*/) \ { \ - MKL_INT m = size, n = otherSize, lda, ldb; \ + BlasIndex m = convert_index(size), n = convert_index(otherSize), lda, ldb; \ char side = 'L', uplo, diag='N', transa; \ /* Set alpha_ */ \ EIGTYPE alpha(1); \ - ldb = otherStride;\ + ldb = convert_index(otherStride);\ \ const EIGTYPE *a; \ /* Set trans */ \ @@ -73,14 +73,14 @@ struct triangular_solve_matrix(a_tmp.outerStride()); \ } else { \ a = _tri; \ - lda = triStride; \ + lda = convert_index(triStride); \ } \ if (IsUnitDiag) diag='U'; \ /* call ?trsm*/ \ - MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ + MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ } \ }; @@ -91,7 +91,7 @@ EIGEN_MKL_TRSM_L(scomplex, float, c) // implements RightSide general * op(triangular)^-1 -#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \ +#define EIGEN_MKL_TRSM_R(EIGTYPE, BLASTYPE, MKLPREFIX) \ template \ struct triangular_solve_matrix \ { \ @@ -106,11 +106,11 @@ struct triangular_solve_matrix& /*blocking*/) \ { \ - MKL_INT m = otherSize, n = size, lda, ldb; \ + BlasIndex m = convert_index(otherSize), n = convert_index(size), lda, ldb; \ char side = 'R', uplo, diag='N', transa; \ /* Set alpha_ */ \ EIGTYPE alpha(1); \ - ldb = otherStride;\ + ldb = convert_index(otherStride);\ \ const EIGTYPE *a; \ /* Set trans */ \ @@ -126,14 +126,14 @@ struct triangular_solve_matrix(a_tmp.outerStride()); \ } else { \ a = _tri; \ - lda = triStride; \ + lda = convert_index(triStride); \ } \ if (IsUnitDiag) diag='U'; \ /* call ?trsm*/ \ - MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \ + MKLPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \ /*std::cout << "TRMS_L specialization!\n";*/ \ } \ }; diff --git a/Eigen/src/Core/util/MKL_support.h b/Eigen/src/Core/util/MKL_support.h index de7847fc4..382014e66 100644 --- a/Eigen/src/Core/util/MKL_support.h +++ b/Eigen/src/Core/util/MKL_support.h @@ -114,7 +114,9 @@ typedef std::complex dcomplex; typedef std::complex scomplex; #if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) -typedef int MKL_INT; +typedef int BlasIndex; +#else +typedef MKL_INT BlasIndex; #endif namespace internal { From fec4c334bac76bfabd14168bf0ac668402f551a7 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 16:04:09 +0200 Subject: [PATCH 10/15] Remove all references to MKL in BLAS wrappers. --- Eigen/Core | 16 ++--- ...h => GeneralMatrixMatrixTriangular_BLAS.h} | 34 ++++----- ...atrix_MKL.h => GeneralMatrixMatrix_BLAS.h} | 12 ++-- ...ector_MKL.h => GeneralMatrixVector_BLAS.h} | 30 ++++---- ...x_MKL.h => SelfadjointMatrixMatrix_BLAS.h} | 40 +++++------ ...r_MKL.h => SelfadjointMatrixVector_BLAS.h} | 30 ++++---- ...ix_MKL.h => TriangularMatrixMatrix_BLAS.h} | 70 +++++++++---------- ...or_MKL.h => TriangularMatrixVector_BLAS.h} | 54 +++++++------- ...ix_MKL.h => TriangularSolverMatrix_BLAS.h} | 32 ++++----- 9 files changed, 159 insertions(+), 159 deletions(-) rename Eigen/src/Core/products/{GeneralMatrixMatrixTriangular_MKL.h => GeneralMatrixMatrixTriangular_BLAS.h} (86%) rename Eigen/src/Core/products/{GeneralMatrixMatrix_MKL.h => GeneralMatrixMatrix_BLAS.h} (91%) rename Eigen/src/Core/products/{GeneralMatrixVector_MKL.h => GeneralMatrixVector_BLAS.h} (86%) rename Eigen/src/Core/products/{SelfadjointMatrixMatrix_MKL.h => SelfadjointMatrixMatrix_BLAS.h} (85%) rename Eigen/src/Core/products/{SelfadjointMatrixVector_MKL.h => SelfadjointMatrixVector_BLAS.h} (83%) rename Eigen/src/Core/products/{TriangularMatrixMatrix_MKL.h => TriangularMatrixMatrix_BLAS.h} (83%) rename Eigen/src/Core/products/{TriangularMatrixVector_MKL.h => TriangularMatrixVector_BLAS.h} (82%) rename Eigen/src/Core/products/{TriangularSolverMatrix_MKL.h => TriangularSolverMatrix_BLAS.h} (85%) diff --git a/Eigen/Core b/Eigen/Core index 1e62f3ec1..30a572479 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -450,14 +450,14 @@ using std::ptrdiff_t; #include "src/Core/ArrayWrapper.h" #ifdef EIGEN_USE_BLAS -#include "src/Core/products/GeneralMatrixMatrix_MKL.h" -#include "src/Core/products/GeneralMatrixVector_MKL.h" -#include "src/Core/products/GeneralMatrixMatrixTriangular_MKL.h" -#include "src/Core/products/SelfadjointMatrixMatrix_MKL.h" -#include "src/Core/products/SelfadjointMatrixVector_MKL.h" -#include "src/Core/products/TriangularMatrixMatrix_MKL.h" -#include "src/Core/products/TriangularMatrixVector_MKL.h" -#include "src/Core/products/TriangularSolverMatrix_MKL.h" +#include "src/Core/products/GeneralMatrixMatrix_BLAS.h" +#include "src/Core/products/GeneralMatrixVector_BLAS.h" +#include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h" +#include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h" +#include "src/Core/products/SelfadjointMatrixVector_BLAS.h" +#include "src/Core/products/TriangularMatrixMatrix_BLAS.h" +#include "src/Core/products/TriangularMatrixVector_BLAS.h" +#include "src/Core/products/TriangularSolverMatrix_BLAS.h" #endif // EIGEN_USE_BLAS #ifdef EIGEN_USE_MKL_VML diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h similarity index 86% rename from Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h rename to Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h index 6c835372c..943d25bd1 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h @@ -25,13 +25,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************** - * Content : Eigen bindings to Intel(R) MKL + * Content : Eigen bindings to BLAS F77 * Level 3 BLAS SYRK/HERK implementation. ******************************************************************************** */ -#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H -#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H +#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H +#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H namespace Eigen { @@ -44,7 +44,7 @@ struct general_matrix_matrix_rankupdate : // try to go to BLAS specialization -#define EIGEN_MKL_RANKUPDATE_SPECIALIZE(Scalar) \ +#define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar) \ template \ struct general_matrix_matrix_triangular_product \ struct general_matrix_matrix_rankupdate { \ enum { \ @@ -92,12 +92,12 @@ struct general_matrix_matrix_rankupdate(alpha_, alpha); \ assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ - MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \ + BLASFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \ } \ }; // HERK for complex data -#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, MKLFUNC) \ +#define EIGEN_BLAS_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, BLASFUNC) \ template \ struct general_matrix_matrix_rankupdate { \ enum { \ @@ -128,21 +128,21 @@ struct general_matrix_matrix_rankupdate(b_tmp.outerStride()); \ } else b = _rhs; \ \ - MKLPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ }}; GEMM_SPECIALIZATION(double, d, double, d) @@ -112,4 +112,4 @@ GEMM_SPECIALIZATION(scomplex, cf, float, c) } // end namespace Eigen -#endif // EIGEN_GENERAL_MATRIX_MATRIX_MKL_H +#endif // EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H diff --git a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h b/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h similarity index 86% rename from Eigen/src/Core/products/GeneralMatrixVector_MKL.h rename to Eigen/src/Core/products/GeneralMatrixVector_BLAS.h index c447c4aed..e3a5d5892 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector_MKL.h +++ b/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h @@ -25,13 +25,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************** - * Content : Eigen bindings to Intel(R) MKL + * Content : Eigen bindings to BLAS F77 * General matrix-vector product functionality based on ?GEMV. ******************************************************************************** */ -#ifndef EIGEN_GENERAL_MATRIX_VECTOR_MKL_H -#define EIGEN_GENERAL_MATRIX_VECTOR_MKL_H +#ifndef EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H +#define EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H namespace Eigen { @@ -49,7 +49,7 @@ namespace internal { template struct general_matrix_vector_product_gemv; -#define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \ +#define EIGEN_BLAS_GEMV_SPECIALIZE(Scalar) \ template \ struct general_matrix_vector_product,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper,ConjugateRhs,Specialized> { \ static void run( \ @@ -80,12 +80,12 @@ static void run( \ } \ }; \ -EIGEN_MKL_GEMV_SPECIALIZE(double) -EIGEN_MKL_GEMV_SPECIALIZE(float) -EIGEN_MKL_GEMV_SPECIALIZE(dcomplex) -EIGEN_MKL_GEMV_SPECIALIZE(scomplex) +EIGEN_BLAS_GEMV_SPECIALIZE(double) +EIGEN_BLAS_GEMV_SPECIALIZE(float) +EIGEN_BLAS_GEMV_SPECIALIZE(dcomplex) +EIGEN_BLAS_GEMV_SPECIALIZE(scomplex) -#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLPREFIX) \ +#define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASPREFIX) \ template \ struct general_matrix_vector_product_gemv \ { \ @@ -113,17 +113,17 @@ static void run( \ x_ptr=x_tmp.data(); \ incx=1; \ } else x_ptr=rhs; \ - MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ + BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ }\ }; -EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d) -EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s) -EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, double, z) -EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, float, c) +EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, d) +EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, s) +EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, z) +EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, c) } // end namespase internal } // end namespace Eigen -#endif // EIGEN_GENERAL_MATRIX_VECTOR_MKL_H +#endif // EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h similarity index 85% rename from Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h rename to Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h index b1176962b..c3e37b1e0 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h @@ -25,13 +25,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // ******************************************************************************** - * Content : Eigen bindings to Intel(R) MKL + * Content : Eigen bindings to BLAS F77 * Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM. ******************************************************************************** */ -#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H -#define EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H +#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H +#define EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H namespace Eigen { @@ -40,7 +40,7 @@ namespace internal { /* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */ -#define EIGEN_MKL_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ template \ @@ -81,13 +81,13 @@ struct product_selfadjoint_matrix(b_tmp.outerStride()); \ } else b = _rhs; \ \ - MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_MKL_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ template \ @@ -144,20 +144,20 @@ struct product_selfadjoint_matrix(b_tmp.outerStride()); \ } \ \ - MKLPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -EIGEN_MKL_SYMM_L(double, double, d, d) -EIGEN_MKL_SYMM_L(float, float, f, s) -EIGEN_MKL_HEMM_L(dcomplex, double, cd, z) -EIGEN_MKL_HEMM_L(scomplex, float, cf, c) +EIGEN_BLAS_SYMM_L(double, double, d, d) +EIGEN_BLAS_SYMM_L(float, float, f, s) +EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z) +EIGEN_BLAS_HEMM_L(scomplex, float, cf, c) /* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */ -#define EIGEN_MKL_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ template \ @@ -197,13 +197,13 @@ struct product_selfadjoint_matrix(b_tmp.outerStride()); \ } else b = _lhs; \ \ - MKLPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ + BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ \ } \ }; -#define EIGEN_MKL_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ template \ @@ -259,17 +259,17 @@ struct product_selfadjoint_matrix {}; -#define EIGEN_MKL_SYMV_SPECIALIZE(Scalar) \ +#define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \ template \ struct selfadjoint_matrix_vector_product { \ static void run( \ @@ -66,12 +66,12 @@ static void run( \ } \ }; \ -EIGEN_MKL_SYMV_SPECIALIZE(double) -EIGEN_MKL_SYMV_SPECIALIZE(float) -EIGEN_MKL_SYMV_SPECIALIZE(dcomplex) -EIGEN_MKL_SYMV_SPECIALIZE(scomplex) +EIGEN_BLAS_SYMV_SPECIALIZE(double) +EIGEN_BLAS_SYMV_SPECIALIZE(float) +EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex) +EIGEN_BLAS_SYMV_SPECIALIZE(scomplex) -#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,MKLFUNC) \ +#define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \ template \ struct selfadjoint_matrix_vector_product_symv \ { \ @@ -95,17 +95,17 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \ x_tmp=map_x.conjugate(); \ x_ptr=x_tmp.data(); \ } else x_ptr=_rhs; \ - MKLFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ + BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \ }\ }; -EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv_) -EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv_) -EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, double, zhemv_) -EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, float, chemv_) +EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_) +EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_) +EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_) +EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_) } // end namespace internal } // end namespace Eigen -#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H +#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h similarity index 83% rename from Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h rename to Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h index 47a8698a7..aecded6bb 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h @@ -25,13 +25,13 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************** - * Content : Eigen bindings to Intel(R) MKL + * Content : Eigen bindings to BLAS F77 * Triangular matrix * matrix product functionality based on ?TRMM. ******************************************************************************** */ -#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H -#define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H +#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H +#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H namespace Eigen { @@ -50,7 +50,7 @@ struct product_triangular_matrix_matrix_trmm : // try to go to BLAS specialization -#define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ +#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ template \ @@ -65,17 +65,17 @@ struct product_triangular_matrix_matrix \ @@ -106,14 +106,14 @@ struct product_triangular_matrix_matrix_trmm MatrixLhs; \ typedef Matrix MatrixRhs; \ \ -/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ +/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \ if (rows != depth) { \ \ /* FIXME handle mkl_domain_get_max_threads */ \ - /*int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS);*/ int nthr = 1;\ + /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\ \ if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ - /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ + /* Most likely no benefit to call TRMM or GEMM from BLAS */ \ product_triangular_matrix_matrix::run( \ _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ @@ -127,7 +127,7 @@ struct product_triangular_matrix_matrix_trmm::run( \ rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ \ - /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ + /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ return; \ } \ @@ -170,9 +170,9 @@ struct product_triangular_matrix_matrix_trmm(lhsStride); \ } \ - /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ + /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ + BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -180,13 +180,13 @@ struct product_triangular_matrix_matrix_trmm \ @@ -217,13 +217,13 @@ struct product_triangular_matrix_matrix_trmm MatrixLhs; \ typedef Matrix MatrixRhs; \ \ -/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ +/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \ if (cols != depth) { \ \ - int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS)*/; \ + int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \ \ if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ - /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ + /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \ product_triangular_matrix_matrix::run( \ _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ @@ -237,7 +237,7 @@ struct product_triangular_matrix_matrix_trmm::run( \ rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ \ - /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ + /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ return; \ } \ @@ -280,9 +280,9 @@ struct product_triangular_matrix_matrix_trmm(rhsStride); \ } \ - /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ + /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \ /* call ?trmm*/ \ - MKLPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ + BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ \ /* Add op(a_triangular)*b into res*/ \ Map > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ @@ -290,13 +290,13 @@ struct product_triangular_matrix_matrix_trmm {}; -#define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \ +#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \ template \ struct triangular_matrix_vector_product { \ static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ @@ -65,13 +65,13 @@ struct triangular_matrix_vector_product \ struct triangular_matrix_vector_product_trmv { \ enum { \ @@ -121,11 +121,11 @@ struct triangular_matrix_vector_product_trmv(size); \ n = convert_index(cols-size); \ } \ - MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ + BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; -EIGEN_MKL_TRMV_CM(double, double, d, d) -EIGEN_MKL_TRMV_CM(dcomplex, double, cd, z) -EIGEN_MKL_TRMV_CM(float, float, f, s) -EIGEN_MKL_TRMV_CM(scomplex, float, cf, c) +EIGEN_BLAS_TRMV_CM(double, double, d, d) +EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z) +EIGEN_BLAS_TRMV_CM(float, float, f, s) +EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c) // implements row-major: res += alpha * op(triangular) * vector -#define EIGEN_MKL_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, MKLPREFIX) \ +#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ template \ struct triangular_matrix_vector_product_trmv { \ enum { \ @@ -203,11 +203,11 @@ struct triangular_matrix_vector_product_trmv(size); \ n = convert_index(cols-size); \ } \ - MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ + BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \ } \ } \ }; -EIGEN_MKL_TRMV_RM(double, double, d, d) -EIGEN_MKL_TRMV_RM(dcomplex, double, cd, z) -EIGEN_MKL_TRMV_RM(float, float, f, s) -EIGEN_MKL_TRMV_RM(scomplex, float, cf, c) +EIGEN_BLAS_TRMV_RM(double, double, d, d) +EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z) +EIGEN_BLAS_TRMV_RM(float, float, f, s) +EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c) } // end namespase internal } // end namespace Eigen -#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H +#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H diff --git a/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h similarity index 85% rename from Eigen/src/Core/products/TriangularSolverMatrix_MKL.h rename to Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h index 1f68a1cec..88c0fb794 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix_MKL.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h @@ -25,20 +25,20 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ******************************************************************************** - * Content : Eigen bindings to Intel(R) MKL + * Content : Eigen bindings to BLAS F77 * Triangular matrix * matrix product functionality based on ?TRMM. ******************************************************************************** */ -#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H -#define EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H +#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H +#define EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H namespace Eigen { namespace internal { // implements LeftSide op(triangular)^-1 * general -#define EIGEN_MKL_TRSM_L(EIGTYPE, BLASTYPE, MKLPREFIX) \ +#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \ template \ struct triangular_solve_matrix \ { \ @@ -80,18 +80,18 @@ struct triangular_solve_matrix \ struct triangular_solve_matrix \ { \ @@ -133,19 +133,19 @@ struct triangular_solve_matrix Date: Mon, 11 Apr 2016 16:09:29 +0200 Subject: [PATCH 11/15] Cleanup obsolete assign_scalar_eig2mkl helper. --- .../GeneralMatrixMatrixTriangular_BLAS.h | 11 +---- Eigen/src/Core/util/MKL_support.h | 40 ------------------- 2 files changed, 2 insertions(+), 49 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h index 943d25bd1..911df8ff3 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h @@ -87,12 +87,8 @@ struct general_matrix_matrix_rankupdate(lhsStride), ldc=convert_index(resStride), n=convert_index(size), k=convert_index(depth); \ char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \ - BLASTYPE alpha_, beta_; \ -\ -/* Set alpha_ & beta_ */ \ - assign_scalar_eig2mkl(alpha_, alpha); \ - assign_scalar_eig2mkl(beta_, EIGTYPE(1)); \ - BLASFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \ + EIGTYPE beta; \ + BLASFUNC(&uplo, &trans, &n, &k, &numext::real_ref(alpha), lhs, &lda, &numext::real_ref(beta), res, &ldc); \ } \ }; @@ -115,9 +111,6 @@ struct general_matrix_matrix_rankupdate(alpha_, alpha); */\ -/* assign_scalar_eig2mkl(beta_, EIGTYPE(1));*/ \ alpha_ = alpha.real(); \ beta_ = 1.0; \ /* Copy with conjugation in some cases*/ \ diff --git a/Eigen/src/Core/util/MKL_support.h b/Eigen/src/Core/util/MKL_support.h index 382014e66..bf47a626b 100644 --- a/Eigen/src/Core/util/MKL_support.h +++ b/Eigen/src/Core/util/MKL_support.h @@ -119,46 +119,6 @@ typedef int BlasIndex; typedef MKL_INT BlasIndex; #endif -namespace internal { - -template -static inline void assign_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) { - mklScalar=eigenScalar; -} - -template -static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) { - mklScalar=eigenScalar; -} - -#ifdef EIGEN_USE_MKL -template <> -inline void assign_scalar_eig2mkl(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) { - mklScalar.real=eigenScalar.real(); - mklScalar.imag=eigenScalar.imag(); -} - -template <> -inline void assign_scalar_eig2mkl(MKL_Complex8& mklScalar, const scomplex& eigenScalar) { - mklScalar.real=eigenScalar.real(); - mklScalar.imag=eigenScalar.imag(); -} - -template <> -inline void assign_conj_scalar_eig2mkl(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) { - mklScalar.real=eigenScalar.real(); - mklScalar.imag=-eigenScalar.imag(); -} - -template <> -inline void assign_conj_scalar_eig2mkl(MKL_Complex8& mklScalar, const scomplex& eigenScalar) { - mklScalar.real=eigenScalar.real(); - mklScalar.imag=-eigenScalar.imag(); -} -#endif - -} // end namespace internal - } // end namespace Eigen #if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) From 048343028371a8b7f79e4007a48caa8aff83e0de Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 17:12:31 +0200 Subject: [PATCH 12/15] Move LAPACK declarations from blas.h to lapack.h and fix compatibility with EIGEN_USE_MKL --- Eigen/src/Core/util/MKL_support.h | 8 +- Eigen/src/misc/blas.h | 176 +++--------------------------- Eigen/src/misc/lapack.h | 152 ++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 162 deletions(-) create mode 100644 Eigen/src/misc/lapack.h diff --git a/Eigen/src/Core/util/MKL_support.h b/Eigen/src/Core/util/MKL_support.h index bf47a626b..8c9239b1d 100644 --- a/Eigen/src/Core/util/MKL_support.h +++ b/Eigen/src/Core/util/MKL_support.h @@ -113,15 +113,15 @@ namespace Eigen { typedef std::complex dcomplex; typedef std::complex scomplex; -#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) -typedef int BlasIndex; -#else +#if defined(EIGEN_USE_MKL) typedef MKL_INT BlasIndex; +#else +typedef int BlasIndex; #endif } // end namespace Eigen -#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL) +#if defined(EIGEN_USE_BLAS) #include "../../misc/blas.h" #endif diff --git a/Eigen/src/misc/blas.h b/Eigen/src/misc/blas.h index ae0c393f1..25215b15e 100644 --- a/Eigen/src/misc/blas.h +++ b/Eigen/src/misc/blas.h @@ -30,15 +30,15 @@ int BLASFUNC(cdotcw) (int *, float *, int *, float *, int *, float*); int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*); int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*); -int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, int *); -int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, int *); -int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, int *); -int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, int *); -int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, int *); -int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, int *); -int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, int *); -int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, int *); -int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, int *); +int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, const int *); int BLASFUNC(scopy) (int *, float *, int *, float *, int *); int BLASFUNC(dcopy) (int *, double *, int *, double *, int *); @@ -229,9 +229,6 @@ int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, doub int BLASFUNC(ssymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); int BLASFUNC(dsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(qsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); -int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); -int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); int BLASFUNC(sspmv) (char *, int *, float *, float *, float *, int *, float *, float *, int *); @@ -239,38 +236,17 @@ int BLASFUNC(dspmv) (char *, int *, double *, double *, double *, int *, double *, double *, int *); int BLASFUNC(qspmv) (char *, int *, double *, double *, double *, int *, double *, double *, int *); -int BLASFUNC(cspmv) (char *, int *, float *, float *, - float *, int *, float *, float *, int *); -int BLASFUNC(zspmv) (char *, int *, double *, double *, - double *, int *, double *, double *, int *); -int BLASFUNC(xspmv) (char *, int *, double *, double *, - double *, int *, double *, double *, int *); -int BLASFUNC(ssyr) (char *, int *, float *, float *, int *, - float *, int *); -int BLASFUNC(dsyr) (char *, int *, double *, double *, int *, - double *, int *); -int BLASFUNC(qsyr) (char *, int *, double *, double *, int *, - double *, int *); -int BLASFUNC(csyr) (char *, int *, float *, float *, int *, - float *, int *); -int BLASFUNC(zsyr) (char *, int *, double *, double *, int *, - double *, int *); -int BLASFUNC(xsyr) (char *, int *, double *, double *, int *, - double *, int *); +int BLASFUNC(ssyr) (const char *, const int *, const float *, const float *, const int *, float *, const int *); +int BLASFUNC(dsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *); +int BLASFUNC(qsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *); -int BLASFUNC(ssyr2) (char *, int *, float *, - float *, int *, float *, int *, float *, int *); -int BLASFUNC(dsyr2) (char *, int *, double *, - double *, int *, double *, int *, double *, int *); -int BLASFUNC(qsyr2) (char *, int *, double *, - double *, int *, double *, int *, double *, int *); -int BLASFUNC(csyr2) (char *, int *, float *, - float *, int *, float *, int *, float *, int *); -int BLASFUNC(zsyr2) (char *, int *, double *, - double *, int *, double *, int *, double *, int *); -int BLASFUNC(xsyr2) (char *, int *, double *, - double *, int *, double *, int *, double *, int *); +int BLASFUNC(ssyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(dsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(qsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(csyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *); +int BLASFUNC(zsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *); +int BLASFUNC(xsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *); int BLASFUNC(sspr) (char *, int *, float *, float *, int *, float *); @@ -278,12 +254,6 @@ int BLASFUNC(dspr) (char *, int *, double *, double *, int *, double *); int BLASFUNC(qspr) (char *, int *, double *, double *, int *, double *); -int BLASFUNC(cspr) (char *, int *, float *, float *, int *, - float *); -int BLASFUNC(zspr) (char *, int *, double *, double *, int *, - double *); -int BLASFUNC(xspr) (char *, int *, double *, double *, int *, - double *); int BLASFUNC(sspr2) (char *, int *, float *, float *, int *, float *, int *, float *); @@ -462,116 +432,6 @@ int BLASFUNC(cher2m)(const char *, const char *, const char *, const int *, cons int BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); int BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *); -int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *, - float *, int *); -int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *, - double *, int *); -int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *, - float *, int *); -int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *, - double *, int *); - -int BLASFUNC(sgema)(char *, char *, int *, int *, float *, - float *, int *, float *, float *, int *, float *, int *); -int BLASFUNC(dgema)(char *, char *, int *, int *, double *, - double *, int *, double*, double *, int *, double*, int *); -int BLASFUNC(cgema)(char *, char *, int *, int *, float *, - float *, int *, float *, float *, int *, float *, int *); -int BLASFUNC(zgema)(char *, char *, int *, int *, double *, - double *, int *, double*, double *, int *, double*, int *); - -int BLASFUNC(sgems)(char *, char *, int *, int *, float *, - float *, int *, float *, float *, int *, float *, int *); -int BLASFUNC(dgems)(char *, char *, int *, int *, double *, - double *, int *, double*, double *, int *, double*, int *); -int BLASFUNC(cgems)(char *, char *, int *, int *, float *, - float *, int *, float *, float *, int *, float *, int *); -int BLASFUNC(zgems)(char *, char *, int *, int *, double *, - double *, int *, double*, double *, int *, double*, int *); - -int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *); -int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *); -int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *); - -int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *); -int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *); -int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *); -int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *); - -int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *); -int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *); -int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *); -int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *); -int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *); -int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *); - -int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *); -int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); -int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); -int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *); -int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); -int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); - -int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *); -int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *); -int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *); -int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *); -int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *); -int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *); - -int BLASFUNC(spotf2)(char *, int *, float *, int *, int *); -int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *); -int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *); -int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *); -int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *); -int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *); - -int BLASFUNC(spotrf)(char *, int *, float *, int *, int *); -int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *); -int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *); -int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *); -int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *); -int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *); - -int BLASFUNC(slauu2)(char *, int *, float *, int *, int *); -int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *); -int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *); -int BLASFUNC(clauu2)(char *, int *, float *, int *, int *); -int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *); -int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *); - -int BLASFUNC(slauum)(char *, int *, float *, int *, int *); -int BLASFUNC(dlauum)(char *, int *, double *, int *, int *); -int BLASFUNC(qlauum)(char *, int *, double *, int *, int *); -int BLASFUNC(clauum)(char *, int *, float *, int *, int *); -int BLASFUNC(zlauum)(char *, int *, double *, int *, int *); -int BLASFUNC(xlauum)(char *, int *, double *, int *, int *); - -int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *); -int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *); -int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *); - -int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *); -int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *); -int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *); -int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *); - -int BLASFUNC(spotri)(char *, int *, float *, int *, int *); -int BLASFUNC(dpotri)(char *, int *, double *, int *, int *); -int BLASFUNC(qpotri)(char *, int *, double *, int *, int *); -int BLASFUNC(cpotri)(char *, int *, float *, int *, int *); -int BLASFUNC(zpotri)(char *, int *, double *, int *, int *); -int BLASFUNC(xpotri)(char *, int *, double *, int *, int *); #ifdef __cplusplus } diff --git a/Eigen/src/misc/lapack.h b/Eigen/src/misc/lapack.h new file mode 100644 index 000000000..249f3575c --- /dev/null +++ b/Eigen/src/misc/lapack.h @@ -0,0 +1,152 @@ +#ifndef LAPACK_H +#define LAPACK_H + +#include "blas.h" + +#ifdef __cplusplus +extern "C" +{ +#endif + +int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *); +int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); + + +int BLASFUNC(cspmv) (char *, int *, float *, float *, + float *, int *, float *, float *, int *); +int BLASFUNC(zspmv) (char *, int *, double *, double *, + double *, int *, double *, double *, int *); +int BLASFUNC(xspmv) (char *, int *, double *, double *, + double *, int *, double *, double *, int *); + +int BLASFUNC(csyr) (char *, int *, float *, float *, int *, + float *, int *); +int BLASFUNC(zsyr) (char *, int *, double *, double *, int *, + double *, int *); +int BLASFUNC(xsyr) (char *, int *, double *, double *, int *, + double *, int *); + +int BLASFUNC(cspr) (char *, int *, float *, float *, int *, + float *); +int BLASFUNC(zspr) (char *, int *, double *, double *, int *, + double *); +int BLASFUNC(xspr) (char *, int *, double *, double *, int *, + double *); + +int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *, + float *, int *); +int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *, + double *, int *); +int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *, + float *, int *); +int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *, + double *, int *); + +int BLASFUNC(sgema)(char *, char *, int *, int *, float *, + float *, int *, float *, float *, int *, float *, int *); +int BLASFUNC(dgema)(char *, char *, int *, int *, double *, + double *, int *, double*, double *, int *, double*, int *); +int BLASFUNC(cgema)(char *, char *, int *, int *, float *, + float *, int *, float *, float *, int *, float *, int *); +int BLASFUNC(zgema)(char *, char *, int *, int *, double *, + double *, int *, double*, double *, int *, double*, int *); + +int BLASFUNC(sgems)(char *, char *, int *, int *, float *, + float *, int *, float *, float *, int *, float *, int *); +int BLASFUNC(dgems)(char *, char *, int *, int *, double *, + double *, int *, double*, double *, int *, double*, int *); +int BLASFUNC(cgems)(char *, char *, int *, int *, float *, + float *, int *, float *, float *, int *, float *, int *); +int BLASFUNC(zgems)(char *, char *, int *, int *, double *, + double *, int *, double*, double *, int *, double*, int *); + +int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *); +int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *); +int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *); + +int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *); +int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *); +int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *); +int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *); + +int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *); +int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *); +int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *); +int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *); +int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *); +int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *); + +int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *); +int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); +int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); +int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *); +int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); +int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *); + +int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *); +int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *); +int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *); +int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *); +int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *); +int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *); + +int BLASFUNC(spotf2)(char *, int *, float *, int *, int *); +int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *); +int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *); +int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *); +int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *); +int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *); + +int BLASFUNC(spotrf)(char *, int *, float *, int *, int *); +int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *); +int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *); +int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *); +int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *); +int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *); + +int BLASFUNC(slauu2)(char *, int *, float *, int *, int *); +int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *); +int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *); +int BLASFUNC(clauu2)(char *, int *, float *, int *, int *); +int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *); +int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *); + +int BLASFUNC(slauum)(char *, int *, float *, int *, int *); +int BLASFUNC(dlauum)(char *, int *, double *, int *, int *); +int BLASFUNC(qlauum)(char *, int *, double *, int *, int *); +int BLASFUNC(clauum)(char *, int *, float *, int *, int *); +int BLASFUNC(zlauum)(char *, int *, double *, int *, int *); +int BLASFUNC(xlauum)(char *, int *, double *, int *, int *); + +int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *); +int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *); +int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *); + +int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *); +int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *); +int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *); +int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *); + +int BLASFUNC(spotri)(char *, int *, float *, int *, int *); +int BLASFUNC(dpotri)(char *, int *, double *, int *, int *); +int BLASFUNC(qpotri)(char *, int *, double *, int *, int *); +int BLASFUNC(cpotri)(char *, int *, float *, int *, int *); +int BLASFUNC(zpotri)(char *, int *, double *, int *, int *); +int BLASFUNC(xpotri)(char *, int *, double *, int *, int *); + +#ifdef __cplusplus +} +#endif + +#endif From 91bf925fc17c50a7898d84e56ce3fbbd93e7d920 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 17:13:01 +0200 Subject: [PATCH 13/15] Improve constness of level2 blas API. --- blas/common.h | 10 +++++----- blas/level1_impl.h | 6 +++--- blas/level2_cplx_impl.h | 13 +++++++------ blas/level2_impl.h | 33 +++++++++++++++++---------------- blas/level2_real_impl.h | 33 +++++++++++++++++---------------- lapack/lapack_common.h | 1 + 6 files changed, 50 insertions(+), 46 deletions(-) diff --git a/blas/common.h b/blas/common.h index acb50af1b..61d8344d9 100644 --- a/blas/common.h +++ b/blas/common.h @@ -10,8 +10,8 @@ #ifndef EIGEN_BLAS_COMMON_H #define EIGEN_BLAS_COMMON_H -#include -#include +#include "../Eigen/Core" +#include "../Eigen/Jacobi" #include @@ -19,8 +19,7 @@ #error the token SCALAR must be defined to compile this file #endif -#include - +#include "../Eigen/src/misc/blas.h" #define NOTR 0 #define TR 1 @@ -94,6 +93,7 @@ enum typedef Matrix PlainMatrixType; typedef Map, 0, OuterStride<> > MatrixType; +typedef Map, 0, OuterStride<> > ConstMatrixType; typedef Map, 0, InnerStride > StridedVectorType; typedef Map > CompactVectorType; @@ -141,7 +141,7 @@ T* get_compact_vector(T* x, int n, int incx) if(incx==1) return x; - T* ret = new Scalar[n]; + typename Eigen::internal::remove_const::type* ret = new Scalar[n]; if(incx<0) make_vector(ret,n) = make_vector(x,n,-incx).reverse(); else make_vector(ret,n) = make_vector(x,n, incx); return ret; diff --git a/blas/level1_impl.h b/blas/level1_impl.h index e623bd178..f857bfa20 100644 --- a/blas/level1_impl.h +++ b/blas/level1_impl.h @@ -9,11 +9,11 @@ #include "common.h" -int EIGEN_BLAS_FUNC(axpy)(int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy) +int EIGEN_BLAS_FUNC(axpy)(const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *py, const int *incy) { - Scalar* x = reinterpret_cast(px); + const Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); if(*n<=0) return 0; diff --git a/blas/level2_cplx_impl.h b/blas/level2_cplx_impl.h index 2edc51596..e3ce61435 100644 --- a/blas/level2_cplx_impl.h +++ b/blas/level2_cplx_impl.h @@ -16,7 +16,8 @@ * where alpha and beta are scalars, x and y are n element vectors and * A is an n by n hermitian matrix. */ -int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy) +int EIGEN_BLAS_FUNC(hemv)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, + const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy) { typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar); static const functype func[2] = { @@ -26,11 +27,11 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa (internal::selfadjoint_matrix_vector_product::run), }; - Scalar* a = reinterpret_cast(pa); - Scalar* x = reinterpret_cast(px); + const Scalar* a = reinterpret_cast(pa); + const Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); // check arguments int info = 0; @@ -45,7 +46,7 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa if(*n==0) return 1; - Scalar* actual_x = get_compact_vector(x,*n,*incx); + const Scalar* actual_x = get_compact_vector(x,*n,*incx); Scalar* actual_y = get_compact_vector(y,*n,*incy); if(beta!=Scalar(1)) diff --git a/blas/level2_impl.h b/blas/level2_impl.h index d09db0cc6..173f40b44 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -23,7 +23,8 @@ struct general_matrix_vector_product_wrapper } }; -int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc) +int EIGEN_BLAS_FUNC(gemv)(const char *opa, const int *m, const int *n, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *incb, const RealScalar *pbeta, RealScalar *pc, const int *incc) { typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar); static const functype func[4] = { @@ -36,11 +37,11 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca 0 }; - Scalar* a = reinterpret_cast(pa); - Scalar* b = reinterpret_cast(pb); + const Scalar* a = reinterpret_cast(pa); + const Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); // check arguments int info = 0; @@ -62,7 +63,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca if(code!=NOTR) std::swap(actual_m,actual_n); - Scalar* actual_b = get_compact_vector(b,actual_n,*incb); + const Scalar* actual_b = get_compact_vector(b,actual_n,*incb); Scalar* actual_c = get_compact_vector(c,actual_m,*incc); if(beta!=Scalar(1)) @@ -82,7 +83,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca return 1; } -int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) +int EIGEN_BLAS_FUNC(trsv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb) { typedef void (*functype)(int, const Scalar *, int, Scalar *); static const functype func[16] = { @@ -116,7 +117,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar 0 }; - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); int info = 0; @@ -141,7 +142,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar -int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb) +int EIGEN_BLAS_FUNC(trmv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb) { typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, const Scalar&); static const functype func[16] = { @@ -175,7 +176,7 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar 0 }; - Scalar* a = reinterpret_cast(pa); + const Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); int info = 0; @@ -217,11 +218,11 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy) { - Scalar* a = reinterpret_cast(pa); - Scalar* x = reinterpret_cast(px); + const Scalar* a = reinterpret_cast(pa); + const Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); int coeff_rows = *kl+*ku+1; int info = 0; @@ -244,7 +245,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca if(OP(*trans)!=NOTR) std::swap(actual_m,actual_n); - Scalar* actual_x = get_compact_vector(x,actual_n,*incx); + const Scalar* actual_x = get_compact_vector(x,actual_n,*incx); Scalar* actual_y = get_compact_vector(y,actual_m,*incy); if(beta!=Scalar(1)) @@ -253,7 +254,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca else make_vector(actual_y, actual_m) *= beta; } - MatrixType mat_coeffs(a,coeff_rows,*n,*lda); + ConstMatrixType mat_coeffs(a,coeff_rows,*n,*lda); int nb = std::min(*n,(*m)+(*ku)); for(int j=0; j::run), }; - Scalar* a = reinterpret_cast(pa); - Scalar* x = reinterpret_cast(px); + const Scalar* a = reinterpret_cast(pa); + const Scalar* x = reinterpret_cast(px); Scalar* y = reinterpret_cast(py); - Scalar alpha = *reinterpret_cast(palpha); - Scalar beta = *reinterpret_cast(pbeta); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); // check arguments int info = 0; @@ -39,7 +40,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p if(*n==0) return 0; - Scalar* actual_x = get_compact_vector(x,*n,*incx); + const Scalar* actual_x = get_compact_vector(x,*n,*incx); Scalar* actual_y = get_compact_vector(y,*n,*incy); if(beta!=Scalar(1)) @@ -61,7 +62,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p } // C := alpha*x*x' + C -int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(syr)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *pc, const int *ldc) { typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&); @@ -72,9 +73,9 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, (selfadjoint_rank1_update::run), }; - Scalar* x = reinterpret_cast(px); + const Scalar* x = reinterpret_cast(px); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); int info = 0; if(UPLO(*uplo)==INVALID) info = 1; @@ -87,7 +88,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, if(*n==0 || alpha==Scalar(0)) return 1; // if the increment is not 1, let's copy it to a temporary vector to enable vectorization - Scalar* x_cpy = get_compact_vector(x,*n,*incx); + const Scalar* x_cpy = get_compact_vector(x,*n,*incx); int code = UPLO(*uplo); if(code>=2 || func[code]==0) @@ -101,7 +102,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, } // C := alpha*x*y' + alpha*y*x' + C -int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc) +int EIGEN_BLAS_FUNC(syr2)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, const RealScalar *py, const int *incy, RealScalar *pc, const int *ldc) { typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar); static const functype func[2] = { @@ -111,10 +112,10 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px (internal::rank2_update_selector::run), }; - Scalar* x = reinterpret_cast(px); - Scalar* y = reinterpret_cast(py); + const Scalar* x = reinterpret_cast(px); + const Scalar* y = reinterpret_cast(py); Scalar* c = reinterpret_cast(pc); - Scalar alpha = *reinterpret_cast(palpha); + Scalar alpha = *reinterpret_cast(palpha); int info = 0; if(UPLO(*uplo)==INVALID) info = 1; @@ -128,8 +129,8 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px if(alpha==Scalar(0)) return 1; - Scalar* x_cpy = get_compact_vector(x,*n,*incx); - Scalar* y_cpy = get_compact_vector(y,*n,*incy); + const Scalar* x_cpy = get_compact_vector(x,*n,*incx); + const Scalar* y_cpy = get_compact_vector(y,*n,*incy); int code = UPLO(*uplo); if(code>=2 || func[code]==0) diff --git a/lapack/lapack_common.h b/lapack/lapack_common.h index a93598784..c872a813e 100644 --- a/lapack/lapack_common.h +++ b/lapack/lapack_common.h @@ -11,6 +11,7 @@ #define EIGEN_LAPACK_COMMON_H #include "../blas/common.h" +#include "../Eigen/src/misc/lapack.h" #define EIGEN_LAPACK_FUNC(FUNC,ARGLIST) \ extern "C" { int EIGEN_BLAS_FUNC(FUNC) ARGLIST; } \ From 1744b5b5d2a488706cb26ff608741548d4853aa4 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 11 Apr 2016 17:16:07 +0200 Subject: [PATCH 14/15] Update doc regarding the genericity of EIGEN_USE_BLAS --- doc/UsingIntelMKL.dox | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/UsingIntelMKL.dox b/doc/UsingIntelMKL.dox index 02c62ad85..dbe559e53 100644 --- a/doc/UsingIntelMKL.dox +++ b/doc/UsingIntelMKL.dox @@ -55,7 +55,7 @@ Operations on other scalar types or mixing reals and complexes will continue to In addition you can choose which parts will be substituted by defining one or multiple of the following macros: - + From 833efb39bfe4957934982112fe435ab30a0c3b4f Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 11 Apr 2016 11:03:56 -0700 Subject: [PATCH 15/15] Added epsilon, dummy_precision, infinity and quiet_NaN NumTraits for fp16 --- Eigen/src/Core/arch/CUDA/Half.h | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/Eigen/src/Core/arch/CUDA/Half.h b/Eigen/src/Core/arch/CUDA/Half.h index 3be7e88d7..281b8e4c6 100644 --- a/Eigen/src/Core/arch/CUDA/Half.h +++ b/Eigen/src/Core/arch/CUDA/Half.h @@ -366,13 +366,22 @@ template<> struct is_arithmetic { enum { value = true }; }; template<> struct NumTraits : GenericNumTraits { - EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float dummy_precision() { return 1e-3f; } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half epsilon() { + return internal::raw_uint16_to_half(0x0800); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half dummy_precision() { return half(1e-3f); } EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half highest() { return internal::raw_uint16_to_half(0x7bff); } EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half lowest() { return internal::raw_uint16_to_half(0xfbff); } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half infinity() { + return internal::raw_uint16_to_half(0x7c00); + } + EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() { + return internal::raw_uint16_to_half(0x7c01); + } }; // Infinity/NaN checks.
\c EIGEN_USE_BLAS Enables the use of external BLAS level 2 and 3 routines (currently works with Intel MKL only)
\c EIGEN_USE_BLAS Enables the use of external BLAS level 2 and 3 routines (compatible with any F77 BLAS interface, not only Intel MKL)
\c EIGEN_USE_LAPACKE Enables the use of external Lapack routines via the Intel Lapacke C interface to Lapack (currently works with Intel MKL only)
\c EIGEN_USE_LAPACKE_STRICT Same as \c EIGEN_USE_LAPACKE but algorithm of lower robustness are disabled. This currently concerns only JacobiSVD which otherwise would be replaced by \c gesvd that is less robust than Jacobi rotations.
\c EIGEN_USE_MKL_VML Enables the use of Intel VML (vector operations)