Add numext::fma and missing pmadd implementations.

This commit is contained in:
Antonio Sánchez 2025-03-23 01:05:53 +00:00
parent 754bd24f5e
commit d935916ac6
11 changed files with 319 additions and 105 deletions

View File

@ -368,6 +368,11 @@ template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pdiv(const Packet& a, const Packet& b) {
return a / b;
}
// Avoid compiler warning for boolean algebra.
template <>
EIGEN_DEVICE_FUNC inline bool pdiv(const bool& a, const bool& b) {
return a && b;
}
// In the generic case, memset to all one bits.
template <typename Packet, typename EnableIf = void>
@ -1294,29 +1299,61 @@ EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a) {
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/
template <typename Packet, typename EnableIf = void>
struct pmadd_impl {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
return padd(pmul(a, b), c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pmsub(const Packet& a, const Packet& b, const Packet& c) {
return psub(pmul(a, b), c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) {
return psub(c, pmul(a, b));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) {
return pnegate(pmadd(a, b, c));
}
};
template <typename Scalar>
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c));
}
};
// FMA instructions.
/** \internal \returns a * b + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
return padd(pmul(a, b), c);
return pmadd_impl<Packet>::pmadd(a, b, c);
}
/** \internal \returns a * b - c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b, const Packet& c) {
return psub(pmul(a, b), c);
return pmadd_impl<Packet>::pmsub(a, b, c);
}
/** \internal \returns -(a * b) + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) {
return psub(c, pmul(a, b));
return pmadd_impl<Packet>::pnmadd(a, b, c);
}
/** \internal \returns -((a * b + c) (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) {
return pnegate(pmadd(a, b, c));
return pmadd_impl<Packet>::pnmsub(a, b, c);
}
/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned

View File

@ -945,6 +945,38 @@ struct nearest_integer_impl<Scalar, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
};
// Default implementation.
template <typename Scalar, typename Enable = void>
struct fma_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b + c;
}
};
// ADL version if it exists.
template <typename T>
struct fma_impl<
T,
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
};
#if defined(EIGEN_GPUCC)
template <>
struct fma_impl<float, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, c);
}
};
template <>
struct fma_impl<double, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
return ::fma(a, b, c);
}
};
#endif
} // end namespace internal
/****************************************************************************
@ -1852,6 +1884,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
}
// Use std::fma if available.
using std::fma;
// Otherwise, rely on template implementation.
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::fma_impl<Scalar>::run(x, y, z);
}
} // end namespace numext
namespace internal {

View File

@ -2415,6 +2415,26 @@ EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmadd<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pmsub<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmadd<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmsub<Packet8h>(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
Packet8f af = half2float(a);
@ -2785,6 +2805,26 @@ EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8b
return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf pmadd<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
return F32ToBf16(pmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf pmsub<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
return F32ToBf16(pmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf pnmadd<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
return F32ToBf16(pnmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf pnmsub<Packet8bf>(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
return F32ToBf16(pnmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));

View File

@ -2264,6 +2264,7 @@ template <>
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256((__m256i*)(void*)to, from);
}
@ -2271,6 +2272,7 @@ template <>
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256((__m256i*)(void*)to, from);
}
@ -2754,11 +2756,13 @@ EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
template <>
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet16bf& from) {
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet16bf& from) {
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
}
@ -3154,32 +3158,46 @@ EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
_mm512_storeu_epi16(out, x);
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
_mm256_storeu_epi16(out, x);
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32(out, x);
#else
_mm256_store_si256(reinterpret_cast<__m256i*>(out), x);
#endif
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
_mm_storeu_epi16(out, x);
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32(out, x);
#else
_mm_store_si128(reinterpret_cast<__m128i*>(out), x);
#endif
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
_mm512_storeu_epi16(out, x);
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
_mm256_storeu_epi16(out, x);
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
_mm_storeu_epi16(out, x);
EIGEN_DEBUG_UNALIGNED_STORE
_mm_storeu_epi32(out, x);
}
template <>

View File

@ -424,55 +424,6 @@ struct unpacket_traits<Packet8bf> {
masked_store_available = false
};
};
inline std::ostream& operator<<(std::ostream& s, const Packet16c& v) {
union {
Packet16c v;
signed char n[16];
} vt;
vt.v = v;
for (int i = 0; i < 16; i++) s << vt.n[i] << ", ";
return s;
}
inline std::ostream& operator<<(std::ostream& s, const Packet16uc& v) {
union {
Packet16uc v;
unsigned char n[16];
} vt;
vt.v = v;
for (int i = 0; i < 16; i++) s << vt.n[i] << ", ";
return s;
}
inline std::ostream& operator<<(std::ostream& s, const Packet4f& v) {
union {
Packet4f v;
float n[4];
} vt;
vt.v = v;
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
return s;
}
inline std::ostream& operator<<(std::ostream& s, const Packet4i& v) {
union {
Packet4i v;
int n[4];
} vt;
vt.v = v;
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
return s;
}
inline std::ostream& operator<<(std::ostream& s, const Packet4ui& v) {
union {
Packet4ui v;
unsigned int n[4];
} vt;
vt.v = v;
s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3];
return s;
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet pload_common(const __UNPACK_TYPE__(Packet) * from) {
@ -2384,6 +2335,44 @@ EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, cons
return F32ToBf16(pmadd_even, pmadd_odd);
}
template <>
EIGEN_STRONG_INLINE Packet8bf pmsub(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
Packet4f a_even = Bf16ToF32Even(a);
Packet4f a_odd = Bf16ToF32Odd(a);
Packet4f b_even = Bf16ToF32Even(b);
Packet4f b_odd = Bf16ToF32Odd(b);
Packet4f c_even = Bf16ToF32Even(c);
Packet4f c_odd = Bf16ToF32Odd(c);
Packet4f pmadd_even = pmsub<Packet4f>(a_even, b_even, c_even);
Packet4f pmadd_odd = pmsub<Packet4f>(a_odd, b_odd, c_odd);
return F32ToBf16(pmadd_even, pmadd_odd);
}
template <>
EIGEN_STRONG_INLINE Packet8bf pnmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
Packet4f a_even = Bf16ToF32Even(a);
Packet4f a_odd = Bf16ToF32Odd(a);
Packet4f b_even = Bf16ToF32Even(b);
Packet4f b_odd = Bf16ToF32Odd(b);
Packet4f c_even = Bf16ToF32Even(c);
Packet4f c_odd = Bf16ToF32Odd(c);
Packet4f pmadd_even = pnmadd<Packet4f>(a_even, b_even, c_even);
Packet4f pmadd_odd = pnmadd<Packet4f>(a_odd, b_odd, c_odd);
return F32ToBf16(pmadd_even, pmadd_odd);
}
template <>
EIGEN_STRONG_INLINE Packet8bf pnmsub(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) {
Packet4f a_even = Bf16ToF32Even(a);
Packet4f a_odd = Bf16ToF32Odd(a);
Packet4f b_even = Bf16ToF32Even(b);
Packet4f b_odd = Bf16ToF32Odd(b);
Packet4f c_even = Bf16ToF32Even(c);
Packet4f c_odd = Bf16ToF32Odd(c);
Packet4f pmadd_even = pnmsub<Packet4f>(a_even, b_even, c_even);
Packet4f pmadd_odd = pnmsub<Packet4f>(a_odd, b_odd, c_odd);
return F32ToBf16(pmadd_even, pmadd_odd);
}
template <>
EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
BF16_TO_F32_BINARY_OP_WRAPPER(pmin<Packet4f>, a, b);

View File

@ -673,6 +673,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfl
return bfloat16(::fmaxf(f1, f2));
}
EIGEN_DEVICE_FUNC inline bfloat16 fma(const bfloat16& a, const bfloat16& b, const bfloat16& c) {
// Emulate FMA via float.
return bfloat16(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
}
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v) {
os << static_cast<float>(v);

View File

@ -804,6 +804,18 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
EIGEN_DEVICE_FUNC inline half fma(const half& a, const half& b, const half& c) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vfmah_f16(c.x, a.x, b.x));
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
// Reduces to vfmadd213sh.
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
#endif
}
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
os << static_cast<float>(v);
@ -1023,36 +1035,6 @@ struct cast_impl<half, float> {
}
};
#ifdef EIGEN_VECTORIZE_FMA
template <>
EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vfmah_f16(a.x, b.x, c.x));
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
// Reduces to vfmadd213sh.
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
#endif
}
template <>
EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vfmah_f16(a.x, b.x, -c.x));
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
// Reduces to vfmadd213sh.
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) - static_cast<float>(c));
#endif
}
#endif
} // namespace internal
} // namespace Eigen

View File

@ -4964,6 +4964,26 @@ EIGEN_STRONG_INLINE Packet4bf pmul<Packet4bf>(const Packet4bf& a, const Packet4b
return F32ToBf16(pmul<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet4bf pmadd<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
return F32ToBf16(pmadd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet4bf pmsub<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
return F32ToBf16(pmsub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet4bf pnmadd<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
return F32ToBf16(pnmadd<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet4bf pnmsub<Packet4bf>(const Packet4bf& a, const Packet4bf& b, const Packet4bf& c) {
return F32ToBf16(pnmsub<Packet4f>(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet4bf pdiv<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
return F32ToBf16(pdiv<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
@ -5634,6 +5654,21 @@ EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, cons
return vfma_f16(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8hf pmsub(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
return vfmaq_f16(pnegate(c), a, b);
}
template <>
EIGEN_STRONG_INLINE Packet4hf pnmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
return vfma_f16(c, pnegate(a), b);
}
template <>
EIGEN_STRONG_INLINE Packet4hf pnmsub(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
return vfma_f16(pnegate(c), pnegate(a), b);
}
template <>
EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
return vminq_f16(a, b);

21
scripts/msvc_setup.ps1 Normal file
View File

@ -0,0 +1,21 @@
# Powershell script to set up MSVC environment.
param ($EIGEN_CI_MSVC_ARCH, $EIGEN_CI_MSVC_VER)
Set-PSDebug -Trace 1
function Get-ScriptDirectory { Split-Path $MyInvocation.ScriptName }
# Set defaults if not already set.
IF (!$EIGEN_CI_MSVC_ARCH) { $EIGEN_CI_MSVC_ARCH = "x64" }
IF (!$EIGEN_CI_MSVC_VER) { $EIGEN_CI_MSVC_VER = "14.29" }
# Export variables into the global scope
$global:EIGEN_CI_MSVC_ARCH = $EIGEN_CI_MSVC_ARCH
$global:EIGEN_CI_MSVC_VER = $EIGEN_CI_MSVC_VER
# Find Visual Studio installation directory.
$global:VS_INSTALL_DIR = &"${Env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe" -latest -property installationPath
# Run VCVarsAll.bat incitialization script and extract environment variables.
# http://allen-mack.blogspot.com/2008/03/replace-visual-studio-command-prompt.html
cmd.exe /c "`"${VS_INSTALL_DIR}\VC\Auxiliary\Build\vcvarsall.bat`" $EIGEN_CI_MSVC_ARCH -vcvars_ver=$EIGEN_CI_MSVC_VER & set" | foreach { if ($_ -match "=") { $v = $_.split("="); set-item -force -path "ENV:\$($v[0])" -value "$($v[1])" } }

View File

@ -24,21 +24,55 @@ template <typename T>
inline T REF_MUL(const T& a, const T& b) {
return a * b;
}
template <typename Scalar, typename EnableIf = void>
struct madd_impl {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b + c;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b - c;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return c - a * b;
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return Scalar(0) - (a * b + c);
}
};
template <typename Scalar>
struct madd_impl<Scalar,
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c));
}
};
template <typename T>
inline T REF_MADD(const T& a, const T& b, const T& c) {
return internal::pmadd(a, b, c);
return madd_impl<T>::madd(a, b, c);
}
template <typename T>
inline T REF_MSUB(const T& a, const T& b, const T& c) {
return internal::pmsub(a, b, c);
return madd_impl<T>::msub(a, b, c);
}
template <typename T>
inline T REF_NMADD(const T& a, const T& b, const T& c) {
return internal::pnmadd(a, b, c);
return madd_impl<T>::nmadd(a, b, c);
}
template <typename T>
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
return internal::pnmsub(a, b, c);
return madd_impl<T>::nmsub(a, b, c);
}
template <typename T>
inline T REF_DIV(const T& a, const T& b) {
@ -70,6 +104,14 @@ template <>
inline bool REF_MADD(const bool& a, const bool& b, const bool& c) {
return (a && b) || c;
}
template <>
inline bool REF_DIV(const bool& a, const bool& b) {
return a && b;
}
template <>
inline bool REF_RECIPROCAL(const bool& a) {
return a;
}
template <typename T>
inline T REF_FREXP(const T& x, T& exp) {
@ -501,8 +543,8 @@ void packetmath() {
eigen_optimization_barrier_test<Scalar>::run();
for (int i = 0; i < size; ++i) {
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data2[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[i] = internal::random<Scalar>();
data2[i] = internal::random<Scalar>();
refvalue = (std::max)(refvalue, numext::abs(data1[i]));
}
@ -522,8 +564,8 @@ void packetmath() {
for (int M = 0; M < PacketSize; ++M) {
for (int N = 0; N <= PacketSize; ++N) {
for (int j = 0; j < size; ++j) {
data1[j] = internal::random<Scalar>() / RealScalar(PacketSize);
data2[j] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[j] = internal::random<Scalar>();
data2[j] = internal::random<Scalar>();
refvalue = (std::max)(refvalue, numext::abs(data1[j]));
}
@ -652,11 +694,11 @@ void packetmath() {
// Avoid overflows.
if (NumTraits<Scalar>::IsInteger && NumTraits<Scalar>::IsSigned &&
Eigen::internal::unpacket_traits<Packet>::size > 1) {
Scalar limit =
static_cast<Scalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size)));
Scalar limit = static_cast<Scalar>(
static_cast<RealScalar>(std::pow(static_cast<double>(numext::real(NumTraits<Scalar>::highest())),
1.0 / static_cast<double>(Eigen::internal::unpacket_traits<Packet>::size))));
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<Scalar>(-limit, limit);
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
}
}
ref[0] = Scalar(1);
@ -1683,7 +1725,7 @@ void packetmath_scatter_gather() {
for (Index N = 0; N <= PacketSize; ++N) {
for (Index i = 0; i < N; ++i) {
data1[i] = internal::random<Scalar>() / RealScalar(PacketSize);
data1[i] = internal::random<Scalar>();
}
for (Index i = 0; i < N * 20; ++i) {
@ -1702,7 +1744,7 @@ void packetmath_scatter_gather() {
}
for (Index i = 0; i < N * 7; ++i) {
buffer[i] = internal::random<Scalar>() / RealScalar(PacketSize);
buffer[i] = internal::random<Scalar>();
}
packet = internal::pgather_partial<Scalar, Packet>(buffer, 7, N);
internal::pstore_partial(data1, packet, N);

View File

@ -162,7 +162,9 @@ struct packet_helper {
template <typename T>
inline Packet load(const T* from, unsigned long long umask) const {
return internal::ploadu<Packet>(from, umask);
using UMaskType = typename numext::get_integer_by_size<internal::plain_enum_max(
internal::unpacket_traits<Packet>::size / CHAR_BIT, 1)>::unsigned_type;
return internal::ploadu<Packet>(from, static_cast<UMaskType>(umask));
}
template <typename T>
@ -172,7 +174,9 @@ struct packet_helper {
template <typename T>
inline void store(T* to, const Packet& x, unsigned long long umask) const {
internal::pstoreu(to, x, umask);
using UMaskType = typename numext::get_integer_by_size<internal::plain_enum_max(
internal::unpacket_traits<Packet>::size / CHAR_BIT, 1)>::unsigned_type;
internal::pstoreu(to, x, static_cast<UMaskType>(umask));
}
template <typename T>