Pulled latest update from trunk

This commit is contained in:
Benoit Steiner 2016-04-11 11:03:02 -07:00
commit e939b087fe
26 changed files with 716 additions and 794 deletions

View File

@ -450,14 +450,14 @@ using std::ptrdiff_t;
#include "src/Core/ArrayWrapper.h"
#ifdef EIGEN_USE_BLAS
#include "src/Core/products/GeneralMatrixMatrix_MKL.h"
#include "src/Core/products/GeneralMatrixVector_MKL.h"
#include "src/Core/products/GeneralMatrixMatrixTriangular_MKL.h"
#include "src/Core/products/SelfadjointMatrixMatrix_MKL.h"
#include "src/Core/products/SelfadjointMatrixVector_MKL.h"
#include "src/Core/products/TriangularMatrixMatrix_MKL.h"
#include "src/Core/products/TriangularMatrixVector_MKL.h"
#include "src/Core/products/TriangularSolverMatrix_MKL.h"
#include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
#include "src/Core/products/GeneralMatrixVector_BLAS.h"
#include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h"
#include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h"
#include "src/Core/products/SelfadjointMatrixVector_BLAS.h"
#include "src/Core/products/TriangularMatrixMatrix_BLAS.h"
#include "src/Core/products/TriangularMatrixVector_BLAS.h"
#include "src/Core/products/TriangularSolverMatrix_BLAS.h"
#endif // EIGEN_USE_BLAS
#ifdef EIGEN_USE_MKL_VML

View File

@ -788,8 +788,8 @@ template<typename Dst, typename Src> void check_for_aliasing(const Dst &dst, con
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
@ -806,8 +806,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
struct Assignment<DstXprType, SrcXprType, Functor, EigenBase2EigenBase, Scalar>
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
{
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
src.evalTo(dst);

View File

@ -79,8 +79,8 @@ namespace cephes {
*/
template <typename Scalar, int N>
struct polevl {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static Scalar run(const Scalar x, const Scalar coef[]) {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x, const Scalar coef[]) {
EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
return polevl<Scalar, N - 1>::run(x, coef) * x + coef[N];
@ -89,8 +89,8 @@ struct polevl {
template <typename Scalar>
struct polevl<Scalar, 0> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static Scalar run(const Scalar, const Scalar coef[]) {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar, const Scalar coef[]) {
return coef[0];
}
};
@ -144,7 +144,7 @@ struct digamma_retval {
template <typename Scalar>
struct digamma_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar x) {
static EIGEN_STRONG_INLINE Scalar run(Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0);
@ -428,20 +428,20 @@ template <typename Scalar> struct igamma_impl; // predeclare igamma_impl
template <typename Scalar>
struct igamma_helper {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
};
template <>
struct igamma_helper<float> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static float machep() {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE float machep() {
return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static float big() {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE float big() {
// use epsneg (1.0 - epsneg == 1.0)
return 1.0 / (NumTraits<float>::epsilon() / 2);
}
@ -449,12 +449,12 @@ struct igamma_helper<float> {
template <>
struct igamma_helper<double> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static double machep() {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE double machep() {
return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static double big() {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE double big() {
return 1.0 / NumTraits<double>::epsilon();
}
};
@ -605,7 +605,7 @@ struct igamma_retval {
template <typename Scalar>
struct igamma_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar a, Scalar x) {
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0);
@ -736,7 +736,7 @@ struct zeta_retval {
template <typename Scalar>
struct zeta_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar x, Scalar q) {
static EIGEN_STRONG_INLINE Scalar run(Scalar x, Scalar q) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0);
@ -757,8 +757,8 @@ struct zeta_impl_series {
template <>
struct zeta_impl_series<float> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static bool run(float& a, float& b, float& s, const float x, const float machep) {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) {
int i = 0;
while(i < 9)
{
@ -777,8 +777,8 @@ struct zeta_impl_series<float> {
template <>
struct zeta_impl_series<double> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static bool run(double& a, double& b, double& s, const double x, const double machep) {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) {
int i = 0;
while( (i < 9) || (a <= 9.0) )
{
@ -881,13 +881,14 @@ struct zeta_impl {
const Scalar maxnum = NumTraits<Scalar>::infinity();
const Scalar zero = 0.0, half = 0.5, one = 1.0;
const Scalar machep = igamma_helper<Scalar>::machep();
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
if( x == one )
return maxnum;
if( x < one )
{
return zero;
return nan;
}
if( q <= zero )
@ -899,7 +900,7 @@ struct zeta_impl {
p = x;
r = numext::floor(p);
if (p != r)
return zero;
return nan;
}
/* Permit negative q but continue sum until n+q > +9 .
@ -954,7 +955,7 @@ struct polygamma_retval {
template <typename Scalar>
struct polygamma_impl {
EIGEN_DEVICE_FUNC
static Scalar run(Scalar n, Scalar x) {
static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0);
@ -969,9 +970,14 @@ struct polygamma_impl {
static Scalar run(Scalar n, Scalar x) {
Scalar zero = 0.0, one = 1.0;
Scalar nplus = n + one;
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
// Check that n is an integer
if (numext::floor(n) != n) {
return nan;
}
// Just return the digamma function for n = 1
if (n == zero) {
else if (n == zero) {
return digamma_impl<Scalar>::run(x);
}
// Use the same implementation as scipy

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Level 3 BLAS SYRK/HERK implementation.
********************************************************************************
*/
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
namespace Eigen {
@ -44,34 +44,35 @@ struct general_matrix_matrix_rankupdate :
// try to go to BLAS specialization
#define EIGEN_MKL_RANKUPDATE_SPECIALIZE(Scalar) \
#define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar) \
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs, int UpLo> \
struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) \
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
{ \
if (lhs==rhs) { \
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} else { \
general_matrix_matrix_triangular_product<Index, \
Scalar, LhsStorageOrder, ConjugateLhs, \
Scalar, RhsStorageOrder, ConjugateRhs, \
ColMajor, UpLo, BuiltIn> \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
} \
} \
};
EIGEN_MKL_RANKUPDATE_SPECIALIZE(double)
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex)
EIGEN_MKL_RANKUPDATE_SPECIALIZE(float)
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
EIGEN_BLAS_RANKUPDATE_SPECIALIZE(double)
EIGEN_BLAS_RANKUPDATE_SPECIALIZE(float)
// TODO handle complex cases
// EIGEN_BLAS_RANKUPDATE_SPECIALIZE(dcomplex)
// EIGEN_BLAS_RANKUPDATE_SPECIALIZE(scomplex)
// SYRK for float/double
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \
#define EIGEN_BLAS_RANKUPDATE_R(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \
@ -80,23 +81,19 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \
}; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
\
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \
MKLTYPE alpha_, beta_; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \
EIGTYPE beta; \
BLASFUNC(&uplo, &trans, &n, &k, &numext::real_ref(alpha), lhs, &lda, &numext::real_ref(beta), res, &ldc); \
} \
};
// HERK for complex data
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \
#define EIGEN_BLAS_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, BLASFUNC) \
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
enum { \
@ -105,18 +102,15 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \
}; \
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
\
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \
RTYPE alpha_, beta_; \
const EIGTYPE* a_ptr; \
\
/* Set alpha_ & beta_ */ \
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); */\
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
alpha_ = alpha.real(); \
beta_ = 1.0; \
/* Copy with conjugation in some cases*/ \
@ -127,20 +121,21 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
lda = a.outerStride(); \
a_ptr = a.data(); \
} else a_ptr=lhs; \
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (MKLTYPE*)a_ptr, &lda, &beta_, (MKLTYPE*)res, &ldc); \
BLASFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
} \
};
EIGEN_MKL_RANKUPDATE_R(double, double, dsyrk)
EIGEN_MKL_RANKUPDATE_R(float, float, ssyrk)
EIGEN_BLAS_RANKUPDATE_R(double, double, dsyrk_)
EIGEN_BLAS_RANKUPDATE_R(float, float, ssyrk_)
//EIGEN_MKL_RANKUPDATE_C(dcomplex, MKL_Complex16, double, zherk)
//EIGEN_MKL_RANKUPDATE_C(scomplex, MKL_Complex8, double, cherk)
// TODO hanlde complex cases
// EIGEN_BLAS_RANKUPDATE_C(dcomplex, double, double, zherk_)
// EIGEN_BLAS_RANKUPDATE_C(scomplex, float, float, cherk_)
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* General matrix-matrix product functionality based on ?GEMM.
********************************************************************************
*/
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
#define EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
#define EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
namespace Eigen {
@ -46,7 +46,7 @@ namespace internal {
// gemm specialization
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, MKLTYPE, MKLPREFIX) \
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASPREFIX) \
template< \
typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
@ -66,55 +66,50 @@ static void run(Index rows, Index cols, Index depth, \
using std::conj; \
\
char transa, transb; \
MKL_INT m, n, k, lda, ldb, ldc; \
BlasIndex m, n, k, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX a_tmp, b_tmp; \
EIGTYPE myone(1);\
\
/* Set transpose options */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
\
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
k = (MKL_INT)depth; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
k = convert_index<BlasIndex>(depth); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
ldc = (MKL_INT)resStride; \
lda = convert_index<BlasIndex>(lhsStride); \
ldb = convert_index<BlasIndex>(rhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \
a_tmp = lhs.conjugate(); \
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _lhs; \
\
if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \
b_tmp = rhs.conjugate(); \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
MKLPREFIX##gemm(&transa, &transb, &m, &n, &k, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
}};
GEMM_SPECIALIZATION(double, d, double, d)
GEMM_SPECIALIZATION(float, f, float, s)
GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, z)
GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, c)
GEMM_SPECIALIZATION(dcomplex, cd, double, z)
GEMM_SPECIALIZATION(scomplex, cf, float, c)
} // end namespase internal
} // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
#endif // EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* General matrix-vector product functionality based on ?GEMV.
********************************************************************************
*/
#ifndef EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
#define EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
#ifndef EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
#define EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
namespace Eigen {
@ -49,7 +49,7 @@ namespace internal {
template<typename Index, typename LhsScalar, int StorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
struct general_matrix_vector_product_gemv;
#define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \
#define EIGEN_BLAS_GEMV_SPECIALIZE(Scalar) \
template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,Specialized> { \
static void run( \
@ -80,12 +80,12 @@ static void run( \
} \
}; \
EIGEN_MKL_GEMV_SPECIALIZE(double)
EIGEN_MKL_GEMV_SPECIALIZE(float)
EIGEN_MKL_GEMV_SPECIALIZE(dcomplex)
EIGEN_MKL_GEMV_SPECIALIZE(scomplex)
EIGEN_BLAS_GEMV_SPECIALIZE(double)
EIGEN_BLAS_GEMV_SPECIALIZE(float)
EIGEN_BLAS_GEMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \
#define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASPREFIX) \
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
{ \
@ -97,16 +97,15 @@ static void run( \
const EIGTYPE* rhs, Index rhsIncr, \
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
{ \
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \
MKLTYPE alpha_, beta_; \
const EIGTYPE *x_ptr, myone(1); \
BlasIndex m=convert_index<BlasIndex>(rows), n=convert_index<BlasIndex>(cols), \
lda=convert_index<BlasIndex>(lhsStride), incx=convert_index<BlasIndex>(rhsIncr), incy=convert_index<BlasIndex>(resIncr); \
const EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
if (LhsStorageOrder==RowMajor) { \
m=cols; \
n=rows; \
m = convert_index<BlasIndex>(cols); \
n = convert_index<BlasIndex>(rows); \
}\
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
GEMVVector x_tmp; \
if (ConjugateRhs) { \
Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \
@ -114,17 +113,17 @@ static void run( \
x_ptr=x_tmp.data(); \
incx=1; \
} else x_ptr=rhs; \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\
};
EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d)
EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s)
EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z)
EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c)
EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, d)
EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, s)
EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, z)
EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, c)
} // end namespase internal
} // end namespace Eigen
#endif // EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
#endif // EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM.
********************************************************************************
*/
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
namespace Eigen {
@ -40,7 +40,7 @@ namespace internal {
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -52,28 +52,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
ldc = (MKL_INT)resStride; \
lda = convert_index<BlasIndex>(lhsStride); \
ldb = convert_index<BlasIndex>(rhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (LhsStorageOrder==RowMajor) uplo='U'; \
@ -83,16 +78,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.adjoint(); \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -103,29 +98,24 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='L', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)lhsStride; \
ldb = (MKL_INT)rhsStride; \
ldc = (MKL_INT)resStride; \
lda = convert_index<BlasIndex>(lhsStride); \
ldb = convert_index<BlasIndex>(rhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
@ -151,23 +141,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
b_tmp = rhs.transpose(); \
} \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \
\
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
EIGEN_MKL_SYMM_L(double, double, d, d)
EIGEN_MKL_SYMM_L(float, float, f, s)
EIGEN_MKL_HEMM_L(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_HEMM_L(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_SYMM_L(double, double, d, d)
EIGEN_BLAS_SYMM_L(float, float, f, s)
EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z)
EIGEN_BLAS_HEMM_L(scomplex, float, cf, c)
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -179,27 +169,22 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
EIGTYPE myone(1);\
\
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \
ldb = (MKL_INT)lhsStride; \
ldc = (MKL_INT)resStride; \
lda = convert_index<BlasIndex>(rhsStride); \
ldb = convert_index<BlasIndex>(lhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (RhsStorageOrder==RowMajor) uplo='U'; \
@ -209,16 +194,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = lhs.adjoint(); \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _lhs; \
\
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -229,35 +214,30 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha) \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
char side='R', uplo='L'; \
MKL_INT m, n, lda, ldb, ldc; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
MKLTYPE alpha_, beta_; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
EIGTYPE myone(1); \
\
/* Set m, n, k */ \
m = (MKL_INT)rows; \
n = (MKL_INT)cols; \
\
/* Set alpha_ & beta_ */ \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = (MKL_INT)rhsStride; \
ldb = (MKL_INT)lhsStride; \
ldc = (MKL_INT)resStride; \
lda = convert_index<BlasIndex>(rhsStride); \
ldb = convert_index<BlasIndex>(lhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
a_tmp = rhs.conjugate(); \
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _rhs; \
if (RhsStorageOrder==RowMajor) uplo='U'; \
\
@ -279,17 +259,17 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = b_tmp.outerStride(); \
} \
\
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
} \
};
EIGEN_MKL_SYMM_R(double, double, d, d)
EIGEN_MKL_SYMM_R(float, float, f, s)
EIGEN_MKL_HEMM_R(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_HEMM_R(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_SYMM_R(double, double, d, d)
EIGEN_BLAS_SYMM_R(float, float, f, s)
EIGEN_BLAS_HEMM_R(dcomplex, double, cd, z)
EIGEN_BLAS_HEMM_R(scomplex, float, cf, c)
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV.
********************************************************************************
*/
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
namespace Eigen {
@ -47,7 +47,7 @@ template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool Conju
struct selfadjoint_matrix_vector_product_symv :
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {};
#define EIGEN_MKL_SYMV_SPECIALIZE(Scalar) \
#define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \
static void run( \
@ -66,12 +66,12 @@ static void run( \
} \
}; \
EIGEN_MKL_SYMV_SPECIALIZE(double)
EIGEN_MKL_SYMV_SPECIALIZE(float)
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex)
EIGEN_MKL_SYMV_SPECIALIZE(scomplex)
EIGEN_BLAS_SYMV_SPECIALIZE(double)
EIGEN_BLAS_SYMV_SPECIALIZE(float)
EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \
#define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
{ \
@ -85,29 +85,27 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
IsLower = UpLo == Lower ? 1 : 0 \
}; \
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \
MKLTYPE alpha_, beta_; \
const EIGTYPE *x_ptr, myone(1); \
BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
SYMVVector x_tmp; \
if (ConjugateRhs) { \
Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \
MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\
};
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv)
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv)
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv)
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv)
EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
namespace Eigen {
@ -50,7 +50,7 @@ struct product_triangular_matrix_matrix_trmm :
// try to go to BLAS specialization
#define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -65,17 +65,17 @@ struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
} \
};
EIGEN_MKL_TRMM_SPECIALIZE(double, true)
EIGEN_MKL_TRMM_SPECIALIZE(double, false)
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
EIGEN_MKL_TRMM_SPECIALIZE(float, true)
EIGEN_MKL_TRMM_SPECIALIZE(float, false)
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
// implements col-major += alpha * op(triangular) * op(general)
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -106,13 +106,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (rows != depth) { \
\
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
/* FIXME handle mkl_domain_get_max_threads */ \
/*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
\
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
/* Most likely no benefit to call TRMM or GEMM from BLAS */ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
@ -121,27 +122,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
/* Make sense to call GEMM */ \
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \
return; \
} \
char side = 'L', transa, uplo, diag = 'N'; \
EIGTYPE *b; \
const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
BlasIndex m, n, lda, ldb; \
\
/* Set m, n */ \
m = (MKL_INT)diagSize; \
n = (MKL_INT)cols; \
m = convert_index<BlasIndex>(diagSize); \
n = convert_index<BlasIndex>(cols); \
\
/* Set trans */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
@ -152,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
\
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
@ -168,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _lhs; \
lda = lhsStride; \
lda = convert_index<BlasIndex>(lhsStride); \
} \
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
/*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -183,13 +180,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
} \
};
EIGEN_MKL_TRMM_L(double, double, d, d)
EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMM_L(float, float, f, s)
EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_TRMM_L(double, double, d, d)
EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z)
EIGEN_BLAS_TRMM_L(float, float, f, s)
EIGEN_BLAS_TRMM_L(scomplex, float, cf, c)
// implements col-major += alpha * op(general) * op(triangular)
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@ -220,13 +217,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (cols != depth) { \
\
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
\
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
/* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
@ -235,27 +232,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
/* Make sense to call GEMM */ \
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
MKL_INT aStride = aa_tmp.outerStride(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \
return; \
} \
char side = 'R', transa, uplo, diag = 'N'; \
EIGTYPE *b; \
const EIGTYPE *a; \
MKL_INT m, n, lda, ldb; \
MKLTYPE alpha_; \
\
/* Set alpha_*/ \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
BlasIndex m, n, lda, ldb; \
\
/* Set m, n */ \
m = (MKL_INT)rows; \
n = (MKL_INT)diagSize; \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(diagSize); \
\
/* Set trans */ \
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
@ -266,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
\
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
b = b_tmp.data(); \
ldb = b_tmp.outerStride(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
@ -282,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _rhs; \
lda = rhsStride; \
lda = convert_index<BlasIndex>(rhsStride); \
} \
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
/*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@ -297,13 +290,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
} \
};
EIGEN_MKL_TRMM_R(double, double, d, d)
EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMM_R(float, float, f, s)
EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_TRMM_R(double, double, d, d)
EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z)
EIGEN_BLAS_TRMM_R(float, float, f, s)
EIGEN_BLAS_TRMM_R(scomplex, float, cf, c)
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H

View File

@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Triangular matrix-vector product functionality based on ?TRMV.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
namespace Eigen {
@ -47,7 +47,7 @@ template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename Rh
struct triangular_matrix_vector_product_trmv :
triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
#define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
@ -65,13 +65,13 @@ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs
} \
};
EIGEN_MKL_TRMV_SPECIALIZE(double)
EIGEN_MKL_TRMV_SPECIALIZE(float)
EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
EIGEN_BLAS_TRMV_SPECIALIZE(double)
EIGEN_BLAS_TRMV_SPECIALIZE(float)
EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
// implements col-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
enum { \
@ -105,17 +105,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\
\
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
lda = lhsStride; \
n = convert_index<BlasIndex>(size); \
lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \
incy = resIncr; \
incy = convert_index<BlasIndex>(resIncr); \
\
/* Set uplo, trans and diag*/ \
trans = 'N'; \
@ -123,39 +121,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
if (size<rows) { \
y = _res + size*resIncr; \
a = _lhs + size; \
m = rows-size; \
n = size; \
m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \
} \
else { \
x += size; \
y = _res; \
a = _lhs + size*lda; \
m = size; \
n = cols-size; \
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
EIGEN_MKL_TRMV_CM(double, double, d, d)
EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMV_CM(float, float, f, s)
EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_TRMV_CM(double, double, d, d)
EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z)
EIGEN_BLAS_TRMV_CM(float, float, f, s)
EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c)
// implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
enum { \
@ -189,17 +187,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
/* Square part handling */\
\
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \
MKLTYPE alpha_, beta_; \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
lda = lhsStride; \
n = convert_index<BlasIndex>(size); \
lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \
incy = resIncr; \
incy = convert_index<BlasIndex>(resIncr); \
\
/* Set uplo, trans and diag*/ \
trans = ConjLhs ? 'C' : 'T'; \
@ -207,39 +203,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
if (size<rows) { \
y = _res + size*resIncr; \
a = _lhs + size*lda; \
m = rows-size; \
n = size; \
m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \
} \
else { \
x += size; \
y = _res; \
a = _lhs + size; \
m = size; \
n = cols-size; \
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
EIGEN_MKL_TRMV_RM(double, double, d, d)
EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
EIGEN_MKL_TRMV_RM(float, float, f, s)
EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
EIGEN_BLAS_TRMV_RM(double, double, d, d)
EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z)
EIGEN_BLAS_TRMV_RM(float, float, f, s)
EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c)
} // end namespase internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H

View File

@ -25,20 +25,20 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to Intel(R) MKL
* Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
namespace Eigen {
namespace internal {
// implements LeftSide op(triangular)^-1 * general
#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \
#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \
@ -53,13 +53,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \
MKL_INT m = size, n = otherSize, lda, ldb; \
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
char side = 'L', uplo, diag='N', transa; \
/* Set alpha_ */ \
MKLTYPE alpha; \
EIGTYPE myone(1); \
assign_scalar_eig2mkl(alpha, myone); \
ldb = otherStride;\
EIGTYPE alpha(1); \
ldb = convert_index<BlasIndex>(otherStride);\
\
const EIGTYPE *a; \
/* Set trans */ \
@ -75,25 +73,25 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
if (conjA) { \
a_tmp = tri.conjugate(); \
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _tri; \
lda = triStride; \
lda = convert_index<BlasIndex>(triStride); \
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
} \
};
EIGEN_MKL_TRSM_L(double, double, d)
EIGEN_MKL_TRSM_L(dcomplex, MKL_Complex16, z)
EIGEN_MKL_TRSM_L(float, float, s)
EIGEN_MKL_TRSM_L(scomplex, MKL_Complex8, c)
EIGEN_BLAS_TRSM_L(double, double, d)
EIGEN_BLAS_TRSM_L(dcomplex, double, z)
EIGEN_BLAS_TRSM_L(float, float, s)
EIGEN_BLAS_TRSM_L(scomplex, float, c)
// implements RightSide general * op(triangular)^-1
#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \
#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASPREFIX) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \
@ -108,13 +106,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \
MKL_INT m = otherSize, n = size, lda, ldb; \
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
char side = 'R', uplo, diag='N', transa; \
/* Set alpha_ */ \
MKLTYPE alpha; \
EIGTYPE myone(1); \
assign_scalar_eig2mkl(alpha, myone); \
ldb = otherStride;\
EIGTYPE alpha(1); \
ldb = convert_index<BlasIndex>(otherStride);\
\
const EIGTYPE *a; \
/* Set trans */ \
@ -130,26 +126,26 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
if (conjA) { \
a_tmp = tri.conjugate(); \
a = a_tmp.data(); \
lda = a_tmp.outerStride(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _tri; \
lda = triStride; \
lda = convert_index<BlasIndex>(triStride); \
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \
} \
};
EIGEN_MKL_TRSM_R(double, double, d)
EIGEN_MKL_TRSM_R(dcomplex, MKL_Complex16, z)
EIGEN_MKL_TRSM_R(float, float, s)
EIGEN_MKL_TRSM_R(scomplex, MKL_Complex8, c)
EIGEN_BLAS_TRSM_R(double, double, d)
EIGEN_BLAS_TRSM_R(dcomplex, double, z)
EIGEN_BLAS_TRSM_R(float, float, s)
EIGEN_BLAS_TRSM_R(scomplex, float, c)
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H

View File

@ -49,7 +49,7 @@
#define EIGEN_USE_LAPACKE
#endif
#if defined(EIGEN_USE_BLAS) || defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
#if defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
#define EIGEN_USE_MKL
#endif
@ -64,7 +64,6 @@
# ifndef EIGEN_USE_MKL
/*If the MKL version is too old, undef everything*/
# undef EIGEN_USE_MKL_ALL
# undef EIGEN_USE_BLAS
# undef EIGEN_USE_LAPACKE
# undef EIGEN_USE_MKL_VML
# undef EIGEN_USE_LAPACKE_STRICT
@ -107,52 +106,23 @@
#else
#define EIGEN_MKL_DOMAIN_PARDISO MKL_PARDISO
#endif
#endif
namespace Eigen {
typedef std::complex<double> dcomplex;
typedef std::complex<float> scomplex;
namespace internal {
template<typename MKLType, typename EigenType>
static inline void assign_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) {
mklScalar=eigenScalar;
}
template<typename MKLType, typename EigenType>
static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) {
mklScalar=eigenScalar;
}
template <>
inline void assign_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=eigenScalar.imag();
}
template <>
inline void assign_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=eigenScalar.imag();
}
template <>
inline void assign_conj_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=-eigenScalar.imag();
}
template <>
inline void assign_conj_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
mklScalar.real=eigenScalar.real();
mklScalar.imag=-eigenScalar.imag();
}
} // end namespace internal
#if defined(EIGEN_USE_MKL)
typedef MKL_INT BlasIndex;
#else
typedef int BlasIndex;
#endif
} // end namespace Eigen
#if defined(EIGEN_USE_BLAS)
#include "../../misc/blas.h"
#endif
#endif // EIGEN_MKL_SUPPORT_H

View File

@ -30,15 +30,15 @@ int BLASFUNC(cdotcw) (int *, float *, int *, float *, int *, float*);
int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*);
int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*);
int BLASFUNC(saxpy) (int *, float *, float *, int *, float *, int *);
int BLASFUNC(daxpy) (int *, double *, double *, int *, double *, int *);
int BLASFUNC(qaxpy) (int *, double *, double *, int *, double *, int *);
int BLASFUNC(caxpy) (int *, float *, float *, int *, float *, int *);
int BLASFUNC(zaxpy) (int *, double *, double *, int *, double *, int *);
int BLASFUNC(xaxpy) (int *, double *, double *, int *, double *, int *);
int BLASFUNC(caxpyc)(int *, float *, float *, int *, float *, int *);
int BLASFUNC(zaxpyc)(int *, double *, double *, int *, double *, int *);
int BLASFUNC(xaxpyc)(int *, double *, double *, int *, double *, int *);
int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(scopy) (int *, float *, int *, float *, int *);
int BLASFUNC(dcopy) (int *, double *, int *, double *, int *);
@ -177,31 +177,19 @@ int BLASFUNC(xgeru)(int *, int *, double *, double *, int *,
int BLASFUNC(xgerc)(int *, int *, double *, double *, int *,
double *, int *, double *, int *);
int BLASFUNC(sgemv)(char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(dgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(qgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(cgemv)(char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xgemv)(char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(sgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(qgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(strsv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(dtrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(qtrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(ctrsv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(ztrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(xtrsv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(strsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(dtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(qtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(ztrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(xtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(stpsv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(dtpsv) (char *, char *, char *, int *, double *, double *, int *);
@ -210,18 +198,12 @@ int BLASFUNC(ctpsv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(ztpsv) (char *, char *, char *, int *, double *, double *, int *);
int BLASFUNC(xtpsv) (char *, char *, char *, int *, double *, double *, int *);
int BLASFUNC(strmv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(dtrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(qtrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(ctrmv) (char *, char *, char *, int *, float *, int *,
float *, int *);
int BLASFUNC(ztrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(xtrmv) (char *, char *, char *, int *, double *, int *,
double *, int *);
int BLASFUNC(strmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(dtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(qtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(ztrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(xtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(stpmv) (char *, char *, char *, int *, float *, float *, int *);
int BLASFUNC(dtpmv) (char *, char *, char *, int *, double *, double *, int *);
@ -244,18 +226,9 @@ int BLASFUNC(ctbsv) (char *, char *, char *, int *, int *, float *, int *, floa
int BLASFUNC(ztbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
int BLASFUNC(ssymv) (char *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(dsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(qsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(csymv) (char *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xsymv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(ssymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(qsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(sspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
@ -263,38 +236,17 @@ int BLASFUNC(dspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(qspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(cspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
int BLASFUNC(zspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(xspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(ssyr) (char *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(dsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(qsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(ssyr) (const char *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(dsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(qsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ssyr2) (char *, int *, float *,
float *, int *, float *, int *, float *, int *);
int BLASFUNC(dsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(qsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(csyr2) (char *, int *, float *,
float *, int *, float *, int *, float *, int *);
int BLASFUNC(zsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(xsyr2) (char *, int *, double *,
double *, int *, double *, int *, double *, int *);
int BLASFUNC(ssyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(dsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(qsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(csyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
int BLASFUNC(zsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(xsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
int BLASFUNC(sspr) (char *, int *, float *, float *, int *,
float *);
@ -302,12 +254,6 @@ int BLASFUNC(dspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(qspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
float *);
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(sspr2) (char *, int *, float *,
float *, int *, float *, int *, float *);
@ -347,12 +293,9 @@ int BLASFUNC(zhpr2) (char *, int *, double *,
int BLASFUNC(xhpr2) (char *, int *, double *,
double *, int *, double *, int *, double *);
int BLASFUNC(chemv) (char *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zhemv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xhemv) (char *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(chemv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(chpmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
@ -401,18 +344,12 @@ int BLASFUNC(xhbmv)(char *, int *, int *, double *, double *, int *,
/* Level 3 routines */
int BLASFUNC(sgemm)(char *, char *, int *, int *, int *, float *,
float *, int *, float *, int *, float *, float *, int *);
int BLASFUNC(dgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(qgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(cgemm)(char *, char *, int *, int *, int *, float *,
float *, int *, float *, int *, float *, float *, int *);
int BLASFUNC(zgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(xgemm)(char *, char *, int *, int *, int *, double *,
double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(sgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(qgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cgemm3m)(char *, char *, int *, int *, int *, float *,
float *, int *, float *, int *, float *, float *, int *);
@ -434,84 +371,48 @@ int BLASFUNC(zge2mm)(char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *,
double *, double *, int *);
int BLASFUNC(strsm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(dtrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(qtrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(ctrsm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(ztrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(xtrsm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(strsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(dtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(qtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(ztrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(strmm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(dtrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(qtrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(ctrmm)(char *, char *, char *, char *, int *, int *,
float *, float *, int *, float *, int *);
int BLASFUNC(ztrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(xtrmm)(char *, char *, char *, char *, int *, int *,
double *, double *, int *, double *, int *);
int BLASFUNC(strmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(dtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(qtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ctrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
int BLASFUNC(ztrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(xtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
int BLASFUNC(ssymm)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(dsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(qsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(csymm)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xsymm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(ssymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(qsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(csymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *);
int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
int BLASFUNC(ssyrk)(char *, char *, int *, int *, float *, float *, int *,
float *, float *, int *);
int BLASFUNC(dsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(qsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(csyrk)(char *, char *, int *, int *, float *, float *, int *,
float *, float *, int *);
int BLASFUNC(zsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(xsyrk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(ssyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(qsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(csyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(ssyr2k)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(dsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(qsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(csyr2k)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(xsyr2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(ssyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(dsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(qsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(csyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(xsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(chemm)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zhemm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(xhemm)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(chemm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(chemm3m)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
@ -520,136 +421,17 @@ int BLASFUNC(zhemm3m)(char *, char *, int *, int *, double *, double *, int *,
int BLASFUNC(xhemm3m)(char *, char *, int *, int *, double *, double *, int *,
double *, int *, double *, double *, int *);
int BLASFUNC(cherk)(char *, char *, int *, int *, float *, float *, int *,
float *, float *, int *);
int BLASFUNC(zherk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(xherk)(char *, char *, int *, int *, double *, double *, int *,
double *, double *, int *);
int BLASFUNC(cherk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cher2k)(char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zher2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(xher2k)(char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(cher2m)(char *, char *, char *, int *, int *, float *, float *, int *,
float *, int *, float *, float *, int *);
int BLASFUNC(zher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(xher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
double*, int *, double *, double *, int *);
int BLASFUNC(cher2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cher2m)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
#ifdef __cplusplus
}

152
Eigen/src/misc/lapack.h Normal file
View File

@ -0,0 +1,152 @@
#ifndef LAPACK_H
#define LAPACK_H
#include "blas.h"
#ifdef __cplusplus
extern "C"
{
#endif
int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
int BLASFUNC(cspmv) (char *, int *, float *, float *,
float *, int *, float *, float *, int *);
int BLASFUNC(zspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(xspmv) (char *, int *, double *, double *,
double *, int *, double *, double *, int *);
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
float *);
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
double *);
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
float *, int *);
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
double *, int *);
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
float *, int *, float *, float *, int *, float *, int *);
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
double *, int *, double*, double *, int *, double*, int *);
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -10,8 +10,8 @@
#ifndef EIGEN_BLAS_COMMON_H
#define EIGEN_BLAS_COMMON_H
#include <Eigen/Core>
#include <Eigen/Jacobi>
#include "../Eigen/Core"
#include "../Eigen/Jacobi"
#include <complex>
@ -19,8 +19,7 @@
#error the token SCALAR must be defined to compile this file
#endif
#include <Eigen/src/misc/blas.h>
#include "../Eigen/src/misc/blas.h"
#define NOTR 0
#define TR 1
@ -94,6 +93,7 @@ enum
typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> PlainMatrixType;
typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType;
typedef Map<const Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > ConstMatrixType;
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
@ -104,25 +104,44 @@ matrix(T* data, int rows, int cols, int stride)
return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
}
template<typename T>
Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >
matrix(const T* data, int rows, int cols, int stride)
{
return Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
}
template<typename T>
Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(T* data, int size, int incr)
{
return Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
}
template<typename T>
Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(const T* data, int size, int incr)
{
return Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
}
template<typename T>
Map<Matrix<T,Dynamic,1> > make_vector(T* data, int size)
{
return Map<Matrix<T,Dynamic,1> >(data, size);
}
template<typename T>
Map<const Matrix<T,Dynamic,1> > make_vector(const T* data, int size)
{
return Map<const Matrix<T,Dynamic,1> >(data, size);
}
template<typename T>
T* get_compact_vector(T* x, int n, int incx)
{
if(incx==1)
return x;
T* ret = new Scalar[n];
typename Eigen::internal::remove_const<T>::type* ret = new Scalar[n];
if(incx<0) make_vector(ret,n) = make_vector(x,n,-incx).reverse();
else make_vector(ret,n) = make_vector(x,n, incx);
return ret;

View File

@ -9,11 +9,11 @@
#include "common.h"
int EIGEN_BLAS_FUNC(axpy)(int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy)
int EIGEN_BLAS_FUNC(axpy)(const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *py, const int *incy)
{
Scalar* x = reinterpret_cast<Scalar*>(px);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
if(*n<=0) return 0;

View File

@ -16,7 +16,8 @@
* where alpha and beta are scalars, x and y are n element vectors and
* A is an n by n hermitian matrix.
*/
int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
int EIGEN_BLAS_FUNC(hemv)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
{
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
static const functype func[2] = {
@ -26,11 +27,11 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments
int info = 0;
@ -45,7 +46,7 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
if(*n==0)
return 1;
Scalar* actual_x = get_compact_vector(x,*n,*incx);
const Scalar* actual_x = get_compact_vector(x,*n,*incx);
Scalar* actual_y = get_compact_vector(y,*n,*incy);
if(beta!=Scalar(1))

View File

@ -23,7 +23,8 @@ struct general_matrix_vector_product_wrapper
}
};
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
int EIGEN_BLAS_FUNC(gemv)(const char *opa, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *incb, const RealScalar *pbeta, RealScalar *pc, const int *incc)
{
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
static const functype func[4] = {
@ -36,11 +37,11 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments
int info = 0;
@ -62,7 +63,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
if(code!=NOTR)
std::swap(actual_m,actual_n);
Scalar* actual_b = get_compact_vector(b,actual_n,*incb);
const Scalar* actual_b = get_compact_vector(b,actual_n,*incb);
Scalar* actual_c = get_compact_vector(c,actual_m,*incc);
if(beta!=Scalar(1))
@ -82,7 +83,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
return 1;
}
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
int EIGEN_BLAS_FUNC(trsv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
{
typedef void (*functype)(int, const Scalar *, int, Scalar *);
static const functype func[16] = {
@ -116,7 +117,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
int info = 0;
@ -141,7 +142,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
int EIGEN_BLAS_FUNC(trmv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
{
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, const Scalar&);
static const functype func[16] = {
@ -175,7 +176,7 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
int info = 0;
@ -217,11 +218,11 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealScalar *palpha, RealScalar *pa, int *lda,
RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
{
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int coeff_rows = *kl+*ku+1;
int info = 0;
@ -244,7 +245,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
if(OP(*trans)!=NOTR)
std::swap(actual_m,actual_n);
Scalar* actual_x = get_compact_vector(x,actual_n,*incx);
const Scalar* actual_x = get_compact_vector(x,actual_n,*incx);
Scalar* actual_y = get_compact_vector(y,actual_m,*incy);
if(beta!=Scalar(1))
@ -253,7 +254,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
else make_vector(actual_y, actual_m) *= beta;
}
MatrixType mat_coeffs(a,coeff_rows,*n,*lda);
ConstMatrixType mat_coeffs(a,coeff_rows,*n,*lda);
int nb = std::min(*n,(*m)+(*ku));
for(int j=0; j<nb; ++j)

View File

@ -10,7 +10,8 @@
#include "common.h"
// y = alpha*A*x + beta*y
int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
int EIGEN_BLAS_FUNC(symv) (const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
{
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
static const functype func[2] = {
@ -20,11 +21,11 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* x = reinterpret_cast<Scalar*>(px);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// check arguments
int info = 0;
@ -39,7 +40,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
if(*n==0)
return 0;
Scalar* actual_x = get_compact_vector(x,*n,*incx);
const Scalar* actual_x = get_compact_vector(x,*n,*incx);
Scalar* actual_y = get_compact_vector(y,*n,*incy);
if(beta!=Scalar(1))
@ -61,7 +62,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
}
// C := alpha*x*x' + C
int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(syr)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *pc, const int *ldc)
{
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&);
@ -72,9 +73,9 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
(selfadjoint_rank1_update<Scalar,int,ColMajor,Lower,false,Conj>::run),
};
Scalar* x = reinterpret_cast<Scalar*>(px);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
@ -87,7 +88,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
if(*n==0 || alpha==Scalar(0)) return 1;
// if the increment is not 1, let's copy it to a temporary vector to enable vectorization
Scalar* x_cpy = get_compact_vector(x,*n,*incx);
const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
int code = UPLO(*uplo);
if(code>=2 || func[code]==0)
@ -101,7 +102,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
}
// C := alpha*x*y' + alpha*y*x' + C
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(syr2)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, const RealScalar *py, const int *incy, RealScalar *pc, const int *ldc)
{
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
static const functype func[2] = {
@ -111,10 +112,10 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
(internal::rank2_update_selector<Scalar,int,Lower>::run),
};
Scalar* x = reinterpret_cast<Scalar*>(px);
Scalar* y = reinterpret_cast<Scalar*>(py);
const Scalar* x = reinterpret_cast<const Scalar*>(px);
const Scalar* y = reinterpret_cast<const Scalar*>(py);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
@ -128,8 +129,8 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
if(alpha==Scalar(0))
return 1;
Scalar* x_cpy = get_compact_vector(x,*n,*incx);
Scalar* y_cpy = get_compact_vector(y,*n,*incy);
const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
const Scalar* y_cpy = get_compact_vector(y,*n,*incy);
int code = UPLO(*uplo);
if(code>=2 || func[code]==0)

View File

@ -9,7 +9,8 @@
#include <iostream>
#include "common.h"
int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
@ -37,11 +38,11 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0;
if(OP(*opa)==INVALID) info = 1;
@ -74,7 +75,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
return 0;
}
int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
{
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
@ -137,9 +139,9 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0;
if(SIDE(*side)==INVALID) info = 1;
@ -178,7 +180,8 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
// b = alpha*op(a)*b for side = 'L'or'l'
// b = alpha*b*op(a) for side = 'R'or'r'
int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
{
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
@ -241,9 +244,9 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
int info = 0;
if(SIDE(*side)==INVALID) info = 1;
@ -281,14 +284,15 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
// c = alpha*a*b + beta*c for side = 'L'or'l'
// c = alpha*b*a + beta*c for side = 'R'or'r
int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0;
if(SIDE(*side)==INVALID) info = 1;
@ -350,7 +354,8 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// c = alpha*a*a' + beta*c for op = 'N'or'n'
// c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
#if !ISCOMPLEX
@ -373,10 +378,10 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
};
#endif
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
int info = 0;
if(UPLO(*uplo)==INVALID) info = 1;
@ -429,13 +434,14 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
// c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
// c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
@ -496,13 +502,14 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
// c = alpha*a*b + beta*c for side = 'L'or'l'
// c = alpha*b*a + beta*c for side = 'R'or'r
int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
@ -554,7 +561,8 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
// c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
// c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
@ -574,7 +582,7 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
0
};
Scalar* a = reinterpret_cast<Scalar*>(pa);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
Scalar* c = reinterpret_cast<Scalar*>(pc);
RealScalar alpha = *palpha;
RealScalar beta = *pbeta;
@ -620,12 +628,13 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
// c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
// c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k,
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
{
Scalar* a = reinterpret_cast<Scalar*>(pa);
Scalar* b = reinterpret_cast<Scalar*>(pb);
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
Scalar* c = reinterpret_cast<Scalar*>(pc);
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
RealScalar beta = *pbeta;
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";

View File

@ -37,10 +37,10 @@ Here is another example reshaping a 2x6 matrix to a 6x2 one:
\section TutorialSlicing Slicing
Slicing consists in taking a set of rows, or columns, or elements, uniformly spaced within a matrix.
Slicing consists in taking a set of rows, columns, or elements, uniformly spaced within a matrix.
Again, the class Map allows to easily mimic this feature.
For instance, one can take skip every P elements in a vector:
For instance, one can skip every P elements in a vector:
<table class="example">
<tr><th>Example:</th><th>Output:</th></tr>
<tr><td>

View File

@ -55,7 +55,7 @@ Operations on other scalar types or mixing reals and complexes will continue to
In addition you can choose which parts will be substituted by defining one or multiple of the following macros:
<table class="manual">
<tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (currently works with Intel MKL only)</td></tr>
<tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (compatible with any F77 BLAS interface, not only Intel MKL)</td></tr>
<tr class="alt"><td>\c EIGEN_USE_LAPACKE </td><td>Enables the use of external Lapack routines via the <a href="http://www.netlib.org/lapack/lapacke.html">Intel Lapacke</a> C interface to Lapack (currently works with Intel MKL only)</td></tr>
<tr><td>\c EIGEN_USE_LAPACKE_STRICT </td><td>Same as \c EIGEN_USE_LAPACKE but algorithm of lower robustness are disabled. This currently concerns only JacobiSVD which otherwise would be replaced by \c gesvd that is less robust than Jacobi rotations.</td></tr>
<tr class="alt"><td>\c EIGEN_USE_MKL_VML </td><td>Enables the use of Intel VML (vector operations)</td></tr>

View File

@ -11,6 +11,7 @@
#define EIGEN_LAPACK_COMMON_H
#include "../blas/common.h"
#include "../Eigen/src/misc/lapack.h"
#define EIGEN_LAPACK_FUNC(FUNC,ARGLIST) \
extern "C" { int EIGEN_BLAS_FUNC(FUNC) ARGLIST; } \

View File

@ -331,11 +331,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
std::numeric_limits<RealScalar>::infinity());
VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter
// Check the polygamma against scipy.special.polygamma examples
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter
// Check the polygamma function over a larger range of values
VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435));

View File

@ -14,6 +14,9 @@
using std::sqrt;
// tolerance for chekcing number of iterations
#define LM_EVAL_COUNT_TOL 4/3
int fcn_chkder(const VectorXd &x, VectorXd &fvec, MatrixXd &fjac, int iflag)
{
/* subroutine fcn for chkder example. */
@ -1023,7 +1026,8 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.njev, 72);
// check norm^2
std::cout.precision(30);
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.4290986055242372e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
std::cout << lm.fvec.squaredNorm() << "\n";
VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
// check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -1044,7 +1048,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev, 9);
VERIFY_IS_EQUAL(lm.njev, 8);
// check norm^2
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.430571737783119393e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
// check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -1354,8 +1358,12 @@ void testNistMGH17(void)
// check return value
VERIFY_IS_EQUAL(info, 2);
VERIFY(lm.nfev < 650); // 602
VERIFY(lm.njev < 600); // 545
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev, 602); // 602
VERIFY_IS_EQUAL(lm.njev, 545); // 545
--g_test_level;
VERIFY(lm.nfev < 602 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev < 545 * LM_EVAL_COUNT_TOL);
/*
* Second try

View File

@ -23,6 +23,9 @@
using std::sqrt;
// tolerance for chekcing number of iterations
#define LM_EVAL_COUNT_TOL 4/3
struct lmder_functor : DenseFunctor<double>
{
lmder_functor(void): DenseFunctor<double>(3,15) {}
@ -631,7 +634,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev(), 79);
VERIFY_IS_EQUAL(lm.njev(), 72);
// check norm^2
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.430899764097e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
// check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -652,7 +655,7 @@ void testNistLanczos1(void)
VERIFY_IS_EQUAL(lm.nfev(), 9);
VERIFY_IS_EQUAL(lm.njev(), 8);
// check norm^2
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.428595533845e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
// check x
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
@ -789,7 +792,8 @@ void testNistMGH10(void)
MGH10_functor functor;
LevenbergMarquardt<MGH10_functor> lm(functor);
info = lm.minimize(x);
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeErrorTooSmall);
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
// was: VERIFY_IS_EQUAL(info, 1);
// check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
@ -799,9 +803,13 @@ void testNistMGH10(void)
VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
// check return value
//VERIFY_IS_EQUAL(info, 1);
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 284 );
VERIFY_IS_EQUAL(lm.njev(), 249 );
--g_test_level;
VERIFY(lm.nfev() < 284 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 249 * LM_EVAL_COUNT_TOL);
/*
* Second try
@ -809,7 +817,10 @@ void testNistMGH10(void)
x<< 0.02, 4000., 250.;
// do the computation
info = lm.minimize(x);
++g_test_level;
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
// was: VERIFY_IS_EQUAL(info, 1);
--g_test_level;
// check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
@ -819,9 +830,12 @@ void testNistMGH10(void)
VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
// check return value
//VERIFY_IS_EQUAL(info, 1);
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 126);
VERIFY_IS_EQUAL(lm.njev(), 116);
--g_test_level;
VERIFY(lm.nfev() < 126 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 116 * LM_EVAL_COUNT_TOL);
}
@ -896,8 +910,12 @@ void testNistBoxBOD(void)
// check return value
VERIFY_IS_EQUAL(info, 1);
++g_test_level;
VERIFY_IS_EQUAL(lm.nfev(), 16 );
VERIFY_IS_EQUAL(lm.njev(), 15 );
--g_test_level;
VERIFY(lm.nfev() < 16 * LM_EVAL_COUNT_TOL);
VERIFY(lm.njev() < 15 * LM_EVAL_COUNT_TOL);
// check norm^2
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.1680088766E+03);
// check x