sve instrinsics with "_x" suffix will be faster than "_z" suffix

This commit is contained in:
qile lin 2024-08-23 12:52:22 +00:00 committed by Rasmus Munk Larsen
parent 98f1ac5e65
commit 3b5a1b4157
2 changed files with 54 additions and 54 deletions

View File

@ -86,22 +86,22 @@ template <>
EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a) {
numext::int32_t c[packet_traits<numext::int32_t>::size];
for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
return svadd_s32_x(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
}
template <>
EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svadd_s32_z(svptrue_b32(), a, b);
return svadd_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svsub_s32_z(svptrue_b32(), a, b);
return svsub_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) {
return svneg_s32_z(svptrue_b32(), a);
return svneg_s32_x(svptrue_b32(), a);
}
template <>
@ -111,27 +111,27 @@ EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) {
template <>
EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svmul_s32_z(svptrue_b32(), a, b);
return svmul_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svdiv_s32_z(svptrue_b32(), a, b);
return svdiv_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) {
return svmla_s32_z(svptrue_b32(), c, a, b);
return svmla_s32_x(svptrue_b32(), c, a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svmin_s32_z(svptrue_b32(), a, b);
return svmin_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svmax_s32_z(svptrue_b32(), a, b);
return svmax_s32_x(svptrue_b32(), a, b);
}
template <>
@ -151,47 +151,47 @@ EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi
template <>
EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/) {
return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
return svdup_n_s32_x(svptrue_b32(), 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/) {
return svdup_n_s32_z(svptrue_b32(), 0);
return svdup_n_s32_x(svptrue_b32(), 0);
}
template <>
EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svand_s32_z(svptrue_b32(), a, b);
return svand_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svorr_s32_z(svptrue_b32(), a, b);
return svorr_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b) {
return sveor_s32_z(svptrue_b32(), a, b);
return sveor_s32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b) {
return svbic_s32_z(svptrue_b32(), a, b);
return svbic_s32_x(svptrue_b32(), a, b);
}
template <int N>
EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) {
return svasrd_n_s32_z(svptrue_b32(), a, N);
return svasrd_n_s32_x(svptrue_b32(), a, N);
}
template <int N>
EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) {
return svreinterpret_s32_u32(svlsr_n_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), N));
return svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), svreinterpret_u32_s32(a), N));
}
template <int N>
EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) {
return svlsl_n_s32_z(svptrue_b32(), a, N);
return svlsl_n_s32_x(svptrue_b32(), a, N);
}
template <>
@ -257,7 +257,7 @@ EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) {
template <>
EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) {
return svabs_s32_z(svptrue_b32(), a);
return svabs_s32_x(svptrue_b32(), a);
}
template <>
@ -270,29 +270,29 @@ EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a) {
EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
// Multiply the vector by its reverse
svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
svint32_t prod = svmul_s32_x(svptrue_b32(), a, svrev_s32(a));
svint32_t half_prod;
// Extract the high half of the vector. Depending on the VL more reductions need to be done
if (EIGEN_ARM64_SVE_VL >= 2048) {
half_prod = svtbl_s32(prod, svindex_u32(32, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 1024) {
half_prod = svtbl_s32(prod, svindex_u32(16, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 512) {
half_prod = svtbl_s32(prod, svindex_u32(8, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 256) {
half_prod = svtbl_s32(prod, svindex_u32(4, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
}
// Last reduction
half_prod = svtbl_s32(prod, svindex_u32(2, 1));
prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
prod = svmul_s32_x(svptrue_b32(), prod, half_prod);
// The reduction is done to the first element.
return pfirst<PacketXi>(prod);
@ -386,29 +386,29 @@ EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from) {
template <>
EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from) {
return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), from));
}
template <>
EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a) {
float c[packet_traits<float>::size];
for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
return svadd_f32_x(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
}
template <>
EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svadd_f32_z(svptrue_b32(), a, b);
return svadd_f32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svsub_f32_z(svptrue_b32(), a, b);
return svsub_f32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) {
return svneg_f32_z(svptrue_b32(), a);
return svneg_f32_x(svptrue_b32(), a);
}
template <>
@ -418,22 +418,22 @@ EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) {
template <>
EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svmul_f32_z(svptrue_b32(), a, b);
return svmul_f32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svdiv_f32_z(svptrue_b32(), a, b);
return svdiv_f32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) {
return svmla_f32_z(svptrue_b32(), c, a, b);
return svmla_f32_x(svptrue_b32(), c, a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svmin_f32_z(svptrue_b32(), a, b);
return svmin_f32_x(svptrue_b32(), a, b);
}
template <>
@ -443,12 +443,12 @@ EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, con
template <>
EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
return svminnm_f32_z(svptrue_b32(), a, b);
return svminnm_f32_x(svptrue_b32(), a, b);
}
template <>
EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svmax_f32_z(svptrue_b32(), a, b);
return svmax_f32_x(svptrue_b32(), a, b);
}
template <>
@ -458,7 +458,7 @@ EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, con
template <>
EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) {
return svmaxnm_f32_z(svptrue_b32(), a, b);
return svmaxnm_f32_x(svptrue_b32(), a, b);
}
// Float comparisons in SVE return svbool (predicate). Use svdup to set active
@ -478,43 +478,43 @@ EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf
return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
}
// Do a predicate inverse (svnot_b_z) on the predicate resulted from the
// Do a predicate inverse (svnot_b_x) on the predicate resulted from the
// greater/equal comparison (svcmpge_f32). Then fill a float vector with the
// active elements.
template <>
EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_x(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
}
template <>
EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a) {
return svrintm_f32_z(svptrue_b32(), a);
return svrintm_f32_x(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/) {
return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
return svreinterpret_f32_u32(svdup_n_u32_x(svptrue_b32(), 0xffffffffu));
}
// Logical Operations are not supported for float, so reinterpret casts
template <>
EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
return svreinterpret_f32_u32(svand_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
return svreinterpret_f32_u32(svorr_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
return svreinterpret_f32_u32(sveor_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b) {
return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
return svreinterpret_f32_u32(svbic_u32_x(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
}
template <>
@ -579,7 +579,7 @@ EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) {
template <>
EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) {
return svabs_f32_z(svptrue_b32(), a);
return svabs_f32_x(svptrue_b32(), a);
}
// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
@ -601,29 +601,29 @@ template <>
EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a) {
EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
// Multiply the vector by its reverse
svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
svfloat32_t prod = svmul_f32_x(svptrue_b32(), a, svrev_f32(a));
svfloat32_t half_prod;
// Extract the high half of the vector. Depending on the VL more reductions need to be done
if (EIGEN_ARM64_SVE_VL >= 2048) {
half_prod = svtbl_f32(prod, svindex_u32(32, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 1024) {
half_prod = svtbl_f32(prod, svindex_u32(16, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 512) {
half_prod = svtbl_f32(prod, svindex_u32(8, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
}
if (EIGEN_ARM64_SVE_VL >= 256) {
half_prod = svtbl_f32(prod, svindex_u32(4, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
}
// Last reduction
half_prod = svtbl_f32(prod, svindex_u32(2, 1));
prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
prod = svmul_f32_x(svptrue_b32(), prod, half_prod);
// The reduction is done to the first element.
return pfirst<PacketXf>(prod);

View File

@ -28,12 +28,12 @@ struct type_casting_traits<numext::int32_t, float> {
template <>
EIGEN_STRONG_INLINE PacketXf pcast<PacketXi, PacketXf>(const PacketXi& a) {
return svcvt_f32_s32_z(svptrue_b32(), a);
return svcvt_f32_s32_x(svptrue_b32(), a);
}
template <>
EIGEN_STRONG_INLINE PacketXi pcast<PacketXf, PacketXi>(const PacketXf& a) {
return svcvt_s32_f32_z(svptrue_b32(), a);
return svcvt_s32_f32_x(svptrue_b32(), a);
}
template <>