mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-30 00:32:01 +08:00
fix tanh inconsistent
This commit is contained in:
parent
5cf1e4c79b
commit
1031223c09
@ -491,19 +491,62 @@ struct functor_traits<scalar_atan_op<Scalar> >
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/** \internal
|
/** \internal
|
||||||
* \brief Template functor to compute the tanh of a scalar
|
* \brief Template functor to compute the tanh of a scalar
|
||||||
* \sa class CwiseUnaryOp, ArrayBase::tanh()
|
* \sa class CwiseUnaryOp, ArrayBase::tanh()
|
||||||
*/
|
*/
|
||||||
template<typename Scalar> struct scalar_tanh_op {
|
template <typename Scalar>
|
||||||
|
struct scalar_tanh_op {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_op)
|
||||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::tanh(a); }
|
EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const {
|
||||||
|
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
|
||||||
|
Doesn't do anything fancy, just a 13/6-degree rational interpolant
|
||||||
|
which
|
||||||
|
is accurate up to a couple of ulp in the range [-9, 9], outside of
|
||||||
|
which
|
||||||
|
the fl(tanh(x)) = +/-1. */
|
||||||
|
|
||||||
|
// Clamp the inputs to the range [-9, 9] since anything outside
|
||||||
|
// this range is +/-1.0f in single-precision.
|
||||||
|
const Scalar plus_9 = static_cast<Scalar>(9.0);
|
||||||
|
const Scalar minus_9 = static_cast<Scalar>(-9.0);
|
||||||
|
const Scalar x = numext::maxi(minus_9, numext::mini(plus_9, a));
|
||||||
|
// Scalarhe monomial coefficients of the numerator polynomial (odd).
|
||||||
|
const Scalar alpha_1 = static_cast<Scalar>(4.89352455891786e-03);
|
||||||
|
const Scalar alpha_3 = static_cast<Scalar>(6.37261928875436e-04);
|
||||||
|
const Scalar alpha_5 = static_cast<Scalar>(1.48572235717979e-05);
|
||||||
|
const Scalar alpha_7 = static_cast<Scalar>(5.12229709037114e-08);
|
||||||
|
const Scalar alpha_9 = static_cast<Scalar>(-8.60467152213735e-11);
|
||||||
|
const Scalar alpha_11 = static_cast<Scalar>(2.00018790482477e-13);
|
||||||
|
const Scalar alpha_13 = static_cast<Scalar>(-2.76076847742355e-16);
|
||||||
|
// Scalarhe monomial coefficients of the denominator polynomial (even).
|
||||||
|
const Scalar beta_0 = static_cast<Scalar>(4.89352518554385e-03);
|
||||||
|
const Scalar beta_2 = static_cast<Scalar>(2.26843463243900e-03);
|
||||||
|
const Scalar beta_4 = static_cast<Scalar>(1.18534705686654e-04);
|
||||||
|
const Scalar beta_6 = static_cast<Scalar>(1.19825839466702e-06);
|
||||||
|
// Since the polynomials are odd/even, we need x^2.
|
||||||
|
const Scalar x2 = x * x;
|
||||||
|
// Evaluate the numerator polynomial p.
|
||||||
|
Scalar p = x2 * alpha_13 + alpha_11;
|
||||||
|
p = x2 * p + alpha_9;
|
||||||
|
p = x2 * p + alpha_7;
|
||||||
|
p = x2 * p + alpha_5;
|
||||||
|
p = x2 * p + alpha_3;
|
||||||
|
p = x2 * p + alpha_1;
|
||||||
|
p = x * p;
|
||||||
|
// Evaluate the denominator polynomial p.
|
||||||
|
Scalar q = x2 * beta_6 + beta_4;
|
||||||
|
q = x2 * q + beta_2;
|
||||||
|
q = x2 * q + beta_0;
|
||||||
|
// Divide the numerator by the denominator.
|
||||||
|
return p / q;
|
||||||
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& _x) const {
|
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& _x) const {
|
||||||
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
|
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
|
||||||
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
|
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
|
||||||
is accurate up to a couple of ulp in the range [-9, 9], outside of which the
|
is accurate up to a couple of ulp in the range [-9, 9], outside of which
|
||||||
|
the
|
||||||
fl(tanh(x)) = +/-1. */
|
fl(tanh(x)) = +/-1. */
|
||||||
|
|
||||||
// Clamp the inputs to the range [-9, 9] since anything outside
|
// Clamp the inputs to the range [-9, 9] since anything outside
|
||||||
@ -511,7 +554,7 @@ template<typename Scalar> struct scalar_tanh_op {
|
|||||||
const Packet plus_9 = pset1<Packet>(9.0);
|
const Packet plus_9 = pset1<Packet>(9.0);
|
||||||
const Packet minus_9 = pset1<Packet>(-9.0);
|
const Packet minus_9 = pset1<Packet>(-9.0);
|
||||||
const Packet x = pmax(minus_9, pmin(plus_9, _x));
|
const Packet x = pmax(minus_9, pmin(plus_9, _x));
|
||||||
|
|
||||||
// The monomial coefficients of the numerator polynomial (odd).
|
// The monomial coefficients of the numerator polynomial (odd).
|
||||||
const Packet alpha_1 = pset1<Packet>(4.89352455891786e-03);
|
const Packet alpha_1 = pset1<Packet>(4.89352455891786e-03);
|
||||||
const Packet alpha_3 = pset1<Packet>(6.37261928875436e-04);
|
const Packet alpha_3 = pset1<Packet>(6.37261928875436e-04);
|
||||||
@ -520,17 +563,17 @@ template<typename Scalar> struct scalar_tanh_op {
|
|||||||
const Packet alpha_9 = pset1<Packet>(-8.60467152213735e-11);
|
const Packet alpha_9 = pset1<Packet>(-8.60467152213735e-11);
|
||||||
const Packet alpha_11 = pset1<Packet>(2.00018790482477e-13);
|
const Packet alpha_11 = pset1<Packet>(2.00018790482477e-13);
|
||||||
const Packet alpha_13 = pset1<Packet>(-2.76076847742355e-16);
|
const Packet alpha_13 = pset1<Packet>(-2.76076847742355e-16);
|
||||||
|
|
||||||
// The monomial coefficients of the denominator polynomial (even).
|
// The monomial coefficients of the denominator polynomial (even).
|
||||||
const Packet beta_0 = pset1<Packet>(4.89352518554385e-03);
|
const Packet beta_0 = pset1<Packet>(4.89352518554385e-03);
|
||||||
const Packet beta_2 = pset1<Packet>(2.26843463243900e-03);
|
const Packet beta_2 = pset1<Packet>(2.26843463243900e-03);
|
||||||
const Packet beta_4 = pset1<Packet>(1.18534705686654e-04);
|
const Packet beta_4 = pset1<Packet>(1.18534705686654e-04);
|
||||||
const Packet beta_6 = pset1<Packet>(1.19825839466702e-06);
|
const Packet beta_6 = pset1<Packet>(1.19825839466702e-06);
|
||||||
|
|
||||||
// Since the polynomials are odd/even, we need x^2.
|
// Since the polynomials are odd/even, we need x^2.
|
||||||
const Packet x2 = pmul(x, x);
|
const Packet x2 = pmul(x, x);
|
||||||
|
|
||||||
// Evaluate the numerator polynomial p.
|
// Evaluate the numerator polynomial p.
|
||||||
Packet p = pmadd(x2, alpha_13, alpha_11);
|
Packet p = pmadd(x2, alpha_13, alpha_11);
|
||||||
p = pmadd(x2, p, alpha_9);
|
p = pmadd(x2, p, alpha_9);
|
||||||
p = pmadd(x2, p, alpha_7);
|
p = pmadd(x2, p, alpha_7);
|
||||||
@ -538,38 +581,56 @@ template<typename Scalar> struct scalar_tanh_op {
|
|||||||
p = pmadd(x2, p, alpha_3);
|
p = pmadd(x2, p, alpha_3);
|
||||||
p = pmadd(x2, p, alpha_1);
|
p = pmadd(x2, p, alpha_1);
|
||||||
p = pmul(x, p);
|
p = pmul(x, p);
|
||||||
|
|
||||||
// Evaluate the denominator polynomial p.
|
// Evaluate the denominator polynomial p.
|
||||||
Packet q = pmadd(x2, beta_6, beta_4);
|
Packet q = pmadd(x2, beta_6, beta_4);
|
||||||
q = pmadd(x2, q, beta_2);
|
q = pmadd(x2, q, beta_2);
|
||||||
q = pmadd(x2, q, beta_0);
|
q = pmadd(x2, q, beta_0);
|
||||||
|
|
||||||
// Divide the numerator by the denominator.
|
// Divide the numerator by the denominator.
|
||||||
return pdiv(p, q);
|
return pdiv(p, q);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<typename Scalar>
|
template <>
|
||||||
struct functor_traits<scalar_tanh_op<Scalar> >
|
struct scalar_tanh_op<std::complex<double> > {
|
||||||
{
|
EIGEN_DEVICE_FUNC inline const std::complex<double> operator()(
|
||||||
|
const std::complex<double>& a) const {
|
||||||
|
return numext::tanh(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct scalar_tanh_op<std::complex<float> > {
|
||||||
|
EIGEN_DEVICE_FUNC inline const std::complex<float> operator()(
|
||||||
|
const std::complex<float>& a) const {
|
||||||
|
return numext::tanh(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename Scalar>
|
||||||
|
struct functor_traits<scalar_tanh_op<Scalar> > {
|
||||||
enum {
|
enum {
|
||||||
PacketAccess = packet_traits<Scalar>::HasTanh,
|
PacketAccess = packet_traits<Scalar>::HasTanh,
|
||||||
Cost =
|
Cost = (PacketAccess && (!is_same<Scalar, std::complex<float> >::value) &&
|
||||||
(PacketAccess
|
(!is_same<Scalar, std::complex<double> >::value)
|
||||||
// The following numbers are based on the AVX implementation,
|
// The following numbers are based on the AVX implementation,
|
||||||
#ifdef EIGEN_VECTORIZE_FMA
|
#ifdef EIGEN_VECTORIZE_FMA
|
||||||
// Haswell can issue 2 add/mul/madd per cycle.
|
// Haswell can issue 2 add/mul/madd per cycle.
|
||||||
// 9 pmadd, 2 pmul, 1 div, 2 other
|
// 9 pmadd, 2 pmul, 1 div, 2 other
|
||||||
? (2 * NumTraits<Scalar>::AddCost + 6 * NumTraits<Scalar>::MulCost +
|
? (2 * NumTraits<Scalar>::AddCost +
|
||||||
NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost)
|
6 * NumTraits<Scalar>::MulCost +
|
||||||
|
NumTraits<Scalar>::template Div<
|
||||||
|
packet_traits<Scalar>::HasDiv>::Cost)
|
||||||
#else
|
#else
|
||||||
? (11 * NumTraits<Scalar>::AddCost +
|
? (11 * NumTraits<Scalar>::AddCost +
|
||||||
11 * NumTraits<Scalar>::MulCost +
|
11 * NumTraits<Scalar>::MulCost +
|
||||||
NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost)
|
NumTraits<Scalar>::template Div<
|
||||||
|
packet_traits<Scalar>::HasDiv>::Cost)
|
||||||
#endif
|
#endif
|
||||||
// This number assumes a naive implementation of tanh
|
// This number assumes a naive implementation of tanh
|
||||||
: (6 * NumTraits<Scalar>::AddCost + 3 * NumTraits<Scalar>::MulCost +
|
: (6 * NumTraits<Scalar>::AddCost +
|
||||||
2 * NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost +
|
3 * NumTraits<Scalar>::MulCost +
|
||||||
functor_traits<scalar_exp_op<Scalar> >::Cost))
|
2 * NumTraits<Scalar>::template Div<
|
||||||
|
packet_traits<Scalar>::HasDiv>::Cost +
|
||||||
|
functor_traits<scalar_exp_op<Scalar> >::Cost))
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user