mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Move sigmoid functor to core.
This commit is contained in:
parent
09c81ac033
commit
fa68342ef8
@ -66,6 +66,7 @@ namespace Eigen
|
|||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sinh,scalar_sinh_op,hyperbolic sine,\sa ArrayBase::sinh)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sinh,scalar_sinh_op,hyperbolic sine,\sa ArrayBase::sinh)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cosh,scalar_cosh_op,hyperbolic cosine,\sa ArrayBase::cosh)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cosh,scalar_cosh_op,hyperbolic cosine,\sa ArrayBase::cosh)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(tanh,scalar_tanh_op,hyperbolic tangent,\sa ArrayBase::tanh)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(tanh,scalar_tanh_op,hyperbolic tangent,\sa ArrayBase::tanh)
|
||||||
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sigmoid,scalar_sigmoid_op,sigmoid function,\sa ArrayBase::sigmoid)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op,natural logarithm of the gamma function,\sa ArrayBase::lgamma)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(lgamma,scalar_lgamma_op,natural logarithm of the gamma function,\sa ArrayBase::lgamma)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op,derivative of lgamma,\sa ArrayBase::digamma)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(digamma,scalar_digamma_op,derivative of lgamma,\sa ArrayBase::digamma)
|
||||||
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op,error function,\sa ArrayBase::erf)
|
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(erf,scalar_erf_op,error function,\sa ArrayBase::erf)
|
||||||
|
@ -823,6 +823,34 @@ struct functor_traits<scalar_sign_op<Scalar> >
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
* \brief Template functor to compute the sigmoid of a scalar
|
||||||
|
* \sa class CwiseUnaryOp, ArrayBase::sigmoid()
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
struct scalar_sigmoid_op {
|
||||||
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_op)
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
|
||||||
|
const T one = T(1);
|
||||||
|
return one / (one + numext::exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
Packet packetOp(const Packet& x) const {
|
||||||
|
const Packet one = pset1<Packet>(T(1));
|
||||||
|
return pdiv(one, padd(one, pexp(pnegate(x))));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
struct functor_traits<scalar_sigmoid_op<T> > {
|
||||||
|
enum {
|
||||||
|
Cost = NumTraits<T>::AddCost * 2 + NumTraits<T>::MulCost * 6,
|
||||||
|
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||||
|
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -21,6 +21,7 @@ typedef CwiseUnaryOp<internal::scalar_acos_op<Scalar>, const Derived> AcosReturn
|
|||||||
typedef CwiseUnaryOp<internal::scalar_asin_op<Scalar>, const Derived> AsinReturnType;
|
typedef CwiseUnaryOp<internal::scalar_asin_op<Scalar>, const Derived> AsinReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_atan_op<Scalar>, const Derived> AtanReturnType;
|
typedef CwiseUnaryOp<internal::scalar_atan_op<Scalar>, const Derived> AtanReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_tanh_op<Scalar>, const Derived> TanhReturnType;
|
typedef CwiseUnaryOp<internal::scalar_tanh_op<Scalar>, const Derived> TanhReturnType;
|
||||||
|
typedef CwiseUnaryOp<internal::scalar_sigmoid_op<Scalar>, const Derived> SigmoidReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_sinh_op<Scalar>, const Derived> SinhReturnType;
|
typedef CwiseUnaryOp<internal::scalar_sinh_op<Scalar>, const Derived> SinhReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturnType;
|
typedef CwiseUnaryOp<internal::scalar_cosh_op<Scalar>, const Derived> CoshReturnType;
|
||||||
typedef CwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> SquareReturnType;
|
typedef CwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> SquareReturnType;
|
||||||
@ -335,6 +336,15 @@ cosh() const
|
|||||||
return CoshReturnType(derived());
|
return CoshReturnType(derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \returns an expression of the coefficient-wise sigmoid of *this.
|
||||||
|
*/
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
inline const SigmoidReturnType
|
||||||
|
sigmoid() const
|
||||||
|
{
|
||||||
|
return SigmoidReturnType(derived());
|
||||||
|
}
|
||||||
|
|
||||||
/** \returns an expression of the coefficient-wise inverse of *this.
|
/** \returns an expression of the coefficient-wise inverse of *this.
|
||||||
*
|
*
|
||||||
* Example: \include Cwise_inverse.cpp
|
* Example: \include Cwise_inverse.cpp
|
||||||
|
@ -231,6 +231,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
||||||
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
||||||
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
||||||
|
VERIFY_IS_APPROX(m1.sigmoid(), sigmoid(m1));
|
||||||
|
|
||||||
VERIFY_IS_APPROX(m1.arg(), arg(m1));
|
VERIFY_IS_APPROX(m1.arg(), arg(m1));
|
||||||
VERIFY_IS_APPROX(m1.round(), round(m1));
|
VERIFY_IS_APPROX(m1.round(), round(m1));
|
||||||
@ -266,6 +267,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
||||||
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
||||||
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
||||||
|
VERIFY_IS_APPROX(sigmoid(m1), (1.0/(1.0+exp(-m1))));
|
||||||
VERIFY_IS_APPROX(arg(m1), ((m1<0).template cast<Scalar>())*std::acos(-1.0));
|
VERIFY_IS_APPROX(arg(m1), ((m1<0).template cast<Scalar>())*std::acos(-1.0));
|
||||||
VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all());
|
VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all());
|
||||||
VERIFY((Eigen::isnan)((m1*0.0)/0.0).all());
|
VERIFY((Eigen::isnan)((m1*0.0)/0.0).all());
|
||||||
@ -345,6 +347,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
||||||
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
||||||
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
||||||
|
VERIFY_IS_APPROX(m1.sigmoid(), sigmoid(m1));
|
||||||
VERIFY_IS_APPROX(m1.arg(), arg(m1));
|
VERIFY_IS_APPROX(m1.arg(), arg(m1));
|
||||||
VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all());
|
VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all());
|
||||||
VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all());
|
VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all());
|
||||||
@ -368,6 +371,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1)));
|
||||||
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1)));
|
||||||
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1))));
|
||||||
|
VERIFY_IS_APPROX(sigmoid(m1), (1.0/(1.0 + exp(-m1))));
|
||||||
|
|
||||||
for (Index i = 0; i < m.rows(); ++i)
|
for (Index i = 0; i < m.rows(); ++i)
|
||||||
for (Index j = 0; j < m.cols(); ++j)
|
for (Index j = 0; j < m.cols(); ++j)
|
||||||
|
@ -54,36 +54,6 @@ struct functor_traits<scalar_fmod_op<Scalar> > {
|
|||||||
PacketAccess = false };
|
PacketAccess = false };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/** \internal
|
|
||||||
* \brief Template functor to compute the sigmoid of a scalar
|
|
||||||
* \sa class CwiseUnaryOp, ArrayBase::sigmoid()
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
struct scalar_sigmoid_op {
|
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_op)
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
|
|
||||||
const T one = T(1);
|
|
||||||
return one / (one + numext::exp(-x));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
|
||||||
Packet packetOp(const Packet& x) const {
|
|
||||||
const Packet one = pset1<Packet>(T(1));
|
|
||||||
return pdiv(one, padd(one, pexp(pnegate(x))));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct functor_traits<scalar_sigmoid_op<T> > {
|
|
||||||
enum {
|
|
||||||
Cost = NumTraits<T>::AddCost * 2 + NumTraits<T>::MulCost * 6,
|
|
||||||
PacketAccess = packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
|
||||||
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Reducer, typename Device>
|
template<typename Reducer, typename Device>
|
||||||
struct reducer_traits {
|
struct reducer_traits {
|
||||||
enum {
|
enum {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user