From bde6741641b7c677d901cd48db844fcea1fd32fe Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Sat, 16 Jan 2021 10:22:07 -0800 Subject: [PATCH] Improved std::complex sqrt and rsqrt. Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously `complex_sqrt` was only used for CUDA and MSVC), and implements custom `complex_rsqrt`. Also introduces `numext::rsqrt` to simplify implementation, and modified `numext::hypot` to adhere to IEEE IEC 6059 for special cases. The `complex_sqrt` and `complex_rsqrt` implementations were found to be significantly faster than `std::sqrt>` and `1/numext::sqrt>`. Benchmark file attached. ``` GCC 10, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 9.21 ns 9.21 ns 73225448 BM_StdSqrt> 17.1 ns 17.1 ns 40966545 BM_Sqrt> 8.53 ns 8.53 ns 81111062 BM_StdSqrt> 21.5 ns 21.5 ns 32757248 BM_Rsqrt> 10.3 ns 10.3 ns 68047474 BM_DivSqrt> 16.3 ns 16.3 ns 42770127 BM_Rsqrt> 11.3 ns 11.3 ns 61322028 BM_DivSqrt> 16.5 ns 16.5 ns 42200711 Clang 11, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 7.46 ns 7.45 ns 90742042 BM_StdSqrt> 16.6 ns 16.6 ns 42369878 BM_Sqrt> 8.49 ns 8.49 ns 81629030 BM_StdSqrt> 21.8 ns 21.7 ns 31809588 BM_Rsqrt> 8.39 ns 8.39 ns 82933666 BM_DivSqrt> 14.4 ns 14.4 ns 48638676 BM_Rsqrt> 9.83 ns 9.82 ns 70068956 BM_DivSqrt> 15.7 ns 15.7 ns 44487798 Clang 9, Pixel 2, aarch64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt> 24.2 ns 24.1 ns 28616031 BM_StdSqrt> 104 ns 103 ns 6826926 BM_Sqrt> 31.8 ns 31.8 ns 22157591 BM_StdSqrt> 128 ns 128 ns 5437375 BM_Rsqrt> 31.9 ns 31.8 ns 22384383 BM_DivSqrt> 99.2 ns 98.9 ns 7250438 BM_Rsqrt> 46.0 ns 45.8 ns 15338689 BM_DivSqrt> 119 ns 119 ns 5898944 ``` --- Eigen/src/Core/GenericPacketMath.h | 3 +- Eigen/src/Core/MathFunctions.h | 107 +++++++---- Eigen/src/Core/MathFunctionsImpl.h | 60 +++++- Eigen/src/Core/arch/CUDA/Complex.h | 13 -- Eigen/src/Core/arch/SSE/MathFunctions.h | 3 +- Eigen/src/Core/functors/UnaryFunctors.h | 2 +- test/numext.cpp | 240 ++++++++++++++++++++++-- test/stable_norm.cpp | 11 +- 8 files changed, 357 insertions(+), 82 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index ec7d20e73..16119c1d8 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -250,8 +250,7 @@ template<> EIGEN_DEVICE_FUNC inline double pzero(const double& a) { template EIGEN_DEVICE_FUNC inline std::complex ptrue(const std::complex& /*a*/) { - RealScalar b; - b = ptrue(b); + RealScalar b = ptrue(RealScalar(0)); return std::complex(b, b); } diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index f64116a41..511a4276f 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -324,7 +324,7 @@ struct abs2_retval }; /**************************************************************************** -* Implementation of sqrt * +* Implementation of sqrt/rsqrt * ****************************************************************************/ template @@ -341,8 +341,8 @@ struct sqrt_impl // Complex sqrt defined in MathFunctionsImpl.h. template EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& a_x); -// MSVC incorrectly handles inf cases. -#if EIGEN_COMP_MSVC > 0 +// Custom implementation is faster than `std::sqrt`, works on +// GPU, and correctly handles special cases (unlike MSVC). template struct sqrt_impl > { @@ -352,7 +352,6 @@ struct sqrt_impl > return complex_sqrt(x); } }; -#endif template struct sqrt_retval @@ -360,6 +359,29 @@ struct sqrt_retval typedef Scalar type; }; +// Default implementation relies on numext::sqrt, at bottom of file. +template +struct rsqrt_impl; + +// Complex rsqrt defined in MathFunctionsImpl.h. +template EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& a_x); + +template +struct rsqrt_impl > +{ + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) + { + return complex_rsqrt(x); + } +}; + +template +struct rsqrt_retval +{ + typedef Scalar type; +}; + /**************************************************************************** * Implementation of norm1 * ****************************************************************************/ @@ -623,36 +645,6 @@ struct expm1_impl { } }; -// Specialization for complex types that are not supported by std::expm1. -template -struct expm1_impl > { - EIGEN_DEVICE_FUNC static inline std::complex run( - const std::complex& x) { - EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar) - RealScalar xr = x.real(); - RealScalar xi = x.imag(); - // expm1(z) = exp(z) - 1 - // = exp(x + i * y) - 1 - // = exp(x) * (cos(y) + i * sin(y)) - 1 - // = exp(x) * cos(y) - 1 + i * exp(x) * sin(y) - // Imag(expm1(z)) = exp(x) * sin(y) - // Real(expm1(z)) = exp(x) * cos(y) - 1 - // = exp(x) * cos(y) - 1. - // = expm1(x) + exp(x) * (cos(y) - 1) - // = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2) - - // TODO better use numext::expm1 and numext::sin (but that would require forward declarations or moving this specialization down). - RealScalar erm1 = expm1_impl::run(xr); - RealScalar er = erm1 + RealScalar(1.); - EIGEN_USING_STD(sin); - RealScalar sin2 = sin(xi / RealScalar(2.)); - sin2 = sin2 * sin2; - RealScalar s = sin(xi); - RealScalar real_part = erm1 - RealScalar(2.) * er * sin2; - return std::complex(real_part, er * s); - } -}; - template struct expm1_retval { @@ -1421,6 +1413,14 @@ bool sqrt(const bool &x) { return x; } SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt) #endif +/** \returns the reciprocal square root of \a x. **/ +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +T rsqrt(const T& x) +{ + return internal::rsqrt_impl::run(x); +} + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { @@ -1936,6 +1936,45 @@ template<> struct scalar_fuzzy_impl }; +} // end namespace internal + +// Default implementations that rely on other numext implementations +namespace internal { + +// Specialization for complex types that are not supported by std::expm1. +template +struct expm1_impl > { + EIGEN_DEVICE_FUNC static inline std::complex run( + const std::complex& x) { + EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar) + RealScalar xr = x.real(); + RealScalar xi = x.imag(); + // expm1(z) = exp(z) - 1 + // = exp(x + i * y) - 1 + // = exp(x) * (cos(y) + i * sin(y)) - 1 + // = exp(x) * cos(y) - 1 + i * exp(x) * sin(y) + // Imag(expm1(z)) = exp(x) * sin(y) + // Real(expm1(z)) = exp(x) * cos(y) - 1 + // = exp(x) * cos(y) - 1. + // = expm1(x) + exp(x) * (cos(y) - 1) + // = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2) + RealScalar erm1 = numext::expm1(xr); + RealScalar er = erm1 + RealScalar(1.); + RealScalar sin2 = numext::sin(xi / RealScalar(2.)); + sin2 = sin2 * sin2; + RealScalar s = numext::sin(xi); + RealScalar real_part = erm1 - RealScalar(2.) * er * sin2; + return std::complex(real_part, er * s); + } +}; + +template +struct rsqrt_impl { + EIGEN_DEVICE_FUNC + static EIGEN_ALWAYS_INLINE T run(const T& x) { + return T(1)/numext::sqrt(x); + } +}; } // end namespace internal diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 9222285b4..0d3f317bb 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -79,6 +79,12 @@ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y) { + // IEEE IEC 6059 special cases. + if ((numext::isinf)(x) || (numext::isinf)(y)) + return NumTraits::infinity(); + if ((numext::isnan)(x) || (numext::isnan)(y)) + return NumTraits::quiet_NaN(); + EIGEN_USING_STD(sqrt); RealScalar p, qp; p = numext::maxi(x,y); @@ -128,20 +134,56 @@ EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& z) { const T x = numext::real(z); const T y = numext::imag(z); const T zero = T(0); - const T cst_half = T(0.5); + const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y))); - // Special case of isinf(y) - if ((numext::isinf)(y)) { - return std::complex(std::numeric_limits::infinity(), y); - } - - T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z))); return - x == zero ? std::complex(w, y < zero ? -w : w) - : x > zero ? std::complex(w, y / (2 * w)) + (numext::isinf)(y) ? std::complex(NumTraits::infinity(), y) + : x == zero ? std::complex(w, y < zero ? -w : w) + : x > zero ? std::complex(w, y / (2 * w)) : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w ); } +// Generic complex rsqrt implementation. +template +EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { + // Computes the principal reciprocal sqrt of the input. + // + // For a complex reciprocal square root of the number z = x + i*y. We want to + // find real numbers u and v such that + // (u + i*v)^2 = 1 / (x + i*y) <=> + // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x/|z|^2 + // 2*u*v = y/|z|^2. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + |z|)) / |z| + // v = -y / (2 * u * |z|) + // and for x < 0, + // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z| + // u = -y / (2 * v * |z|) + // + // Letting w = sqrt(0.5 * (|x| + |z|)), + // if x == 0: u = w / |z|, v = -sign(y) * w / |z| + // if x > 0: u = w / |z|, v = -y / (2 * w * |z|) + // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z| + + const T x = numext::real(z); + const T y = numext::imag(z); + const T zero = T(0); + + const T abs_z = numext::hypot(x, y); + const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z)); + const T woz = w / abs_z; + // Corner cases consistent with 1/sqrt(z) on gcc/clang. + return + abs_z == zero ? std::complex(NumTraits::infinity(), NumTraits::quiet_NaN()) + : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex(zero, zero) + : x == zero ? std::complex(woz, y < zero ? woz : -woz) + : x > zero ? std::complex(woz, -y / (2 * w * abs_z)) + : std::complex(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz ); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index df5a3c2a4..6e77372b0 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -94,19 +94,6 @@ template struct scalar_quotient_op, const std: template struct scalar_quotient_op, std::complex > : scalar_quotient_op, const std::complex > {}; -// Complex sqrt is already specialized on Windows. -#if EIGEN_COMP_MSVC == 0 -template -struct sqrt_impl > -{ - EIGEN_DEVICE_FUNC - static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) - { - return complex_sqrt(x); - } -}; -#endif - } // namespace internal } // namespace Eigen diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h index 9f66d8ab3..8736d0d6b 100644 --- a/Eigen/src/Core/arch/SSE/MathFunctions.h +++ b/Eigen/src/Core/arch/SSE/MathFunctions.h @@ -150,7 +150,7 @@ Packet4f prsqrt(const Packet4f& _x) { template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f prsqrt(const Packet4f& x) { - // Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation. + // Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation. return _mm_div_ps(pset1(1.0f), _mm_sqrt_ps(x)); } @@ -158,7 +158,6 @@ Packet4f prsqrt(const Packet4f& x) { template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d prsqrt(const Packet2d& x) { - // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation. return _mm_div_pd(pset1(1.0), _mm_sqrt_pd(x)); } diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index eee6ae194..976ecba59 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -456,7 +456,7 @@ struct functor_traits > { */ template struct scalar_rsqrt_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(1)/numext::sqrt(a); } + EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::rsqrt(a); } template EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); } }; diff --git a/test/numext.cpp b/test/numext.cpp index ff4d13ff3..cf1ca173d 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -9,6 +9,33 @@ #include "main.h" +template +bool check_if_equal_or_nans(const T& actual, const U& expected) { + return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected))); +} + +template +bool check_if_equal_or_nans(const std::complex& actual, const std::complex& expected) { + return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) + && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected)); +} + +template +bool test_is_equal_or_nans(const T& actual, const U& expected) +{ + if (check_if_equal_or_nans(actual, expected)) { + return true; + } + + // false: + std::cerr + << "\n actual = " << actual + << "\n expected = " << expected << "\n\n"; + return false; +} + +#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) + template void check_abs() { typedef typename NumTraits::Real Real; @@ -19,7 +46,7 @@ void check_abs() { VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); - for(int k=0; k(); if(!internal::is_same::value) @@ -34,22 +61,199 @@ void check_abs() { } } -EIGEN_DECLARE_TEST(numext) { - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs() ); +template +struct check_sqrt_impl { + static void run() { + for (int i=0; i<1000; ++i) { + const T x = numext::abs(internal::random()); + const T sqrtx = numext::sqrt(x); + VERIFY_IS_APPROX(sqrtx*sqrtx, x); + } - CALL_SUBTEST( check_abs >() ); - CALL_SUBTEST( check_abs >() ); + // Corner cases. + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits::infinity(); + const T nan = std::numeric_limits::quiet_NaN(); + VERIFY_IS_EQUAL(numext::sqrt(zero), zero); + VERIFY_IS_EQUAL(numext::sqrt(inf), inf); + VERIFY((numext::isnan)(numext::sqrt(nan))); + VERIFY((numext::isnan)(numext::sqrt(-one))); + } +}; + +template +struct check_sqrt_impl > { + static void run() { + typedef typename std::complex ComplexT; + + for (int i=0; i<1000; ++i) { + const ComplexT x = internal::random(); + const ComplexT sqrtx = numext::sqrt(x); + VERIFY_IS_APPROX(sqrtx*sqrtx, x); + } + + // Corner cases. + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits::infinity(); + const T nan = std::numeric_limits::quiet_NaN(); + + // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt + const int kNumCorners = 20; + const ComplexT corners[kNumCorners][2] = { + {ComplexT(zero, zero), ComplexT(zero, zero)}, + {ComplexT(-zero, zero), ComplexT(zero, zero)}, + {ComplexT(zero, -zero), ComplexT(zero, zero)}, + {ComplexT(-zero, -zero), ComplexT(zero, zero)}, + {ComplexT(one, inf), ComplexT(inf, inf)}, + {ComplexT(nan, inf), ComplexT(inf, inf)}, + {ComplexT(one, -inf), ComplexT(inf, -inf)}, + {ComplexT(nan, -inf), ComplexT(inf, -inf)}, + {ComplexT(-inf, one), ComplexT(zero, inf)}, + {ComplexT(inf, one), ComplexT(inf, zero)}, + {ComplexT(-inf, -one), ComplexT(zero, -inf)}, + {ComplexT(inf, -one), ComplexT(inf, -zero)}, + {ComplexT(-inf, nan), ComplexT(nan, inf)}, + {ComplexT(inf, nan), ComplexT(inf, nan)}, + {ComplexT(zero, nan), ComplexT(nan, nan)}, + {ComplexT(one, nan), ComplexT(nan, nan)}, + {ComplexT(nan, zero), ComplexT(nan, nan)}, + {ComplexT(nan, one), ComplexT(nan, nan)}, + {ComplexT(nan, -one), ComplexT(nan, nan)}, + {ComplexT(nan, nan), ComplexT(nan, nan)}, + }; + + for (int i=0; i +void check_sqrt() { + check_sqrt_impl::run(); +} + +template +struct check_rsqrt_impl { + static void run() { + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits::infinity(); + const T nan = std::numeric_limits::quiet_NaN(); + + for (int i=0; i<1000; ++i) { + const T x = numext::abs(internal::random()); + const T rsqrtx = numext::rsqrt(x); + const T invx = one / x; + VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); + } + + // Corner cases. + VERIFY_IS_EQUAL(numext::rsqrt(zero), inf); + VERIFY_IS_EQUAL(numext::rsqrt(inf), zero); + VERIFY((numext::isnan)(numext::rsqrt(nan))); + VERIFY((numext::isnan)(numext::rsqrt(-one))); + } +}; + +template +struct check_rsqrt_impl > { + static void run() { + typedef typename std::complex ComplexT; + const T zero = T(0); + const T one = T(1); + const T inf = std::numeric_limits::infinity(); + const T nan = std::numeric_limits::quiet_NaN(); + + for (int i=0; i<1000; ++i) { + const ComplexT x = internal::random(); + const ComplexT invx = ComplexT(one, zero) / x; + const ComplexT rsqrtx = numext::rsqrt(x); + VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); + } + + // GCC and MSVC differ in their treatment of 1/(0 + 0i) + // GCC/clang = (inf, nan) + // MSVC = (nan, nan) + // and 1 / (x + inf i) + // GCC/clang = (0, 0) + // MSVC = (nan, nan) + #if (EIGEN_COMP_GNUC) + { + const int kNumCorners = 20; + const ComplexT corners[kNumCorners][2] = { + // Only consistent across GCC, clang + {ComplexT(zero, zero), ComplexT(zero, zero)}, + {ComplexT(-zero, zero), ComplexT(zero, zero)}, + {ComplexT(zero, -zero), ComplexT(zero, zero)}, + {ComplexT(-zero, -zero), ComplexT(zero, zero)}, + {ComplexT(one, inf), ComplexT(inf, inf)}, + {ComplexT(nan, inf), ComplexT(inf, inf)}, + {ComplexT(one, -inf), ComplexT(inf, -inf)}, + {ComplexT(nan, -inf), ComplexT(inf, -inf)}, + // Consistent across GCC, clang, MSVC + {ComplexT(-inf, one), ComplexT(zero, inf)}, + {ComplexT(inf, one), ComplexT(inf, zero)}, + {ComplexT(-inf, -one), ComplexT(zero, -inf)}, + {ComplexT(inf, -one), ComplexT(inf, -zero)}, + {ComplexT(-inf, nan), ComplexT(nan, inf)}, + {ComplexT(inf, nan), ComplexT(inf, nan)}, + {ComplexT(zero, nan), ComplexT(nan, nan)}, + {ComplexT(one, nan), ComplexT(nan, nan)}, + {ComplexT(nan, zero), ComplexT(nan, nan)}, + {ComplexT(nan, one), ComplexT(nan, nan)}, + {ComplexT(nan, -one), ComplexT(nan, nan)}, + {ComplexT(nan, nan), ComplexT(nan, nan)}, + }; + + for (int i=0; i +void check_rsqrt() { + check_rsqrt_impl::run(); +} + +EIGEN_DECLARE_TEST(numext) { + for(int k=0; k() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + CALL_SUBTEST( check_abs() ); + + CALL_SUBTEST( check_abs >() ); + CALL_SUBTEST( check_abs >() ); + + CALL_SUBTEST( check_sqrt() ); + CALL_SUBTEST( check_sqrt() ); + CALL_SUBTEST( check_sqrt >() ); + CALL_SUBTEST( check_sqrt >() ); + + CALL_SUBTEST( check_rsqrt() ); + CALL_SUBTEST( check_rsqrt() ); + CALL_SUBTEST( check_rsqrt >() ); + CALL_SUBTEST( check_rsqrt >() ); + } } diff --git a/test/stable_norm.cpp b/test/stable_norm.cpp index ee5f91674..008e35d87 100644 --- a/test/stable_norm.cpp +++ b/test/stable_norm.cpp @@ -161,8 +161,12 @@ template void stable_norm(const MatrixType& m) // mix { - Index i2 = internal::random(0,rows-1); - Index j2 = internal::random(0,cols-1); + // Ensure unique indices otherwise inf may be overwritten by NaN. + Index i2, j2; + do { + i2 = internal::random(0,rows-1); + j2 = internal::random(0,cols-1); + } while (i2 == i && j2 == j); v = vrand; v(i,j) = -std::numeric_limits::infinity(); v(i2,j2) = std::numeric_limits::quiet_NaN(); @@ -170,7 +174,8 @@ template void stable_norm(const MatrixType& m) VERIFY(!(numext::isfinite)(v.norm())); VERIFY((numext::isnan)(v.norm())); VERIFY(!(numext::isfinite)(v.stableNorm())); VERIFY((numext::isnan)(v.stableNorm())); VERIFY(!(numext::isfinite)(v.blueNorm())); VERIFY((numext::isnan)(v.blueNorm())); - VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isnan)(v.hypotNorm())); + // hypot propagates inf over NaN. + VERIFY(!(numext::isfinite)(v.hypotNorm())); VERIFY((numext::isinf)(v.hypotNorm())); } // stableNormalize[d]