From cdf6a1f5ed2e4937dbd9b6f86789af219125e11d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Mon, 16 Jun 2025 18:23:03 +0000 Subject: [PATCH] Add OpenBLAS sbgemm. --- .../Core/products/GeneralMatrixMatrix_BLAS.h | 92 ++++++++++++++++++- 1 file changed, 87 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h index 56743da3b..913beb696 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h @@ -55,7 +55,7 @@ namespace internal { ConjugateRhs, ColMajor, 1> { \ typedef gebp_traits Traits; \ \ - static void run(Index rows, Index cols, Index depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \ + static void run(Index rows, Index cols, Index depth, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \ Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \ level3_blocking& /*blocking*/, GemmParallelInfo* /*info = 0*/) { \ using std::conj; \ @@ -84,20 +84,20 @@ namespace internal { \ /* Set a, b, c */ \ if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \ - Map > lhs(_lhs, m, k, OuterStride<>(lhsStride)); \ + Map > lhs(lhs_, m, k, OuterStride<>(lhsStride)); \ a_tmp = lhs.conjugate(); \ a = a_tmp.data(); \ lda = convert_index(a_tmp.outerStride()); \ } else \ - a = _lhs; \ + a = lhs_; \ \ if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \ - Map > rhs(_rhs, k, n, OuterStride<>(rhsStride)); \ + Map > rhs(rhs_, k, n, OuterStride<>(rhsStride)); \ b_tmp = rhs.conjugate(); \ b = b_tmp.data(); \ ldb = convert_index(b_tmp.outerStride()); \ } else \ - b = _rhs; \ + b = rhs_; \ \ BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \ (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ @@ -116,6 +116,88 @@ GEMM_SPECIALIZATION(dcomplex, cd, double, zgemm_) GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_) #endif +// If OpenBLAS with BUILD_BFLOAT16=1 support is available, +// use sbgemm for bfloat16. +#if EIGEN_USE_OPENBLAS_BFLOAT16 + +extern "C" { +// OpenBLAS prototype. +void sbgemm_(const char* trans_a, const char* trans_b, const int* M, const int* N, const int* K, const float* alpha, + const Eigen::bfloat16* A, const int* lda, const Eigen::bfloat16* B, const int* ldb, const float* beta, + float* C, const int* ldc); +} // extern "C" + +template +struct general_matrix_matrix_product { + typedef gebp_traits Traits; + + static void run(Index rows, Index cols, Index depth, const Eigen::bfloat16* lhs_, Index lhsStride, + const Eigen::bfloat16* rhs_, Index rhsStride, Eigen::bfloat16* res, Index resIncr, Index resStride, + Eigen::bfloat16 alpha, level3_blocking& /*blocking*/, + GemmParallelInfo* /*info = 0*/) { + using std::conj; + if (rows == 0 || cols == 0 || depth == 0) return; + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); + eigen_assert(resIncr == 1); + char transa, transb; + BlasIndex m, n, k, lda, ldb, ldc; + const Eigen::bfloat16 *a, *b; + + float falpha = static_cast(alpha); + float fbeta = float(1.0); + + using MatrixXbf16 = Matrix; + MatrixXbf16 a_tmp, b_tmp; + MatrixXf r_tmp; + + /* Set transpose options */ + transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; + transb = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; + + /* Set m, n, k */ + m = convert_index(rows); + n = convert_index(cols); + k = convert_index(depth); + + /* Set lda, ldb, ldc */ + lda = convert_index(lhsStride); + ldb = convert_index(rhsStride); + ldc = convert_index(m); + + /* Set a, b, c */ + if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { + Map > lhs(lhs_, m, k, OuterStride<>(lhsStride)); + a_tmp = lhs.conjugate(); + a = a_tmp.data(); + lda = convert_index(a_tmp.outerStride()); + } else { + a = lhs_; + } + + if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { + Map > rhs(rhs_, k, n, OuterStride<>(rhsStride)); + b_tmp = rhs.conjugate(); + b = b_tmp.data(); + ldb = convert_index(b_tmp.outerStride()); + } else { + b = rhs_; + } + + // Evaluate to a temporary intermediate array. + r_tmp.resize(m, n); + + sbgemm_(&transa, &transb, &m, &n, &k, (const float*)&numext::real_ref(falpha), a, &lda, b, &ldb, + (const float*)&numext::real_ref(fbeta), r_tmp.data(), &ldc); + + // Cast to the output. + Map > result(res, m, n, OuterStride<>(resStride)); + result = r_tmp.cast(); + } +}; + +#endif // EIGEN_USE_OPENBLAS_SBGEMM + } // namespace internal } // end namespace Eigen