mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-21 20:09:06 +08:00
Pulled latest update from trunk
This commit is contained in:
commit
e939b087fe
16
Eigen/Core
16
Eigen/Core
@ -450,14 +450,14 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/ArrayWrapper.h"
|
||||
|
||||
#ifdef EIGEN_USE_BLAS
|
||||
#include "src/Core/products/GeneralMatrixMatrix_MKL.h"
|
||||
#include "src/Core/products/GeneralMatrixVector_MKL.h"
|
||||
#include "src/Core/products/GeneralMatrixMatrixTriangular_MKL.h"
|
||||
#include "src/Core/products/SelfadjointMatrixMatrix_MKL.h"
|
||||
#include "src/Core/products/SelfadjointMatrixVector_MKL.h"
|
||||
#include "src/Core/products/TriangularMatrixMatrix_MKL.h"
|
||||
#include "src/Core/products/TriangularMatrixVector_MKL.h"
|
||||
#include "src/Core/products/TriangularSolverMatrix_MKL.h"
|
||||
#include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
|
||||
#include "src/Core/products/GeneralMatrixVector_BLAS.h"
|
||||
#include "src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h"
|
||||
#include "src/Core/products/SelfadjointMatrixMatrix_BLAS.h"
|
||||
#include "src/Core/products/SelfadjointMatrixVector_BLAS.h"
|
||||
#include "src/Core/products/TriangularMatrixMatrix_BLAS.h"
|
||||
#include "src/Core/products/TriangularMatrixVector_BLAS.h"
|
||||
#include "src/Core/products/TriangularSolverMatrix_BLAS.h"
|
||||
#endif // EIGEN_USE_BLAS
|
||||
|
||||
#ifdef EIGEN_USE_MKL_VML
|
||||
|
@ -788,8 +788,8 @@ template<typename Dst, typename Src> void check_for_aliasing(const Dst &dst, con
|
||||
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
|
||||
struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
|
||||
{
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
|
||||
{
|
||||
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
|
||||
|
||||
@ -806,8 +806,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
|
||||
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
|
||||
struct Assignment<DstXprType, SrcXprType, Functor, EigenBase2EigenBase, Scalar>
|
||||
{
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
|
||||
{
|
||||
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
|
||||
src.evalTo(dst);
|
||||
|
@ -79,8 +79,8 @@ namespace cephes {
|
||||
*/
|
||||
template <typename Scalar, int N>
|
||||
struct polevl {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static Scalar run(const Scalar x, const Scalar coef[]) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(const Scalar x, const Scalar coef[]) {
|
||||
EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
|
||||
return polevl<Scalar, N - 1>::run(x, coef) * x + coef[N];
|
||||
@ -89,8 +89,8 @@ struct polevl {
|
||||
|
||||
template <typename Scalar>
|
||||
struct polevl<Scalar, 0> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static Scalar run(const Scalar, const Scalar coef[]) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(const Scalar, const Scalar coef[]) {
|
||||
return coef[0];
|
||||
}
|
||||
};
|
||||
@ -144,7 +144,7 @@ struct digamma_retval {
|
||||
template <typename Scalar>
|
||||
struct digamma_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar x) {
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar x) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
@ -428,20 +428,20 @@ template <typename Scalar> struct igamma_impl; // predeclare igamma_impl
|
||||
|
||||
template <typename Scalar>
|
||||
struct igamma_helper {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct igamma_helper<float> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static float machep() {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE float machep() {
|
||||
return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static float big() {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE float big() {
|
||||
// use epsneg (1.0 - epsneg == 1.0)
|
||||
return 1.0 / (NumTraits<float>::epsilon() / 2);
|
||||
}
|
||||
@ -449,12 +449,12 @@ struct igamma_helper<float> {
|
||||
|
||||
template <>
|
||||
struct igamma_helper<double> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static double machep() {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE double machep() {
|
||||
return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static double big() {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE double big() {
|
||||
return 1.0 / NumTraits<double>::epsilon();
|
||||
}
|
||||
};
|
||||
@ -605,7 +605,7 @@ struct igamma_retval {
|
||||
template <typename Scalar>
|
||||
struct igamma_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar a, Scalar x) {
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
@ -736,7 +736,7 @@ struct zeta_retval {
|
||||
template <typename Scalar>
|
||||
struct zeta_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar x, Scalar q) {
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar x, Scalar q) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
@ -757,8 +757,8 @@ struct zeta_impl_series {
|
||||
|
||||
template <>
|
||||
struct zeta_impl_series<float> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static bool run(float& a, float& b, float& s, const float x, const float machep) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) {
|
||||
int i = 0;
|
||||
while(i < 9)
|
||||
{
|
||||
@ -777,8 +777,8 @@ struct zeta_impl_series<float> {
|
||||
|
||||
template <>
|
||||
struct zeta_impl_series<double> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static bool run(double& a, double& b, double& s, const double x, const double machep) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) {
|
||||
int i = 0;
|
||||
while( (i < 9) || (a <= 9.0) )
|
||||
{
|
||||
@ -881,13 +881,14 @@ struct zeta_impl {
|
||||
const Scalar maxnum = NumTraits<Scalar>::infinity();
|
||||
const Scalar zero = 0.0, half = 0.5, one = 1.0;
|
||||
const Scalar machep = igamma_helper<Scalar>::machep();
|
||||
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||
|
||||
if( x == one )
|
||||
return maxnum;
|
||||
|
||||
if( x < one )
|
||||
{
|
||||
return zero;
|
||||
return nan;
|
||||
}
|
||||
|
||||
if( q <= zero )
|
||||
@ -899,7 +900,7 @@ struct zeta_impl {
|
||||
p = x;
|
||||
r = numext::floor(p);
|
||||
if (p != r)
|
||||
return zero;
|
||||
return nan;
|
||||
}
|
||||
|
||||
/* Permit negative q but continue sum until n+q > +9 .
|
||||
@ -954,7 +955,7 @@ struct polygamma_retval {
|
||||
template <typename Scalar>
|
||||
struct polygamma_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar n, Scalar x) {
|
||||
static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
@ -969,9 +970,14 @@ struct polygamma_impl {
|
||||
static Scalar run(Scalar n, Scalar x) {
|
||||
Scalar zero = 0.0, one = 1.0;
|
||||
Scalar nplus = n + one;
|
||||
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||
|
||||
// Check that n is an integer
|
||||
if (numext::floor(n) != n) {
|
||||
return nan;
|
||||
}
|
||||
// Just return the digamma function for n = 1
|
||||
if (n == zero) {
|
||||
else if (n == zero) {
|
||||
return digamma_impl<Scalar>::run(x);
|
||||
}
|
||||
// Use the same implementation as scipy
|
||||
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Level 3 BLAS SYRK/HERK implementation.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
|
||||
#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
|
||||
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
|
||||
#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -44,34 +44,35 @@ struct general_matrix_matrix_rankupdate :
|
||||
|
||||
|
||||
// try to go to BLAS specialization
|
||||
#define EIGEN_MKL_RANKUPDATE_SPECIALIZE(Scalar) \
|
||||
#define EIGEN_BLAS_RANKUPDATE_SPECIALIZE(Scalar) \
|
||||
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs, int UpLo> \
|
||||
struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
|
||||
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
|
||||
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) \
|
||||
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
|
||||
{ \
|
||||
if (lhs==rhs) { \
|
||||
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
|
||||
} else { \
|
||||
general_matrix_matrix_triangular_product<Index, \
|
||||
Scalar, LhsStorageOrder, ConjugateLhs, \
|
||||
Scalar, RhsStorageOrder, ConjugateRhs, \
|
||||
ColMajor, UpLo, BuiltIn> \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha); \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_RANKUPDATE_SPECIALIZE(double)
|
||||
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_RANKUPDATE_SPECIALIZE(float)
|
||||
//EIGEN_MKL_RANKUPDATE_SPECIALIZE(scomplex)
|
||||
EIGEN_BLAS_RANKUPDATE_SPECIALIZE(double)
|
||||
EIGEN_BLAS_RANKUPDATE_SPECIALIZE(float)
|
||||
// TODO handle complex cases
|
||||
// EIGEN_BLAS_RANKUPDATE_SPECIALIZE(dcomplex)
|
||||
// EIGEN_BLAS_RANKUPDATE_SPECIALIZE(scomplex)
|
||||
|
||||
// SYRK for float/double
|
||||
#define EIGEN_MKL_RANKUPDATE_R(EIGTYPE, MKLTYPE, MKLFUNC) \
|
||||
#define EIGEN_BLAS_RANKUPDATE_R(EIGTYPE, BLASTYPE, BLASFUNC) \
|
||||
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
|
||||
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
|
||||
enum { \
|
||||
@ -80,23 +81,19 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
conjA = ((AStorageOrder==ColMajor) && ConjugateA) ? 1 : 0 \
|
||||
}; \
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
|
||||
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
|
||||
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
|
||||
\
|
||||
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
|
||||
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
|
||||
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'T':'N'; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
||||
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, lhs, &lda, &beta_, res, &ldc); \
|
||||
EIGTYPE beta; \
|
||||
BLASFUNC(&uplo, &trans, &n, &k, &numext::real_ref(alpha), lhs, &lda, &numext::real_ref(beta), res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
// HERK for complex data
|
||||
#define EIGEN_MKL_RANKUPDATE_C(EIGTYPE, MKLTYPE, RTYPE, MKLFUNC) \
|
||||
#define EIGEN_BLAS_RANKUPDATE_C(EIGTYPE, BLASTYPE, RTYPE, BLASFUNC) \
|
||||
template <typename Index, int AStorageOrder, bool ConjugateA, int UpLo> \
|
||||
struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,ColMajor,UpLo> { \
|
||||
enum { \
|
||||
@ -105,18 +102,15 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
conjA = (((AStorageOrder==ColMajor) && ConjugateA) || ((AStorageOrder==RowMajor) && !ConjugateA)) ? 1 : 0 \
|
||||
}; \
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const EIGTYPE* lhs, Index lhsStride, \
|
||||
const EIGTYPE* rhs, Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha) \
|
||||
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
|
||||
\
|
||||
MKL_INT lda=lhsStride, ldc=resStride, n=size, k=depth; \
|
||||
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
|
||||
char uplo=(IsLower) ? 'L' : 'U', trans=(AStorageOrder==RowMajor) ? 'C':'N'; \
|
||||
RTYPE alpha_, beta_; \
|
||||
const EIGTYPE* a_ptr; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); */\
|
||||
/* assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1));*/ \
|
||||
alpha_ = alpha.real(); \
|
||||
beta_ = 1.0; \
|
||||
/* Copy with conjugation in some cases*/ \
|
||||
@ -127,20 +121,21 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
|
||||
lda = a.outerStride(); \
|
||||
a_ptr = a.data(); \
|
||||
} else a_ptr=lhs; \
|
||||
MKLFUNC(&uplo, &trans, &n, &k, &alpha_, (MKLTYPE*)a_ptr, &lda, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASFUNC(&uplo, &trans, &n, &k, &alpha_, (BLASTYPE*)a_ptr, &lda, &beta_, (BLASTYPE*)res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
EIGEN_MKL_RANKUPDATE_R(double, double, dsyrk)
|
||||
EIGEN_MKL_RANKUPDATE_R(float, float, ssyrk)
|
||||
EIGEN_BLAS_RANKUPDATE_R(double, double, dsyrk_)
|
||||
EIGEN_BLAS_RANKUPDATE_R(float, float, ssyrk_)
|
||||
|
||||
//EIGEN_MKL_RANKUPDATE_C(dcomplex, MKL_Complex16, double, zherk)
|
||||
//EIGEN_MKL_RANKUPDATE_C(scomplex, MKL_Complex8, double, cherk)
|
||||
// TODO hanlde complex cases
|
||||
// EIGEN_BLAS_RANKUPDATE_C(dcomplex, double, double, zherk_)
|
||||
// EIGEN_BLAS_RANKUPDATE_C(scomplex, float, float, cherk_)
|
||||
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_MKL_H
|
||||
#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* General matrix-matrix product functionality based on ?GEMM.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
|
||||
#define EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
|
||||
#ifndef EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
|
||||
#define EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -46,7 +46,7 @@ namespace internal {
|
||||
|
||||
// gemm specialization
|
||||
|
||||
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, MKLTYPE, MKLPREFIX) \
|
||||
#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASPREFIX) \
|
||||
template< \
|
||||
typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
@ -66,55 +66,50 @@ static void run(Index rows, Index cols, Index depth, \
|
||||
using std::conj; \
|
||||
\
|
||||
char transa, transb; \
|
||||
MKL_INT m, n, k, lda, ldb, ldc; \
|
||||
BlasIndex m, n, k, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX a_tmp, b_tmp; \
|
||||
EIGTYPE myone(1);\
|
||||
\
|
||||
/* Set transpose options */ \
|
||||
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
|
||||
transb = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
k = (MKL_INT)depth; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
k = convert_index<BlasIndex>(depth); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if ((LhsStorageOrder==ColMajor) && (ConjugateLhs)) { \
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,k,OuterStride<>(lhsStride)); \
|
||||
a_tmp = lhs.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else a = _lhs; \
|
||||
\
|
||||
if ((RhsStorageOrder==ColMajor) && (ConjugateRhs)) { \
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,k,n,OuterStride<>(rhsStride)); \
|
||||
b_tmp = rhs.conjugate(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _rhs; \
|
||||
\
|
||||
MKLPREFIX##gemm(&transa, &transb, &m, &n, &k, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
}};
|
||||
|
||||
GEMM_SPECIALIZATION(double, d, double, d)
|
||||
GEMM_SPECIALIZATION(float, f, float, s)
|
||||
GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, z)
|
||||
GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, c)
|
||||
GEMM_SPECIALIZATION(dcomplex, cd, double, z)
|
||||
GEMM_SPECIALIZATION(scomplex, cf, float, c)
|
||||
|
||||
} // end namespase internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_GENERAL_MATRIX_MATRIX_MKL_H
|
||||
#endif // EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* General matrix-vector product functionality based on ?GEMV.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
|
||||
#define EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
|
||||
#ifndef EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
|
||||
#define EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -49,7 +49,7 @@ namespace internal {
|
||||
template<typename Index, typename LhsScalar, int StorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs>
|
||||
struct general_matrix_vector_product_gemv;
|
||||
|
||||
#define EIGEN_MKL_GEMV_SPECIALIZE(Scalar) \
|
||||
#define EIGEN_BLAS_GEMV_SPECIALIZE(Scalar) \
|
||||
template<typename Index, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct general_matrix_vector_product<Index,Scalar,const_blas_data_mapper<Scalar,Index,ColMajor>,ColMajor,ConjugateLhs,Scalar,const_blas_data_mapper<Scalar,Index,RowMajor>,ConjugateRhs,Specialized> { \
|
||||
static void run( \
|
||||
@ -80,12 +80,12 @@ static void run( \
|
||||
} \
|
||||
}; \
|
||||
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(double)
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(float)
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_GEMV_SPECIALIZE(scomplex)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZE(double)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZE(float)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
|
||||
|
||||
#define EIGEN_MKL_GEMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLPREFIX) \
|
||||
#define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASPREFIX) \
|
||||
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
|
||||
{ \
|
||||
@ -97,16 +97,15 @@ static void run( \
|
||||
const EIGTYPE* rhs, Index rhsIncr, \
|
||||
EIGTYPE* res, Index resIncr, EIGTYPE alpha) \
|
||||
{ \
|
||||
MKL_INT m=rows, n=cols, lda=lhsStride, incx=rhsIncr, incy=resIncr; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
const EIGTYPE *x_ptr, myone(1); \
|
||||
BlasIndex m=convert_index<BlasIndex>(rows), n=convert_index<BlasIndex>(cols), \
|
||||
lda=convert_index<BlasIndex>(lhsStride), incx=convert_index<BlasIndex>(rhsIncr), incy=convert_index<BlasIndex>(resIncr); \
|
||||
const EIGTYPE beta(1); \
|
||||
const EIGTYPE *x_ptr; \
|
||||
char trans=(LhsStorageOrder==ColMajor) ? 'N' : (ConjugateLhs) ? 'C' : 'T'; \
|
||||
if (LhsStorageOrder==RowMajor) { \
|
||||
m=cols; \
|
||||
n=rows; \
|
||||
m = convert_index<BlasIndex>(cols); \
|
||||
n = convert_index<BlasIndex>(rows); \
|
||||
}\
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
GEMVVector x_tmp; \
|
||||
if (ConjugateRhs) { \
|
||||
Map<const GEMVVector, 0, InnerStride<> > map_x(rhs,cols,1,InnerStride<>(incx)); \
|
||||
@ -114,17 +113,17 @@ static void run( \
|
||||
x_ptr=x_tmp.data(); \
|
||||
incx=1; \
|
||||
} else x_ptr=rhs; \
|
||||
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
|
||||
BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
||||
}\
|
||||
};
|
||||
|
||||
EIGEN_MKL_GEMV_SPECIALIZATION(double, double, d)
|
||||
EIGEN_MKL_GEMV_SPECIALIZATION(float, float, s)
|
||||
EIGEN_MKL_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, z)
|
||||
EIGEN_MKL_GEMV_SPECIALIZATION(scomplex, MKL_Complex8, c)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, d)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, s)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, z)
|
||||
EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, c)
|
||||
|
||||
} // end namespase internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_GENERAL_MATRIX_VECTOR_MKL_H
|
||||
#endif // EIGEN_GENERAL_MATRIX_VECTOR_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
|
||||
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
|
||||
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
|
||||
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -40,7 +40,7 @@ namespace internal {
|
||||
|
||||
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
|
||||
|
||||
#define EIGEN_MKL_SYMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -52,28 +52,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE alpha) \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='L', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
EIGTYPE myone(1);\
|
||||
\
|
||||
/* Set transpose options */ \
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (LhsStorageOrder==RowMajor) uplo='U'; \
|
||||
@ -83,16 +78,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
|
||||
b_tmp = rhs.adjoint(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _rhs; \
|
||||
\
|
||||
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
#define EIGEN_MKL_HEMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -103,29 +98,24 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE alpha) \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='L', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \
|
||||
EIGTYPE myone(1); \
|
||||
\
|
||||
/* Set transpose options */ \
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)lhsStride; \
|
||||
ldb = (MKL_INT)rhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
ldb = convert_index<BlasIndex>(rhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
|
||||
@ -151,23 +141,23 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
b_tmp = rhs.transpose(); \
|
||||
} \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} \
|
||||
\
|
||||
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_SYMM_L(double, double, d, d)
|
||||
EIGEN_MKL_SYMM_L(float, float, f, s)
|
||||
EIGEN_MKL_HEMM_L(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_HEMM_L(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_SYMM_L(double, double, d, d)
|
||||
EIGEN_BLAS_SYMM_L(float, float, f, s)
|
||||
EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_HEMM_L(scomplex, float, cf, c)
|
||||
|
||||
|
||||
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
|
||||
|
||||
#define EIGEN_MKL_SYMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -179,27 +169,22 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE alpha) \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='R', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
EIGTYPE myone(1);\
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)rhsStride; \
|
||||
ldb = (MKL_INT)lhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
ldb = convert_index<BlasIndex>(lhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (RhsStorageOrder==RowMajor) uplo='U'; \
|
||||
@ -209,16 +194,16 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
|
||||
b_tmp = lhs.adjoint(); \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
} else b = _lhs; \
|
||||
\
|
||||
MKLPREFIX##symm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
\
|
||||
} \
|
||||
};
|
||||
|
||||
|
||||
#define EIGEN_MKL_HEMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -229,35 +214,30 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE alpha) \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
char side='R', uplo='L'; \
|
||||
MKL_INT m, n, lda, ldb, ldc; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
EIGTYPE beta(1); \
|
||||
MatrixX##EIGPREFIX b_tmp; \
|
||||
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
|
||||
EIGTYPE myone(1); \
|
||||
\
|
||||
/* Set m, n, k */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)cols; \
|
||||
\
|
||||
/* Set alpha_ & beta_ */ \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set lda, ldb, ldc */ \
|
||||
lda = (MKL_INT)rhsStride; \
|
||||
ldb = (MKL_INT)lhsStride; \
|
||||
ldc = (MKL_INT)resStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
ldb = convert_index<BlasIndex>(lhsStride); \
|
||||
ldc = convert_index<BlasIndex>(resStride); \
|
||||
\
|
||||
/* Set a, b, c */ \
|
||||
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
|
||||
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
|
||||
a_tmp = rhs.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else a = _rhs; \
|
||||
if (RhsStorageOrder==RowMajor) uplo='U'; \
|
||||
\
|
||||
@ -279,17 +259,17 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
ldb = b_tmp.outerStride(); \
|
||||
} \
|
||||
\
|
||||
MKLPREFIX##hemm(&side, &uplo, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)b, &ldb, &beta_, (MKLTYPE*)res, &ldc); \
|
||||
BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_SYMM_R(double, double, d, d)
|
||||
EIGEN_MKL_SYMM_R(float, float, f, s)
|
||||
EIGEN_MKL_HEMM_R(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_HEMM_R(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_SYMM_R(double, double, d, d)
|
||||
EIGEN_BLAS_SYMM_R(float, float, f, s)
|
||||
EIGEN_BLAS_HEMM_R(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_HEMM_R(scomplex, float, cf, c)
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_MKL_H
|
||||
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
|
||||
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
|
||||
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
|
||||
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -47,7 +47,7 @@ template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool Conju
|
||||
struct selfadjoint_matrix_vector_product_symv :
|
||||
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {};
|
||||
|
||||
#define EIGEN_MKL_SYMV_SPECIALIZE(Scalar) \
|
||||
#define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \
|
||||
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \
|
||||
static void run( \
|
||||
@ -66,12 +66,12 @@ static void run( \
|
||||
} \
|
||||
}; \
|
||||
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(double)
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(float)
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_SYMV_SPECIALIZE(scomplex)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZE(double)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZE(float)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
|
||||
|
||||
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \
|
||||
#define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
|
||||
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
|
||||
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
|
||||
{ \
|
||||
@ -85,29 +85,27 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
|
||||
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
|
||||
IsLower = UpLo == Lower ? 1 : 0 \
|
||||
}; \
|
||||
MKL_INT n=size, lda=lhsStride, incx=1, incy=1; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
const EIGTYPE *x_ptr, myone(1); \
|
||||
BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
|
||||
EIGTYPE beta(1); \
|
||||
const EIGTYPE *x_ptr; \
|
||||
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
|
||||
assign_scalar_eig2mkl(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl(beta_, myone); \
|
||||
SYMVVector x_tmp; \
|
||||
if (ConjugateRhs) { \
|
||||
Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
|
||||
x_tmp=map_x.conjugate(); \
|
||||
x_ptr=x_tmp.data(); \
|
||||
} else x_ptr=_rhs; \
|
||||
MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \
|
||||
BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
||||
}\
|
||||
};
|
||||
|
||||
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv)
|
||||
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv)
|
||||
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv)
|
||||
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
|
||||
EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H
|
||||
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Triangular matrix * matrix product functionality based on ?TRMM.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
||||
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
||||
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
|
||||
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -50,7 +50,7 @@ struct product_triangular_matrix_matrix_trmm :
|
||||
|
||||
|
||||
// try to go to BLAS specialization
|
||||
#define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
|
||||
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
|
||||
template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -65,17 +65,17 @@ struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(double, true)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(double, false)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(float, true)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(float, false)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
|
||||
EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
|
||||
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
|
||||
|
||||
// implements col-major += alpha * op(triangular) * op(general)
|
||||
#define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -106,13 +106,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
|
||||
\
|
||||
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
|
||||
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
|
||||
if (rows != depth) { \
|
||||
\
|
||||
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
|
||||
/* FIXME handle mkl_domain_get_max_threads */ \
|
||||
/*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
|
||||
\
|
||||
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
|
||||
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
|
||||
/* Most likely no benefit to call TRMM or GEMM from BLAS */ \
|
||||
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
@ -121,27 +122,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
/* Make sense to call GEMM */ \
|
||||
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
|
||||
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
||||
MKL_INT aStride = aa_tmp.outerStride(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
\
|
||||
/*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
|
||||
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
char side = 'L', transa, uplo, diag = 'N'; \
|
||||
EIGTYPE *b; \
|
||||
const EIGTYPE *a; \
|
||||
MKL_INT m, n, lda, ldb; \
|
||||
MKLTYPE alpha_; \
|
||||
\
|
||||
/* Set alpha_*/ \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
BlasIndex m, n, lda, ldb; \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
m = (MKL_INT)diagSize; \
|
||||
n = (MKL_INT)cols; \
|
||||
m = convert_index<BlasIndex>(diagSize); \
|
||||
n = convert_index<BlasIndex>(cols); \
|
||||
\
|
||||
/* Set trans */ \
|
||||
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
|
||||
@ -152,7 +149,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
\
|
||||
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
\
|
||||
/* Set uplo */ \
|
||||
uplo = IsLower ? 'L' : 'U'; \
|
||||
@ -168,14 +165,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
else if (IsUnitDiag) \
|
||||
a_tmp.diagonal().setOnes();\
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _lhs; \
|
||||
lda = lhsStride; \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
} \
|
||||
/*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
|
||||
/*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
|
||||
/* call ?trmm*/ \
|
||||
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
||||
BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
|
||||
\
|
||||
/* Add op(a_triangular)*b into res*/ \
|
||||
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
||||
@ -183,13 +180,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMM_L(double, double, d, d)
|
||||
EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_TRMM_L(float, float, f, s)
|
||||
EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_TRMM_L(double, double, d, d)
|
||||
EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_TRMM_L(float, float, f, s)
|
||||
EIGEN_BLAS_TRMM_L(scomplex, float, cf, c)
|
||||
|
||||
// implements col-major += alpha * op(general) * op(triangular)
|
||||
#define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
@ -220,13 +217,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
|
||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
|
||||
\
|
||||
/* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
|
||||
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
|
||||
if (cols != depth) { \
|
||||
\
|
||||
int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
|
||||
int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
|
||||
\
|
||||
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
|
||||
/* Most likely no benefit to call TRMM or GEMM from MKL*/ \
|
||||
/* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
|
||||
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
@ -235,27 +232,23 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
/* Make sense to call GEMM */ \
|
||||
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
|
||||
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
||||
MKL_INT aStride = aa_tmp.outerStride(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
\
|
||||
/*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
|
||||
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||
} \
|
||||
return; \
|
||||
} \
|
||||
char side = 'R', transa, uplo, diag = 'N'; \
|
||||
EIGTYPE *b; \
|
||||
const EIGTYPE *a; \
|
||||
MKL_INT m, n, lda, ldb; \
|
||||
MKLTYPE alpha_; \
|
||||
\
|
||||
/* Set alpha_*/ \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
BlasIndex m, n, lda, ldb; \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
m = (MKL_INT)rows; \
|
||||
n = (MKL_INT)diagSize; \
|
||||
m = convert_index<BlasIndex>(rows); \
|
||||
n = convert_index<BlasIndex>(diagSize); \
|
||||
\
|
||||
/* Set trans */ \
|
||||
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
|
||||
@ -266,7 +259,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
\
|
||||
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
|
||||
b = b_tmp.data(); \
|
||||
ldb = b_tmp.outerStride(); \
|
||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||
\
|
||||
/* Set uplo */ \
|
||||
uplo = IsLower ? 'L' : 'U'; \
|
||||
@ -282,14 +275,14 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
else if (IsUnitDiag) \
|
||||
a_tmp.diagonal().setOnes();\
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _rhs; \
|
||||
lda = rhsStride; \
|
||||
lda = convert_index<BlasIndex>(rhsStride); \
|
||||
} \
|
||||
/*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
|
||||
/*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
|
||||
/* call ?trmm*/ \
|
||||
MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
||||
BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
|
||||
\
|
||||
/* Add op(a_triangular)*b into res*/ \
|
||||
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
||||
@ -297,13 +290,13 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMM_R(double, double, d, d)
|
||||
EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_TRMM_R(float, float, f, s)
|
||||
EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_TRMM_R(double, double, d, d)
|
||||
EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_TRMM_R(float, float, f, s)
|
||||
EIGEN_BLAS_TRMM_R(scomplex, float, cf, c)
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
||||
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
|
@ -25,13 +25,13 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Triangular matrix-vector product functionality based on ?TRMV.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
||||
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
||||
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
|
||||
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -47,7 +47,7 @@ template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename Rh
|
||||
struct triangular_matrix_vector_product_trmv :
|
||||
triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
|
||||
|
||||
#define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
|
||||
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
|
||||
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
||||
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
|
||||
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
|
||||
@ -65,13 +65,13 @@ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMV_SPECIALIZE(double)
|
||||
EIGEN_MKL_TRMV_SPECIALIZE(float)
|
||||
EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
|
||||
EIGEN_BLAS_TRMV_SPECIALIZE(double)
|
||||
EIGEN_BLAS_TRMV_SPECIALIZE(float)
|
||||
EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
|
||||
EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
|
||||
|
||||
// implements col-major: res += alpha * op(triangular) * vector
|
||||
#define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
||||
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
|
||||
enum { \
|
||||
@ -105,17 +105,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
/* Square part handling */\
|
||||
\
|
||||
char trans, uplo, diag; \
|
||||
MKL_INT m, n, lda, incx, incy; \
|
||||
BlasIndex m, n, lda, incx, incy; \
|
||||
EIGTYPE const *a; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
||||
EIGTYPE beta(1); \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
n = (MKL_INT)size; \
|
||||
lda = lhsStride; \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
incx = 1; \
|
||||
incy = resIncr; \
|
||||
incy = convert_index<BlasIndex>(resIncr); \
|
||||
\
|
||||
/* Set uplo, trans and diag*/ \
|
||||
trans = 'N'; \
|
||||
@ -123,39 +121,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
diag = IsUnitDiag ? 'U' : 'N'; \
|
||||
\
|
||||
/* call ?TRMV*/ \
|
||||
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
||||
BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
|
||||
\
|
||||
/* Add op(a_tr)rhs into res*/ \
|
||||
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
||||
BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
|
||||
if (size<(std::max)(rows,cols)) { \
|
||||
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
||||
x = x_tmp.data(); \
|
||||
if (size<rows) { \
|
||||
y = _res + size*resIncr; \
|
||||
a = _lhs + size; \
|
||||
m = rows-size; \
|
||||
n = size; \
|
||||
m = convert_index<BlasIndex>(rows-size); \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
} \
|
||||
else { \
|
||||
x += size; \
|
||||
y = _res; \
|
||||
a = _lhs + size*lda; \
|
||||
m = size; \
|
||||
n = cols-size; \
|
||||
m = convert_index<BlasIndex>(size); \
|
||||
n = convert_index<BlasIndex>(cols-size); \
|
||||
} \
|
||||
MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
|
||||
BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMV_CM(double, double, d, d)
|
||||
EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_TRMV_CM(float, float, f, s)
|
||||
EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_TRMV_CM(double, double, d, d)
|
||||
EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_TRMV_CM(float, float, f, s)
|
||||
EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c)
|
||||
|
||||
// implements row-major: res += alpha * op(triangular) * vector
|
||||
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
|
||||
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
||||
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
|
||||
enum { \
|
||||
@ -189,17 +187,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
/* Square part handling */\
|
||||
\
|
||||
char trans, uplo, diag; \
|
||||
MKL_INT m, n, lda, incx, incy; \
|
||||
BlasIndex m, n, lda, incx, incy; \
|
||||
EIGTYPE const *a; \
|
||||
MKLTYPE alpha_, beta_; \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
||||
assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
||||
EIGTYPE beta(1); \
|
||||
\
|
||||
/* Set m, n */ \
|
||||
n = (MKL_INT)size; \
|
||||
lda = lhsStride; \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
lda = convert_index<BlasIndex>(lhsStride); \
|
||||
incx = 1; \
|
||||
incy = resIncr; \
|
||||
incy = convert_index<BlasIndex>(resIncr); \
|
||||
\
|
||||
/* Set uplo, trans and diag*/ \
|
||||
trans = ConjLhs ? 'C' : 'T'; \
|
||||
@ -207,39 +203,39 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
|
||||
diag = IsUnitDiag ? 'U' : 'N'; \
|
||||
\
|
||||
/* call ?TRMV*/ \
|
||||
MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
||||
BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
|
||||
\
|
||||
/* Add op(a_tr)rhs into res*/ \
|
||||
MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
||||
BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
|
||||
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
|
||||
if (size<(std::max)(rows,cols)) { \
|
||||
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
||||
x = x_tmp.data(); \
|
||||
if (size<rows) { \
|
||||
y = _res + size*resIncr; \
|
||||
a = _lhs + size*lda; \
|
||||
m = rows-size; \
|
||||
n = size; \
|
||||
m = convert_index<BlasIndex>(rows-size); \
|
||||
n = convert_index<BlasIndex>(size); \
|
||||
} \
|
||||
else { \
|
||||
x += size; \
|
||||
y = _res; \
|
||||
a = _lhs + size; \
|
||||
m = size; \
|
||||
n = cols-size; \
|
||||
m = convert_index<BlasIndex>(size); \
|
||||
n = convert_index<BlasIndex>(cols-size); \
|
||||
} \
|
||||
MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
|
||||
BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRMV_RM(double, double, d, d)
|
||||
EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
|
||||
EIGEN_MKL_TRMV_RM(float, float, f, s)
|
||||
EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
|
||||
EIGEN_BLAS_TRMV_RM(double, double, d, d)
|
||||
EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z)
|
||||
EIGEN_BLAS_TRMV_RM(float, float, f, s)
|
||||
EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c)
|
||||
|
||||
} // end namespase internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
||||
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
|
@ -25,20 +25,20 @@
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
********************************************************************************
|
||||
* Content : Eigen bindings to Intel(R) MKL
|
||||
* Content : Eigen bindings to BLAS F77
|
||||
* Triangular matrix * matrix product functionality based on ?TRMM.
|
||||
********************************************************************************
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
|
||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
|
||||
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
|
||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// implements LeftSide op(triangular)^-1 * general
|
||||
#define EIGEN_MKL_TRSM_L(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
{ \
|
||||
@ -53,13 +53,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
MKL_INT m = size, n = otherSize, lda, ldb; \
|
||||
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
|
||||
char side = 'L', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
MKLTYPE alpha; \
|
||||
EIGTYPE myone(1); \
|
||||
assign_scalar_eig2mkl(alpha, myone); \
|
||||
ldb = otherStride;\
|
||||
EIGTYPE alpha(1); \
|
||||
ldb = convert_index<BlasIndex>(otherStride);\
|
||||
\
|
||||
const EIGTYPE *a; \
|
||||
/* Set trans */ \
|
||||
@ -75,25 +73,25 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
|
||||
if (conjA) { \
|
||||
a_tmp = tri.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _tri; \
|
||||
lda = triStride; \
|
||||
lda = convert_index<BlasIndex>(triStride); \
|
||||
} \
|
||||
if (IsUnitDiag) diag='U'; \
|
||||
/* call ?trsm*/ \
|
||||
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
|
||||
BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRSM_L(double, double, d)
|
||||
EIGEN_MKL_TRSM_L(dcomplex, MKL_Complex16, z)
|
||||
EIGEN_MKL_TRSM_L(float, float, s)
|
||||
EIGEN_MKL_TRSM_L(scomplex, MKL_Complex8, c)
|
||||
EIGEN_BLAS_TRSM_L(double, double, d)
|
||||
EIGEN_BLAS_TRSM_L(dcomplex, double, z)
|
||||
EIGEN_BLAS_TRSM_L(float, float, s)
|
||||
EIGEN_BLAS_TRSM_L(scomplex, float, c)
|
||||
|
||||
|
||||
// implements RightSide general * op(triangular)^-1
|
||||
#define EIGEN_MKL_TRSM_R(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
||||
#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASPREFIX) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
{ \
|
||||
@ -108,13 +106,11 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
MKL_INT m = otherSize, n = size, lda, ldb; \
|
||||
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
|
||||
char side = 'R', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
MKLTYPE alpha; \
|
||||
EIGTYPE myone(1); \
|
||||
assign_scalar_eig2mkl(alpha, myone); \
|
||||
ldb = otherStride;\
|
||||
EIGTYPE alpha(1); \
|
||||
ldb = convert_index<BlasIndex>(otherStride);\
|
||||
\
|
||||
const EIGTYPE *a; \
|
||||
/* Set trans */ \
|
||||
@ -130,26 +126,26 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
|
||||
if (conjA) { \
|
||||
a_tmp = tri.conjugate(); \
|
||||
a = a_tmp.data(); \
|
||||
lda = a_tmp.outerStride(); \
|
||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||
} else { \
|
||||
a = _tri; \
|
||||
lda = triStride; \
|
||||
lda = convert_index<BlasIndex>(triStride); \
|
||||
} \
|
||||
if (IsUnitDiag) diag='U'; \
|
||||
/* call ?trsm*/ \
|
||||
MKLPREFIX##trsm(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const MKLTYPE*)a, &lda, (MKLTYPE*)_other, &ldb); \
|
||||
BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
|
||||
/*std::cout << "TRMS_L specialization!\n";*/ \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_MKL_TRSM_R(double, double, d)
|
||||
EIGEN_MKL_TRSM_R(dcomplex, MKL_Complex16, z)
|
||||
EIGEN_MKL_TRSM_R(float, float, s)
|
||||
EIGEN_MKL_TRSM_R(scomplex, MKL_Complex8, c)
|
||||
EIGEN_BLAS_TRSM_R(double, double, d)
|
||||
EIGEN_BLAS_TRSM_R(dcomplex, double, z)
|
||||
EIGEN_BLAS_TRSM_R(float, float, s)
|
||||
EIGEN_BLAS_TRSM_R(scomplex, float, c)
|
||||
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_MKL_H
|
||||
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
|
@ -49,7 +49,7 @@
|
||||
#define EIGEN_USE_LAPACKE
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_USE_BLAS) || defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
|
||||
#if defined(EIGEN_USE_LAPACKE) || defined(EIGEN_USE_MKL_VML)
|
||||
#define EIGEN_USE_MKL
|
||||
#endif
|
||||
|
||||
@ -64,7 +64,6 @@
|
||||
# ifndef EIGEN_USE_MKL
|
||||
/*If the MKL version is too old, undef everything*/
|
||||
# undef EIGEN_USE_MKL_ALL
|
||||
# undef EIGEN_USE_BLAS
|
||||
# undef EIGEN_USE_LAPACKE
|
||||
# undef EIGEN_USE_MKL_VML
|
||||
# undef EIGEN_USE_LAPACKE_STRICT
|
||||
@ -107,52 +106,23 @@
|
||||
#else
|
||||
#define EIGEN_MKL_DOMAIN_PARDISO MKL_PARDISO
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
typedef std::complex<double> dcomplex;
|
||||
typedef std::complex<float> scomplex;
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename MKLType, typename EigenType>
|
||||
static inline void assign_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) {
|
||||
mklScalar=eigenScalar;
|
||||
}
|
||||
|
||||
template<typename MKLType, typename EigenType>
|
||||
static inline void assign_conj_scalar_eig2mkl(MKLType& mklScalar, const EigenType& eigenScalar) {
|
||||
mklScalar=eigenScalar;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void assign_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
|
||||
mklScalar.real=eigenScalar.real();
|
||||
mklScalar.imag=eigenScalar.imag();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void assign_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
|
||||
mklScalar.real=eigenScalar.real();
|
||||
mklScalar.imag=eigenScalar.imag();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void assign_conj_scalar_eig2mkl<MKL_Complex16,dcomplex>(MKL_Complex16& mklScalar, const dcomplex& eigenScalar) {
|
||||
mklScalar.real=eigenScalar.real();
|
||||
mklScalar.imag=-eigenScalar.imag();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void assign_conj_scalar_eig2mkl<MKL_Complex8,scomplex>(MKL_Complex8& mklScalar, const scomplex& eigenScalar) {
|
||||
mklScalar.real=eigenScalar.real();
|
||||
mklScalar.imag=-eigenScalar.imag();
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
#if defined(EIGEN_USE_MKL)
|
||||
typedef MKL_INT BlasIndex;
|
||||
#else
|
||||
typedef int BlasIndex;
|
||||
#endif
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#if defined(EIGEN_USE_BLAS)
|
||||
#include "../../misc/blas.h"
|
||||
#endif
|
||||
|
||||
#endif // EIGEN_MKL_SUPPORT_H
|
||||
|
@ -30,15 +30,15 @@ int BLASFUNC(cdotcw) (int *, float *, int *, float *, int *, float*);
|
||||
int BLASFUNC(zdotuw) (int *, double *, int *, double *, int *, double*);
|
||||
int BLASFUNC(zdotcw) (int *, double *, int *, double *, int *, double*);
|
||||
|
||||
int BLASFUNC(saxpy) (int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(daxpy) (int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(qaxpy) (int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(caxpy) (int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zaxpy) (int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xaxpy) (int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(caxpyc)(int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zaxpyc)(int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xaxpyc)(int *, double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(saxpy) (const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(daxpy) (const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(caxpy) (const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(zaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xaxpy) (const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(caxpyc)(const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(zaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xaxpyc)(const int *, const double *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(scopy) (int *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dcopy) (int *, double *, int *, double *, int *);
|
||||
@ -177,31 +177,19 @@ int BLASFUNC(xgeru)(int *, int *, double *, double *, int *,
|
||||
int BLASFUNC(xgerc)(int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, int *);
|
||||
|
||||
int BLASFUNC(sgemv)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dgemv)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(qgemv)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(cgemv)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zgemv)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xgemv)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(sgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(cgemv)(const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xgemv)(const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(strsv) (char *, char *, char *, int *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(dtrsv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(qtrsv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(ctrsv) (char *, char *, char *, int *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(ztrsv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(xtrsv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(strsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(ctrsv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(ztrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xtrsv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(stpsv) (char *, char *, char *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dtpsv) (char *, char *, char *, int *, double *, double *, int *);
|
||||
@ -210,18 +198,12 @@ int BLASFUNC(ctpsv) (char *, char *, char *, int *, float *, float *, int *);
|
||||
int BLASFUNC(ztpsv) (char *, char *, char *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xtpsv) (char *, char *, char *, int *, double *, double *, int *);
|
||||
|
||||
int BLASFUNC(strmv) (char *, char *, char *, int *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(dtrmv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(qtrmv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(ctrmv) (char *, char *, char *, int *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(ztrmv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(xtrmv) (char *, char *, char *, int *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(strmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(ctrmv) (const char *, const char *, const char *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(ztrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xtrmv) (const char *, const char *, const char *, const int *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(stpmv) (char *, char *, char *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dtpmv) (char *, char *, char *, int *, double *, double *, int *);
|
||||
@ -244,18 +226,9 @@ int BLASFUNC(ctbsv) (char *, char *, char *, int *, int *, float *, int *, floa
|
||||
int BLASFUNC(ztbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xtbsv) (char *, char *, char *, int *, int *, double *, int *, double *, int *);
|
||||
|
||||
int BLASFUNC(ssymv) (char *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dsymv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(qsymv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(csymv) (char *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zsymv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xsymv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(ssymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(sspmv) (char *, int *, float *, float *,
|
||||
float *, int *, float *, float *, int *);
|
||||
@ -263,38 +236,17 @@ int BLASFUNC(dspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(qspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(cspmv) (char *, int *, float *, float *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
|
||||
int BLASFUNC(ssyr) (char *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(dsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(qsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(ssyr) (const char *, const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qsyr) (const char *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(ssyr2) (char *, int *, float *,
|
||||
float *, int *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dsyr2) (char *, int *, double *,
|
||||
double *, int *, double *, int *, double *, int *);
|
||||
int BLASFUNC(qsyr2) (char *, int *, double *,
|
||||
double *, int *, double *, int *, double *, int *);
|
||||
int BLASFUNC(csyr2) (char *, int *, float *,
|
||||
float *, int *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zsyr2) (char *, int *, double *,
|
||||
double *, int *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xsyr2) (char *, int *, double *,
|
||||
double *, int *, double *, int *, double *, int *);
|
||||
int BLASFUNC(ssyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(csyr2) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(zsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xsyr2) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(sspr) (char *, int *, float *, float *, int *,
|
||||
float *);
|
||||
@ -302,12 +254,6 @@ int BLASFUNC(dspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
int BLASFUNC(qspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
|
||||
float *);
|
||||
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
|
||||
int BLASFUNC(sspr2) (char *, int *, float *,
|
||||
float *, int *, float *, int *, float *);
|
||||
@ -347,12 +293,9 @@ int BLASFUNC(zhpr2) (char *, int *, double *,
|
||||
int BLASFUNC(xhpr2) (char *, int *, double *,
|
||||
double *, int *, double *, int *, double *);
|
||||
|
||||
int BLASFUNC(chemv) (char *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zhemv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xhemv) (char *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(chemv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xhemv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(chpmv) (char *, int *, float *, float *,
|
||||
float *, int *, float *, float *, int *);
|
||||
@ -401,18 +344,12 @@ int BLASFUNC(xhbmv)(char *, int *, int *, double *, double *, int *,
|
||||
|
||||
/* Level 3 routines */
|
||||
|
||||
int BLASFUNC(sgemm)(char *, char *, int *, int *, int *, float *,
|
||||
float *, int *, float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dgemm)(char *, char *, int *, int *, int *, double *,
|
||||
double *, int *, double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(qgemm)(char *, char *, int *, int *, int *, double *,
|
||||
double *, int *, double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(cgemm)(char *, char *, int *, int *, int *, float *,
|
||||
float *, int *, float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zgemm)(char *, char *, int *, int *, int *, double *,
|
||||
double *, int *, double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xgemm)(char *, char *, int *, int *, int *, double *,
|
||||
double *, int *, double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(sgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(cgemm)(const char *, const char *, const int *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xgemm)(const char *, const char *, const int *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(cgemm3m)(char *, char *, int *, int *, int *, float *,
|
||||
float *, int *, float *, int *, float *, float *, int *);
|
||||
@ -434,84 +371,48 @@ int BLASFUNC(zge2mm)(char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *,
|
||||
double *, double *, int *);
|
||||
|
||||
int BLASFUNC(strsm)(char *, char *, char *, char *, int *, int *,
|
||||
float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dtrsm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(qtrsm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(ctrsm)(char *, char *, char *, char *, int *, int *,
|
||||
float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(ztrsm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xtrsm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(strsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(ctrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(ztrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xtrsm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(strmm)(char *, char *, char *, char *, int *, int *,
|
||||
float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dtrmm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(qtrmm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(ctrmm)(char *, char *, char *, char *, int *, int *,
|
||||
float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(ztrmm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(xtrmm)(char *, char *, char *, char *, int *, int *,
|
||||
double *, double *, int *, double *, int *);
|
||||
int BLASFUNC(strmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(dtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(qtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(ctrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, float *, const int *);
|
||||
int BLASFUNC(ztrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
int BLASFUNC(xtrmm)(const char *, const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, double *, const int *);
|
||||
|
||||
int BLASFUNC(ssymm)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dsymm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(qsymm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(csymm)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zsymm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xsymm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(ssymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(csymm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xsymm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(csymm3m)(char *, char *, int *, int *, float *, float *, int *, float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xsymm3m)(char *, char *, int *, int *, double *, double *, int *, double *, int *, double *, double *, int *);
|
||||
|
||||
int BLASFUNC(ssyrk)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, float *, int *);
|
||||
int BLASFUNC(dsyrk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(qsyrk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(csyrk)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, float *, int *);
|
||||
int BLASFUNC(zsyrk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(xsyrk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(ssyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(csyrk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xsyrk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(ssyr2k)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(dsyr2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(qsyr2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(csyr2k)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zsyr2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(xsyr2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(ssyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(dsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(qsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(csyr2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xsyr2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(chemm)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zhemm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xhemm)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(chemm)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xhemm)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(chemm3m)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
@ -520,136 +421,17 @@ int BLASFUNC(zhemm3m)(char *, char *, int *, int *, double *, double *, int *,
|
||||
int BLASFUNC(xhemm3m)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, int *, double *, double *, int *);
|
||||
|
||||
int BLASFUNC(cherk)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, float *, int *);
|
||||
int BLASFUNC(zherk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(xherk)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double *, double *, int *);
|
||||
int BLASFUNC(cherk)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xherk)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(cher2k)(char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zher2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(xher2k)(char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(cher2m)(char *, char *, char *, int *, int *, float *, float *, int *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(xher2m)(char *, char *, char *, int *, int *, double *, double *, int *,
|
||||
double*, int *, double *, double *, int *);
|
||||
int BLASFUNC(cher2k)(const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xher2k)(const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(cher2m)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double*, const int *, const double *, double *, const int *);
|
||||
|
||||
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
|
||||
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
|
||||
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
|
||||
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
|
||||
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
152
Eigen/src/misc/lapack.h
Normal file
152
Eigen/src/misc/lapack.h
Normal file
@ -0,0 +1,152 @@
|
||||
#ifndef LAPACK_H
|
||||
#define LAPACK_H
|
||||
|
||||
#include "blas.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
#endif
|
||||
|
||||
int BLASFUNC(csymv) (const char *, const int *, const float *, const float *, const int *, const float *, const int *, const float *, float *, const int *);
|
||||
int BLASFUNC(zsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
int BLASFUNC(xsymv) (const char *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
|
||||
int BLASFUNC(cspmv) (char *, int *, float *, float *,
|
||||
float *, int *, float *, float *, int *);
|
||||
int BLASFUNC(zspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
int BLASFUNC(xspmv) (char *, int *, double *, double *,
|
||||
double *, int *, double *, double *, int *);
|
||||
|
||||
int BLASFUNC(csyr) (char *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(zsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(xsyr) (char *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
|
||||
int BLASFUNC(cspr) (char *, int *, float *, float *, int *,
|
||||
float *);
|
||||
int BLASFUNC(zspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
int BLASFUNC(xspr) (char *, int *, double *, double *, int *,
|
||||
double *);
|
||||
|
||||
int BLASFUNC(sgemt)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(dgemt)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
int BLASFUNC(cgemt)(char *, int *, int *, float *, float *, int *,
|
||||
float *, int *);
|
||||
int BLASFUNC(zgemt)(char *, int *, int *, double *, double *, int *,
|
||||
double *, int *);
|
||||
|
||||
int BLASFUNC(sgema)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dgema)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
int BLASFUNC(cgema)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zgema)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
|
||||
int BLASFUNC(sgems)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(dgems)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
int BLASFUNC(cgems)(char *, char *, int *, int *, float *,
|
||||
float *, int *, float *, float *, int *, float *, int *);
|
||||
int BLASFUNC(zgems)(char *, char *, int *, int *, double *,
|
||||
double *, int *, double*, double *, int *, double*, int *);
|
||||
|
||||
int BLASFUNC(sgetf2)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(dgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(qgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(cgetf2)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(zgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(xgetf2)(int *, int *, double *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgetrf)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(dgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(qgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(cgetrf)(int *, int *, float *, int *, int *, int *);
|
||||
int BLASFUNC(zgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
int BLASFUNC(xgetrf)(int *, int *, double *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(slaswp)(int *, float *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(dlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(qlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(claswp)(int *, float *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(zlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
int BLASFUNC(xlaswp)(int *, double *, int *, int *, int *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cgetrs)(char *, int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xgetrs)(char *, int *, int *, double *, int *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(sgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(qgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(cgesv)(int *, int *, float *, int *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
int BLASFUNC(xgesv)(int *, int *, double *, int *, int *, double*, int *, int *);
|
||||
|
||||
int BLASFUNC(spotf2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotf2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotf2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotf2)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(spotrf)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotrf)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotrf)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotrf)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(slauu2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(clauu2)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zlauu2)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xlauu2)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(slauum)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(clauum)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zlauum)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xlauum)(char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(strti2)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(ctrti2)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(ztrti2)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xtrti2)(char *, char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(strtri)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(ctrtri)(char *, char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(ztrtri)(char *, char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xtrtri)(char *, char *, int *, double *, int *, int *);
|
||||
|
||||
int BLASFUNC(spotri)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(dpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(qpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(cpotri)(char *, int *, float *, int *, int *);
|
||||
int BLASFUNC(zpotri)(char *, int *, double *, int *, int *);
|
||||
int BLASFUNC(xpotri)(char *, int *, double *, int *, int *);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
@ -10,8 +10,8 @@
|
||||
#ifndef EIGEN_BLAS_COMMON_H
|
||||
#define EIGEN_BLAS_COMMON_H
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <Eigen/Jacobi>
|
||||
#include "../Eigen/Core"
|
||||
#include "../Eigen/Jacobi"
|
||||
|
||||
#include <complex>
|
||||
|
||||
@ -19,8 +19,7 @@
|
||||
#error the token SCALAR must be defined to compile this file
|
||||
#endif
|
||||
|
||||
#include <Eigen/src/misc/blas.h>
|
||||
|
||||
#include "../Eigen/src/misc/blas.h"
|
||||
|
||||
#define NOTR 0
|
||||
#define TR 1
|
||||
@ -94,6 +93,7 @@ enum
|
||||
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> PlainMatrixType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType;
|
||||
typedef Map<const Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > ConstMatrixType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
|
||||
|
||||
@ -104,25 +104,44 @@ matrix(T* data, int rows, int cols, int stride)
|
||||
return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >
|
||||
matrix(const T* data, int rows, int cols, int stride)
|
||||
{
|
||||
return Map<const Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(T* data, int size, int incr)
|
||||
{
|
||||
return Map<Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> > make_vector(const T* data, int size, int incr)
|
||||
{
|
||||
return Map<const Matrix<T,Dynamic,1>, 0, InnerStride<Dynamic> >(data, size, InnerStride<Dynamic>(incr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Map<Matrix<T,Dynamic,1> > make_vector(T* data, int size)
|
||||
{
|
||||
return Map<Matrix<T,Dynamic,1> >(data, size);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Map<const Matrix<T,Dynamic,1> > make_vector(const T* data, int size)
|
||||
{
|
||||
return Map<const Matrix<T,Dynamic,1> >(data, size);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* get_compact_vector(T* x, int n, int incx)
|
||||
{
|
||||
if(incx==1)
|
||||
return x;
|
||||
|
||||
T* ret = new Scalar[n];
|
||||
typename Eigen::internal::remove_const<T>::type* ret = new Scalar[n];
|
||||
if(incx<0) make_vector(ret,n) = make_vector(x,n,-incx).reverse();
|
||||
else make_vector(ret,n) = make_vector(x,n, incx);
|
||||
return ret;
|
||||
|
@ -9,11 +9,11 @@
|
||||
|
||||
#include "common.h"
|
||||
|
||||
int EIGEN_BLAS_FUNC(axpy)(int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy)
|
||||
int EIGEN_BLAS_FUNC(axpy)(const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *py, const int *incy)
|
||||
{
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
Scalar* y = reinterpret_cast<Scalar*>(py);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
|
||||
if(*n<=0) return 0;
|
||||
|
||||
|
@ -16,7 +16,8 @@
|
||||
* where alpha and beta are scalars, x and y are n element vectors and
|
||||
* A is an n by n hermitian matrix.
|
||||
*/
|
||||
int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
|
||||
int EIGEN_BLAS_FUNC(hemv)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
|
||||
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
|
||||
{
|
||||
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
|
||||
static const functype func[2] = {
|
||||
@ -26,11 +27,11 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
|
||||
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
Scalar* y = reinterpret_cast<Scalar*>(py);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
// check arguments
|
||||
int info = 0;
|
||||
@ -45,7 +46,7 @@ int EIGEN_BLAS_FUNC(hemv)(char *uplo, int *n, RealScalar *palpha, RealScalar *pa
|
||||
if(*n==0)
|
||||
return 1;
|
||||
|
||||
Scalar* actual_x = get_compact_vector(x,*n,*incx);
|
||||
const Scalar* actual_x = get_compact_vector(x,*n,*incx);
|
||||
Scalar* actual_y = get_compact_vector(y,*n,*incy);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
|
@ -23,7 +23,8 @@ struct general_matrix_vector_product_wrapper
|
||||
}
|
||||
};
|
||||
|
||||
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
|
||||
int EIGEN_BLAS_FUNC(gemv)(const char *opa, const int *m, const int *n, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *incb, const RealScalar *pbeta, RealScalar *pc, const int *incc)
|
||||
{
|
||||
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
|
||||
static const functype func[4] = {
|
||||
@ -36,11 +37,11 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
// check arguments
|
||||
int info = 0;
|
||||
@ -62,7 +63,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
if(code!=NOTR)
|
||||
std::swap(actual_m,actual_n);
|
||||
|
||||
Scalar* actual_b = get_compact_vector(b,actual_n,*incb);
|
||||
const Scalar* actual_b = get_compact_vector(b,actual_n,*incb);
|
||||
Scalar* actual_c = get_compact_vector(c,actual_m,*incc);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
@ -82,7 +83,7 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
return 1;
|
||||
}
|
||||
|
||||
int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
|
||||
int EIGEN_BLAS_FUNC(trsv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
|
||||
{
|
||||
typedef void (*functype)(int, const Scalar *, int, Scalar *);
|
||||
static const functype func[16] = {
|
||||
@ -116,7 +117,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
|
||||
int info = 0;
|
||||
@ -141,7 +142,7 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
|
||||
|
||||
|
||||
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
|
||||
int EIGEN_BLAS_FUNC(trmv)(const char *uplo, const char *opa, const char *diag, const int *n, const RealScalar *pa, const int *lda, RealScalar *pb, const int *incb)
|
||||
{
|
||||
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, const Scalar&);
|
||||
static const functype func[16] = {
|
||||
@ -175,7 +176,7 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
|
||||
int info = 0;
|
||||
@ -217,11 +218,11 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
||||
int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealScalar *palpha, RealScalar *pa, int *lda,
|
||||
RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
|
||||
{
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
Scalar* y = reinterpret_cast<Scalar*>(py);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
int coeff_rows = *kl+*ku+1;
|
||||
|
||||
int info = 0;
|
||||
@ -244,7 +245,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
|
||||
if(OP(*trans)!=NOTR)
|
||||
std::swap(actual_m,actual_n);
|
||||
|
||||
Scalar* actual_x = get_compact_vector(x,actual_n,*incx);
|
||||
const Scalar* actual_x = get_compact_vector(x,actual_n,*incx);
|
||||
Scalar* actual_y = get_compact_vector(y,actual_m,*incy);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
@ -253,7 +254,7 @@ int EIGEN_BLAS_FUNC(gbmv)(char *trans, int *m, int *n, int *kl, int *ku, RealSca
|
||||
else make_vector(actual_y, actual_m) *= beta;
|
||||
}
|
||||
|
||||
MatrixType mat_coeffs(a,coeff_rows,*n,*lda);
|
||||
ConstMatrixType mat_coeffs(a,coeff_rows,*n,*lda);
|
||||
|
||||
int nb = std::min(*n,(*m)+(*ku));
|
||||
for(int j=0; j<nb; ++j)
|
||||
|
@ -10,7 +10,8 @@
|
||||
#include "common.h"
|
||||
|
||||
// y = alpha*A*x + beta*y
|
||||
int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *px, int *incx, RealScalar *pbeta, RealScalar *py, int *incy)
|
||||
int EIGEN_BLAS_FUNC(symv) (const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda,
|
||||
const RealScalar *px, const int *incx, const RealScalar *pbeta, RealScalar *py, const int *incy)
|
||||
{
|
||||
typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
|
||||
static const functype func[2] = {
|
||||
@ -20,11 +21,11 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
|
||||
(internal::selfadjoint_matrix_vector_product<Scalar,int,ColMajor,Lower,false,false>::run),
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
Scalar* y = reinterpret_cast<Scalar*>(py);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
// check arguments
|
||||
int info = 0;
|
||||
@ -39,7 +40,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
|
||||
if(*n==0)
|
||||
return 0;
|
||||
|
||||
Scalar* actual_x = get_compact_vector(x,*n,*incx);
|
||||
const Scalar* actual_x = get_compact_vector(x,*n,*incx);
|
||||
Scalar* actual_y = get_compact_vector(y,*n,*incy);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
@ -61,7 +62,7 @@ int EIGEN_BLAS_FUNC(symv) (char *uplo, int *n, RealScalar *palpha, RealScalar *p
|
||||
}
|
||||
|
||||
// C := alpha*x*x' + C
|
||||
int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(syr)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&);
|
||||
@ -72,9 +73,9 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
(selfadjoint_rank1_update<Scalar,int,ColMajor,Lower,false,Conj>::run),
|
||||
};
|
||||
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
@ -87,7 +88,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
if(*n==0 || alpha==Scalar(0)) return 1;
|
||||
|
||||
// if the increment is not 1, let's copy it to a temporary vector to enable vectorization
|
||||
Scalar* x_cpy = get_compact_vector(x,*n,*incx);
|
||||
const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
|
||||
|
||||
int code = UPLO(*uplo);
|
||||
if(code>=2 || func[code]==0)
|
||||
@ -101,7 +102,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
}
|
||||
|
||||
// C := alpha*x*y' + alpha*y*x' + C
|
||||
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(syr2)(const char *uplo, const int *n, const RealScalar *palpha, const RealScalar *px, const int *incx, const RealScalar *py, const int *incy, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
|
||||
static const functype func[2] = {
|
||||
@ -111,10 +112,10 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
|
||||
(internal::rank2_update_selector<Scalar,int,Lower>::run),
|
||||
};
|
||||
|
||||
Scalar* x = reinterpret_cast<Scalar*>(px);
|
||||
Scalar* y = reinterpret_cast<Scalar*>(py);
|
||||
const Scalar* x = reinterpret_cast<const Scalar*>(px);
|
||||
const Scalar* y = reinterpret_cast<const Scalar*>(py);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
@ -128,8 +129,8 @@ int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
|
||||
if(alpha==Scalar(0))
|
||||
return 1;
|
||||
|
||||
Scalar* x_cpy = get_compact_vector(x,*n,*incx);
|
||||
Scalar* y_cpy = get_compact_vector(y,*n,*incy);
|
||||
const Scalar* x_cpy = get_compact_vector(x,*n,*incx);
|
||||
const Scalar* y_cpy = get_compact_vector(y,*n,*incy);
|
||||
|
||||
int code = UPLO(*uplo);
|
||||
if(code>=2 || func[code]==0)
|
||||
|
@ -9,7 +9,8 @@
|
||||
#include <iostream>
|
||||
#include "common.h"
|
||||
|
||||
int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
|
||||
@ -37,11 +38,11 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
int info = 0;
|
||||
if(OP(*opa)==INVALID) info = 1;
|
||||
@ -74,7 +75,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
|
||||
return 0;
|
||||
}
|
||||
|
||||
int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
|
||||
int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
|
||||
{
|
||||
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
|
||||
@ -137,9 +139,9 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
@ -178,7 +180,8 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
|
||||
// b = alpha*op(a)*b for side = 'L'or'l'
|
||||
// b = alpha*b*op(a) for side = 'R'or'r'
|
||||
int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
|
||||
int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
|
||||
{
|
||||
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
@ -241,9 +244,9 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
@ -281,14 +284,15 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
|
||||
// c = alpha*a*b + beta*c for side = 'L'or'l'
|
||||
// c = alpha*b*a + beta*c for side = 'R'or'r
|
||||
int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
@ -350,7 +354,8 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
|
||||
|
||||
// c = alpha*a*a' + beta*c for op = 'N'or'n'
|
||||
// c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
|
||||
int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
|
||||
#if !ISCOMPLEX
|
||||
@ -373,10 +378,10 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
};
|
||||
#endif
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
@ -429,13 +434,14 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
|
||||
// c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
|
||||
// c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
|
||||
int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
@ -496,13 +502,14 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
|
||||
|
||||
// c = alpha*a*b + beta*c for side = 'L'or'l'
|
||||
// c = alpha*b*a + beta*c for side = 'R'or'r
|
||||
int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);
|
||||
|
||||
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
@ -554,7 +561,8 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
|
||||
|
||||
// c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
|
||||
// c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
|
||||
int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
|
||||
|
||||
@ -574,7 +582,7 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
0
|
||||
};
|
||||
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
RealScalar alpha = *palpha;
|
||||
RealScalar beta = *pbeta;
|
||||
@ -620,12 +628,13 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
|
||||
// c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
|
||||
// c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
|
||||
int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
|
||||
int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
const Scalar* a = reinterpret_cast<const Scalar*>(pa);
|
||||
const Scalar* b = reinterpret_cast<const Scalar*>(pb);
|
||||
Scalar* c = reinterpret_cast<Scalar*>(pc);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
|
||||
RealScalar beta = *pbeta;
|
||||
|
||||
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n";
|
||||
|
@ -37,10 +37,10 @@ Here is another example reshaping a 2x6 matrix to a 6x2 one:
|
||||
|
||||
\section TutorialSlicing Slicing
|
||||
|
||||
Slicing consists in taking a set of rows, or columns, or elements, uniformly spaced within a matrix.
|
||||
Slicing consists in taking a set of rows, columns, or elements, uniformly spaced within a matrix.
|
||||
Again, the class Map allows to easily mimic this feature.
|
||||
|
||||
For instance, one can take skip every P elements in a vector:
|
||||
For instance, one can skip every P elements in a vector:
|
||||
<table class="example">
|
||||
<tr><th>Example:</th><th>Output:</th></tr>
|
||||
<tr><td>
|
||||
|
@ -55,7 +55,7 @@ Operations on other scalar types or mixing reals and complexes will continue to
|
||||
In addition you can choose which parts will be substituted by defining one or multiple of the following macros:
|
||||
|
||||
<table class="manual">
|
||||
<tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (currently works with Intel MKL only)</td></tr>
|
||||
<tr><td>\c EIGEN_USE_BLAS </td><td>Enables the use of external BLAS level 2 and 3 routines (compatible with any F77 BLAS interface, not only Intel MKL)</td></tr>
|
||||
<tr class="alt"><td>\c EIGEN_USE_LAPACKE </td><td>Enables the use of external Lapack routines via the <a href="http://www.netlib.org/lapack/lapacke.html">Intel Lapacke</a> C interface to Lapack (currently works with Intel MKL only)</td></tr>
|
||||
<tr><td>\c EIGEN_USE_LAPACKE_STRICT </td><td>Same as \c EIGEN_USE_LAPACKE but algorithm of lower robustness are disabled. This currently concerns only JacobiSVD which otherwise would be replaced by \c gesvd that is less robust than Jacobi rotations.</td></tr>
|
||||
<tr class="alt"><td>\c EIGEN_USE_MKL_VML </td><td>Enables the use of Intel VML (vector operations)</td></tr>
|
||||
|
@ -11,6 +11,7 @@
|
||||
#define EIGEN_LAPACK_COMMON_H
|
||||
|
||||
#include "../blas/common.h"
|
||||
#include "../Eigen/src/misc/lapack.h"
|
||||
|
||||
#define EIGEN_LAPACK_FUNC(FUNC,ARGLIST) \
|
||||
extern "C" { int EIGEN_BLAS_FUNC(FUNC) ARGLIST; } \
|
||||
|
@ -331,11 +331,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
|
||||
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
|
||||
std::numeric_limits<RealScalar>::infinity());
|
||||
VERIFY((numext::isnan)(numext::zeta(Scalar(0.9), Scalar(1.2345)))); // The second scalar does not matter
|
||||
|
||||
// Check the polygamma against scipy.special.polygamma examples
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
|
||||
VERIFY((numext::isnan)(numext::polygamma(Scalar(1.5), Scalar(1.2345)))); // The second scalar does not matter
|
||||
|
||||
// Check the polygamma function over a larger range of values
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435));
|
||||
|
@ -14,6 +14,9 @@
|
||||
|
||||
using std::sqrt;
|
||||
|
||||
// tolerance for chekcing number of iterations
|
||||
#define LM_EVAL_COUNT_TOL 4/3
|
||||
|
||||
int fcn_chkder(const VectorXd &x, VectorXd &fvec, MatrixXd &fjac, int iflag)
|
||||
{
|
||||
/* subroutine fcn for chkder example. */
|
||||
@ -1023,7 +1026,8 @@ void testNistLanczos1(void)
|
||||
VERIFY_IS_EQUAL(lm.njev, 72);
|
||||
// check norm^2
|
||||
std::cout.precision(30);
|
||||
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.4290986055242372e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
|
||||
std::cout << lm.fvec.squaredNorm() << "\n";
|
||||
VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
|
||||
// check x
|
||||
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
|
||||
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
|
||||
@ -1044,7 +1048,7 @@ void testNistLanczos1(void)
|
||||
VERIFY_IS_EQUAL(lm.nfev, 9);
|
||||
VERIFY_IS_EQUAL(lm.njev, 8);
|
||||
// check norm^2
|
||||
VERIFY_IS_APPROX(lm.fvec.squaredNorm(), 1.430571737783119393e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
|
||||
VERIFY(lm.fvec.squaredNorm() <= 1.4307867721E-25);
|
||||
// check x
|
||||
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
|
||||
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
|
||||
@ -1354,8 +1358,12 @@ void testNistMGH17(void)
|
||||
|
||||
// check return value
|
||||
VERIFY_IS_EQUAL(info, 2);
|
||||
VERIFY(lm.nfev < 650); // 602
|
||||
VERIFY(lm.njev < 600); // 545
|
||||
++g_test_level;
|
||||
VERIFY_IS_EQUAL(lm.nfev, 602); // 602
|
||||
VERIFY_IS_EQUAL(lm.njev, 545); // 545
|
||||
--g_test_level;
|
||||
VERIFY(lm.nfev < 602 * LM_EVAL_COUNT_TOL);
|
||||
VERIFY(lm.njev < 545 * LM_EVAL_COUNT_TOL);
|
||||
|
||||
/*
|
||||
* Second try
|
||||
|
@ -23,6 +23,9 @@
|
||||
|
||||
using std::sqrt;
|
||||
|
||||
// tolerance for chekcing number of iterations
|
||||
#define LM_EVAL_COUNT_TOL 4/3
|
||||
|
||||
struct lmder_functor : DenseFunctor<double>
|
||||
{
|
||||
lmder_functor(void): DenseFunctor<double>(3,15) {}
|
||||
@ -631,7 +634,7 @@ void testNistLanczos1(void)
|
||||
VERIFY_IS_EQUAL(lm.nfev(), 79);
|
||||
VERIFY_IS_EQUAL(lm.njev(), 72);
|
||||
// check norm^2
|
||||
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.430899764097e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
|
||||
VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
|
||||
// check x
|
||||
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
|
||||
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
|
||||
@ -652,7 +655,7 @@ void testNistLanczos1(void)
|
||||
VERIFY_IS_EQUAL(lm.nfev(), 9);
|
||||
VERIFY_IS_EQUAL(lm.njev(), 8);
|
||||
// check norm^2
|
||||
// VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.428595533845e-25); // should be 1.4307867721E-25, but nist results are on 128-bit floats
|
||||
VERIFY(lm.fvec().squaredNorm() <= 1.4307867721E-25);
|
||||
// check x
|
||||
VERIFY_IS_APPROX(x[0], 9.5100000027E-02);
|
||||
VERIFY_IS_APPROX(x[1], 1.0000000001E+00);
|
||||
@ -789,7 +792,8 @@ void testNistMGH10(void)
|
||||
MGH10_functor functor;
|
||||
LevenbergMarquardt<MGH10_functor> lm(functor);
|
||||
info = lm.minimize(x);
|
||||
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeErrorTooSmall);
|
||||
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
|
||||
// was: VERIFY_IS_EQUAL(info, 1);
|
||||
|
||||
// check norm^2
|
||||
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
|
||||
@ -799,9 +803,13 @@ void testNistMGH10(void)
|
||||
VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
|
||||
|
||||
// check return value
|
||||
//VERIFY_IS_EQUAL(info, 1);
|
||||
|
||||
++g_test_level;
|
||||
VERIFY_IS_EQUAL(lm.nfev(), 284 );
|
||||
VERIFY_IS_EQUAL(lm.njev(), 249 );
|
||||
--g_test_level;
|
||||
VERIFY(lm.nfev() < 284 * LM_EVAL_COUNT_TOL);
|
||||
VERIFY(lm.njev() < 249 * LM_EVAL_COUNT_TOL);
|
||||
|
||||
/*
|
||||
* Second try
|
||||
@ -809,7 +817,10 @@ void testNistMGH10(void)
|
||||
x<< 0.02, 4000., 250.;
|
||||
// do the computation
|
||||
info = lm.minimize(x);
|
||||
++g_test_level;
|
||||
VERIFY_IS_EQUAL(info, LevenbergMarquardtSpace::RelativeReductionTooSmall);
|
||||
// was: VERIFY_IS_EQUAL(info, 1);
|
||||
--g_test_level;
|
||||
|
||||
// check norm^2
|
||||
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 8.7945855171E+01);
|
||||
@ -819,9 +830,12 @@ void testNistMGH10(void)
|
||||
VERIFY_IS_APPROX(x[2], 3.4522363462E+02);
|
||||
|
||||
// check return value
|
||||
//VERIFY_IS_EQUAL(info, 1);
|
||||
++g_test_level;
|
||||
VERIFY_IS_EQUAL(lm.nfev(), 126);
|
||||
VERIFY_IS_EQUAL(lm.njev(), 116);
|
||||
--g_test_level;
|
||||
VERIFY(lm.nfev() < 126 * LM_EVAL_COUNT_TOL);
|
||||
VERIFY(lm.njev() < 116 * LM_EVAL_COUNT_TOL);
|
||||
}
|
||||
|
||||
|
||||
@ -896,8 +910,12 @@ void testNistBoxBOD(void)
|
||||
|
||||
// check return value
|
||||
VERIFY_IS_EQUAL(info, 1);
|
||||
++g_test_level;
|
||||
VERIFY_IS_EQUAL(lm.nfev(), 16 );
|
||||
VERIFY_IS_EQUAL(lm.njev(), 15 );
|
||||
--g_test_level;
|
||||
VERIFY(lm.nfev() < 16 * LM_EVAL_COUNT_TOL);
|
||||
VERIFY(lm.njev() < 15 * LM_EVAL_COUNT_TOL);
|
||||
// check norm^2
|
||||
VERIFY_IS_APPROX(lm.fvec().squaredNorm(), 1.1680088766E+03);
|
||||
// check x
|
||||
|
Loading…
x
Reference in New Issue
Block a user