Fix a couple of issues with unary pow():

This commit is contained in:
Rasmus Munk Larsen 2022-09-09 17:21:11 +00:00
parent 07d0759951
commit e8a2aa24a2

View File

@ -1085,7 +1085,7 @@ struct scalar_unary_pow_op {
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return pow(a, m_exponent);
return static_cast<result_type>(pow(a, m_exponent));
}
private:
@ -1096,12 +1096,12 @@ struct scalar_unary_pow_op {
template <typename Scalar, typename ScalarExponent>
struct scalar_unary_pow_op<Scalar, ScalarExponent, false, false, false, false> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_unary_pow_op(const ScalarExponent& exponent) : m_exponent(exponent) {
EIGEN_STATIC_ASSERT((is_same<Scalar, ScalarExponent>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_same<std::remove_const_t<Scalar>, std::remove_const_t<ScalarExponent>>::value), NON_INTEGER_EXPONENT_MUST_BE_SAME_TYPE_AS_BASE);
EIGEN_STATIC_ASSERT((is_arithmetic<ScalarExponent>::value), EXPONENT_MUST_BE_ARITHMETIC);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
EIGEN_USING_STD(pow);
return pow(a, m_exponent);
return static_cast<Scalar>(pow(a, m_exponent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const {