From 10e62ccd22a19b7dfaa05a26c2b04e249d82d3ce Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Wed, 12 Mar 2025 17:06:32 +0000 Subject: [PATCH] Fix x86 complex vectorized fma --- Eigen/src/Core/arch/AVX/Complex.h | 24 ++++------------------ Eigen/src/Core/arch/SSE/Complex.h | 24 ++++------------------ test/packetmath.cpp | 34 +++++++++++++++++++++++-------- 3 files changed, 34 insertions(+), 48 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index d5506dae4..09fa20b7c 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -475,19 +475,11 @@ EIGEN_STRONG_INLINE Packet4cf pmsub(const Packet4cf& a, const Packet4cf& b, cons } template <> EIGEN_STRONG_INLINE Packet4cf pnmadd(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { - __m256 a_odd = _mm256_movehdup_ps(a.v); - __m256 a_even = _mm256_moveldup_ps(a.v); - __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); - __m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmaddsub_ps(a_even, b.v, c.v)); - return Packet4cf(result); + return pnegate(pmsub(a, b, c)); } template <> EIGEN_STRONG_INLINE Packet4cf pnmsub(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { - __m256 a_odd = _mm256_movehdup_ps(a.v); - __m256 a_even = _mm256_moveldup_ps(a.v); - __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); - __m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmsubadd_ps(a_even, b.v, c.v)); - return Packet4cf(result); + return pnegate(pmadd(a, b, c)); } // std::complex template <> @@ -508,19 +500,11 @@ EIGEN_STRONG_INLINE Packet2cd pmsub(const Packet2cd& a, const Packet2cd& b, cons } template <> EIGEN_STRONG_INLINE Packet2cd pnmadd(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { - __m256d a_odd = _mm256_permute_pd(a.v, 0xF); - __m256d a_even = _mm256_movedup_pd(a.v); - __m256d b_swap = _mm256_permute_pd(b.v, 0x5); - __m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmaddsub_pd(a_even, b.v, c.v)); - return Packet2cd(result); + return pnegate(pmsub(a, b, c)); } template <> EIGEN_STRONG_INLINE Packet2cd pnmsub(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { - __m256d a_odd = _mm256_permute_pd(a.v, 0xF); - __m256d a_even = _mm256_movedup_pd(a.v); - __m256d b_swap = _mm256_permute_pd(b.v, 0x5); - __m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmsubadd_pd(a_even, b.v, c.v)); - return Packet2cd(result); + return pnegate(pmadd(a, b, c)); } #endif } // end namespace internal diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index c69e3d479..f79da7b8c 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -465,19 +465,11 @@ EIGEN_STRONG_INLINE Packet2cf pmsub(const Packet2cf& a, const Packet2cf& b, cons } template <> EIGEN_STRONG_INLINE Packet2cf pnmadd(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { - __m128 a_odd = _mm_movehdup_ps(a.v); - __m128 a_even = _mm_moveldup_ps(a.v); - __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); - __m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmaddsub_ps(a_even, b.v, c.v)); - return Packet2cf(result); + return pnegate(pmsub(a, b, c)); } template <> EIGEN_STRONG_INLINE Packet2cf pnmsub(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { - __m128 a_odd = _mm_movehdup_ps(a.v); - __m128 a_even = _mm_moveldup_ps(a.v); - __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); - __m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmsubadd_ps(a_even, b.v, c.v)); - return Packet2cf(result); + return pnegate(pmadd(a, b, c)); } // std::complex template <> @@ -498,19 +490,11 @@ EIGEN_STRONG_INLINE Packet1cd pmsub(const Packet1cd& a, const Packet1cd& b, cons } template <> EIGEN_STRONG_INLINE Packet1cd pnmadd(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { - __m128d a_odd = _mm_permute_pd(a.v, 0x3); - __m128d a_even = _mm_movedup_pd(a.v); - __m128d b_swap = _mm_permute_pd(b.v, 0x1); - __m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmaddsub_pd(a_even, b.v, c.v)); - return Packet1cd(result); + return pnegate(pmsub(a, b, c)); } template <> EIGEN_STRONG_INLINE Packet1cd pnmsub(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { - __m128d a_odd = _mm_permute_pd(a.v, 0x3); - __m128d a_even = _mm_movedup_pd(a.v); - __m128d b_swap = _mm_permute_pd(b.v, 0x1); - __m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmsubadd_pd(a_even, b.v, c.v)); - return Packet1cd(result); + return pnegate(pmadd(a, b, c)); } #endif } // end namespace internal diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 9c5d6cf46..102817f02 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -92,6 +92,26 @@ inline T REF_LDEXP(const T& x, const T& exp) { return static_cast(ldexp(x, static_cast(exp))); } +// provides a convenient function to take the absolute value of each component of a complex number to prevent +// catastrophic cancellation in randomly generated complex numbers +template ::IsComplex> +struct abs_helper_impl { + static T run(T x) { return numext::abs(x); } +}; +template +struct abs_helper_impl { + static T run(T x) { + T res = x; + numext::real_ref(res) = numext::abs(numext::real(res)); + numext::imag_ref(res) = numext::abs(numext::imag(res)); + return res; + } +}; +template +T abs_helper(T x) { + return abs_helper_impl::run(x); +} + // Uses pcast to cast from one array to another. template struct pcast_array; @@ -724,11 +744,6 @@ void packetmath() { packetmath_pcast_ops_runner::run(); packetmath_minus_zero_add_test::run(); - for (int i = 0; i < size; ++i) { - data1[i] = numext::abs(internal::random()); - } - CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); - CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd); if (!std::is_same::value && NumTraits::IsSigned) { nmsub_test(data1, data2, ref, PacketSize); @@ -738,14 +753,17 @@ void packetmath() { // which can lead to very flaky tests. Here we ensure the signs are such that // they do not cancel. for (int i = 0; i < PacketSize; ++i) { - data1[i] = numext::abs(internal::random()); - data1[i + PacketSize] = numext::abs(internal::random()); - data1[i + 2 * PacketSize] = Scalar(0) - numext::abs(internal::random()); + data1[i] = abs_helper(internal::random()); + data1[i + PacketSize] = abs_helper(internal::random()); + data1[i + 2 * PacketSize] = Scalar(0) - abs_helper(internal::random()); } if (!std::is_same::value && NumTraits::IsSigned) { CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub); CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd); } + + CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); + CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); } // Notice that this definition works for complex types as well.