mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Added polygamma function.
This commit is contained in:
parent
dd5d390daf
commit
57239f4a81
@ -77,6 +77,7 @@ struct default_packet_traits
|
|||||||
HasLGamma = 0,
|
HasLGamma = 0,
|
||||||
HasDiGamma = 0,
|
HasDiGamma = 0,
|
||||||
HasZeta = 0,
|
HasZeta = 0,
|
||||||
|
HasPolygamma = 0,
|
||||||
HasErf = 0,
|
HasErf = 0,
|
||||||
HasErfc = 0,
|
HasErfc = 0,
|
||||||
HasIGamma = 0,
|
HasIGamma = 0,
|
||||||
@ -456,6 +457,10 @@ Packet pdigamma(const Packet& a) { using numext::digamma; return digamma(a); }
|
|||||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||||
Packet pzeta(const Packet& x, const Packet& q) { using numext::zeta; return zeta(x, q); }
|
Packet pzeta(const Packet& x, const Packet& q) { using numext::zeta; return zeta(x, q); }
|
||||||
|
|
||||||
|
/** \internal \returns the polygamma function (coeff-wise) */
|
||||||
|
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||||
|
Packet ppolygamma(const Packet& n, const Packet& x) { using numext::polygamma; return polygamma(n, x); }
|
||||||
|
|
||||||
/** \internal \returns the erf(\a a) (coeff-wise) */
|
/** \internal \returns the erf(\a a) (coeff-wise) */
|
||||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||||
Packet perf(const Packet& a) { using numext::erf; return erf(a); }
|
Packet perf(const Packet& a) { using numext::erf; return erf(a); }
|
||||||
|
@ -52,6 +52,7 @@ namespace Eigen
|
|||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(zeta,scalar_zeta_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(zeta,scalar_zeta_op)
|
||||||
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(polygamma,scalar_polygamma_op)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erfc,scalar_erfc_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erfc,scalar_erfc_op)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(exp,scalar_exp_op)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(exp,scalar_exp_op)
|
||||||
|
@ -736,7 +736,7 @@ struct zeta_retval {
|
|||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
struct zeta_impl {
|
struct zeta_impl {
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
static Scalar run(Scalar x) {
|
static Scalar run(Scalar x, Scalar q) {
|
||||||
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||||
return Scalar(0);
|
return Scalar(0);
|
||||||
@ -905,6 +905,50 @@ struct zeta_impl {
|
|||||||
|
|
||||||
#endif // EIGEN_HAS_C99_MATH
|
#endif // EIGEN_HAS_C99_MATH
|
||||||
|
|
||||||
|
/****************************************************************************
|
||||||
|
* Implementation of polygamma function *
|
||||||
|
****************************************************************************/
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct polygamma_retval {
|
||||||
|
typedef Scalar type;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef EIGEN_HAS_C99_MATH
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct polygamma_impl {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static Scalar run(Scalar n, Scalar x) {
|
||||||
|
EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
|
||||||
|
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||||
|
return Scalar(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct polygamma_impl {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static Scalar run(Scalar n, Scalar x) {
|
||||||
|
Scalar zero = 0.0, one = 1.0;
|
||||||
|
Scalar nplus = n + one;
|
||||||
|
|
||||||
|
// Just return the digamma function for n = 1
|
||||||
|
if (n == zero) {
|
||||||
|
return digamma_impl<Scalar>::run(x);
|
||||||
|
}
|
||||||
|
// Use the same implementation as scipy
|
||||||
|
else {
|
||||||
|
Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
|
||||||
|
return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // EIGEN_HAS_C99_MATH
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
namespace numext {
|
namespace numext {
|
||||||
@ -927,6 +971,12 @@ zeta(const Scalar& x, const Scalar& q) {
|
|||||||
return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
|
return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar)
|
||||||
|
polygamma(const Scalar& n, const Scalar& x) {
|
||||||
|
return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar)
|
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar)
|
||||||
erf(const Scalar& x) {
|
erf(const Scalar& x) {
|
||||||
|
@ -93,17 +93,31 @@ double2 pdigamma<double2>(const double2& a)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
float4 pzeta<float4>(const float4& a)
|
float4 pzeta<float4>(const float4& x, const float4& q)
|
||||||
{
|
{
|
||||||
using numext::zeta;
|
using numext::zeta;
|
||||||
return make_float4(zeta(a.x), zeta(a.y), zeta(a.z), zeta(a.w));
|
return make_float4(zeta(x.x, q.x), zeta(x.y, q.y), zeta(x.z, q.z), zeta(x.w, q.w));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
double2 pzeta<double2>(const double2& a)
|
double2 pzeta<double2>(const double2& x, const double2& q)
|
||||||
{
|
{
|
||||||
using numext::zeta;
|
using numext::zeta;
|
||||||
return make_double2(zeta(a.x), zeta(a.y));
|
return make_double2(zeta(x.x, q.x), zeta(x.y, q.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
float4 ppolygamma<float4>(const float4& n, const float4& x)
|
||||||
|
{
|
||||||
|
using numext::polygamma;
|
||||||
|
return make_float4(polygamma(n.x, x.x), polygamma(n.y, x.y), polygamma(n.z, x.z), polygamma(n.w, x.w));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
double2 ppolygamma<double2>(const double2& n, const double2& x)
|
||||||
|
{
|
||||||
|
using numext::polygamma;
|
||||||
|
return make_double2(polygamma(n.x, x.x), polygamma(n.y, x.y));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
@ -41,6 +41,7 @@ template<> struct packet_traits<float> : default_packet_traits
|
|||||||
HasLGamma = 1,
|
HasLGamma = 1,
|
||||||
HasDiGamma = 1,
|
HasDiGamma = 1,
|
||||||
HasZeta = 1,
|
HasZeta = 1,
|
||||||
|
HasPolygamma = 1,
|
||||||
HasErf = 1,
|
HasErf = 1,
|
||||||
HasErfc = 1,
|
HasErfc = 1,
|
||||||
HasIgamma = 1,
|
HasIgamma = 1,
|
||||||
|
@ -471,6 +471,28 @@ struct functor_traits<scalar_zeta_op<Scalar> >
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
* \brief Template functor to compute the polygamma function.
|
||||||
|
* \sa class CwiseUnaryOp, Cwise::polygamma()
|
||||||
|
*/
|
||||||
|
template<typename Scalar> struct scalar_polygamma_op {
|
||||||
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_polygamma_op)
|
||||||
|
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& n, const Scalar& x) const {
|
||||||
|
using numext::polygamma; return polygamma(n, x);
|
||||||
|
}
|
||||||
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
|
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& n, const Packet& x) const { return internal::ppolygamma(n, x); }
|
||||||
|
};
|
||||||
|
template<typename Scalar>
|
||||||
|
struct functor_traits<scalar_polygamma_op<Scalar> >
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
// Guesstimate
|
||||||
|
Cost = 10 * NumTraits<Scalar>::MulCost + 5 * NumTraits<Scalar>::AddCost,
|
||||||
|
PacketAccess = packet_traits<Scalar>::HasPolygamma
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/** \internal
|
/** \internal
|
||||||
* \brief Template functor to compute the Gauss error function of a
|
* \brief Template functor to compute the Gauss error function of a
|
||||||
* scalar
|
* scalar
|
||||||
|
@ -24,6 +24,7 @@ typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturn
|
|||||||
typedef CwiseUnaryOp<internal::scalar_lgamma_op<Scalar>, const Derived> LgammaReturnType;
|
typedef CwiseUnaryOp<internal::scalar_lgamma_op<Scalar>, const Derived> LgammaReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType;
|
typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_zeta_op<Scalar>, const Derived> ZetaReturnType;
|
typedef CwiseUnaryOp<internal::scalar_zeta_op<Scalar>, const Derived> ZetaReturnType;
|
||||||
|
typedef CwiseUnaryOp<internal::scalar_polygamma_op<Scalar>, const Derived> PolygammaReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType;
|
typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType;
|
typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> PowReturnType;
|
typedef CwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> PowReturnType;
|
||||||
@ -338,6 +339,14 @@ zeta() const
|
|||||||
return ZetaReturnType(derived());
|
return ZetaReturnType(derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \returns an expression of the coefficient-wise polygamma function.
|
||||||
|
*/
|
||||||
|
inline const PolygammaReturnType
|
||||||
|
polygamma() const
|
||||||
|
{
|
||||||
|
return PolygammaReturnType(derived());
|
||||||
|
}
|
||||||
|
|
||||||
/** \returns an expression of the coefficient-wise Gauss error
|
/** \returns an expression of the coefficient-wise Gauss error
|
||||||
* function of *this.
|
* function of *this.
|
||||||
*
|
*
|
||||||
|
@ -331,6 +331,11 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
|
VERIFY_IS_APPROX(numext::zeta(Scalar(3), Scalar(-2.5)), RealScalar(0.054102025820864097));
|
||||||
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
|
VERIFY_IS_EQUAL(numext::zeta(Scalar(1), Scalar(1.2345)), // The second scalar does not matter
|
||||||
std::numeric_limits<RealScalar>::infinity());
|
std::numeric_limits<RealScalar>::infinity());
|
||||||
|
|
||||||
|
// Check the polygamma against scipy.special.polygamma
|
||||||
|
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(2)), RealScalar(0.644934066848));
|
||||||
|
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(3)), RealScalar(0.394934066848));
|
||||||
|
VERIFY_IS_APPROX(numext::polygamma(Scalar(1), Scalar(25.5)), RealScalar(0.0399946696496));
|
||||||
|
|
||||||
{
|
{
|
||||||
// Test various propreties of igamma & igammac. These are normalized
|
// Test various propreties of igamma & igammac. These are normalized
|
||||||
|
@ -138,6 +138,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
zeta() const {
|
zeta() const {
|
||||||
return unaryExpr(internal::scalar_zeta_op<Scalar>());
|
return unaryExpr(internal::scalar_zeta_op<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_polygamma_op<Scalar>, const Derived>
|
||||||
|
polygamma() const {
|
||||||
|
return unaryExpr(internal::scalar_polygamma_op<Scalar>());
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user