Add gemmtr implementation.

This commit is contained in:
Antonio Sánchez 2025-09-05 22:31:30 +00:00 committed by GitLab
parent f426eff949
commit e5f3fa2d61
No known key found for this signature in database
2 changed files with 146 additions and 26 deletions

View File

@ -464,6 +464,19 @@ void BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, con
void BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *, void BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *,
const double *, const int *, const double *, const int *, const double *, double *, const int *); const double *, const int *, const double *, const int *, const double *, double *, const int *);
void BLASFUNC(sgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *,
const int *, const float *, const int *, const float *, float *, const int *);
void BLASFUNC(dgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
const double *, const int *, const double *, const int *, const double *, double *, const int *);
void BLASFUNC(qgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
const double *, const int *, const double *, const int *, const double *, double *, const int *);
void BLASFUNC(cgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *,
const int *, const float *, const int *, const float *, float *, const int *);
void BLASFUNC(zgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
const double *, const int *, const double *, const int *, const double *, double *, const int *);
void BLASFUNC(xgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
const double *, const int *, const double *, const int *, const double *, double *, const int *);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -6,15 +6,12 @@
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <iostream>
#include "common.h" #include "common.h"
EIGEN_BLAS_FUNC(gemm) EIGEN_BLAS_FUNC(gemm)
(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha, (const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc,
const int *ldc) { const int *ldc) {
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " <<
// *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
using Eigen::ColMajor; using Eigen::ColMajor;
using Eigen::DenseIndex; using Eigen::DenseIndex;
using Eigen::Dynamic; using Eigen::Dynamic;
@ -97,11 +94,141 @@ EIGEN_BLAS_FUNC(gemm)
func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0); func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0);
} }
EIGEN_BLAS_FUNC(gemmtr)
(const char *uplo, const char *opa, const char *opb, const int *n, const int *k, const RealScalar *palpha,
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc,
const int *ldc) {
using Eigen::ColMajor;
using Eigen::DenseIndex;
using Eigen::Dynamic;
using Eigen::Lower;
using Eigen::RowMajor;
using Eigen::Upper;
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *,
DenseIndex, DenseIndex, const Scalar &, Eigen::internal::level3_blocking<Scalar, Scalar> &);
static const functype func[24] = {
// Upper-triangular result.
// array index: NOTR | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, ColMajor,
false, ColMajor, 1, Upper>::run),
// array index: TR | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, ColMajor,
false, ColMajor, 1, Upper>::run),
// array index: ADJ | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, ColMajor,
false, ColMajor, 1, Upper>::run),
0,
// array index: NOTR | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
false, ColMajor, 1, Upper>::run),
// array index: TR | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
false, ColMajor, 1, Upper>::run),
// array index: ADJ | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
false, ColMajor, 1, Upper>::run),
0,
// array index: NOTR | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
Conj, ColMajor, 1, Upper>::run),
// array index: TR | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
Conj, ColMajor, 1, Upper>::run),
// array index: ADJ | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
Conj, ColMajor, 1, Upper>::run),
0,
// Lower-triangular result.
// array index: NOTR | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, ColMajor,
false, ColMajor, 1, Lower>::run),
// array index: TR | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, ColMajor,
false, ColMajor, 1, Lower>::run),
// array index: ADJ | (NOTR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, ColMajor,
false, ColMajor, 1, Lower>::run),
0,
// array index: NOTR | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
false, ColMajor, 1, Lower>::run),
// array index: TR | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
false, ColMajor, 1, Lower>::run),
// array index: ADJ | (TR << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
false, ColMajor, 1, Lower>::run),
0,
// array index: NOTR | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
Conj, ColMajor, 1, Lower>::run),
// array index: TR | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
Conj, ColMajor, 1, Lower>::run),
// array index: ADJ | (ADJ << 2)
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
Conj, ColMajor, 1, Lower>::run),
0,
};
const Scalar *a = reinterpret_cast<const Scalar *>(pa);
const Scalar *b = reinterpret_cast<const Scalar *>(pb);
Scalar *c = reinterpret_cast<Scalar *>(pc);
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
int info = 0;
if (UPLO(*uplo) == INVALID)
info = 1;
else if (OP(*opa) == INVALID)
info = 2;
else if (OP(*opb) == INVALID)
info = 3;
else if (*n < 0)
info = 4;
else if (*k < 0)
info = 5;
else if (*lda < std::max(1, (OP(*opa) == NOTR) ? *n : *k))
info = 8;
else if (*ldb < std::max(1, (OP(*opb) == NOTR) ? *k : *n))
info = 10;
else if (*ldc < std::max(1, *n))
info = 13;
if (info) return xerbla_(SCALAR_SUFFIX_UP "GEMMTR ", &info);
if (*n == 0) return;
const int upper = (UPLO(*uplo) == UP);
if (beta != Scalar(1)) {
if (beta == Scalar(0)) {
if (upper) {
matrix(c, *n, *n, *ldc).triangularView<Eigen::Upper>().setZero();
} else {
matrix(c, *n, *n, *ldc).triangularView<Eigen::Lower>().setZero();
}
} else {
if (upper) {
matrix(c, *n, *n, *ldc).triangularView<Eigen::Upper>() *= beta;
} else {
matrix(c, *n, *n, *ldc).triangularView<Eigen::Lower>() *= beta;
}
}
}
if (*k == 0) return;
Eigen::internal::gemm_blocking_space<ColMajor, Scalar, Scalar, Dynamic, Dynamic, Dynamic> blocking(*n, *n, *k, 1,
false);
int code = OP(*opa) | (OP(*opb) << 2) + (UPLO(*uplo) * 12);
func[code](*n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
}
EIGEN_BLAS_FUNC(trsm) EIGEN_BLAS_FUNC(trsm)
(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " "
// << *palpha << " " << *lda << " " << *ldb<< "\n";
using Eigen::ColMajor; using Eigen::ColMajor;
using Eigen::DenseIndex; using Eigen::DenseIndex;
using Eigen::Dynamic; using Eigen::Dynamic;
@ -240,8 +367,6 @@ EIGEN_BLAS_FUNC(trsm)
EIGEN_BLAS_FUNC(trmm) EIGEN_BLAS_FUNC(trmm)
(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) { const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " "
// << *lda << " " << *ldb << " " << *palpha << "\n";
using Eigen::ColMajor; using Eigen::ColMajor;
using Eigen::DenseIndex; using Eigen::DenseIndex;
using Eigen::Dynamic; using Eigen::Dynamic;
@ -381,8 +506,6 @@ EIGEN_BLAS_FUNC(trmm)
EIGEN_BLAS_FUNC(symm) EIGEN_BLAS_FUNC(symm)
(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa, (const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa,
const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb
// << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
const Scalar *a = reinterpret_cast<const Scalar *>(pa); const Scalar *a = reinterpret_cast<const Scalar *>(pa);
const Scalar *b = reinterpret_cast<const Scalar *>(pb); const Scalar *b = reinterpret_cast<const Scalar *>(pb);
Scalar *c = reinterpret_cast<Scalar *>(pc); Scalar *c = reinterpret_cast<Scalar *>(pc);
@ -472,8 +595,6 @@ EIGEN_BLAS_FUNC(symm)
EIGEN_BLAS_FUNC(syrk) EIGEN_BLAS_FUNC(syrk)
(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
// << *pbeta << " " << *ldc << "\n";
using Eigen::ColMajor; using Eigen::ColMajor;
using Eigen::DenseIndex; using Eigen::DenseIndex;
using Eigen::Dynamic; using Eigen::Dynamic;
@ -577,9 +698,6 @@ EIGEN_BLAS_FUNC(syr2k)
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha); Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta); Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
// << *ldb << " " << beta << " " << *ldc << "\n";
int info = 0; int info = 0;
if (UPLO(*uplo) == INVALID) if (UPLO(*uplo) == INVALID)
info = 1; info = 1;
@ -647,9 +765,6 @@ EIGEN_BLAS_FUNC(hemm)
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha); Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta); Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " "
// << beta << " " << *ldc << "\n";
int info = 0; int info = 0;
if (SIDE(*side) == INVALID) if (SIDE(*side) == INVALID)
info = 1; info = 1;
@ -719,8 +834,6 @@ RowMajor,true,Conj, ColMajor, 1>
EIGEN_BLAS_FUNC(herk) EIGEN_BLAS_FUNC(herk)
(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa, (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
// << *pbeta << " " << *ldc << "\n";
using Eigen::ColMajor; using Eigen::ColMajor;
using Eigen::DenseIndex; using Eigen::DenseIndex;
using Eigen::Dynamic; using Eigen::Dynamic;
@ -754,9 +867,6 @@ EIGEN_BLAS_FUNC(herk)
RealScalar alpha = *palpha; RealScalar alpha = *palpha;
RealScalar beta = *pbeta; RealScalar beta = *pbeta;
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " <<
// beta << " " << *ldc << "\n";
int info = 0; int info = 0;
if (UPLO(*uplo) == INVALID) if (UPLO(*uplo) == INVALID)
info = 1; info = 1;
@ -810,9 +920,6 @@ EIGEN_BLAS_FUNC(her2k)
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha); Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
RealScalar beta = *pbeta; RealScalar beta = *pbeta;
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
// << *ldb << " " << beta << " " << *ldc << "\n";
int info = 0; int info = 0;
if (UPLO(*uplo) == INVALID) if (UPLO(*uplo) == INVALID)
info = 1; info = 1;
@ -875,4 +982,4 @@ EIGEN_BLAS_FUNC(her2k)
} }
} }
#endif // ISCOMPLEX #endif