From 40d1e2f8c7a85b7c0522d11d6e3d0c6a18bc9721 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 28 Apr 2016 13:57:08 -0700 Subject: [PATCH] Eliminate mutual recursion in igamma{,c}_impl::Run. Presently, igammac_impl::Run calls igamma_impl::Run, which in turn calls igammac_impl::Run. This isn't actually mutual recursion; the calls are guarded such that we never get into a loop. Nonetheless, it's a stretch for clang to prove this. As a result, clang emits a recursive call in both igammac_impl::Run and igamma_impl::Run. That this is suboptimal code is bad enough, but it's particularly bad when compiling for CUDA/nvptx. nvptx allows recursion, but only begrudgingly: If you have recursive calls in a kernel, it's on you to manually specify the kernel's stack size. Otherwise, ptxas will dump a warning, make a guess, and who knows if it's right. This change explicitly eliminates the mutual recursion in igammac_impl::Run and igamma_impl::Run. --- Eigen/src/Core/SpecialFunctions.h | 93 ++++++++++++++++++++++++------- 1 file changed, 72 insertions(+), 21 deletions(-) diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index a3857ae1f..10ff4371e 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -517,26 +517,51 @@ struct igammac_impl { */ const Scalar zero = 0; const Scalar one = 1; - const Scalar two = 2; - const Scalar machep = igamma_helper::machep(); - const Scalar maxlog = numext::log(NumTraits::highest()); - const Scalar big = igamma_helper::big(); - const Scalar biginv = 1 / big; const Scalar nan = NumTraits::quiet_NaN(); - const Scalar inf = NumTraits::infinity(); - Scalar ans, ax, c, yc, r, t, y, z; - Scalar pk, pkm1, pkm2, qk, qkm1, qkm2; - - if ((x < zero) || ( a <= zero)) { + if ((x < zero) || (a <= zero)) { // domain error return nan; } if ((x < one) || (x < a)) { - return (one - igamma_impl::run(a, x)); + /* The checks above ensure that we meet the preconditions for + * igamma_impl::Impl(), so call it, rather than igamma_impl::Run(). + * Calling Run() would also work, but in that case the compiler may not be + * able to prove that igammac_impl::Run and igamma_impl::Run are not + * mutually recursive. This leads to worse code, particularly on + * platforms like nvptx, where recursion is allowed only begrudgingly. + */ + return (one - igamma_impl::Impl(a, x)); } + return Impl(a, x); + } + + private: + /* igamma_impl calls igammac_impl::Impl. */ + friend struct igamma_impl; + + /* Actually computes igamc(a, x). + * + * Preconditions: + * a > 0 + * x >= 1 + * x >= a + */ + static Scalar Impl(Scalar a, Scalar x) { + const Scalar zero = 0; + const Scalar one = 1; + const Scalar two = 2; + const Scalar machep = igamma_helper::machep(); + const Scalar maxlog = numext::log(NumTraits::highest()); + const Scalar big = igamma_helper::big(); + const Scalar biginv = 1 / big; + const Scalar inf = NumTraits::infinity(); + + Scalar ans, ax, c, yc, r, t, y, z; + Scalar pk, pkm1, pkm2, qk, qkm1, qkm2; + if (x == inf) return zero; // std::isinf crashes on CUDA /* Compute x**a * exp(-x) / gamma(a) */ @@ -678,22 +703,48 @@ struct igamma_impl { */ const Scalar zero = 0; const Scalar one = 1; + const Scalar nan = NumTraits::quiet_NaN(); + + if (x == zero) return zero; + + if ((x < zero) || (a <= zero)) { // domain error + return nan; + } + + if ((x > one) && (x > a)) { + /* The checks above ensure that we meet the preconditions for + * igammac_impl::Impl(), so call it, rather than igammac_impl::Run(). + * Calling Run() would also work, but in that case the compiler may not be + * able to prove that igammac_impl::Run and igamma_impl::Run are not + * mutually recursive. This leads to worse code, particularly on + * platforms like nvptx, where recursion is allowed only begrudgingly. + */ + return (one - igammac_impl::Impl(a, x)); + } + + return Impl(a, x); + } + + private: + /* igammac_impl calls igamma_impl::Impl. */ + friend struct igammac_impl; + + /* Actually computes igam(a, x). + * + * Preconditions: + * x > 0 + * a > 0 + * !(x > 1 && x > a) + */ + static Scalar Impl(Scalar a, Scalar x) { + const Scalar zero = 0; + const Scalar one = 1; const Scalar machep = igamma_helper::machep(); const Scalar maxlog = numext::log(NumTraits::highest()); const Scalar nan = NumTraits::quiet_NaN(); double ans, ax, c, r; - if (x == zero) return zero; - - if ((x < zero) || ( a <= zero)) { // domain error - return nan; - } - - if ((x > one) && (x > a)) { - return (one - igammac_impl::run(a, x)); - } - /* Compute x**a * exp(-x) / gamma(a) */ ax = a * numext::log(x) - x - lgamma_impl::run(a); if (ax < -maxlog) {