From 648bce6cae448fac2accbdb0a456134749bbae82 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Thu, 29 Aug 2024 17:37:57 +0000 Subject: [PATCH] SSE/AVX Complex FMA --- Eigen/src/Core/arch/AVX/Complex.h | 68 +++++++++++++++++++++++++++++++ Eigen/src/Core/arch/SSE/Complex.h | 68 +++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index 67945cbd0..d5506dae4 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -455,6 +455,74 @@ EIGEN_STRONG_INLINE Packet4cf pexp(const Packet4cf& a) { return pexp_complex(a); } +#ifdef EIGEN_VECTORIZE_FMA +// std::complex +template <> +EIGEN_STRONG_INLINE Packet4cf pmadd(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { + __m256 a_odd = _mm256_movehdup_ps(a.v); + __m256 a_even = _mm256_moveldup_ps(a.v); + __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 result = _mm256_fmaddsub_ps(a_even, b.v, _mm256_fmaddsub_ps(a_odd, b_swap, c.v)); + return Packet4cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet4cf pmsub(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { + __m256 a_odd = _mm256_movehdup_ps(a.v); + __m256 a_even = _mm256_moveldup_ps(a.v); + __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 result = _mm256_fmaddsub_ps(a_even, b.v, _mm256_fmsubadd_ps(a_odd, b_swap, c.v)); + return Packet4cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet4cf pnmadd(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { + __m256 a_odd = _mm256_movehdup_ps(a.v); + __m256 a_even = _mm256_moveldup_ps(a.v); + __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmaddsub_ps(a_even, b.v, c.v)); + return Packet4cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet4cf pnmsub(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) { + __m256 a_odd = _mm256_movehdup_ps(a.v); + __m256 a_even = _mm256_moveldup_ps(a.v); + __m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmsubadd_ps(a_even, b.v, c.v)); + return Packet4cf(result); +} +// std::complex +template <> +EIGEN_STRONG_INLINE Packet2cd pmadd(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { + __m256d a_odd = _mm256_permute_pd(a.v, 0xF); + __m256d a_even = _mm256_movedup_pd(a.v); + __m256d b_swap = _mm256_permute_pd(b.v, 0x5); + __m256d result = _mm256_fmaddsub_pd(a_even, b.v, _mm256_fmaddsub_pd(a_odd, b_swap, c.v)); + return Packet2cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cd pmsub(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { + __m256d a_odd = _mm256_permute_pd(a.v, 0xF); + __m256d a_even = _mm256_movedup_pd(a.v); + __m256d b_swap = _mm256_permute_pd(b.v, 0x5); + __m256d result = _mm256_fmaddsub_pd(a_even, b.v, _mm256_fmsubadd_pd(a_odd, b_swap, c.v)); + return Packet2cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cd pnmadd(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { + __m256d a_odd = _mm256_permute_pd(a.v, 0xF); + __m256d a_even = _mm256_movedup_pd(a.v); + __m256d b_swap = _mm256_permute_pd(b.v, 0x5); + __m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmaddsub_pd(a_even, b.v, c.v)); + return Packet2cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cd pnmsub(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) { + __m256d a_odd = _mm256_permute_pd(a.v, 0xF); + __m256d a_even = _mm256_movedup_pd(a.v); + __m256d b_swap = _mm256_permute_pd(b.v, 0x5); + __m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmsubadd_pd(a_even, b.v, c.v)); + return Packet2cd(result); +} +#endif } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index a390260db..66fcade3b 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -445,6 +445,74 @@ EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a) { return pexp_complex(a); } +#ifdef EIGEN_VECTORIZE_FMA +// std::complex +template <> +EIGEN_STRONG_INLINE Packet2cf pmadd(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { + __m128 a_odd = _mm_movehdup_ps(a.v); + __m128 a_even = _mm_moveldup_ps(a.v); + __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 result = _mm_fmaddsub_ps(a_even, b.v, _mm_fmaddsub_ps(a_odd, b_swap, c.v)); + return Packet2cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cf pmsub(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { + __m128 a_odd = _mm_movehdup_ps(a.v); + __m128 a_even = _mm_moveldup_ps(a.v); + __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 result = _mm_fmaddsub_ps(a_even, b.v, _mm_fmsubadd_ps(a_odd, b_swap, c.v)); + return Packet2cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cf pnmadd(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { + __m128 a_odd = _mm_movehdup_ps(a.v); + __m128 a_even = _mm_moveldup_ps(a.v); + __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmaddsub_ps(a_even, b.v, c.v)); + return Packet2cf(result); +} +template <> +EIGEN_STRONG_INLINE Packet2cf pnmsub(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) { + __m128 a_odd = _mm_movehdup_ps(a.v); + __m128 a_even = _mm_moveldup_ps(a.v); + __m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmsubadd_ps(a_even, b.v, c.v)); + return Packet2cf(result); +} +// std::complex +template <> +EIGEN_STRONG_INLINE Packet1cd pmadd(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { + __m128d a_odd = _mm_permute_pd(a.v, 0xF); + __m128d a_even = _mm_movedup_pd(a.v); + __m128d b_swap = _mm_permute_pd(b.v, 0x5); + __m128d result = _mm_fmaddsub_pd(a_even, b.v, _mm_fmaddsub_pd(a_odd, b_swap, c.v)); + return Packet1cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet1cd pmsub(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { + __m128d a_odd = _mm_permute_pd(a.v, 0xF); + __m128d a_even = _mm_movedup_pd(a.v); + __m128d b_swap = _mm_permute_pd(b.v, 0x5); + __m128d result = _mm_fmaddsub_pd(a_even, b.v, _mm_fmsubadd_pd(a_odd, b_swap, c.v)); + return Packet1cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet1cd pnmadd(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { + __m128d a_odd = _mm_permute_pd(a.v, 0xF); + __m128d a_even = _mm_movedup_pd(a.v); + __m128d b_swap = _mm_permute_pd(b.v, 0x5); + __m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmaddsub_pd(a_even, b.v, c.v)); + return Packet1cd(result); +} +template <> +EIGEN_STRONG_INLINE Packet1cd pnmsub(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) { + __m128d a_odd = _mm_permute_pd(a.v, 0xF); + __m128d a_even = _mm_movedup_pd(a.v); + __m128d b_swap = _mm_permute_pd(b.v, 0x5); + __m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmsubadd_pd(a_even, b.v, c.v)); + return Packet1cd(result); +} +#endif } // end namespace internal } // end namespace Eigen