re-enable pow for complex types

This commit is contained in:
chuckyschluz 2022-08-26 17:29:02 -04:00
parent 7064ed1345
commit 8acbf5c11c
6 changed files with 39 additions and 34 deletions

View File

@ -101,33 +101,32 @@ namespace Eigen
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(isfinite,scalar_isfinite_op,finite value test,\sa Eigen::isinf DOXCOMMA Eigen::isnan DOXCOMMA ArrayBase::isfinite)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sign,scalar_sign_op,sign (or 0),\sa ArrayBase::sign)
/** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent.
*
* \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression (\c Derived::Scalar).
*
* \sa ArrayBase::pow()
*
* \relates ArrayBase
*/
template <typename Derived, typename ScalarExponent>
using GlobalUnaryPowReturnType = std::enable_if_t<
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
template <typename Derived, typename ScalarExponent>
using GlobalUnaryPowReturnType = std::enable_if_t<
!internal::is_arithmetic<typename NumTraits<Derived>::Real>::value &&
internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
/** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent.
*
* \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given
* expression (\c Derived::Scalar).
*
* \sa ArrayBase::pow()
*
* \relates ArrayBase
*/
#ifdef EIGEN_PARSED_BY_DOXYGEN
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent>
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent> pow(
const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
#else
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const typename std::enable_if<
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >::type
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
return CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived>(
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
}
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent> pow(
const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
return GlobalUnaryPowReturnType<Derived, ScalarExponent>(
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
}
#endif
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.

View File

@ -1070,11 +1070,14 @@ struct functor_traits<scalar_logistic_op<T> > {
};
};
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType = NumTraits<Scalar>::IsInteger,
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
template <typename Scalar, typename ScalarExponent,
bool BaseIsInteger = NumTraits<Scalar>::IsInteger,
bool ExponentIsInteger = NumTraits<ScalarExponent>::IsInteger,
bool BaseIsComplex = NumTraits<Scalar>::IsComplex,
bool ExponentIsComplex = NumTraits<ScalarExponent>::IsComplex>
struct scalar_unary_pow_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
EIGEN_STATIC_ASSERT((is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value), EXPONENT_MUST_BE_ARITHMETIC_OR_COMPLEX);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
@ -1087,7 +1090,7 @@ struct scalar_unary_pow_op {
};
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
@ -1107,10 +1110,11 @@ struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
};
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true> {
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
// TODO: error handling logic for complex^real_integer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
}

View File

@ -197,8 +197,10 @@ template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op;
template<typename Scalar> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType, bool ExponentIsIntegerType>
template <typename Scalar, typename ScalarExponent>
struct scalar_pow_op;
template <typename Scalar, typename ScalarExponent, bool BaseIsInteger, bool ExponentIsInteger, bool BaseIsComplex,
bool ExponentIsComplex>
struct scalar_unary_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;

View File

@ -697,7 +697,7 @@ ndtri() const
template <typename ScalarExponent>
using UnaryPowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
#ifndef EIGEN_PARSED_BY_DOXYGEN

View File

@ -96,7 +96,7 @@ cwiseArg() const { return CwiseArgReturnType(derived()); }
template <typename ScalarExponent>
using CwisePowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
template <typename ScalarExponent>

View File

@ -332,7 +332,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
template<typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,
TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>
pow(ScalarExponent exponent) const
{