From 9296bb4b933973365d19b4b71e7d2b205d00a1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Tue, 8 Mar 2022 21:21:20 +0000 Subject: [PATCH] Fix edge-case in zeta for large inputs. --- .../Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h | 10 +++++++++- unsupported/test/special_functions.cpp | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index c1609f1bd..0addd09a3 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -1388,7 +1388,7 @@ struct zeta_impl { }; const Scalar maxnum = NumTraits::infinity(); - const Scalar zero = 0.0, half = 0.5, one = 1.0; + const Scalar zero = Scalar(0.0), half = Scalar(0.5), one = Scalar(1.0); const Scalar machep = cephes_helper::machep(); const Scalar nan = NumTraits::quiet_NaN(); @@ -1430,11 +1430,19 @@ struct zeta_impl { return s; } + // If b is zero, then the tail sum will also end up being zero. + // Exiting early here can prevent NaNs for some large inputs, where + // the tail sum computed below has term `a` which can overflow to `inf`. + if (numext::equal_strict(b, zero)) { + return s; + } + w = a; s += b*w/(x-one); s -= half * b; a = one; k = zero; + for( i=0; i<12; i++ ) { a *= x + k; diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index 589bb76e1..756f031c2 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -191,10 +191,10 @@ template void array_special_functions() // Check the zeta function against scipy.special.zeta { - ArrayType x(10), q(10), res(10), ref(10); - x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4; - q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3; - ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf; + ArrayType x(11), q(11), res(11), ref(11); + x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4, 2000; + q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3, 2000; + ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf, 0; CALL_SUBTEST( verify_component_wise(ref, ref); ); CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); ); CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );