Eliminate mutual recursion in igamma{,c}_impl::Run.

Presently, igammac_impl::Run calls igamma_impl::Run, which in turn calls
igammac_impl::Run.

This isn't actually mutual recursion; the calls are guarded such that we never
get into a loop.  Nonetheless, it's a stretch for clang to prove this.  As a
result, clang emits a recursive call in both igammac_impl::Run and
igamma_impl::Run.

That this is suboptimal code is bad enough, but it's particularly bad when
compiling for CUDA/nvptx.  nvptx allows recursion, but only begrudgingly: If
you have recursive calls in a kernel, it's on you to manually specify the
kernel's stack size.  Otherwise, ptxas will dump a warning, make a guess, and
who knows if it's right.

This change explicitly eliminates the mutual recursion in igammac_impl::Run and
igamma_impl::Run.
This commit is contained in:
Justin Lebar 2016-04-28 13:57:08 -07:00
parent 3ec81fc00f
commit 40d1e2f8c7

View File

@ -517,26 +517,51 @@ struct igammac_impl {
*/ */
const Scalar zero = 0; const Scalar zero = 0;
const Scalar one = 1; const Scalar one = 1;
const Scalar two = 2;
const Scalar machep = igamma_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
const Scalar big = igamma_helper<Scalar>::big();
const Scalar biginv = 1 / big;
const Scalar nan = NumTraits<Scalar>::quiet_NaN(); const Scalar nan = NumTraits<Scalar>::quiet_NaN();
const Scalar inf = NumTraits<Scalar>::infinity();
Scalar ans, ax, c, yc, r, t, y, z; if ((x < zero) || (a <= zero)) {
Scalar pk, pkm1, pkm2, qk, qkm1, qkm2;
if ((x < zero) || ( a <= zero)) {
// domain error // domain error
return nan; return nan;
} }
if ((x < one) || (x < a)) { if ((x < one) || (x < a)) {
return (one - igamma_impl<Scalar>::run(a, x)); /* The checks above ensure that we meet the preconditions for
* igamma_impl::Impl(), so call it, rather than igamma_impl::Run().
* Calling Run() would also work, but in that case the compiler may not be
* able to prove that igammac_impl::Run and igamma_impl::Run are not
* mutually recursive. This leads to worse code, particularly on
* platforms like nvptx, where recursion is allowed only begrudgingly.
*/
return (one - igamma_impl<Scalar>::Impl(a, x));
} }
return Impl(a, x);
}
private:
/* igamma_impl calls igammac_impl::Impl. */
friend struct igamma_impl<Scalar>;
/* Actually computes igamc(a, x).
*
* Preconditions:
* a > 0
* x >= 1
* x >= a
*/
static Scalar Impl(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar two = 2;
const Scalar machep = igamma_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
const Scalar big = igamma_helper<Scalar>::big();
const Scalar biginv = 1 / big;
const Scalar inf = NumTraits<Scalar>::infinity();
Scalar ans, ax, c, yc, r, t, y, z;
Scalar pk, pkm1, pkm2, qk, qkm1, qkm2;
if (x == inf) return zero; // std::isinf crashes on CUDA if (x == inf) return zero; // std::isinf crashes on CUDA
/* Compute x**a * exp(-x) / gamma(a) */ /* Compute x**a * exp(-x) / gamma(a) */
@ -678,22 +703,48 @@ struct igamma_impl {
*/ */
const Scalar zero = 0; const Scalar zero = 0;
const Scalar one = 1; const Scalar one = 1;
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
if (x == zero) return zero;
if ((x < zero) || (a <= zero)) { // domain error
return nan;
}
if ((x > one) && (x > a)) {
/* The checks above ensure that we meet the preconditions for
* igammac_impl::Impl(), so call it, rather than igammac_impl::Run().
* Calling Run() would also work, but in that case the compiler may not be
* able to prove that igammac_impl::Run and igamma_impl::Run are not
* mutually recursive. This leads to worse code, particularly on
* platforms like nvptx, where recursion is allowed only begrudgingly.
*/
return (one - igammac_impl<Scalar>::Impl(a, x));
}
return Impl(a, x);
}
private:
/* igammac_impl calls igamma_impl::Impl. */
friend struct igammac_impl<Scalar>;
/* Actually computes igam(a, x).
*
* Preconditions:
* x > 0
* a > 0
* !(x > 1 && x > a)
*/
static Scalar Impl(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar machep = igamma_helper<Scalar>::machep(); const Scalar machep = igamma_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest()); const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
const Scalar nan = NumTraits<Scalar>::quiet_NaN(); const Scalar nan = NumTraits<Scalar>::quiet_NaN();
double ans, ax, c, r; double ans, ax, c, r;
if (x == zero) return zero;
if ((x < zero) || ( a <= zero)) { // domain error
return nan;
}
if ((x > one) && (x > a)) {
return (one - igammac_impl<Scalar>::run(a, x));
}
/* Compute x**a * exp(-x) / gamma(a) */ /* Compute x**a * exp(-x) / gamma(a) */
ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a); ax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
if (ax < -maxlog) { if (ax < -maxlog) {