diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 709cebe4e..af1b48744 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3371,6 +3371,10 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p) return reinterpret_cast(vshlq_n_u32(vmovl_u16(p), 16)); } +EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) { + return vmovn_u32(vreinterpretq_f32_u32(p)); +} + template<> EIGEN_STRONG_INLINE Packet4bf pset1(const bfloat16& from) { return pset1(from.value); } @@ -3528,17 +3532,17 @@ template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff(const Packet4bf& a, template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pnegate(const Packet4bf& a)