mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-25 23:56:56 +08:00
Add numext::fma and missing pmadd implementations.
This commit is contained in:
parent
754bd24f5e
commit
d935916ac6
@ -368,6 +368,11 @@ template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pdiv(const Packet& a, const Packet& b) {
|
||||
return a / b;
|
||||
}
|
||||
// Avoid compiler warning for boolean algebra.
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline bool pdiv(const bool& a, const bool& b) {
|
||||
return a && b;
|
||||
}
|
||||
|
||||
// In the generic case, memset to all one bits.
|
||||
template <typename Packet, typename EnableIf = void>
|
||||
@ -1294,29 +1299,61 @@ EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a) {
|
||||
* The following functions might not have to be overwritten for vectorized types
|
||||
***************************************************************************/
|
||||
|
||||
template <typename Packet, typename EnableIf = void>
|
||||
struct pmadd_impl {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return padd(pmul(a, b), c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pmsub(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return psub(pmul(a, b), c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return psub(c, pmul(a, b));
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return pnegate(pmadd(a, b, c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, Scalar(-c));
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(Scalar(-a), b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return -Scalar(numext::fma(a, b, c));
|
||||
}
|
||||
};
|
||||
|
||||
// FMA instructions.
|
||||
/** \internal \returns a * b + c (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return padd(pmul(a, b), c);
|
||||
return pmadd_impl<Packet>::pmadd(a, b, c);
|
||||
}
|
||||
|
||||
/** \internal \returns a * b - c (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return psub(pmul(a, b), c);
|
||||
return pmadd_impl<Packet>::pmsub(a, b, c);
|
||||
}
|
||||
|
||||
/** \internal \returns -(a * b) + c (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return psub(c, pmul(a, b));
|
||||
return pmadd_impl<Packet>::pnmadd(a, b, c);
|
||||
}
|
||||
|
||||
/** \internal \returns -((a * b + c) (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) {
|
||||
return pnegate(pmadd(a, b, c));
|
||||
return pmadd_impl<Packet>::pnmsub(a, b, c);
|
||||
}
|
||||
|
||||
/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned
|
||||
|
@ -945,6 +945,38 @@ struct nearest_integer_impl<Scalar, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
|
||||
};
|
||||
|
||||
// Default implementation.
|
||||
template <typename Scalar, typename Enable = void>
|
||||
struct fma_impl {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return a * b + c;
|
||||
}
|
||||
};
|
||||
|
||||
// ADL version if it exists.
|
||||
template <typename T>
|
||||
struct fma_impl<
|
||||
T,
|
||||
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
|
||||
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
|
||||
};
|
||||
|
||||
#if defined(EIGEN_GPUCC)
|
||||
template <>
|
||||
struct fma_impl<float, void> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
|
||||
return ::fmaf(a, b, c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fma_impl<double, void> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
|
||||
return ::fma(a, b, c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/****************************************************************************
|
||||
@ -1852,6 +1884,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
|
||||
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
|
||||
}
|
||||
|
||||
// Use std::fma if available.
|
||||
using std::fma;
|
||||
|
||||
// Otherwise, rely on template implementation.
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||
return internal::fma_impl<Scalar>::run(x, y, z);
|
||||
}
|
||||
|
||||
} // end namespace numext
|
||||
|
||||
namespace internal {
|
||||
|
@ -2415,6 +2415,26 @@ EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b
|
||||
return float2half(rf);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pmadd<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pmsub<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pnmadd<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pnmsub<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
Packet8f af = half2float(a);
|
||||
@ -2785,6 +2805,26 @@ EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8b
|
||||
return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pmadd<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
return F32ToBf16(pmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pmsub<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
return F32ToBf16(pmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pnmadd<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
return F32ToBf16(pnmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pnmsub<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
return F32ToBf16(pnmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||
return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
|
@ -2264,6 +2264,7 @@ template <>
|
||||
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
|
||||
// (void*) -> workaround clang warning:
|
||||
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
|
||||
EIGEN_DEBUG_ALIGNED_STORE
|
||||
_mm256_store_si256((__m256i*)(void*)to, from);
|
||||
}
|
||||
|
||||
@ -2271,6 +2272,7 @@ template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
|
||||
// (void*) -> workaround clang warning:
|
||||
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
_mm256_storeu_si256((__m256i*)(void*)to, from);
|
||||
}
|
||||
|
||||
@ -2754,11 +2756,13 @@ EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet16bf& from) {
|
||||
EIGEN_DEBUG_ALIGNED_STORE
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet16bf& from) {
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
|
||||
}
|
||||
|
||||
@ -3154,32 +3158,46 @@ EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
|
||||
_mm512_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_ALIGNED_STORE
|
||||
_mm512_store_epi32(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
|
||||
_mm256_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_ALIGNED_STORE
|
||||
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
|
||||
_mm256_store_epi32(out, x);
|
||||
#else
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(out), x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
|
||||
_mm_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_ALIGNED_STORE
|
||||
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
|
||||
_mm256_store_epi32(out, x);
|
||||
#else
|
||||
_mm_store_si128(reinterpret_cast<__m128i*>(out), x);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
|
||||
_mm512_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
_mm512_storeu_epi32(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
|
||||
_mm256_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
_mm256_storeu_epi32(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
|
||||
_mm_storeu_epi16(out, x);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
_mm_storeu_epi32(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -424,55 +424,6 @@ struct unpacket_traits<Packet8bf> {
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
inline std::ostream& operator<<(std::ostream& s, const Packet16c& v) {
|
||||
union {
|
||||
Packet16c v;
|
||||
signed char n[16];
|
||||
} vt;
|
||||
vt.v = v;
|
||||
for (int i = 0; i < 16; i++) s << vt.n[i] << ", ";
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& s, const Packet16uc& v) {
|
||||
union {
|
||||
Packet16uc v;
|
||||
unsigned char n[16];
|
||||
} vt;
|
||||
vt.v = v;
|
||||
for (int i = 0; i < 16; i++) s << vt.n[i] << ", ";
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& s, const Packet4f& v) {
|
||||
union {
|
||||
Packet4f v;
|
||||
float n[4];
|
||||
} vt;
|
||||
vt.v = v;
|
||||
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& s, const Packet4i& v) {
|
||||
union {
|
||||
Packet4i v;
|
||||
int n[4];
|
||||
} vt;
|
||||
vt.v = v;
|
||||
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& s, const Packet4ui& v) {
|
||||
union {
|
||||
Packet4ui v;
|
||||
unsigned int n[4];
|
||||
} vt;
|
||||
vt.v = v;
|
||||
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE Packet pload_common(const __UNPACK_TYPE__(Packet) * from) {
|
||||
@ -2384,6 +2335,44 @@ EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, cons
|
||||
return F32ToBf16(pmadd_even, pmadd_odd);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pmsub(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
Packet4f a_even = Bf16ToF32Even(a);
|
||||
Packet4f a_odd = Bf16ToF32Odd(a);
|
||||
Packet4f b_even = Bf16ToF32Even(b);
|
||||
Packet4f b_odd = Bf16ToF32Odd(b);
|
||||
Packet4f c_even = Bf16ToF32Even(c);
|
||||
Packet4f c_odd = Bf16ToF32Odd(c);
|
||||
Packet4f pmadd_even = pmsub<Packet4f>(a_even, b_even, c_even);
|
||||
Packet4f pmadd_odd = pmsub<Packet4f>(a_odd, b_odd, c_odd);
|
||||
return F32ToBf16(pmadd_even, pmadd_odd);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pnmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
Packet4f a_even = Bf16ToF32Even(a);
|
||||
Packet4f a_odd = Bf16ToF32Odd(a);
|
||||
Packet4f b_even = Bf16ToF32Even(b);
|
||||
Packet4f b_odd = Bf16ToF32Odd(b);
|
||||
Packet4f c_even = Bf16ToF32Even(c);
|
||||
Packet4f c_odd = Bf16ToF32Odd(c);
|
||||
Packet4f pmadd_even = pnmadd<Packet4f>(a_even, b_even, c_even);
|
||||
Packet4f pmadd_odd = pnmadd<Packet4f>(a_odd, b_odd, c_odd);
|
||||
return F32ToBf16(pmadd_even, pmadd_odd);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pnmsub(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
|
||||
Packet4f a_even = Bf16ToF32Even(a);
|
||||
Packet4f a_odd = Bf16ToF32Odd(a);
|
||||
Packet4f b_even = Bf16ToF32Even(b);
|
||||
Packet4f b_odd = Bf16ToF32Odd(b);
|
||||
Packet4f c_even = Bf16ToF32Even(c);
|
||||
Packet4f c_odd = Bf16ToF32Odd(c);
|
||||
Packet4f pmadd_even = pnmsub<Packet4f>(a_even, b_even, c_even);
|
||||
Packet4f pmadd_odd = pnmsub<Packet4f>(a_odd, b_odd, c_odd);
|
||||
return F32ToBf16(pmadd_even, pmadd_odd);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||
BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);
|
||||
|
@ -673,6 +673,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfl
|
||||
return bfloat16(::fmaxf(f1, f2));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline bfloat16 fma(const bfloat16& a, const bfloat16& b, const bfloat16& c) {
|
||||
// Emulate FMA via float.
|
||||
return bfloat16(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
|
||||
}
|
||||
|
||||
#ifndef EIGEN_NO_IO
|
||||
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) {
|
||||
os << static_cast<float>(v);
|
||||
|
@ -804,6 +804,18 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) {
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
|
||||
|
||||
EIGEN_DEVICE_FUNC inline half fma(const half& a, const half& b, const half& c) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vfmah_f16(c.x, a.x, b.x));
|
||||
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
|
||||
// Reduces to vfmadd213sh.
|
||||
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
|
||||
#else
|
||||
// Emulate FMA via float.
|
||||
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef EIGEN_NO_IO
|
||||
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
|
||||
os << static_cast<float>(v);
|
||||
@ -1023,36 +1035,6 @@ struct cast_impl<half, float> {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vfmah_f16(a.x, b.x, c.x));
|
||||
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
|
||||
// Reduces to vfmadd213sh.
|
||||
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
|
||||
#else
|
||||
// Emulate FMA via float.
|
||||
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vfmah_f16(a.x, b.x, -c.x));
|
||||
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
|
||||
// Reduces to vfmadd213sh.
|
||||
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x))));
|
||||
#else
|
||||
// Emulate FMA via float.
|
||||
return half(static_cast<float>(a) * static_cast<float>(b) - static_cast<float>(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
|
@ -4964,6 +4964,26 @@ EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4b
|
||||
return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pmadd<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
|
||||
return F32ToBf16(pmadd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pmsub<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
|
||||
return F32ToBf16(pmsub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pnmadd<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
|
||||
return F32ToBf16(pnmadd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pnmsub<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
|
||||
return F32ToBf16(pnmsub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
|
||||
return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
|
||||
@ -5634,6 +5654,21 @@ EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, cons
|
||||
return vfma_f16(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmsub(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
|
||||
return vfmaq_f16(pnegate(c), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pnmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
|
||||
return vfma_f16(c, pnegate(a), b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pnmsub(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
|
||||
return vfma_f16(pnegate(c), pnegate(a), b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vminq_f16(a, b);
|
||||
|
21
scripts/msvc_setup.ps1
Normal file
21
scripts/msvc_setup.ps1
Normal file
@ -0,0 +1,21 @@
|
||||
# Powershell script to set up MSVC environment.
|
||||
|
||||
param ($EIGEN_CI_MSVC_ARCH, $EIGEN_CI_MSVC_VER)
|
||||
|
||||
Set-PSDebug -Trace 1
|
||||
|
||||
function Get-ScriptDirectory { Split-Path $MyInvocation.ScriptName }
|
||||
|
||||
# Set defaults if not already set.
|
||||
IF (!$EIGEN_CI_MSVC_ARCH) { $EIGEN_CI_MSVC_ARCH = "x64" }
|
||||
IF (!$EIGEN_CI_MSVC_VER) { $EIGEN_CI_MSVC_VER = "14.29" }
|
||||
|
||||
# Export variables into the global scope
|
||||
$global:EIGEN_CI_MSVC_ARCH = $EIGEN_CI_MSVC_ARCH
|
||||
$global:EIGEN_CI_MSVC_VER = $EIGEN_CI_MSVC_VER
|
||||
|
||||
# Find Visual Studio installation directory.
|
||||
$global:VS_INSTALL_DIR = &"${Env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath
|
||||
# Run VCVarsAll.bat incitialization script and extract environment variables.
|
||||
# http://allen-mack.blogspot.com/2008/03/replace-visual-studio-command-prompt.html
|
||||
cmd.exe /c "`"${VS_INSTALL_DIR}\VC\Auxiliary\Build\vcvarsall.bat`" $EIGEN_CI_MSVC_ARCH -vcvars_ver=$EIGEN_CI_MSVC_VER & set" | foreach { if ($_ -match "=") { $v = $_.split("="); set-item -force -path "ENV:\$($v[0])" -value "$($v[1])" } }
|
@ -24,21 +24,55 @@ template <typename T>
|
||||
inline T REF_MUL(const T& a, const T& b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template <typename Scalar, typename EnableIf = void>
|
||||
struct madd_impl {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return a * b + c;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return a * b - c;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return c - a * b;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return Scalar(0) - (a * b + c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct madd_impl<Scalar,
|
||||
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, Scalar(-c));
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(Scalar(-a), b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return -Scalar(numext::fma(a, b, c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline T REF_MADD(const T& a, const T& b, const T& c) {
|
||||
return internal::pmadd(a, b, c);
|
||||
return madd_impl<T>::madd(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_MSUB(const T& a, const T& b, const T& c) {
|
||||
return internal::pmsub(a, b, c);
|
||||
return madd_impl<T>::msub(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_NMADD(const T& a, const T& b, const T& c) {
|
||||
return internal::pnmadd(a, b, c);
|
||||
return madd_impl<T>::nmadd(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
|
||||
return internal::pnmsub(a, b, c);
|
||||
return madd_impl<T>::nmsub(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_DIV(const T& a, const T& b) {
|
||||
@ -70,6 +104,14 @@ template <>
|
||||
inline bool REF_MADD(const bool& a, const bool& b, const bool& c) {
|
||||
return (a && b) || c;
|
||||
}
|
||||
template <>
|
||||
inline bool REF_DIV(const bool& a, const bool& b) {
|
||||
return a && b;
|
||||
}
|
||||
template <>
|
||||
inline bool REF_RECIPROCAL(const bool& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T REF_FREXP(const T& x, T& exp) {
|
||||
@ -501,8 +543,8 @@ void packetmath() {
|
||||
eigen_optimization_barrier_test<Scalar>::run();
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
data2[i] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
data1[i] = internal::random<Scalar>();
|
||||
data2[i] = internal::random<Scalar>();
|
||||
refvalue = (std::max)(refvalue, numext::abs(data1[i]));
|
||||
}
|
||||
|
||||
@ -522,8 +564,8 @@ void packetmath() {
|
||||
for (int M = 0; M < PacketSize; ++M) {
|
||||
for (int N = 0; N <= PacketSize; ++N) {
|
||||
for (int j = 0; j < size; ++j) {
|
||||
data1[j] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
data2[j] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
data1[j] = internal::random<Scalar>();
|
||||
data2[j] = internal::random<Scalar>();
|
||||
refvalue = (std::max)(refvalue, numext::abs(data1[j]));
|
||||
}
|
||||
|
||||
@ -652,11 +694,11 @@ void packetmath() {
|
||||
// Avoid overflows.
|
||||
if (NumTraits<Scalar>::IsInteger && NumTraits<Scalar>::IsSigned &&
|
||||
Eigen::internal::unpacket_traits<Packet>::size > 1) {
|
||||
Scalar limit =
|
||||
static_cast<Scalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
|
||||
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size)));
|
||||
Scalar limit = static_cast<Scalar>(
|
||||
static_cast<RealScalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
|
||||
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size))));
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
data1[i] = internal::random<Scalar>(-limit, limit);
|
||||
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
|
||||
}
|
||||
}
|
||||
ref[0] = Scalar(1);
|
||||
@ -1683,7 +1725,7 @@ void packetmath_scatter_gather() {
|
||||
|
||||
for (Index N = 0; N <= PacketSize; ++N) {
|
||||
for (Index i = 0; i < N; ++i) {
|
||||
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
data1[i] = internal::random<Scalar>();
|
||||
}
|
||||
|
||||
for (Index i = 0; i < N * 20; ++i) {
|
||||
@ -1702,7 +1744,7 @@ void packetmath_scatter_gather() {
|
||||
}
|
||||
|
||||
for (Index i = 0; i < N * 7; ++i) {
|
||||
buffer[i] = internal::random<Scalar>() / RealScalar(PacketSize);
|
||||
buffer[i] = internal::random<Scalar>();
|
||||
}
|
||||
packet = internal::pgather_partial<Scalar, Packet>(buffer, 7, N);
|
||||
internal::pstore_partial(data1, packet, N);
|
||||
|
@ -162,7 +162,9 @@ struct packet_helper {
|
||||
|
||||
template <typename T>
|
||||
inline Packet load(const T* from, unsigned long long umask) const {
|
||||
return internal::ploadu<Packet>(from, umask);
|
||||
using UMaskType = typename numext::get_integer_by_size<internal::plain_enum_max(
|
||||
internal::unpacket_traits<Packet>::size / CHAR_BIT, 1)>::unsigned_type;
|
||||
return internal::ploadu<Packet>(from, static_cast<UMaskType>(umask));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -172,7 +174,9 @@ struct packet_helper {
|
||||
|
||||
template <typename T>
|
||||
inline void store(T* to, const Packet& x, unsigned long long umask) const {
|
||||
internal::pstoreu(to, x, umask);
|
||||
using UMaskType = typename numext::get_integer_by_size<internal::plain_enum_max(
|
||||
internal::unpacket_traits<Packet>::size / CHAR_BIT, 1)>::unsigned_type;
|
||||
internal::pstoreu(to, x, static_cast<UMaskType>(umask));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
x
Reference in New Issue
Block a user