Add signbit function

This commit is contained in:
Charles Schlosser 2022-11-04 00:31:20 +00:00 committed by Rasmus Munk Larsen
parent 8f8e36458f
commit 82b152dbe7
12 changed files with 332 additions and 9 deletions

View File

@ -563,13 +563,13 @@ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
parg(const Packet& a) { using numext::arg; return arg(a); }
/** \internal \returns \a a logically shifted by N bits to the right */
/** \internal \returns \a a arithmetically shifted by N bits to the right */
template<int N> EIGEN_DEVICE_FUNC inline int
parithmetic_shift_right(const int& a) { return a >> N; }
template<int N> EIGEN_DEVICE_FUNC inline long int
parithmetic_shift_right(const long int& a) { return a >> N; }
/** \internal \returns \a a arithmetically shifted by N bits to the right */
/** \internal \returns \a a logically shifted by N bits to the right */
template<int N> EIGEN_DEVICE_FUNC inline int
plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); }
template<int N> EIGEN_DEVICE_FUNC inline long int
@ -1191,6 +1191,34 @@ Packet prsqrt(const Packet& a) {
return preciprocal<Packet>(psqrt(a));
}
template <typename Packet, bool IsScalar = is_scalar<Packet>::value,
bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
struct psignbit_impl;
template <typename Packet, bool IsInteger>
struct psignbit_impl<Packet, true, IsInteger> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return numext::signbit(a); }
};
template <typename Packet>
struct psignbit_impl<Packet, false, false> {
// generic implementation if not specialized in PacketMath.h
// slower than arithmetic shift
typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Packet run(const Packet& a) {
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
const Packet cst_neg_one = pset1<Packet>(Scalar(-1));
return pcmp_eq(por(pand(a, cst_neg_one), cst_pos_one), cst_neg_one);
}
};
template <typename Packet>
struct psignbit_impl<Packet, false, true> {
// generic implementation for integer packets
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return pcmp_lt(a, pzero(a)); }
};
/** \internal \returns the sign bit of \a a as a bitmask*/
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr Packet
psignbit(const Packet& a) { return psignbit_impl<Packet>::run(a); }
} // end namespace internal
} // end namespace Eigen

View File

@ -1531,6 +1531,37 @@ double abs(const std::complex<double>& x) {
}
#endif
template <typename Scalar, bool IsInteger = NumTraits<Scalar>::IsInteger, bool IsSigned = NumTraits<Scalar>::IsSigned>
struct signbit_impl;
template <typename Scalar>
struct signbit_impl<Scalar, false, true> {
static constexpr size_t Size = sizeof(Scalar);
static constexpr size_t Shift = (CHAR_BIT * Size) - 1;
using intSize_t = typename get_integer_by_size<Size>::signed_type;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Scalar run(const Scalar& x) {
intSize_t a = bit_cast<intSize_t, Scalar>(x);
a = a >> Shift;
Scalar result = bit_cast<Scalar, intSize_t>(a);
return result;
}
};
template <typename Scalar>
struct signbit_impl<Scalar, true, true> {
static constexpr size_t Size = sizeof(Scalar);
static constexpr size_t Shift = (CHAR_BIT * Size) - 1;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& x) { return x >> Shift; }
};
template <typename Scalar>
struct signbit_impl<Scalar, true, false> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& ) {
return Scalar(0);
}
};
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar signbit(const Scalar& x) {
return signbit_impl<Scalar>::run(x);
}
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T exp(const T &x) {

View File

@ -95,7 +95,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) {
// Load src into registers first. This allows the memcpy to be elided by CUDA.
const Src staged = src;
EIGEN_USING_STD(memcpy)
memcpy(&tgt, &staged, sizeof(Tgt));
memcpy(static_cast<void*>(&tgt),static_cast<const void*>(&staged), sizeof(Tgt));
return tgt;
}
} // namespace numext

View File

@ -229,10 +229,7 @@ template<> struct packet_traits<int64_t> : default_packet_traits
Vectorizable = 1,
AlignedOnScalar = 1,
HasCmp = 1,
size=4,
// requires AVX512
HasShift = 0,
size=4
};
};
#endif
@ -360,6 +357,35 @@ template <int N>
EIGEN_STRONG_INLINE Packet4l plogical_shift_left(Packet4l a) {
return _mm256_slli_epi64(a, N);
}
#ifdef EIGEN_VECTORIZE_AVX512FP16
template <int N>
EIGEN_STRONG_INLINE Packet4l parithmetic_shift_right(Packet4l a) { return _mm256_srai_epi64(a, N); }
#else
template <int N>
EIGEN_STRONG_INLINE std::enable_if_t< (N == 0), Packet4l> parithmetic_shift_right(Packet4l a) {
return a;
}
template <int N>
EIGEN_STRONG_INLINE std::enable_if_t< (N > 0) && (N < 32), Packet4l> parithmetic_shift_right(Packet4l a) {
__m256i hi_word = _mm256_srai_epi32(a, N);
__m256i lo_word = _mm256_srli_epi64(a, N);
return _mm256_blend_epi32(hi_word, lo_word, 0b01010101);
}
template <int N>
EIGEN_STRONG_INLINE std::enable_if_t< (N >= 32) && (N < 63), Packet4l> parithmetic_shift_right(Packet4l a) {
__m256i hi_word = _mm256_srai_epi32(a, 31);
__m256i lo_word = _mm256_shuffle_epi32(_mm256_srai_epi32(a, N - 32), (shuffle_mask<1, 1, 3, 3>::mask));
return _mm256_blend_epi32(hi_word, lo_word, 0b01010101);
}
template <int N>
EIGEN_STRONG_INLINE std::enable_if_t< (N == 63), Packet4l> parithmetic_shift_right(Packet4l a) {
return _mm256_shuffle_epi32(_mm256_srai_epi32(a, 31), (shuffle_mask<1, 1, 3, 3>::mask));
}
template <int N>
EIGEN_STRONG_INLINE std::enable_if_t< (N < 0) || (N > 63), Packet4l> parithmetic_shift_right(Packet4l a) {
return parithmetic_shift_right<int(N&63)>(a);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet4l pload<Packet4l>(const int64_t* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
@ -1103,6 +1129,11 @@ template<> EIGEN_STRONG_INLINE Packet8i pabs(const Packet8i& a)
#endif
}
template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { return _mm_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return _mm_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet8f psignbit(const Packet8f& a) { return _mm256_castsi256_ps(parithmetic_shift_right<31>((Packet8i)_mm256_castps_si256(a))); }
template<> EIGEN_STRONG_INLINE Packet4d psignbit(const Packet4d& a) { return _mm256_castsi256_pd(parithmetic_shift_right<63>((Packet4l)_mm256_castpd_si256(a))); }
template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
return pfrexp_generic(a,exponent);
}

View File

@ -1127,6 +1127,11 @@ template<> EIGEN_STRONG_INLINE Packet16i pabs(const Packet16i& a)
return _mm512_abs_epi32(a);
}
template<> EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { return _mm256_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) { return _mm256_srai_epi16(a, 15); }
template<> EIGEN_STRONG_INLINE Packet16f psignbit(const Packet16f& a) { return _mm512_castsi512_ps(_mm512_srai_epi32(_mm512_castps_si512(a), 31)); }
template<> EIGEN_STRONG_INLINE Packet8d psignbit(const Packet8d& a) { return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), 63)); }
template<>
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
return pfrexp_generic(a, exponent);

View File

@ -196,6 +196,13 @@ EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
return _mm512_abs_ph(a);
}
// psignbit
template <>
EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
}
// pmin
template <>

View File

@ -1575,6 +1575,9 @@ template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
return pand<Packet8us>(p8us_abs_mask, a);
}
template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return vec_sra(a.m_val, vec_splat_u16(15)); }
template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return (Packet4f)vec_sra((Packet4i)a, vec_splats(uint32_t(31))); }
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
@ -2928,7 +2931,7 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
return vec_sld(a, a, 8);
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats(uint64_t(63))); }
// VSX support varies between different compilers and even different
// versions of the same compiler. For gcc version >= 4.9.3, we can use
// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use

View File

@ -2372,6 +2372,12 @@ template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) {
}
template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4h psignbit(const Packet4h& a) { vreinterpret_f16_s16( vshr_n_s16( vreinterpret_s16_f16(a), 15)); }
template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { vreinterpretq_f16_s16(vshrq_n_s16(vreinterpretq_s16_f16(a), 15)); }
template<> EIGEN_STRONG_INLINE Packet2f psignbit(const Packet2f& a) { vreinterpret_f32_s32( vshr_n_s32( vreinterpret_s32_f32(a), 31)); }
template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { vreinterpretq_f32_s32(vshrq_n_s32(vreinterpretq_s32_f32(a), 31)); }
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { vreinterpretq_f64_s64(vshrq_n_s64(vreinterpretq_s64_f64(a), 63)); }
template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
{ return pfrexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)

View File

@ -649,6 +649,17 @@ template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
#endif
}
template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return _mm_castsi128_ps(_mm_srai_epi32(_mm_castps_si128(a), 31)); }
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
{
Packet4f tmp = psignbit<Packet4f>(_mm_castpd_ps(a));
#ifdef EIGEN_VECTORIZE_AVX
return _mm_castps_pd(_mm_permute_ps(tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
#else
return _mm_castps_pd(_mm_shuffle_ps(tmp, tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
#endif // EIGEN_VECTORIZE_AVX
}
#ifdef EIGEN_VECTORIZE_SSE4_1
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
{

View File

@ -43,6 +43,32 @@ typedef std::uint32_t uint32_t;
typedef std::int32_t int32_t;
typedef std::uint64_t uint64_t;
typedef std::int64_t int64_t;
template <size_t Size>
struct get_integer_by_size {
typedef void signed_type;
typedef void unsigned_type;
};
template <>
struct get_integer_by_size<1> {
typedef int8_t signed_type;
typedef uint8_t unsigned_type;
};
template <>
struct get_integer_by_size<2> {
typedef int16_t signed_type;
typedef uint16_t unsigned_type;
};
template <>
struct get_integer_by_size<4> {
typedef int32_t signed_type;
typedef uint32_t unsigned_type;
};
template <>
struct get_integer_by_size<8> {
typedef int64_t signed_type;
typedef uint64_t unsigned_type;
};
}
}

View File

@ -219,7 +219,7 @@ void unary_pow_test() {
for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
test_exponent<Base, Exponent>(exponent);
}
};
}
void mixed_pow_test() {
// The following cases will test promoting a smaller exponent type
@ -260,6 +260,81 @@ void int_pow_test() {
unary_pow_test<long long, int>();
}
namespace Eigen {
namespace internal {
template <typename Scalar>
struct test_signbit_op {
Scalar constexpr operator()(const Scalar& a) const { return numext::signbit(a); }
template <typename Packet>
inline Packet packetOp(const Packet& a) const {
return psignbit(a);
}
};
template <typename Scalar>
struct functor_traits<test_signbit_op<Scalar>> {
enum { Cost = 1, PacketAccess = true }; //todo: define HasSignbit flag
};
} // namespace internal
} // namespace Eigen
template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
struct ref_signbit_func_impl {
static bool run(const T& x) { return std::signbit(x); }
};
template <typename T>
struct ref_signbit_func_impl<T, true> {
// MSVC (perhaps others) does not have a std::signbit overload for integers
static bool run(const T& x) { return x < T(0); }
};
template <typename T>
bool ref_signbit_func(const T& x) {
return ref_signbit_func_impl<T>::run(x);
}
template <typename Scalar>
void signbit_test() {
Scalar true_mask;
std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(Scalar));
Scalar false_mask;
std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(Scalar));
const size_t size = 100 * internal::packet_traits<Scalar>::size;
ArrayX<Scalar> x(size), y(size);
x.setRandom();
std::vector<Scalar> special_vals = special_values<Scalar>();
for (size_t i = 0; i < special_vals.size(); i++) {
x(2 * i + 0) = special_vals[i];
x(2 * i + 1) = -special_vals[i];
}
y = x.unaryExpr(internal::test_signbit_op<Scalar>());
bool all_pass = true;
for (size_t i = 0; i < size; i++) {
const Scalar ref_val = ref_signbit_func(x(i)) ? true_mask : false_mask;
bool not_same = internal::predux_any(internal::bitwise_helper<Scalar>::bitwise_xor(ref_val, y(i)));
if (not_same) std::cout << "signbit(" << x(i) << ") != " << y(i) << "\n";
all_pass = all_pass && !not_same;
}
VERIFY(all_pass);
}
void signbit_tests() {
signbit_test<float>();
signbit_test<double>();
signbit_test<Eigen::half>();
signbit_test<Eigen::bfloat16>();
signbit_test<uint8_t>();
signbit_test<uint16_t>();
signbit_test<uint32_t>();
signbit_test<uint64_t>();
signbit_test<int8_t>();
signbit_test<int16_t>();
signbit_test<int32_t>();
signbit_test<int64_t>();
}
template<typename ArrayType> void array(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;
@ -855,6 +930,35 @@ template<typename ArrayType> void array_integer(const ArrayType& m)
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() );
}
template <typename ArrayType>
struct signed_shift_test_impl {
typedef typename ArrayType::Scalar Scalar;
static constexpr size_t Size = sizeof(Scalar);
static constexpr size_t MaxShift = (CHAR_BIT * Size) - 1;
template <size_t N = 0>
static inline std::enable_if_t<(N > MaxShift), void> run(const ArrayType& ) {}
template <size_t N = 0>
static inline std::enable_if_t<(N <= MaxShift), void> run(const ArrayType& m) {
const Index rows = m.rows();
const Index cols = m.cols();
ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols);
m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; });
VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>())).all());
m2 = m1.unaryExpr([](const Scalar& x) { return x << N; });
VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op<Scalar, N>())).all());
run<N + 1>(m);
}
};
template <typename ArrayType>
void signed_shift_test(const ArrayType& m) {
signed_shift_test_impl<ArrayType>::run(m);
}
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@ -867,6 +971,9 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_6( array(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_7( signed_shift_test(Array<Index, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
}
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( comparisons(Array<float, 1, 1>()) );
@ -897,6 +1004,7 @@ EIGEN_DECLARE_TEST(array_cwise)
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_6( int_pow_test() );
CALL_SUBTEST_7( mixed_pow_test() );
CALL_SUBTEST_8( signbit_tests() );
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));

View File

@ -239,6 +239,58 @@ void check_rsqrt() {
check_rsqrt_impl<T>::run();
}
template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
struct ref_signbit_func_impl {
static bool run(const T& x) { return std::signbit(x); }
};
template <typename T>
struct ref_signbit_func_impl<T, true> {
// MSVC (perhaps others) does not have a std::signbit overload for integers
static bool run(const T& x) { return x < T(0); }
};
template <typename T>
bool ref_signbit_func(const T& x) {
return ref_signbit_func_impl<T>::run(x);
}
template <typename T>
struct check_signbit_impl {
static void run() {
T true_mask;
std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(T));
T false_mask;
std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(T));
// has sign bit
const T neg_zero = static_cast<T>(-0.0);
const T neg_one = static_cast<T>(-1.0);
const T neg_inf = -std::numeric_limits<T>::infinity();
const T neg_nan = -std::numeric_limits<T>::quiet_NaN();
// does not have sign bit
const T pos_zero = static_cast<T>(0.0);
const T pos_one = static_cast<T>(1.0);
const T pos_inf = std::numeric_limits<T>::infinity();
const T pos_nan = std::numeric_limits<T>::quiet_NaN();
std::vector<T> values = {neg_zero, neg_one, neg_inf, neg_nan, pos_zero, pos_one, pos_inf, pos_nan};
bool all_pass = true;
for (T val : values) {
const T numext_val = numext::signbit(val);
const T ref_val = ref_signbit_func(val) ? true_mask : false_mask;
bool not_same = internal::predux_any(internal::bitwise_helper<T>::bitwise_xor(ref_val, numext_val));
all_pass = all_pass && !not_same;
if (not_same) std::cout << "signbit(" << val << ") != " << numext_val << "\n";
}
VERIFY(all_pass);
}
};
template <typename T>
void check_signbit() {
check_signbit_impl<T>::run();
}
EIGEN_DECLARE_TEST(numext) {
for(int k=0; k<g_repeat; ++k)
{
@ -271,5 +323,20 @@ EIGEN_DECLARE_TEST(numext) {
CALL_SUBTEST( check_rsqrt<double>() );
CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
CALL_SUBTEST( check_signbit<half>());
CALL_SUBTEST( check_signbit<bfloat16>());
CALL_SUBTEST( check_signbit<float>());
CALL_SUBTEST( check_signbit<double>());
CALL_SUBTEST( check_signbit<uint8_t>());
CALL_SUBTEST( check_signbit<uint16_t>());
CALL_SUBTEST( check_signbit<uint32_t>());
CALL_SUBTEST( check_signbit<uint64_t>());
CALL_SUBTEST( check_signbit<int8_t>());
CALL_SUBTEST( check_signbit<int16_t>());
CALL_SUBTEST( check_signbit<int32_t>());
CALL_SUBTEST( check_signbit<int64_t>());
}
}