mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-23 21:34:30 +08:00
Add OpenBLAS sbgemm.
This commit is contained in:
parent
d228bcdf8f
commit
cdf6a1f5ed
@ -55,7 +55,7 @@ namespace internal {
|
|||||||
ConjugateRhs, ColMajor, 1> { \
|
ConjugateRhs, ColMajor, 1> { \
|
||||||
typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \
|
typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \
|
||||||
\
|
\
|
||||||
static void run(Index rows, Index cols, Index depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
|
static void run(Index rows, Index cols, Index depth, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
|
||||||
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
|
||||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \
|
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, GemmParallelInfo<Index>* /*info = 0*/) { \
|
||||||
using std::conj; \
|
using std::conj; \
|
||||||
@ -84,20 +84,20 @@ namespace internal {
|
|||||||
\
|
\
|
||||||
/* Set a, b, c */ \
|
/* Set a, b, c */ \
|
||||||
if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \
|
if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \
|
||||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs, m, k, OuterStride<>(lhsStride)); \
|
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride)); \
|
||||||
a_tmp = lhs.conjugate(); \
|
a_tmp = lhs.conjugate(); \
|
||||||
a = a_tmp.data(); \
|
a = a_tmp.data(); \
|
||||||
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
|
||||||
} else \
|
} else \
|
||||||
a = _lhs; \
|
a = lhs_; \
|
||||||
\
|
\
|
||||||
if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \
|
if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \
|
||||||
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs, k, n, OuterStride<>(rhsStride)); \
|
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride)); \
|
||||||
b_tmp = rhs.conjugate(); \
|
b_tmp = rhs.conjugate(); \
|
||||||
b = b_tmp.data(); \
|
b = b_tmp.data(); \
|
||||||
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
|
||||||
} else \
|
} else \
|
||||||
b = _rhs; \
|
b = rhs_; \
|
||||||
\
|
\
|
||||||
BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \
|
BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \
|
||||||
(const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
(const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
|
||||||
@ -116,6 +116,88 @@ GEMM_SPECIALIZATION(dcomplex, cd, double, zgemm_)
|
|||||||
GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_)
|
GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// If OpenBLAS with BUILD_BFLOAT16=1 support is available,
|
||||||
|
// use sbgemm for bfloat16.
|
||||||
|
#if EIGEN_USE_OPENBLAS_BFLOAT16
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
// OpenBLAS prototype.
|
||||||
|
void sbgemm_(const char* trans_a, const char* trans_b, const int* M, const int* N, const int* K, const float* alpha,
|
||||||
|
const Eigen::bfloat16* A, const int* lda, const Eigen::bfloat16* B, const int* ldb, const float* beta,
|
||||||
|
float* C, const int* ldc);
|
||||||
|
} // extern "C"
|
||||||
|
|
||||||
|
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs>
|
||||||
|
struct general_matrix_matrix_product<Index, Eigen::bfloat16, LhsStorageOrder, ConjugateLhs, Eigen::bfloat16,
|
||||||
|
RhsStorageOrder, ConjugateRhs, ColMajor, 1> {
|
||||||
|
typedef gebp_traits<Eigen::bfloat16, Eigen::bfloat16> Traits;
|
||||||
|
|
||||||
|
static void run(Index rows, Index cols, Index depth, const Eigen::bfloat16* lhs_, Index lhsStride,
|
||||||
|
const Eigen::bfloat16* rhs_, Index rhsStride, Eigen::bfloat16* res, Index resIncr, Index resStride,
|
||||||
|
Eigen::bfloat16 alpha, level3_blocking<Eigen::bfloat16, Eigen::bfloat16>& /*blocking*/,
|
||||||
|
GemmParallelInfo<Index>* /*info = 0*/) {
|
||||||
|
using std::conj;
|
||||||
|
if (rows == 0 || cols == 0 || depth == 0) return;
|
||||||
|
EIGEN_ONLY_USED_FOR_DEBUG(resIncr);
|
||||||
|
eigen_assert(resIncr == 1);
|
||||||
|
char transa, transb;
|
||||||
|
BlasIndex m, n, k, lda, ldb, ldc;
|
||||||
|
const Eigen::bfloat16 *a, *b;
|
||||||
|
|
||||||
|
float falpha = static_cast<float>(alpha);
|
||||||
|
float fbeta = float(1.0);
|
||||||
|
|
||||||
|
using MatrixXbf16 = Matrix<Eigen::bfloat16, Dynamic, Dynamic>;
|
||||||
|
MatrixXbf16 a_tmp, b_tmp;
|
||||||
|
MatrixXf r_tmp;
|
||||||
|
|
||||||
|
/* Set transpose options */
|
||||||
|
transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N';
|
||||||
|
transb = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N';
|
||||||
|
|
||||||
|
/* Set m, n, k */
|
||||||
|
m = convert_index<BlasIndex>(rows);
|
||||||
|
n = convert_index<BlasIndex>(cols);
|
||||||
|
k = convert_index<BlasIndex>(depth);
|
||||||
|
|
||||||
|
/* Set lda, ldb, ldc */
|
||||||
|
lda = convert_index<BlasIndex>(lhsStride);
|
||||||
|
ldb = convert_index<BlasIndex>(rhsStride);
|
||||||
|
ldc = convert_index<BlasIndex>(m);
|
||||||
|
|
||||||
|
/* Set a, b, c */
|
||||||
|
if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) {
|
||||||
|
Map<const MatrixXbf16, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride));
|
||||||
|
a_tmp = lhs.conjugate();
|
||||||
|
a = a_tmp.data();
|
||||||
|
lda = convert_index<BlasIndex>(a_tmp.outerStride());
|
||||||
|
} else {
|
||||||
|
a = lhs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) {
|
||||||
|
Map<const MatrixXbf16, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride));
|
||||||
|
b_tmp = rhs.conjugate();
|
||||||
|
b = b_tmp.data();
|
||||||
|
ldb = convert_index<BlasIndex>(b_tmp.outerStride());
|
||||||
|
} else {
|
||||||
|
b = rhs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evaluate to a temporary intermediate array.
|
||||||
|
r_tmp.resize(m, n);
|
||||||
|
|
||||||
|
sbgemm_(&transa, &transb, &m, &n, &k, (const float*)&numext::real_ref(falpha), a, &lda, b, &ldb,
|
||||||
|
(const float*)&numext::real_ref(fbeta), r_tmp.data(), &ldc);
|
||||||
|
|
||||||
|
// Cast to the output.
|
||||||
|
Map<MatrixXbf16, 0, OuterStride<> > result(res, m, n, OuterStride<>(resStride));
|
||||||
|
result = r_tmp.cast<Eigen::bfloat16>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // EIGEN_USE_OPENBLAS_SBGEMM
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
Loading…
x
Reference in New Issue
Block a user