Fix new generic nearest integer ops on GPU.

This commit is contained in:
Rasmus Munk Larsen 2024-04-30 22:18:25 +00:00
parent 0ee5c90aa9
commit 9000b37677
4 changed files with 41 additions and 13 deletions

View File

@ -2470,7 +2470,7 @@ struct unary_pow_impl<Packet, ScalarExponent, true, true, false> {
};
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
using IntType = typename numext::get_integer_by_size<sizeof(Scalar)>::signed_type;
// Adds and subtracts signum(a) * 2^kMantissaBits to force rounding.
@ -2490,7 +2490,7 @@ EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a) {
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_1 = pset1<Packet>(Scalar(1));
Packet rint_a = generic_rint(a);
@ -2502,7 +2502,7 @@ EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a) {
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_1 = pset1<Packet>(Scalar(1));
Packet rint_a = generic_rint(a);
@ -2514,7 +2514,7 @@ EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a) {
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a) {
Packet abs_a = pabs(a);
Packet sign_a = pandnot(a, abs_a);
Packet floor_abs_a = generic_floor(abs_a);
@ -2523,7 +2523,7 @@ EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a) {
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_round(const Packet& a) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_half = pset1<Packet>(Scalar(0.5));
const Packet cst_1 = pset1<Packet>(Scalar(1));

View File

@ -134,19 +134,19 @@ template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Packet& x);
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_rint(const Packet& a);
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_floor(const Packet& a);
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_ceil(const Packet& a);
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_trunc(const Packet& a);
template <typename Packet>
EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
// Macros for instantiating these generic functions for different backends.
#define EIGEN_PACKET_FUNCTION(METHOD, SCALAR, PACKET) \

View File

@ -74,7 +74,6 @@ struct packet_traits<float> : default_packet_traits {
HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,
HasBlend = 0
};
};
@ -106,9 +105,7 @@ struct packet_traits<double> : default_packet_traits {
HasGammaSampleDerAlpha = 1,
HasIGammac = 1,
HasBetaInc = 1,
HasBlend = 0,
HasFloor = 1,
};
};
@ -518,6 +515,33 @@ EIGEN_DEVICE_FUNC inline double2 pfloor<double2>(const double2& a) {
return make_double2(floor(a.x), floor(a.y));
}
template <>
EIGEN_DEVICE_FUNC inline float4 pceil<float4>(const float4& a) {
return make_float4(ceilf(a.x), ceilf(a.y), ceilf(a.z), ceilf(a.w));
}
template <>
EIGEN_DEVICE_FUNC inline double2 pceil<double2>(const double2& a) {
return make_double2(ceil(a.x), ceil(a.y));
}
template <>
EIGEN_DEVICE_FUNC inline float4 print<float4>(const float4& a) {
return make_float4(rintf(a.x), rintf(a.y), rintf(a.z), rintf(a.w));
}
template <>
EIGEN_DEVICE_FUNC inline double2 print<double2>(const double2& a) {
return make_double2(rint(a.x), rint(a.y));
}
template <>
EIGEN_DEVICE_FUNC inline float4 ptrunc<float4>(const float4& a) {
return make_float4(truncf(a.x), truncf(a.y), truncf(a.z), truncf(a.w));
}
template <>
EIGEN_DEVICE_FUNC inline double2 ptrunc<double2>(const double2& a) {
return make_double2(trunc(a.x), trunc(a.y));
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<float4, 4>& kernel) {
float tmp = kernel.packet[0].y;
kernel.packet[0].y = kernel.packet[1].x;

View File

@ -964,6 +964,10 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE constexpr void ignore_unused_variable(cons
// added then subtracted, which is otherwise compiled away with -ffast-math.
//
// See bug 1674
#if defined(EIGEN_GPU_COMPILE_PHASE)
#define EIGEN_OPTIMIZATION_BARRIER(X)
#endif
#if !defined(EIGEN_OPTIMIZATION_BARRIER)
#if EIGEN_COMP_GNUC
// According to https://gcc.gnu.org/onlinedocs/gcc/Constraints.html: