mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-09 15:53:16 +08:00
Add gemmtr implementation.
This commit is contained in:
parent
f426eff949
commit
e5f3fa2d61
13
blas/blas.h
13
blas/blas.h
@ -464,6 +464,19 @@ void BLASFUNC(zher2m)(const char *, const char *, const char *, const int *, con
|
||||
void BLASFUNC(xher2m)(const char *, const char *, const char *, const int *, const int *, const double *,
|
||||
const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
void BLASFUNC(sgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *,
|
||||
const int *, const float *, const int *, const float *, float *, const int *);
|
||||
void BLASFUNC(dgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
|
||||
const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
void BLASFUNC(qgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
|
||||
const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
void BLASFUNC(cgemmtr)(const char *, const char *, const char *, const int *, const int *, const float *, const float *,
|
||||
const int *, const float *, const int *, const float *, float *, const int *);
|
||||
void BLASFUNC(zgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
|
||||
const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
void BLASFUNC(xgemmtr)(const char *, const char *, const char *, const int *, const int *, const double *,
|
||||
const double *, const int *, const double *, const int *, const double *, double *, const int *);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -6,15 +6,12 @@
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
#include <iostream>
|
||||
#include "common.h"
|
||||
|
||||
EIGEN_BLAS_FUNC(gemm)
|
||||
(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc,
|
||||
const int *ldc) {
|
||||
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " <<
|
||||
// *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
@ -97,11 +94,141 @@ EIGEN_BLAS_FUNC(gemm)
|
||||
func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0);
|
||||
}
|
||||
|
||||
EIGEN_BLAS_FUNC(gemmtr)
|
||||
(const char *uplo, const char *opa, const char *opb, const int *n, const int *k, const RealScalar *palpha,
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc,
|
||||
const int *ldc) {
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
using Eigen::Lower;
|
||||
using Eigen::RowMajor;
|
||||
using Eigen::Upper;
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *,
|
||||
DenseIndex, DenseIndex, const Scalar &, Eigen::internal::level3_blocking<Scalar, Scalar> &);
|
||||
static const functype func[24] = {
|
||||
// Upper-triangular result.
|
||||
// array index: NOTR | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
// array index: TR | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
// array index: ADJ | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
0,
|
||||
// array index: NOTR | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
// array index: TR | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
// array index: ADJ | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Upper>::run),
|
||||
0,
|
||||
// array index: NOTR | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Upper>::run),
|
||||
// array index: TR | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Upper>::run),
|
||||
// array index: ADJ | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Upper>::run),
|
||||
0,
|
||||
|
||||
// Lower-triangular result.
|
||||
// array index: NOTR | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
// array index: TR | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
// array index: ADJ | (NOTR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, ColMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
0,
|
||||
// array index: NOTR | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
// array index: TR | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
// array index: ADJ | (TR << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
|
||||
false, ColMajor, 1, Lower>::run),
|
||||
0,
|
||||
// array index: NOTR | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, ColMajor, false, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Lower>::run),
|
||||
// array index: TR | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, false, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Lower>::run),
|
||||
// array index: ADJ | (ADJ << 2)
|
||||
(Eigen::internal::general_matrix_matrix_triangular_product<DenseIndex, Scalar, RowMajor, Conj, Scalar, RowMajor,
|
||||
Conj, ColMajor, 1, Lower>::run),
|
||||
0,
|
||||
};
|
||||
|
||||
const Scalar *a = reinterpret_cast<const Scalar *>(pa);
|
||||
const Scalar *b = reinterpret_cast<const Scalar *>(pb);
|
||||
Scalar *c = reinterpret_cast<Scalar *>(pc);
|
||||
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
|
||||
|
||||
int info = 0;
|
||||
if (UPLO(*uplo) == INVALID)
|
||||
info = 1;
|
||||
else if (OP(*opa) == INVALID)
|
||||
info = 2;
|
||||
else if (OP(*opb) == INVALID)
|
||||
info = 3;
|
||||
else if (*n < 0)
|
||||
info = 4;
|
||||
else if (*k < 0)
|
||||
info = 5;
|
||||
else if (*lda < std::max(1, (OP(*opa) == NOTR) ? *n : *k))
|
||||
info = 8;
|
||||
else if (*ldb < std::max(1, (OP(*opb) == NOTR) ? *k : *n))
|
||||
info = 10;
|
||||
else if (*ldc < std::max(1, *n))
|
||||
info = 13;
|
||||
if (info) return xerbla_(SCALAR_SUFFIX_UP "GEMMTR ", &info);
|
||||
|
||||
if (*n == 0) return;
|
||||
|
||||
const int upper = (UPLO(*uplo) == UP);
|
||||
|
||||
if (beta != Scalar(1)) {
|
||||
if (beta == Scalar(0)) {
|
||||
if (upper) {
|
||||
matrix(c, *n, *n, *ldc).triangularView<Eigen::Upper>().setZero();
|
||||
} else {
|
||||
matrix(c, *n, *n, *ldc).triangularView<Eigen::Lower>().setZero();
|
||||
}
|
||||
} else {
|
||||
if (upper) {
|
||||
matrix(c, *n, *n, *ldc).triangularView<Eigen::Upper>() *= beta;
|
||||
} else {
|
||||
matrix(c, *n, *n, *ldc).triangularView<Eigen::Lower>() *= beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (*k == 0) return;
|
||||
|
||||
Eigen::internal::gemm_blocking_space<ColMajor, Scalar, Scalar, Dynamic, Dynamic, Dynamic> blocking(*n, *n, *k, 1,
|
||||
false);
|
||||
|
||||
int code = OP(*opa) | (OP(*opb) << 2) + (UPLO(*uplo) * 12);
|
||||
func[code](*n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
|
||||
}
|
||||
|
||||
EIGEN_BLAS_FUNC(trsm)
|
||||
(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
|
||||
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " "
|
||||
// << *palpha << " " << *lda << " " << *ldb<< "\n";
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
@ -240,8 +367,6 @@ EIGEN_BLAS_FUNC(trsm)
|
||||
EIGEN_BLAS_FUNC(trmm)
|
||||
(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
|
||||
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " "
|
||||
// << *lda << " " << *ldb << " " << *palpha << "\n";
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
@ -381,8 +506,6 @@ EIGEN_BLAS_FUNC(trmm)
|
||||
EIGEN_BLAS_FUNC(symm)
|
||||
(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa,
|
||||
const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
|
||||
// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb
|
||||
// << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
|
||||
const Scalar *a = reinterpret_cast<const Scalar *>(pa);
|
||||
const Scalar *b = reinterpret_cast<const Scalar *>(pb);
|
||||
Scalar *c = reinterpret_cast<Scalar *>(pc);
|
||||
@ -472,8 +595,6 @@ EIGEN_BLAS_FUNC(symm)
|
||||
EIGEN_BLAS_FUNC(syrk)
|
||||
(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
|
||||
const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
|
||||
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
|
||||
// << *pbeta << " " << *ldc << "\n";
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
@ -577,9 +698,6 @@ EIGEN_BLAS_FUNC(syr2k)
|
||||
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
|
||||
|
||||
// std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
|
||||
// << *ldb << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
int info = 0;
|
||||
if (UPLO(*uplo) == INVALID)
|
||||
info = 1;
|
||||
@ -647,9 +765,6 @@ EIGEN_BLAS_FUNC(hemm)
|
||||
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
|
||||
Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
|
||||
|
||||
// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " "
|
||||
// << beta << " " << *ldc << "\n";
|
||||
|
||||
int info = 0;
|
||||
if (SIDE(*side) == INVALID)
|
||||
info = 1;
|
||||
@ -719,8 +834,6 @@ RowMajor,true,Conj, ColMajor, 1>
|
||||
EIGEN_BLAS_FUNC(herk)
|
||||
(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
|
||||
const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
|
||||
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
|
||||
// << *pbeta << " " << *ldc << "\n";
|
||||
using Eigen::ColMajor;
|
||||
using Eigen::DenseIndex;
|
||||
using Eigen::Dynamic;
|
||||
@ -754,9 +867,6 @@ EIGEN_BLAS_FUNC(herk)
|
||||
RealScalar alpha = *palpha;
|
||||
RealScalar beta = *pbeta;
|
||||
|
||||
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " <<
|
||||
// beta << " " << *ldc << "\n";
|
||||
|
||||
int info = 0;
|
||||
if (UPLO(*uplo) == INVALID)
|
||||
info = 1;
|
||||
@ -810,9 +920,6 @@ EIGEN_BLAS_FUNC(her2k)
|
||||
Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
|
||||
RealScalar beta = *pbeta;
|
||||
|
||||
// std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
|
||||
// << *ldb << " " << beta << " " << *ldc << "\n";
|
||||
|
||||
int info = 0;
|
||||
if (UPLO(*uplo) == INVALID)
|
||||
info = 1;
|
||||
@ -875,4 +982,4 @@ EIGEN_BLAS_FUNC(her2k)
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ISCOMPLEX
|
||||
#endif
|
||||
|
Loading…
x
Reference in New Issue
Block a user