Improved std::complex sqrt and rsqrt.

Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously
`complex_sqrt` was only used for CUDA and MSVC), and implements
custom `complex_rsqrt`.

Also introduces `numext::rsqrt` to simplify implementation, and modified
`numext::hypot` to adhere to IEEE IEC 6059 for special cases.

The `complex_sqrt` and `complex_rsqrt` implementations were found to be
significantly faster than `std::sqrt<std::complex<T>>` and
`1/numext::sqrt<std::complex<T>>`.

Benchmark file attached.
```
GCC 10, Intel Xeon, x86_64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           9.21 ns         9.21 ns     73225448
BM_StdSqrt<std::complex<float>>        17.1 ns         17.1 ns     40966545
BM_Sqrt<std::complex<double>>          8.53 ns         8.53 ns     81111062
BM_StdSqrt<std::complex<double>>       21.5 ns         21.5 ns     32757248
BM_Rsqrt<std::complex<float>>          10.3 ns         10.3 ns     68047474
BM_DivSqrt<std::complex<float>>        16.3 ns         16.3 ns     42770127
BM_Rsqrt<std::complex<double>>         11.3 ns         11.3 ns     61322028
BM_DivSqrt<std::complex<double>>       16.5 ns         16.5 ns     42200711

Clang 11, Intel Xeon, x86_64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           7.46 ns         7.45 ns     90742042
BM_StdSqrt<std::complex<float>>        16.6 ns         16.6 ns     42369878
BM_Sqrt<std::complex<double>>          8.49 ns         8.49 ns     81629030
BM_StdSqrt<std::complex<double>>       21.8 ns         21.7 ns     31809588
BM_Rsqrt<std::complex<float>>          8.39 ns         8.39 ns     82933666
BM_DivSqrt<std::complex<float>>        14.4 ns         14.4 ns     48638676
BM_Rsqrt<std::complex<double>>         9.83 ns         9.82 ns     70068956
BM_DivSqrt<std::complex<double>>       15.7 ns         15.7 ns     44487798

Clang 9, Pixel 2, aarch64:
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_Sqrt<std::complex<float>>           24.2 ns         24.1 ns     28616031
BM_StdSqrt<std::complex<float>>         104 ns          103 ns      6826926
BM_Sqrt<std::complex<double>>          31.8 ns         31.8 ns     22157591
BM_StdSqrt<std::complex<double>>        128 ns          128 ns      5437375
BM_Rsqrt<std::complex<float>>          31.9 ns         31.8 ns     22384383
BM_DivSqrt<std::complex<float>>        99.2 ns         98.9 ns      7250438
BM_Rsqrt<std::complex<double>>         46.0 ns         45.8 ns     15338689
BM_DivSqrt<std::complex<double>>        119 ns          119 ns      5898944
```
This commit is contained in:
Antonio Sanchez 2021-01-16 10:22:07 -08:00
parent 21a8a2487c
commit bde6741641
8 changed files with 357 additions and 82 deletions

View File

@ -250,8 +250,7 @@ template<> EIGEN_DEVICE_FUNC inline double pzero<double>(const double& a) {
template <typename RealScalar>
EIGEN_DEVICE_FUNC inline std::complex<RealScalar> ptrue(const std::complex<RealScalar>& /*a*/) {
RealScalar b;
b = ptrue(b);
RealScalar b = ptrue(RealScalar(0));
return std::complex<RealScalar>(b, b);
}

View File

@ -324,7 +324,7 @@ struct abs2_retval
};
/****************************************************************************
* Implementation of sqrt *
* Implementation of sqrt/rsqrt *
****************************************************************************/
template<typename Scalar>
@ -341,8 +341,8 @@ struct sqrt_impl
// Complex sqrt defined in MathFunctionsImpl.h.
template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x);
// MSVC incorrectly handles inf cases.
#if EIGEN_COMP_MSVC > 0
// Custom implementation is faster than `std::sqrt`, works on
// GPU, and correctly handles special cases (unlike MSVC).
template<typename T>
struct sqrt_impl<std::complex<T> >
{
@ -352,7 +352,6 @@ struct sqrt_impl<std::complex<T> >
return complex_sqrt<T>(x);
}
};
#endif
template<typename Scalar>
struct sqrt_retval
@ -360,6 +359,29 @@ struct sqrt_retval
typedef Scalar type;
};
// Default implementation relies on numext::sqrt, at bottom of file.
template<typename T>
struct rsqrt_impl;
// Complex rsqrt defined in MathFunctionsImpl.h.
template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x);
template<typename T>
struct rsqrt_impl<std::complex<T> >
{
EIGEN_DEVICE_FUNC
static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
{
return complex_rsqrt<T>(x);
}
};
template<typename Scalar>
struct rsqrt_retval
{
typedef Scalar type;
};
/****************************************************************************
* Implementation of norm1 *
****************************************************************************/
@ -623,36 +645,6 @@ struct expm1_impl {
}
};
// Specialization for complex types that are not supported by std::expm1.
template <typename RealScalar>
struct expm1_impl<std::complex<RealScalar> > {
EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
const std::complex<RealScalar>& x) {
EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
RealScalar xr = x.real();
RealScalar xi = x.imag();
// expm1(z) = exp(z) - 1
// = exp(x + i * y) - 1
// = exp(x) * (cos(y) + i * sin(y)) - 1
// = exp(x) * cos(y) - 1 + i * exp(x) * sin(y)
// Imag(expm1(z)) = exp(x) * sin(y)
// Real(expm1(z)) = exp(x) * cos(y) - 1
// = exp(x) * cos(y) - 1.
// = expm1(x) + exp(x) * (cos(y) - 1)
// = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2)
// TODO better use numext::expm1 and numext::sin (but that would require forward declarations or moving this specialization down).
RealScalar erm1 = expm1_impl<RealScalar>::run(xr);
RealScalar er = erm1 + RealScalar(1.);
EIGEN_USING_STD(sin);
RealScalar sin2 = sin(xi / RealScalar(2.));
sin2 = sin2 * sin2;
RealScalar s = sin(xi);
RealScalar real_part = erm1 - RealScalar(2.) * er * sin2;
return std::complex<RealScalar>(real_part, er * s);
}
};
template<typename Scalar>
struct expm1_retval
{
@ -1421,6 +1413,14 @@ bool sqrt<bool>(const bool &x) { return x; }
SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
#endif
/** \returns the reciprocal square root of \a x. **/
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T rsqrt(const T& x)
{
return internal::rsqrt_impl<T>::run(x);
}
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
@ -1936,6 +1936,45 @@ template<> struct scalar_fuzzy_impl<bool>
};
} // end namespace internal
// Default implementations that rely on other numext implementations
namespace internal {
// Specialization for complex types that are not supported by std::expm1.
template <typename RealScalar>
struct expm1_impl<std::complex<RealScalar> > {
EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
const std::complex<RealScalar>& x) {
EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
RealScalar xr = x.real();
RealScalar xi = x.imag();
// expm1(z) = exp(z) - 1
// = exp(x + i * y) - 1
// = exp(x) * (cos(y) + i * sin(y)) - 1
// = exp(x) * cos(y) - 1 + i * exp(x) * sin(y)
// Imag(expm1(z)) = exp(x) * sin(y)
// Real(expm1(z)) = exp(x) * cos(y) - 1
// = exp(x) * cos(y) - 1.
// = expm1(x) + exp(x) * (cos(y) - 1)
// = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2)
RealScalar erm1 = numext::expm1<RealScalar>(xr);
RealScalar er = erm1 + RealScalar(1.);
RealScalar sin2 = numext::sin(xi / RealScalar(2.));
sin2 = sin2 * sin2;
RealScalar s = numext::sin(xi);
RealScalar real_part = erm1 - RealScalar(2.) * er * sin2;
return std::complex<RealScalar>(real_part, er * s);
}
};
template<typename T>
struct rsqrt_impl {
EIGEN_DEVICE_FUNC
static EIGEN_ALWAYS_INLINE T run(const T& x) {
return T(1)/numext::sqrt(x);
}
};
} // end namespace internal

View File

@ -79,6 +79,12 @@ template<typename RealScalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y)
{
// IEEE IEC 6059 special cases.
if ((numext::isinf)(x) || (numext::isinf)(y))
return NumTraits<RealScalar>::infinity();
if ((numext::isnan)(x) || (numext::isnan)(y))
return NumTraits<RealScalar>::quiet_NaN();
EIGEN_USING_STD(sqrt);
RealScalar p, qp;
p = numext::maxi(x,y);
@ -128,20 +134,56 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
const T cst_half = T(0.5);
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
// Special case of isinf(y)
if ((numext::isinf)(y)) {
return std::complex<T>(std::numeric_limits<T>::infinity(), y);
}
T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
return
x == zero ? std::complex<T>(w, y < zero ? -w : w)
: x > zero ? std::complex<T>(w, y / (2 * w))
(numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
: x == zero ? std::complex<T>(w, y < zero ? -w : w)
: x > zero ? std::complex<T>(w, y / (2 * w))
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
}
// Generic complex rsqrt implementation.
template<typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
// Computes the principal reciprocal sqrt of the input.
//
// For a complex reciprocal square root of the number z = x + i*y. We want to
// find real numbers u and v such that
// (u + i*v)^2 = 1 / (x + i*y) <=>
// u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2.
// By equating the real and imaginary parts we get:
// u^2 - v^2 = x/|z|^2
// 2*u*v = y/|z|^2.
//
// For x >= 0, this has the numerically stable solution
// u = sqrt(0.5 * (x + |z|)) / |z|
// v = -y / (2 * u * |z|)
// and for x < 0,
// v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z|
// u = -y / (2 * v * |z|)
//
// Letting w = sqrt(0.5 * (|x| + |z|)),
// if x == 0: u = w / |z|, v = -sign(y) * w / |z|
// if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
// if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
const T abs_z = numext::hypot(x, y);
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
const T woz = w / abs_z;
// Corner cases consistent with 1/sqrt(z) on gcc/clang.
return
abs_z == zero ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
: ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
: x == zero ? std::complex<T>(woz, y < zero ? woz : -woz)
: x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz );
}
} // end namespace internal
} // end namespace Eigen

View File

@ -94,19 +94,6 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
// Complex sqrt is already specialized on Windows.
#if EIGEN_COMP_MSVC == 0
template<typename T>
struct sqrt_impl<std::complex<T> >
{
EIGEN_DEVICE_FUNC
static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
{
return complex_sqrt<T>(x);
}
};
#endif
} // namespace internal
} // namespace Eigen

View File

@ -150,7 +150,7 @@ Packet4f prsqrt<Packet4f>(const Packet4f& _x) {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
// Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
// Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
@ -158,7 +158,6 @@ Packet4f prsqrt<Packet4f>(const Packet4f& x) {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
// Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}

View File

@ -456,7 +456,7 @@ struct functor_traits<scalar_sqrt_op<bool> > {
*/
template<typename Scalar> struct scalar_rsqrt_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(1)/numext::sqrt(a); }
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::rsqrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
};

View File

@ -9,6 +9,33 @@
#include "main.h"
template<typename T, typename U>
bool check_if_equal_or_nans(const T& actual, const U& expected) {
return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
}
template<typename T, typename U>
bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
&& check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
}
template<typename T, typename U>
bool test_is_equal_or_nans(const T& actual, const U& expected)
{
if (check_if_equal_or_nans(actual, expected)) {
return true;
}
// false:
std::cerr
<< "\n actual = " << actual
<< "\n expected = " << expected << "\n\n";
return false;
}
#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
template<typename T>
void check_abs() {
typedef typename NumTraits<T>::Real Real;
@ -19,7 +46,7 @@ void check_abs() {
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
for(int k=0; k<g_repeat*100; ++k)
for(int k=0; k<100; ++k)
{
T x = internal::random<T>();
if(!internal::is_same<T,bool>::value)
@ -34,22 +61,199 @@ void check_abs() {
}
}
EIGEN_DECLARE_TEST(numext) {
CALL_SUBTEST( check_abs<bool>() );
CALL_SUBTEST( check_abs<signed char>() );
CALL_SUBTEST( check_abs<unsigned char>() );
CALL_SUBTEST( check_abs<short>() );
CALL_SUBTEST( check_abs<unsigned short>() );
CALL_SUBTEST( check_abs<int>() );
CALL_SUBTEST( check_abs<unsigned int>() );
CALL_SUBTEST( check_abs<long>() );
CALL_SUBTEST( check_abs<unsigned long>() );
CALL_SUBTEST( check_abs<half>() );
CALL_SUBTEST( check_abs<bfloat16>() );
CALL_SUBTEST( check_abs<float>() );
CALL_SUBTEST( check_abs<double>() );
CALL_SUBTEST( check_abs<long double>() );
template<typename T>
struct check_sqrt_impl {
static void run() {
for (int i=0; i<1000; ++i) {
const T x = numext::abs(internal::random<T>());
const T sqrtx = numext::sqrt(x);
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
}
CALL_SUBTEST( check_abs<std::complex<float> >() );
CALL_SUBTEST( check_abs<std::complex<double> >() );
// Corner cases.
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
VERIFY((numext::isnan)(numext::sqrt(nan)));
VERIFY((numext::isnan)(numext::sqrt(-one)));
}
};
template<typename T>
struct check_sqrt_impl<std::complex<T> > {
static void run() {
typedef typename std::complex<T> ComplexT;
for (int i=0; i<1000; ++i) {
const ComplexT x = internal::random<ComplexT>();
const ComplexT sqrtx = numext::sqrt(x);
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
}
// Corner cases.
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
// Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
const int kNumCorners = 20;
const ComplexT corners[kNumCorners][2] = {
{ComplexT(zero, zero), ComplexT(zero, zero)},
{ComplexT(-zero, zero), ComplexT(zero, zero)},
{ComplexT(zero, -zero), ComplexT(zero, zero)},
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
{ComplexT(one, inf), ComplexT(inf, inf)},
{ComplexT(nan, inf), ComplexT(inf, inf)},
{ComplexT(one, -inf), ComplexT(inf, -inf)},
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
{ComplexT(-inf, one), ComplexT(zero, inf)},
{ComplexT(inf, one), ComplexT(inf, zero)},
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
{ComplexT(inf, -one), ComplexT(inf, -zero)},
{ComplexT(-inf, nan), ComplexT(nan, inf)},
{ComplexT(inf, nan), ComplexT(inf, nan)},
{ComplexT(zero, nan), ComplexT(nan, nan)},
{ComplexT(one, nan), ComplexT(nan, nan)},
{ComplexT(nan, zero), ComplexT(nan, nan)},
{ComplexT(nan, one), ComplexT(nan, nan)},
{ComplexT(nan, -one), ComplexT(nan, nan)},
{ComplexT(nan, nan), ComplexT(nan, nan)},
};
for (int i=0; i<kNumCorners; ++i) {
const ComplexT& x = corners[i][0];
const ComplexT sqrtx = corners[i][1];
VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
}
}
};
template<typename T>
void check_sqrt() {
check_sqrt_impl<T>::run();
}
template<typename T>
struct check_rsqrt_impl {
static void run() {
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
for (int i=0; i<1000; ++i) {
const T x = numext::abs(internal::random<T>());
const T rsqrtx = numext::rsqrt(x);
const T invx = one / x;
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
}
// Corner cases.
VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
VERIFY((numext::isnan)(numext::rsqrt(nan)));
VERIFY((numext::isnan)(numext::rsqrt(-one)));
}
};
template<typename T>
struct check_rsqrt_impl<std::complex<T> > {
static void run() {
typedef typename std::complex<T> ComplexT;
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
for (int i=0; i<1000; ++i) {
const ComplexT x = internal::random<ComplexT>();
const ComplexT invx = ComplexT(one, zero) / x;
const ComplexT rsqrtx = numext::rsqrt(x);
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
}
// GCC and MSVC differ in their treatment of 1/(0 + 0i)
// GCC/clang = (inf, nan)
// MSVC = (nan, nan)
// and 1 / (x + inf i)
// GCC/clang = (0, 0)
// MSVC = (nan, nan)
#if (EIGEN_COMP_GNUC)
{
const int kNumCorners = 20;
const ComplexT corners[kNumCorners][2] = {
// Only consistent across GCC, clang
{ComplexT(zero, zero), ComplexT(zero, zero)},
{ComplexT(-zero, zero), ComplexT(zero, zero)},
{ComplexT(zero, -zero), ComplexT(zero, zero)},
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
{ComplexT(one, inf), ComplexT(inf, inf)},
{ComplexT(nan, inf), ComplexT(inf, inf)},
{ComplexT(one, -inf), ComplexT(inf, -inf)},
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
// Consistent across GCC, clang, MSVC
{ComplexT(-inf, one), ComplexT(zero, inf)},
{ComplexT(inf, one), ComplexT(inf, zero)},
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
{ComplexT(inf, -one), ComplexT(inf, -zero)},
{ComplexT(-inf, nan), ComplexT(nan, inf)},
{ComplexT(inf, nan), ComplexT(inf, nan)},
{ComplexT(zero, nan), ComplexT(nan, nan)},
{ComplexT(one, nan), ComplexT(nan, nan)},
{ComplexT(nan, zero), ComplexT(nan, nan)},
{ComplexT(nan, one), ComplexT(nan, nan)},
{ComplexT(nan, -one), ComplexT(nan, nan)},
{ComplexT(nan, nan), ComplexT(nan, nan)},
};
for (int i=0; i<kNumCorners; ++i) {
const ComplexT& x = corners[i][0];
const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
}
}
#endif
}
};
template<typename T>
void check_rsqrt() {
check_rsqrt_impl<T>::run();
}
EIGEN_DECLARE_TEST(numext) {
for(int k=0; k<g_repeat; ++k)
{
CALL_SUBTEST( check_abs<bool>() );
CALL_SUBTEST( check_abs<signed char>() );
CALL_SUBTEST( check_abs<unsigned char>() );
CALL_SUBTEST( check_abs<short>() );
CALL_SUBTEST( check_abs<unsigned short>() );
CALL_SUBTEST( check_abs<int>() );
CALL_SUBTEST( check_abs<unsigned int>() );
CALL_SUBTEST( check_abs<long>() );
CALL_SUBTEST( check_abs<unsigned long>() );
CALL_SUBTEST( check_abs<half>() );
CALL_SUBTEST( check_abs<bfloat16>() );
CALL_SUBTEST( check_abs<float>() );
CALL_SUBTEST( check_abs<double>() );
CALL_SUBTEST( check_abs<long double>() );
CALL_SUBTEST( check_abs<std::complex<float> >() );
CALL_SUBTEST( check_abs<std::complex<double> >() );
CALL_SUBTEST( check_sqrt<float>() );
CALL_SUBTEST( check_sqrt<double>() );
CALL_SUBTEST( check_sqrt<std::complex<float> >() );
CALL_SUBTEST( check_sqrt<std::complex<double> >() );
CALL_SUBTEST( check_rsqrt<float>() );
CALL_SUBTEST( check_rsqrt<double>() );
CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
}
}

View File

@ -161,8 +161,12 @@ template<typename MatrixType> void stable_norm(const MatrixType& m)
// mix
{
Index i2 = internal::random<Index>(0,rows-1);
Index j2 = internal::random<Index>(0,cols-1);
// Ensure unique indices otherwise inf may be overwritten by NaN.
Index i2, j2;
do {
i2 = internal::random<Index>(0,rows-1);
j2 = internal::random<Index>(0,cols-1);
} while (i2 == i && j2 == j);
v = vrand;
v(i,j) = -std::numeric_limits<RealScalar>::infinity();
v(i2,j2) = std::numeric_limits<RealScalar>::quiet_NaN();
@ -170,7 +174,8 @@ template<typename MatrixType> void stable_norm(const MatrixType& m)
VERIFY(!(numext::isfinite)(v.norm())); VERIFY((numext::isnan)(v.norm()));
VERIFY(!(numext::isfinite)(v.stableNorm())); VERIFY((numext::isnan)(v.stableNorm()));
VERIFY(!(numext::isfinite)(v.blueNorm())); VERIFY((numext::isnan)(v.blueNorm()));
VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isnan)(v.hypotNorm()));
// hypot propagates inf over NaN.
VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isinf)(v.hypotNorm()));
}
// stableNormalize[d]