mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 04:35:57 +08:00
Add degenerate checks before calling BLAS routines.
This commit is contained in:
parent
fa201f1bb3
commit
6893287c99
@ -84,7 +84,7 @@ EIGEN_BLAS_RANKUPDATE_SPECIALIZE(float)
|
|||||||
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, \
|
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, \
|
||||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
|
/* typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs;*/ \
|
||||||
\
|
if (size == 0 || depth == 0) return; \
|
||||||
BlasIndex lda = convert_index<BlasIndex>(lhsStride), ldc = convert_index<BlasIndex>(resStride), \
|
BlasIndex lda = convert_index<BlasIndex>(lhsStride), ldc = convert_index<BlasIndex>(resStride), \
|
||||||
n = convert_index<BlasIndex>(size), k = convert_index<BlasIndex>(depth); \
|
n = convert_index<BlasIndex>(size), k = convert_index<BlasIndex>(depth); \
|
||||||
char uplo = ((IsLower) ? 'L' : 'U'), trans = ((AStorageOrder == RowMajor) ? 'T' : 'N'); \
|
char uplo = ((IsLower) ? 'L' : 'U'), trans = ((AStorageOrder == RowMajor) ? 'T' : 'N'); \
|
||||||
@ -107,7 +107,7 @@ EIGEN_BLAS_RANKUPDATE_SPECIALIZE(float)
|
|||||||
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, \
|
const EIGTYPE* /*rhs*/, Index /*rhsStride*/, EIGTYPE* res, Index resStride, \
|
||||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
|
typedef Matrix<EIGTYPE, Dynamic, Dynamic, AStorageOrder> MatrixType; \
|
||||||
\
|
if (size == 0 || depth == 0) return; \
|
||||||
BlasIndex lda = convert_index<BlasIndex>(lhsStride), ldc = convert_index<BlasIndex>(resStride), \
|
BlasIndex lda = convert_index<BlasIndex>(lhsStride), ldc = convert_index<BlasIndex>(resStride), \
|
||||||
n = convert_index<BlasIndex>(size), k = convert_index<BlasIndex>(depth); \
|
n = convert_index<BlasIndex>(size), k = convert_index<BlasIndex>(depth); \
|
||||||
char uplo = ((IsLower) ? 'L' : 'U'), trans = ((AStorageOrder == RowMajor) ? 'C' : 'N'); \
|
char uplo = ((IsLower) ? 'L' : 'U'), trans = ((AStorageOrder == RowMajor) ? 'C' : 'N'); \
|
||||||
|
@ -59,7 +59,7 @@ namespace internal {
|
|||||||
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \
|
||||||
using std::conj; \
|
using std::conj; \
|
||||||
\
|
if (rows == 0 || cols == 0 || depth == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||||
eigen_assert(resIncr == 1); \
|
eigen_assert(resIncr == 1); \
|
||||||
char transa, transb; \
|
char transa, transb; \
|
||||||
|
@ -95,6 +95,7 @@ EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
|
|||||||
\
|
\
|
||||||
static void run(Index rows, Index cols, const EIGTYPE* lhs, Index lhsStride, const EIGTYPE* rhs, Index rhsIncr, \
|
static void run(Index rows, Index cols, const EIGTYPE* lhs, Index lhsStride, const EIGTYPE* rhs, Index rhsIncr, \
|
||||||
EIGTYPE* res, Index resIncr, EIGTYPE alpha) { \
|
EIGTYPE* res, Index resIncr, EIGTYPE alpha) { \
|
||||||
|
if (rows == 0 || cols == 0) return; \
|
||||||
BlasIndex m = convert_index<BlasIndex>(rows), n = convert_index<BlasIndex>(cols), \
|
BlasIndex m = convert_index<BlasIndex>(rows), n = convert_index<BlasIndex>(cols), \
|
||||||
lda = convert_index<BlasIndex>(lhsStride), incx = convert_index<BlasIndex>(rhsIncr), \
|
lda = convert_index<BlasIndex>(lhsStride), incx = convert_index<BlasIndex>(rhsIncr), \
|
||||||
incy = convert_index<BlasIndex>(resIncr); \
|
incy = convert_index<BlasIndex>(resIncr); \
|
||||||
@ -111,8 +112,9 @@ EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
|
|||||||
x_tmp = map_x.conjugate(); \
|
x_tmp = map_x.conjugate(); \
|
||||||
x_ptr = x_tmp.data(); \
|
x_ptr = x_tmp.data(); \
|
||||||
incx = 1; \
|
incx = 1; \
|
||||||
} else \
|
} else { \
|
||||||
x_ptr = rhs; \
|
x_ptr = rhs; \
|
||||||
|
} \
|
||||||
BLASFUNC(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, \
|
BLASFUNC(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, \
|
||||||
(const BLASTYPE*)x_ptr, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
(const BLASTYPE*)x_ptr, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &incy); \
|
||||||
} \
|
} \
|
||||||
|
@ -49,6 +49,7 @@ namespace internal {
|
|||||||
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
|
if (rows == 0 || cols == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||||
eigen_assert(resIncr == 1); \
|
eigen_assert(resIncr == 1); \
|
||||||
char side = 'L', uplo = 'L'; \
|
char side = 'L', uplo = 'L'; \
|
||||||
@ -91,6 +92,7 @@ namespace internal {
|
|||||||
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
|
if (rows == 0 || cols == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||||
eigen_assert(resIncr == 1); \
|
eigen_assert(resIncr == 1); \
|
||||||
char side = 'L', uplo = 'L'; \
|
char side = 'L', uplo = 'L'; \
|
||||||
@ -164,6 +166,7 @@ EIGEN_BLAS_HEMM_L(scomplex, float, cf, chemm_)
|
|||||||
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index rows, Index cols, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
|
if (rows == 0 || cols == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||||
eigen_assert(resIncr == 1); \
|
eigen_assert(resIncr == 1); \
|
||||||
char side = 'R', uplo = 'L'; \
|
char side = 'R', uplo = 'L'; \
|
||||||
|
@ -78,6 +78,7 @@ EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
|
|||||||
\
|
\
|
||||||
static void run(Index size, const EIGTYPE* lhs, Index lhsStride, const EIGTYPE* _rhs, EIGTYPE* res, \
|
static void run(Index size, const EIGTYPE* lhs, Index lhsStride, const EIGTYPE* _rhs, EIGTYPE* res, \
|
||||||
EIGTYPE alpha) { \
|
EIGTYPE alpha) { \
|
||||||
|
if (size == 0) return; \
|
||||||
enum { IsRowMajor = StorageOrder == RowMajor ? 1 : 0, IsLower = UpLo == Lower ? 1 : 0 }; \
|
enum { IsRowMajor = StorageOrder == RowMajor ? 1 : 0, IsLower = UpLo == Lower ? 1 : 0 }; \
|
||||||
BlasIndex n = convert_index<BlasIndex>(size), lda = convert_index<BlasIndex>(lhsStride), incx = 1, incy = 1; \
|
BlasIndex n = convert_index<BlasIndex>(size), lda = convert_index<BlasIndex>(lhsStride), incx = 1, incy = 1; \
|
||||||
EIGTYPE beta(1); \
|
EIGTYPE beta(1); \
|
||||||
|
@ -90,6 +90,7 @@ EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
|
|||||||
static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
|
||||||
|
if (_rows == 0 || _cols == 0 || _depth == 0) return; \
|
||||||
Index diagSize = (std::min)(_rows, _depth); \
|
Index diagSize = (std::min)(_rows, _depth); \
|
||||||
Index rows = IsLower ? _rows : diagSize; \
|
Index rows = IsLower ? _rows : diagSize; \
|
||||||
Index depth = IsLower ? diagSize : _depth; \
|
Index depth = IsLower ? diagSize : _depth; \
|
||||||
@ -211,6 +212,7 @@ EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
|
|||||||
static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
|
||||||
|
if (_rows == 0 || _cols == 0 || _depth == 0) return; \
|
||||||
Index diagSize = (std::min)(_cols, _depth); \
|
Index diagSize = (std::min)(_cols, _depth); \
|
||||||
Index rows = _rows; \
|
Index rows = _rows; \
|
||||||
Index depth = IsLower ? _depth : diagSize; \
|
Index depth = IsLower ? _depth : diagSize; \
|
||||||
|
@ -87,6 +87,7 @@ EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
|
|||||||
}; \
|
}; \
|
||||||
static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
|
static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
|
||||||
Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
|
Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
|
||||||
|
if (rows_ == 0 || cols_ == 0) return; \
|
||||||
if (ConjLhs || IsZeroDiag) { \
|
if (ConjLhs || IsZeroDiag) { \
|
||||||
triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \
|
triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \
|
||||||
rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
|
rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
|
||||||
@ -183,6 +184,7 @@ EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c, _)
|
|||||||
}; \
|
}; \
|
||||||
static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
|
static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
|
||||||
Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
|
Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
|
||||||
|
if (rows_ == 0 || cols_ == 0) return; \
|
||||||
if (IsZeroDiag) { \
|
if (IsZeroDiag) { \
|
||||||
triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \
|
triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \
|
||||||
rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
|
rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
|
||||||
|
@ -52,6 +52,7 @@ namespace internal {
|
|||||||
}; \
|
}; \
|
||||||
static void run(Index size, Index otherSize, const EIGTYPE* _tri, Index triStride, EIGTYPE* _other, \
|
static void run(Index size, Index otherSize, const EIGTYPE* _tri, Index triStride, EIGTYPE* _other, \
|
||||||
Index otherIncr, Index otherStride, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
Index otherIncr, Index otherStride, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
|
if (size == 0 || otherSize == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
||||||
eigen_assert(otherIncr == 1); \
|
eigen_assert(otherIncr == 1); \
|
||||||
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
|
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
|
||||||
@ -110,6 +111,7 @@ EIGEN_BLAS_TRSM_L(scomplex, float, ctrsm_)
|
|||||||
}; \
|
}; \
|
||||||
static void run(Index size, Index otherSize, const EIGTYPE* _tri, Index triStride, EIGTYPE* _other, \
|
static void run(Index size, Index otherSize, const EIGTYPE* _tri, Index triStride, EIGTYPE* _other, \
|
||||||
Index otherIncr, Index otherStride, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
Index otherIncr, Index otherStride, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) { \
|
||||||
|
if (size == 0 || otherSize == 0) return; \
|
||||||
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
||||||
eigen_assert(otherIncr == 1); \
|
eigen_assert(otherIncr == 1); \
|
||||||
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
|
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user