mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Revert the specialization for scalar_logistic_op<float> introduced in:
77b447c24e
While providing a 50% speedup on Haswell+ processors, the large relative error outside [-18, 18] in this approximation causes problems, e.g., when computing gradients of activation functions like softplus in neural networks.
This commit is contained in:
parent
3b15373bb3
commit
66f07efeae
@ -905,83 +905,14 @@ struct scalar_logistic_op {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \internal
|
|
||||||
* \brief Template specialization of the logistic function for float.
|
|
||||||
*
|
|
||||||
* Uses just a 9/10-degree rational interpolant which
|
|
||||||
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range
|
|
||||||
* [-18, 18], outside of which the fl(logistic(x)) = {0|1}. The shifted
|
|
||||||
* logistic is interpolated because it was easier to make the fit converge.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct scalar_logistic_op<float> {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
|
||||||
if (x < -18.0f) return 0.0f;
|
|
||||||
else if (x > 18.0f) return 1.0f;
|
|
||||||
else return 1.0f / (1.0f + numext::exp(-x));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
|
||||||
Packet packetOp(const Packet& _x) const {
|
|
||||||
// Clamp the inputs to the range [-18, 18] since anything outside
|
|
||||||
// this range is 0.0f or 1.0f in single-precision.
|
|
||||||
const Packet x = pmax(pmin(_x, pset1<Packet>(18.0)), pset1<Packet>(-18.0));
|
|
||||||
|
|
||||||
// The monomial coefficients of the numerator polynomial (odd).
|
|
||||||
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01);
|
|
||||||
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03);
|
|
||||||
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05);
|
|
||||||
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07);
|
|
||||||
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11);
|
|
||||||
|
|
||||||
// The monomial coefficients of the denominator polynomial (even).
|
|
||||||
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01);
|
|
||||||
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01);
|
|
||||||
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03);
|
|
||||||
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06);
|
|
||||||
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09);
|
|
||||||
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13);
|
|
||||||
|
|
||||||
// Since the polynomials are odd/even, we need x^2.
|
|
||||||
const Packet x2 = pmul(x, x);
|
|
||||||
|
|
||||||
// Evaluate the numerator polynomial p.
|
|
||||||
Packet p = pmadd(x2, alpha_9, alpha_7);
|
|
||||||
p = pmadd(x2, p, alpha_5);
|
|
||||||
p = pmadd(x2, p, alpha_3);
|
|
||||||
p = pmadd(x2, p, alpha_1);
|
|
||||||
p = pmul(x, p);
|
|
||||||
|
|
||||||
// Evaluate the denominator polynomial p.
|
|
||||||
Packet q = pmadd(x2, beta_10, beta_8);
|
|
||||||
q = pmadd(x2, q, beta_6);
|
|
||||||
q = pmadd(x2, q, beta_4);
|
|
||||||
q = pmadd(x2, q, beta_2);
|
|
||||||
q = pmadd(x2, q, beta_0);
|
|
||||||
|
|
||||||
// Divide the numerator by the denominator and shift it up.
|
|
||||||
return pmax(pmin(padd(pdiv(p, q), pset1<Packet>(0.5)), pset1<Packet>(1.0)),
|
|
||||||
pset1<Packet>(0.0));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct functor_traits<scalar_logistic_op<T> > {
|
struct functor_traits<scalar_logistic_op<T> > {
|
||||||
enum {
|
enum {
|
||||||
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
||||||
(internal::is_same<T, float>::value
|
NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T> >::Cost,
|
||||||
? NumTraits<T>::AddCost * 12 + NumTraits<T>::MulCost * 11
|
|
||||||
: NumTraits<T>::AddCost * 2 +
|
|
||||||
functor_traits<scalar_exp_op<T> >::Cost),
|
|
||||||
PacketAccess =
|
PacketAccess =
|
||||||
packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||||
(internal::is_same<T, float>::value
|
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
|
||||||
? packet_traits<T>::HasMul && packet_traits<T>::HasMax &&
|
|
||||||
packet_traits<T>::HasMin
|
|
||||||
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user