Use numext::fma in more places in SparseCore.

This commit is contained in:
Rasmus Munk Larsen 2025-07-17 21:20:39 +00:00
parent d7fa5ebe0e
commit 2cf66d4b0d
2 changed files with 15 additions and 7 deletions

View File

@ -67,7 +67,7 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
Scalar res(0);
while (i && j) {
if (i.index() == j.index()) {
res += numext::conj(i.value()) * j.value();
res = numext::fma(numext::conj(i.value()), j.value(), res);
++i;
++j;
} else if (i.index() < j.index())

View File

@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
lastVal = it.value();
lastIndex = it.index();
if (lastIndex == i) break;
tmp -= lastVal * other.coeff(lastIndex, col);
tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp);
}
if (Mode & UnitDiag)
other.coeffRef(i, col) = tmp;
@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
} else if (it && it.index() == i)
++it;
for (; it; ++it) {
tmp -= it.value() * other.coeff(it.index(), col);
tmp = numext::fma(-it.value(), other.coeff(it.index(), col), tmp);
}
if (Mode & UnitDiag)
@ -107,7 +107,9 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
tmp /= it.value();
}
if (it && it.index() == i) ++it;
for (; it; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
for (; it; ++it) {
other.coeffRef(it.index(), col) = numext::fma(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
}
@ -135,7 +137,9 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
other.coeffRef(i, col) /= it.value();
}
LhsIterator it(lhsEval, i);
for (; it && it.index() < i; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
for (; it && it.index() < i; ++it) {
other.coeffRef(it.index(), col) = numext::fma(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
}
@ -215,9 +219,13 @@ struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
tempVector.restart();
if (IsLower) {
if (it.index() == i) ++it;
for (; it; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
for (; it; ++it) {
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
}
} else {
for (; it && it.index() < i; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
for (; it && it.index() < i; ++it) {
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
}
}
}
}