Add OpenBLAS sbgemm.

This commit is contained in:
Antonio Sánchez 2025-06-16 18:23:03 +00:00 committed by Rasmus Munk Larsen
parent d228bcdf8f
commit cdf6a1f5ed

View File

@ -55,7 +55,7 @@ namespace internal {
ConjugateRhs, ColMajor, 1> { \ ConjugateRhs, ColMajor, 1> { \
typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \ typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \
\ \
static void run(Index rows, Index cols, Index depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \ static void run(Index rows, Index cols, Index depth, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \ Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \ level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \
using std::conj; \ using std::conj; \
@ -84,20 +84,20 @@ namespace internal {
\ \
/* Set a, b, c */ \ /* Set a, b, c */ \
if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \ if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs, m, k, OuterStride<>(lhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride)); \
a_tmp = lhs.conjugate(); \ a_tmp = lhs.conjugate(); \
a = a_tmp.data(); \ a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else \ } else \
a = _lhs; \ a = lhs_; \
\ \
if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \ if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs, k, n, OuterStride<>(rhsStride)); \ Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride)); \
b_tmp = rhs.conjugate(); \ b_tmp = rhs.conjugate(); \
b = b_tmp.data(); \ b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else \ } else \
b = _rhs; \ b = rhs_; \
\ \
BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \ BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \
(const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
@ -116,6 +116,88 @@ GEMM_SPECIALIZATION(dcomplex, cd, double, zgemm_)
GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_) GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_)
#endif #endif
// If OpenBLAS with BUILD_BFLOAT16=1 support is available,
// use sbgemm for bfloat16.
#if EIGEN_USE_OPENBLAS_BFLOAT16
extern "C" {
// OpenBLAS prototype.
void sbgemm_(const char* trans_a, const char* trans_b, const int* M, const int* N, const int* K, const float* alpha,
const Eigen::bfloat16* A, const int* lda, const Eigen::bfloat16* B, const int* ldb, const float* beta,
float* C, const int* ldc);
} // extern "C"
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs>
struct general_matrix_matrix_product<Index, Eigen::bfloat16, LhsStorageOrder, ConjugateLhs, Eigen::bfloat16,
RhsStorageOrder, ConjugateRhs, ColMajor, 1> {
typedef gebp_traits<Eigen::bfloat16, Eigen::bfloat16> Traits;
static void run(Index rows, Index cols, Index depth, const Eigen::bfloat16* lhs_, Index lhsStride,
const Eigen::bfloat16* rhs_, Index rhsStride, Eigen::bfloat16* res, Index resIncr, Index resStride,
Eigen::bfloat16 alpha, level3_blocking<Eigen::bfloat16, Eigen::bfloat16>& /*blocking*/,
GemmParallelInfo<Index>* /*info = 0*/) {
using std::conj;
if (rows == 0 || cols == 0 || depth == 0) return;
EIGEN_ONLY_USED_FOR_DEBUG(resIncr);
eigen_assert(resIncr == 1);
char transa, transb;
BlasIndex m, n, k, lda, ldb, ldc;
const Eigen::bfloat16 *a, *b;
float falpha = static_cast<float>(alpha);
float fbeta = float(1.0);
using MatrixXbf16 = Matrix<Eigen::bfloat16, Dynamic, Dynamic>;
MatrixXbf16 a_tmp, b_tmp;
MatrixXf r_tmp;
/* Set transpose options */
transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N';
transb = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N';
/* Set m, n, k */
m = convert_index<BlasIndex>(rows);
n = convert_index<BlasIndex>(cols);
k = convert_index<BlasIndex>(depth);
/* Set lda, ldb, ldc */
lda = convert_index<BlasIndex>(lhsStride);
ldb = convert_index<BlasIndex>(rhsStride);
ldc = convert_index<BlasIndex>(m);
/* Set a, b, c */
if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) {
Map<const MatrixXbf16, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride));
a_tmp = lhs.conjugate();
a = a_tmp.data();
lda = convert_index<BlasIndex>(a_tmp.outerStride());
} else {
a = lhs_;
}
if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) {
Map<const MatrixXbf16, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride));
b_tmp = rhs.conjugate();
b = b_tmp.data();
ldb = convert_index<BlasIndex>(b_tmp.outerStride());
} else {
b = rhs_;
}
// Evaluate to a temporary intermediate array.
r_tmp.resize(m, n);
sbgemm_(&transa, &transb, &m, &n, &k, (const float*)&numext::real_ref(falpha), a, &lda, b, &ldb,
(const float*)&numext::real_ref(fbeta), r_tmp.data(), &ldc);
// Cast to the output.
Map<MatrixXbf16, 0, OuterStride<> > result(res, m, n, OuterStride<>(resStride));
result = r_tmp.cast<Eigen::bfloat16>();
}
};
#endif // EIGEN_USE_OPENBLAS_SBGEMM
} // namespace internal } // namespace internal
} // end namespace Eigen } // end namespace Eigen