mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 17:33:15 +08:00
Fix pfrexp/pldexp for half.
The recent addition of vectorized pow (!330) relies on `pfrexp` and `pldexp`. This was missing for `Eigen::half` and `Eigen::bfloat16`. Adding tests for these packet ops also exposed an issue with handling negative values in `pfrexp`, returning an incorrect exponent. Added the missing implementations, corrected the exponent in `pfrexp1`, and added `packetmath` tests.
This commit is contained in:
parent
25d8498f8b
commit
b2126fd6b5
@ -442,7 +442,7 @@ template <typename Packet>
|
|||||||
EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
|
EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
|
||||||
int exp;
|
int exp;
|
||||||
EIGEN_USING_STD(frexp);
|
EIGEN_USING_STD(frexp);
|
||||||
Packet result = frexp(a, &exp);
|
Packet result = static_cast<Packet>(frexp(a, &exp));
|
||||||
exponent = static_cast<Packet>(exp);
|
exponent = static_cast<Packet>(exp);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -453,7 +453,7 @@ EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) {
|
|||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
pldexp(const Packet &a, const Packet &exponent) {
|
pldexp(const Packet &a, const Packet &exponent) {
|
||||||
EIGEN_USING_STD(ldexp)
|
EIGEN_USING_STD(ldexp)
|
||||||
return ldexp(a, static_cast<int>(exponent));
|
return static_cast<Packet>(ldexp(a, static_cast<int>(exponent)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal \returns the min of \a a and \a b (coeff-wise) */
|
/** \internal \returns the min of \a a and \a b (coeff-wise) */
|
||||||
|
@ -184,6 +184,19 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
|
|||||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
||||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) {
|
||||||
|
Packet8f fexponent;
|
||||||
|
const Packet8h out = float2half(pfrexp<Packet8f>(half2float(a), fexponent));
|
||||||
|
exponent = float2half(fexponent);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) {
|
||||||
|
return float2half(pldexp<Packet8f>(half2float(a), half2float(exponent)));
|
||||||
|
}
|
||||||
|
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
|
||||||
@ -195,6 +208,19 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
|
|||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
||||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) {
|
||||||
|
Packet8f fexponent;
|
||||||
|
const Packet8bf out = F32ToBf16(pfrexp<Packet8f>(Bf16ToF32(a), fexponent));
|
||||||
|
exponent = F32ToBf16(fexponent);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) {
|
||||||
|
return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent)));
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -191,6 +191,32 @@ pexp<Packet8d>(const Packet8d& _x) {
|
|||||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
|
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
|
||||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
|
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) {
|
||||||
|
Packet16f fexponent;
|
||||||
|
const Packet16h out = float2half(pfrexp<Packet16f>(half2float(a), fexponent));
|
||||||
|
exponent = float2half(fexponent);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) {
|
||||||
|
return float2half(pldexp<Packet16f>(half2float(a), half2float(exponent)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) {
|
||||||
|
Packet16f fexponent;
|
||||||
|
const Packet16bf out = F32ToBf16(pfrexp<Packet16f>(Bf16ToF32(a), fexponent));
|
||||||
|
exponent = F32ToBf16(fexponent);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) {
|
||||||
|
return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent)));
|
||||||
|
}
|
||||||
|
|
||||||
// Functions for sqrt.
|
// Functions for sqrt.
|
||||||
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
|
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
|
||||||
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
|
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
|
||||||
|
@ -31,7 +31,7 @@ pfrexp_float(const Packet& a, Packet& exponent) {
|
|||||||
const Packet cst_126f = pset1<Packet>(126.0f);
|
const Packet cst_126f = pset1<Packet>(126.0f);
|
||||||
const Packet cst_half = pset1<Packet>(0.5f);
|
const Packet cst_half = pset1<Packet>(0.5f);
|
||||||
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
|
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
|
||||||
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(a))), cst_126f);
|
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f);
|
||||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ pfrexp_double(const Packet& a, Packet& exponent) {
|
|||||||
const Packet cst_1022d = pset1<Packet>(1022.0);
|
const Packet cst_1022d = pset1<Packet>(1022.0);
|
||||||
const Packet cst_half = pset1<Packet>(0.5);
|
const Packet cst_half = pset1<Packet>(0.5);
|
||||||
const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
||||||
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(a))), cst_1022d);
|
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
|
||||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,6 +44,18 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
|
|||||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
|
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp)
|
||||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
|
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) {
|
||||||
|
Packet4f fexponent;
|
||||||
|
const Packet4bf out = F32ToBf16(pfrexp<Packet4f>(Bf16ToF32(a), fexponent));
|
||||||
|
exponent = F32ToBf16(fexponent);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) {
|
||||||
|
return F32ToBf16(pldexp<Packet4f>(Bf16ToF32(a), Bf16ToF32(exponent)));
|
||||||
|
}
|
||||||
|
|
||||||
//---------- double ----------
|
//---------- double ----------
|
||||||
|
|
||||||
|
@ -46,6 +46,21 @@ inline bool REF_MUL(const bool& a, const bool& b) {
|
|||||||
return a && b;
|
return a && b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T REF_FREXP(const T& x, T& exp) {
|
||||||
|
int iexp;
|
||||||
|
EIGEN_USING_STD(frexp)
|
||||||
|
const T out = static_cast<T>(frexp(x, &iexp));
|
||||||
|
exp = static_cast<T>(iexp);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T REF_LDEXP(const T& x, const T& exp) {
|
||||||
|
EIGEN_USING_STD(ldexp)
|
||||||
|
return static_cast<T>(ldexp(x, static_cast<int>(exp)));
|
||||||
|
}
|
||||||
|
|
||||||
// Uses pcast to cast from one array to another.
|
// Uses pcast to cast from one array to another.
|
||||||
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
||||||
struct pcast_array;
|
struct pcast_array;
|
||||||
@ -552,6 +567,17 @@ void packetmath_real() {
|
|||||||
data2[i] = Scalar(internal::random<double>(-87, 88));
|
data2[i] = Scalar(internal::random<double>(-87, 88));
|
||||||
}
|
}
|
||||||
CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp);
|
CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp);
|
||||||
|
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i] = Scalar(internal::random<double>(-1, 1));
|
||||||
|
data2[i] = Scalar(internal::random<double>(-1, 1));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i+PacketSize] = Scalar(internal::random<int>(0, 4));
|
||||||
|
data2[i+PacketSize] = Scalar(internal::random<double>(0, 4));
|
||||||
|
}
|
||||||
|
CHECK_CWISE2_IF(PacketTraits::HasExp, REF_LDEXP, internal::pldexp);
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
data1[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6)));
|
data1[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6)));
|
||||||
data2[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6)));
|
data2[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6)));
|
||||||
|
@ -143,6 +143,9 @@ struct packet_helper
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline void store(T* to, const Packet& x, unsigned long long umask) const { internal::pstoreu(to, x, umask); }
|
inline void store(T* to, const Packet& x, unsigned long long umask) const { internal::pstoreu(to, x, umask); }
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline Packet& forward_reference(Packet& packet, T& /*scalar*/) const { return packet; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
@ -162,6 +165,9 @@ struct packet_helper<false,Packet>
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline void store(T* to, const T& x, unsigned long long) const { *to = x; }
|
inline void store(T* to, const T& x, unsigned long long) const { *to = x; }
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline T& forward_reference(Packet& /*packet*/, T& scalar) const { return scalar; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CHECK_CWISE1_IF(COND, REFOP, POP) if(COND) { \
|
#define CHECK_CWISE1_IF(COND, REFOP, POP) if(COND) { \
|
||||||
@ -180,6 +186,18 @@ struct packet_helper<false,Packet>
|
|||||||
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
|
VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// One input, one output by reference.
|
||||||
|
#define CHECK_CWISE1_BYREF1_IF(COND, REFOP, POP) if(COND) { \
|
||||||
|
test::packet_helper<COND,Packet> h; \
|
||||||
|
for (int i=0; i<PacketSize; ++i) \
|
||||||
|
ref[i] = Scalar(REFOP(data1[i], ref[i+PacketSize])); \
|
||||||
|
Packet pout; \
|
||||||
|
Scalar sout; \
|
||||||
|
h.store(data2, POP(h.load(data1), h.forward_reference(pout, sout))); \
|
||||||
|
h.store(data2+PacketSize, h.forward_reference(pout, sout)); \
|
||||||
|
VERIFY(test::areApprox(ref, data2, 2 * PacketSize) && #POP); \
|
||||||
|
}
|
||||||
|
|
||||||
#define CHECK_CWISE3_IF(COND, REFOP, POP) if (COND) { \
|
#define CHECK_CWISE3_IF(COND, REFOP, POP) if (COND) { \
|
||||||
test::packet_helper<COND, Packet> h; \
|
test::packet_helper<COND, Packet> h; \
|
||||||
for (int i = 0; i < PacketSize; ++i) \
|
for (int i = 0; i < PacketSize; ++i) \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user