mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
Fix sqrt/rsqrt for NEON.
This commit is contained in:
parent
23755030c9
commit
28e008b99a
@ -3340,23 +3340,13 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
|
template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
|
||||||
// Compute approximate reciprocal sqrt.
|
|
||||||
Packet4f x = vrsqrteq_f32(a);
|
|
||||||
// Do Newton iterations for 1/sqrt(x).
|
// Do Newton iterations for 1/sqrt(x).
|
||||||
x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
|
return generic_rsqrt_newton_step<Packet4f, /*Steps=*/2>::run(a, vrsqrteq_f32(a));
|
||||||
x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
|
|
||||||
const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
|
|
||||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
|
template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
|
||||||
// Compute approximate reciprocal sqrt.
|
// Compute approximate reciprocal sqrt.
|
||||||
Packet2f x = vrsqrte_f32(a);
|
return generic_rsqrt_newton_step<Packet2f, /*Steps=*/2>::run(a, vrsqrte_f32(a));
|
||||||
// Do Newton iterations for 1/sqrt(x).
|
|
||||||
x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
|
|
||||||
x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
|
|
||||||
const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
|
|
||||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unfortunately vsqrt_f32 is only available for A64.
|
// Unfortunately vsqrt_f32 is only available for A64.
|
||||||
@ -3365,14 +3355,10 @@ template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_
|
|||||||
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
|
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
|
||||||
#else
|
#else
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
|
template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
|
||||||
const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
|
return generic_sqrt_newton_step<Packet4f>::run(a, prsqrt(a));
|
||||||
const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
|
|
||||||
return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
|
|
||||||
}
|
}
|
||||||
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
|
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
|
||||||
const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
|
return generic_sqrt_newton_step<Packet2f>::run(a, prsqrt(a));
|
||||||
const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
|
|
||||||
return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -3966,14 +3952,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
|
|||||||
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
|
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
|
template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
|
||||||
// Compute approximate reciprocal sqrt.
|
|
||||||
Packet2d x = vrsqrteq_f64(a);
|
|
||||||
// Do Newton iterations for 1/sqrt(x).
|
// Do Newton iterations for 1/sqrt(x).
|
||||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
return generic_rsqrt_newton_step<Packet2d, /*Steps=*/3>::run(a, vrsqrteq_f64(a));
|
||||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
|
||||||
x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
|
|
||||||
const Packet2d infinity = pset1<Packet2d>(NumTraits<double>::infinity());
|
|
||||||
return pselect(pcmp_eq(a, pzero(a)), infinity, x);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
|
template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
|
||||||
|
@ -954,11 +954,11 @@ void packetmath_real() {
|
|||||||
} else {
|
} else {
|
||||||
data1[1] = -((std::numeric_limits<Scalar>::min)());
|
data1[1] = -((std::numeric_limits<Scalar>::min)());
|
||||||
}
|
}
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||||
|
|
||||||
data1[0] = Scalar(0.0f);
|
data1[0] = Scalar(0.0f);
|
||||||
data1[1] = NumTraits<Scalar>::infinity();
|
data1[1] = NumTraits<Scalar>::infinity();
|
||||||
CHECK_CWISE1(numext::sqrt, internal::psqrt);
|
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (PacketTraits::HasRsqrt) {
|
if (PacketTraits::HasRsqrt) {
|
||||||
@ -968,11 +968,11 @@ void packetmath_real() {
|
|||||||
} else {
|
} else {
|
||||||
data1[1] = -((std::numeric_limits<Scalar>::min)());
|
data1[1] = -((std::numeric_limits<Scalar>::min)());
|
||||||
}
|
}
|
||||||
CHECK_CWISE1(numext::rsqrt, internal::prsqrt);
|
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||||
|
|
||||||
data1[0] = Scalar(0.0f);
|
data1[0] = Scalar(0.0f);
|
||||||
data1[1] = NumTraits<Scalar>::infinity();
|
data1[1] = NumTraits<Scalar>::infinity();
|
||||||
CHECK_CWISE1(numext::rsqrt, internal::prsqrt);
|
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(rmlarsen): Re-enable for half and bfloat16.
|
// TODO(rmlarsen): Re-enable for half and bfloat16.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user