From e5f3fa2d6143e3c781b5ee51b3c2b04a0fc3f5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Fri, 5 Sep 2025 22:31:30 +0000 Subject: [PATCH] Add gemmtr implementation. --- blas/blas.h | 13 ++++ blas/level3_impl.h | 159 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 146 insertions(+), 26 deletions(-) diff --git a/blas/blas.h b/blas/blas.h index 8962dc954..178b5e5cc 100644 --- a/blas/blas.h +++ b/blas/blas.h @@ -464,6 +464,19 @@ void BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, con void BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, const double *, const int *, const double *, const int *, const double *, double *, const int *); +void BLASFUNC(sgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, + const int *, const float *, const int *, const float *, float *, const int *); +void BLASFUNC(dgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *, + const double *, const int *, const double *, const int *, const double *, double *, const int *); +void BLASFUNC(qgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *, + const double *, const int *, const double *, const int *, const double *, double *, const int *); +void BLASFUNC(cgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *, + const int *, const float *, const int *, const float *, float *, const int *); +void BLASFUNC(zgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *, + const double *, const int *, const double *, const int *, const double *, double *, const int *); +void BLASFUNC(xgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *, + const double *, const int *, const double *, const int *, const double *, double *, const int *); + #ifdef __cplusplus } #endif diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 66a7d468b..b2e317388 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -6,15 +6,12 @@ // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#include #include "common.h" EIGEN_BLAS_FUNC(gemm) (const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << - // *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; using Eigen::ColMajor; using Eigen::DenseIndex; using Eigen::Dynamic; @@ -97,11 +94,141 @@ EIGEN_BLAS_FUNC(gemm) func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0); } +EIGEN_BLAS_FUNC(gemmtr) +(const char *uplo, const char *opa, const char *opb, const int *n, const int *k, const RealScalar *palpha, + const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, + const int *ldc) { + using Eigen::ColMajor; + using Eigen::DenseIndex; + using Eigen::Dynamic; + using Eigen::Lower; + using Eigen::RowMajor; + using Eigen::Upper; + typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, + DenseIndex, DenseIndex, const Scalar &, Eigen::internal::level3_blocking &); + static const functype func[24] = { + // Upper-triangular result. + // array index: NOTR | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + // array index: NOTR | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + // array index: NOTR | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + + // Lower-triangular result. + // array index: NOTR | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (NOTR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + // array index: NOTR | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (TR << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + // array index: NOTR | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: TR | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + // array index: ADJ | (ADJ << 2) + (Eigen::internal::general_matrix_matrix_triangular_product::run), + 0, + }; + + const Scalar *a = reinterpret_cast(pa); + const Scalar *b = reinterpret_cast(pb); + Scalar *c = reinterpret_cast(pc); + Scalar alpha = *reinterpret_cast(palpha); + Scalar beta = *reinterpret_cast(pbeta); + + int info = 0; + if (UPLO(*uplo) == INVALID) + info = 1; + else if (OP(*opa) == INVALID) + info = 2; + else if (OP(*opb) == INVALID) + info = 3; + else if (*n < 0) + info = 4; + else if (*k < 0) + info = 5; + else if (*lda < std::max(1, (OP(*opa) == NOTR) ? *n : *k)) + info = 8; + else if (*ldb < std::max(1, (OP(*opb) == NOTR) ? *k : *n)) + info = 10; + else if (*ldc < std::max(1, *n)) + info = 13; + if (info) return xerbla_(SCALAR_SUFFIX_UP "GEMMTR ", &info); + + if (*n == 0) return; + + const int upper = (UPLO(*uplo) == UP); + + if (beta != Scalar(1)) { + if (beta == Scalar(0)) { + if (upper) { + matrix(c, *n, *n, *ldc).triangularView().setZero(); + } else { + matrix(c, *n, *n, *ldc).triangularView().setZero(); + } + } else { + if (upper) { + matrix(c, *n, *n, *ldc).triangularView() *= beta; + } else { + matrix(c, *n, *n, *ldc).triangularView() *= beta; + } + } + } + + if (*k == 0) return; + + Eigen::internal::gemm_blocking_space blocking(*n, *n, *k, 1, + false); + + int code = OP(*opa) | (OP(*opb) << 2) + (UPLO(*uplo) * 12); + func[code](*n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking); +} + EIGEN_BLAS_FUNC(trsm) (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { - // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " - // << *palpha << " " << *lda << " " << *ldb<< "\n"; using Eigen::ColMajor; using Eigen::DenseIndex; using Eigen::Dynamic; @@ -240,8 +367,6 @@ EIGEN_BLAS_FUNC(trsm) EIGEN_BLAS_FUNC(trmm) (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { - // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " - // << *lda << " " << *ldb << " " << *palpha << "\n"; using Eigen::ColMajor; using Eigen::DenseIndex; using Eigen::Dynamic; @@ -381,8 +506,6 @@ EIGEN_BLAS_FUNC(trmm) EIGEN_BLAS_FUNC(symm) (const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb - // << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; const Scalar *a = reinterpret_cast(pa); const Scalar *b = reinterpret_cast(pb); Scalar *c = reinterpret_cast(pc); @@ -472,8 +595,6 @@ EIGEN_BLAS_FUNC(symm) EIGEN_BLAS_FUNC(syrk) (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " - // << *pbeta << " " << *ldc << "\n"; using Eigen::ColMajor; using Eigen::DenseIndex; using Eigen::Dynamic; @@ -577,9 +698,6 @@ EIGEN_BLAS_FUNC(syr2k) Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " - // << *ldb << " " << beta << " " << *ldc << "\n"; - int info = 0; if (UPLO(*uplo) == INVALID) info = 1; @@ -647,9 +765,6 @@ EIGEN_BLAS_FUNC(hemm) Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " - // << beta << " " << *ldc << "\n"; - int info = 0; if (SIDE(*side) == INVALID) info = 1; @@ -719,8 +834,6 @@ RowMajor,true,Conj, ColMajor, 1> EIGEN_BLAS_FUNC(herk) (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { - // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " - // << *pbeta << " " << *ldc << "\n"; using Eigen::ColMajor; using Eigen::DenseIndex; using Eigen::Dynamic; @@ -754,9 +867,6 @@ EIGEN_BLAS_FUNC(herk) RealScalar alpha = *palpha; RealScalar beta = *pbeta; - // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << - // beta << " " << *ldc << "\n"; - int info = 0; if (UPLO(*uplo) == INVALID) info = 1; @@ -810,9 +920,6 @@ EIGEN_BLAS_FUNC(her2k) Scalar alpha = *reinterpret_cast(palpha); RealScalar beta = *pbeta; - // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " - // << *ldb << " " << beta << " " << *ldc << "\n"; - int info = 0; if (UPLO(*uplo) == INVALID) info = 1; @@ -875,4 +982,4 @@ EIGEN_BLAS_FUNC(her2k) } } -#endif // ISCOMPLEX +#endif