From 3252ecc7a4975f6eb8c8c171fb449ca328e6d08a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Tue, 5 Dec 2023 18:21:04 +0000 Subject: [PATCH] Fix scalar_logistic_function overflow for complex inputs. --- Eigen/src/Core/functors/UnaryFunctors.h | 27 ++++++++++++++++++------- test/array_cwise.cpp | 9 ++++++++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 3c7dfb769..a3fc44c1f 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1091,12 +1091,9 @@ struct functor_traits> { }; }; -/** \internal - * \brief Template functor to compute the logistic function of a scalar - * \sa class CwiseUnaryOp, ArrayBase::logistic() - */ -template -struct scalar_logistic_op { +// Real-valued implementation. +template +struct scalar_logistic_op_impl { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); } template @@ -1109,6 +1106,22 @@ struct scalar_logistic_op { } }; +// Complex-valud implementation. +template +struct scalar_logistic_op_impl::IsComplex>> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { + const T e = numext::exp(x); + return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1)); + } +}; + +/** \internal + * \brief Template functor to compute the logistic function of a scalar + * \sa class CwiseUnaryOp, ArrayBase::logistic() + */ +template +struct scalar_logistic_op : scalar_logistic_op_impl {}; + // TODO(rmlarsen): Enable the following on host when integer_packet is defined // for the relevant packet types. #ifdef EIGEN_GPU_CC @@ -1206,7 +1219,7 @@ struct functor_traits> { Cost = scalar_div_cost::HasDiv>::value + (internal::is_same::value ? NumTraits::AddCost * 15 + NumTraits::MulCost * 11 : NumTraits::AddCost * 2 + functor_traits>::Cost), - PacketAccess = packet_traits::HasAdd && packet_traits::HasDiv && + PacketAccess = !NumTraits::IsComplex && packet_traits::HasAdd && packet_traits::HasDiv && (internal::is_same::value ? packet_traits::HasMul && packet_traits::HasMax && packet_traits::HasMin : packet_traits::HasNegate && packet_traits::HasExp) diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index bfea96ace..9b629697d 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -976,7 +976,14 @@ template void array_complex(const ArrayType& m) VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1))); VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1))); VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1)))); - VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1)))); + VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); + if (m1.size() > 0) { + // Complex exponential overflow edge-case. + Scalar old_m1_val = m1(0, 0); + m1(0, 0) = std::complex(1000.0, 1000.0); + VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1)))); + m1(0, 0) = old_m1_val; // Restore value for future tests. + } for (Index i = 0; i < m.rows(); ++i) for (Index j = 0; j < m.cols(); ++j)