mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Merged in tillahoffmann/eigen (pull request PR-173)
Added zeta function of two arguments and polygamma function
This commit is contained in:
commit
10bdd8e378
@ -76,6 +76,8 @@ struct default_packet_traits
|
||||
HasTanh = 0,
|
||||
HasLGamma = 0,
|
||||
HasDiGamma = 0,
|
||||
HasZeta = 0,
|
||||
HasPolygamma = 0,
|
||||
HasErf = 0,
|
||||
HasErfc = 0,
|
||||
HasIGamma = 0,
|
||||
@ -450,6 +452,14 @@ Packet plgamma(const Packet& a) { using numext::lgamma; return lgamma(a); }
|
||||
/** \internal \returns the derivative of lgamma, psi(\a a) (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet pdigamma(const Packet& a) { using numext::digamma; return digamma(a); }
|
||||
|
||||
/** \internal \returns the zeta function of two arguments (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet pzeta(const Packet& x, const Packet& q) { using numext::zeta; return zeta(x, q); }
|
||||
|
||||
/** \internal \returns the polygamma function (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet ppolygamma(const Packet& n, const Packet& x) { using numext::polygamma; return polygamma(n, x); }
|
||||
|
||||
/** \internal \returns the erf(\a a) (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
|
@ -51,6 +51,8 @@ namespace Eigen
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(tanh,scalar_tanh_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(zeta,scalar_zeta_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(polygamma,scalar_polygamma_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erfc,scalar_erfc_op)
|
||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(exp,scalar_exp_op)
|
||||
|
@ -722,6 +722,268 @@ struct igamma_impl {
|
||||
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of Riemann zeta function of two arguments *
|
||||
****************************************************************************/
|
||||
|
||||
template <typename Scalar>
|
||||
struct zeta_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
#ifndef EIGEN_HAS_C99_MATH
|
||||
|
||||
template <typename Scalar>
|
||||
struct zeta_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar x, Scalar q) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
struct zeta_impl_series {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct zeta_impl_series<float> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static bool run(float& a, float& b, float& s, const float x, const float machep) {
|
||||
int i = 0;
|
||||
while(i < 9)
|
||||
{
|
||||
i += 1;
|
||||
a += 1.0f;
|
||||
b = numext::pow( a, -x );
|
||||
s += b;
|
||||
if( numext::abs(b/s) < machep )
|
||||
return true;
|
||||
}
|
||||
|
||||
//Return whether we are done
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct zeta_impl_series<double> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
static bool run(double& a, double& b, double& s, const double x, const double machep) {
|
||||
int i = 0;
|
||||
while( (i < 9) || (a <= 9.0) )
|
||||
{
|
||||
i += 1;
|
||||
a += 1.0;
|
||||
b = numext::pow( a, -x );
|
||||
s += b;
|
||||
if( numext::abs(b/s) < machep )
|
||||
return true;
|
||||
}
|
||||
|
||||
//Return whether we are done
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct zeta_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar x, Scalar q) {
|
||||
/* zeta.c
|
||||
*
|
||||
* Riemann zeta function of two arguments
|
||||
*
|
||||
*
|
||||
*
|
||||
* SYNOPSIS:
|
||||
*
|
||||
* double x, q, y, zeta();
|
||||
*
|
||||
* y = zeta( x, q );
|
||||
*
|
||||
*
|
||||
*
|
||||
* DESCRIPTION:
|
||||
*
|
||||
*
|
||||
*
|
||||
* inf.
|
||||
* - -x
|
||||
* zeta(x,q) = > (k+q)
|
||||
* -
|
||||
* k=0
|
||||
*
|
||||
* where x > 1 and q is not a negative integer or zero.
|
||||
* The Euler-Maclaurin summation formula is used to obtain
|
||||
* the expansion
|
||||
*
|
||||
* n
|
||||
* - -x
|
||||
* zeta(x,q) = > (k+q)
|
||||
* -
|
||||
* k=1
|
||||
*
|
||||
* 1-x inf. B x(x+1)...(x+2j)
|
||||
* (n+q) 1 - 2j
|
||||
* + --------- - ------- + > --------------------
|
||||
* x-1 x - x+2j+1
|
||||
* 2(n+q) j=1 (2j)! (n+q)
|
||||
*
|
||||
* where the B2j are Bernoulli numbers. Note that (see zetac.c)
|
||||
* zeta(x,1) = zetac(x) + 1.
|
||||
*
|
||||
*
|
||||
*
|
||||
* ACCURACY:
|
||||
*
|
||||
* Relative error for single precision:
|
||||
* arithmetic domain # trials peak rms
|
||||
* IEEE 0,25 10000 6.9e-7 1.0e-7
|
||||
*
|
||||
* Large arguments may produce underflow in powf(), in which
|
||||
* case the results are inaccurate.
|
||||
*
|
||||
* REFERENCE:
|
||||
*
|
||||
* Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals,
|
||||
* Series, and Products, p. 1073; Academic Press, 1980.
|
||||
*
|
||||
*/
|
||||
|
||||
int i;
|
||||
Scalar p, r, a, b, k, s, t, w;
|
||||
|
||||
const Scalar A[] = {
|
||||
Scalar(12.0),
|
||||
Scalar(-720.0),
|
||||
Scalar(30240.0),
|
||||
Scalar(-1209600.0),
|
||||
Scalar(47900160.0),
|
||||
Scalar(-1.8924375803183791606e9), /*1.307674368e12/691*/
|
||||
Scalar(7.47242496e10),
|
||||
Scalar(-2.950130727918164224e12), /*1.067062284288e16/3617*/
|
||||
Scalar(1.1646782814350067249e14), /*5.109094217170944e18/43867*/
|
||||
Scalar(-4.5979787224074726105e15), /*8.028576626982912e20/174611*/
|
||||
Scalar(1.8152105401943546773e17), /*1.5511210043330985984e23/854513*/
|
||||
Scalar(-7.1661652561756670113e18) /*1.6938241367317436694528e27/236364091*/
|
||||
};
|
||||
|
||||
const Scalar maxnum = NumTraits<Scalar>::infinity();
|
||||
const Scalar zero = 0.0, half = 0.5, one = 1.0;
|
||||
const Scalar machep = igamma_helper<Scalar>::machep();
|
||||
|
||||
if( x == one )
|
||||
return maxnum;
|
||||
|
||||
if( x < one )
|
||||
{
|
||||
return zero;
|
||||
}
|
||||
|
||||
if( q <= zero )
|
||||
{
|
||||
if(q == numext::floor(q))
|
||||
{
|
||||
return maxnum;
|
||||
}
|
||||
p = x;
|
||||
r = numext::floor(p);
|
||||
if (p != r)
|
||||
return zero;
|
||||
}
|
||||
|
||||
/* Permit negative q but continue sum until n+q > +9 .
|
||||
* This case should be handled by a reflection formula.
|
||||
* If q<0 and x is an integer, there is a relation to
|
||||
* the polygamma function.
|
||||
*/
|
||||
s = numext::pow( q, -x );
|
||||
a = q;
|
||||
b = zero;
|
||||
// Run the summation in a helper function that is specific to the floating precision
|
||||
if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
|
||||
return s;
|
||||
}
|
||||
|
||||
w = a;
|
||||
s += b*w/(x-one);
|
||||
s -= half * b;
|
||||
a = one;
|
||||
k = zero;
|
||||
for( i=0; i<12; i++ )
|
||||
{
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
t = a*b/A[i];
|
||||
s = s + t;
|
||||
t = numext::abs(t/s);
|
||||
if( t < machep )
|
||||
return s;
|
||||
k += one;
|
||||
a *= x + k;
|
||||
b /= w;
|
||||
k += one;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of polygamma function *
|
||||
****************************************************************************/
|
||||
|
||||
template <typename Scalar>
|
||||
struct polygamma_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
#ifndef EIGEN_HAS_C99_MATH
|
||||
|
||||
template <typename Scalar>
|
||||
struct polygamma_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar n, Scalar x) {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
return Scalar(0);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
struct polygamma_impl {
|
||||
EIGEN_DEVICE_FUNC
|
||||
static Scalar run(Scalar n, Scalar x) {
|
||||
Scalar zero = 0.0, one = 1.0;
|
||||
Scalar nplus = n + one;
|
||||
|
||||
// Just return the digamma function for n = 1
|
||||
if (n == zero) {
|
||||
return digamma_impl<Scalar>::run(x);
|
||||
}
|
||||
// Use the same implementation as scipy
|
||||
else {
|
||||
Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
|
||||
return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_HAS_C99_MATH
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
namespace numext {
|
||||
@ -737,6 +999,18 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar)
|
||||
digamma(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar)
|
||||
zeta(const Scalar& x, const Scalar& q) {
|
||||
return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar)
|
||||
polygamma(const Scalar& n, const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar)
|
||||
|
@ -91,6 +91,34 @@ double2 pdigamma<double2>(const double2& a)
|
||||
using numext::digamma;
|
||||
return make_double2(digamma(a.x), digamma(a.y));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
float4 pzeta<float4>(const float4& x, const float4& q)
|
||||
{
|
||||
using numext::zeta;
|
||||
return make_float4(zeta(x.x, q.x), zeta(x.y, q.y), zeta(x.z, q.z), zeta(x.w, q.w));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
double2 pzeta<double2>(const double2& x, const double2& q)
|
||||
{
|
||||
using numext::zeta;
|
||||
return make_double2(zeta(x.x, q.x), zeta(x.y, q.y));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
float4 ppolygamma<float4>(const float4& n, const float4& x)
|
||||
{
|
||||
using numext::polygamma;
|
||||
return make_float4(polygamma(n.x, x.x), polygamma(n.y, x.y), polygamma(n.z, x.z), polygamma(n.w, x.w));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
double2 ppolygamma<double2>(const double2& n, const double2& x)
|
||||
{
|
||||
using numext::polygamma;
|
||||
return make_double2(polygamma(n.x, x.x), polygamma(n.y, x.y));
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
float4 perf<float4>(const float4& a)
|
||||
|
@ -40,6 +40,8 @@ template<> struct packet_traits<float> : default_packet_traits
|
||||
HasRsqrt = 1,
|
||||
HasLGamma = 1,
|
||||
HasDiGamma = 1,
|
||||
HasZeta = 1,
|
||||
HasPolygamma = 1,
|
||||
HasErf = 1,
|
||||
HasErfc = 1,
|
||||
HasIgamma = 1,
|
||||
|
@ -448,6 +448,50 @@ struct functor_traits<scalar_digamma_op<Scalar> >
|
||||
PacketAccess = packet_traits<Scalar>::HasDiGamma
|
||||
};
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the Riemann Zeta function of two arguments.
|
||||
* \sa class CwiseUnaryOp, Cwise::zeta()
|
||||
*/
|
||||
template<typename Scalar> struct scalar_zeta_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_zeta_op)
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& x, const Scalar& q) const {
|
||||
using numext::zeta; return zeta(x, q);
|
||||
}
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& x, const Packet& q) const { return internal::pzeta(x, q); }
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_zeta_op<Scalar> >
|
||||
{
|
||||
enum {
|
||||
// Guesstimate
|
||||
Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
|
||||
PacketAccess = packet_traits<Scalar>::HasZeta
|
||||
};
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the polygamma function.
|
||||
* \sa class CwiseUnaryOp, Cwise::polygamma()
|
||||
*/
|
||||
template<typename Scalar> struct scalar_polygamma_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op)
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& n, const Scalar& x) const {
|
||||
using numext::polygamma; return polygamma(n, x);
|
||||
}
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); }
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_polygamma_op<Scalar> >
|
||||
{
|
||||
enum {
|
||||
// Guesstimate
|
||||
Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
|
||||
PacketAccess = packet_traits<Scalar>::HasPolygamma
|
||||
};
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the Gauss error function of a
|
||||
|
@ -23,6 +23,8 @@ typedef CwiseUnaryOp<internal::scalar_sinh_op<Scalar>, const Derived> SinhReturn
|
||||
typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_lgamma_op<Scalar>, const Derived> LgammaReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_zeta_op<Scalar>, const Derived> ZetaReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_polygamma_op<Scalar>, const Derived> PolygammaReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType;
|
||||
typedef CwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> PowReturnType;
|
||||
@ -329,6 +331,22 @@ digamma() const
|
||||
return DigammaReturnType(derived());
|
||||
}
|
||||
|
||||
/** \returns an expression of the coefficient-wise zeta function.
|
||||
*/
|
||||
inline const ZetaReturnType
|
||||
zeta() const
|
||||
{
|
||||
return ZetaReturnType(derived());
|
||||
}
|
||||
|
||||
/** \returns an expression of the coefficient-wise polygamma function.
|
||||
*/
|
||||
inline const PolygammaReturnType
|
||||
polygamma() const
|
||||
{
|
||||
return PolygammaReturnType(derived());
|
||||
}
|
||||
|
||||
/** \returns an expression of the coefficient-wise Gauss error
|
||||
* function of *this.
|
||||
*
|
||||
|
@ -322,6 +322,32 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
||||
std::numeric_limits<RealScalar>::infinity());
|
||||
VERIFY_IS_EQUAL(numext::digamma(Scalar(-1)),
|
||||
std::numeric_limits<RealScalar>::infinity());
|
||||
|
||||
// Check the zeta function against scipy.special.zeta
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(1.5), Scalar(2)), RealScalar(1.61237534869));
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(4), Scalar(1.5)), RealScalar(0.234848505667));
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(10.5), Scalar(3)), RealScalar(1.03086757337e-5));
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(10000.5), Scalar(1.0001)), RealScalar(0.367879440865));
|
||||
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
|
||||
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
|
||||
std::numeric_limits<RealScalar>::infinity());
|
||||
|
||||
// Check the polygamma against scipy.special.polygamma examples
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
|
||||
|
||||
// Check the polygamma function over a larger range of values
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(17), Scalar(4.7)), RealScalar(293.334565435));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(31), Scalar(11.8)), RealScalar(0.445487887616));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(28), Scalar(17.7)), RealScalar(-2.47810300902e-07));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(8), Scalar(30.2)), RealScalar(-8.29668781082e-09));
|
||||
/* The following tests only pass for doubles because floats cannot handle the large values of
|
||||
the gamma function.
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(42), Scalar(15.8)), RealScalar(-0.434562276666));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(147), Scalar(54.1)), RealScalar(0.567742190178));
|
||||
VERIFY_IS_APPROX(numext::polygamma(Scalar(170), Scalar(64)), RealScalar(-0.0108615497927));
|
||||
*/
|
||||
|
||||
{
|
||||
// Test various propreties of igamma & igammac. These are normalized
|
||||
|
@ -353,6 +353,20 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
igammac(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), internal::scalar_igammac_op<Scalar>());
|
||||
}
|
||||
|
||||
// zeta(x = this, q = other)
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<internal::scalar_zeta_op<Scalar>, const Derived, const OtherDerived>
|
||||
igammac(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), internal::scalar_igammac_op<Scalar>());
|
||||
}
|
||||
|
||||
// polygamma(n = this, x = other)
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<internal::scalar_polygamma_op<Scalar>, const Derived, const OtherDerived>
|
||||
igammac(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), internal::scalar_igammac_op<Scalar>());
|
||||
}
|
||||
|
||||
// comparisons and tests for Scalars
|
||||
EIGEN_DEVICE_FUNC
|
||||
|
@ -626,6 +626,127 @@ void test_cuda_digamma()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void test_cuda_zeta()
|
||||
{
|
||||
Tensor<Scalar, 1> in_x(6);
|
||||
Tensor<Scalar, 1> in_q(6);
|
||||
Tensor<Scalar, 1> out(6);
|
||||
Tensor<Scalar, 1> expected_out(6);
|
||||
out.setZero();
|
||||
|
||||
in_x(0) = Scalar(1);
|
||||
in_x(1) = Scalar(1.5);
|
||||
in_x(2) = Scalar(4);
|
||||
in_x(3) = Scalar(-10.5);
|
||||
in_x(4) = Scalar(10000.5);
|
||||
in_x(5) = Scalar(3);
|
||||
|
||||
in_q(0) = Scalar(1.2345);
|
||||
in_q(1) = Scalar(2);
|
||||
in_q(2) = Scalar(1.5);
|
||||
in_q(3) = Scalar(3);
|
||||
in_q(4) = Scalar(1.0001);
|
||||
in_q(5) = Scalar(-2.5);
|
||||
|
||||
expected_out(0) = std::numeric_limits<Scalar>::infinity();
|
||||
expected_out(1) = Scalar(1.61237534869);
|
||||
expected_out(2) = Scalar(0.234848505667);
|
||||
expected_out(3) = Scalar(1.03086757337e-5);
|
||||
expected_out(4) = Scalar(0.367879440865);
|
||||
expected_out(5) = Scalar(0.054102025820864097);
|
||||
|
||||
std::size_t bytes = in_x.size() * sizeof(Scalar);
|
||||
|
||||
Scalar* d_in_x, d_in_q;
|
||||
Scalar* d_out;
|
||||
cudaMalloc((void**)(&d_in_x), bytes);
|
||||
cudaMalloc((void**)(&d_in_q), bytes);
|
||||
cudaMalloc((void**)(&d_out), bytes);
|
||||
|
||||
cudaMemcpy(d_in_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_in_q, in_q.data(), bytes, cudaMemcpyHostToDevice);
|
||||
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_x(d_in_x, 6);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_q(d_in_q, 6);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 6);
|
||||
|
||||
gpu_out.device(gpu_device) = gpu_in_x.zeta(gpu_in_q);
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
|
||||
VERIFY_IS_EQUAL(out(0), expected_out(0));
|
||||
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
VERIFY_IS_APPROX(out(i), expected_out(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void test_cuda_polygamma()
|
||||
{
|
||||
Tensor<Scalar, 1> in_x(7);
|
||||
Tensor<Scalar, 1> in_n(7);
|
||||
Tensor<Scalar, 1> out(7);
|
||||
Tensor<Scalar, 1> expected_out(7);
|
||||
out.setZero();
|
||||
|
||||
in_n(0) = Scalar(1);
|
||||
in_n(1) = Scalar(1);
|
||||
in_n(2) = Scalar(1);
|
||||
in_n(3) = Scalar(17);
|
||||
in_n(4) = Scalar(31);
|
||||
in_n(5) = Scalar(28);
|
||||
in_n(6) = Scalar(8);
|
||||
|
||||
in_x(0) = Scalar(2);
|
||||
in_x(1) = Scalar(3);
|
||||
in_x(2) = Scalar(25.5);
|
||||
in_x(3) = Scalar(4.7);
|
||||
in_x(4) = Scalar(11.8);
|
||||
in_x(5) = Scalar(17.7);
|
||||
in_x(6) = Scalar(30.2);
|
||||
|
||||
expected_out(0) = Scalar(0.644934066848);
|
||||
expected_out(1) = Scalar(0.394934066848);
|
||||
expected_out(2) = Scalar(0.0399946696496);
|
||||
expected_out(3) = Scalar(293.334565435);
|
||||
expected_out(4) = Scalar(0.445487887616);
|
||||
expected_out(5) = Scalar(-2.47810300902e-07);
|
||||
expected_out(6) = Scalar(-8.29668781082e-09);
|
||||
|
||||
std::size_t bytes = in_x.size() * sizeof(Scalar);
|
||||
|
||||
Scalar* d_in_x, d_in_n;
|
||||
Scalar* d_out;
|
||||
cudaMalloc((void**)(&d_in_x), bytes);
|
||||
cudaMalloc((void**)(&d_in_n), bytes);
|
||||
cudaMalloc((void**)(&d_out), bytes);
|
||||
|
||||
cudaMemcpy(d_in_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_in_n, in_n.data(), bytes, cudaMemcpyHostToDevice);
|
||||
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_x(d_in_x, 7);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_n(d_in_n, 7);
|
||||
Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 7);
|
||||
|
||||
gpu_out.device(gpu_device) = gpu_in_n.zeta(gpu_in_x);
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
VERIFY_IS_APPROX(out(i), expected_out(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void test_cuda_igamma()
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user