From bfbc66e078570f992f97f9f4ff4119ea737b957c Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Wed, 29 Mar 2023 01:35:26 +0000 Subject: [PATCH] refactor indexedviewmethods, enable non-const ref access with symbolic indices --- Eigen/src/Core/IndexedView.h | 1 - Eigen/src/plugins/IndexedViewMethods.h | 297 +++++++++++++------------ test/indexed_view.cpp | 12 +- 3 files changed, 157 insertions(+), 153 deletions(-) diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h index f9673011f..feab3a9a3 100644 --- a/Eigen/src/Core/IndexedView.h +++ b/Eigen/src/Core/IndexedView.h @@ -93,7 +93,6 @@ class IndexedViewImpl; * - std::vector * - std::valarray * - std::array - * - Plain C arrays: int[N] * - Eigen::ArrayXi * - decltype(ArrayXi::LinSpaced(...)) * - Any view/expressions of the previous types diff --git a/Eigen/src/plugins/IndexedViewMethods.h b/Eigen/src/plugins/IndexedViewMethods.h index 011fcbed7..cef34a551 100644 --- a/Eigen/src/plugins/IndexedViewMethods.h +++ b/Eigen/src/plugins/IndexedViewMethods.h @@ -9,200 +9,207 @@ #if !defined(EIGEN_PARSED_BY_DOXYGEN) -// This file is automatically included twice to generate const and non-const versions - -#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS -#define EIGEN_INDEXED_VIEW_METHOD_CONST const -#define EIGEN_INDEXED_VIEW_METHOD_TYPE ConstIndexedViewType -#else -#define EIGEN_INDEXED_VIEW_METHOD_CONST -#define EIGEN_INDEXED_VIEW_METHOD_TYPE IndexedViewType -#endif - -#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS protected: - // define some aliases to ease readability -template -struct IvcRowType : public internal::IndexedViewCompatibleType {}; +template +using IvcRowType = typename internal::IndexedViewCompatibleType::type; -template -struct IvcColType : public internal::IndexedViewCompatibleType {}; +template +using IvcColType = typename internal::IndexedViewCompatibleType::type; -template -struct IvcType : public internal::IndexedViewCompatibleType {}; +template +using IvcType = typename internal::IndexedViewCompatibleType::type; -typedef typename internal::IndexedViewCompatibleType::type IvcIndex; +typedef typename internal::IndexedViewCompatibleType::type IvcIndex; -template -typename IvcRowType::type -ivcRow(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic(derived().rows()),Specialized); +template +IvcRowType ivcRow(const Indices& indices) const { + return internal::makeIndexedViewCompatible( + indices, internal::variable_if_dynamic(derived().rows()), Specialized); } -template -typename IvcColType::type -ivcCol(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic(derived().cols()),Specialized); +template +IvcColType ivcCol(const Indices& indices) const { + return internal::makeIndexedViewCompatible( + indices, internal::variable_if_dynamic(derived().cols()), Specialized); } -template -typename IvcColType::type -ivcSize(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic(derived().size()),Specialized); +template +IvcColType ivcSize(const Indices& indices) const { + return internal::makeIndexedViewCompatible( + indices, internal::variable_if_dynamic(derived().size()), Specialized); } public: -#endif +template +using IndexedViewType = IndexedView, IvcColType>; -template -struct EIGEN_INDEXED_VIEW_METHOD_TYPE { - typedef IndexedView::type, - typename IvcColType::type> type; -}; +template +using ConstIndexedViewType = IndexedView, IvcColType>; // This is the generic version -template -std::enable_if_t::value - && internal::traits::type>::ReturnAsIndexedView, - typename EIGEN_INDEXED_VIEW_METHOD_TYPE::type> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return typename EIGEN_INDEXED_VIEW_METHOD_TYPE::type - (derived(), ivcRow(rowIndices), ivcCol(colIndices)); +template +std::enable_if_t::value && + internal::traits>::ReturnAsIndexedView, + IndexedViewType> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + return IndexedViewType(derived(), ivcRow(rowIndices), ivcCol(colIndices)); +} + +template +std::enable_if_t::value && + internal::traits>::ReturnAsIndexedView, + ConstIndexedViewType> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { + return ConstIndexedViewType(derived(), ivcRow(rowIndices), ivcCol(colIndices)); } // The following overload returns a Block<> object -template -std::enable_if_t::value - && internal::traits::type>::ReturnAsBlock, - typename internal::traits::type>::BlockType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - typedef typename internal::traits::type>::BlockType BlockType; - typename IvcRowType::type actualRowIndices = ivcRow(rowIndices); - typename IvcColType::type actualColIndices = ivcCol(colIndices); - return BlockType(derived(), - internal::first(actualRowIndices), - internal::first(actualColIndices), - internal::index_list_size(actualRowIndices), - internal::index_list_size(actualColIndices)); +template +std::enable_if_t::value && + internal::traits>::ReturnAsBlock, + typename internal::traits>::BlockType> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + typedef typename internal::traits>::BlockType BlockType; + IvcRowType actualRowIndices = ivcRow(rowIndices); + IvcColType actualColIndices = ivcCol(colIndices); + return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices), + internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); +} + +template +std::enable_if_t::value && + internal::traits>::ReturnAsBlock, + typename internal::traits>::BlockType> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { + typedef typename internal::traits>::BlockType BlockType; + IvcRowType actualRowIndices = ivcRow(rowIndices); + IvcColType actualColIndices = ivcCol(colIndices); + return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices), + internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); } // The following overload returns a Scalar -template -std::enable_if_t::value - && internal::traits::type>::ReturnAsScalar, - CoeffReturnType > -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return Base::operator()(internal::eval_expr_given_size(rowIndices,rows()),internal::eval_expr_given_size(colIndices,cols())); +template +std::enable_if_t::value && + internal::traits>::ReturnAsScalar, + Scalar&> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()), + internal::eval_expr_given_size(colIndices, cols())); } -// The following three overloads are needed to handle raw Index[N] arrays. - -template -IndexedView::type> -operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return IndexedView::type> - (derived(), rowIndices, ivcCol(colIndices)); +template +std::enable_if_t::value && + internal::traits>::ReturnAsScalar, + CoeffReturnType> +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { + return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()), + internal::eval_expr_given_size(colIndices, cols())); } -template -IndexedView::type, const ColIndicesT (&)[ColIndicesN]> -operator()(const RowIndices& rowIndices, const ColIndicesT (&colIndices)[ColIndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return IndexedView::type,const ColIndicesT (&)[ColIndicesN]> - (derived(), ivcRow(rowIndices), colIndices); -} - -template -IndexedView -operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndicesT (&colIndices)[ColIndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return IndexedView - (derived(), rowIndices, colIndices); -} - - // Overloads for 1D vectors/arrays -template -std::enable_if_t< - IsRowMajor && (!(internal::get_compile_time_incr::type>::value==1 || internal::is_valid_index_type::value)), - IndexedView::type> > -operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ +template +std::enable_if_t>::value == 1 || + internal::is_valid_index_type::value)), + IndexedView>> +operator()(const Indices& indices) { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView::type> - (derived(), IvcIndex(0), ivcCol(indices)); + return IndexedView>(derived(), IvcIndex(0), ivcCol(indices)); } -template -std::enable_if_t< - (!IsRowMajor) && (!(internal::get_compile_time_incr::type>::value==1 || internal::is_valid_index_type::value)), - IndexedView::type,IvcIndex> > -operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ +template +std::enable_if_t>::value == 1 || + internal::is_valid_index_type::value)), + IndexedView>> +operator()(const Indices& indices) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView::type,IvcIndex> - (derived(), ivcRow(indices), IvcIndex(0)); + return IndexedView>(derived(), IvcIndex(0), ivcCol(indices)); } -template -std::enable_if_t< - (internal::get_compile_time_incr::type>::value==1) && (!internal::is_valid_index_type::value) && (!symbolic::is_symbolic::value), - VectorBlock::value> > -operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST -{ +template +std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr>::value == 1 || + internal::is_valid_index_type::value)), + IndexedView, IvcIndex>> +operator()(const Indices& indices) { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - typename IvcType::type actualIndices = ivcSize(indices); - return VectorBlock::value> - (derived(), internal::first(actualIndices), internal::index_list_size(actualIndices)); + return IndexedView, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0)); } -template -std::enable_if_t::value, CoeffReturnType > -operator()(const IndexType& id) EIGEN_INDEXED_VIEW_METHOD_CONST -{ - return Base::operator()(internal::eval_expr_given_size(id,size())); -} - -template -std::enable_if_t > -operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST -{ +template +std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr>::value == 1 || + internal::is_valid_index_type::value)), + IndexedView, IvcIndex>> +operator()(const Indices& indices) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView - (derived(), IvcIndex(0), indices); + return IndexedView, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0)); } -template -std::enable_if_t > -operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST -{ +template +std::enable_if_t<(internal::get_compile_time_incr>::value == 1) && + (!internal::is_valid_index_type::value) && (!symbolic::is_symbolic::value), + VectorBlock::value>> +operator()(const Indices& indices) { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView - (derived(), indices, IvcIndex(0)); + IvcType actualIndices = ivcSize(indices); + return VectorBlock::value>(derived(), internal::first(actualIndices), + internal::index_list_size(actualIndices)); } -#undef EIGEN_INDEXED_VIEW_METHOD_CONST -#undef EIGEN_INDEXED_VIEW_METHOD_TYPE +template +std::enable_if_t<(internal::get_compile_time_incr>::value == 1) && + (!internal::is_valid_index_type::value) && (!symbolic::is_symbolic::value), + VectorBlock::value>> +operator()(const Indices& indices) const { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + IvcType actualIndices = ivcSize(indices); + return VectorBlock::value>(derived(), internal::first(actualIndices), + internal::index_list_size(actualIndices)); +} -#ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS -#define EIGEN_INDEXED_VIEW_METHOD_2ND_PASS -#include "IndexedViewMethods.h" -#undef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS -#endif +template +std::enable_if_t::value, Scalar&> operator()(const IndexType& id) { + return Base::operator()(internal::eval_expr_given_size(id, size())); +} + +template +std::enable_if_t::value, CoeffReturnType> operator()(const IndexType& id) const { + return Base::operator()(internal::eval_expr_given_size(id, size())); +} + +template +std::enable_if_t> operator()( + const IndicesT (&indices)[IndicesN]) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView(derived(), IvcIndex(0), indices); +} + +template +std::enable_if_t> operator()( + const IndicesT (&indices)[IndicesN]) const { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView(derived(), IvcIndex(0), indices); +} + +template +std::enable_if_t> operator()( + const IndicesT (&indices)[IndicesN]) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView(derived(), indices, IvcIndex(0)); +} + +template +std::enable_if_t> operator()( + const IndicesT (&indices)[IndicesN]) const { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView(derived(), indices, IvcIndex(0)); +} #else // EIGEN_PARSED_BY_DOXYGEN diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index d14996066..1eab082b0 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -289,13 +289,8 @@ void check_indexed_view() VERIFY( (A(all, std::array{{1,3,2,4}})).ColsAtCompileTime == 4); VERIFY_IS_APPROX( (A(std::array{{1,3,5}}, std::array{{9,6,3,0}})), A(seqN(1,3,2), seqN(9,4,-3)) ); - - VERIFY_IS_APPROX( A({3, 1, 6, 5}, all), A(std::array{{3, 1, 6, 5}}, all) ); - VERIFY_IS_APPROX( A(all,{3, 1, 6, 5}), A(all,std::array{{3, 1, 6, 5}}) ); - VERIFY_IS_APPROX( A({1,3,5},{3, 1, 6, 5}), A(std::array{{1,3,5}},std::array{{3, 1, 6, 5}}) ); - - VERIFY_IS_EQUAL( A({1,3,5},{3, 1, 6, 5}).RowsAtCompileTime, 3 ); - VERIFY_IS_EQUAL( A({1,3,5},{3, 1, 6, 5}).ColsAtCompileTime, 4 ); + VERIFY_IS_EQUAL(A(std::array{1, 3, 5}, std::array{3, 1, 6, 5}).RowsAtCompileTime, 3); + VERIFY_IS_EQUAL(A(std::array{1, 3, 5}, std::array{3, 1, 6, 5}).ColsAtCompileTime, 4); VERIFY_IS_APPROX( a({3, 1, 6, 5}), a(std::array{{3, 1, 6, 5}}) ); VERIFY_IS_EQUAL( a({1,3,5}).SizeAtCompileTime, 3 ); @@ -364,6 +359,9 @@ void check_indexed_view() A(X,Y) = 1; A(XX,Y) = 1; A(X,YY) = 1; + // check symbolic indices + a(last) = 1; + A(last, last) = 1; // Check compilation of varying integer types as index types: Index i = n/2;