mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-16 14:49:39 +08:00
Fix scalar_logistic_function overflow for complex inputs.
This commit is contained in:
parent
9688081029
commit
3252ecc7a4
@ -1091,12 +1091,9 @@ struct functor_traits<scalar_sign_op<Scalar>> {
|
||||
};
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the logistic function of a scalar
|
||||
* \sa class CwiseUnaryOp, ArrayBase::logistic()
|
||||
*/
|
||||
template <typename T>
|
||||
struct scalar_logistic_op {
|
||||
// Real-valued implementation.
|
||||
template <typename T, typename EnableIf = void>
|
||||
struct scalar_logistic_op_impl {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); }
|
||||
|
||||
template <typename Packet>
|
||||
@ -1109,6 +1106,22 @@ struct scalar_logistic_op {
|
||||
}
|
||||
};
|
||||
|
||||
// Complex-valud implementation.
|
||||
template <typename T>
|
||||
struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
|
||||
const T e = numext::exp(x);
|
||||
return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1));
|
||||
}
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to compute the logistic function of a scalar
|
||||
* \sa class CwiseUnaryOp, ArrayBase::logistic()
|
||||
*/
|
||||
template <typename T>
|
||||
struct scalar_logistic_op : scalar_logistic_op_impl<T> {};
|
||||
|
||||
// TODO(rmlarsen): Enable the following on host when integer_packet is defined
|
||||
// for the relevant packet types.
|
||||
#ifdef EIGEN_GPU_CC
|
||||
@ -1206,7 +1219,7 @@ struct functor_traits<scalar_logistic_op<T>> {
|
||||
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
||||
(internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
|
||||
: NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost),
|
||||
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||
PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||
(internal::is_same<T, float>::value
|
||||
? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin
|
||||
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
|
||||
|
@ -976,7 +976,14 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
||||
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
||||
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
||||
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
||||
VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1))));
|
||||
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
|
||||
if (m1.size() > 0) {
|
||||
// Complex exponential overflow edge-case.
|
||||
Scalar old_m1_val = m1(0, 0);
|
||||
m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0);
|
||||
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
|
||||
m1(0, 0) = old_m1_val; // Restore value for future tests.
|
||||
}
|
||||
|
||||
for (Index i = 0; i < m.rows(); ++i)
|
||||
for (Index j = 0; j < m.cols(); ++j)
|
||||
|
Loading…
x
Reference in New Issue
Block a user