From 26b8fabd80e882235682f58331ef232bf78b9f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Fri, 5 May 2023 16:27:26 +0000 Subject: [PATCH] Return NaN in ndtri for values outside valid input range. (cherry picked from commit 1f79a6078fb77da47069c8aec23c4e309fb982e2) --- .../Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h | 10 +++++----- unsupported/test/special_functions.cpp | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index 7634bf72f..243ffdd5e 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -471,9 +471,9 @@ struct erfc_impl { * ERROR MESSAGES: * * message condition value returned - * ndtri domain x <= 0 -MAXNUM - * ndtri domain x >= 1 MAXNUM - * + * ndtri domain x == 0 -INF + * ndtri domain x == 1 INF + * ndtri domain x < 0, x > 1 NAN */ /* Cephes Math Library Release 2.2: June, 1992 @@ -635,8 +635,8 @@ T generic_ndtri(const T& a) { generic_ndtri_lt_exp_neg_two(b, should_flipsign)); return pselect( - pcmp_le(a, zero), neg_maxnum, - pselect(pcmp_le(one, a), maxnum, ndtri)); + pcmp_eq(a, zero), neg_maxnum, + pselect(pcmp_eq(one, a), maxnum, ndtri)); } template diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index 756f031c2..44c77535e 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -171,9 +171,9 @@ template void array_special_functions() // Check the ndtri function against scipy.special.ndtri { - ArrayType x(7), res(7), ref(7); - x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01; - ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408; + ArrayType x(11), res(11), ref(11); + x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01, 0, 1, -0.01, 1.01; + ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408, -plusinf, plusinf, nan, nan; CALL_SUBTEST( verify_component_wise(ref, ref); ); CALL_SUBTEST( res = x.ndtri(); verify_component_wise(res, ref); ); CALL_SUBTEST( res = ndtri(x); verify_component_wise(res, ref); );