From b56e30841ca3140f310bd66aa93267ce134a1e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Tue, 20 Feb 2024 18:21:45 +0000 Subject: [PATCH] Enable direct access for IndexedView. --- Eigen/src/Core/ArithmeticSequence.h | 5 ++ Eigen/src/Core/IndexedView.h | 99 +++++++++++++++++++++---- Eigen/src/Core/util/IndexedViewHelper.h | 5 ++ test/indexed_view.cpp | 33 ++++++++- 4 files changed, 126 insertions(+), 16 deletions(-) diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h index 0f45e89ea..055beabd5 100644 --- a/Eigen/src/Core/ArithmeticSequence.h +++ b/Eigen/src/Core/ArithmeticSequence.h @@ -226,6 +226,11 @@ struct get_compile_time_incr > enum { value = get_fixed_value::value }; }; +template +constexpr Index get_runtime_incr(const ArithmeticSequence& x) EIGEN_NOEXCEPT { + return static_cast(x.incrObject()); +}; + } // end namespace internal /** \namespace Eigen::indexing diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h index 0a024170e..b90ecb1e6 100644 --- a/Eigen/src/Core/IndexedView.h +++ b/Eigen/src/Core/IndexedView.h @@ -75,11 +75,11 @@ struct traits> : traits { typedef Block BlockType; }; -} // namespace internal - -template +template class IndexedViewImpl; +} // namespace internal + /** \class IndexedView * \ingroup Core_Module * @@ -120,19 +120,36 @@ class IndexedViewImpl; */ template class IndexedView - : public IndexedViewImpl::StorageKind> { + : public internal::IndexedViewImpl::StorageKind, + (internal::traits>::Flags & + DirectAccessBit) != 0> { public: - typedef - typename IndexedViewImpl::StorageKind>::Base - Base; + typedef typename internal::IndexedViewImpl< + XprType, RowIndices, ColIndices, typename internal::traits::StorageKind, + (internal::traits>::Flags & DirectAccessBit) != 0> + Base; EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView) + template + IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {} +}; + +namespace internal { + +// Generic API dispatcher +template +class IndexedViewImpl : public internal::generic_xpr_base>::type { + public: + typedef typename internal::generic_xpr_base>::type Base; typedef typename internal::ref_selector::non_const_type MatrixTypeNested; typedef internal::remove_all_t NestedExpression; + typedef typename XprType::Scalar Scalar; + + EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl) template - IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) + IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {} /** \returns number of rows */ @@ -153,20 +170,76 @@ class IndexedView /** \returns a const reference to the object storing/generating the column indices */ const ColIndices& colIndices() const { return m_colIndices; } + constexpr Scalar& coeffRef(Index rowId, Index colId) { + return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]); + } + + constexpr const Scalar& coeffRef(Index rowId, Index colId) const { + return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]); + } + protected: MatrixTypeNested m_xpr; RowIndices m_rowIndices; ColIndices m_colIndices; }; -// Generic API dispatcher template -class IndexedViewImpl : public internal::generic_xpr_base>::type { +class IndexedViewImpl + : public IndexedViewImpl { public: - typedef typename internal::generic_xpr_base>::type Base; -}; + using Base = internal::IndexedViewImpl::StorageKind, false>; + using Derived = IndexedView; -namespace internal { + EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl) + + template + IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {} + + Index rowIncrement() const { + if (traits::RowIncr != DynamicIndex && traits::RowIncr != UndefinedIncr) { + return traits::RowIncr; + } + return get_runtime_incr(this->rowIndices()); + } + Index colIncrement() const { + if (traits::ColIncr != DynamicIndex && traits::ColIncr != UndefinedIncr) { + return traits::ColIncr; + } + return get_runtime_incr(this->colIndices()); + } + + Index innerIncrement() const { return traits::IsRowMajor ? colIncrement() : rowIncrement(); } + + Index outerIncrement() const { return traits::IsRowMajor ? rowIncrement() : colIncrement(); } + + std::decay_t* data() { + Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride(); + Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride(); + return this->nestedExpression().data() + row_offset + col_offset; + } + + const std::decay_t* data() const { + Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride(); + Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride(); + return this->nestedExpression().data() + row_offset + col_offset; + } + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const EIGEN_NOEXCEPT { + if (traits::InnerStrideAtCompileTime != Dynamic) { + return traits::InnerStrideAtCompileTime; + } + return innerIncrement() * this->nestedExpression().innerStride(); + } + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const EIGEN_NOEXCEPT { + if (traits::OuterStrideAtCompileTime != Dynamic) { + return traits::OuterStrideAtCompileTime; + } + return outerIncrement() * this->nestedExpression().outerStride(); + } +}; template struct unary_evaluator, IndexBased> diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h index 3b451084b..00018a98d 100644 --- a/Eigen/src/Core/util/IndexedViewHelper.h +++ b/Eigen/src/Core/util/IndexedViewHelper.h @@ -67,6 +67,11 @@ struct get_compile_time_incr { enum { value = UndefinedIncr }; }; +template +constexpr Index get_runtime_incr(const T&) EIGEN_NOEXCEPT { + return Index(1); +}; + // Analogue of std::get<0>(x), but tailored for our needs. template EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT { diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index 4040448aa..d3cf4a679 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -498,12 +498,39 @@ void check_indexed_view() { // A(1, seq(0,2,1)).cwiseAbs().colwise().replicate(2).eval(); STATIC_CHECK(((internal::evaluator::Flags & RowMajorBit) == RowMajorBit)); } + + // Direct access. + { + int rows = 3; + int row_start = internal::random(0, rows - 1); + int row_inc = internal::random(1, rows - row_start); + int row_size = internal::random(1, (rows - row_start) / row_inc); + auto row_seq = seqN(row_start, row_size, row_inc); + + int cols = 3; + int col_start = internal::random(0, cols - 1); + int col_inc = internal::random(1, cols - col_start); + int col_size = internal::random(1, (cols - col_start) / col_inc); + auto col_seq = seqN(col_start, col_size, col_inc); + + MatrixXd m1 = MatrixXd::Random(rows, cols); + MatrixXd m2 = MatrixXd::Random(cols, rows); + VERIFY_IS_APPROX(m1(row_seq, indexing::all) * m2, m1(row_seq, indexing::all).eval() * m2); + VERIFY_IS_APPROX(m1 * m2(indexing::all, col_seq), m1 * m2(indexing::all, col_seq).eval()); + VERIFY_IS_APPROX(m1(row_seq, col_seq) * m2(col_seq, row_seq), + m1(row_seq, col_seq).eval() * m2(col_seq, row_seq).eval()); + + VectorXd v1 = VectorXd::Random(cols); + VERIFY_IS_APPROX(m1(row_seq, col_seq) * v1(col_seq), m1(row_seq, col_seq).eval() * v1(col_seq).eval()); + VERIFY_IS_APPROX(v1(col_seq).transpose() * m2(col_seq, row_seq), + v1(col_seq).transpose().eval() * m2(col_seq, row_seq).eval()); + } } EIGEN_DECLARE_TEST(indexed_view) { - // for(int i = 0; i < g_repeat; i++) { - CALL_SUBTEST_1(check_indexed_view()); - // } + for (int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1(check_indexed_view()); + } // static checks of some internals: STATIC_CHECK((internal::is_valid_index_type::value));