From c8d94ae944407c05ae7600347afb6a532783c962 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 27 Jan 2016 09:52:29 -0800 Subject: [PATCH] digamma special function: merge shared code. Moved type-specific code into a helper class digamma_impl_maybe_poly. --- Eigen/src/Core/SpecialFunctions.h | 219 +++++++++++------------------- 1 file changed, 82 insertions(+), 137 deletions(-) diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h index bd022946c..21583e6f5 100644 --- a/Eigen/src/Core/SpecialFunctions.h +++ b/Eigen/src/Core/SpecialFunctions.h @@ -134,8 +134,23 @@ struct lgamma_impl { * Implementation of digamma (psi) * ****************************************************************************/ +#ifdef EIGEN_HAS_C99_MATH + +/* + * + * Polynomial evaluation helper for the Psi (digamma) function. + * + * digamma_impl_maybe_poly::run(s) evaluates the asymptotic Psi expansion for + * input Scalar s, assuming s is above 10.0. + * + * If s is above a certain threshold for the given Scalar type, zero + * is returned. Otherwise the polynomial is evaluated with enough + * coefficients for results matching Scalar machine precision. + * + * + */ template -struct digamma_impl { +struct digamma_impl_maybe_poly { EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Scalar run(const Scalar) { EIGEN_STATIC_ASSERT((internal::is_same::value == false), @@ -144,72 +159,11 @@ struct digamma_impl { } }; -template -struct digamma_retval { - typedef Scalar type; -}; -#ifdef EIGEN_HAS_C99_MATH template <> -struct digamma_impl { - /* - * Psi (digamma) function (modified for Eigen) - * - * - * SYNOPSIS: - * - * float x, y, psif(); - * - * y = psif( x ); - * - * - * DESCRIPTION: - * - * d - - * psi(x) = -- ln | (x) - * dx - * - * is the logarithmic derivative of the gamma function. - * For integer x, - * n-1 - * - - * psi(n) = -EUL + > 1/k. - * - - * k=1 - * - * If x is negative, it is transformed to a positive argument by the - * reflection formula psi(1-x) = psi(x) + pi cot(pi x). - * For general positive x, the argument is made greater than 10 - * using the recurrence psi(x+1) = psi(x) + 1/x. - * Then the following asymptotic expansion is applied: - * - * inf. B - * - 2k - * psi(x) = log(x) - 1/2x - > ------- - * - 2k - * k=1 2k x - * - * where the B2k are Bernoulli numbers. - * - * ACCURACY: - * Absolute error, relative when |psi| > 1 : - * arithmetic domain # trials peak rms - * IEEE -33,0 30000 8.2e-7 1.2e-7 - * IEEE 0,33 100000 7.3e-7 7.7e-8 - * - * ERROR MESSAGES: - * message condition value returned - * psi singularity x integer <=0 INFINITY - */ +struct digamma_impl_maybe_poly { EIGEN_DEVICE_FUNC - static float run(float xx) { - float p, q, nz, x, s, w, y, z; - bool negative; - - // Some necessary constants - const float m_pif = 3.141592653589793238; - const float maxnumf = std::numeric_limits::infinity(); - + static EIGEN_STRONG_INLINE float run(const float s) { const float A[] = { -4.16666666666666666667E-3, 3.96825396825396825397E-3, @@ -217,53 +171,49 @@ struct digamma_impl { 8.33333333333333333333E-2 }; - x = xx; - nz = 0.0f; - negative = 0; - if (x <= 0.0f) { - negative = 1; - q = x; - p = ::floor(q); - if (p == q) { - return (maxnumf); - } - nz = q - p; - if (nz != 0.5f) { - if (nz > 0.5f) { - p += 1.0f; - nz = q - p; - } - nz = m_pif / ::tan(m_pif * nz); - } else { - nz = 0.0f; - } - x = 1.0f - x; - } - - /* use the recurrence psi(x+1) = psi(x) + 1/x. */ - s = x; - w = 0.0f; - while (s < 10.0f) { - w += 1.0f / s; - s += 1.0f; - } - + float z; if (s < 1.0e8f) { z = 1.0f / (s * s); - y = z * cephes::polevl::run(z, A); - } else - y = 0.0f; - - y = ::log(s) - (0.5f / s) - y - w; - - return (negative) ? y - nz : y; + return z * cephes::polevl::run(z, A); + } else return 0.0f; } }; template <> -struct digamma_impl { +struct digamma_impl_maybe_poly { EIGEN_DEVICE_FUNC - static double run(double x) { + static EIGEN_STRONG_INLINE double run(const double s) { + const double A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2 + }; + + double z; + if (s < 1.0e17) { + z = 1.0 / (s * s); + return z * cephes::polevl::run(z, A); + } + else return 0.0; + } +}; + +#endif // EIGEN_HAS_C99_MATH + +template +struct digamma_retval { + typedef Scalar type; +}; + +#ifdef EIGEN_HAS_C99_MATH +template +struct digamma_impl { + EIGEN_DEVICE_FUNC + static Scalar run(Scalar x) { /* * * Psi (digamma) function (modified for Eigen) @@ -304,38 +254,38 @@ struct digamma_impl { * * where the B2k are Bernoulli numbers. * - * ACCURACY: + * ACCURACY (float): * Relative error (except absolute when |psi| < 1): * arithmetic domain # trials peak rms * IEEE 0,30 30000 1.3e-15 1.4e-16 * IEEE -30,0 40000 1.5e-15 2.2e-16 * + * ACCURACY (double): + * Absolute error, relative when |psi| > 1 : + * arithmetic domain # trials peak rms + * IEEE -33,0 30000 8.2e-7 1.2e-7 + * IEEE 0,33 100000 7.3e-7 7.7e-8 + * * ERROR MESSAGES: * message condition value returned * psi singularity x integer <=0 INFINITY */ - double p, q, nz, s, w, y, z; + Scalar p, q, nz, s, w, y; bool negative; - const double A[] = { - 8.33333333333333333333E-2, - -2.10927960927960927961E-2, - 7.57575757575757575758E-3, - -4.16666666666666666667E-3, - 3.96825396825396825397E-3, - -8.33333333333333333333E-3, - 8.33333333333333333333E-2 - }; - - const double maxnum = std::numeric_limits::infinity(); - const double m_pi = 3.14159265358979323846; + const Scalar maxnum = std::numeric_limits::infinity(); + const Scalar m_pi = 3.14159265358979323846; negative = 0; nz = 0.0; - if (x <= 0.0) { - negative = 1; + const Scalar zero = 0.0; + const Scalar one = 1.0; + const Scalar half = 0.5; + + if (x <= zero) { + negative = one; q = x; p = ::floor(q); if (p == q) { @@ -345,41 +295,36 @@ struct digamma_impl { * by subtracting the nearest integer from x */ nz = q - p; - if (nz != 0.5) { - if (nz > 0.5) { - p += 1.0; + if (nz != half) { + if (nz > half) { + p += one; nz = q - p; } nz = m_pi / ::tan(m_pi * nz); } else { - nz = 0.0; + nz = zero; } - x = 1.0 - x; + x = one - x; } /* use the recurrence psi(x+1) = psi(x) + 1/x. */ s = x; - w = 0.0; - while (s < 10.0) { - w += 1.0 / s; - s += 1.0; + w = zero; + while (s < Scalar(10)) { + w += one / s; + s += one; } - if (s < 1.0e17) { - z = 1.0 / (s * s); - y = z * cephes::polevl::run(z, A); - } - else - y = 0.0; + y = digamma_impl_maybe_poly::run(s); - y = ::log(s) - (0.5 / s) - y - w; + y = ::log(s) - (half / s) - y - w; return (negative) ? y - nz : y; } }; -#endif +#endif // EIGEN_HAS_C99_MATH /**************************************************************************** * Implementation of erf *