mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 11:19:02 +08:00
Fix floor/ceil for NEON fp16.
Forgot to test this. Fixes bug introduced in !416.
This commit is contained in:
parent
5529db7524
commit
e19829c3b0
@ -4182,62 +4182,6 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, cons
|
||||
return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
|
||||
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
|
||||
// If greater, substract one.
|
||||
uint16x8_t mask = vcgtq_f16(tmp, a);
|
||||
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
|
||||
tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
|
||||
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
|
||||
// If greater, substract one.
|
||||
uint16x4_t mask = vcgt_f16(tmp, a);
|
||||
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
|
||||
tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
|
||||
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
|
||||
// If smaller, add one.
|
||||
uint16x8_t mask = vcltq_f16(tmp, a);
|
||||
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
|
||||
tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
|
||||
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
|
||||
// If smaller, add one.
|
||||
uint16x4_t mask = vclt_f16(tmp, a);
|
||||
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
|
||||
tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
|
||||
return vsqrtq_f16(a);
|
||||
@ -4472,6 +4416,62 @@ EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
|
||||
return vabs_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
|
||||
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
|
||||
// If greater, substract one.
|
||||
uint16x8_t mask = vcgtq_f16(tmp, a);
|
||||
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
|
||||
tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
|
||||
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
|
||||
// If greater, substract one.
|
||||
uint16x4_t mask = vcgt_f16(tmp, a);
|
||||
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
|
||||
tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
|
||||
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
|
||||
// If smaller, add one.
|
||||
uint16x8_t mask = vcltq_f16(tmp, a);
|
||||
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
|
||||
tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
|
||||
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
|
||||
// Round to nearest.
|
||||
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
|
||||
// If smaller, add one.
|
||||
uint16x4_t mask = vclt_f16(tmp, a);
|
||||
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
|
||||
tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
|
||||
// Handle saturation cases.
|
||||
const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
|
||||
return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi, sum;
|
||||
|
@ -551,7 +551,9 @@ void packetmath_real() {
|
||||
if (PacketTraits::HasRound || PacketTraits::HasCeil || PacketTraits::HasFloor || PacketTraits::HasRint) {
|
||||
typedef typename internal::make_integer<Scalar>::type IntType;
|
||||
// Start with values that cannot fit inside an integer, work down to less than one.
|
||||
Scalar val = Scalar(2) * static_cast<Scalar>(NumTraits<IntType>::highest());
|
||||
Scalar val = numext::mini(
|
||||
Scalar(2) * static_cast<Scalar>(NumTraits<IntType>::highest()),
|
||||
NumTraits<Scalar>::highest());
|
||||
std::vector<Scalar> values;
|
||||
while (val > Scalar(0.25)) {
|
||||
// Cover both even and odd, positive and negative cases.
|
||||
|
Loading…
x
Reference in New Issue
Block a user