mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-12 16:11:49 +08:00
Fix scalar_logistic_function overflow for complex inputs.
This commit is contained in:
parent
9688081029
commit
3252ecc7a4
@ -1091,12 +1091,9 @@ struct functor_traits<scalar_sign_op<Scalar>> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \internal
|
// Real-valued implementation.
|
||||||
* \brief Template functor to compute the logistic function of a scalar
|
template <typename T, typename EnableIf = void>
|
||||||
* \sa class CwiseUnaryOp, ArrayBase::logistic()
|
struct scalar_logistic_op_impl {
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
struct scalar_logistic_op {
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return packetOp(x); }
|
||||||
|
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
@ -1109,6 +1106,22 @@ struct scalar_logistic_op {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Complex-valud implementation.
|
||||||
|
template <typename T>
|
||||||
|
struct scalar_logistic_op_impl<T, std::enable_if_t<NumTraits<T>::IsComplex>> {
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
|
||||||
|
const T e = numext::exp(x);
|
||||||
|
return (numext::isinf)(numext::real(e)) ? T(1) : e / (e + T(1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
* \brief Template functor to compute the logistic function of a scalar
|
||||||
|
* \sa class CwiseUnaryOp, ArrayBase::logistic()
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
struct scalar_logistic_op : scalar_logistic_op_impl<T> {};
|
||||||
|
|
||||||
// TODO(rmlarsen): Enable the following on host when integer_packet is defined
|
// TODO(rmlarsen): Enable the following on host when integer_packet is defined
|
||||||
// for the relevant packet types.
|
// for the relevant packet types.
|
||||||
#ifdef EIGEN_GPU_CC
|
#ifdef EIGEN_GPU_CC
|
||||||
@ -1206,7 +1219,7 @@ struct functor_traits<scalar_logistic_op<T>> {
|
|||||||
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
||||||
(internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
|
(internal::is_same<T, float>::value ? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
|
||||||
: NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost),
|
: NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T>>::Cost),
|
||||||
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
PacketAccess = !NumTraits<T>::IsComplex && packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||||
(internal::is_same<T, float>::value
|
(internal::is_same<T, float>::value
|
||||||
? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin
|
? packet_traits<T>::HasMul && packet_traits<T>::HasMax && packet_traits<T>::HasMin
|
||||||
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
|
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
|
||||||
|
@ -976,7 +976,14 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
||||||
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
||||||
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
||||||
VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0 + exp(-m1))));
|
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
|
||||||
|
if (m1.size() > 0) {
|
||||||
|
// Complex exponential overflow edge-case.
|
||||||
|
Scalar old_m1_val = m1(0, 0);
|
||||||
|
m1(0, 0) = std::complex<RealScalar>(1000.0, 1000.0);
|
||||||
|
VERIFY_IS_APPROX(logistic(m1), (1.0 / (1.0 + exp(-m1))));
|
||||||
|
m1(0, 0) = old_m1_val; // Restore value for future tests.
|
||||||
|
}
|
||||||
|
|
||||||
for (Index i = 0; i < m.rows(); ++i)
|
for (Index i = 0; i < m.rows(); ++i)
|
||||||
for (Index j = 0; j < m.cols(); ++j)
|
for (Index j = 0; j < m.cols(); ++j)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user