Implement vectorized complex square root.

Closes #1905

Measured speedup for sqrt of `complex<float>` on Skylake:

SSE:
```
name                      old time/op             new time/op  delta
BM_eigen_sqrt_ctype/1     49.4ns ± 0%             54.3ns ± 0%  +10.01%
BM_eigen_sqrt_ctype/8      332ns ± 0%               50ns ± 1%  -84.97%
BM_eigen_sqrt_ctype/64    2.81µs ± 1%             0.38µs ± 0%  -86.49%
BM_eigen_sqrt_ctype/512   23.8µs ± 0%              3.0µs ± 0%  -87.32%
BM_eigen_sqrt_ctype/4k     202µs ± 0%               24µs ± 2%  -88.03%
BM_eigen_sqrt_ctype/32k   1.63ms ± 0%             0.19ms ± 0%  -88.18%
BM_eigen_sqrt_ctype/256k  13.0ms ± 0%              1.5ms ± 1%  -88.20%
BM_eigen_sqrt_ctype/1M    52.1ms ± 0%              6.2ms ± 0%  -88.18%
```

AVX2:
```
name                      old cpu/op  new cpu/op  delta
BM_eigen_sqrt_ctype/1     53.6ns ± 0%  55.6ns ± 0%   +3.71%
BM_eigen_sqrt_ctype/8      334ns ± 0%    27ns ± 0%  -91.86%
BM_eigen_sqrt_ctype/64    2.79µs ± 0%  0.22µs ± 2%  -92.28%
BM_eigen_sqrt_ctype/512   23.8µs ± 1%   1.7µs ± 1%  -92.81%
BM_eigen_sqrt_ctype/4k     201µs ± 0%    14µs ± 1%  -93.24%
BM_eigen_sqrt_ctype/32k   1.62ms ± 0%  0.11ms ± 1%  -93.29%
BM_eigen_sqrt_ctype/256k  13.0ms ± 0%   0.9ms ± 1%  -93.31%
BM_eigen_sqrt_ctype/1M    52.0ms ± 0%   3.5ms ± 1%  -93.31%
```

AVX512:
```
name                      old cpu/op  new cpu/op  delta
BM_eigen_sqrt_ctype/1     53.7ns ± 0%  56.2ns ± 1%   +4.75%
BM_eigen_sqrt_ctype/8      334ns ± 0%    18ns ± 2%  -94.63%
BM_eigen_sqrt_ctype/64    2.79µs ± 0%  0.12µs ± 1%  -95.54%
BM_eigen_sqrt_ctype/512   23.9µs ± 1%   1.0µs ± 1%  -95.89%
BM_eigen_sqrt_ctype/4k     202µs ± 0%     8µs ± 1%  -96.13%
BM_eigen_sqrt_ctype/32k   1.63ms ± 0%  0.06ms ± 1%  -96.15%
BM_eigen_sqrt_ctype/256k  13.0ms ± 0%   0.5ms ± 4%  -96.11%
BM_eigen_sqrt_ctype/1M    52.1ms ± 0%   2.0ms ± 1%  -96.13%
```
This commit is contained in:
Rasmus Munk Larsen 2020-12-08 18:13:35 -08:00
parent 8cfe0db108
commit 125cc9a5df
10 changed files with 290 additions and 11 deletions

View File

@ -539,6 +539,20 @@ inline void pbroadcast2(const typename unpacket_traits<Packet>::type *a,
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
plset(const typename unpacket_traits<Packet>::type& a) { return a; }
/** \internal \returns a packet with constant coefficients \a a, e.g.: (x, 0, x, 0),
where x is the value of all 1-bits. */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
peven_mask(const Packet& /*a*/) {
typedef typename unpacket_traits<Packet>::type Scalar;
const size_t n = unpacket_traits<Packet>::size;
Scalar elements[n];
for(size_t i = 0; i < n; ++i) {
memset(elements+i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar));
}
return ploadu<Packet>(elements);
}
/** \internal copy the packet \a from to \a *to, \a to must be 16 bytes aligned */
template<typename Scalar, typename Packet> EIGEN_DEVICE_FUNC inline void pstore(Scalar* to, const Packet& from)
{ (*to) = from; }

View File

@ -38,6 +38,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -47,7 +48,18 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
};
#endif
template<> struct unpacket_traits<Packet4cf> { typedef std::complex<float> type; enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; };
template<> struct unpacket_traits<Packet4cf> {
typedef std::complex<float> type;
typedef Packet2cf half;
typedef Packet8f as_real;
enum {
size=4,
alignment=Aligned32,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); }
@ -228,6 +240,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -237,7 +250,18 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
#endif
template<> struct unpacket_traits<Packet2cd> { typedef std::complex<double> type; enum {size=2, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; };
template<> struct unpacket_traits<Packet2cd> {
typedef std::complex<double> type;
typedef Packet1cd half;
typedef Packet4d as_real;
enum {
size=2,
alignment=Aligned32,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); }
@ -399,6 +423,14 @@ ptranspose(PacketBlock<Packet2cd,2>& kernel) {
kernel.packet[0].v = tmp;
}
template<> EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) {
return psqrt_complex<Packet2cd>(a);
}
template<> EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) {
return psqrt_complex<Packet4cf>(a);
}
} // end namespace internal
} // end namespace Eigen

View File

@ -248,6 +248,11 @@ template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _m
template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); }
template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); }
template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return Packet8f(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return Packet8i(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return Packet4d(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); }
template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { return _mm256_broadcast_ss(from); }
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }

View File

@ -37,6 +37,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -47,6 +48,8 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
template<> struct unpacket_traits<Packet8cf> {
typedef std::complex<float> type;
typedef Packet4cf half;
typedef Packet16f as_real;
enum {
size = 8,
alignment=unpacket_traits<Packet16f>::alignment,
@ -54,7 +57,6 @@ template<> struct unpacket_traits<Packet8cf> {
masked_load_available=false,
masked_store_available=false
};
typedef Packet4cf half;
};
template<> EIGEN_STRONG_INLINE Packet8cf ptrue<Packet8cf>(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); }
@ -223,6 +225,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -233,6 +236,8 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
template<> struct unpacket_traits<Packet4cd> {
typedef std::complex<double> type;
typedef Packet2cd half;
typedef Packet8d as_real;
enum {
size = 4,
alignment = unpacket_traits<Packet8d>::alignment,
@ -240,7 +245,6 @@ template<> struct unpacket_traits<Packet4cd> {
masked_load_available=false,
masked_store_available=false
};
typedef Packet2cd half;
};
template<> EIGEN_STRONG_INLINE Packet4cd padd<Packet4cd>(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); }
@ -437,8 +441,15 @@ ptranspose(PacketBlock<Packet4cd,4>& kernel) {
kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask))); // [a0 b0 c0 d0]
}
} // end namespace internal
template<> EIGEN_STRONG_INLINE Packet4cd psqrt<Packet4cd>(const Packet4cd& a) {
return psqrt_complex<Packet4cd>(a);
}
template<> EIGEN_STRONG_INLINE Packet8cf psqrt<Packet8cf>(const Packet8cf& a) {
return psqrt_complex<Packet8cf>(a);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_COMPLEX_AVX512_H

View File

@ -219,6 +219,19 @@ template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return
template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); }
template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); }
template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) {
return Packet16f(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
0, -1, 0, -1, 0, -1, 0, -1));
}
template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) {
return Packet16i(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
0, -1, 0, -1, 0, -1, 0, -1));
}
template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
return Packet8d(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1,
0, 0, -1, -1, 0, 0, -1, -1));
}
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
return _mm512_broadcastss_ps(_mm_load_ps1(from));

View File

@ -673,6 +673,120 @@ Packet pcos_float(const Packet& x)
return psincos_float<false>(x);
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psqrt_complex(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Computes the principal sqrt of the complex numbers in the input.
//
// For example, for packets containing 2 complex numbers stored in interleaved format
// a = [a0, a1] = [x0, y0, x1, y1],
// where x0 = real(a0), y0 = imag(a0) etc., this function returns
// b = [b0, b1] = [u0, v0, u1, v1],
// such that b0^2 = a0, b1^2 = a1.
//
// To derive the formula for the complex square roots, let's consider the equation for
// a single complex square root of the number x + i*y. We want to find real numbers
// u and v such that
// (u + i*v)^2 = x + i*y <=>
// u^2 - v^2 + i*2*u*v = x + i*v.
// By equating the real and imaginary parts we get:
// u^2 - v^2 = x
// 2*u*v = y.
//
// For x >= 0, this has the numerically stable solution
// u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// v = 0.5 * (y / u)
// and for x < 0,
// v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2)))
// u = |0.5 * (y / v)|
//
// To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as
// l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) ,
// In the following, without lack of generality, we have annotated the code, assuming
// that the input is a packet of 2 complex numbers.
//
// Step 1. Compute l = [l0, l0, l1, l1], where
// l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2)
// To avoid over- and underflow, we use the stable formula for each hypotenuse
// l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)),
// where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1.
Packet a_flip = pcplxflip(a);
RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|]
RealPacket a_abs_flip = pabs(a_flip.v); // [|y0|, |x0|, |y1|, |x1|]
RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero.
l = pselect(a_min_zero_mask, a_max, l);
// Step 2. Compute [rho0, *, rho1, *], where
// rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|))
// We don't care about the imaginary parts computed here. They will be overwritten later.
const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5));
Packet rho;
rho.v = psqrt(pmul(cst_half, padd(a_abs, l)));
// Step 3. Compute [rho0, eta0, rho1, eta1], where
// eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2.
// set eta = 0 of input is 0 + i0.
RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask);
RealPacket real_mask = peven_mask(a.v);
Packet positive_real_result;
// Compute result for inputs with positive real part.
positive_real_result.v = pselect(real_mask, rho.v, eta);
// Step 4. Compute solution for inputs with negative real part:
// [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), RealScalar(-0.0))).v;
RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
Packet negative_real_result;
// Notice that rho is positive, so taking it's absolute value is a noop.
negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs);
// Step 5. Select solution branch based on the sign of the real parts.
Packet negative_real_mask;
negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v));
negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v);
Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result);
// Step 6. Handle special cases for infinities:
// * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN
// * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN
// * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y
// * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y
const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity());
Packet is_inf;
is_inf.v = pcmp_eq(a_abs, cst_pos_inf);
Packet is_real_inf;
is_real_inf.v = pand(is_inf.v, real_mask);
is_real_inf = por(is_real_inf, pcplxflip(is_real_inf));
// prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part.
Packet real_inf_result;
real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v);
real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v);
// prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part.
Packet is_imag_inf;
is_imag_inf.v = pandnot(is_inf.v, real_mask);
is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf));
Packet imag_inf_result;
imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
return pselect(is_imag_inf, imag_inf_result,
pselect(is_real_inf, real_inf_result,result));
}
/* polevl (modified for Eigen)
*
* Evaluate polynomial

View File

@ -82,8 +82,15 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet pcos_float(const Packet& x);
/** \internal \returns sqrt(x) for complex types */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet psqrt_complex(const Packet& a);
template <typename Packet, int N> struct ppolevl;
} // end namespace internal
} // end namespace Eigen

View File

@ -40,6 +40,7 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -50,7 +51,18 @@ template<> struct packet_traits<std::complex<float> > : default_packet_traits
};
#endif
template<> struct unpacket_traits<Packet2cf> { typedef std::complex<float> type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; };
template<> struct unpacket_traits<Packet2cf> {
typedef std::complex<float> type;
typedef Packet2cf half;
typedef Packet4f as_real;
enum {
size=2,
alignment=Aligned16,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet2cf padd<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf psub<Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); }
@ -83,7 +95,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf pmul<Packet2cf>(const Packet2cf& a, con
}
template<> EIGEN_STRONG_INLINE Packet2cf ptrue <Packet2cf>(const Packet2cf& a) { return Packet2cf(ptrue(Packet4f(a.v))); }
template<> EIGEN_STRONG_INLINE Packet2cf pand <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_and_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf por <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_or_ps(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet2cf pxor <Packet2cf>(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_xor_ps(a.v,b.v)); }
@ -255,6 +266,7 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasSqrt = 1,
HasAbs = 0,
HasAbs2 = 0,
HasMin = 0,
@ -264,7 +276,18 @@ template<> struct packet_traits<std::complex<double> > : default_packet_traits
};
#endif
template<> struct unpacket_traits<Packet1cd> { typedef std::complex<double> type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; };
template<> struct unpacket_traits<Packet1cd> {
typedef std::complex<double> type;
typedef Packet1cd half;
typedef Packet2d as_real;
enum {
size=1,
alignment=Aligned16,
vectorizable=true,
masked_load_available=false,
masked_store_available=false
};
};
template<> EIGEN_STRONG_INLINE Packet1cd padd<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); }
template<> EIGEN_STRONG_INLINE Packet1cd psub<Packet1cd>(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); }
@ -426,8 +449,15 @@ template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, co
return Packet2cf(_mm_castpd_ps(result));
}
} // end namespace internal
template<> EIGEN_STRONG_INLINE Packet1cd psqrt<Packet1cd>(const Packet1cd& a) {
return psqrt_complex<Packet1cd>(a);
}
template<> EIGEN_STRONG_INLINE Packet2cf psqrt<Packet2cf>(const Packet2cf& a) {
return psqrt_complex<Packet2cf>(a);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_COMPLEX_SSE_H

View File

@ -267,6 +267,10 @@ template<> EIGEN_STRONG_INLINE Packet16b pset1<Packet16b>(const bool& from) {
template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) { return _mm_castsi128_ps(pset1<Packet4i>(from)); }
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); }
template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return Packet4f(_mm_set_epi32(0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return Packet4i(_mm_set_epi32(0, -1, 0, -1)); }
template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return Packet2d(_mm_set_epi32(0, 0, -1, -1)); }
template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); }
template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); }
template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); }

View File

@ -473,8 +473,6 @@ void packetmath() {
CHECK_CWISE3_IF(true, internal::pselect, internal::pselect);
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
for (int i = 0; i < size; ++i) {
data1[i] = internal::random<Scalar>();
}
@ -486,6 +484,11 @@ void packetmath() {
packetmath_boolean_mask_ops<Scalar, Packet>();
packetmath_pcast_ops_runner<Scalar, Packet>::run();
packetmath_minus_zero_add<Scalar, Packet>();
for (int i = 0; i < size; ++i) {
data1[i] = numext::abs(internal::random<Scalar>());
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
}
// Notice that this definition works for complex types as well.
@ -899,6 +902,8 @@ void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) {
template <typename Scalar, typename Packet>
void packetmath_complex() {
typedef internal::packet_traits<Scalar> PacketTraits;
typedef typename Scalar::value_type RealScalar;
const int PacketSize = internal::unpacket_traits<Packet>::size;
const int size = PacketSize * 4;
@ -917,11 +922,55 @@ void packetmath_complex() {
test_conj_helper<Scalar, Packet, true, false>(data1, data2, ref, pval);
test_conj_helper<Scalar, Packet, true, true>(data1, data2, ref, pval);
// Test pcplxflip.
{
for (int i = 0; i < PacketSize; ++i) ref[i] = Scalar(std::imag(data1[i]), std::real(data1[i]));
internal::pstore(pval, internal::pcplxflip(internal::pload<Packet>(data1)));
VERIFY(test::areApprox(ref, pval, PacketSize) && "pcplxflip");
}
if (PacketTraits::HasSqrt) {
for (int i = 0; i < size; ++i) {
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
}
CHECK_CWISE1(numext::sqrt, internal::psqrt);
// Test misc. corner cases.
const RealScalar zero = RealScalar(0);
const RealScalar one = RealScalar(1);
const RealScalar inf = std::numeric_limits<RealScalar>::infinity();
const RealScalar nan = std::numeric_limits<RealScalar>::quiet_NaN();
data1[0] = Scalar(zero, zero);
data1[1] = Scalar(-zero, zero);
data1[2] = Scalar(one, zero);
data1[3] = Scalar(zero, one);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(-one, zero);
data1[1] = Scalar(zero, -one);
data1[2] = Scalar(one, one);
data1[3] = Scalar(-one, -one);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(inf, zero);
data1[1] = Scalar(zero, inf);
data1[2] = Scalar(-inf, zero);
data1[3] = Scalar(zero, -inf);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(inf, inf);
data1[1] = Scalar(-inf, inf);
data1[2] = Scalar(inf, -inf);
data1[3] = Scalar(-inf, -inf);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(nan, zero);
data1[1] = Scalar(zero, nan);
data1[2] = Scalar(nan, one);
data1[3] = Scalar(one, nan);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(nan, nan);
data1[1] = Scalar(inf, nan);
data1[2] = Scalar(nan, inf);
data1[3] = Scalar(-inf, nan);
CHECK_CWISE1(numext::sqrt, internal::psqrt);
}
}
template <typename Scalar, typename Packet>