Special function implementations for half/bfloat16 packets.

Current implementations fail to consider half-float packets, only
half-float scalars.  Added specializations for packets on AVX, AVX512 and
NEON.  Added tests to `special_packetmath`.

The current `special_functions` tests would fail for half and bfloat16 due to
lack of precision. The NEON tests also fail with precision issues and
due to different handling of `sqrt(inf)`, so special functions bessel, ndtri
have been disabled.

Tested with AVX, AVX512.
This commit is contained in:
Antonio Sanchez 2020-12-02 14:00:57 -08:00
parent 305b8bd277
commit e2f21465fe
17 changed files with 418 additions and 159 deletions

View File

@ -147,7 +147,9 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
HasRint = 1
HasRint = 1,
HasBessel = 1,
HasNdtri = 1,
};
};
@ -189,7 +191,9 @@ struct packet_traits<bfloat16> : default_packet_traits {
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
HasRint = 1
HasRint = 1,
HasBessel = 1,
HasNdtri = 1,
};
};
#endif

View File

@ -86,7 +86,9 @@ struct packet_traits<half> : default_packet_traits {
HasRound = 1,
HasFloor = 1,
HasCeil = 1,
HasRint = 1
HasRint = 1,
HasBessel = 1,
HasNdtri = 1,
};
};

View File

@ -58,7 +58,7 @@
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
template <> \
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \
PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
return float2half(METHOD<PACKET_F>(half2float(_x))); \
}

View File

@ -192,7 +192,9 @@ struct packet_traits<float> : default_packet_traits
HasExp = 1,
HasSqrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH
HasErf = EIGEN_FAST_MATH,
HasBessel = 0, // Issues with accuracy.
HasNdtri = 0
};
};
@ -3321,7 +3323,9 @@ template<> struct packet_traits<bfloat16> : default_packet_traits
HasExp = 1,
HasSqrt = 0,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH
HasErf = EIGEN_FAST_MATH,
HasBessel = 0, // Issues with accuracy.
HasNdtri = 0,
};
};
@ -3887,7 +3891,10 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasCos = 0,
HasLog = 0,
HasExp = 0,
HasSqrt = 1
HasSqrt = 1,
HasErf = EIGEN_FAST_MATH,
HasBessel = 0, // Issues with accuracy.
HasNdtri = 0,
};
};

View File

@ -61,23 +61,36 @@ namespace Eigen {
}
#include "src/SpecialFunctions/BesselFunctionsImpl.h"
#include "src/SpecialFunctions/BesselFunctionsPacketMath.h"
#include "src/SpecialFunctions/BesselFunctionsBFloat16.h"
#include "src/SpecialFunctions/BesselFunctionsHalf.h"
#include "src/SpecialFunctions/BesselFunctionsPacketMath.h"
#include "src/SpecialFunctions/BesselFunctionsFunctors.h"
#include "src/SpecialFunctions/BesselFunctionsArrayAPI.h"
#include "src/SpecialFunctions/SpecialFunctionsImpl.h"
#if defined(EIGEN_HIPCC)
#include "src/SpecialFunctions/HipVectorCompatibility.h"
#endif
#include "src/SpecialFunctions/SpecialFunctionsPacketMath.h"
#include "src/SpecialFunctions/SpecialFunctionsBFloat16.h"
#include "src/SpecialFunctions/SpecialFunctionsHalf.h"
#include "src/SpecialFunctions/SpecialFunctionsPacketMath.h"
#include "src/SpecialFunctions/SpecialFunctionsFunctors.h"
#include "src/SpecialFunctions/SpecialFunctionsArrayAPI.h"
#if defined EIGEN_VECTORIZE_AVX512
#include "src/SpecialFunctions/arch/AVX/BesselFunctions.h"
#include "src/SpecialFunctions/arch/AVX/SpecialFunctions.h"
#include "src/SpecialFunctions/arch/AVX512/BesselFunctions.h"
#include "src/SpecialFunctions/arch/AVX512/SpecialFunctions.h"
#elif defined EIGEN_VECTORIZE_AVX
#include "src/SpecialFunctions/arch/AVX/BesselFunctions.h"
#include "src/SpecialFunctions/arch/AVX/SpecialFunctions.h"
#elif defined EIGEN_VECTORIZE_NEON
#include "src/SpecialFunctions/arch/NEON/BesselFunctions.h"
#include "src/SpecialFunctions/arch/NEON/SpecialFunctions.h"
#endif
#if defined EIGEN_VECTORIZE_GPU
#include "src/SpecialFunctions/arch/GPU/GpuSpecialFunctions.h"
#include "src/SpecialFunctions/arch/GPU/SpecialFunctions.h"
#endif
namespace Eigen {

View File

@ -46,7 +46,7 @@ struct bessel_i0e_retval {
typedef Scalar type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i0e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -201,11 +201,11 @@ struct generic_i0e<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_i0e_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_i0e<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_i0e<T>::run(x);
}
};
@ -214,7 +214,7 @@ struct bessel_i0_retval {
typedef Scalar type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T& x) {
@ -224,11 +224,11 @@ struct generic_i0 {
}
};
template <typename Scalar>
template <typename T>
struct bessel_i0_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_i0<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_i0<T>::run(x);
}
};
@ -237,7 +237,7 @@ struct bessel_i1e_retval {
typedef Scalar type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type >
struct generic_i1e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -396,20 +396,20 @@ struct generic_i1e<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_i1e_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_i1e<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_i1e<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_i1_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_i1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T& x) {
@ -419,20 +419,20 @@ struct generic_i1 {
}
};
template <typename Scalar>
template <typename T>
struct bessel_i1_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_i1<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_i1<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_k0e_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k0e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -582,20 +582,20 @@ struct generic_k0e<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_k0e_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_k0e<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_k0e<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_k0_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -754,20 +754,20 @@ struct generic_k0<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_k0_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_k0<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_k0<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_k1e_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k1e {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -910,20 +910,20 @@ struct generic_k1e<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_k1e_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_k1e<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_k1e<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_k1_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_k1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -1076,20 +1076,20 @@ struct generic_k1<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_k1_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_k1<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_k1<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_j0_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_j0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -1276,20 +1276,20 @@ struct generic_j0<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_j0_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_j0<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_j0<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_y0_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_y0 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -1474,20 +1474,20 @@ struct generic_y0<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_y0_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_y0<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_y0<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_j1_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_j1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -1665,20 +1665,20 @@ struct generic_j1<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_j1_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_j1<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_j1<T>::run(x);
}
};
template <typename Scalar>
template <typename T>
struct bessel_y1_retval {
typedef Scalar type;
typedef T type;
};
template <typename T, typename ScalarType>
template <typename T, typename ScalarType = typename unpacket_traits<T>::type>
struct generic_y1 {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T&) {
@ -1868,11 +1868,11 @@ struct generic_y1<T, double> {
}
};
template <typename Scalar>
template <typename T>
struct bessel_y1_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
return generic_y1<Scalar, Scalar>::run(x);
static EIGEN_STRONG_INLINE T run(const T x) {
return generic_y1<T>::run(x);
}
};

View File

@ -19,8 +19,7 @@ namespace internal {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_i0(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_i0; return generic_i0<Packet, ScalarType>::run(x);
return numext::bessel_i0(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -28,8 +27,7 @@ Packet pbessel_i0(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_i0e(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_i0e; return generic_i0e<Packet, ScalarType>::run(x);
return numext::bessel_i0e(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -37,8 +35,7 @@ Packet pbessel_i0e(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_i1(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_i1; return generic_i1<Packet, ScalarType>::run(x);
return numext::bessel_i1(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -46,8 +43,7 @@ Packet pbessel_i1(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_i1e(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_i1e; return generic_i1e<Packet, ScalarType>::run(x);
return numext::bessel_i1e(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -55,8 +51,7 @@ Packet pbessel_i1e(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_j0(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_j0; return generic_j0<Packet, ScalarType>::run(x);
return numext::bessel_j0(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -64,8 +59,7 @@ Packet pbessel_j0(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_j1(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_j1; return generic_j1<Packet, ScalarType>::run(x);
return numext::bessel_j1(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -73,8 +67,7 @@ Packet pbessel_j1(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_y0(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_y0; return generic_y0<Packet, ScalarType>::run(x);
return numext::bessel_y0(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -82,8 +75,7 @@ Packet pbessel_y0(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_y1(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_y1; return generic_y1<Packet, ScalarType>::run(x);
return numext::bessel_y1(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -91,8 +83,7 @@ Packet pbessel_y1(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_k0(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_k0; return generic_k0<Packet, ScalarType>::run(x);
return numext::bessel_k0(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -100,8 +91,7 @@ Packet pbessel_k0(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_k0e(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_k0e; return generic_k0e<Packet, ScalarType>::run(x);
return numext::bessel_k0e(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -109,8 +99,7 @@ Packet pbessel_k0e(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_k1(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_k1; return generic_k1<Packet, ScalarType>::run(x);
return numext::bessel_k1(x);
}
/** \internal \returns the exponentially scaled modified Bessel function of
@ -118,8 +107,7 @@ Packet pbessel_k1(const Packet& x) {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pbessel_k1e(const Packet& x) {
typedef typename unpacket_traits<Packet>::type ScalarType;
using internal::generic_k1e; return generic_k1e<Packet, ScalarType>::run(x);
return numext::bessel_k1e(x);
}
} // end namespace internal

View File

@ -348,7 +348,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
template <typename T>
struct erf_impl {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE T run(const T x) {
static EIGEN_STRONG_INLINE T run(const T& x) {
return generic_fast_erf_float(x);
}
};
@ -490,7 +490,8 @@ struct erfc_impl<double> {
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
const T& should_flipsign, const T& x) {
const T sign_mask = pset1<T>(-0.0);
typedef typename unpacket_traits<T>::type Scalar;
const T sign_mask = pset1<T>(Scalar(-0.0));
T sign_bit = pand<T>(should_flipsign, sign_mask);
return pxor<T>(sign_bit, x);
}

View File

@ -0,0 +1,46 @@
#ifndef EIGEN_AVX_BESSELFUNCTIONS_H
#define EIGEN_AVX_BESSELFUNCTIONS_H
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i0e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_i1e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_j1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_j1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k0e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_k1e)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y0)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y0)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pbessel_y1)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pbessel_y1)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX_BESSELFUNCTIONS_H

View File

@ -0,0 +1,16 @@
#ifndef EIGEN_AVX_SPECIALFUNCTIONS_H
#define EIGEN_AVX_SPECIALFUNCTIONS_H
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet8f, Packet8h, perf)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, perf)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pndtri)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pndtri)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX_SPECIAL_FUNCTIONS_H

View File

@ -0,0 +1,46 @@
#ifndef EIGEN_AVX512_BESSELFUNCTIONS_H
#define EIGEN_AVX512_BESSELFUNCTIONS_H
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i0e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_i1e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_j1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_j1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k0e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_k1e)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y0)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y0)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pbessel_y1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pbessel_y1)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX512_BESSELFUNCTIONS_H

View File

@ -0,0 +1,16 @@
#ifndef EIGEN_AVX512_SPECIALFUNCTIONS_H
#define EIGEN_AVX512_SPECIALFUNCTIONS_H
namespace Eigen {
namespace internal {
F16_PACKET_FUNCTION(Packet16f, Packet16h, perf)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, perf)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pndtri)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pndtri)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_AVX512_SPECIAL_FUNCTIONS_H

View File

@ -0,0 +1,54 @@
#ifndef EIGEN_NEON_BESSELFUNCTIONS_H
#define EIGEN_NEON_BESSELFUNCTIONS_H
namespace Eigen {
namespace internal {
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
#define NEON_HALF_TO_FLOAT_FUNCTIONS(METHOD) \
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
Packet8hf METHOD<Packet8hf>(const Packet8hf& x) { \
const Packet4f lo = METHOD<Packet4f>(vcvt_f32_f16(vget_low_f16(x))); \
const Packet4f hi = METHOD<Packet4f>(vcvt_f32_f16(vget_high_f16(x))); \
return vcombine_f16(vcvt_f16_f32(lo), vcvt_f16_f32(hi)); \
} \
\
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
Packet4hf METHOD<Packet4hf>(const Packet4hf& x) { \
return vcvt_f16_f32(METHOD<Packet4f>(vcvt_f32_f16(x))); \
}
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_i0)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_i0e)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_i1)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_i1e)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_j0)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_j1)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_k0)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_k0e)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_k1)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_k1e)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_y0)
NEON_HALF_TO_FLOAT_FUNCTIONS(pbessel_y1)
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
#endif
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i0e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_i1e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_j1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k0e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_k1e)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y0)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pbessel_y1)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_NEON_BESSELFUNCTIONS_H

View File

@ -0,0 +1,34 @@
#ifndef EIGEN_NEON_SPECIALFUNCTIONS_H
#define EIGEN_NEON_SPECIALFUNCTIONS_H
namespace Eigen {
namespace internal {
#ifdef EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
#define NEON_HALF_TO_FLOAT_FUNCTIONS(METHOD) \
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
Packet8hf METHOD<Packet8hf>(const Packet8hf& x) { \
const Packet4f lo = METHOD<Packet4f>(vcvt_f32_f16(vget_low_f16(x))); \
const Packet4f hi = METHOD<Packet4f>(vcvt_f32_f16(vget_high_f16(x))); \
return vcombine_f16(vcvt_f16_f32(lo), vcvt_f16_f32(hi)); \
} \
\
template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
Packet4hf METHOD<Packet4hf>(const Packet4hf& x) { \
return vcvt_f16_f32(METHOD<Packet4f>(vcvt_f32_f16(x))); \
}
NEON_HALF_TO_FLOAT_FUNCTIONS(perf)
NEON_HALF_TO_FLOAT_FUNCTIONS(pndtri)
#undef NEON_HALF_TO_FLOAT_FUNCTIONS
#endif
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, perf)
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pndtri)
} // namespace internal
} // namespace Eigen
#endif // EIGEN_NEON_SPECIALFUNCTIONS_H

View File

@ -11,6 +11,17 @@
#include "main.h"
#include "../Eigen/SpecialFunctions"
// Hack to allow "implicit" conversions from double to Scalar via comma-initialization.
template<typename Derived>
Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) {
return (dense << static_cast<typename Derived::Scalar>(v));
}
template<typename XprType>
Eigen::CommaInitializer<XprType>& operator,(Eigen::CommaInitializer<XprType>& ci, double v) {
return (ci, static_cast<typename XprType::Scalar>(v));
}
template<typename X, typename Y>
void verify_component_wise(const X& x, const Y& y)
{
@ -65,8 +76,8 @@ template<typename ArrayType> void array_special_functions()
// igamma(a, x) = gamma(a, x) / Gamma(a)
// where Gamma and gamma are considered the standard unnormalized
// upper and lower incomplete gamma functions, respectively.
ArrayType a = m1.abs() + 2;
ArrayType x = m2.abs() + 2;
ArrayType a = m1.abs() + Scalar(2);
ArrayType x = m2.abs() + Scalar(2);
ArrayType zero = ArrayType::Zero(rows, cols);
ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
ArrayType a_m1 = a - one;
@ -83,18 +94,18 @@ template<typename ArrayType> void array_special_functions()
VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
// Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp());
// gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp());
}
{
// Verify for large a and x that values are between 0 and 1.
ArrayType m1 = ArrayType::Random(rows,cols);
ArrayType m2 = ArrayType::Random(rows,cols);
Scalar max_exponent = std::numeric_limits<Scalar>::max_exponent10;
ArrayType a = m1.abs() * pow(10., max_exponent - 1);
ArrayType x = m2.abs() * pow(10., max_exponent - 1);
int max_exponent = std::numeric_limits<Scalar>::max_exponent10;
ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1));
ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1));
for (int i = 0; i < a.size(); ++i) {
Scalar igam = numext::igamma(a(i), x(i));
VERIFY(0 <= igam);
@ -108,27 +119,37 @@ template<typename ArrayType> void array_special_functions()
Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
// location i*6+j corresponds to a_s[i], x_s[j].
Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
{0.0, 0.6321205588285578, 0.7768698398515702,
0.9816843611112658, 9.999500016666262e-05, 1.0},
{0.0, 0.4275932955291202, 0.608374823728911,
0.9539882943107686, 7.522076445089201e-07, 1.0},
{0.0, 0.01898815687615381, 0.06564245437845008,
0.5665298796332909, 4.166333347221828e-18, 1.0},
{0.0, 0.9999780593618628, 0.9999899967080838,
0.9999996219837988, 0.9991370418689945, 1.0},
{0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
{1.0, 0.36787944117144233, 0.22313016014842982,
0.018315638888734182, 0.9999000049998333, 0.0},
{1.0, 0.5724067044708798, 0.3916251762710878,
0.04601170568923136, 0.9999992477923555, 0.0},
{1.0, 0.9810118431238462, 0.9343575456215499,
0.4334701203667089, 1.0, 0.0},
{1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
3.7801620118431334e-07, 0.0008629581310054535,
0.0},
{1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
Scalar igamma_s[][6] = {
{Scalar(0.0), nan, nan, nan, nan, nan},
{Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702),
Scalar(0.9816843611112658), Scalar(9.999500016666262e-05),
Scalar(1.0)},
{Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911),
Scalar(0.9539882943107686), Scalar(7.522076445089201e-07),
Scalar(1.0)},
{Scalar(0.0), Scalar(0.01898815687615381),
Scalar(0.06564245437845008), Scalar(0.5665298796332909),
Scalar(4.166333347221828e-18), Scalar(1.0)},
{Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838),
Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)},
{Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0),
Scalar(0.5042041932513908)}};
Scalar igammac_s[][6] = {
{nan, nan, nan, nan, nan, nan},
{Scalar(1.0), Scalar(0.36787944117144233),
Scalar(0.22313016014842982), Scalar(0.018315638888734182),
Scalar(0.9999000049998333), Scalar(0.0)},
{Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878),
Scalar(0.04601170568923136), Scalar(0.9999992477923555),
Scalar(0.0)},
{Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499),
Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)},
{Scalar(1.0), Scalar(2.1940638138146658e-05),
Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07),
Scalar(0.0008629581310054535), Scalar(0.0)},
{Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0),
Scalar(0.49579580674813944)}};
for (int i = 0; i < 6; ++i) {
for (int j = 0; j < 6; ++j) {
if ((std::isnan)(igamma_s[i][j])) {
@ -162,8 +183,8 @@ template<typename ArrayType> void array_special_functions()
ArrayType m1 = ArrayType::Random(32);
using std::sqrt;
ArrayType cdf_val = (m1 / sqrt(2.)).erf();
cdf_val = (cdf_val + 1.) / 2.;
ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf();
cdf_val = (cdf_val + Scalar(1)) / Scalar(2);
verify_component_wise(cdf_val.ndtri(), m1););
}
@ -190,7 +211,6 @@ template<typename ArrayType> void array_special_functions()
CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
}
#if EIGEN_HAS_C99_MATH
{
ArrayType n(11), x(11), res(11), ref(11);
@ -323,8 +343,8 @@ template<typename ArrayType> void array_special_functions()
ArrayType m3 = ArrayType::Random(32);
ArrayType one = ArrayType::Constant(32, Scalar(1.0));
const Scalar eps = std::numeric_limits<Scalar>::epsilon();
ArrayType a = (m1 * 4.0).exp();
ArrayType b = (m2 * 4.0).exp();
ArrayType a = (m1 * Scalar(4)).exp();
ArrayType b = (m2 * Scalar(4)).exp();
ArrayType x = m3.abs();
// betainc(a, 1, x) == x**a
@ -471,4 +491,7 @@ EIGEN_DECLARE_TEST(special_functions)
{
CALL_SUBTEST_1(array_special_functions<ArrayXf>());
CALL_SUBTEST_2(array_special_functions<ArrayXd>());
// TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above.
// CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>());
// CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>());
}

View File

@ -8,6 +8,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include <limits>
#include "packetmath_test_shared.h"
#include "../Eigen/SpecialFunctions"
@ -43,42 +44,48 @@ template<typename Scalar,typename Packet> void packetmath_real()
}
{
for (int i=0; i<size; ++i) {
data1[i] = internal::random<Scalar>(0,1);
data1[i] = internal::random<Scalar>(Scalar(0),Scalar(1));
}
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasNdtri, numext::ndtri, internal::pndtri);
}
#endif // EIGEN_HAS_C99_MATH
// For bessel_i*e and bessel_j*, the valid range is negative reals.
for (int i=0; i<size; ++i)
{
data1[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
data2[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
}
const int max_exponent = numext::mini(std::numeric_limits<Scalar>::max_exponent10-1, 6);
for (int i=0; i<size; ++i)
{
data1[i] = internal::random<Scalar>(Scalar(-1),Scalar(1)) * Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-max_exponent),Scalar(max_exponent))));
data2[i] = internal::random<Scalar>(Scalar(-1),Scalar(1)) * Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-max_exponent),Scalar(max_exponent))));
}
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i0e, internal::pbessel_i0e);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i1e, internal::pbessel_i1e);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_j0, internal::pbessel_j0);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_j1, internal::pbessel_j1);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i0e, internal::pbessel_i0e);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i1e, internal::pbessel_i1e);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_j0, internal::pbessel_j0);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_j1, internal::pbessel_j1);
}
// Use a smaller data range for the bessel_i* as these can become very large.
// Following #1693, we also restrict this range further to avoid inf's due to
// differences in pexp and exp.
for (int i=0; i<size; ++i) {
data1[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(9), internal::random<Scalar>(-1,2));
data2[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(9), internal::random<Scalar>(-1,2));
data1[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(9), internal::random<Scalar>(Scalar(-1),Scalar(2))));
data2[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(9), internal::random<Scalar>(Scalar(-1),Scalar(2))));
}
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i0, internal::pbessel_i0);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_i1, internal::pbessel_i1);
// y_i, and k_i are valid for x > 0.
for (int i=0; i<size; ++i)
{
data1[i] = internal::random<Scalar>(0.01,1) * std::pow(Scalar(10), internal::random<Scalar>(-2,5));
data2[i] = internal::random<Scalar>(0.01,1) * std::pow(Scalar(10), internal::random<Scalar>(-2,5));
const int max_exponent = numext::mini(std::numeric_limits<Scalar>::max_exponent10-1, 5);
for (int i=0; i<size; ++i)
{
data1[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) * Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-2),Scalar(max_exponent))));
data2[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) * Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-2),Scalar(max_exponent))));
}
}
// TODO(srvasude): Re-enable this test once properly investigated why the
@ -91,20 +98,20 @@ template<typename Scalar,typename Packet> void packetmath_real()
// Following #1693, we restrict the range for exp to avoid zeroing out too
// fast.
for (int i=0; i<size; ++i) {
data1[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(9), internal::random<Scalar>(-1,2));
data2[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(9), internal::random<Scalar>(-1,2));
data1[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(9), internal::random<Scalar>(Scalar(-1),Scalar(2))));
data2[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(9), internal::random<Scalar>(Scalar(-1),Scalar(2))));
}
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_k0, internal::pbessel_k0);
CHECK_CWISE1_IF(PacketTraits::HasBessel, numext::bessel_k1, internal::pbessel_k1);
for (int i=0; i<size; ++i) {
data1[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(10), internal::random<Scalar>(-1,2));
data2[i] = internal::random<Scalar>(0.01,1) * std::pow(
Scalar(10), internal::random<Scalar>(-1,2));
data1[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-1),Scalar(2))));
data2[i] = internal::random<Scalar>(Scalar(0.01),Scalar(1)) *
Scalar(std::pow(Scalar(10), internal::random<Scalar>(Scalar(-1),Scalar(2))));
}
#if EIGEN_HAS_C99_MATH && (__cplusplus > 199711L)
@ -135,6 +142,8 @@ EIGEN_DECLARE_TEST(special_packetmath)
CALL_SUBTEST_1( test::runner<float>::run() );
CALL_SUBTEST_2( test::runner<double>::run() );
CALL_SUBTEST_3( test::runner<Eigen::half>::run() );
CALL_SUBTEST_4( test::runner<Eigen::bfloat16>::run() );
g_first_pass = false;
}
}