From e5af9f87f255bd6c0f3129413b0df4ea75fedfc8 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Mon, 29 Aug 2022 19:23:54 +0000 Subject: [PATCH] Vectorize pow for integer base / exponent types --- Eigen/src/Core/arch/AVX/PacketMath.h | 2 + Eigen/src/Core/arch/AVX512/PacketMath.h | 1 + Eigen/src/Core/arch/AltiVec/PacketMath.h | 15 ++-- .../arch/Default/GenericPacketMathFunctions.h | 43 +++++++++++ Eigen/src/Core/arch/SSE/PacketMath.h | 1 + Eigen/src/Core/functors/UnaryFunctors.h | 6 +- test/array_cwise.cpp | 74 +++++++++++++++++++ 7 files changed, 134 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 28e65d4a0..7608776b1 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -208,6 +208,7 @@ template<> struct packet_traits : default_packet_traits enum { Vectorizable = 1, AlignedOnScalar = 1, + HasCmp = 1, size=8 }; }; @@ -222,6 +223,7 @@ template<> struct packet_traits : default_packet_traits enum { Vectorizable = 1, AlignedOnScalar = 1, + HasCmp = 1, size=4, // requires AVX512 diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 213b6902a..0158b12fd 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -174,6 +174,7 @@ template<> struct packet_traits : default_packet_traits enum { Vectorizable = 1, AlignedOnScalar = 1, + HasCmp = 1, size=16 }; }; diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 5ec06c458..7460bdc56 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -253,7 +253,8 @@ struct packet_traits : default_packet_traits { HasShift = 1, HasMul = 1, HasDiv = 0, - HasBlend = 1 + HasBlend = 1, + HasCmp = 1 }; }; @@ -271,7 +272,8 @@ struct packet_traits : default_packet_traits { HasSub = 1, HasMul = 1, HasDiv = 0, - HasBlend = 1 + HasBlend = 1, + HasCmp = 1 }; }; @@ -289,7 +291,8 @@ struct packet_traits : default_packet_traits { HasSub = 1, HasMul = 1, HasDiv = 0, - HasBlend = 1 + HasBlend = 1, + HasCmp = 1 }; }; @@ -307,7 +310,8 @@ struct packet_traits : default_packet_traits { HasSub = 1, HasMul = 1, HasDiv = 0, - HasBlend = 1 + HasBlend = 1, + HasCmp = 1 }; }; @@ -325,7 +329,8 @@ struct packet_traits : default_packet_traits { HasSub = 1, HasMul = 1, HasDiv = 0, - HasBlend = 1 + HasBlend = 1, + HasCmp = 1 }; }; diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index ea72b6a28..694ccfc39 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1880,6 +1880,36 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors( result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result); return result; } + +template +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { + typedef typename unpacket_traits::type Scalar; + + // integer base, integer exponent case + + // This routine handles negative and very large positive exponents + // Signed integer overflow and divide by zero is undefined behavior + // Unsigned intgers do not overflow + + const bool exponent_is_odd = unary_pow::is_odd::run(exponent); + + const Scalar zero = Scalar(0); + const Scalar pos_one = Scalar(1); + + const Packet cst_zero = pset1(zero); + const Packet cst_pos_one = pset1(pos_one); + + const Packet abs_x = pabs(x); + + const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, abs_x) : pzero(x); + const Packet pow_is_one = pcmp_eq(cst_pos_one, abs_x); + const Packet pow_is_neg = exponent_is_odd ? pcmp_lt(x, cst_zero) : pzero(x); + + Packet result = pselect(pow_is_zero, cst_zero, x); + result = pselect(pandnot(pow_is_one, pow_is_neg), cst_pos_one, result); + result = pselect(pand(pow_is_one, pow_is_neg), pnegate(cst_pos_one), result); + return result; +} } // end namespace unary_pow template { } }; +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + if (exponent < 0 || exponent > NumTraits::digits()) { + return unary_pow::handle_int_int(x, exponent); + } + else { + return unary_pow::int_pow(x, exponent); + } + } +}; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 42b698ac1..bd3424fa3 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -191,6 +191,7 @@ template<> struct packet_traits : default_packet_traits enum { Vectorizable = 1, AlignedOnScalar = 1, + HasCmp = 1, size=4, HasShift = 1, diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 15f2cf12a..6b1bcc757 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1109,8 +1109,8 @@ struct scalar_unary_pow_op { scalar_unary_pow_op() {} }; -template -struct scalar_unary_pow_op { +template +struct scalar_unary_pow_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); } @@ -1132,7 +1132,7 @@ template struct functor_traits> { enum { GenPacketAccess = functor_traits>::PacketAccess, - IntPacketAccess = !NumTraits::IsComplex && !NumTraits::IsInteger && packet_traits::HasMul && packet_traits::HasDiv && packet_traits::HasCmp, + IntPacketAccess = !NumTraits::IsComplex && packet_traits::HasMul && (packet_traits::HasDiv || NumTraits::IsInteger) && packet_traits::HasCmp, PacketAccess = NumTraits::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess), Cost = functor_traits>::Cost }; diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 3e96e4cb5..64c4c2a86 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -126,6 +126,72 @@ void pow_test() { VERIFY(all_pass); } +template +Scalar calc_overflow_threshold(const ScalarExponent exponent) { + EIGEN_USING_STD(exp2); + EIGEN_STATIC_ASSERT((NumTraits::digits() < 2 * NumTraits::digits()), BASE_TYPE_IS_TOO_BIG); + + if (exponent < 2) + return NumTraits::highest(); + else { + const double max_exponent = static_cast(NumTraits::digits()); + const double clamped_exponent = exponent < max_exponent ? static_cast(exponent) : max_exponent; + const double threshold = exp2(max_exponent / clamped_exponent); + return static_cast(threshold); + } +} + +template +void test_exponent(Exponent exponent) { + EIGEN_USING_STD(pow); + + const Base max_abs_bases = 10000; + // avoid integer overflow in Base type + Base threshold = calc_overflow_threshold(numext::abs(exponent)); + // avoid numbers that can't be verified with std::pow + double double_threshold = calc_overflow_threshold(numext::abs(exponent)); + // use the lesser of these two thresholds + Base testing_threshold = threshold < double_threshold ? threshold : static_cast(double_threshold); + // avoid divide by zero + Base min_abs_base = exponent < 0 ? 1 : 0; + // avoid excessively long test + Base max_abs_base = numext::mini(testing_threshold, max_abs_bases); + // test both vectorized and non-vectorized code paths + const Index array_size = 2 * internal::packet_traits::size + 1; + + Base max_base = numext::mini(testing_threshold, max_abs_bases); + Base min_base = NumTraits::IsSigned ? -max_base : 0; + + ArrayX x(array_size), y(array_size); + + bool all_pass = true; + + for (Base base = min_base; base <= max_base; base++) { + if (exponent < 0 && base == 0) continue; + x.setConstant(base); + y = x.pow(exponent); + Base e = pow(base, exponent); + for (Base a : y) { + bool pass = a == e; + all_pass &= pass; + if (!pass) { + std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl; + } + } + } + + VERIFY(all_pass); +} +template +void int_pow_test() { + Exponent max_exponent = NumTraits::digits(); + Exponent min_exponent = NumTraits::IsSigned ? -max_exponent : 0; + + for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) { + test_exponent(exponent); + } +} + template void array(const ArrayType& m) { typedef typename ArrayType::Scalar Scalar; @@ -500,6 +566,14 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse()); pow_test(); + typedef typename internal::make_integer::type SignedInt; + typedef typename std::make_unsigned::type UnsignedInt; + + int_pow_test(); + int_pow_test(); + int_pow_test(); + int_pow_test(); + VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10))); VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));