diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index 53f9dfa2a..a828e4bf7 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -109,23 +109,25 @@ namespace Eigen * * \relates ArrayBase */ + +template +using GlobalUnaryPowReturnType = std::enable_if_t< + !internal::is_arithmetic::value && internal::is_arithmetic::value, + CwiseUnaryOp, const Derived> >; + #ifdef EIGEN_PARSED_BY_DOXYGEN - template - inline const CwiseBinaryOp,Derived,Constant > - pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent); +template +EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType +pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent); #else - template - EIGEN_DEVICE_FUNC inline - const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg::type,pow) - pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent) - { - typedef typename internal::promote_scalar_arg::type PromotedExponent; - return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedExponent,pow)(x.derived(), - typename internal::plain_constant_type::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op(exponent))); - } +template +EIGEN_DEVICE_FUNC inline const typename std::enable_if< + !internal::is_arithmetic::value && internal::is_arithmetic::value, + CwiseUnaryOp, const Derived> >::type +pow(const Eigen::ArrayBase& x, const ScalarExponent& exponent) { + return CwiseUnaryOp, const Derived>( + x.derived(), internal::scalar_unary_pow_op(exponent)); +} #endif /** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents. diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 66310b651..a88439cac 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1690,6 +1690,225 @@ struct pchebevl { } }; +namespace unary_pow { +template ::IsInteger> +struct is_odd { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) { + ScalarExponent xdiv2 = x / ScalarExponent(2); + ScalarExponent floorxdiv2 = numext::floor(xdiv2); + return xdiv2 != floorxdiv2; + } +}; +template +struct is_odd { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) { + return x % ScalarExponent(2); + } +}; + +template ::type>::IsInteger> +struct do_div { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + typedef typename unpacket_traits::type Scalar; + const Packet cst_pos_one = pset1(Scalar(1)); + return exponent < 0 ? pdiv(cst_pos_one, x) : x; + } +}; + +template +struct do_div { + // pdiv not defined, nor necessary for integer base types + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + EIGEN_UNUSED_VARIABLE(exponent); + return x; + } +}; + +template +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { + typedef typename unpacket_traits::type Scalar; + const Packet cst_pos_one = pset1(Scalar(1)); + if (exponent == 0) return cst_pos_one; + Packet result = x; + Packet y = cst_pos_one; + ScalarExponent m = numext::abs(exponent); + while (m > 1) { + bool odd = is_odd::run(m); + if (odd) y = pmul(y, result); + result = pmul(result, result); + m = numext::floor(m / ScalarExponent(2)); + } + result = pmul(y, result); + result = do_div::run(result, exponent); + return result; +} + +template +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, + const typename unpacket_traits::type& exponent) { + const Packet exponent_packet = pset1(exponent); + return generic_pow_impl(x, exponent_packet); +} + +template +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx, + const ScalarExponent& exponent) { + typedef typename unpacket_traits::type Scalar; + + // non-integer base, integer exponent case + + const bool exponent_is_odd = is_odd::run(exponent); + const bool exponent_is_neg = exponent < 0; + + const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x); + const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x); + + const Scalar pos_zero = Scalar(0); + const Scalar neg_zero = -Scalar(0); + const Scalar pos_one = Scalar(1); + const Scalar pos_inf = NumTraits::infinity(); + const Scalar neg_inf = -NumTraits::infinity(); + + const Packet cst_pos_zero = pset1(pos_zero); + const Packet cst_neg_zero = pset1(neg_zero); + const Packet cst_pos_one = pset1(pos_one); + const Packet cst_pos_inf = pset1(pos_inf); + const Packet cst_neg_inf = pset1(neg_inf); + + const Packet abs_x = pabs(x); + const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + + const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf); + const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero); + const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero); + + if (exponent == 0) { + return cst_pos_one; + } + + Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg)); + pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd))); + pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd))); + pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg)); + + Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd)); + pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg))); + + Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg); + pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd))); + pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg)); + + Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg)); + pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg))); + + Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx); + result = pselect(pow_is_neg_zero, cst_neg_zero, result); + result = pselect(pow_is_pos_zero, cst_pos_zero, result); + result = pselect(pow_is_pos_inf, cst_pos_inf, result); + result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result); + return result; +} + +template +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, + const ScalarExponent& exponent) { + typedef typename unpacket_traits::type Scalar; + + // non-integer base and exponent case + + const bool exponent_is_fin = (numext::isfinite)(exponent); + const bool exponent_is_nan = (numext::isnan)(exponent); + const bool exponent_is_neg = exponent < 0; + const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan; + + const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x); + const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x); + + const Scalar pos_zero = Scalar(0); + const Scalar pos_one = Scalar(1); + const Scalar pos_inf = NumTraits::infinity(); + const Scalar neg_inf = -NumTraits::infinity(); + const Scalar nan = NumTraits::quiet_NaN(); + + const Packet cst_pos_zero = pset1(pos_zero); + const Packet cst_pos_one = pset1(pos_one); + const Packet cst_pos_inf = pset1(pos_inf); + const Packet cst_neg_inf = pset1(neg_inf); + const Packet cst_nan = pset1(nan); + + const Packet abs_x = pabs(x); + const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero); + const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one); + const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + + const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf); + const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero); + const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero); + + if (exponent_is_nan) { + return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan); + } + + Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg); + pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg))); + pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg))); + pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg)); + + const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf); + + Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg); + pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg))); + pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg))); + pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg)); + + const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf); + + Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx); + result = pselect(pow_is_pos_one, cst_pos_one, result); + result = pselect(pow_is_pos_zero, cst_pos_zero, result); + result = pselect(pow_is_nan, cst_nan, result); + result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result); + return result; +} +} // end namespace unary_pow + +template ::type>::IsInteger, + bool ExponentIsIntegerType = NumTraits::IsInteger> +struct unary_pow_impl; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; + if (exponent_is_integer) { + Packet result = unary_pow::int_pow(x, exponent); + result = unary_pow::handle_nonint_int_errors(x, result, exponent); + return result; + } else { + Packet result = unary_pow::gen_pow(x, exponent); + result = unary_pow::handle_nonint_nonint_errors(x, result, exponent); + return result; + } + } +}; + +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + Packet result = unary_pow::int_pow(x, exponent); + result = unary_pow::handle_nonint_int_errors(x, result, exponent); + return result; + } +}; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index f4d5fcae4..3b629a22f 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1070,6 +1070,70 @@ struct functor_traits > { }; }; +template ::IsInteger, + bool ExponentIsIntegerType = NumTraits::IsInteger> +struct scalar_unary_pow_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { + EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + EIGEN_USING_STD(pow); + return pow(a, m_exponent); + } + + private: + const ScalarExponent m_exponent; + scalar_unary_pow_op() {} +}; + +template +struct scalar_unary_pow_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { + EIGEN_STATIC_ASSERT((is_same::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE); + EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + EIGEN_USING_STD(pow); + return pow(a, m_exponent); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { + return unary_pow_impl::run(a, m_exponent); + } + + private: + const ScalarExponent m_exponent; + scalar_unary_pow_op() {} +}; + +template +struct scalar_unary_pow_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) { + EIGEN_STATIC_ASSERT((is_arithmetic::value), EXPONENT_MUST_BE_ARITHMETIC); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return unary_pow_impl::run(a, m_exponent); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { + return unary_pow_impl::run(a, m_exponent); + } + + private: + const ScalarExponent m_exponent; + scalar_unary_pow_op() {} +}; + +template +struct functor_traits> { + enum { + GenPacketAccess = functor_traits>::PacketAccess, + IntPacketAccess = !NumTraits::IsComplex && !NumTraits::IsInteger && packet_traits::HasMul && packet_traits::HasDiv && packet_traits::HasCmp, + PacketAccess = NumTraits::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess), + Cost = functor_traits>::Cost + }; +}; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index bbed82c1f..11250209f 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -198,6 +198,8 @@ template struct scalar_constant_op; template struct scalar_identity_op; template struct scalar_sign_op; template struct scalar_pow_op; +template +struct scalar_unary_pow_op; template struct scalar_hypot_op; template struct scalar_product_op; template struct scalar_quotient_op; diff --git a/Eigen/src/plugins/ArrayCwiseBinaryOps.h b/Eigen/src/plugins/ArrayCwiseBinaryOps.h index a4dbfbc33..5f1e84459 100644 --- a/Eigen/src/plugins/ArrayCwiseBinaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseBinaryOps.h @@ -134,26 +134,6 @@ absolute_difference */ EIGEN_MAKE_CWISE_BINARY_OP(pow,pow) -#ifndef EIGEN_PARSED_BY_DOXYGEN -EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(pow,pow) -#else -/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent - * - * \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression. - * - * This function computes the coefficient-wise power. The function MatrixBase::pow() in the - * unsupported module MatrixFunctions computes the matrix power. - * - * Example: \include Cwise_pow.cpp - * Output: \verbinclude Cwise_pow.out - * - * \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log() - */ -template -const CwiseBinaryOp,Derived,Constant > pow(const T& exponent) const; -#endif - - // TODO code generating macros could be moved to Macros.h and could include generation of documentation #define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \ template \ diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/Eigen/src/plugins/ArrayCwiseUnaryOps.h index 13c55f4b1..ee301f5b0 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.h @@ -694,3 +694,32 @@ ndtri() const { return NdtriReturnType(derived()); } + +template +using UnaryPowReturnType = + std::enable_if_t::value, + CwiseUnaryOp, const Derived>>; + +#ifndef EIGEN_PARSED_BY_DOXYGEN +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType pow( + const ScalarExponent& exponent) const { + return UnaryPowReturnType(derived(), internal::scalar_unary_pow_op(exponent)); +#else +/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent + * + * \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression. + * + * This function computes the coefficient-wise power. The function MatrixBase::pow() in the + * unsupported module MatrixFunctions computes the matrix power. + * + * Example: \include Cwise_pow.cpp + * Output: \verbinclude Cwise_pow.out + * + * \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log() + */ +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType pow( + const ScalarExponent& exponent) const; +#endif +} diff --git a/Eigen/src/plugins/MatrixCwiseUnaryOps.h b/Eigen/src/plugins/MatrixCwiseUnaryOps.h index 0514d8f78..62f91db51 100644 --- a/Eigen/src/plugins/MatrixCwiseUnaryOps.h +++ b/Eigen/src/plugins/MatrixCwiseUnaryOps.h @@ -93,3 +93,13 @@ EIGEN_DOC_UNARY_ADDONS(cwiseArg,arg) EIGEN_DEVICE_FUNC inline const CwiseArgReturnType cwiseArg() const { return CwiseArgReturnType(derived()); } + +template +using CwisePowReturnType = + std::enable_if_t::value, + CwiseUnaryOp, const Derived>>; + +template +EIGEN_DEVICE_FUNC inline const CwisePowReturnType cwisePow(const ScalarExponent& exponent) const { + return CwisePowReturnType(derived(), internal::scalar_unary_pow_op(exponent)); +} diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 48290a171..d877cb141 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -79,6 +79,50 @@ void pow_test() { } } } + + typedef typename internal::make_integer::type Int_t; + + // ensure both vectorized and non-vectorized paths taken + Index test_size = 2 * internal::packet_traits::size + 1; + + Array eigenPow(test_size); + for (int i = 0; i < num_cases; ++i) { + Array bases = x.col(i); + for (Scalar abs_exponent : abs_vals){ + for (Scalar exponent : {-abs_exponent, abs_exponent}){ + // test floating point exponent code path + eigenPow.setZero(); + eigenPow = bases.pow(exponent); + for (int j = 0; j < num_repeats; j++){ + Scalar e = static_cast(std::pow(bases(j), exponent)); + Scalar a = eigenPow(j); + bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e)); + all_pass &= success; + if (!success) { + std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl; + } + } + // test integer exponent code path + bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && (numext::abs(exponent) < static_cast(NumTraits::highest())); + if (exponent_is_integer) + { + Int_t exponent_as_int = static_cast(exponent); + eigenPow.setZero(); + eigenPow = bases.pow(exponent_as_int); + for (int j = 0; j < num_repeats; j++){ + Scalar e = static_cast(std::pow(bases(j), exponent)); + Scalar a = eigenPow(j); + bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e)); + all_pass &= success; + if (!success) { + std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl; + } + } + } + } + } + } + VERIFY(all_pass); } diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 4a6edb1cb..74e11dfad 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -331,10 +331,12 @@ class TensorBase return choose(Cond::IsComplex>(), unaryExpr(internal::scalar_conjugate_op()), derived()); } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> - pow(Scalar exponent) const { - return unaryExpr(internal::bind2nd_op >(exponent)); + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t::value, + TensorCwiseUnaryOp, const Derived>> + pow(ScalarExponent exponent) const + { + return unaryExpr(internal::scalar_unary_pow_op(exponent)); } EIGEN_DEVICE_FUNC