mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-03 02:03:59 +08:00
Fix x86 complex vectorized fma
This commit is contained in:
parent
464c1d0978
commit
10e62ccd22
@ -475,19 +475,11 @@ EIGEN_STRONG_INLINE Packet4cf pmsub(const Packet4cf& a, const Packet4cf& b, cons
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4cf pnmadd(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) {
|
||||
__m256 a_odd = _mm256_movehdup_ps(a.v);
|
||||
__m256 a_even = _mm256_moveldup_ps(a.v);
|
||||
__m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
__m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmaddsub_ps(a_even, b.v, c.v));
|
||||
return Packet4cf(result);
|
||||
return pnegate(pmsub(a, b, c));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4cf pnmsub(const Packet4cf& a, const Packet4cf& b, const Packet4cf& c) {
|
||||
__m256 a_odd = _mm256_movehdup_ps(a.v);
|
||||
__m256 a_even = _mm256_moveldup_ps(a.v);
|
||||
__m256 b_swap = _mm256_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
__m256 result = _mm256_fmaddsub_ps(a_odd, b_swap, _mm256_fmsubadd_ps(a_even, b.v, c.v));
|
||||
return Packet4cf(result);
|
||||
return pnegate(pmadd(a, b, c));
|
||||
}
|
||||
// std::complex<double>
|
||||
template <>
|
||||
@ -508,19 +500,11 @@ EIGEN_STRONG_INLINE Packet2cd pmsub(const Packet2cd& a, const Packet2cd& b, cons
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cd pnmadd(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) {
|
||||
__m256d a_odd = _mm256_permute_pd(a.v, 0xF);
|
||||
__m256d a_even = _mm256_movedup_pd(a.v);
|
||||
__m256d b_swap = _mm256_permute_pd(b.v, 0x5);
|
||||
__m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmaddsub_pd(a_even, b.v, c.v));
|
||||
return Packet2cd(result);
|
||||
return pnegate(pmsub(a, b, c));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cd pnmsub(const Packet2cd& a, const Packet2cd& b, const Packet2cd& c) {
|
||||
__m256d a_odd = _mm256_permute_pd(a.v, 0xF);
|
||||
__m256d a_even = _mm256_movedup_pd(a.v);
|
||||
__m256d b_swap = _mm256_permute_pd(b.v, 0x5);
|
||||
__m256d result = _mm256_fmaddsub_pd(a_odd, b_swap, _mm256_fmsubadd_pd(a_even, b.v, c.v));
|
||||
return Packet2cd(result);
|
||||
return pnegate(pmadd(a, b, c));
|
||||
}
|
||||
#endif
|
||||
} // end namespace internal
|
||||
|
@ -465,19 +465,11 @@ EIGEN_STRONG_INLINE Packet2cf pmsub(const Packet2cf& a, const Packet2cf& b, cons
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cf pnmadd(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) {
|
||||
__m128 a_odd = _mm_movehdup_ps(a.v);
|
||||
__m128 a_even = _mm_moveldup_ps(a.v);
|
||||
__m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
__m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmaddsub_ps(a_even, b.v, c.v));
|
||||
return Packet2cf(result);
|
||||
return pnegate(pmsub(a, b, c));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet2cf pnmsub(const Packet2cf& a, const Packet2cf& b, const Packet2cf& c) {
|
||||
__m128 a_odd = _mm_movehdup_ps(a.v);
|
||||
__m128 a_even = _mm_moveldup_ps(a.v);
|
||||
__m128 b_swap = _mm_permute_ps(b.v, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
__m128 result = _mm_fmaddsub_ps(a_odd, b_swap, _mm_fmsubadd_ps(a_even, b.v, c.v));
|
||||
return Packet2cf(result);
|
||||
return pnegate(pmadd(a, b, c));
|
||||
}
|
||||
// std::complex<double>
|
||||
template <>
|
||||
@ -498,19 +490,11 @@ EIGEN_STRONG_INLINE Packet1cd pmsub(const Packet1cd& a, const Packet1cd& b, cons
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1cd pnmadd(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) {
|
||||
__m128d a_odd = _mm_permute_pd(a.v, 0x3);
|
||||
__m128d a_even = _mm_movedup_pd(a.v);
|
||||
__m128d b_swap = _mm_permute_pd(b.v, 0x1);
|
||||
__m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmaddsub_pd(a_even, b.v, c.v));
|
||||
return Packet1cd(result);
|
||||
return pnegate(pmsub(a, b, c));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet1cd pnmsub(const Packet1cd& a, const Packet1cd& b, const Packet1cd& c) {
|
||||
__m128d a_odd = _mm_permute_pd(a.v, 0x3);
|
||||
__m128d a_even = _mm_movedup_pd(a.v);
|
||||
__m128d b_swap = _mm_permute_pd(b.v, 0x1);
|
||||
__m128d result = _mm_fmaddsub_pd(a_odd, b_swap, _mm_fmsubadd_pd(a_even, b.v, c.v));
|
||||
return Packet1cd(result);
|
||||
return pnegate(pmadd(a, b, c));
|
||||
}
|
||||
#endif
|
||||
} // end namespace internal
|
||||
|
@ -92,6 +92,26 @@ inline T REF_LDEXP(const T& x, const T& exp) {
|
||||
return static_cast<T>(ldexp(x, static_cast<int>(exp)));
|
||||
}
|
||||
|
||||
// provides a convenient function to take the absolute value of each component of a complex number to prevent
|
||||
// catastrophic cancellation in randomly generated complex numbers
|
||||
template <typename T, bool IsComplex = NumTraits<T>::IsComplex>
|
||||
struct abs_helper_impl {
|
||||
static T run(T x) { return numext::abs(x); }
|
||||
};
|
||||
template <typename T>
|
||||
struct abs_helper_impl<T, true> {
|
||||
static T run(T x) {
|
||||
T res = x;
|
||||
numext::real_ref(res) = numext::abs(numext::real(res));
|
||||
numext::imag_ref(res) = numext::abs(numext::imag(res));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
T abs_helper(T x) {
|
||||
return abs_helper_impl<T>::run(x);
|
||||
}
|
||||
|
||||
// Uses pcast to cast from one array to another.
|
||||
template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
|
||||
struct pcast_array;
|
||||
@ -724,11 +744,6 @@ void packetmath() {
|
||||
packetmath_pcast_ops_runner<Scalar, Packet>::run();
|
||||
packetmath_minus_zero_add_test<Scalar, Packet>::run();
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data1[i] = numext::abs(internal::random<Scalar>());
|
||||
}
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||
CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd);
|
||||
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
|
||||
nmsub_test<Scalar, Packet>(data1, data2, ref, PacketSize);
|
||||
@ -738,14 +753,17 @@ void packetmath() {
|
||||
// which can lead to very flaky tests. Here we ensure the signs are such that
|
||||
// they do not cancel.
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
data1[i] = numext::abs(internal::random<Scalar>());
|
||||
data1[i + PacketSize] = numext::abs(internal::random<Scalar>());
|
||||
data1[i + 2 * PacketSize] = Scalar(0) - numext::abs(internal::random<Scalar>());
|
||||
data1[i] = abs_helper(internal::random<Scalar>());
|
||||
data1[i + PacketSize] = abs_helper(internal::random<Scalar>());
|
||||
data1[i + 2 * PacketSize] = Scalar(0) - abs_helper(internal::random<Scalar>());
|
||||
}
|
||||
if (!std::is_same<Scalar, bool>::value && NumTraits<Scalar>::IsSigned) {
|
||||
CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub);
|
||||
CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd);
|
||||
}
|
||||
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
|
||||
}
|
||||
|
||||
// Notice that this definition works for complex types as well.
|
||||
|
Loading…
x
Reference in New Issue
Block a user