From 1615a2799384a2964d01ba77fe98e3f6fcc412f4 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 26 Jan 2021 10:23:23 -0800 Subject: [PATCH] Fix altivec packetmath. Allows the altivec packetmath tests to pass. There were a few issues: - `pstoreu` was missing MSQ on `_BIG_ENDIAN` systems - `cmp_*` didn't properly handle conversion of bool flags (0x7FC instead of 0xFFFF) - `pfrexp` needed to set the `exponent` argument. Related to !370, #2128 cc: @ChipKerchner @pdrocaldeira Tested on `_BIG_ENDIAN` running on QEMU with VSX. Couldn't figure out build flags to get it to work for little endian. --- Eigen/src/Core/arch/AltiVec/PacketMath.h | 54 ++++++++++++++++++------ 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 6d7842021..fdf4f1e9c 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -1056,6 +1056,7 @@ template EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE_ MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ) LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ) vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first + vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second #else vec_xst(from, 0, to); #endif @@ -1209,6 +1210,16 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ ); } +// Simple interleaving of bool masks, prevents true values from being +// converted to NaNs. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) { + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + Packet4f bf_odd, bf_even; + bf_odd = pand(reinterpret_cast(p4ui_high_mask), odd); + bf_even = plogical_shift_right<16>(even); + return reinterpret_cast(por(bf_even, bf_odd)); +} + EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ Packet4ui input = reinterpret_cast(p4f); Packet4ui lsb = plogical_shift_right<16>(input); @@ -1272,6 +1283,15 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ Packet4f op_odd = OP(a_odd, b_odd);\ return F32ToBf16(op_even, op_odd);\ +#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16Bool(op_even, op_odd);\ + template<> EIGEN_STRONG_INLINE Packet8bf padd(const Packet8bf& a, const Packet8bf& b) { BF16_TO_F32_BINARY_OP_WRAPPER(padd, a, b); } @@ -1301,12 +1321,28 @@ template<> EIGEN_STRONG_INLINE Packet8bf prsqrt (const Packet8bf& a){ template<> EIGEN_STRONG_INLINE Packet8bf pexp (const Packet8bf& a){ BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a); } + +template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { + return pldexp_float(a,exponent); +} template<> EIGEN_STRONG_INLINE Packet8bf pldexp (const Packet8bf& a, const Packet8bf& exponent){ BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent); } -template<> EIGEN_STRONG_INLINE Packet8bf pfrexp (const Packet8bf& a, Packet8bf& exponent){ - BF16_TO_F32_BINARY_OP_WRAPPER(pfrexp_float, a, exponent); + +template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { + return pfrexp_float(a,exponent); } +template<> EIGEN_STRONG_INLINE Packet8bf pfrexp (const Packet8bf& a, Packet8bf& e){ + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f e_even; + Packet4f e_odd; + Packet4f op_even = pfrexp(a_even, e_even); + Packet4f op_odd = pfrexp(a_odd, e_odd); + e = F32ToBf16(e_even, e_odd); + return F32ToBf16(op_even, op_odd); +} + template<> EIGEN_STRONG_INLINE Packet8bf psin (const Packet8bf& a){ BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a); } @@ -1346,13 +1382,13 @@ template<> EIGEN_STRONG_INLINE Packet8bf pmax(const Packet8bf& a, con } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_lt, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt, a, b); } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_le, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le, a, b); } template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) { - BF16_TO_F32_BINARY_OP_WRAPPER(pcmp_eq, a, b); + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq, a, b); } template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) { @@ -1370,14 +1406,6 @@ template<> EIGEN_STRONG_INLINE Packet8bf plset(const bfloat16& a) { return padd(pset1(a), pload(countdown)); } -template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { - return pfrexp_float(a,exponent); -} - -template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { - return pldexp_float(a,exponent); -} - template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) { Packet4f b, sum;