From be06c9ad51e23d0b6b10cbb96a9e7db7f3299077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damiano=20Franz=C3=B2?= Date: Wed, 14 Feb 2024 14:55:03 +0100 Subject: [PATCH] Implement float pexp_complex --- Eigen/src/Core/arch/AVX/Complex.h | 6 ++ Eigen/src/Core/arch/AVX512/Complex.h | 6 ++ Eigen/src/Core/arch/AltiVec/Complex.h | 6 ++ .../arch/Default/GenericPacketMathFunctions.h | 63 +++++++++++++++++-- .../Default/GenericPacketMathFunctionsFwd.h | 4 ++ Eigen/src/Core/arch/NEON/Complex.h | 11 ++++ Eigen/src/Core/arch/SSE/Complex.h | 6 ++ Eigen/src/Core/arch/ZVector/Complex.h | 6 ++ test/packetmath.cpp | 29 +++++++++ 9 files changed, 133 insertions(+), 4 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index 6a8bee890..bae57146b 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -41,6 +41,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -443,6 +444,11 @@ EIGEN_STRONG_INLINE Packet4cf plog(const Packet4cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet4cf pexp(const Packet4cf& a) { + return pexp_complex(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index c14b4a0b6..b70c7fefe 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -40,6 +40,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -460,6 +461,11 @@ EIGEN_STRONG_INLINE Packet8cf plog(const Packet8cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet8cf pexp(const Packet8cf& a) { + return pexp_complex(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h index e3c4436c5..0252efa6d 100644 --- a/Eigen/src/Core/arch/AltiVec/Complex.h +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -99,6 +99,7 @@ struct packet_traits > : default_packet_traits { HasMax = 0, HasSqrt = 1, HasLog = 1, + HasExp = 1, #ifdef EIGEN_VECTORIZE_VSX HasBlend = 1, #endif @@ -375,6 +376,11 @@ EIGEN_STRONG_INLINE Packet2cf plog(const Packet2cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a) { + return pexp_complex(a); +} + //---------- double ---------- #ifdef EIGEN_VECTORIZE_VSX struct Packet1cd { diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 839df3781..626185c40 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -555,7 +555,7 @@ inline float trig_reduce_huge(float xf, Eigen::numext::int32_t* quadrant) { return float(double(int64_t(p)) * pio2_62); } -template +template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS #if EIGEN_COMP_GNUC_STRICT __attribute__((optimize("-fno-unsafe-math-optimizations"))) @@ -669,10 +669,21 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS y2 = pmadd(y2, x, x); // Select the correct result from the two polynomials. - y = ComputeSine ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2); - + if (ComputeBoth) { + Packet peven = peven_mask(x); + Packet ysin = pselect(poly_mask, y2, y1); + Packet ycos = pselect(poly_mask, y1, y2); + Packet sign_bit_sin = pxor(_x, preinterpret(plogical_shift_left<30>(y_int))); + Packet sign_bit_cos = preinterpret(plogical_shift_left<30>(padd(y_int, csti_1))); + sign_bit_sin = pand(sign_bit_sin, cst_sign_mask); // clear all but left most bit + sign_bit_cos = pand(sign_bit_cos, cst_sign_mask); // clear all but left most bit + y = pselect(peven, pxor(ysin, sign_bit_sin), pxor(ycos, sign_bit_cos)); + } else { + y = ComputeSine ? pselect(poly_mask, y2, y1) : pselect(poly_mask, y1, y2); + y = pxor(y, sign_bit); + } // Update the sign and filter huge inputs - return pxor(y, sign_bit); + return y; } template @@ -1051,6 +1062,50 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Pa return xres; } +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& a) { + typedef typename unpacket_traits::as_real RealPacket; + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + const RealPacket even_mask = peven_mask(a.v); + const Packet even_maskp = Packet(even_mask); + const RealPacket odd_mask = pcplxflip(Packet(even_mask)).v; + + Packet p0y = Packet(pand(odd_mask, a.v)); + Packet py0 = pcplxflip(p0y); + Packet pyy = padd(p0y, py0); + + RealPacket sincos = psincos_float(pyy.v); + RealPacket cossin = pcplxflip(Packet(sincos)).v; + + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + const RealPacket cst_neg_inf = pset1(-NumTraits::infinity()); + Packet x_is_inf = Packet(pcmp_eq(a.v, cst_pos_inf)); + Packet x_is_minf = Packet(pcmp_eq(a.v, cst_neg_inf)); + Packet x_is_zero = Packet(pcmp_eq(pzero(a).v, a.v)); + Packet x_real_is_inf = pand(even_maskp, x_is_inf); + Packet x_real_is_minf = pand(even_maskp, x_is_minf); + Packet inf0 = pset1(Scalar(NumTraits::infinity(), RealScalar(0))); + Packet x_is_inf0 = pand(x_real_is_inf, pcplxflip(x_is_zero)); + x_is_inf0 = por(x_is_inf0, pcplxflip(x_is_inf0)); + Packet x_imag_goes_zero = pand(por(x_is_minf, x_is_inf), pcplxflip(x_real_is_minf)); + Packet x_is_nan = Packet(pisnan(a.v)); + Packet x_real_goes_zero = pand(x_is_nan, pcplxflip(x_real_is_minf)); + + RealPacket pexp_real = pexp(a.v); + Packet pexp_half = Packet(pand(even_mask, pexp_real)); + RealPacket xexp_flip_rp = pcplxflip(pexp_half).v; + RealPacket xexp = padd(pexp_half.v, xexp_flip_rp); + Packet result(pmul(cossin, xexp)); + + result = pselect(x_is_inf0, inf0, result); + result = pselect(x_real_is_minf, pzero(a), result); + result = pselect(x_imag_goes_zero, pzero(a), result); + result = pselect(x_real_goes_zero, pzero(a), result); + + return result; +} + template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a) { typedef typename unpacket_traits::type Scalar; diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index dd1698830..9560de2fc 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -121,6 +121,10 @@ struct ppolevl; template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_complex(const Packet& x); +/** \internal \returns exp(x) for complex types */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& x); + // Macros for instantiating these generic functions for different backends. #define EIGEN_PACKET_FUNCTION(METHOD, SCALAR, PACKET) \ template <> \ diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h index 22c776574..5257c03c8 100644 --- a/Eigen/src/Core/arch/NEON/Complex.h +++ b/Eigen/src/Core/arch/NEON/Complex.h @@ -63,6 +63,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -447,6 +448,16 @@ EIGEN_STRONG_INLINE Packet2cf plog(const Packet2cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet1cf pexp(const Packet1cf& a) { + return pexp_complex(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a) { + return pexp_complex(a); +} + //---------- double ---------- #if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index 76c3a0547..0e70f0318 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -43,6 +43,7 @@ struct packet_traits > : default_packet_traits { HasNegate = 1, HasSqrt = 1, HasLog = 1, + HasExp = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -424,6 +425,11 @@ EIGEN_STRONG_INLINE Packet2cf plog(const Packet2cf& a) { return plog_complex(a); } +template <> +EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a) { + return pexp_complex(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/ZVector/Complex.h b/Eigen/src/Core/arch/ZVector/Complex.h index e8bd17da1..9b8974742 100644 --- a/Eigen/src/Core/arch/ZVector/Complex.h +++ b/Eigen/src/Core/arch/ZVector/Complex.h @@ -61,6 +61,7 @@ struct packet_traits > : default_packet_traits { HasMul = 1, HasDiv = 1, HasLog = 1, + HasExp = 1, HasNegate = 1, HasAbs = 0, HasAbs2 = 0, @@ -436,6 +437,11 @@ EIGEN_STRONG_INLINE Packet2cf plog(const Packet2cf& a, const Packet2c return plog_complex(a, b); } +template <> +EIGEN_STRONG_INLINE Packet2cf pexp(const Packet2cf& a, const Packet2cf& b) { + return pexp_complex(a, b); +} + EIGEN_STRONG_INLINE Packet2cf pcplxflip /**/ (const Packet2cf& x) { Packet2cf res; res.cd[0] = pcplxflip(x.cd[0]); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index bf2970cef..c5e4897ad 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -1447,6 +1447,35 @@ void packetmath_complex() { data1[3] = Scalar(nan, -inf); CHECK_CWISE1_IM1ULP_N(std::log, internal::plog, 4); } + + if (PacketTraits::HasExp) { + for (int i = 0; i < size; ++i) { + data1[i] = Scalar(internal::random(), internal::random()); + } + CHECK_CWISE1_N(std::exp, internal::pexp, size); + + // Test misc. corner cases. + const RealScalar zero = RealScalar(0); + const RealScalar one = RealScalar(1); + const RealScalar inf = std::numeric_limits::infinity(); + const RealScalar nan = std::numeric_limits::quiet_NaN(); + for (RealScalar x : {zero, one, inf}) { + for (RealScalar y : {zero, one, inf}) { + data1[0] = Scalar(x, y); + data1[1] = Scalar(-x, y); + data1[2] = Scalar(x, -y); + data1[3] = Scalar(-x, -y); + CHECK_CWISE1_N(std::exp, internal::pexp, 4); + } + } + for (RealScalar x : {zero, one, inf}) { + data1[0] = Scalar(x, nan); + data1[1] = Scalar(-x, nan); + data1[2] = Scalar(nan, x); + data1[3] = Scalar(nan, -x); + CHECK_CWISE1_N(std::exp, internal::pexp, 4); + } + } } template