Implement integer square-root for NEON

This commit is contained in:
Joel Holdsworth 2020-03-19 17:05:13 +00:00 committed by Rasmus Munk Larsen
parent 37ccb86916
commit 54aa8fa186
2 changed files with 102 additions and 4 deletions

View File

@ -216,7 +216,9 @@ struct packet_traits<uint8_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasReduxp = 1
HasReduxp = 1,
HasSqrt = 1
};
};
@ -272,7 +274,9 @@ struct packet_traits<uint16_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasReduxp = 1
HasReduxp = 1,
HasSqrt = 1
};
};
@ -328,7 +332,9 @@ struct packet_traits<uint32_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
HasReduxp = 1
HasReduxp = 1,
HasSqrt = 1
};
};
@ -3344,6 +3350,97 @@ ptranspose(PacketBlock<Packet2ul, 2>& kernel)
#endif
}
/**
* Computes the integer square root
* @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
* and tests whether setting that digit to 1 would cause the square of the value to be greater than the argument
* value. The algorithm is described in detail here: http://ww1.microchip.com/downloads/en/AppNotes/91040a.pdf .
*/
template<> EIGEN_STRONG_INLINE Packet4uc psqrt(const Packet4uc& a) {
uint8x8_t x = vreinterpret_u8_u32(vdup_n_u32(a));
uint8x8_t res = vdup_n_u8(0);
uint8x8_t add = vdup_n_u8(0x8);
for (int i = 0; i < 4; i++)
{
const uint8x8_t temp = vorr_u8(res, add);
res = vbsl_u8(vcge_u8(x, vmul_u8(temp, temp)), temp, res);
add = vshr_n_u8(add, 1);
}
return vget_lane_u32(vreinterpret_u32_u8(res), 0);
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet8uc psqrt(const Packet8uc& a) {
uint8x8_t res = vdup_n_u8(0);
uint8x8_t add = vdup_n_u8(0x8);
for (int i = 0; i < 4; i++)
{
const uint8x8_t temp = vorr_u8(res, add);
res = vbsl_u8(vcge_u8(a, vmul_u8(temp, temp)), temp, res);
add = vshr_n_u8(add, 1);
}
return res;
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet16uc psqrt(const Packet16uc& a) {
uint8x16_t res = vdupq_n_u8(0);
uint8x16_t add = vdupq_n_u8(0x8);
for (int i = 0; i < 4; i++)
{
const uint8x16_t temp = vorrq_u8(res, add);
res = vbslq_u8(vcgeq_u8(a, vmulq_u8(temp, temp)), temp, res);
add = vshrq_n_u8(add, 1);
}
return res;
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet4us psqrt(const Packet4us& a) {
uint16x4_t res = vdup_n_u16(0);
uint16x4_t add = vdup_n_u16(0x80);
for (int i = 0; i < 8; i++)
{
const uint16x4_t temp = vorr_u16(res, add);
res = vbsl_u16(vcge_u16(a, vmul_u16(temp, temp)), temp, res);
add = vshr_n_u16(add, 1);
}
return res;
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet8us psqrt(const Packet8us& a) {
uint16x8_t res = vdupq_n_u16(0);
uint16x8_t add = vdupq_n_u16(0x80);
for (int i = 0; i < 8; i++)
{
const uint16x8_t temp = vorrq_u16(res, add);
res = vbslq_u16(vcgeq_u16(a, vmulq_u16(temp, temp)), temp, res);
add = vshrq_n_u16(add, 1);
}
return res;
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet2ui psqrt(const Packet2ui& a) {
uint32x2_t res = vdup_n_u32(0);
uint32x2_t add = vdup_n_u32(0x8000);
for (int i = 0; i < 16; i++)
{
const uint32x2_t temp = vorr_u32(res, add);
res = vbsl_u32(vcge_u32(a, vmul_u32(temp, temp)), temp, res);
add = vshr_n_u32(add, 1);
}
return res;
}
/// @copydoc Eigen::internal::psqrt(const Packet4uc& a)
template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
uint32x4_t res = vdupq_n_u32(0);
uint32x4_t add = vdupq_n_u32(0x8000);
for (int i = 0; i < 16; i++)
{
const uint32x4_t temp = vorrq_u32(res, add);
res = vbslq_u32(vcgeq_u32(a, vmulq_u32(temp, temp)), temp, res);
add = vshrq_n_u32(add, 1);
}
return res;
}
//---------- double ----------
// Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double.

View File

@ -320,6 +320,8 @@ template<typename Scalar,typename Packet> void packetmath()
}
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
}
template<typename Scalar,typename Packet> void packetmath_real()
@ -341,7 +343,6 @@ template<typename Scalar,typename Packet> void packetmath_real()
if(internal::random<float>(0,1)<0.1f)
data1[internal::random<int>(0, PacketSize)] = 0;
CHECK_CWISE1_IF(PacketTraits::HasSqrt, std::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, Scalar(1)/std::sqrt, internal::prsqrt);