diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h index 74edd2c27..4f1f992ee 100644 --- a/Eigen/src/Core/NumTraits.h +++ b/Eigen/src/Core/NumTraits.h @@ -63,10 +63,10 @@ struct default_digits_impl // Floating point { EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static int run() { - using std::log; + using std::log2; using std::ceil; typedef typename NumTraits::Real Real; - return int(ceil(-log(NumTraits::epsilon())/log(static_cast(2)))); + return int(ceil(-log2(NumTraits::epsilon()))); } }; diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index f98801294..39b7be7c3 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1070,75 +1070,94 @@ struct functor_traits > { }; }; -template ::IsInteger, - bool ExponentIsInteger = NumTraits::IsInteger, - bool BaseIsComplex = NumTraits::IsComplex, - bool ExponentIsComplex = NumTraits::IsComplex> +template ::IsInteger, + bool IsExponentInteger = NumTraits::IsInteger, + bool IsBaseComplex = NumTraits::IsComplex, + bool IsExponentComplex = NumTraits::IsComplex> struct scalar_unary_pow_op { typedef typename internal::promote_scalar_arg< - Scalar, ScalarExponent, - internal::has_ReturnType >::value>::type PromotedExponent; + Scalar, ExponentScalar, + internal::has_ReturnType >::value>::type PromotedExponent; typedef typename ScalarBinaryOpTraits::ReturnType result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { - EIGEN_STATIC_ASSERT((is_arithmetic::Real>::value), EXPONENT_MUST_BE_ARITHMETIC_OR_COMPLEX); - } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(exponent) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const { EIGEN_USING_STD(pow); return static_cast(pow(a, m_exponent)); } private: - const ScalarExponent m_exponent; + const ExponentScalar m_exponent; scalar_unary_pow_op() {} }; -template -struct scalar_unary_pow_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { - EIGEN_STATIC_ASSERT((is_same, std::remove_const_t>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE); - EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); +template +constexpr int exponent_digits() { + return CHAR_BIT * sizeof(T) - NumTraits::digits() - NumTraits::IsSigned; +} + +template +struct is_floating_exactly_representable { + // TODO(rmlarsen): Add radix to NumTraits and enable this check. + // (NumTraits::radix == NumTraits::radix) && + static constexpr bool value = (exponent_digits() >= exponent_digits() && + NumTraits::digits() >= NumTraits::digits()); +}; + + +// Specialization for real, non-integer types, non-complex types. +template +struct scalar_unary_pow_op { + template ::value> + std::enable_if_t check_is_representable() const {} + + // Issue a deprecation warning if we do a narrowing conversion on the exponent. + template ::value> + EIGEN_DEPRECATED std::enable_if_t check_is_representable() const {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(static_cast(exponent)) { + check_is_representable(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const { EIGEN_USING_STD(pow); return static_cast(pow(a, m_exponent)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { - return unary_pow_impl::run(a, m_exponent); + return unary_pow_impl::run(a, m_exponent); } private: - const ScalarExponent m_exponent; + const Scalar m_exponent; scalar_unary_pow_op() {} }; -template -struct scalar_unary_pow_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { - EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); - } +template +struct scalar_unary_pow_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(exponent) {} // TODO: error handling logic for complex^real_integer EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const { - return unary_pow_impl::run(a, m_exponent); + return unary_pow_impl::run(a, m_exponent); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const { - return unary_pow_impl::run(a, m_exponent); + return unary_pow_impl::run(a, m_exponent); } private: - const ScalarExponent m_exponent; + const ExponentScalar m_exponent; scalar_unary_pow_op() {} }; -template -struct functor_traits> { +template +struct functor_traits> { enum { - GenPacketAccess = functor_traits>::PacketAccess, + GenPacketAccess = functor_traits>::PacketAccess, IntPacketAccess = !NumTraits::IsComplex && packet_traits::HasMul && (packet_traits::HasDiv || NumTraits::IsInteger) && packet_traits::HasCmp, - PacketAccess = NumTraits::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess), - Cost = functor_traits>::Cost + PacketAccess = NumTraits::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess), + Cost = functor_traits>::Cost }; }; diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 1d4f7b505..5a4c28278 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -138,7 +138,7 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) { // base^e <= highest ==> base <= 2^(log2(highest)/e) // For floating-point types, consider the bound for integer values that can be reproduced exactly = 2 ^ digits double highest_bits = numext::mini(static_cast(NumTraits::digits()), - log2(NumTraits::highest())); + static_cast(log2(NumTraits::highest()))); return static_cast( numext::floor(exp2(highest_bits / static_cast(exponent)))); } @@ -146,49 +146,90 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) { template void test_exponent(Exponent exponent) { + const Base max_abs_bases = static_cast(10000); + // avoid integer overflow in Base type + Base threshold = calc_overflow_threshold(numext::abs(exponent)); + // avoid numbers that can't be verified with std::pow + double double_threshold = calc_overflow_threshold(numext::abs(exponent)); + // use the lesser of these two thresholds + Base testing_threshold = + static_cast(threshold) < double_threshold ? threshold : static_cast(double_threshold); + // test both vectorized and non-vectorized code paths + const Index array_size = 2 * internal::packet_traits::size + 1; + + Base max_base = numext::mini(testing_threshold, max_abs_bases); + Base min_base = NumTraits::IsSigned ? -max_base : Base(0); + + ArrayX x(array_size), y(array_size); + bool all_pass = true; + for (Base base = min_base; base <= max_base; base++) { + if (exponent < 0 && base == 0) continue; + x.setConstant(base); + y = x.pow(exponent); EIGEN_USING_STD(pow); - - const Base max_abs_bases = 10000; - // avoid integer overflow in Base type - Base threshold = calc_overflow_threshold(numext::abs(exponent)); - // avoid numbers that can't be verified with std::pow - double double_threshold = calc_overflow_threshold(numext::abs(exponent)); - // use the lesser of these two thresholds - Base testing_threshold = threshold < double_threshold ? threshold : static_cast(double_threshold); - // test both vectorized and non-vectorized code paths - const Index array_size = 2 * internal::packet_traits::size + 1; - - Base max_base = numext::mini(testing_threshold, max_abs_bases); - Base min_base = NumTraits::IsSigned ? -max_base : 0; - - ArrayX x(array_size), y(array_size); - - bool all_pass = true; - - for (Base base = min_base; base <= max_base; base++) { - if (exponent < 0 && base == 0) continue; - x.setConstant(base); - y = x.pow(exponent); - Base e = pow(base, exponent); - for (Base a : y) { - bool pass = a == e; - all_pass &= pass; - if (!pass) { - std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl; - } - } + Base e = pow(base, static_cast(exponent)); + for (Base a : y) { + bool pass = (a == e); + if (!NumTraits::IsInteger) { + pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) || + ((numext::isnan)(a) && (numext::isnan)(e))); + } + all_pass &= pass; + if (!pass) { + std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl; + } } - - VERIFY(all_pass); + } + VERIFY(all_pass); } -template -void int_pow_test() { - Exponent max_exponent = NumTraits::digits(); - Exponent min_exponent = NumTraits::IsSigned ? -max_exponent : 0; - for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) { - test_exponent(exponent); - } +template +void unary_pow_test() { + Exponent max_exponent = static_cast(NumTraits::digits()); + Exponent min_exponent = static_cast(NumTraits::IsSigned ? -max_exponent : 0); + + for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) { + test_exponent(exponent); + } +}; + +void mixed_pow_test() { + // The following cases will test promoting a smaller exponent type + // to a wider base type. + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + + // Although in the following cases the exponent cannot be represented exactly + // in the base type, we do not perform a conversion, but implement + // the operation using repeated squaring. + unary_pow_test(); + unary_pow_test(); + + // The following cases will test promoting a wider exponent type + // to a narrower base type. This should compile but generate a + // deprecation warning: + unary_pow_test(); +} + +void int_pow_test() { + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + + // Although in the following cases the exponent cannot be represented exactly + // in the base type, we do not perform a conversion, but implement the + // operation using repeated squaring. + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); + unary_pow_test(); } template void array(const ArrayType& m) @@ -207,7 +248,7 @@ template void array(const ArrayType& m) // Here we cap the size of the values in m1 such that pow(3)/cube() // doesn't overflow and result in undefined behavior. Notice that because // pow(int, int) promotes its inputs and output to double (according to - // the C++ standard), we hvae to make sure that the result fits in 53 bits + // the C++ standard), we have to make sure that the result fits in 53 bits // for int64, RealScalar max_val = numext::mini(RealScalar(std::cbrt(NumTraits::highest())), @@ -565,14 +606,6 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse()); pow_test(); - typedef typename internal::make_integer::type SignedInt; - typedef typename std::make_unsigned::type UnsignedInt; - - int_pow_test(); - int_pow_test(); - int_pow_test(); - int_pow_test(); - VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10))); VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2))); @@ -590,6 +623,7 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m3, m1); } + template void array_complex(const ArrayType& m) { typedef typename ArrayType::Scalar Scalar; @@ -823,6 +857,11 @@ EIGEN_DECLARE_TEST(array_cwise) CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); } + for(int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_6( int_pow_test() ); + CALL_SUBTEST_7( mixed_pow_test() ); + } + VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, int >::value)); VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, float >::value)); VERIFY((internal::is_same< internal::global_math_functions_filtering_base::type, ArrayBase >::value));