Speed up StableNorm for non-trivial sizes and improve consistency between aligned and unaligned inputs.

This commit is contained in:
Rasmus Munk Larsen 2024-08-14 21:42:04 +00:00
parent 1dbc7581ec
commit 92e373e6f5

View File

@ -48,34 +48,16 @@ inline void stable_norm_kernel(const ExpressionType& bl, Scalar& ssq, Scalar& sc
template <typename VectorType, typename RealScalar>
void stable_norm_impl_inner_step(const VectorType& vec, RealScalar& ssq, RealScalar& scale, RealScalar& invScale) {
typedef typename VectorType::Scalar Scalar;
const Index blockSize = 4096;
typedef typename internal::nested_eval<VectorType, 2>::type VectorTypeCopy;
typedef internal::remove_all_t<VectorTypeCopy> VectorTypeCopyClean;
const VectorTypeCopy copy(vec);
enum {
CanAlign =
((int(VectorTypeCopyClean::Flags) & DirectAccessBit) ||
(int(internal::evaluator<VectorTypeCopyClean>::Alignment) > 0) // FIXME Alignment)>0 might not be enough
) &&
(blockSize * sizeof(Scalar) * 2 < EIGEN_STACK_ALLOCATION_LIMIT) &&
(EIGEN_MAX_STATIC_ALIGN_BYTES >
0) // if we cannot allocate on the stack, then let's not bother about this optimization
};
typedef std::conditional_t<
CanAlign,
Ref<const Matrix<Scalar, Dynamic, 1, 0, blockSize, 1>, internal::evaluator<VectorTypeCopyClean>::Alignment>,
typename VectorTypeCopyClean::ConstSegmentReturnType>
SegmentWrapper;
Index n = vec.size();
Index bi = internal::first_default_aligned(copy);
if (bi > 0) internal::stable_norm_kernel(copy.head(bi), ssq, scale, invScale);
for (; bi < n; bi += blockSize)
internal::stable_norm_kernel(SegmentWrapper(copy.segment(bi, numext::mini(blockSize, n - bi))), ssq, scale,
invScale);
Index blockEnd = numext::round_down(n, blockSize);
for (Index i = 0; i < blockEnd; i += blockSize) {
internal::stable_norm_kernel(vec.template segment<blockSize>(i), ssq, scale, invScale);
}
if (n > blockEnd) {
internal::stable_norm_kernel(vec.tail(n - blockEnd), ssq, scale, invScale);
}
}
template <typename VectorType>
@ -85,8 +67,7 @@ typename VectorType::RealScalar stable_norm_impl(const VectorType& vec,
using std::sqrt;
Index n = vec.size();
if (n == 1) return abs(vec.coeff(0));
if (EIGEN_PREDICT_FALSE(n == 1)) return abs(vec.coeff(0));
typedef typename VectorType::RealScalar RealScalar;
RealScalar scale(0);