Speed up sparse x dense dot product.

This commit is contained in:
Rasmus Munk Larsen 2024-02-24 19:13:33 +00:00
parent 7a88cdd6ad
commit a2f8eba026

View File

@ -17,7 +17,8 @@ namespace Eigen {
template <typename Derived> template <typename Derived>
template <typename OtherDerived> template <typename OtherDerived>
typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const { inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot(
const MatrixBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived, OtherDerived) EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived, OtherDerived)
@ -30,17 +31,23 @@ typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot(const
internal::evaluator<Derived> thisEval(derived()); internal::evaluator<Derived> thisEval(derived());
typename internal::evaluator<Derived>::InnerIterator i(thisEval, 0); typename internal::evaluator<Derived>::InnerIterator i(thisEval, 0);
Scalar res(0); // Two accumulators, which breaks the dependency chain on the accumulator
while (i) { // and allows more instruction-level parallelism in the following loop.
res += numext::conj(i.value()) * other.coeff(i.index()); Scalar res1(0);
Scalar res2(0);
for (; i; ++i) {
res1 += numext::conj(i.value()) * other.coeff(i.index());
++i; ++i;
if (i) {
res2 += numext::conj(i.value()) * other.coeff(i.index());
}
} }
return res; return res1 + res2;
} }
template <typename Derived> template <typename Derived>
template <typename OtherDerived> template <typename OtherDerived>
typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot( inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot(
const SparseMatrixBase<OtherDerived>& other) const { const SparseMatrixBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)