Fix scalar_logistic_function overflow for complex inputs.

This commit is contained in:
Antonio Sánchez 2023-12-05 18:21:04 +00:00 committed by Rasmus Munk Larsen
parent 9688081029
commit 3252ecc7a4
2 changed files with 28 additions and 8 deletions

View File

@ -1091,12 +1091,9 @@ struct functor_traits<scalar_sign_op<Scalar>> {
};
};
/** \internal
* \brief Template functor to compute the logistic function of a scalar
* \sa class CwiseUnaryOp, ArrayBase::logistic()
*/
template <typename T>
struct scalar_logistic_op {
// Real-valued implementation.
template <typename T, typename EnableIf = void>
struct scalar_logistic_op_impl {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); }
template <typename Packet>
@ -1109,6 +1106,22 @@ struct scalar_logistic_op {
}
};
// Complex-valud implementation.
template <typename T>
struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
const T e = numext::exp(x);
return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1));
}
};
/** \internal
* \brief Template functor to compute the logistic function of a scalar
* \sa class CwiseUnaryOp, ArrayBase::logistic()
*/
template <typename T>
struct scalar_logistic_op : scalar_logistic_op_impl<T> {};
// TODO(rmlarsen): Enable the following on host when integer_packet is defined
// for the relevant packet types.
#ifdef EIGEN_GPU_CC
@ -1206,7 +1219,7 @@ struct functor_traits<scalar_logistic_op<T>> {
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
(internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
: NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost),
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
(internal::is_same<T, float>::value
? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)

View File

@ -976,7 +976,14 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1))));
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
if (m1.size() > 0) {
// Complex exponential overflow edge-case.
Scalar old_m1_val = m1(0, 0);
m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0);
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
m1(0, 0) = old_m1_val; // Restore value for future tests.
}
for (Index i = 0; i < m.rows(); ++i)
for (Index j = 0; j < m.cols(); ++j)