From 8acbf5c11c2bbbfb2279cc0dc58278a91235eac6 Mon Sep 17 00:00:00 2001 From: chuckyschluz Date: Fri, 26 Aug 2022 17:29:02 -0400 Subject: [PATCH] re-enable pow for complex types --- Eigen/src/Core/GlobalFunctions.h | 47 +++++++++---------- Eigen/src/Core/functors/UnaryFunctors.h | 14 ++++-- Eigen/src/Core/util/ForwardDeclarations.h | 6 ++- Eigen/src/plugins/ArrayCwiseUnaryOps.h | 2 +- Eigen/src/plugins/MatrixCwiseUnaryOps.h | 2 +- .../Eigen/CXX11/src/Tensor/TensorBase.h | 2 +- 6 files changed, 39 insertions(+), 34 deletions(-) diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index a828e4bf7..f8d00b165 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -101,33 +101,32 @@ namespace Eigen EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isfinite,scalar_isfinite_op,finite value test,\sa Eigen::isinf DOXCOMMA Eigen::isnan DOXCOMMA ArrayBase::isfinite) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sign,scalar_sign_op,sign (or 0),\sa ArrayBase::sign) - /** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent. - * - * \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression (\c Derived::Scalar). - * - * \sa ArrayBase::pow() - * - * \relates ArrayBase - */ - -template -using GlobalUnaryPowReturnType = std::enable_if_t< - !internal::is_arithmetic::value && internal::is_arithmetic::value, - CwiseUnaryOp, const Derived> >; + template + using GlobalUnaryPowReturnType = std::enable_if_t< + !internal::is_arithmetic::Real>::value && + internal::is_arithmetic::Real>::value, + CwiseUnaryOp, const Derived> >; + /** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent. + * + * \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given + * expression (\c Derived::Scalar). + * + * \sa ArrayBase::pow() + * + * \relates ArrayBase + */ #ifdef EIGEN_PARSED_BY_DOXYGEN -template -EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType -pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent); + template + EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType pow( + const Eigen::ArrayBase& x, const ScalarExponent& exponent); #else -template -EIGEN_DEVICE_FUNC inline const typename std::enable_if< - !internal::is_arithmetic::value && internal::is_arithmetic::value, - CwiseUnaryOp, const Derived> >::type -pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent) { - return CwiseUnaryOp, const Derived>( - x.derived(), internal::scalar_unary_pow_op(exponent)); -} + template + EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType pow( + const Eigen::ArrayBase& x, const ScalarExponent& exponent) { + return GlobalUnaryPowReturnType( + x.derived(), internal::scalar_unary_pow_op(exponent)); + } #endif /** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents. diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 3b629a22f..15f2cf12a 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1070,11 +1070,14 @@ struct functor_traits > { }; }; -template ::IsInteger, - bool ExponentIsIntegerType = NumTraits::IsInteger> +template ::IsInteger, + bool ExponentIsInteger = NumTraits::IsInteger, + bool BaseIsComplex = NumTraits::IsComplex, + bool ExponentIsComplex = NumTraits::IsComplex> struct scalar_unary_pow_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { - EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); + EIGEN_STATIC_ASSERT((is_arithmetic::Real>::value), EXPONENT_MUST_BE_ARITHMETIC_OR_COMPLEX); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { EIGEN_USING_STD(pow); @@ -1087,7 +1090,7 @@ struct scalar_unary_pow_op { }; template -struct scalar_unary_pow_op { +struct scalar_unary_pow_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_STATIC_ASSERT((is_same::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE); EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); @@ -1107,10 +1110,11 @@ struct scalar_unary_pow_op { }; template -struct scalar_unary_pow_op { +struct scalar_unary_pow_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); } + // TODO: error handling logic for complex^real_integer EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return unary_pow_impl::run(a, m_exponent); } diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 11250209f..5c55639cd 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -197,8 +197,10 @@ template struct scalar_random_op; template struct scalar_constant_op; template struct scalar_identity_op; template struct scalar_sign_op; -template struct scalar_pow_op; -template +template +struct scalar_pow_op; +template struct scalar_unary_pow_op; template struct scalar_hypot_op; template struct scalar_product_op; diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/Eigen/src/plugins/ArrayCwiseUnaryOps.h index ee301f5b0..b2d933190 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.h @@ -697,7 +697,7 @@ ndtri() const template using UnaryPowReturnType = - std::enable_if_t::value, + std::enable_if_t::Real>::value, CwiseUnaryOp, const Derived>>; #ifndef EIGEN_PARSED_BY_DOXYGEN diff --git a/Eigen/src/plugins/MatrixCwiseUnaryOps.h b/Eigen/src/plugins/MatrixCwiseUnaryOps.h index 62f91db51..98d925dd2 100644 --- a/Eigen/src/plugins/MatrixCwiseUnaryOps.h +++ b/Eigen/src/plugins/MatrixCwiseUnaryOps.h @@ -96,7 +96,7 @@ cwiseArg() const { return CwiseArgReturnType(derived()); } template using CwisePowReturnType = - std::enable_if_t::value, + std::enable_if_t::Real>::value, CwiseUnaryOp, const Derived>>; template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 74e11dfad..a4ac2ad6d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -332,7 +332,7 @@ class TensorBase } template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t::value, + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t::Real>::value, TensorCwiseUnaryOp, const Derived>> pow(ScalarExponent exponent) const {