Refactored code into type-specific helper functions.

This commit is contained in:
Till Hoffmann 2016-04-04 19:16:03 +01:00
parent 80eba21ad0
commit b97911dd18

View File

@ -744,6 +744,56 @@ struct zeta_impl {
};
#else
template <typename Scalar>
struct zeta_impl_series {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
THIS_TYPE_IS_NOT_SUPPORTED);
return Scalar(0);
}
};
template <>
struct zeta_impl_series<float> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static bool run(float& a, float& b, float& s, const float x, const float machep) {
int i = 0;
while(i < 9)
{
i += 1;
a += 1.0f;
b = numext::pow( a, -x );
s += b;
if( numext::abs(b/s) < machep )
return true;
}
//Return whether we are done
return false;
}
};
template <>
struct zeta_impl_series<double> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
static bool run(double& a, double& b, double& s, const double x, const double machep) {
int i = 0;
while( (i < 9) || (a <= 9.0) )
{
i += 1;
a += 1.0;
b = numext::pow( a, -x );
s += b;
if( numext::abs(b/s) < machep )
return true;
}
//Return whether we are done
return false;
}
};
template <typename Scalar>
struct zeta_impl {
@ -809,18 +859,18 @@ struct zeta_impl {
Scalar p, r, a, b, k, s, t, w;
const Scalar A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
Scalar(12.0),
Scalar(-720.0),
Scalar(30240.0),
Scalar(-1209600.0),
Scalar(47900160.0),
Scalar(-1.8924375803183791606e9), /*1.307674368e12/691*/
Scalar(7.47242496e10),
Scalar(-2.950130727918164224e12), /*1.067062284288e16/3617*/
Scalar(1.1646782814350067249e14), /*5.109094217170944e18/43867*/
Scalar(-4.5979787224074726105e15), /*8.028576626982912e20/174611*/
Scalar(1.8152105401943546773e17), /*1.5511210043330985984e23/854513*/
Scalar(-7.1661652561756670113e18) /*1.6938241367317436694528e27/236364091*/
};
const Scalar maxnum = NumTraits<Scalar>::infinity();
@ -854,16 +904,10 @@ struct zeta_impl {
*/
s = numext::pow( q, -x );
a = q;
i = 0;
b = zero;
while( (i < 9) || (a <= Scalar(9)) )
{
i += 1;
a += one;
b = numext::pow( a, -x );
s += b;
if( numext::abs(b/s) < machep )
return s;
// Run the summation in a helper function that is specific to the floating precision
if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
return s;
}
w = a;