mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 03:09:01 +08:00
add fixed power unary operation
This commit is contained in:
parent
39fcc89798
commit
76a669fb45
@ -109,23 +109,25 @@ namespace Eigen
|
||||
*
|
||||
* \relates ArrayBase
|
||||
*/
|
||||
|
||||
template <typename Derived, typename ScalarExponent>
|
||||
using GlobalUnaryPowReturnType = std::enable_if_t<
|
||||
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
|
||||
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >;
|
||||
|
||||
#ifdef EIGEN_PARSED_BY_DOXYGEN
|
||||
template<typename Derived,typename ScalarExponent>
|
||||
inline const CwiseBinaryOp<internal::scalar_pow_op<Derived::Scalar,ScalarExponent>,Derived,Constant<ScalarExponent> >
|
||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
|
||||
template <typename Derived, typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC inline const GlobalUnaryPowReturnType<Derived, ScalarExponent>
|
||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent);
|
||||
#else
|
||||
template <typename Derived,typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC inline
|
||||
const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<typename Derived::Scalar
|
||||
EIGEN_COMMA ScalarExponent EIGEN_COMMA
|
||||
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type,pow)
|
||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent)
|
||||
{
|
||||
typedef typename internal::promote_scalar_arg<typename Derived::Scalar,ScalarExponent,
|
||||
EIGEN_SCALAR_BINARY_SUPPORTED(pow,typename Derived::Scalar,ScalarExponent)>::type PromotedExponent;
|
||||
return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedExponent,pow)(x.derived(),
|
||||
typename internal::plain_constant_type<Derived,PromotedExponent>::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op<PromotedExponent>(exponent)));
|
||||
}
|
||||
template <typename Derived, typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC inline const typename std::enable_if<
|
||||
!internal::is_arithmetic<Derived>::value && internal::is_arithmetic<ScalarExponent>::value,
|
||||
CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived> >::type
|
||||
pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) {
|
||||
return CwiseUnaryOp<internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>, const Derived>(
|
||||
x.derived(), internal::scalar_unary_pow_op<typename Derived::Scalar, ScalarExponent>(exponent));
|
||||
}
|
||||
#endif
|
||||
|
||||
/** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents.
|
||||
|
@ -1690,6 +1690,225 @@ struct pchebevl {
|
||||
}
|
||||
};
|
||||
|
||||
namespace unary_pow {
|
||||
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct is_odd {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||
ScalarExponent xdiv2 = x / ScalarExponent(2);
|
||||
ScalarExponent floorxdiv2 = numext::floor(xdiv2);
|
||||
return xdiv2 != floorxdiv2;
|
||||
}
|
||||
};
|
||||
template <typename ScalarExponent>
|
||||
struct is_odd<ScalarExponent, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||
return x % ScalarExponent(2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
|
||||
struct do_div {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
return exponent < 0 ? pdiv(cst_pos_one, x) : x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct do_div<Packet, ScalarExponent, true> {
|
||||
// pdiv not defined, nor necessary for integer base types
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
EIGEN_UNUSED_VARIABLE(exponent);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
if (exponent == 0) return cst_pos_one;
|
||||
Packet result = x;
|
||||
Packet y = cst_pos_one;
|
||||
ScalarExponent m = numext::abs(exponent);
|
||||
while (m > 1) {
|
||||
bool odd = is_odd<ScalarExponent>::run(m);
|
||||
if (odd) y = pmul(y, result);
|
||||
result = pmul(result, result);
|
||||
m = numext::floor(m / ScalarExponent(2));
|
||||
}
|
||||
result = pmul(y, result);
|
||||
result = do_div<Packet, ScalarExponent>::run(result, exponent);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
|
||||
const typename unpacket_traits<Packet>::type& exponent) {
|
||||
const Packet exponent_packet = pset1<Packet>(exponent);
|
||||
return generic_pow_impl(x, exponent_packet);
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// non-integer base, integer exponent case
|
||||
|
||||
const bool exponent_is_odd = is_odd<ScalarExponent>::run(exponent);
|
||||
const bool exponent_is_neg = exponent < 0;
|
||||
|
||||
const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x);
|
||||
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||
|
||||
const Scalar pos_zero = Scalar(0);
|
||||
const Scalar neg_zero = -Scalar(0);
|
||||
const Scalar pos_one = Scalar(1);
|
||||
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
|
||||
|
||||
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||
const Packet cst_neg_zero = pset1<Packet>(neg_zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
|
||||
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
|
||||
if (exponent == 0) {
|
||||
return cst_pos_one;
|
||||
}
|
||||
|
||||
Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||
|
||||
Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd));
|
||||
pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg)));
|
||||
|
||||
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||
|
||||
Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg));
|
||||
pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg)));
|
||||
|
||||
Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx);
|
||||
result = pselect(pow_is_neg_zero, cst_neg_zero, result);
|
||||
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||
result = pselect(pow_is_pos_inf, cst_pos_inf, result);
|
||||
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// non-integer base and exponent case
|
||||
|
||||
const bool exponent_is_fin = (numext::isfinite)(exponent);
|
||||
const bool exponent_is_nan = (numext::isnan)(exponent);
|
||||
const bool exponent_is_neg = exponent < 0;
|
||||
const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan;
|
||||
|
||||
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||
const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x);
|
||||
|
||||
const Scalar pos_zero = Scalar(0);
|
||||
const Scalar pos_one = Scalar(1);
|
||||
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
|
||||
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||
|
||||
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
|
||||
const Packet cst_nan = pset1<Packet>(nan);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
|
||||
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
|
||||
if (exponent_is_nan) {
|
||||
return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan);
|
||||
}
|
||||
|
||||
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg));
|
||||
|
||||
const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf);
|
||||
|
||||
Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg));
|
||||
|
||||
const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf);
|
||||
|
||||
Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx);
|
||||
result = pselect(pow_is_pos_one, cst_pos_one, result);
|
||||
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||
result = pselect(pow_is_nan, cst_nan, result);
|
||||
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||
return result;
|
||||
}
|
||||
} // end namespace unary_pow
|
||||
|
||||
template <typename Packet, typename ScalarExponent,
|
||||
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger,
|
||||
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct unary_pow_impl;
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, false> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||
if (exponent_is_integer) {
|
||||
Packet result = unary_pow::int_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||
return result;
|
||||
} else {
|
||||
Packet result = unary_pow::gen_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
Packet result = unary_pow::int_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
|
@ -1070,6 +1070,70 @@ struct functor_traits<scalar_logistic_op<T> > {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType = NumTraits<Scalar>::IsInteger,
|
||||
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct scalar_unary_pow_op {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||
EIGEN_USING_STD(pow);
|
||||
return pow(a, m_exponent);
|
||||
}
|
||||
|
||||
private:
|
||||
const ScalarExponent m_exponent;
|
||||
scalar_unary_pow_op() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename ScalarExponent>
|
||||
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
|
||||
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||
EIGEN_USING_STD(pow);
|
||||
return pow(a, m_exponent);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
|
||||
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
|
||||
}
|
||||
|
||||
private:
|
||||
const ScalarExponent m_exponent;
|
||||
scalar_unary_pow_op() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename ScalarExponent>
|
||||
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, true> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
|
||||
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const {
|
||||
return unary_pow_impl<Scalar, ScalarExponent>::run(a, m_exponent);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
|
||||
return unary_pow_impl<Packet, ScalarExponent>::run(a, m_exponent);
|
||||
}
|
||||
|
||||
private:
|
||||
const ScalarExponent m_exponent;
|
||||
scalar_unary_pow_op() {}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename ScalarExponent>
|
||||
struct functor_traits<scalar_unary_pow_op<Scalar, ScalarExponent>> {
|
||||
enum {
|
||||
GenPacketAccess = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::PacketAccess,
|
||||
IntPacketAccess = !NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasCmp,
|
||||
PacketAccess = NumTraits<ScalarExponent>::IsInteger ? IntPacketAccess : (IntPacketAccess && GenPacketAccess),
|
||||
Cost = functor_traits<scalar_pow_op<Scalar, ScalarExponent>>::Cost
|
||||
};
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -198,6 +198,8 @@ template<typename Scalar> struct scalar_constant_op;
|
||||
template<typename Scalar> struct scalar_identity_op;
|
||||
template<typename Scalar> struct scalar_sign_op;
|
||||
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
|
||||
template <typename Scalar, typename ScalarExponent, bool BaseIsIntegerType, bool ExponentIsIntegerType>
|
||||
struct scalar_unary_pow_op;
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op;
|
||||
|
@ -134,26 +134,6 @@ absolute_difference
|
||||
*/
|
||||
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
||||
|
||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||
EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(pow,pow)
|
||||
#else
|
||||
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
|
||||
*
|
||||
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
|
||||
*
|
||||
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
|
||||
* unsupported module MatrixFunctions computes the matrix power.
|
||||
*
|
||||
* Example: \include Cwise_pow.cpp
|
||||
* Output: \verbinclude Cwise_pow.out
|
||||
*
|
||||
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
|
||||
*/
|
||||
template<typename T>
|
||||
const CwiseBinaryOp<internal::scalar_pow_op<Scalar,T>,Derived,Constant<T> > pow(const T& exponent) const;
|
||||
#endif
|
||||
|
||||
|
||||
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
||||
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
||||
template<typename OtherDerived> \
|
||||
|
@ -694,3 +694,32 @@ ndtri() const
|
||||
{
|
||||
return NdtriReturnType(derived());
|
||||
}
|
||||
|
||||
template <typename ScalarExponent>
|
||||
using UnaryPowReturnType =
|
||||
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
|
||||
|
||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||
template <typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
|
||||
const ScalarExponent& exponent) const {
|
||||
return UnaryPowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||
#else
|
||||
/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent
|
||||
*
|
||||
* \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression.
|
||||
*
|
||||
* This function computes the coefficient-wise power. The function MatrixBase::pow() in the
|
||||
* unsupported module MatrixFunctions computes the matrix power.
|
||||
*
|
||||
* Example: \include Cwise_pow.cpp
|
||||
* Output: \verbinclude Cwise_pow.out
|
||||
*
|
||||
* \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log()
|
||||
*/
|
||||
template <typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const UnaryPowReturnType<ScalarExponent> pow(
|
||||
const ScalarExponent& exponent) const;
|
||||
#endif
|
||||
}
|
||||
|
@ -93,3 +93,13 @@ EIGEN_DOC_UNARY_ADDONS(cwiseArg,arg)
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline const CwiseArgReturnType
|
||||
cwiseArg() const { return CwiseArgReturnType(derived()); }
|
||||
|
||||
template <typename ScalarExponent>
|
||||
using CwisePowReturnType =
|
||||
std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||
CwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>;
|
||||
|
||||
template <typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC inline const CwisePowReturnType<ScalarExponent> cwisePow(const ScalarExponent& exponent) const {
|
||||
return CwisePowReturnType<ScalarExponent>(derived(), internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||
}
|
||||
|
@ -79,6 +79,50 @@ void pow_test() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef typename internal::make_integer<Scalar>::type Int_t;
|
||||
|
||||
// ensure both vectorized and non-vectorized paths taken
|
||||
Index test_size = 2 * internal::packet_traits<Scalar>::size + 1;
|
||||
|
||||
Array<Scalar, Dynamic, 1> eigenPow(test_size);
|
||||
for (int i = 0; i < num_cases; ++i) {
|
||||
Array<Scalar, Dynamic, 1> bases = x.col(i);
|
||||
for (Scalar abs_exponent : abs_vals){
|
||||
for (Scalar exponent : {-abs_exponent, abs_exponent}){
|
||||
// test floating point exponent code path
|
||||
eigenPow.setZero();
|
||||
eigenPow = bases.pow(exponent);
|
||||
for (int j = 0; j < num_repeats; j++){
|
||||
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
|
||||
Scalar a = eigenPow(j);
|
||||
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||
all_pass &= success;
|
||||
if (!success) {
|
||||
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
|
||||
}
|
||||
}
|
||||
// test integer exponent code path
|
||||
bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && (numext::abs(exponent) < static_cast<Scalar>(NumTraits<Int_t>::highest()));
|
||||
if (exponent_is_integer)
|
||||
{
|
||||
Int_t exponent_as_int = static_cast<Int_t>(exponent);
|
||||
eigenPow.setZero();
|
||||
eigenPow = bases.pow(exponent_as_int);
|
||||
for (int j = 0; j < num_repeats; j++){
|
||||
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent));
|
||||
Scalar a = eigenPow(j);
|
||||
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||
all_pass &= success;
|
||||
if (!success) {
|
||||
std::cout << "pow(" << x(i, j) << "," << y(i, j) << ") = " << a << " != " << e << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VERIFY(all_pass);
|
||||
}
|
||||
|
||||
|
@ -331,10 +331,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return choose(Cond<NumTraits<CoeffReturnType>::IsComplex>(), unaryExpr(internal::scalar_conjugate_op<Scalar>()), derived());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >, const Derived>
|
||||
pow(Scalar exponent) const {
|
||||
return unaryExpr(internal::bind2nd_op<internal::scalar_pow_op<Scalar,Scalar> >(exponent));
|
||||
template<typename ScalarExponent>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::enable_if_t<internal::is_arithmetic<ScalarExponent>::value,
|
||||
TensorCwiseUnaryOp<internal::scalar_unary_pow_op<Scalar, ScalarExponent>, const Derived>>
|
||||
pow(ScalarExponent exponent) const
|
||||
{
|
||||
return unaryExpr(internal::scalar_unary_pow_op<Scalar, ScalarExponent>(exponent));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
|
Loading…
x
Reference in New Issue
Block a user