mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 01:59:38 +08:00
implement proper error handling in level 3 routines
This commit is contained in:
parent
a8fb6b0ad3
commit
fd88d721d2
@ -56,22 +56,40 @@ extern "C"
|
||||
#define NUNIT 0
|
||||
#define UNIT 1
|
||||
|
||||
#define INVALID 0xff
|
||||
|
||||
#define OP(X) ( ((X)=='N' || (X)=='n') ? NOTR \
|
||||
: ((X)=='T' || (X)=='t') ? TR \
|
||||
: ((X)=='C' || (X)=='c') ? ADJ \
|
||||
: 0xff)
|
||||
: INVALID)
|
||||
|
||||
#define SIDE(X) ( ((X)=='L' || (X)=='l') ? LEFT \
|
||||
: ((X)=='R' || (X)=='r') ? RIGHT \
|
||||
: 0xff)
|
||||
: INVALID)
|
||||
|
||||
#define UPLO(X) ( ((X)=='U' || (X)=='u') ? UP \
|
||||
: ((X)=='L' || (X)=='l') ? LO \
|
||||
: 0xff)
|
||||
: INVALID)
|
||||
|
||||
#define DIAG(X) ( ((X)=='N' || (X)=='N') ? NUNIT \
|
||||
: ((X)=='U' || (X)=='u') ? UNIT \
|
||||
: 0xff)
|
||||
: INVALID)
|
||||
|
||||
|
||||
inline bool check_op(const char* op)
|
||||
{
|
||||
return OP(*op)!=0xff;
|
||||
}
|
||||
|
||||
inline bool check_side(const char* side)
|
||||
{
|
||||
return SIDE(*side)!=0xff;
|
||||
}
|
||||
|
||||
inline bool check_uplo(const char* uplo)
|
||||
{
|
||||
return UPLO(*uplo)!=0xff;
|
||||
}
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <Eigen/Jacobi>
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#define SCALAR std::complex<double>
|
||||
#define SCALAR_SUFFIX z
|
||||
#define SCALAR_SUFFIX_UP "Z"
|
||||
#define REAL_SCALAR_SUFFIX d
|
||||
#define ISCOMPLEX 1
|
||||
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#define SCALAR std::complex<float>
|
||||
#define SCALAR_SUFFIX c
|
||||
#define SCALAR_SUFFIX_UP "C"
|
||||
#define REAL_SCALAR_SUFFIX s
|
||||
#define ISCOMPLEX 1
|
||||
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#define SCALAR double
|
||||
#define SCALAR_SUFFIX d
|
||||
#define SCALAR_SUFFIX_UP "D"
|
||||
#define ISCOMPLEX 0
|
||||
|
||||
#include "level1_impl.h"
|
||||
|
@ -32,6 +32,20 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
|
||||
// check arguments
|
||||
int info = 0;
|
||||
if( OP(*opa)!=NOTR
|
||||
&& OP(*opa)!=TR
|
||||
&& OP(*opa)!=ADJ) info = 1;
|
||||
else if(*m<0) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*lda<std::max(1,*m)) info = 6;
|
||||
else if(*incb==0) info = 8;
|
||||
else if(*incc==0) info = 11;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"GEMV ",&info,6);
|
||||
// return xerbla_("SGEMV ",&info,sizeof("SGEMV "));
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
vector(c, *m, *incc) *= beta;
|
||||
|
||||
|
@ -53,13 +53,17 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
|
||||
int code = OP(*opa) | (OP(*opb) << 2);
|
||||
if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0))
|
||||
{
|
||||
int info = 1;
|
||||
xerbla_("GEMM", &info, 4);
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(OP(*opa)==INVALID) info = 1;
|
||||
else if(OP(*opb)==INVALID) info = 2;
|
||||
else if(*m<0) info = 3;
|
||||
else if(*n<0) info = 4;
|
||||
else if(*k<0) info = 5;
|
||||
else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
|
||||
else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
|
||||
else if(*ldc<std::max(1,*m)) info = 13;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
{
|
||||
@ -69,6 +73,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
|
||||
|
||||
ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
|
||||
|
||||
int code = OP(*opa) | (OP(*opb) << 2);
|
||||
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
|
||||
return 0;
|
||||
}
|
||||
@ -125,13 +130,19 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
else if(UPLO(*uplo)==INVALID) info = 2;
|
||||
else if(OP(*opa)==INVALID) info = 3;
|
||||
else if(DIAG(*diag)==INVALID) info = 4;
|
||||
else if(*m<0) info = 5;
|
||||
else if(*n<0) info = 6;
|
||||
else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
|
||||
else if(*ldb<std::max(1,*m)) info = 11;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
|
||||
|
||||
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
|
||||
if(code>=32 || func[code]==0 || *m<0 || *n <0)
|
||||
{
|
||||
int info=1;
|
||||
xerbla_("TRSM",&info,4);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(SIDE(*side)==LEFT)
|
||||
func[code](*m, *n, a, *lda, b, *ldb);
|
||||
@ -197,13 +208,19 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
else if(UPLO(*uplo)==INVALID) info = 2;
|
||||
else if(OP(*opa)==INVALID) info = 3;
|
||||
else if(DIAG(*diag)==INVALID) info = 4;
|
||||
else if(*m<0) info = 5;
|
||||
else if(*n<0) info = 6;
|
||||
else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
|
||||
else if(*ldb<std::max(1,*m)) info = 11;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
|
||||
|
||||
int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
|
||||
if(code>=32 || func[code]==0 || *m<0 || *n <0)
|
||||
{
|
||||
int info=1;
|
||||
xerbla_("TRMM",&info,4);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(*m==0 || *n==0)
|
||||
return 1;
|
||||
@ -230,12 +247,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
|
||||
if(*m<0 || *n<0)
|
||||
{
|
||||
int info=1;
|
||||
xerbla_("SYMM",&info,4);
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
else if(UPLO(*uplo)==INVALID) info = 2;
|
||||
else if(*m<0) info = 3;
|
||||
else if(*n<0) info = 4;
|
||||
else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
|
||||
else if(*ldb<std::max(1,*m)) info = 9;
|
||||
else if(*ldc<std::max(1,*m)) info = 12;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
{
|
||||
@ -312,13 +333,17 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
else if(OP(*op)==INVALID) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*k<0) info = 4;
|
||||
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
|
||||
else if(*ldc<std::max(1,*n)) info = 10;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
|
||||
|
||||
int code = OP(*op) | (UPLO(*uplo) << 2);
|
||||
if(code>=8 || func[code]==0 || *n<0 || *k<0)
|
||||
{
|
||||
int info=1;
|
||||
xerbla_("SYRK",&info,4);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
{
|
||||
@ -359,10 +384,16 @@ int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
|
||||
|
||||
if(*n<=0 || *k<0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
else if(OP(*op)==INVALID) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*k<0) info = 4;
|
||||
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
|
||||
else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
|
||||
else if(*ldc<std::max(1,*n)) info = 12;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
{
|
||||
@ -416,10 +447,16 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa
|
||||
|
||||
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
if(*m<0 || *n<0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(SIDE(*side)==INVALID) info = 1;
|
||||
else if(UPLO(*uplo)==INVALID) info = 2;
|
||||
else if(*m<0) info = 3;
|
||||
else if(*n<0) info = 4;
|
||||
else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
|
||||
else if(*ldb<std::max(1,*m)) info = 9;
|
||||
else if(*ldc<std::max(1,*m)) info = 12;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
|
||||
|
||||
if(beta==Scalar(0))
|
||||
matrix(c, *m, *n, *ldc).setZero();
|
||||
@ -484,14 +521,17 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp
|
||||
|
||||
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
if(*n<0 || *k<0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*k<0) info = 4;
|
||||
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
|
||||
else if(*ldc<std::max(1,*n)) info = 10;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
|
||||
|
||||
int code = OP(*op) | (UPLO(*uplo) << 2);
|
||||
if(code>=8 || func[code]==0)
|
||||
return 0;
|
||||
|
||||
if(beta!=RealScalar(1))
|
||||
{
|
||||
@ -520,10 +560,16 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
|
||||
Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
|
||||
RealScalar beta = *pbeta;
|
||||
|
||||
if(*n<=0 || *k<0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
int info = 0;
|
||||
if(UPLO(*uplo)==INVALID) info = 1;
|
||||
else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
|
||||
else if(*n<0) info = 3;
|
||||
else if(*k<0) info = 4;
|
||||
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
|
||||
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
|
||||
else if(*ldc<std::max(1,*n)) info = 12;
|
||||
if(info)
|
||||
return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
|
||||
|
||||
if(beta!=RealScalar(1))
|
||||
{
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#define SCALAR float
|
||||
#define SCALAR_SUFFIX s
|
||||
#define SCALAR_SUFFIX_UP "S"
|
||||
#define ISCOMPLEX 0
|
||||
|
||||
#include "level1_impl.h"
|
||||
|
Loading…
x
Reference in New Issue
Block a user