Eigen pblend

This commit is contained in:
Charles Schlosser 2024-04-15 16:19:53 +00:00 committed by Rasmus Munk Larsen
parent 9099c5eac7
commit 6ad2ccea4e
4 changed files with 24 additions and 59 deletions

View File

@ -1401,12 +1401,6 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet, 1>& /*kernel*/) {
template <size_t N>
struct Selector {
bool select[N];
template <typename MaskType = int>
EIGEN_DEVICE_FUNC inline MaskType mask(size_t begin = 0, size_t end = N) const {
MaskType res = 0;
for (size_t i = begin; i < end; i++) res |= (static_cast<MaskType>(select[i]) << i);
return res;
}
};
template <typename Packet>

View File

@ -2133,37 +2133,20 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4d, 4>& kernel) {
template <>
EIGEN_STRONG_INLINE Packet8f pblend(const Selector<8>& ifPacket, const Packet8f& thenPacket,
const Packet8f& elsePacket) {
#ifdef EIGEN_VECTORIZE_AVX2
const __m256i zero = _mm256_setzero_si256();
const __m256i select =
_mm256_set_epi32(ifPacket.select[7], ifPacket.select[6], ifPacket.select[5], ifPacket.select[4],
ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256i false_mask = _mm256_cmpeq_epi32(zero, select);
return _mm256_blendv_ps(thenPacket, elsePacket, _mm256_castsi256_ps(false_mask));
#else
const __m256 zero = _mm256_setzero_ps();
const __m256 select = _mm256_set_ps(ifPacket.select[7], ifPacket.select[6], ifPacket.select[5], ifPacket.select[4],
ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256 false_mask = _mm256_cmp_ps(select, zero, _CMP_EQ_UQ);
return _mm256_blendv_ps(thenPacket, elsePacket, false_mask);
#endif
const __m256i true_mask = _mm256_sub_epi32(_mm256_setzero_si256(), select);
return pselect<Packet8f>(_mm256_castsi256_ps(true_mask), thenPacket, elsePacket);
}
template <>
EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, const Packet4d& thenPacket,
const Packet4d& elsePacket) {
#ifdef EIGEN_VECTORIZE_AVX2
const __m256i zero = _mm256_setzero_si256();
const __m256i select =
_mm256_set_epi64x(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256i false_mask = _mm256_cmpeq_epi64(select, zero);
return _mm256_blendv_pd(thenPacket, elsePacket, _mm256_castsi256_pd(false_mask));
#else
const __m256d zero = _mm256_setzero_pd();
const __m256d select = _mm256_set_pd(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m256d false_mask = _mm256_cmp_pd(select, zero, _CMP_EQ_UQ);
return _mm256_blendv_pd(thenPacket, elsePacket, false_mask);
#endif
const __m256i true_mask = _mm256_sub_epi64(_mm256_setzero_si256(), select);
return pselect<Packet4d>(_mm256_castsi256_pd(true_mask), thenPacket, elsePacket);
}
// Packet math for Eigen::half

View File

@ -2148,16 +2148,24 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16i, 4>& kernel) {
PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 3, 1);
}
template <size_t N>
EIGEN_STRONG_INLINE int avx512_blend_mask(const Selector<N>& ifPacket) {
alignas(__m128i) uint8_t aux[sizeof(__m128i)];
for (size_t i = 0; i < N; i++) aux[i] = static_cast<uint8_t>(ifPacket.select[i]);
__m128i paux = _mm_sub_epi8(_mm_setzero_si128(), _mm_load_si128(reinterpret_cast<const __m128i*>(aux)));
return _mm_movemask_epi8(paux);
}
template <>
EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& ifPacket, const Packet16f& thenPacket,
const Packet16f& elsePacket) {
__mmask16 m = ifPacket.mask<__mmask16>();
__mmask16 m = avx512_blend_mask(ifPacket);
return _mm512_mask_blend_ps(m, elsePacket, thenPacket);
}
template <>
EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d& thenPacket,
const Packet8d& elsePacket) {
__mmask8 m = ifPacket.mask<__mmask8>();
__mmask8 m = avx512_blend_mask(ifPacket);
return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
}

View File

@ -2233,26 +2233,16 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16b, 16>& kernel) {
template <>
EIGEN_STRONG_INLINE Packet2l pblend(const Selector<2>& ifPacket, const Packet2l& thenPacket,
const Packet2l& elsePacket) {
const __m128i zero = _mm_setzero_si128();
const __m128i select = _mm_set_epi64x(ifPacket.select[1], ifPacket.select[0]);
__m128i false_mask = pcmp_eq<Packet2l>(select, zero);
#ifdef EIGEN_VECTORIZE_SSE4_1
return _mm_blendv_epi8(thenPacket, elsePacket, false_mask);
#else
return _mm_or_si128(_mm_andnot_si128(false_mask, thenPacket), _mm_and_si128(false_mask, elsePacket));
#endif
const __m128i true_mask = _mm_sub_epi64(_mm_setzero_si128(), select);
return pselect<Packet2l>(true_mask, thenPacket, elsePacket);
}
template <>
EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket,
const Packet4i& elsePacket) {
const __m128i zero = _mm_setzero_si128();
const __m128i select = _mm_set_epi32(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m128i false_mask = _mm_cmpeq_epi32(select, zero);
#ifdef EIGEN_VECTORIZE_SSE4_1
return _mm_blendv_epi8(thenPacket, elsePacket, false_mask);
#else
return _mm_or_si128(_mm_andnot_si128(false_mask, thenPacket), _mm_and_si128(false_mask, elsePacket));
#endif
const __m128i true_mask = _mm_sub_epi32(_mm_setzero_si128(), select);
return pselect<Packet4i>(true_mask, thenPacket, elsePacket);
}
template <>
EIGEN_STRONG_INLINE Packet4ui pblend(const Selector<4>& ifPacket, const Packet4ui& thenPacket,
@ -2262,26 +2252,16 @@ EIGEN_STRONG_INLINE Packet4ui pblend(const Selector<4>& ifPacket, const Packet4u
template <>
EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket,
const Packet4f& elsePacket) {
const __m128 zero = _mm_setzero_ps();
const __m128 select = _mm_set_ps(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
__m128 false_mask = _mm_cmpeq_ps(select, zero);
#ifdef EIGEN_VECTORIZE_SSE4_1
return _mm_blendv_ps(thenPacket, elsePacket, false_mask);
#else
return _mm_or_ps(_mm_andnot_ps(false_mask, thenPacket), _mm_and_ps(false_mask, elsePacket));
#endif
const __m128i select = _mm_set_epi32(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]);
const __m128i true_mask = _mm_sub_epi32(_mm_setzero_si128(), select);
return pselect<Packet4f>(_mm_castsi128_ps(true_mask), thenPacket, elsePacket);
}
template <>
EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket,
const Packet2d& elsePacket) {
const __m128d zero = _mm_setzero_pd();
const __m128d select = _mm_set_pd(ifPacket.select[1], ifPacket.select[0]);
__m128d false_mask = _mm_cmpeq_pd(select, zero);
#ifdef EIGEN_VECTORIZE_SSE4_1
return _mm_blendv_pd(thenPacket, elsePacket, false_mask);
#else
return _mm_or_pd(_mm_andnot_pd(false_mask, thenPacket), _mm_and_pd(false_mask, elsePacket));
#endif
const __m128i select = _mm_set_epi64x(ifPacket.select[1], ifPacket.select[0]);
const __m128i true_mask = _mm_sub_epi64(_mm_setzero_si128(), select);
return pselect<Packet2d>(_mm_castsi128_pd(true_mask), thenPacket, elsePacket);
}
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.