Split general_matrix_vector_product interface for Power into two macros - one ColMajor and RowMajor.

This commit is contained in:
Chip Kerchner 2022-03-23 18:09:33 +00:00 committed by Antonio Sánchez
parent 19a6a827c4
commit 0699fa06fe

View File

@ -1996,9 +1996,9 @@ EIGEN_STRONG_INLINE void gemv_row(
}
}
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE(Scalar, Major) \
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, Scalar, LhsMapper, Major, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
\
@ -2008,18 +2008,30 @@ struct general_matrix_vector_product<Index, Scalar, LhsMapper, Major, ConjugateL
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
if (Major == ColMajor) { \
gemv_col<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} else { \
gemv_row<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
} \
};
EIGEN_POWER_GEMV_REAL_SPECIALIZE(float, ColMajor)
EIGEN_POWER_GEMV_REAL_SPECIALIZE(double, ColMajor)
EIGEN_POWER_GEMV_REAL_SPECIALIZE(float, RowMajor)
EIGEN_POWER_GEMV_REAL_SPECIALIZE(double, RowMajor)
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
\
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
Index rows, Index cols, \
const LhsMapper& lhs, \
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
gemv_row<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(float)
EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
@ -2311,9 +2323,9 @@ EIGEN_STRONG_INLINE void gemv_complex_row(
}
}
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(Scalar, LhsScalar, RhsScalar, Major) \
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, Major, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
\
@ -2323,26 +2335,38 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, Major, Conjuga
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
if (Major == ColMajor) { \
gemv_complex_col<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} else { \
gemv_complex_row<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
} \
};
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, float, std::complex<float>, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, float, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, std::complex<float>, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, double, std::complex<double>, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, double, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, std::complex<double>, ColMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, float, std::complex<float>, RowMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, float, RowMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(float, std::complex<float>, std::complex<float>, RowMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, double, std::complex<double>, RowMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, double, RowMajor)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE(double, std::complex<double>, std::complex<double>, RowMajor)
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs, Version> \
{ \
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
\
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
Index rows, Index cols, \
const LhsMapper& lhs, \
const RhsMapper& rhs, \
ResScalar* res, Index resIncr, \
ResScalar alpha) { \
gemv_complex_row<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, float, std::complex<float>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, float)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, std::complex<float>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, double, std::complex<double>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, double)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, std::complex<double>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, float, std::complex<float>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, float)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, std::complex<float>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, double, std::complex<double>)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, double)
EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, std::complex<double>)
#endif // EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H