Vectorize cbrt for float and double.

This commit is contained in:
Rasmus Munk Larsen 2025-04-17 23:31:20 +00:00
parent 5330960900
commit 33f5f59614
12 changed files with 183 additions and 11 deletions

View File

@ -72,6 +72,7 @@ struct default_packet_traits {
HasReciprocal = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasCbrt = 0,
HasExp = 0,
HasExpm1 = 0,
HasLog = 0,

View File

@ -836,8 +836,8 @@ EIGEN_DEVICE_FUNC std::enable_if_t<(std::numeric_limits<T>::has_infinity && !Num
template <typename T>
EIGEN_DEVICE_FUNC
std::enable_if_t<!(std::numeric_limits<T>::has_quiet_NaN || std::numeric_limits<T>::has_signaling_NaN), bool>
isnan_impl(const T&) {
std::enable_if_t<!(std::numeric_limits<T>::has_quiet_NaN || std::numeric_limits<T>::has_signaling_NaN), bool>
isnan_impl(const T&) {
return false;
}
@ -1361,11 +1361,17 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
/** \returns the cube root of \a x. **/
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T cbrt(const T& x) {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t<!NumTraits<T>::IsComplex, T> cbrt(const T& x) {
EIGEN_USING_STD(cbrt);
return static_cast<T>(cbrt(x));
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t<NumTraits<T>::IsComplex, T> cbrt(const T& x) {
EIGEN_USING_STD(pow);
return pow(x, typename NumTraits<T>::Real(1.0 / 3.0));
}
/** \returns the reciprocal square root of \a x. **/
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T rsqrt(const T& x) {
@ -1394,17 +1400,17 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double log(const double& x) {
#endif
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
abs(const T& x) {
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
abs(const T& x) {
EIGEN_USING_STD(abs);
return abs(x);
}
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
std::enable_if_t<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex), typename NumTraits<T>::Real>
abs(const T& x) {
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE std::enable_if_t<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex), typename NumTraits<T>::Real>
abs(const T& x) {
return x;
}

View File

@ -28,6 +28,7 @@ EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, Packet4d)
#ifdef EIGEN_VECTORIZE_AVX2
EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d)

View File

@ -122,6 +122,7 @@ struct packet_traits<float> : default_packet_traits {
HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
@ -150,6 +151,7 @@ struct packet_traits<double> : default_packet_traits {
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasATan = 1,
HasATanh = 1,
HasBlend = 1

View File

@ -128,6 +128,7 @@ struct packet_traits<float> : default_packet_traits {
HasATanh = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
@ -153,6 +154,7 @@ struct packet_traits<double> : default_packet_traits {
HasBlend = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,

View File

@ -186,6 +186,7 @@ struct packet_traits<float> : default_packet_traits {
HasExp = 1,
#ifdef EIGEN_VECTORIZE_VSX
HasSqrt = 1,
HasCbrt = 1,
#if !EIGEN_COMP_CLANG
HasRsqrt = 1,
#else
@ -3176,6 +3177,7 @@ struct packet_traits<double> : default_packet_traits {
HasLog = 0,
HasExp = 1,
HasSqrt = 1,
HasCbrt = 1,
#if !EIGEN_COMP_CLANG
HasRsqrt = 1,
#else

View File

@ -289,6 +289,143 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
}
// This function implements a single step of Halley's iteration for
// computing x = y^(1/3):
// x_{k+1} = x_k - (x_k^3 - y) x_k / (2x_k^3 + y)
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_halley_iteration_step(const Packet& x_k,
const Packet& y) {
typedef typename unpacket_traits<Packet>::type Scalar;
Packet x_k_cb = pmul(x_k, pmul(x_k, x_k));
Packet denom = pmadd(pset1<Packet>(Scalar(2)), x_k_cb, y);
Packet num = psub(x_k_cb, y);
Packet r = pdiv(num, denom);
return pnmadd(x_k, r, x_k);
}
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
// interval [0.125,1].
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_decompose(const Packet& x, Packet& e_div3) {
typedef typename unpacket_traits<Packet>::type Scalar;
// Extract the significant s in the range [0.5,1) and exponent e, such that
// x = 2^e * s.
Packet e, s;
s = pfrexp(x, e);
// Split the exponent into a part divisible by 3 and the remainder.
// e = 3*e_div3 + e_mod3.
constexpr Scalar kOneThird = Scalar(1) / 3;
e_div3 = pceil(pmul(e, pset1<Packet>(kOneThird)));
Packet e_mod3 = pnmadd(pset1<Packet>(Scalar(3)), e_div3, e);
// Replace s by y = (s * 2^e_mod3).
return pldexp_fast(s, e_mod3);
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_special_cases_and_sign(const Packet& x,
const Packet& abs_root) {
typedef typename unpacket_traits<Packet>::type Scalar;
// Set sign.
const Packet sign_mask = pset1<Packet>(Scalar(-0.0));
const Packet x_sign = pand(sign_mask, x);
Packet root = por(x_sign, abs_root);
// Handle non-finite and zero values of x.
// constexpr Scalar kInf = NumTraits<Scalar>::infinity();
const Packet is_not_finite = psub(x,x);;
const Packet is_zero = pcmp_eq(pzero(x), x);
const Packet use_root = por(is_not_finite, is_zero);
return pselect(use_root, x, root);
}
// Generic implementation of cbrt(x) for float.
//
// The algorithm computes the cubic root of the input by first
// decomposing it into a exponent and significant
// x = s * 2^e.
//
// We can then write the cube root as
//
// x^(1/3) = 2^(e/3) * s^(1/3)
// = 2^((3*e_div3 + e_mod3)/3) * s^(1/3)
// = 2^(e_div3) * 2^(e_mod3/3) * s^(1/3)
// = 2^(e_div3) * (s * 2^e_mod3)^(1/3)
//
// where e_div3 = ceil(e/3) and e_mod3 = e - 3*e_div3.
//
// The cube root of the second term y = (s * 2^e_mod3)^(1/3) is coarsely
// approximated using a cubic polynomial and subsequently refined using a
// single step of Halley's iteration, and finally the two terms are combined
// using pldexp_fast.
//
// Note: Many alternatives exist for implementing cbrt. See, for example,
// the excellent discussion in Kahan's note:
// https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf
// This particular implementation was found to be very fast and accurate
// among several alternatives tried, but is probably not "optimal" on all
// platforms.
//
// This is accurate to 2 ULP.
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
// interval [0.125,1].
Packet e_div3;
const Packet y = cbrt_decompose(pabs(x), e_div3);
// Compute initial approximation accurate to 5.22e-3.
// The polynomial was computed using Rminimax.
constexpr float alpha[] = {5.9220016002655029296875e-01f, -1.3859539031982421875e+00f, 1.4581282138824462890625e+00f,
3.408401906490325927734375e-01f};
Packet r = ppolevl<Packet, 3>::run(y, alpha);
// Take one step of Halley's iteration.
r = cbrt_halley_iteration_step(r, y);
// Finally multiply by 2^(e_div3)
r = pldexp_fast(r, e_div3);
return cbrt_special_cases_and_sign(x, r);
}
// Generic implementation of cbrt(x) for double.
//
// The algorithm is identical to the one for float except that a different initial
// approximation is used for y^(1/3) and two Halley iteration steps are peformed.
//
// This is accurate to 1 ULP.
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
// interval [0.125,1].
Packet e_div3;
const Packet y = cbrt_decompose(pabs(x), e_div3);
// Compute initial approximation accurate to 0.016.
// The polynomial was computed using Rminimax.
constexpr double alpha[] = {-4.69470621553356115551736138513660989701747894287109375e-01,
1.072314636518546304699839311069808900356292724609375e+00,
3.81249427609571867048288140722434036433696746826171875e-01};
Packet r = ppolevl<Packet, 2>::run(y, alpha);
// Take two steps of Halley's iteration.
r = cbrt_halley_iteration_step(r, y);
r = cbrt_halley_iteration_step(r, y);
// Finally multiply by 2^(e_div3).
r = pldexp_fast(r, e_div3);
return cbrt_special_cases_and_sign(x, r);
}
// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
@ -1123,7 +1260,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_atan(const Pa
constexpr Scalar kPiOverTwo = static_cast<Scalar>(EIGEN_PI / 2);
const Packet cst_signmask = pset1<Packet>(-Scalar(0));
const Packet cst_signmask = pset1<Packet>(Scalar(-0.0));
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);

View File

@ -54,6 +54,14 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con
template <typename Packet>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent);
/** \internal \returns cbrt(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x_in);
/** \internal \returns cbrt(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x_in);
/** \internal \returns log(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_float(const Packet _x);
@ -195,6 +203,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_FLOAT_PACKET_FUNCTION(log, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(log2, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(exp, PACKET) \
EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
@ -208,6 +217,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET) \
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET)

View File

@ -207,6 +207,7 @@ struct packet_traits<float> : default_packet_traits {
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
@ -5160,6 +5161,7 @@ struct packet_traits<double> : default_packet_traits {
HasCos = EIGEN_FAST_MATH,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH

View File

@ -195,6 +195,7 @@ struct packet_traits<float> : default_packet_traits {
HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
@ -222,6 +223,7 @@ struct packet_traits<double> : default_packet_traits {
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasATan = 1,
HasATanh = 1,
HasBlend = 1

View File

@ -558,11 +558,15 @@ struct functor_traits<scalar_sqrt_op<bool>> {
template <typename Scalar>
struct scalar_cbrt_op {
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::cbrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
return internal::pcbrt(a);
}
};
template <typename Scalar>
struct functor_traits<scalar_cbrt_op<Scalar>> {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
enum { Cost = 20 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCbrt };
};
/** \internal

View File

@ -812,6 +812,7 @@ void packetmath() {
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
CHECK_CWISE1_IF(PacketTraits::HasCbrt, numext::cbrt, internal::pcbrt);
}
// Notice that this definition works for complex types as well.
@ -833,6 +834,7 @@ Scalar log2(Scalar x) {
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt);
// TODO(rmlarsen): Run this test for more functions.
template <bool Cond, typename Scalar, typename Packet, typename RefFunctorT, typename FunctorT>
@ -1203,6 +1205,7 @@ void packetmath_real() {
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, psqrt_functor());
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, prsqrt_functor());
packetmath_test_IEEE_corner_cases<PacketTraits::HasCbrt, Scalar, Packet>(numext::cbrt<Scalar>, pcbrt_functor());
// TODO(rmlarsen): Re-enable for half and bfloat16.
if (PacketTraits::HasCos && !internal::is_same<Scalar, half>::value &&