From 42acbd570028c5dee7e6dbfcfe0ea614f09d9d75 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 7 May 2021 08:24:32 -0700 Subject: [PATCH] Fix numext::arg return type. The cxx11 path for `numext::arg` incorrectly returned the complex type instead of the real type, leading to compile errors. Fixed this and added tests. Related to !477, which uncovered the issue. (cherry picked from commit 90e9a33e1ce3e4e7663dd67e6c1f225afaf5c206) --- Eigen/src/Core/MathFunctions.h | 9 +++++---- test/numext.cpp | 18 +++++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 29201214f..67b1d8263 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -592,8 +592,9 @@ struct arg_default_impl; template struct arg_default_impl { + typedef typename NumTraits::Real RealScalar; EIGEN_DEVICE_FUNC - static inline Scalar run(const Scalar& x) + static inline RealScalar run(const Scalar& x) { #if defined(EIGEN_HIP_DEVICE_COMPILE) // HIP does not seem to have a native device side implementation for the math routine "arg" @@ -601,7 +602,7 @@ struct arg_default_impl { #else EIGEN_USING_STD(arg); #endif - return static_cast(arg(x)); + return static_cast(arg(x)); } }; @@ -612,7 +613,7 @@ struct arg_default_impl { EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { - return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); + return (x < Scalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0); } }; #else @@ -623,7 +624,7 @@ struct arg_default_impl EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { - return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); + return (x < RealScalar(0)) ? RealScalar(EIGEN_PI) : RealScalar(0); } }; diff --git a/test/numext.cpp b/test/numext.cpp index cf1ca173d..8a2fde501 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -61,6 +61,20 @@ void check_abs() { } } +template +void check_arg() { + typedef typename NumTraits::Real Real; + VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); + VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); + + for(int k=0; k<100; ++k) + { + T x = internal::random(); + Real y = numext::arg(x); + VERIFY_IS_APPROX( y, std::arg(x) ); + } +} + template struct check_sqrt_impl { static void run() { @@ -242,10 +256,12 @@ EIGEN_DECLARE_TEST(numext) { CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); CALL_SUBTEST( check_abs() ); - CALL_SUBTEST( check_abs >() ); CALL_SUBTEST( check_abs >() ); + CALL_SUBTEST( check_arg >() ); + CALL_SUBTEST( check_arg >() ); + CALL_SUBTEST( check_sqrt() ); CALL_SUBTEST( check_sqrt() ); CALL_SUBTEST( check_sqrt >() );