mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 14:53:13 +08:00
Vectorize pow for integer base / exponent types
This commit is contained in:
parent
8acbf5c11c
commit
e5af9f87f2
@ -208,6 +208,7 @@ template<> struct packet_traits<int> : default_packet_traits
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
HasCmp = 1,
|
||||
size=8
|
||||
};
|
||||
};
|
||||
@ -222,6 +223,7 @@ template<> struct packet_traits<int64_t> : default_packet_traits
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
HasCmp = 1,
|
||||
size=4,
|
||||
|
||||
// requires AVX512
|
||||
|
@ -174,6 +174,7 @@ template<> struct packet_traits<int> : default_packet_traits
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
HasCmp = 1,
|
||||
size=16
|
||||
};
|
||||
};
|
||||
|
@ -253,7 +253,8 @@ struct packet_traits<int> : default_packet_traits {
|
||||
HasShift = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 0,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasCmp = 1
|
||||
};
|
||||
};
|
||||
|
||||
@ -271,7 +272,8 @@ struct packet_traits<short int> : default_packet_traits {
|
||||
HasSub = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 0,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasCmp = 1
|
||||
};
|
||||
};
|
||||
|
||||
@ -289,7 +291,8 @@ struct packet_traits<unsigned short int> : default_packet_traits {
|
||||
HasSub = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 0,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasCmp = 1
|
||||
};
|
||||
};
|
||||
|
||||
@ -307,7 +310,8 @@ struct packet_traits<signed char> : default_packet_traits {
|
||||
HasSub = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 0,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasCmp = 1
|
||||
};
|
||||
};
|
||||
|
||||
@ -325,7 +329,8 @@ struct packet_traits<unsigned char> : default_packet_traits {
|
||||
HasSub = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 0,
|
||||
HasBlend = 1
|
||||
HasBlend = 1,
|
||||
HasCmp = 1
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -1880,6 +1880,36 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(
|
||||
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// integer base, integer exponent case
|
||||
|
||||
// This routine handles negative and very large positive exponents
|
||||
// Signed integer overflow and divide by zero is undefined behavior
|
||||
// Unsigned intgers do not overflow
|
||||
|
||||
const bool exponent_is_odd = unary_pow::is_odd<ScalarExponent>::run(exponent);
|
||||
|
||||
const Scalar zero = Scalar(0);
|
||||
const Scalar pos_one = Scalar(1);
|
||||
|
||||
const Packet cst_zero = pset1<Packet>(zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
|
||||
const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, abs_x) : pzero(x);
|
||||
const Packet pow_is_one = pcmp_eq(cst_pos_one, abs_x);
|
||||
const Packet pow_is_neg = exponent_is_odd ? pcmp_lt(x, cst_zero) : pzero(x);
|
||||
|
||||
Packet result = pselect(pow_is_zero, cst_zero, x);
|
||||
result = pselect(pandnot(pow_is_one, pow_is_neg), cst_pos_one, result);
|
||||
result = pselect(pand(pow_is_one, pow_is_neg), pnegate(cst_pos_one), result);
|
||||
return result;
|
||||
}
|
||||
} // end namespace unary_pow
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
@ -1914,6 +1944,19 @@ struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, true, true> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
if (exponent < 0 || exponent > NumTraits<Scalar>::digits()) {
|
||||
return unary_pow::handle_int_int(x, exponent);
|
||||
}
|
||||
else {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -191,6 +191,7 @@ template<> struct packet_traits<int> : default_packet_traits
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
HasCmp = 1,
|
||||
size=4,
|
||||
|
||||
HasShift = 1,
|
||||
|
@ -1109,8 +1109,8 @@ struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false, false, false> {
|
||||
scalar_unary_pow_op() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename ScalarExponent>
|
||||
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true, false, false> {
|
||||
template <typename Scalar, typename ScalarExponent, bool BaseIsInteger>
|
||||
struct scalar_unary_pow_op<Scalar, ScalarExponent, BaseIsInteger, true, false, false> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||
}
|
||||
@ -1132,7 +1132,7 @@ template <typename Scalar, typename ScalarExponent>
|
||||
struct functor_traits<scalar_unary_pow_op<Scalar, ScalarExponent>> {
|
||||
enum {
|
||||
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::PacketAccess,
|
||||
IntPacketAccess = !NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasCmp,
|
||||
IntPacketAccess = !NumTraits<Scalar>::IsComplex && packet_traits<Scalar>::HasMul && (packet_traits<Scalar>::HasDiv || NumTraits<Scalar>::IsInteger) && packet_traits<Scalar>::HasCmp,
|
||||
PacketAccess = NumTraits<ScalarExponent>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
|
||||
Cost = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::Cost
|
||||
};
|
||||
|
@ -126,6 +126,72 @@ void pow_test() {
|
||||
VERIFY(all_pass);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename ScalarExponent>
|
||||
Scalar calc_overflow_threshold(const ScalarExponent exponent) {
|
||||
EIGEN_USING_STD(exp2);
|
||||
EIGEN_STATIC_ASSERT((NumTraits<Scalar>::digits() < 2 * NumTraits<double>::digits()), BASE_TYPE_IS_TOO_BIG);
|
||||
|
||||
if (exponent < 2)
|
||||
return NumTraits<Scalar>::highest();
|
||||
else {
|
||||
const double max_exponent = static_cast<double>(NumTraits<Scalar>::digits());
|
||||
const double clamped_exponent = exponent < max_exponent ? static_cast<double>(exponent) : max_exponent;
|
||||
const double threshold = exp2(max_exponent / clamped_exponent);
|
||||
return static_cast<Scalar>(threshold);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Base, typename Exponent>
|
||||
void test_exponent(Exponent exponent) {
|
||||
EIGEN_USING_STD(pow);
|
||||
|
||||
const Base max_abs_bases = 10000;
|
||||
// avoid integer overflow in Base type
|
||||
Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
|
||||
// avoid numbers that can't be verified with std::pow
|
||||
double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
|
||||
// use the lesser of these two thresholds
|
||||
Base testing_threshold = threshold < double_threshold ? threshold : static_cast<Base>(double_threshold);
|
||||
// avoid divide by zero
|
||||
Base min_abs_base = exponent < 0 ? 1 : 0;
|
||||
// avoid excessively long test
|
||||
Base max_abs_base = numext::mini(testing_threshold, max_abs_bases);
|
||||
// test both vectorized and non-vectorized code paths
|
||||
const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
|
||||
|
||||
Base max_base = numext::mini(testing_threshold, max_abs_bases);
|
||||
Base min_base = NumTraits<Base>::IsSigned ? -max_base : 0;
|
||||
|
||||
ArrayX<Base> x(array_size), y(array_size);
|
||||
|
||||
bool all_pass = true;
|
||||
|
||||
for (Base base = min_base; base <= max_base; base++) {
|
||||
if (exponent < 0 && base == 0) continue;
|
||||
x.setConstant(base);
|
||||
y = x.pow(exponent);
|
||||
Base e = pow(base, exponent);
|
||||
for (Base a : y) {
|
||||
bool pass = a == e;
|
||||
all_pass &= pass;
|
||||
if (!pass) {
|
||||
std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VERIFY(all_pass);
|
||||
}
|
||||
template <typename Base, typename Exponent>
|
||||
void int_pow_test() {
|
||||
Exponent max_exponent = NumTraits<Base>::digits();
|
||||
Exponent min_exponent = NumTraits<Exponent>::IsSigned ? -max_exponent : 0;
|
||||
|
||||
for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) {
|
||||
test_exponent<Base, Exponent>(exponent);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType> void array(const ArrayType& m)
|
||||
{
|
||||
typedef typename ArrayType::Scalar Scalar;
|
||||
@ -500,6 +566,14 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
||||
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
|
||||
pow_test<Scalar>();
|
||||
|
||||
typedef typename internal::make_integer<Scalar>::type SignedInt;
|
||||
typedef typename std::make_unsigned<SignedInt>::type UnsignedInt;
|
||||
|
||||
int_pow_test<SignedInt, SignedInt>();
|
||||
int_pow_test<SignedInt, UnsignedInt>();
|
||||
int_pow_test<UnsignedInt, SignedInt>();
|
||||
int_pow_test<UnsignedInt, UnsignedInt>();
|
||||
|
||||
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
|
||||
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user