From 19a6a827c42062133aee119bc57c67ac6aac043c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Wed, 23 Mar 2022 15:27:57 +0000 Subject: [PATCH] Optimize visitor traversal in case of RowMajor. --- Eigen/src/Core/Visitor.h | 31 +++++++-- test/visitor.cpp | 147 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 7 deletions(-) diff --git a/Eigen/src/Core/Visitor.h b/Eigen/src/Core/Visitor.h index b384bb3d0..e1c17fc57 100644 --- a/Eigen/src/Core/Visitor.h +++ b/Eigen/src/Core/Visitor.h @@ -23,8 +23,10 @@ template struct visitor_impl { enum { - col = (UnrollCount-1) / Derived::RowsAtCompileTime, - row = (UnrollCount-1) % Derived::RowsAtCompileTime + col = Derived::IsRowMajor ? (UnrollCount-1) % Derived::ColsAtCompileTime + : (UnrollCount-1) / Derived::RowsAtCompileTime, + row = Derived::IsRowMajor ? (UnrollCount-1) / Derived::ColsAtCompileTime + : (UnrollCount-1) % Derived::RowsAtCompileTime }; EIGEN_DEVICE_FUNC @@ -60,11 +62,25 @@ struct visitor_impl static inline void run(const Derived& mat, Visitor& visitor) { visitor.init(mat.coeff(0,0), 0, 0); - for(Index i = 1; i < mat.rows(); ++i) - visitor(mat.coeff(i, 0), i, 0); - for(Index j = 1; j < mat.cols(); ++j) - for(Index i = 0; i < mat.rows(); ++i) - visitor(mat.coeff(i, j), i, j); + if (Derived::IsRowMajor) { + for(Index i = 1; i < mat.cols(); ++i) { + visitor(mat.coeff(0, i), 0, i); + } + for(Index j = 1; j < mat.rows(); ++j) { + for(Index i = 0; i < mat.cols(); ++i) { + visitor(mat.coeff(j, i), j, i); + } + } + } else { + for(Index i = 1; i < mat.rows(); ++i) { + visitor(mat.coeff(i, 0), i, 0); + } + for(Index j = 1; j < mat.cols(); ++j) { + for(Index i = 0; i < mat.rows(); ++i) { + visitor(mat.coeff(i, j), i, j); + } + } + } } }; @@ -114,6 +130,7 @@ public: PacketAccess = Evaluator::Flags & PacketAccessBit, IsRowMajor = XprType::IsRowMajor, RowsAtCompileTime = XprType::RowsAtCompileTime, + ColsAtCompileTime = XprType::ColsAtCompileTime, CoeffReadCost = Evaluator::CoeffReadCost }; diff --git a/test/visitor.cpp b/test/visitor.cpp index 05c2a4838..7ff7bf1ac 100644 --- a/test/visitor.cpp +++ b/test/visitor.cpp @@ -173,6 +173,152 @@ template void vectorVisitor(const VectorType& w) } } +template +struct TrackedVisitor { + void init(T v, int i, int j) { return this->operator()(v,i,j); } + void operator()(T v, int i, int j) { + EIGEN_UNUSED_VARIABLE(v) + visited.push_back({i, j}); + vectorized = false; + } + + template + void packet(Packet p, int i, int j) { + EIGEN_UNUSED_VARIABLE(p) + visited.push_back({i, j}); + vectorized = true; + } + std::vector> visited; + bool vectorized; +}; + +namespace Eigen { +namespace internal { + +template +struct functor_traits > { + enum { PacketAccess = Vectorizable, Cost = 1 }; +}; + +} // namespace internal +} // namespace Eigen + +void checkOptimalTraversal() { + + // Unrolled - ColMajor. + { + Eigen::Matrix4f X = Eigen::Matrix4f::Random(); + TrackedVisitor visitor; + X.visit(visitor); + int count = 0; + for (int j=0; j; + { + Matrix4fRowMajor X = Matrix4fRowMajor::Random(); + TrackedVisitor visitor; + X.visit(visitor); + int count = 0; + for (int i=0; i visitor; + X.visit(visitor); + int count = 0; + for (int j=0; j; + { + MatrixXfRowMajor X = MatrixXfRowMajor::Random(4, 4); + TrackedVisitor visitor; + X.visit(visitor); + int count = 0; + for (int i=0; i::size; + Eigen::MatrixXf X = Eigen::MatrixXf::Random(4 * PacketSize, 4 * PacketSize); + TrackedVisitor visitor; + X.visit(visitor); + int previ = -1; + int prevj = 0; + for (const auto& p : visitor.visited) { + int i = p.first; + int j = p.second; + VERIFY( + (j == prevj && i == previ + 1) // Advance single element + || (j == prevj && i == previ + PacketSize) // Advance packet + || (j == prevj + 1 && i == 0) // Advance column + ); + previ = i; + prevj = j; + } + if (Eigen::internal::packet_traits::Vectorizable) { + VERIFY(visitor.vectorized); + } + } + + // Vectorized - RowMajor. + { + // Ensure rows/cols is larger than packet size. + constexpr int PacketSize = Eigen::internal::packet_traits::size; + MatrixXfRowMajor X = MatrixXfRowMajor::Random(4 * PacketSize, 4 * PacketSize); + TrackedVisitor visitor; + X.visit(visitor); + int previ = 0; + int prevj = -1; + for (const auto& p : visitor.visited) { + int i = p.first; + int j = p.second; + VERIFY( + (i == previ && j == prevj + 1) // Advance single element + || (i == previ && j == prevj + PacketSize) // Advance packet + || (i == previ + 1 && j == 0) // Advance row + ); + previ = i; + prevj = j; + } + if (Eigen::internal::packet_traits::Vectorizable) { + VERIFY(visitor.vectorized); + } + } + +} + EIGEN_DECLARE_TEST(visitor) { for(int i = 0; i < g_repeat; i++) { @@ -190,4 +336,5 @@ EIGEN_DECLARE_TEST(visitor) CALL_SUBTEST_9( vectorVisitor(RowVectorXd(10)) ); CALL_SUBTEST_10( vectorVisitor(VectorXf(33)) ); } + CALL_SUBTEST_11(checkOptimalTraversal()); }