Fix pow and other cwise ops for half/bfloat16.

The new `generic_pow` implementation was failing for half/bfloat16 since
their construction from int/float is not `constexpr`. Modified
in `GenericPacketMathFunctions` to remove `constexpr`.

While adding tests for half/bfloat16, found other issues related to
implicit conversions.

Also needed to implement `numext::arg` for non-integer, non-complex,
non-float/double/long double types.  These seem to be  implicitly
converted to `std::complex<T>`, which then fails for half/bfloat16.
This commit is contained in:
Antonio Sanchez 2021-01-22 11:10:54 -08:00
parent f19bcffee6
commit f0e46ed5d4
4 changed files with 92 additions and 69 deletions

View File

@ -555,8 +555,15 @@ struct rint_retval
****************************************************************************/ ****************************************************************************/
#if EIGEN_HAS_CXX11_MATH #if EIGEN_HAS_CXX11_MATH
template<typename Scalar> // std::arg is only defined for types of std::complex, or integer types or float/double/long double
struct arg_impl { template<typename Scalar,
bool HasStdImpl = NumTraits<Scalar>::IsComplex || is_integral<Scalar>::value
|| is_same<Scalar, float>::value || is_same<Scalar, double>::value
|| is_same<Scalar, long double>::value >
struct arg_default_impl;
template<typename Scalar>
struct arg_default_impl<Scalar, true> {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x) static inline Scalar run(const Scalar& x)
{ {
@ -566,23 +573,35 @@ struct rint_retval
#else #else
EIGEN_USING_STD(arg); EIGEN_USING_STD(arg);
#endif #endif
return arg(x); return static_cast<Scalar>(arg(x));
} }
}; };
#else
template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex> // Must be non-complex floating-point type (e.g. half/bfloat16).
struct arg_default_impl template<typename Scalar>
{ struct arg_default_impl<Scalar, false> {
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x) static inline RealScalar run(const Scalar& x)
{ {
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); } return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
}; }
};
template<typename Scalar> #else
struct arg_default_impl<Scalar,true> template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
struct arg_default_impl
{
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x)
{ {
return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
}
};
template<typename Scalar>
struct arg_default_impl<Scalar,true>
{
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
static inline RealScalar run(const Scalar& x) static inline RealScalar run(const Scalar& x)
@ -590,10 +609,9 @@ struct rint_retval
EIGEN_USING_STD(arg); EIGEN_USING_STD(arg);
return arg(x); return arg(x);
} }
}; };
template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
#endif #endif
template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
template<typename Scalar> template<typename Scalar>
struct arg_retval struct arg_retval
@ -1425,7 +1443,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) { T log(const T &x) {
EIGEN_USING_STD(log); EIGEN_USING_STD(log);
return log(x); return static_cast<T>(log(x));
} }
#if defined(SYCL_DEVICE_ONLY) #if defined(SYCL_DEVICE_ONLY)
@ -1602,7 +1620,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T acosh(const T &x) { T acosh(const T &x) {
EIGEN_USING_STD(acosh); EIGEN_USING_STD(acosh);
return acosh(x); return static_cast<T>(acosh(x));
} }
#endif #endif
@ -1631,7 +1649,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T asinh(const T &x) { T asinh(const T &x) {
EIGEN_USING_STD(asinh); EIGEN_USING_STD(asinh);
return asinh(x); return static_cast<T>(asinh(x));
} }
#endif #endif
@ -1652,7 +1670,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T atan(const T &x) { T atan(const T &x) {
EIGEN_USING_STD(atan); EIGEN_USING_STD(atan);
return atan(x); return static_cast<T>(atan(x));
} }
#if EIGEN_HAS_CXX11_MATH #if EIGEN_HAS_CXX11_MATH
@ -1660,7 +1678,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T atanh(const T &x) { T atanh(const T &x) {
EIGEN_USING_STD(atanh); EIGEN_USING_STD(atanh);
return atanh(x); return static_cast<T>(atanh(x));
} }
#endif #endif
@ -1682,7 +1700,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T cosh(const T &x) { T cosh(const T &x) {
EIGEN_USING_STD(cosh); EIGEN_USING_STD(cosh);
return cosh(x); return static_cast<T>(cosh(x));
} }
#if defined(SYCL_DEVICE_ONLY) #if defined(SYCL_DEVICE_ONLY)
@ -1701,7 +1719,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T sinh(const T &x) { T sinh(const T &x) {
EIGEN_USING_STD(sinh); EIGEN_USING_STD(sinh);
return sinh(x); return static_cast<T>(sinh(x));
} }
#if defined(SYCL_DEVICE_ONLY) #if defined(SYCL_DEVICE_ONLY)

View File

@ -804,8 +804,8 @@ EIGEN_STRONG_INLINE
void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) { void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2; EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
EIGEN_CONSTEXPR Scalar shift_scale = Scalar(uint64_t(1) << shift); Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
Packet gamma = pmul(pset1<Packet>(shift_scale + 1), x); Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma); x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma);
#else #else

View File

@ -403,7 +403,7 @@ struct functor_traits<scalar_log10_op<Scalar> >
*/ */
template<typename Scalar> struct scalar_log2_op { template<typename Scalar> struct scalar_log2_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * std::log(a); } EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * numext::log(a); }
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); } EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); }
}; };

View File

@ -329,7 +329,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
m3(rows, cols), m3(rows, cols),
m4 = m1; m4 = m1;
m4 = (m4.abs()==Scalar(0)).select(1,m4); m4 = (m4.abs()==Scalar(0)).select(Scalar(1),m4);
Scalar s1 = internal::random<Scalar>(); Scalar s1 = internal::random<Scalar>();
@ -358,7 +358,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all()); VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all());
VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all()); VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all());
VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all()); VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all());
VERIFY_IS_APPROX(m1.inverse(), inverse(m1)); VERIFY_IS_APPROX(m4.inverse(), inverse(m4));
VERIFY_IS_APPROX(m1.abs(), abs(m1)); VERIFY_IS_APPROX(m1.abs(), abs(m1));
VERIFY_IS_APPROX(m1.abs2(), abs2(m1)); VERIFY_IS_APPROX(m1.abs2(), abs2(m1));
VERIFY_IS_APPROX(m1.square(), square(m1)); VERIFY_IS_APPROX(m1.square(), square(m1));
@ -367,11 +367,11 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.sign(), sign(m1)); VERIFY_IS_APPROX(m1.sign(), sign(m1));
VERIFY((m1.sqrt().sign().isNaN() == (Eigen::isnan)(sign(sqrt(m1)))).all()); VERIFY((m1.sqrt().sign().isNaN() == (Eigen::isnan)(sign(sqrt(m1)))).all());
// avoid NaNs with abs() so verification doesn't fail // avoid inf and NaNs so verification doesn't fail
m3 = m1.abs(); m3 = m4.abs();
VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m1))); VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m3)));
VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m1))); VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m3)));
VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m1))); VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m3)));
VERIFY_IS_APPROX(m3.log(), log(m3)); VERIFY_IS_APPROX(m3.log(), log(m3));
VERIFY_IS_APPROX(m3.log1p(), log1p(m3)); VERIFY_IS_APPROX(m3.log1p(), log1p(m3));
VERIFY_IS_APPROX(m3.log10(), log10(m3)); VERIFY_IS_APPROX(m3.log10(), log10(m3));
@ -383,23 +383,23 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(sin(m1.asin()), m1); VERIFY_IS_APPROX(sin(m1.asin()), m1);
VERIFY_IS_APPROX(cos(m1.acos()), m1); VERIFY_IS_APPROX(cos(m1.acos()), m1);
VERIFY_IS_APPROX(tan(m1.atan()), m1); VERIFY_IS_APPROX(tan(m1.atan()), m1);
VERIFY_IS_APPROX(sinh(m1), 0.5*(exp(m1)-exp(-m1))); VERIFY_IS_APPROX(sinh(m1), Scalar(0.5)*(exp(m1)-exp(-m1)));
VERIFY_IS_APPROX(cosh(m1), 0.5*(exp(m1)+exp(-m1))); VERIFY_IS_APPROX(cosh(m1), Scalar(0.5)*(exp(m1)+exp(-m1)));
VERIFY_IS_APPROX(tanh(m1), (0.5*(exp(m1)-exp(-m1)))/(0.5*(exp(m1)+exp(-m1)))); VERIFY_IS_APPROX(tanh(m1), (Scalar(0.5)*(exp(m1)-exp(-m1)))/(Scalar(0.5)*(exp(m1)+exp(-m1))));
VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0+exp(-m1)))); VERIFY_IS_APPROX(logistic(m1), (Scalar(1)/(Scalar(1)+exp(-m1))));
VERIFY_IS_APPROX(arg(m1), ((m1<0).template cast<Scalar>())*std::acos(-1.0)); VERIFY_IS_APPROX(arg(m1), ((m1<Scalar(0)).template cast<Scalar>())*Scalar(std::acos(Scalar(-1))));
VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all()); VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all());
VERIFY((rint(m1) <= ceil(m1) && rint(m1) >= floor(m1)).all()); VERIFY((rint(m1) <= ceil(m1) && rint(m1) >= floor(m1)).all());
VERIFY(((ceil(m1) - round(m1)) <= Scalar(0.5) || (round(m1) - floor(m1)) <= Scalar(0.5)).all()); VERIFY(((ceil(m1) - round(m1)) <= Scalar(0.5) || (round(m1) - floor(m1)) <= Scalar(0.5)).all());
VERIFY(((ceil(m1) - round(m1)) <= Scalar(1.0) && (round(m1) - floor(m1)) <= Scalar(1.0)).all()); VERIFY(((ceil(m1) - round(m1)) <= Scalar(1.0) && (round(m1) - floor(m1)) <= Scalar(1.0)).all());
VERIFY(((ceil(m1) - rint(m1)) <= Scalar(0.5) || (rint(m1) - floor(m1)) <= Scalar(0.5)).all()); VERIFY(((ceil(m1) - rint(m1)) <= Scalar(0.5) || (rint(m1) - floor(m1)) <= Scalar(0.5)).all());
VERIFY(((ceil(m1) - rint(m1)) <= Scalar(1.0) && (rint(m1) - floor(m1)) <= Scalar(1.0)).all()); VERIFY(((ceil(m1) - rint(m1)) <= Scalar(1.0) && (rint(m1) - floor(m1)) <= Scalar(1.0)).all());
VERIFY((Eigen::isnan)((m1*0.0)/0.0).all()); VERIFY((Eigen::isnan)((m1*Scalar(0))/Scalar(0)).all());
VERIFY((Eigen::isinf)(m4/0.0).all()); VERIFY((Eigen::isinf)(m4/Scalar(0)).all());
VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*0.0/0.0)) && (!(Eigen::isfinite)(m4/0.0))).all()); VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*Scalar(0)/Scalar(0))) && (!(Eigen::isfinite)(m4/Scalar(0)))).all());
VERIFY_IS_APPROX(inverse(inverse(m1)),m1); VERIFY_IS_APPROX(inverse(inverse(m4)),m4);
VERIFY((abs(m1) == m1 || abs(m1) == -m1).all()); VERIFY((abs(m1) == m1 || abs(m1) == -m1).all());
VERIFY_IS_APPROX(m3, sqrt(abs2(m1))); VERIFY_IS_APPROX(m3, sqrt(abs2(m3)));
VERIFY_IS_APPROX(m1.absolute_difference(m2), (m1 > m2).select(m1 - m2, m2 - m1)); VERIFY_IS_APPROX(m1.absolute_difference(m2), (m1 > m2).select(m1 - m2, m2 - m1));
VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() ); VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() );
VERIFY_IS_APPROX( m1*m1.sign(),m1.abs()); VERIFY_IS_APPROX( m1*m1.sign(),m1.abs());
@ -412,26 +412,29 @@ template<typename ArrayType> void array_real(const ArrayType& m)
// shift argument of logarithm so that it is not zero // shift argument of logarithm so that it is not zero
Scalar smallNumber = NumTraits<Scalar>::dummy_precision(); Scalar smallNumber = NumTraits<Scalar>::dummy_precision();
VERIFY_IS_APPROX((m3 + smallNumber).log() , log(abs(m1) + smallNumber)); VERIFY_IS_APPROX((m3 + smallNumber).log() , log(abs(m3) + smallNumber));
VERIFY_IS_APPROX((m3 + smallNumber + 1).log() , log1p(abs(m1) + smallNumber)); VERIFY_IS_APPROX((m3 + smallNumber + Scalar(1)).log() , log1p(abs(m3) + smallNumber));
VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2)); VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2));
VERIFY_IS_APPROX(m1.exp(), exp(m1)); VERIFY_IS_APPROX(m1.exp(), exp(m1));
VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp()); VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());
VERIFY_IS_APPROX(m1.expm1(), expm1(m1)); VERIFY_IS_APPROX(m1.expm1(), expm1(m1));
VERIFY_IS_APPROX((m3 + smallNumber).exp() - 1, expm1(abs(m3) + smallNumber)); VERIFY_IS_APPROX((m3 + smallNumber).exp() - Scalar(1), expm1(abs(m3) + smallNumber));
VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt());
VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt());
VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt()); VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt());
VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt());
VERIFY_IS_APPROX(m1.pow(RealScalar(-2)), m1.square().inverse());
// Avoid inf and NaN.
m3 = (m1.square()<NumTraits<Scalar>::epsilon()).select(Scalar(1),m3);
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
pow_test<Scalar>(); pow_test<Scalar>();
VERIFY_IS_APPROX(log10(m3), log(m3)/log(10)); VERIFY_IS_APPROX(log10(m3), log(m3)/log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/log(2)); VERIFY_IS_APPROX(log2(m3), log(m3)/log(Scalar(2)));
// scalar by array division // scalar by array division
const RealScalar tiny = sqrt(std::numeric_limits<RealScalar>::epsilon()); const RealScalar tiny = sqrt(std::numeric_limits<RealScalar>::epsilon());
@ -480,7 +483,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all()); VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all());
VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all()); VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all());
VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all()); VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all());
VERIFY_IS_APPROX(m1.inverse(), inverse(m1)); VERIFY_IS_APPROX(m4.inverse(), inverse(m4));
VERIFY_IS_APPROX(m1.log(), log(m1)); VERIFY_IS_APPROX(m1.log(), log(m1));
VERIFY_IS_APPROX(m1.log10(), log10(m1)); VERIFY_IS_APPROX(m1.log10(), log10(m1));
VERIFY_IS_APPROX(m1.log2(), log2(m1)); VERIFY_IS_APPROX(m1.log2(), log2(m1));
@ -534,7 +537,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*zero/zero)) && (!(Eigen::isfinite)(m1/zero))).all()); VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*zero/zero)) && (!(Eigen::isfinite)(m1/zero))).all());
VERIFY_IS_APPROX(inverse(inverse(m1)),m1); VERIFY_IS_APPROX(inverse(inverse(m4)),m4);
VERIFY_IS_APPROX(conj(m1.conjugate()), m1); VERIFY_IS_APPROX(conj(m1.conjugate()), m1);
VERIFY_IS_APPROX(abs(m1), sqrt(square(m1.real())+square(m1.imag()))); VERIFY_IS_APPROX(abs(m1), sqrt(square(m1.real())+square(m1.imag())));
VERIFY_IS_APPROX(abs(m1), sqrt(abs2(m1))); VERIFY_IS_APPROX(abs(m1), sqrt(abs2(m1)));
@ -622,6 +625,8 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_2( array_real(Array22f()) ); CALL_SUBTEST_2( array_real(Array22f()) );
CALL_SUBTEST_3( array_real(Array44d()) ); CALL_SUBTEST_3( array_real(Array44d()) );
CALL_SUBTEST_5( array_real(ArrayXXf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_5( array_real(ArrayXXf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_7( array_real(Array<Eigen::half, 32, 32>()) );
CALL_SUBTEST_8( array_real(Array<Eigen::bfloat16, 32, 32>()) );
} }
for(int i = 0; i < g_repeat; i++) { for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) ); CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );