From a7b9250ad04fe02f9c51085164478bc1687577f3 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 1 Mar 2010 19:06:07 +0100 Subject: [PATCH] blas interface: fix compilation, fix GEMM, SYMM, TRMM, and TRSM, i,e., they all pass the blas test suite. More to come --- blas/CMakeLists.txt | 3 +- blas/common.h | 49 ++++++----- blas/level1_impl.h | 16 ++-- blas/level3_impl.h | 204 +++++++++++++++++++++++++++----------------- 4 files changed, 159 insertions(+), 113 deletions(-) diff --git a/blas/CMakeLists.txt b/blas/CMakeLists.txt index a6c330a5c..ee67fe519 100644 --- a/blas/CMakeLists.txt +++ b/blas/CMakeLists.txt @@ -4,7 +4,8 @@ add_custom_target(blas) set(EigenBlas_SRCS single.cpp double.cpp complex_single.cpp complex_double.cpp) -add_library(eigen_blas SHARED ${EigenBlas_SRCS}) +add_library(eigen_blas ${EigenBlas_SRCS}) +# add_library(eigen_blas SHARED ${EigenBlas_SRCS}) add_dependencies(blas eigen_blas) install(TARGETS eigen_blas diff --git a/blas/common.h b/blas/common.h index e7bfda570..8b9c6ff09 100644 --- a/blas/common.h +++ b/blas/common.h @@ -25,6 +25,8 @@ #ifndef EIGEN_BLAS_COMMON_H #define EIGEN_BLAS_COMMON_H +#include + #ifndef SCALAR #error the token SCALAR must be defined to compile this file #endif @@ -34,13 +36,12 @@ extern "C" { #endif -#include +#include "../bench/btl/libs/C_BLAS/blas.h" #ifdef __cplusplus } #endif - #define NOTR 0 #define TR 1 #define ADJ 2 @@ -75,27 +76,6 @@ extern "C" #include using namespace Eigen; -template -Block >, Dynamic, Dynamic> -matrix(T* data, int rows, int cols, int stride) -{ - return Map >(data, stride, cols).block(0,0,rows,cols); -} - -template -Block >, Dynamic, 1> -vector(T* data, int size, int incr) -{ - return Map >(data, size, incr).col(0); -} - -template -Map > -vector(T* data, int size) -{ - return Map >(data, size); -} - typedef SCALAR Scalar; typedef NumTraits::Real RealScalar; typedef std::complex Complex; @@ -106,10 +86,29 @@ enum Conj = IsComplex }; -typedef Block >, Dynamic, Dynamic> MatrixType; -typedef Block >, Dynamic, 1> StridedVectorType; +typedef Map, 0, OuterStride > MatrixType; +typedef Map, 0, InnerStride > StridedVectorType; typedef Map > CompactVectorType; +template +Map, 0, OuterStride > +matrix(T* data, int rows, int cols, int stride) +{ + return Map, 0, OuterStride >(data, rows, cols, OuterStride(stride)); +} + +template +Map, 0, InnerStride > vector(T* data, int size, int incr) +{ + return Map, 0, InnerStride >(data, size, InnerStride(incr)); +} + +template +Map > vector(T* data, int size) +{ + return Map >(data, size); +} + #define EIGEN_BLAS_FUNC(X) EIGEN_CAT(SCALAR_SUFFIX,X##_) #endif // EIGEN_BLAS_COMMON_H diff --git a/blas/level1_impl.h b/blas/level1_impl.h index c508626db..5326c6917 100644 --- a/blas/level1_impl.h +++ b/blas/level1_impl.h @@ -45,9 +45,9 @@ RealScalar EIGEN_BLAS_FUNC(asum)(int *n, RealScalar *px, int *incx) int size = IsComplex ? 2* *n : *n; if(*incx==1) - return vector(px,size).cwise().abs().sum(); + return vector(px,size).cwiseAbs().sum(); else - return vector(px,size,*incx).cwise().abs().sum(); + return vector(px,size,*incx).cwiseAbs().sum(); return 1; } @@ -71,9 +71,9 @@ Scalar EIGEN_BLAS_FUNC(dot)(int *n, RealScalar *px, int *incx, RealScalar *py, i Scalar* y = reinterpret_cast(py); if(*incx==1 && *incy==1) - return (vector(x,*n).cwise()*vector(y,*n)).sum(); + return (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); - return (vector(x,*n,*incx).cwise()*vector(y,*n,*incy)).sum(); + return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); } /* @@ -114,9 +114,9 @@ Scalar EIGEN_BLAS_FUNC(dotu)(int *n, RealScalar *px, int *incx, RealScalar *py, Scalar* y = reinterpret_cast(py); if(*incx==1 && *incy==1) - return (vector(x,*n).cwise()*vector(y,*n)).sum(); + return (vector(x,*n).cwiseProduct(vector(y,*n))).sum(); - return (vector(x,*n,*incx).cwise()*vector(y,*n,*incy)).sum(); + return (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum(); } #endif // ISCOMPLEX @@ -215,9 +215,9 @@ RealScalar EIGEN_BLAS_FUNC(casum)(int *n, RealScalar *px, int *incx) Complex* x = reinterpret_cast(px); if(*incx==1) - return vector(x,*n).cwise().abs().sum(); + return vector(x,*n).cwiseAbs().sum(); else - return vector(x,*n,*incx).cwise().abs().sum(); + return vector(x,*n,*incx).cwiseAbs().sum(); return 1; } diff --git a/blas/level3_impl.h b/blas/level3_impl.h index d44de1b5d..76497ec26 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -26,8 +26,9 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; typedef void (*functype)(int, int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[12]; + static functype func[12]; static bool init = false; if(!init) @@ -52,21 +53,29 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - if(beta!=Scalar(1)) - matrix(c, *m, *n, *ldc) *= beta; - int code = OP(*opa) | (OP(*opb) << 2); - if(code>=12 || func[code]==0) + if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0)) + { + int info = 1; + xerbla_("GEMM", &info, 4); return 0; + } + + if(beta!=Scalar(1)) + if(beta==Scalar(0)) + matrix(c, *m, *n, *ldc).setZero(); + else + matrix(c, *m, *n, *ldc) *= beta; func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha); - return 1; + return 0; } int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) { +// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n"; typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int); - functype func[32]; + static functype func[32]; static bool init = false; if(!init) @@ -74,38 +83,38 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, for(int k=0; k<32; ++k) func[k] = 0; - func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); init = true; } @@ -114,14 +123,23 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast(pb); Scalar alpha = *reinterpret_cast(palpha); - // TODO handle alpha - int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0) + if(code>=32 || func[code]==0 || *m<0 || *n <0) + { + int info=1; + xerbla_("TRSM",&info,4); return 0; + } - func[code](*m, *n, a, *lda, b, *ldb); - return 1; + if(SIDE(*side)==LEFT) + func[code](*m, *n, a, *lda, b, *ldb); + else + func[code](*n, *m, a, *lda, b, *ldb); + + if(alpha!=Scalar(1)) + matrix(b,*m,*n,*ldb) *= alpha; + + return 0; } @@ -129,46 +147,46 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, // b = alpha*b*op(a) for side = 'R'or'r' int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) { +// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n"; typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[32]; - + static functype func[32]; static bool init = false; if(!init) { for(int k=0; k<32; ++k) func[k] = 0; - func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); init = true; } @@ -178,10 +196,21 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar alpha = *reinterpret_cast(palpha); int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0) + if(code>=32 || func[code]==0 || *m<0 || *n <0) + { + int info=1; + xerbla_("TRMM",&info,4); return 0; + } - func[code](*m, *n, a, *lda, b, *ldb, b, *ldb, alpha); + // FIXME find a way to avoid this copy + Matrix tmp = matrix(b,*m,*n,*ldb); + matrix(b,*m,*n,*ldb).setZero(); + + if(SIDE(*side)==LEFT) + func[code](*m, *n, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha); + else + func[code](*n, *m, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha); return 1; } @@ -189,14 +218,26 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, // c = alpha*b*a + beta*c for side = 'R'or'r int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << " " +// << pa << " " << pb << " " << pc << "\n"; Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); + if(*m<0 || *n<0) + { + int info=1; + xerbla_("SYMM",&info,4); + return 0; + } + if(beta!=Scalar(1)) - matrix(c, *m, *n, *ldc) *= beta; + if(beta==Scalar(0)) + matrix(c, *m, *n, *ldc).setZero(); + else + matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) if(UPLO(*uplo)==UP) @@ -215,15 +256,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa else return 0; - return 1; + return 0; } // c = alpha*a*a' + beta*c for op = 'N'or'n' // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c' int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << "\n"; typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[8]; + static functype func[8]; static bool init = false; if(!init) @@ -231,13 +273,13 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp for(int k=0; k<8; ++k) func[k] = 0; - func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); - func[TR | (UP << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); + func[TR | (UP << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); - func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); - func[TR | (LO << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); + func[TR | (LO << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); init = true; } @@ -248,8 +290,12 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp Scalar beta = *reinterpret_cast(pbeta); int code = OP(*op) | (UPLO(*uplo) << 2); - if(code>=8 || func[code]==0) + if(code>=8 || func[code]==0 || *n<0 || *k<0) + { + int info=1; + xerbla_("SYRK",&info,4); return 0; + } if(beta!=Scalar(1)) matrix(c, *n, *n, *ldc) *= beta; @@ -314,7 +360,7 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) { typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[8]; + static functype func[8]; static bool init = false; if(!init) @@ -322,11 +368,11 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp for(int k=0; k<8; ++k) func[k] = 0; - func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); - func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); init = true; }