mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
add fixed power unary operation
This commit is contained in:
parent
39fcc89798
commit
76a669fb45
@ -109,23 +109,25 @@ namespace Eigen
|
|||||||
*
|
*
|
||||||
* \relates ArrayBase
|
* \relates ArrayBase
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
template <typename Derived, typename ScalarExponent>
|
||||||
|
using GlobalUnaryPowReturnType = std::enable_if_t<
|
||||||
|
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
|
||||||
|
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
|
||||||
|
|
||||||
#ifdef EIGEN_PARSED_BY_DOXYGEN
|
#ifdef EIGEN_PARSED_BY_DOXYGEN
|
||||||
template<typename Derived,typename ScalarExponent>
|
template <typename Derived, typename ScalarExponent>
|
||||||
inline const CwiseBinaryOp<internal::scalar_pow_op<Derived::Scalar,ScalarExponent>,Derived,Constant<ScalarExponent> >
|
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent>
|
||||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
|
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
|
||||||
#else
|
#else
|
||||||
template <typename Derived,typename ScalarExponent>
|
template <typename Derived, typename ScalarExponent>
|
||||||
EIGEN_DEVICE_FUNC inline
|
EIGEN_DEVICE_FUNC inline const typename std::enable_if<
|
||||||
const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<typename Derived::Scalar
|
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
|
||||||
EIGEN_COMMA ScalarExponent EIGEN_COMMA
|
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >::type
|
||||||
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type,pow)
|
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
|
||||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent)
|
return CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived>(
|
||||||
{
|
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
|
||||||
typedef typename internal::promote_scalar_arg<typename Derived::Scalar,ScalarExponent,
|
}
|
||||||
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type PromotedExponent;
|
|
||||||
return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedExponent,pow)(x.derived(),
|
|
||||||
typename internal::plain_constant_type<Derived,PromotedExponent>::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op<PromotedExponent>(exponent)));
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.
|
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.
|
||||||
|
@ -1690,6 +1690,225 @@ struct pchebevl {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace unary_pow {
|
||||||
|
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
|
||||||
|
struct is_odd {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||||
|
ScalarExponent xdiv2 = x / ScalarExponent(2);
|
||||||
|
ScalarExponent floorxdiv2 = numext::floor(xdiv2);
|
||||||
|
return xdiv2 != floorxdiv2;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
struct is_odd<ScalarExponent, true> {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||||
|
return x % ScalarExponent(2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent,
|
||||||
|
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
|
||||||
|
struct do_div {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||||
|
return exponent < 0 ? pdiv(cst_pos_one, x) : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
struct do_div<Packet, ScalarExponent, true> {
|
||||||
|
// pdiv not defined, nor necessary for integer base types
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||||
|
EIGEN_UNUSED_VARIABLE(exponent);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||||
|
if (exponent == 0) return cst_pos_one;
|
||||||
|
Packet result = x;
|
||||||
|
Packet y = cst_pos_one;
|
||||||
|
ScalarExponent m = numext::abs(exponent);
|
||||||
|
while (m > 1) {
|
||||||
|
bool odd = is_odd<ScalarExponent>::run(m);
|
||||||
|
if (odd) y = pmul(y, result);
|
||||||
|
result = pmul(result, result);
|
||||||
|
m = numext::floor(m / ScalarExponent(2));
|
||||||
|
}
|
||||||
|
result = pmul(y, result);
|
||||||
|
result = do_div<Packet, ScalarExponent>::run(result, exponent);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
|
||||||
|
const typename unpacket_traits<Packet>::type& exponent) {
|
||||||
|
const Packet exponent_packet = pset1<Packet>(exponent);
|
||||||
|
return generic_pow_impl(x, exponent_packet);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx,
|
||||||
|
const ScalarExponent& exponent) {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
|
||||||
|
// non-integer base, integer exponent case
|
||||||
|
|
||||||
|
const bool exponent_is_odd = is_odd<ScalarExponent>::run(exponent);
|
||||||
|
const bool exponent_is_neg = exponent < 0;
|
||||||
|
|
||||||
|
const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x);
|
||||||
|
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||||
|
|
||||||
|
const Scalar pos_zero = Scalar(0);
|
||||||
|
const Scalar neg_zero = -Scalar(0);
|
||||||
|
const Scalar pos_one = Scalar(1);
|
||||||
|
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||||
|
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
|
||||||
|
|
||||||
|
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||||
|
const Packet cst_neg_zero = pset1<Packet>(neg_zero);
|
||||||
|
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||||
|
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||||
|
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
|
||||||
|
|
||||||
|
const Packet abs_x = pabs(x);
|
||||||
|
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||||
|
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||||
|
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||||
|
|
||||||
|
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||||
|
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||||
|
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||||
|
|
||||||
|
if (exponent == 0) {
|
||||||
|
return cst_pos_one;
|
||||||
|
}
|
||||||
|
|
||||||
|
Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg));
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd)));
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd)));
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||||
|
|
||||||
|
Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd));
|
||||||
|
pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg)));
|
||||||
|
|
||||||
|
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||||
|
pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd)));
|
||||||
|
pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||||
|
|
||||||
|
Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg));
|
||||||
|
pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg)));
|
||||||
|
|
||||||
|
Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx);
|
||||||
|
result = pselect(pow_is_neg_zero, cst_neg_zero, result);
|
||||||
|
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||||
|
result = pselect(pow_is_pos_inf, cst_pos_inf, result);
|
||||||
|
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||||
|
const ScalarExponent& exponent) {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
|
||||||
|
// non-integer base and exponent case
|
||||||
|
|
||||||
|
const bool exponent_is_fin = (numext::isfinite)(exponent);
|
||||||
|
const bool exponent_is_nan = (numext::isnan)(exponent);
|
||||||
|
const bool exponent_is_neg = exponent < 0;
|
||||||
|
const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan;
|
||||||
|
|
||||||
|
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||||
|
const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x);
|
||||||
|
|
||||||
|
const Scalar pos_zero = Scalar(0);
|
||||||
|
const Scalar pos_one = Scalar(1);
|
||||||
|
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||||
|
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
|
||||||
|
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||||
|
|
||||||
|
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||||
|
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||||
|
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||||
|
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
|
||||||
|
const Packet cst_nan = pset1<Packet>(nan);
|
||||||
|
|
||||||
|
const Packet abs_x = pabs(x);
|
||||||
|
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||||
|
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one);
|
||||||
|
const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x);
|
||||||
|
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||||
|
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||||
|
|
||||||
|
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||||
|
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||||
|
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||||
|
|
||||||
|
if (exponent_is_nan) {
|
||||||
|
return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan);
|
||||||
|
}
|
||||||
|
|
||||||
|
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||||
|
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg)));
|
||||||
|
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||||
|
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg));
|
||||||
|
|
||||||
|
const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf);
|
||||||
|
|
||||||
|
Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg);
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg)));
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||||
|
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg));
|
||||||
|
|
||||||
|
const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf);
|
||||||
|
|
||||||
|
Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx);
|
||||||
|
result = pselect(pow_is_pos_one, cst_pos_one, result);
|
||||||
|
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||||
|
result = pselect(pow_is_nan, cst_nan, result);
|
||||||
|
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // end namespace unary_pow
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent,
|
||||||
|
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger,
|
||||||
|
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
|
||||||
|
struct unary_pow_impl;
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
struct unary_pow_impl<Packet, ScalarExponent, false, false> {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||||
|
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||||
|
if (exponent_is_integer) {
|
||||||
|
Packet result = unary_pow::int_pow(x, exponent);
|
||||||
|
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
Packet result = unary_pow::gen_pow(x, exponent);
|
||||||
|
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Packet, typename ScalarExponent>
|
||||||
|
struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||||
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||||
|
Packet result = unary_pow::int_pow(x, exponent);
|
||||||
|
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -1070,6 +1070,70 @@ struct functor_traits<scalar_logistic_op<T> > {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType = NumTraits<Scalar>::IsInteger,
|
||||||
|
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
|
||||||
|
struct scalar_unary_pow_op {
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||||
|
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||||
|
EIGEN_USING_STD(pow);
|
||||||
|
return pow(a, m_exponent);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ScalarExponent m_exponent;
|
||||||
|
scalar_unary_pow_op() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar, typename ScalarExponent>
|
||||||
|
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||||
|
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
|
||||||
|
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||||
|
EIGEN_USING_STD(pow);
|
||||||
|
return pow(a, m_exponent);
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
|
||||||
|
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ScalarExponent m_exponent;
|
||||||
|
scalar_unary_pow_op() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar, typename ScalarExponent>
|
||||||
|
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true> {
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||||
|
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||||
|
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
|
||||||
|
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ScalarExponent m_exponent;
|
||||||
|
scalar_unary_pow_op() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar, typename ScalarExponent>
|
||||||
|
struct functor_traits<scalar_unary_pow_op<Scalar, ScalarExponent>> {
|
||||||
|
enum {
|
||||||
|
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::PacketAccess,
|
||||||
|
IntPacketAccess = !NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasCmp,
|
||||||
|
PacketAccess = NumTraits<ScalarExponent>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
|
||||||
|
Cost = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::Cost
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -198,6 +198,8 @@ template<typename Scalar> struct scalar_constant_op;
|
|||||||
template<typename Scalar> struct scalar_identity_op;
|
template<typename Scalar> struct scalar_identity_op;
|
||||||
template<typename Scalar> struct scalar_sign_op;
|
template<typename Scalar> struct scalar_sign_op;
|
||||||
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
|
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
|
||||||
|
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType, bool ExponentIsIntegerType>
|
||||||
|
struct scalar_unary_pow_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op;
|
||||||
|
@ -134,26 +134,6 @@ absolute_difference
|
|||||||
*/
|
*/
|
||||||
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
||||||
|
|
||||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
|
||||||
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(pow,pow)
|
|
||||||
#else
|
|
||||||
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
|
|
||||||
*
|
|
||||||
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
|
|
||||||
*
|
|
||||||
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
|
|
||||||
* unsupported module MatrixFunctions computes the matrix power.
|
|
||||||
*
|
|
||||||
* Example: \include Cwise_pow.cpp
|
|
||||||
* Output: \verbinclude Cwise_pow.out
|
|
||||||
*
|
|
||||||
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
const CwiseBinaryOp<internal::scalar_pow_op<Scalar,T>,Derived,Constant<T> > pow(const T& exponent) const;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
||||||
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
||||||
template<typename OtherDerived> \
|
template<typename OtherDerived> \
|
||||||
|
@ -694,3 +694,32 @@ ndtri() const
|
|||||||
{
|
{
|
||||||
return NdtriReturnType(derived());
|
return NdtriReturnType(derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
using UnaryPowReturnType =
|
||||||
|
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||||
|
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
|
||||||
|
|
||||||
|
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
|
||||||
|
const ScalarExponent& exponent) const {
|
||||||
|
return UnaryPowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||||
|
#else
|
||||||
|
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
|
||||||
|
*
|
||||||
|
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
|
||||||
|
*
|
||||||
|
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
|
||||||
|
* unsupported module MatrixFunctions computes the matrix power.
|
||||||
|
*
|
||||||
|
* Example: \include Cwise_pow.cpp
|
||||||
|
* Output: \verbinclude Cwise_pow.out
|
||||||
|
*
|
||||||
|
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
|
||||||
|
*/
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
|
||||||
|
const ScalarExponent& exponent) const;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
@ -93,3 +93,13 @@ EIGEN_DOC_UNARY_ADDONS(cwiseArg,arg)
|
|||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const CwiseArgReturnType
|
inline const CwiseArgReturnType
|
||||||
cwiseArg() const { return CwiseArgReturnType(derived()); }
|
cwiseArg() const { return CwiseArgReturnType(derived()); }
|
||||||
|
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
using CwisePowReturnType =
|
||||||
|
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||||
|
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
|
||||||
|
|
||||||
|
template <typename ScalarExponent>
|
||||||
|
EIGEN_DEVICE_FUNC inline const CwisePowReturnType<ScalarExponent> cwisePow(const ScalarExponent& exponent) const {
|
||||||
|
return CwisePowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||||
|
}
|
||||||
|
@ -79,6 +79,50 @@ void pow_test() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef typename internal::make_integer<Scalar>::type Int_t;
|
||||||
|
|
||||||
|
// ensure both vectorized and non-vectorized paths taken
|
||||||
|
Index test_size = 2 * internal::packet_traits<Scalar>::size + 1;
|
||||||
|
|
||||||
|
Array<Scalar, Dynamic, 1> eigenPow(test_size);
|
||||||
|
for (int i = 0; i < num_cases; ++i) {
|
||||||
|
Array<Scalar, Dynamic, 1> bases = x.col(i);
|
||||||
|
for (Scalar abs_exponent : abs_vals){
|
||||||
|
for (Scalar exponent : {-abs_exponent, abs_exponent}){
|
||||||
|
// test floating point exponent code path
|
||||||
|
eigenPow.setZero();
|
||||||
|
eigenPow = bases.pow(exponent);
|
||||||
|
for (int j = 0; j < num_repeats; j++){
|
||||||
|
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
|
||||||
|
Scalar a = eigenPow(j);
|
||||||
|
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||||
|
all_pass &= success;
|
||||||
|
if (!success) {
|
||||||
|
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// test integer exponent code path
|
||||||
|
bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && (numext::abs(exponent) < static_cast<Scalar>(NumTraits<Int_t>::highest()));
|
||||||
|
if (exponent_is_integer)
|
||||||
|
{
|
||||||
|
Int_t exponent_as_int = static_cast<Int_t>(exponent);
|
||||||
|
eigenPow.setZero();
|
||||||
|
eigenPow = bases.pow(exponent_as_int);
|
||||||
|
for (int j = 0; j < num_repeats; j++){
|
||||||
|
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
|
||||||
|
Scalar a = eigenPow(j);
|
||||||
|
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||||
|
all_pass &= success;
|
||||||
|
if (!success) {
|
||||||
|
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
VERIFY(all_pass);
|
VERIFY(all_pass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,10 +331,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived());
|
return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
template<typename ScalarExponent>
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >, const Derived>
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||||
pow(Scalar exponent) const {
|
TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>
|
||||||
return unaryExpr(internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >(exponent));
|
pow(ScalarExponent exponent) const
|
||||||
|
{
|
||||||
|
return unaryExpr(internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
|
Loading…
x
Reference in New Issue
Block a user