Allow mixed types for pow(), as long as the exponent is exactly representable in the base type.

This commit is contained in:
Rasmus Munk Larsen 2022-09-12 21:55:30 +00:00
parent b2c82a9347
commit afc014f1b5
3 changed files with 140 additions and 82 deletions

View File

@ -63,10 +63,10 @@ struct default_digits_impl<T,false,false> // Floating point
{
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
static int run() {
using std::log;
using std::log2;
using std::ceil;
typedef typename NumTraits<T>::Real Real;
return int(ceil(-log(NumTraits<Real>::epsilon())/log(static_cast<Real>(2))));
return int(ceil(-log2(NumTraits<Real>::epsilon())));
}
};

View File

@ -1070,75 +1070,94 @@ struct functor_traits<scalar_logistic_op<T> > {
};
};
template <typename Scalar, typename ScalarExponent,
bool BaseIsInteger = NumTraits<Scalar>::IsInteger,
bool ExponentIsInteger = NumTraits<ScalarExponent>::IsInteger,
bool BaseIsComplex = NumTraits<Scalar>::IsComplex,
bool ExponentIsComplex = NumTraits<ScalarExponent>::IsComplex>
template <typename Scalar, typename ExponentScalar,
bool IsBaseInteger = NumTraits<Scalar>::IsInteger,
bool IsExponentInteger = NumTraits<ExponentScalar>::IsInteger,
bool IsBaseComplex = NumTraits<Scalar>::IsComplex,
bool IsExponentComplex = NumTraits<ExponentScalar>::IsComplex>
struct scalar_unary_pow_op {
typedef typename internal::promote_scalar_arg<
Scalar, ScalarExponent,
internal::has_ReturnType<ScalarBinaryOpTraits<Scalar,ScalarExponent,scalar_unary_pow_op> >::value>::type PromotedExponent;
Scalar, ExponentScalar,
internal::has_ReturnType<ScalarBinaryOpTraits<Scalar,ExponentScalar,scalar_unary_pow_op> >::value>::type PromotedExponent;
typedef typename ScalarBinaryOpTraits<Scalar, PromotedExponent, scalar_unary_pow_op>::ReturnType result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value), EXPONENT_MUST_BE_ARITHMETIC_OR_COMPLEX);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(exponent) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return static_cast<result_type>(pow(a, m_exponent));
}
private:
const ScalarExponent m_exponent;
const ExponentScalar m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_same<std::remove_const_t<Scalar>, std::remove_const_t<ScalarExponent>>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
template <typename T>
constexpr int exponent_digits() {
return CHAR_BIT * sizeof(T) - NumTraits<T>::digits() - NumTraits<T>::IsSigned;
}
template<typename From, typename To>
struct is_floating_exactly_representable {
// TODO(rmlarsen): Add radix to NumTraits and enable this check.
// (NumTraits<To>::radix == NumTraits<From>::radix) &&
static constexpr bool value = (exponent_digits<To>() >= exponent_digits<From>() &&
NumTraits<To>::digits() >= NumTraits<From>::digits());
};
// Specialization for real, non-integer types, non-complex types.
template <typename Scalar, typename ExponentScalar>
struct scalar_unary_pow_op<Scalar, ExponentScalar, false, false, false, false> {
template <bool IsExactlyRepresentable = is_floating_exactly_representable<ExponentScalar, Scalar>::value>
std::enable_if_t<IsExactlyRepresentable, void> check_is_representable() const {}
// Issue a deprecation warning if we do a narrowing conversion on the exponent.
template <bool IsExactlyRepresentable = is_floating_exactly_representable<ExponentScalar, Scalar>::value>
EIGEN_DEPRECATED std::enable_if_t<!IsExactlyRepresentable, void> check_is_representable() const {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(static_cast<Scalar>(exponent)) {
check_is_representable();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return static_cast<Scalar>(pow(a, m_exponent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const {
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
return unary_pow_impl<Packet, Scalar>::run(a, m_exponent);
}
private:
const ScalarExponent m_exponent;
const Scalar m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent, bool BaseIsInteger>
struct scalar_unary_pow_op<Scalar, ScalarExponent, BaseIsInteger, true, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
template <typename Scalar, typename ExponentScalar, bool BaseIsInteger>
struct scalar_unary_pow_op<Scalar, ExponentScalar, BaseIsInteger, true, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ExponentScalar& exponent) : m_exponent(exponent) {}
// TODO: error handling logic for complex^real_integer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
return unary_pow_impl<Scalar, ExponentScalar>::run(a, m_exponent);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const {
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
return unary_pow_impl<Packet, ExponentScalar>::run(a, m_exponent);
}
private:
const ScalarExponent m_exponent;
const ExponentScalar m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent>
struct functor_traits<scalar_unary_pow_op<Scalar, ScalarExponent>> {
template <typename Scalar, typename ExponentScalar>
struct functor_traits<scalar_unary_pow_op<Scalar, ExponentScalar>> {
enum {
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::PacketAccess,
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ExponentScalar>>::PacketAccess,
IntPacketAccess = !NumTraits<Scalar>::IsComplex && packet_traits<Scalar>::HasMul && (packet_traits<Scalar>::HasDiv || NumTraits<Scalar>::IsInteger) && packet_traits<Scalar>::HasCmp,
PacketAccess = NumTraits<ScalarExponent>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
Cost = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::Cost
PacketAccess = NumTraits<ExponentScalar>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
Cost = functor_traits<scalar_pow_op<Scalar, ExponentScalar>>::Cost
};
};

View File

@ -138,7 +138,7 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) {
// base^e <= highest ==> base <= 2^(log2(highest)/e)
// For floating-point types, consider the bound for integer values that can be reproduced exactly = 2 ^ digits
double highest_bits = numext::mini(static_cast<double>(NumTraits<Scalar>::digits()),
log2(NumTraits<Scalar>::highest()));
static_cast<double>(log2(NumTraits<Scalar>::highest())));
return static_cast<Scalar>(
numext::floor(exp2(highest_bits / static_cast<double>(exponent))));
}
@ -146,49 +146,90 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) {
template <typename Base, typename Exponent>
void test_exponent(Exponent exponent) {
const Base max_abs_bases = static_cast<Base>(10000);
// avoid integer overflow in Base type
Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
// avoid numbers that can't be verified with std::pow
double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
// use the lesser of these two thresholds
Base testing_threshold =
static_cast<double>(threshold) < double_threshold ? threshold : static_cast<Base>(double_threshold);
// test both vectorized and non-vectorized code paths
const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
Base max_base = numext::mini(testing_threshold, max_abs_bases);
Base min_base = NumTraits<Base>::IsSigned ? -max_base : Base(0);
ArrayX<Base> x(array_size), y(array_size);
bool all_pass = true;
for (Base base = min_base; base <= max_base; base++) {
if (exponent < 0 && base == 0) continue;
x.setConstant(base);
y = x.pow(exponent);
EIGEN_USING_STD(pow);
const Base max_abs_bases = 10000;
// avoid integer overflow in Base type
Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
// avoid numbers that can't be verified with std::pow
double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
// use the lesser of these two thresholds
Base testing_threshold = threshold < double_threshold ? threshold : static_cast<Base>(double_threshold);
// test both vectorized and non-vectorized code paths
const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
Base max_base = numext::mini(testing_threshold, max_abs_bases);
Base min_base = NumTraits<Base>::IsSigned ? -max_base : 0;
ArrayX<Base> x(array_size), y(array_size);
bool all_pass = true;
for (Base base = min_base; base <= max_base; base++) {
if (exponent < 0 && base == 0) continue;
x.setConstant(base);
y = x.pow(exponent);
Base e = pow(base, exponent);
for (Base a : y) {
bool pass = a == e;
all_pass &= pass;
if (!pass) {
std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
}
}
Base e = pow(base, static_cast<Base>(exponent));
for (Base a : y) {
bool pass = (a == e);
if (!NumTraits<Base>::IsInteger) {
pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) ||
((numext::isnan)(a) && (numext::isnan)(e)));
}
all_pass &= pass;
if (!pass) {
std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
}
}
VERIFY(all_pass);
}
VERIFY(all_pass);
}
template <typename Base, typename Exponent>
void int_pow_test() {
Exponent max_exponent = NumTraits<Base>::digits();
Exponent min_exponent = NumTraits<Exponent>::IsSigned ? -max_exponent : 0;
for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) {
test_exponent<Base, Exponent>(exponent);
}
template <typename Base, typename Exponent>
void unary_pow_test() {
Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
Exponent min_exponent = static_cast<Exponent>(NumTraits<Exponent>::IsSigned ? -max_exponent : 0);
for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
test_exponent<Base, Exponent>(exponent);
}
};
void mixed_pow_test() {
// The following cases will test promoting a smaller exponent type
// to a wider base type.
unary_pow_test<double, int>();
unary_pow_test<double, float>();
unary_pow_test<float, half>();
unary_pow_test<double, half>();
unary_pow_test<float, bfloat16>();
unary_pow_test<double, bfloat16>();
// Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement
// the operation using repeated squaring.
unary_pow_test<float, int>();
unary_pow_test<double, long long>();
// The following cases will test promoting a wider exponent type
// to a narrower base type. This should compile but generate a
// deprecation warning:
unary_pow_test<float, double>();
}
void int_pow_test() {
unary_pow_test<int, int>();
unary_pow_test<unsigned int, unsigned int>();
unary_pow_test<long long, long long>();
unary_pow_test<unsigned long long, unsigned long long>();
// Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement the
// operation using repeated squaring.
unary_pow_test<long long, int>();
unary_pow_test<int, unsigned int>();
unary_pow_test<unsigned int, int>();
unary_pow_test<long long, unsigned long long>();
unary_pow_test<unsigned long long, long long>();
unary_pow_test<long long, int>();
}
template<typename ArrayType> void array(const ArrayType& m)
@ -207,7 +248,7 @@ template<typename ArrayType> void array(const ArrayType& m)
// Here we cap the size of the values in m1 such that pow(3)/cube()
// doesn't overflow and result in undefined behavior. Notice that because
// pow(int, int) promotes its inputs and output to double (according to
// the C++ standard), we hvae to make sure that the result fits in 53 bits
// the C++ standard), we have to make sure that the result fits in 53 bits
// for int64,
RealScalar max_val =
numext::mini(RealScalar(std::cbrt(NumTraits<RealScalar>::highest())),
@ -565,14 +606,6 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
pow_test<Scalar>();
typedef typename internal::make_integer<Scalar>::type SignedInt;
typedef typename std::make_unsigned<SignedInt>::type UnsignedInt;
int_pow_test<SignedInt, SignedInt>();
int_pow_test<SignedInt, UnsignedInt>();
int_pow_test<UnsignedInt, SignedInt>();
int_pow_test<UnsignedInt, UnsignedInt>();
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));
@ -590,6 +623,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m3, m1);
}
template<typename ArrayType> void array_complex(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;
@ -823,6 +857,11 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
}
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_6( int_pow_test() );
CALL_SUBTEST_7( mixed_pow_test() );
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Array2i>::type, ArrayBase<Array2i> >::value));