add fixed power unary operation

This commit is contained in:
Charles Schlosser 2022-08-16 21:32:36 +00:00 committed by Rasmus Munk Larsen
parent 39fcc89798
commit 76a669fb45
9 changed files with 391 additions and 39 deletions

View File

@ -109,23 +109,25 @@ namespace Eigen
*
* \relates ArrayBase
*/
template <typename Derived, typename ScalarExponent>
using GlobalUnaryPowReturnType = std::enable_if_t<
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
#ifdef EIGEN_PARSED_BY_DOXYGEN
template<typename Derived,typename ScalarExponent>
inline const CwiseBinaryOp<internal::scalar_pow_op<Derived::Scalar,ScalarExponent>,Derived,Constant<ScalarExponent> >
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent>
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
#else
template <typename Derived,typename ScalarExponent>
EIGEN_DEVICE_FUNC inline
const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<typename Derived::Scalar
EIGEN_COMMA ScalarExponent EIGEN_COMMA
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type,pow)
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent)
{
typedef typename internal::promote_scalar_arg<typename Derived::Scalar,ScalarExponent,
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type PromotedExponent;
return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedExponent,pow)(x.derived(),
typename internal::plain_constant_type<Derived,PromotedExponent>::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op<PromotedExponent>(exponent)));
}
template <typename Derived, typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const typename std::enable_if<
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >::type
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
return CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived>(
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
}
#endif
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.

View File

@ -1690,6 +1690,225 @@ struct pchebevl {
}
};
namespace unary_pow {
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
struct is_odd {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
ScalarExponent xdiv2 = x / ScalarExponent(2);
ScalarExponent floorxdiv2 = numext::floor(xdiv2);
return xdiv2 != floorxdiv2;
}
};
template <typename ScalarExponent>
struct is_odd<ScalarExponent, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
return x % ScalarExponent(2);
}
};
template <typename Packet, typename ScalarExponent,
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
struct do_div {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
return exponent < 0 ? pdiv(cst_pos_one, x) : x;
}
};
template <typename Packet, typename ScalarExponent>
struct do_div<Packet, ScalarExponent, true> {
// pdiv not defined, nor necessary for integer base types
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
EIGEN_UNUSED_VARIABLE(exponent);
return x;
}
};
template <typename Packet, typename ScalarExponent>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
if (exponent == 0) return cst_pos_one;
Packet result = x;
Packet y = cst_pos_one;
ScalarExponent m = numext::abs(exponent);
while (m > 1) {
bool odd = is_odd<ScalarExponent>::run(m);
if (odd) y = pmul(y, result);
result = pmul(result, result);
m = numext::floor(m / ScalarExponent(2));
}
result = pmul(y, result);
result = do_div<Packet, ScalarExponent>::run(result, exponent);
return result;
}
template <typename Packet>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
const typename unpacket_traits<Packet>::type& exponent) {
const Packet exponent_packet = pset1<Packet>(exponent);
return generic_pow_impl(x, exponent_packet);
}
template <typename Packet, typename ScalarExponent>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx,
const ScalarExponent& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
// non-integer base, integer exponent case
const bool exponent_is_odd = is_odd<ScalarExponent>::run(exponent);
const bool exponent_is_neg = exponent < 0;
const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x);
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
const Scalar pos_zero = Scalar(0);
const Scalar neg_zero = -Scalar(0);
const Scalar pos_one = Scalar(1);
const Scalar pos_inf = NumTraits<Scalar>::infinity();
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
const Packet cst_neg_zero = pset1<Packet>(neg_zero);
const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
const Packet abs_x = pabs(x);
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
if (exponent == 0) {
return cst_pos_one;
}
Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg));
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd)));
pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd)));
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd));
pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg)));
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd)));
pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg));
pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg)));
Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx);
result = pselect(pow_is_neg_zero, cst_neg_zero, result);
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
result = pselect(pow_is_pos_inf, cst_pos_inf, result);
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
return result;
}
template <typename Packet, typename ScalarExponent>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
const ScalarExponent& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
// non-integer base and exponent case
const bool exponent_is_fin = (numext::isfinite)(exponent);
const bool exponent_is_nan = (numext::isnan)(exponent);
const bool exponent_is_neg = exponent < 0;
const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan;
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x);
const Scalar pos_zero = Scalar(0);
const Scalar pos_one = Scalar(1);
const Scalar pos_inf = NumTraits<Scalar>::infinity();
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
const Packet cst_nan = pset1<Packet>(nan);
const Packet abs_x = pabs(x);
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one);
const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x);
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
if (exponent_is_nan) {
return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan);
}
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg)));
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg)));
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg));
const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf);
Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg);
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg)));
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg)));
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg));
const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf);
Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx);
result = pselect(pow_is_pos_one, cst_pos_one, result);
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
result = pselect(pow_is_nan, cst_nan, result);
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
return result;
}
} // end namespace unary_pow
template <typename Packet, typename ScalarExponent,
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger,
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
struct unary_pow_impl;
template <typename Packet, typename ScalarExponent>
struct unary_pow_impl<Packet, ScalarExponent, false, false> {
typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
if (exponent_is_integer) {
Packet result = unary_pow::int_pow(x, exponent);
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
return result;
} else {
Packet result = unary_pow::gen_pow(x, exponent);
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
return result;
}
}
};
template <typename Packet, typename ScalarExponent>
struct unary_pow_impl<Packet, ScalarExponent, false, true> {
typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
Packet result = unary_pow::int_pow(x, exponent);
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
return result;
}
};
} // end namespace internal
} // end namespace Eigen

View File

@ -1070,6 +1070,70 @@ struct functor_traits<scalar_logistic_op<T> > {
};
};
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType = NumTraits<Scalar>::IsInteger,
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
struct scalar_unary_pow_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return pow(a, m_exponent);
}
private:
const ScalarExponent m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return pow(a, m_exponent);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
}
private:
const ScalarExponent m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
}
private:
const ScalarExponent m_exponent;
scalar_unary_pow_op() {}
};
template <typename Scalar, typename ScalarExponent>
struct functor_traits<scalar_unary_pow_op<Scalar, ScalarExponent>> {
enum {
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::PacketAccess,
IntPacketAccess = !NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasCmp,
PacketAccess = NumTraits<ScalarExponent>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
Cost = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::Cost
};
};
} // end namespace internal
} // end namespace Eigen

View File

@ -198,6 +198,8 @@ template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op;
template<typename Scalar> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType, bool ExponentIsIntegerType>
struct scalar_unary_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op;

View File

@ -134,26 +134,6 @@ absolute_difference
*/
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
#ifndef EIGEN_PARSED_BY_DOXYGEN
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(pow,pow)
#else
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
*
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
*
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
* unsupported module MatrixFunctions computes the matrix power.
*
* Example: \include Cwise_pow.cpp
* Output: \verbinclude Cwise_pow.out
*
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
*/
template<typename T>
const CwiseBinaryOp<internal::scalar_pow_op<Scalar,T>,Derived,Constant<T> > pow(const T& exponent) const;
#endif
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
template<typename OtherDerived> \

View File

@ -694,3 +694,32 @@ ndtri() const
{
return NdtriReturnType(derived());
}
template <typename ScalarExponent>
using UnaryPowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
#ifndef EIGEN_PARSED_BY_DOXYGEN
template <typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
const ScalarExponent& exponent) const {
return UnaryPowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
#else
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
*
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
*
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
* unsupported module MatrixFunctions computes the matrix power.
*
* Example: \include Cwise_pow.cpp
* Output: \verbinclude Cwise_pow.out
*
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
*/
template <typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
const ScalarExponent& exponent) const;
#endif
}

View File

@ -93,3 +93,13 @@ EIGEN_DOC_UNARY_ADDONS(cwiseArg,arg)
EIGEN_DEVICE_FUNC
inline const CwiseArgReturnType
cwiseArg() const { return CwiseArgReturnType(derived()); }
template <typename ScalarExponent>
using CwisePowReturnType =
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
template <typename ScalarExponent>
EIGEN_DEVICE_FUNC inline const CwisePowReturnType<ScalarExponent> cwisePow(const ScalarExponent& exponent) const {
return CwisePowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
}

View File

@ -79,6 +79,50 @@ void pow_test() {
}
}
}
typedef typename internal::make_integer<Scalar>::type Int_t;
// ensure both vectorized and non-vectorized paths taken
Index test_size = 2 * internal::packet_traits<Scalar>::size + 1;
Array<Scalar, Dynamic, 1> eigenPow(test_size);
for (int i = 0; i < num_cases; ++i) {
Array<Scalar, Dynamic, 1> bases = x.col(i);
for (Scalar abs_exponent : abs_vals){
for (Scalar exponent : {-abs_exponent, abs_exponent}){
// test floating point exponent code path
eigenPow.setZero();
eigenPow = bases.pow(exponent);
for (int j = 0; j < num_repeats; j++){
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
all_pass &= success;
if (!success) {
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
}
}
// test integer exponent code path
bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && (numext::abs(exponent) < static_cast<Scalar>(NumTraits<Int_t>::highest()));
if (exponent_is_integer)
{
Int_t exponent_as_int = static_cast<Int_t>(exponent);
eigenPow.setZero();
eigenPow = bases.pow(exponent_as_int);
for (int j = 0; j < num_repeats; j++){
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
all_pass &= success;
if (!success) {
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
}
}
}
}
}
}
VERIFY(all_pass);
}

View File

@ -331,10 +331,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >, const Derived>
pow(Scalar exponent) const {
return unaryExpr(internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >(exponent));
template<typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>
pow(ScalarExponent exponent) const
{
return unaryExpr(internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
}
EIGEN_DEVICE_FUNC