re-enable pow for complex types

This commit is contained in:
chuckyschluz 2022-08-26 17:29:02 -04:00
parent 7064ed1345
commit 8acbf5c11c
6 changed files with 39 additions and 34 deletions

View File

@ -101,33 +101,32 @@ namespace Eigen
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isfinite,scalar_isfinite_op,finite value test,\sa Eigen::isinf DOXCOMMA Eigen::isnan DOXCOMMA ArrayBase::isfinite) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isfinite,scalar_isfinite_op,finite value test,\sa Eigen::isinf DOXCOMMA Eigen::isnan DOXCOMMA ArrayBase::isfinite)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sign,scalar_sign_op,sign (or 0),\sa ArrayBase::sign) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sign,scalar_sign_op,sign (or 0),\sa ArrayBase::sign)
/** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent. template <typename Derived, typename ScalarExponent>
* using GlobalUnaryPowReturnType = std::enable_if_t<
* \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression (\c Derived::Scalar). !internal::is_arithmetic<typename NumTraits<Derived>::Real>::value &&
* internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
* \sa ArrayBase::pow() CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
*
* \relates ArrayBase
*/
template <typename Derived, typename ScalarExponent>
using GlobalUnaryPowReturnType = std::enable_if_t<
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
/** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent.
*
* \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given
* expression (\c Derived::Scalar).
*
* \sa ArrayBase::pow()
*
* \relates ArrayBase
*/
#ifdef EIGEN_PARSED_BY_DOXYGEN #ifdef EIGEN_PARSED_BY_DOXYGEN
template <typename Derived, typename ScalarExponent> template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent> EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent> pow(
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent); const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
#else #else
template <typename Derived, typename ScalarExponent> template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const typename std::enable_if< EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent> pow(
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value, const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >::type return GlobalUnaryPowReturnType<Derived, ScalarExponent>(
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) { x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
return CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived>( }
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
}
#endif #endif
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents. /** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.

View File

@ -1070,11 +1070,14 @@ struct functor_traits<scalar_logistic_op<T> > {
}; };
}; };
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType = NumTraits<Scalar>::IsInteger, template <typename Scalar, typename ScalarExponent,
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger> bool BaseIsInteger = NumTraits<Scalar>::IsInteger,
bool ExponentIsInteger = NumTraits<ScalarExponent>::IsInteger,
bool BaseIsComplex = NumTraits<Scalar>::IsComplex,
bool ExponentIsComplex = NumTraits<ScalarExponent>::IsComplex>
struct scalar_unary_pow_op { struct scalar_unary_pow_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC); EIGEN_STATIC_ASSERT((is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value), EXPONENT_MUST_BE_ARITHMETIC_OR_COMPLEX);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow); EIGEN_USING_STD(pow);
@ -1087,7 +1090,7 @@ struct scalar_unary_pow_op {
}; };
template <typename Scalar, typename ScalarExponent> template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> { struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE); EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC); EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
@ -1107,10 +1110,11 @@ struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
}; };
template <typename Scalar, typename ScalarExponent> template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true> { struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC); EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
} }
// TODO: error handling logic for complex^real_integer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent); return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
} }

View File

@ -197,8 +197,10 @@ template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_constant_op; template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op; template<typename Scalar> struct scalar_identity_op;
template<typename Scalar> struct scalar_sign_op; template<typename Scalar> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op; template <typename Scalar, typename ScalarExponent>
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType, bool ExponentIsIntegerType> struct scalar_pow_op;
template <typename Scalar, typename ScalarExponent, bool BaseIsInteger, bool ExponentIsInteger, bool BaseIsComplex,
bool ExponentIsComplex>
struct scalar_unary_pow_op; struct scalar_unary_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;

View File

@ -697,7 +697,7 @@ ndtri() const
template <typename ScalarExponent> template <typename ScalarExponent>
using UnaryPowReturnType = using UnaryPowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value, std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>; CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN

View File

@ -96,7 +96,7 @@ cwiseArg() const { return CwiseArgReturnType(derived()); }
template <typename ScalarExponent> template <typename ScalarExponent>
using CwisePowReturnType = using CwisePowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value, std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>; CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
template <typename ScalarExponent> template <typename ScalarExponent>

View File

@ -332,7 +332,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
} }
template<typename ScalarExponent> template<typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>> TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>
pow(ScalarExponent exponent) const pow(ScalarExponent exponent) const
{ {