eigen/test/numext.cpp
Antonio Sanchez 42acbd5700 Fix numext::arg return type.
The cxx11 path for `numext::arg` incorrectly returned the complex type
instead of the real type, leading to compile errors. Fixed this and
added tests.

Related to !477, which uncovered the issue.


(cherry picked from commit 90e9a33e1ce3e4e7663dd67e6c1f225afaf5c206)
2021-05-07 17:52:07 +00:00

276 lines
8.8 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
template<typename T, typename U>
bool check_if_equal_or_nans(const T& actual, const U& expected) {
return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
}
template<typename T, typename U>
bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
&& check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
}
template<typename T, typename U>
bool test_is_equal_or_nans(const T& actual, const U& expected)
{
if (check_if_equal_or_nans(actual, expected)) {
return true;
}
// false:
std::cerr
<< "\n actual = " << actual
<< "\n expected = " << expected << "\n\n";
return false;
}
#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
template<typename T>
void check_abs() {
typedef typename NumTraits<T>::Real Real;
Real zero(0);
if(NumTraits<T>::IsSigned)
VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
for(int k=0; k<100; ++k)
{
T x = internal::random<T>();
if(!internal::is_same<T,bool>::value)
x = x/Real(2);
if(NumTraits<T>::IsSigned)
{
VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
VERIFY( numext::abs(-x) >= zero );
}
VERIFY( numext::abs(x) >= zero );
VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
}
}
template<typename T>
void check_arg() {
typedef typename NumTraits<T>::Real Real;
VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
for(int k=0; k<100; ++k)
{
T x = internal::random<T>();
Real y = numext::arg(x);
VERIFY_IS_APPROX( y, std::arg(x) );
}
}
template<typename T>
struct check_sqrt_impl {
static void run() {
for (int i=0; i<1000; ++i) {
const T x = numext::abs(internal::random<T>());
const T sqrtx = numext::sqrt(x);
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
}
// Corner cases.
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
VERIFY((numext::isnan)(numext::sqrt(nan)));
VERIFY((numext::isnan)(numext::sqrt(-one)));
}
};
template<typename T>
struct check_sqrt_impl<std::complex<T> > {
static void run() {
typedef typename std::complex<T> ComplexT;
for (int i=0; i<1000; ++i) {
const ComplexT x = internal::random<ComplexT>();
const ComplexT sqrtx = numext::sqrt(x);
VERIFY_IS_APPROX(sqrtx*sqrtx, x);
}
// Corner cases.
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
// Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
const int kNumCorners = 20;
const ComplexT corners[kNumCorners][2] = {
{ComplexT(zero, zero), ComplexT(zero, zero)},
{ComplexT(-zero, zero), ComplexT(zero, zero)},
{ComplexT(zero, -zero), ComplexT(zero, zero)},
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
{ComplexT(one, inf), ComplexT(inf, inf)},
{ComplexT(nan, inf), ComplexT(inf, inf)},
{ComplexT(one, -inf), ComplexT(inf, -inf)},
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
{ComplexT(-inf, one), ComplexT(zero, inf)},
{ComplexT(inf, one), ComplexT(inf, zero)},
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
{ComplexT(inf, -one), ComplexT(inf, -zero)},
{ComplexT(-inf, nan), ComplexT(nan, inf)},
{ComplexT(inf, nan), ComplexT(inf, nan)},
{ComplexT(zero, nan), ComplexT(nan, nan)},
{ComplexT(one, nan), ComplexT(nan, nan)},
{ComplexT(nan, zero), ComplexT(nan, nan)},
{ComplexT(nan, one), ComplexT(nan, nan)},
{ComplexT(nan, -one), ComplexT(nan, nan)},
{ComplexT(nan, nan), ComplexT(nan, nan)},
};
for (int i=0; i<kNumCorners; ++i) {
const ComplexT& x = corners[i][0];
const ComplexT sqrtx = corners[i][1];
VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
}
}
};
template<typename T>
void check_sqrt() {
check_sqrt_impl<T>::run();
}
template<typename T>
struct check_rsqrt_impl {
static void run() {
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
for (int i=0; i<1000; ++i) {
const T x = numext::abs(internal::random<T>());
const T rsqrtx = numext::rsqrt(x);
const T invx = one / x;
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
}
// Corner cases.
VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
VERIFY((numext::isnan)(numext::rsqrt(nan)));
VERIFY((numext::isnan)(numext::rsqrt(-one)));
}
};
template<typename T>
struct check_rsqrt_impl<std::complex<T> > {
static void run() {
typedef typename std::complex<T> ComplexT;
const T zero = T(0);
const T one = T(1);
const T inf = std::numeric_limits<T>::infinity();
const T nan = std::numeric_limits<T>::quiet_NaN();
for (int i=0; i<1000; ++i) {
const ComplexT x = internal::random<ComplexT>();
const ComplexT invx = ComplexT(one, zero) / x;
const ComplexT rsqrtx = numext::rsqrt(x);
VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
}
// GCC and MSVC differ in their treatment of 1/(0 + 0i)
// GCC/clang = (inf, nan)
// MSVC = (nan, nan)
// and 1 / (x + inf i)
// GCC/clang = (0, 0)
// MSVC = (nan, nan)
#if (EIGEN_COMP_GNUC)
{
const int kNumCorners = 20;
const ComplexT corners[kNumCorners][2] = {
// Only consistent across GCC, clang
{ComplexT(zero, zero), ComplexT(zero, zero)},
{ComplexT(-zero, zero), ComplexT(zero, zero)},
{ComplexT(zero, -zero), ComplexT(zero, zero)},
{ComplexT(-zero, -zero), ComplexT(zero, zero)},
{ComplexT(one, inf), ComplexT(inf, inf)},
{ComplexT(nan, inf), ComplexT(inf, inf)},
{ComplexT(one, -inf), ComplexT(inf, -inf)},
{ComplexT(nan, -inf), ComplexT(inf, -inf)},
// Consistent across GCC, clang, MSVC
{ComplexT(-inf, one), ComplexT(zero, inf)},
{ComplexT(inf, one), ComplexT(inf, zero)},
{ComplexT(-inf, -one), ComplexT(zero, -inf)},
{ComplexT(inf, -one), ComplexT(inf, -zero)},
{ComplexT(-inf, nan), ComplexT(nan, inf)},
{ComplexT(inf, nan), ComplexT(inf, nan)},
{ComplexT(zero, nan), ComplexT(nan, nan)},
{ComplexT(one, nan), ComplexT(nan, nan)},
{ComplexT(nan, zero), ComplexT(nan, nan)},
{ComplexT(nan, one), ComplexT(nan, nan)},
{ComplexT(nan, -one), ComplexT(nan, nan)},
{ComplexT(nan, nan), ComplexT(nan, nan)},
};
for (int i=0; i<kNumCorners; ++i) {
const ComplexT& x = corners[i][0];
const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
}
}
#endif
}
};
template<typename T>
void check_rsqrt() {
check_rsqrt_impl<T>::run();
}
EIGEN_DECLARE_TEST(numext) {
for(int k=0; k<g_repeat; ++k)
{
CALL_SUBTEST( check_abs<bool>() );
CALL_SUBTEST( check_abs<signed char>() );
CALL_SUBTEST( check_abs<unsigned char>() );
CALL_SUBTEST( check_abs<short>() );
CALL_SUBTEST( check_abs<unsigned short>() );
CALL_SUBTEST( check_abs<int>() );
CALL_SUBTEST( check_abs<unsigned int>() );
CALL_SUBTEST( check_abs<long>() );
CALL_SUBTEST( check_abs<unsigned long>() );
CALL_SUBTEST( check_abs<half>() );
CALL_SUBTEST( check_abs<bfloat16>() );
CALL_SUBTEST( check_abs<float>() );
CALL_SUBTEST( check_abs<double>() );
CALL_SUBTEST( check_abs<long double>() );
CALL_SUBTEST( check_abs<std::complex<float> >() );
CALL_SUBTEST( check_abs<std::complex<double> >() );
CALL_SUBTEST( check_arg<std::complex<float> >() );
CALL_SUBTEST( check_arg<std::complex<double> >() );
CALL_SUBTEST( check_sqrt<float>() );
CALL_SUBTEST( check_sqrt<double>() );
CALL_SUBTEST( check_sqrt<std::complex<float> >() );
CALL_SUBTEST( check_sqrt<std::complex<double> >() );
CALL_SUBTEST( check_rsqrt<float>() );
CALL_SUBTEST( check_rsqrt<double>() );
CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
}
}