Remove AVX512VL dependency in trsm

This commit is contained in:
Shi, Brian 2022-04-14 11:35:26 -07:00
parent 07db964bde
commit fc1d888415
2 changed files with 14 additions and 13 deletions

View File

@ -190,9 +190,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/AVX/MathFunctions.h"
#include "src/Core/arch/AVX512/MathFunctions.h"
#ifdef __AVX512VL__
#include "src/Core/arch/AVX512/TrsmKernel.h"
#endif
#include "src/Core/arch/AVX512/TrsmKernel.h"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers
#include "src/Core/arch/SSE/PacketMath.h"

View File

@ -237,7 +237,7 @@ template<> struct unpacket_traits<Packet8f> {
typedef Packet8i integer_packet;
typedef uint8_t mask_t;
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true
#ifdef __AVX512VL__
#ifdef EIGEN_VECTORIZE_AVX512
, masked_fpops_available=true
#endif
};
@ -468,11 +468,14 @@ template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { r
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
#ifdef __AVX512VL__
#ifdef EIGEN_VECTORIZE_AVX512
template <>
EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm256_maskz_add_ps(mask, a, b);
__mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
return _mm512_castps512_ps256(_mm512_maskz_add_ps(
mask,
_mm512_castps256_ps512(a),
_mm512_castps256_ps512(b)));
}
#endif
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
@ -859,9 +862,9 @@ template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { E
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
#ifdef __AVX512VL__
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskz_loadu_ps(mask, from);
#ifdef EIGEN_VECTORIZE_AVX512
__mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_castps512_ps256(_mm512_maskz_loadu_ps(mask, from));
#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
@ -927,9 +930,9 @@ template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d&
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
#ifdef __AVX512VL__
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_mask_storeu_ps(to, mask, from);
#ifdef EIGEN_VECTORIZE_AVX512
__mmask16 mask = static_cast<__mmask16>(umask & 0x00FF);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, _mm512_castps256_ps512(from));
#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);