Create the ability to disable the specialized gemm_pack_rhs in Eigen (only PPC) for TensorFlow

(cherry picked from commit 91e99ec1e02100d07e35a7abb1b5c76707237219)
This commit is contained in:
Chip Kerchner 2021-06-30 23:05:04 +00:00 committed by Rasmus Munk Larsen
parent 8190739f12
commit eebde572d9

View File

@ -11,6 +11,10 @@
#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
#endif
#include "MatrixProductCommon.h" #include "MatrixProductCommon.h"
// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX // Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
@ -2423,6 +2427,7 @@ void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Co
pack(blockA, lhs, depth, rows, stride, offset); pack(blockA, lhs, depth, rows, stride, offset);
} }
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{ {
@ -2450,6 +2455,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack; dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset); pack(blockB, rhs, depth, cols, stride, offset);
} }
#endif
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
@ -2478,6 +2484,7 @@ void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Con
dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack; dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
pack(blockA, lhs, depth, rows, stride, offset); pack(blockA, lhs, depth, rows, stride, offset);
} }
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
{ {
@ -2506,6 +2513,7 @@ void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet,
pack(blockA, lhs, depth, rows, stride, offset); pack(blockA, lhs, depth, rows, stride, offset);
} }
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{ {
@ -2533,6 +2541,7 @@ void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack; dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset); pack(blockB, rhs, depth, cols, stride, offset);
} }
#endif
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>