Fix packetmath compilation error.

This commit is contained in:
Antonio Sánchez 2022-02-23 23:27:08 +00:00
parent 8970719771
commit 3d7e2d0e3e

View File

@ -657,10 +657,23 @@ Scalar log2(Scalar x) {
return Scalar(EIGEN_LOG2E) * std::log(x);
}
// Create a functor out of a function so it can be passed (with overloads)
// to another function as an input argument.
#define CREATE_FUNCTOR(Name, Func) \
struct Name { \
template<typename T> \
T operator()(const T& val) const { \
return Func(val); \
} \
}
CREATE_FUNCTOR(psqrt_functor, internal::psqrt);
CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt);
// TODO(rmlarsen): Run this test for more functions.
template <bool Cond, typename Scalar, typename Packet, typename REF_FUNCTOR_T, typename FUNCTOR_T>
void packetmath_test_IEEE_corner_cases(const REF_FUNCTOR_T& ref_fun,
const FUNCTOR_T& fun) {
template <bool Cond, typename Scalar, typename Packet, typename RefFunctorT, typename FunctorT>
void packetmath_test_IEEE_corner_cases(const RefFunctorT& ref_fun,
const FunctorT& fun) {
const int PacketSize = internal::unpacket_traits<Packet>::size;
const Scalar norm_min = (std::numeric_limits<Scalar>::min)();
@ -1013,8 +1026,8 @@ void packetmath_real() {
VERIFY((numext::isnan)(data2[1]));
}
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, internal::psqrt<Packet>);
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, internal::prsqrt<Packet>);
packetmath_test_IEEE_corner_cases<PacketTraits::HasSqrt, Scalar, Packet>(numext::sqrt<Scalar>, psqrt_functor());
packetmath_test_IEEE_corner_cases<PacketTraits::HasRsqrt, Scalar, Packet>(numext::rsqrt<Scalar>, prsqrt_functor());
// TODO(rmlarsen): Re-enable for half and bfloat16.
if (PacketTraits::HasCos