diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 642f08cc7..8db258991 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -72,6 +72,7 @@ struct default_packet_traits { HasReciprocal = 0, HasSqrt = 0, HasRsqrt = 0, + HasCbrt = 0, HasExp = 0, HasExpm1 = 0, HasLog = 0, diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index d4aa3f889..a48eedb9f 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -836,8 +836,8 @@ EIGEN_DEVICE_FUNC std::enable_if_t<(std::numeric_limits::has_infinity && !Num template EIGEN_DEVICE_FUNC - std::enable_if_t::has_quiet_NaN || std::numeric_limits::has_signaling_NaN), bool> - isnan_impl(const T&) { +std::enable_if_t::has_quiet_NaN || std::numeric_limits::has_signaling_NaN), bool> +isnan_impl(const T&) { return false; } @@ -1361,11 +1361,17 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt) /** \returns the cube root of \a x. **/ template -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T cbrt(const T& x) { +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t::IsComplex, T> cbrt(const T& x) { EIGEN_USING_STD(cbrt); return static_cast(cbrt(x)); } +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t::IsComplex, T> cbrt(const T& x) { + EIGEN_USING_STD(pow); + return pow(x, typename NumTraits::Real(1.0 / 3.0)); +} + /** \returns the reciprocal square root of \a x. **/ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T rsqrt(const T& x) { @@ -1394,17 +1400,17 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double log(const double& x) { #endif template -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE - std::enable_if_t::IsSigned || NumTraits::IsComplex, typename NumTraits::Real> - abs(const T& x) { +EIGEN_DEVICE_FUNC +EIGEN_ALWAYS_INLINE std::enable_if_t::IsSigned || NumTraits::IsComplex, typename NumTraits::Real> +abs(const T& x) { EIGEN_USING_STD(abs); return abs(x); } template -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE - std::enable_if_t::IsSigned || NumTraits::IsComplex), typename NumTraits::Real> - abs(const T& x) { +EIGEN_DEVICE_FUNC +EIGEN_ALWAYS_INLINE std::enable_if_t::IsSigned || NumTraits::IsComplex), typename NumTraits::Real> +abs(const T& x) { return x; } diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index eb0011c2b..5b7285f99 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -28,6 +28,7 @@ EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(tanh, Packet4d) +EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, Packet4d) #ifdef EIGEN_VECTORIZE_AVX2 EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index fb49206ea..470e36d8d 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -122,6 +122,7 @@ struct packet_traits : default_packet_traits { HasBessel = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH, HasErfc = EIGEN_FAST_MATH, @@ -150,6 +151,7 @@ struct packet_traits : default_packet_traits { HasExp = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasATan = 1, HasATanh = 1, HasBlend = 1 diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index b7c8b9028..27a0f1023 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -128,6 +128,7 @@ struct packet_traits : default_packet_traits { HasATanh = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasLog = 1, HasLog1p = 1, HasExpm1 = 1, @@ -153,6 +154,7 @@ struct packet_traits : default_packet_traits { HasBlend = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, HasLog = 1, diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 35cc273f0..d7bd9bee4 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -186,6 +186,7 @@ struct packet_traits : default_packet_traits { HasExp = 1, #ifdef EIGEN_VECTORIZE_VSX HasSqrt = 1, + HasCbrt = 1, #if !EIGEN_COMP_CLANG HasRsqrt = 1, #else @@ -3176,6 +3177,7 @@ struct packet_traits : default_packet_traits { HasLog = 0, HasExp = 1, HasSqrt = 1, + HasCbrt = 1, #if !EIGEN_COMP_CLANG HasRsqrt = 1, #else diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 174eb575d..eaa20ac7c 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -289,6 +289,143 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const return pmul(a, preinterpret(plogical_shift_left(e))); } +// This function implements a single step of Halley's iteration for +// computing x = y^(1/3): +// x_{k+1} = x_k - (x_k^3 - y) x_k / (2x_k^3 + y) +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_halley_iteration_step(const Packet& x_k, + const Packet& y) { + typedef typename unpacket_traits::type Scalar; + Packet x_k_cb = pmul(x_k, pmul(x_k, x_k)); + Packet denom = pmadd(pset1(Scalar(2)), x_k_cb, y); + Packet num = psub(x_k_cb, y); + Packet r = pdiv(num, denom); + return pnmadd(x_k, r, x_k); +} + +// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the +// interval [0.125,1]. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_decompose(const Packet& x, Packet& e_div3) { + typedef typename unpacket_traits::type Scalar; + // Extract the significant s in the range [0.5,1) and exponent e, such that + // x = 2^e * s. + Packet e, s; + s = pfrexp(x, e); + + // Split the exponent into a part divisible by 3 and the remainder. + // e = 3*e_div3 + e_mod3. + constexpr Scalar kOneThird = Scalar(1) / 3; + e_div3 = pceil(pmul(e, pset1(kOneThird))); + Packet e_mod3 = pnmadd(pset1(Scalar(3)), e_div3, e); + + // Replace s by y = (s * 2^e_mod3). + return pldexp_fast(s, e_mod3); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_special_cases_and_sign(const Packet& x, + const Packet& abs_root) { + typedef typename unpacket_traits::type Scalar; + + // Set sign. + const Packet sign_mask = pset1(Scalar(-0.0)); + const Packet x_sign = pand(sign_mask, x); + Packet root = por(x_sign, abs_root); + + // Handle non-finite and zero values of x. + // constexpr Scalar kInf = NumTraits::infinity(); + const Packet is_not_finite = psub(x,x);; + const Packet is_zero = pcmp_eq(pzero(x), x); + const Packet use_root = por(is_not_finite, is_zero); + return pselect(use_root, x, root); +} + +// Generic implementation of cbrt(x) for float. +// +// The algorithm computes the cubic root of the input by first +// decomposing it into a exponent and significant +// x = s * 2^e. +// +// We can then write the cube root as +// +// x^(1/3) = 2^(e/3) * s^(1/3) +// = 2^((3*e_div3 + e_mod3)/3) * s^(1/3) +// = 2^(e_div3) * 2^(e_mod3/3) * s^(1/3) +// = 2^(e_div3) * (s * 2^e_mod3)^(1/3) +// +// where e_div3 = ceil(e/3) and e_mod3 = e - 3*e_div3. +// +// The cube root of the second term y = (s * 2^e_mod3)^(1/3) is coarsely +// approximated using a cubic polynomial and subsequently refined using a +// single step of Halley's iteration, and finally the two terms are combined +// using pldexp_fast. +// +// Note: Many alternatives exist for implementing cbrt. See, for example, +// the excellent discussion in Kahan's note: +// https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf +// This particular implementation was found to be very fast and accurate +// among several alternatives tried, but is probably not "optimal" on all +// platforms. +// +// This is accurate to 2 ULP. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be float"); + + // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the + // interval [0.125,1]. + Packet e_div3; + const Packet y = cbrt_decompose(pabs(x), e_div3); + + // Compute initial approximation accurate to 5.22e-3. + // The polynomial was computed using Rminimax. + constexpr float alpha[] = {5.9220016002655029296875e-01f, -1.3859539031982421875e+00f, 1.4581282138824462890625e+00f, + 3.408401906490325927734375e-01f}; + Packet r = ppolevl::run(y, alpha); + + // Take one step of Halley's iteration. + r = cbrt_halley_iteration_step(r, y); + + // Finally multiply by 2^(e_div3) + r = pldexp_fast(r, e_div3); + + return cbrt_special_cases_and_sign(x, r); +} + +// Generic implementation of cbrt(x) for double. +// +// The algorithm is identical to the one for float except that a different initial +// approximation is used for y^(1/3) and two Halley iteration steps are peformed. +// +// This is accurate to 1 ULP. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be double"); + + // Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the + // interval [0.125,1]. + Packet e_div3; + const Packet y = cbrt_decompose(pabs(x), e_div3); + + // Compute initial approximation accurate to 0.016. + // The polynomial was computed using Rminimax. + constexpr double alpha[] = {-4.69470621553356115551736138513660989701747894287109375e-01, + 1.072314636518546304699839311069808900356292724609375e+00, + 3.81249427609571867048288140722434036433696746826171875e-01}; + Packet r = ppolevl::run(y, alpha); + + // Take two steps of Halley's iteration. + r = cbrt_halley_iteration_step(r, y); + r = cbrt_halley_iteration_step(r, y); + + // Finally multiply by 2^(e_div3). + r = pldexp_fast(r, e_div3); + return cbrt_special_cases_and_sign(x, r); +} + // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can @@ -1123,7 +1260,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_atan(const Pa constexpr Scalar kPiOverTwo = static_cast(EIGEN_PI / 2); - const Packet cst_signmask = pset1(-Scalar(0)); + const Packet cst_signmask = pset1(Scalar(-0.0)); const Packet cst_one = pset1(Scalar(1)); const Packet cst_pi_over_two = pset1(kPiOverTwo); diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index ac0e2cfd3..673954e92 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -54,6 +54,14 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent); +/** \internal \returns cbrt(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x_in); + +/** \internal \returns cbrt(x) for double precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x_in); + /** \internal \returns log(x) for single precision float */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_float(const Packet _x); @@ -195,6 +203,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a); EIGEN_FLOAT_PACKET_FUNCTION(log, PACKET) \ EIGEN_FLOAT_PACKET_FUNCTION(log2, PACKET) \ EIGEN_FLOAT_PACKET_FUNCTION(exp, PACKET) \ + EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \ @@ -208,6 +217,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a); EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET) \ + EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET) \ EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 8c6cc3271..6d7f038ff 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -207,6 +207,7 @@ struct packet_traits : default_packet_traits { HasExp = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH, HasErfc = EIGEN_FAST_MATH, @@ -5160,6 +5161,7 @@ struct packet_traits : default_packet_traits { HasCos = EIGEN_FAST_MATH, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH, HasErfc = EIGEN_FAST_MATH diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 5e91fbaca..70d13d6af 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -195,6 +195,7 @@ struct packet_traits : default_packet_traits { HasBessel = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH, HasErfc = EIGEN_FAST_MATH, @@ -222,6 +223,7 @@ struct packet_traits : default_packet_traits { HasExp = 1, HasSqrt = 1, HasRsqrt = 1, + HasCbrt = 1, HasATan = 1, HasATanh = 1, HasBlend = 1 diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 03542e331..ba7d97a03 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -558,11 +558,15 @@ struct functor_traits> { template struct scalar_cbrt_op { EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::cbrt(a); } + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { + return internal::pcbrt(a); + } }; template struct functor_traits> { - enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; + enum { Cost = 20 * NumTraits::MulCost, PacketAccess = packet_traits::HasCbrt }; }; /** \internal diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 7333ad8fa..76475923f 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -812,6 +812,7 @@ void packetmath() { CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); + CHECK_CWISE1_IF(PacketTraits::HasCbrt, numext::cbrt, internal::pcbrt); } // Notice that this definition works for complex types as well. @@ -833,6 +834,7 @@ Scalar log2(Scalar x) { CREATE_FUNCTOR(psqrt_functor, internal::psqrt); CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt); +CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt); // TODO(rmlarsen): Run this test for more functions. template @@ -1203,6 +1205,7 @@ void packetmath_real() { packetmath_test_IEEE_corner_cases(numext::sqrt, psqrt_functor()); packetmath_test_IEEE_corner_cases(numext::rsqrt, prsqrt_functor()); + packetmath_test_IEEE_corner_cases(numext::cbrt, pcbrt_functor()); // TODO(rmlarsen): Re-enable for half and bfloat16. if (PacketTraits::HasCos && !internal::is_same::value &&