mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 02:33:59 +08:00
Vectorize cbrt for float and double.
This commit is contained in:
parent
5330960900
commit
33f5f59614
@ -72,6 +72,7 @@ struct default_packet_traits {
|
||||
HasReciprocal = 0,
|
||||
HasSqrt = 0,
|
||||
HasRsqrt = 0,
|
||||
HasCbrt = 0,
|
||||
HasExp = 0,
|
||||
HasExpm1 = 0,
|
||||
HasLog = 0,
|
||||
|
@ -836,8 +836,8 @@ EIGEN_DEVICE_FUNC std::enable_if_t<(std::numeric_limits<T>::has_infinity && !Num
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC
|
||||
std::enable_if_t<!(std::numeric_limits<T>::has_quiet_NaN || std::numeric_limits<T>::has_signaling_NaN), bool>
|
||||
isnan_impl(const T&) {
|
||||
std::enable_if_t<!(std::numeric_limits<T>::has_quiet_NaN || std::numeric_limits<T>::has_signaling_NaN), bool>
|
||||
isnan_impl(const T&) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1361,11 +1361,17 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
|
||||
|
||||
/** \returns the cube root of \a x. **/
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T cbrt(const T& x) {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t<!NumTraits<T>::IsComplex, T> cbrt(const T& x) {
|
||||
EIGEN_USING_STD(cbrt);
|
||||
return static_cast<T>(cbrt(x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t<NumTraits<T>::IsComplex, T> cbrt(const T& x) {
|
||||
EIGEN_USING_STD(pow);
|
||||
return pow(x, typename NumTraits<T>::Real(1.0 / 3.0));
|
||||
}
|
||||
|
||||
/** \returns the reciprocal square root of \a x. **/
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T rsqrt(const T& x) {
|
||||
@ -1394,17 +1400,17 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double log(const double& x) {
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<NumTraits<T>::IsSigned || NumTraits<T>::IsComplex, typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
EIGEN_USING_STD(abs);
|
||||
return abs(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
std::enable_if_t<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex), typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<!(NumTraits<T>::IsSigned || NumTraits<T>::IsComplex), typename NumTraits<T>::Real>
|
||||
abs(const T& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,7 @@ EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, Packet4d)
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d)
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d)
|
||||
|
@ -122,6 +122,7 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasBessel = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH,
|
||||
@ -150,6 +151,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasATan = 1,
|
||||
HasATanh = 1,
|
||||
HasBlend = 1
|
||||
|
@ -128,6 +128,7 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasATanh = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasLog = 1,
|
||||
HasLog1p = 1,
|
||||
HasExpm1 = 1,
|
||||
@ -153,6 +154,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasBlend = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasSin = EIGEN_FAST_MATH,
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasLog = 1,
|
||||
|
@ -186,6 +186,7 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasExp = 1,
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
HasSqrt = 1,
|
||||
HasCbrt = 1,
|
||||
#if !EIGEN_COMP_CLANG
|
||||
HasRsqrt = 1,
|
||||
#else
|
||||
@ -3176,6 +3177,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasLog = 0,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasCbrt = 1,
|
||||
#if !EIGEN_COMP_CLANG
|
||||
HasRsqrt = 1,
|
||||
#else
|
||||
|
@ -289,6 +289,143 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const
|
||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
|
||||
}
|
||||
|
||||
// This function implements a single step of Halley's iteration for
|
||||
// computing x = y^(1/3):
|
||||
// x_{k+1} = x_k - (x_k^3 - y) x_k / (2x_k^3 + y)
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_halley_iteration_step(const Packet& x_k,
|
||||
const Packet& y) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
Packet x_k_cb = pmul(x_k, pmul(x_k, x_k));
|
||||
Packet denom = pmadd(pset1<Packet>(Scalar(2)), x_k_cb, y);
|
||||
Packet num = psub(x_k_cb, y);
|
||||
Packet r = pdiv(num, denom);
|
||||
return pnmadd(x_k, r, x_k);
|
||||
}
|
||||
|
||||
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
|
||||
// interval [0.125,1].
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_decompose(const Packet& x, Packet& e_div3) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
// Extract the significant s in the range [0.5,1) and exponent e, such that
|
||||
// x = 2^e * s.
|
||||
Packet e, s;
|
||||
s = pfrexp(x, e);
|
||||
|
||||
// Split the exponent into a part divisible by 3 and the remainder.
|
||||
// e = 3*e_div3 + e_mod3.
|
||||
constexpr Scalar kOneThird = Scalar(1) / 3;
|
||||
e_div3 = pceil(pmul(e, pset1<Packet>(kOneThird)));
|
||||
Packet e_mod3 = pnmadd(pset1<Packet>(Scalar(3)), e_div3, e);
|
||||
|
||||
// Replace s by y = (s * 2^e_mod3).
|
||||
return pldexp_fast(s, e_mod3);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet cbrt_special_cases_and_sign(const Packet& x,
|
||||
const Packet& abs_root) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// Set sign.
|
||||
const Packet sign_mask = pset1<Packet>(Scalar(-0.0));
|
||||
const Packet x_sign = pand(sign_mask, x);
|
||||
Packet root = por(x_sign, abs_root);
|
||||
|
||||
// Handle non-finite and zero values of x.
|
||||
// constexpr Scalar kInf = NumTraits<Scalar>::infinity();
|
||||
const Packet is_not_finite = psub(x,x);;
|
||||
const Packet is_zero = pcmp_eq(pzero(x), x);
|
||||
const Packet use_root = por(is_not_finite, is_zero);
|
||||
return pselect(use_root, x, root);
|
||||
}
|
||||
|
||||
// Generic implementation of cbrt(x) for float.
|
||||
//
|
||||
// The algorithm computes the cubic root of the input by first
|
||||
// decomposing it into a exponent and significant
|
||||
// x = s * 2^e.
|
||||
//
|
||||
// We can then write the cube root as
|
||||
//
|
||||
// x^(1/3) = 2^(e/3) * s^(1/3)
|
||||
// = 2^((3*e_div3 + e_mod3)/3) * s^(1/3)
|
||||
// = 2^(e_div3) * 2^(e_mod3/3) * s^(1/3)
|
||||
// = 2^(e_div3) * (s * 2^e_mod3)^(1/3)
|
||||
//
|
||||
// where e_div3 = ceil(e/3) and e_mod3 = e - 3*e_div3.
|
||||
//
|
||||
// The cube root of the second term y = (s * 2^e_mod3)^(1/3) is coarsely
|
||||
// approximated using a cubic polynomial and subsequently refined using a
|
||||
// single step of Halley's iteration, and finally the two terms are combined
|
||||
// using pldexp_fast.
|
||||
//
|
||||
// Note: Many alternatives exist for implementing cbrt. See, for example,
|
||||
// the excellent discussion in Kahan's note:
|
||||
// https://csclub.uwaterloo.ca/~pbarfuss/qbrt.pdf
|
||||
// This particular implementation was found to be very fast and accurate
|
||||
// among several alternatives tried, but is probably not "optimal" on all
|
||||
// platforms.
|
||||
//
|
||||
// This is accurate to 2 ULP.
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
|
||||
|
||||
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
|
||||
// interval [0.125,1].
|
||||
Packet e_div3;
|
||||
const Packet y = cbrt_decompose(pabs(x), e_div3);
|
||||
|
||||
// Compute initial approximation accurate to 5.22e-3.
|
||||
// The polynomial was computed using Rminimax.
|
||||
constexpr float alpha[] = {5.9220016002655029296875e-01f, -1.3859539031982421875e+00f, 1.4581282138824462890625e+00f,
|
||||
3.408401906490325927734375e-01f};
|
||||
Packet r = ppolevl<Packet, 3>::run(y, alpha);
|
||||
|
||||
// Take one step of Halley's iteration.
|
||||
r = cbrt_halley_iteration_step(r, y);
|
||||
|
||||
// Finally multiply by 2^(e_div3)
|
||||
r = pldexp_fast(r, e_div3);
|
||||
|
||||
return cbrt_special_cases_and_sign(x, r);
|
||||
}
|
||||
|
||||
// Generic implementation of cbrt(x) for double.
|
||||
//
|
||||
// The algorithm is identical to the one for float except that a different initial
|
||||
// approximation is used for y^(1/3) and two Halley iteration steps are peformed.
|
||||
//
|
||||
// This is accurate to 1 ULP.
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
|
||||
|
||||
// Decompose the input such that x^(1/3) = y^(1/3) * 2^e_div3, and y is in the
|
||||
// interval [0.125,1].
|
||||
Packet e_div3;
|
||||
const Packet y = cbrt_decompose(pabs(x), e_div3);
|
||||
|
||||
// Compute initial approximation accurate to 0.016.
|
||||
// The polynomial was computed using Rminimax.
|
||||
constexpr double alpha[] = {-4.69470621553356115551736138513660989701747894287109375e-01,
|
||||
1.072314636518546304699839311069808900356292724609375e+00,
|
||||
3.81249427609571867048288140722434036433696746826171875e-01};
|
||||
Packet r = ppolevl<Packet, 2>::run(y, alpha);
|
||||
|
||||
// Take two steps of Halley's iteration.
|
||||
r = cbrt_halley_iteration_step(r, y);
|
||||
r = cbrt_halley_iteration_step(r, y);
|
||||
|
||||
// Finally multiply by 2^(e_div3).
|
||||
r = pldexp_fast(r, e_div3);
|
||||
return cbrt_special_cases_and_sign(x, r);
|
||||
}
|
||||
|
||||
// Natural or base 2 logarithm.
|
||||
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
|
||||
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
|
||||
@ -1123,7 +1260,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_atan(const Pa
|
||||
|
||||
constexpr Scalar kPiOverTwo = static_cast<Scalar>(EIGEN_PI / 2);
|
||||
|
||||
const Packet cst_signmask = pset1<Packet>(-Scalar(0));
|
||||
const Packet cst_signmask = pset1<Packet>(Scalar(-0.0));
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
|
||||
|
||||
|
@ -54,6 +54,14 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_fast(const Packet& a, const Packet& exponent);
|
||||
|
||||
/** \internal \returns cbrt(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_float(const Packet& x_in);
|
||||
|
||||
/** \internal \returns cbrt(x) for double precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcbrt_double(const Packet& x_in);
|
||||
|
||||
/** \internal \returns log(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog_float(const Packet _x);
|
||||
@ -195,6 +203,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(log, PACKET) \
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(log2, PACKET) \
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(exp, PACKET) \
|
||||
EIGEN_FLOAT_PACKET_FUNCTION(cbrt, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(expm1, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(log1p, PACKET) \
|
||||
@ -208,6 +217,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET) \
|
||||
EIGEN_DOUBLE_PACKET_FUNCTION(cbrt, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(atan, PACKET) \
|
||||
EIGEN_GENERIC_PACKET_FUNCTION(exp2, PACKET)
|
||||
|
||||
|
@ -207,6 +207,7 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH,
|
||||
@ -5160,6 +5161,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH
|
||||
|
@ -195,6 +195,7 @@ struct packet_traits<float> : default_packet_traits {
|
||||
HasBessel = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
HasErfc = EIGEN_FAST_MATH,
|
||||
@ -222,6 +223,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasCbrt = 1,
|
||||
HasATan = 1,
|
||||
HasATanh = 1,
|
||||
HasBlend = 1
|
||||
|
@ -558,11 +558,15 @@ struct functor_traits<scalar_sqrt_op<bool>> {
|
||||
template <typename Scalar>
|
||||
struct scalar_cbrt_op {
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::cbrt(a); }
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
|
||||
return internal::pcbrt(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct functor_traits<scalar_cbrt_op<Scalar>> {
|
||||
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
|
||||
enum { Cost = 20 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCbrt };
|
||||
};
|
||||
|
||||
/** \internal
|
||||
|
@ -812,6 +812,7 @@ void packetmath() {
|
||||
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasCbrt, numext::cbrt, internal::pcbrt);
|
||||
}
|
||||
|
||||
// Notice that this definition works for complex types as well.
|
||||
@ -833,6 +834,7 @@ Scalar log2(Scalar x) {
|
||||
|
||||
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
|
||||
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
|
||||
CREATE_FUNCTOR(pcbrt_functor, internal::pcbrt);
|
||||
|
||||
// TODO(rmlarsen): Run this test for more functions.
|
||||
template <bool Cond, typename Scalar, typename Packet, typename RefFunctorT, typename FunctorT>
|
||||
@ -1203,6 +1205,7 @@ void packetmath_real() {
|
||||
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, psqrt_functor());
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, prsqrt_functor());
|
||||
packetmath_test_IEEE_corner_cases<PacketTraits::HasCbrt, Scalar, Packet>(numext::cbrt<Scalar>, pcbrt_functor());
|
||||
|
||||
// TODO(rmlarsen): Re-enable for half and bfloat16.
|
||||
if (PacketTraits::HasCos && !internal::is_same<Scalar, half>::value &&
|
||||
|
Loading…
x
Reference in New Issue
Block a user