From 87300c93cae6a8afd9a4f8aa8d9d5c5324cf02e1 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Mon, 17 Apr 2023 12:32:50 +0000 Subject: [PATCH] Refactor IndexedView --- Eigen/src/plugins/IndexedViewMethods.h | 344 +++++++++++++++++-------- test/indexed_view.cpp | 90 ++++++- 2 files changed, 319 insertions(+), 115 deletions(-) diff --git a/Eigen/src/plugins/IndexedViewMethods.h b/Eigen/src/plugins/IndexedViewMethods.h index b796b397f..78f12fe05 100644 --- a/Eigen/src/plugins/IndexedViewMethods.h +++ b/Eigen/src/plugins/IndexedViewMethods.h @@ -7,6 +7,7 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + #if !defined(EIGEN_PARSED_BY_DOXYGEN) protected: @@ -24,163 +25,278 @@ using IvcType = typename internal::IndexedViewCompatibleType::type IvcIndex; template -IvcRowType ivcRow(const Indices& indices) const { +inline IvcRowType ivcRow(const Indices& indices) const { return internal::makeIndexedViewCompatible( indices, internal::variable_if_dynamic(derived().rows()), Specialized); } template -IvcColType ivcCol(const Indices& indices) const { +inline IvcColType ivcCol(const Indices& indices) const { return internal::makeIndexedViewCompatible( indices, internal::variable_if_dynamic(derived().cols()), Specialized); } template -IvcColType ivcSize(const Indices& indices) const { +inline IvcColType ivcSize(const Indices& indices) const { return internal::makeIndexedViewCompatible( indices, internal::variable_if_dynamic(derived().size()), Specialized); } +// this helper class assumes internal::valid_indexed_view_overload::value == true +template , IvcColType>>::ReturnAsScalar, + bool UseBlock = internal::traits, IvcColType>>::ReturnAsBlock, + bool UseGeneric = internal::traits, IvcColType>>::ReturnAsIndexedView> +struct IndexedViewSelector; + +// Generic +template +struct IndexedViewSelector { + using ReturnType = IndexedView, IvcColType>; + using ConstReturnType = IndexedView, IvcColType>; + + static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { + return ReturnType(derived, derived.ivcRow(rowIndices), derived.ivcCol(colIndices)); + } + static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices, + const ColIndices& colIndices) { + return ConstReturnType(derived, derived.ivcRow(rowIndices), derived.ivcCol(colIndices)); + } +}; + +// Block +template +struct IndexedViewSelector { + using IndexedViewType = IndexedView, IvcColType>; + using ConstIndexedViewType = IndexedView, IvcColType>; + using ReturnType = typename internal::traits::BlockType; + using ConstReturnType = typename internal::traits::BlockType; + + static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { + IvcRowType actualRowIndices = derived.ivcRow(rowIndices); + IvcColType actualColIndices = derived.ivcCol(colIndices); + return ReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices), + internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); + } + static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices, + const ColIndices& colIndices) { + IvcRowType actualRowIndices = derived.ivcRow(rowIndices); + IvcColType actualColIndices = derived.ivcCol(colIndices); + return ConstReturnType(derived, internal::first(actualRowIndices), internal::first(actualColIndices), + internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); + } +}; + +// Symbolic +template +struct IndexedViewSelector { + using ReturnType = typename DenseBase::Scalar&; + using ConstReturnType = typename DenseBase::CoeffReturnType; + + static inline ReturnType run(Derived& derived, const RowIndices& rowIndices, const ColIndices& colIndices) { + return derived(internal::eval_expr_given_size(rowIndices, derived.rows()), + internal::eval_expr_given_size(colIndices, derived.cols())); + } + static inline ConstReturnType run(const Derived& derived, const RowIndices& rowIndices, + const ColIndices& colIndices) { + return derived(internal::eval_expr_given_size(rowIndices, derived.rows()), + internal::eval_expr_given_size(colIndices, derived.cols())); + } +}; + +// this helper class assumes internal::is_valid_index_type::value == false +template ::value, + bool UseBlock = !UseSymbolic && internal::get_compile_time_incr>::value == 1, + bool UseGeneric = !UseSymbolic && !UseBlock> +struct VectorIndexedViewSelector; + +// Generic +template +struct VectorIndexedViewSelector { + + static constexpr bool IsRowMajor = DenseBase::IsRowMajor; + + using RowMajorReturnType = IndexedView>; + using ConstRowMajorReturnType = IndexedView>; + + using ColMajorReturnType = IndexedView, IvcIndex>; + using ConstColMajorReturnType = IndexedView, IvcIndex>; + + using ReturnType = typename internal::conditional::type; + using ConstReturnType = + typename internal::conditional::type; + + template = true> + static inline RowMajorReturnType run(Derived& derived, const Indices& indices) { + return RowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices)); + } + template = true> + static inline ConstRowMajorReturnType run(const Derived& derived, const Indices& indices) { + return ConstRowMajorReturnType(derived, IvcIndex(0), derived.ivcCol(indices)); + } + template = true> + static inline ColMajorReturnType run(Derived& derived, const Indices& indices) { + return ColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0)); + } + template = true> + static inline ConstColMajorReturnType run(const Derived& derived, const Indices& indices) { + return ConstColMajorReturnType(derived, derived.ivcRow(indices), IvcIndex(0)); + } +}; + +// Block +template +struct VectorIndexedViewSelector { + + using ReturnType = VectorBlock::value>; + using ConstReturnType = VectorBlock::value>; + + static inline ReturnType run(Derived& derived, const Indices& indices) { + IvcType actualIndices = derived.ivcSize(indices); + return ReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices)); + } + static inline ConstReturnType run(const Derived& derived, const Indices& indices) { + IvcType actualIndices = derived.ivcSize(indices); + return ConstReturnType(derived, internal::first(actualIndices), internal::index_list_size(actualIndices)); + } +}; + +// Symbolic +template +struct VectorIndexedViewSelector { + + using ReturnType = typename DenseBase::Scalar&; + using ConstReturnType = typename DenseBase::CoeffReturnType; + + static inline ReturnType run(Derived& derived, const Indices& id) { + return derived(internal::eval_expr_given_size(id, derived.size())); + } + static inline ConstReturnType run(const Derived& derived, const Indices& id) { + return derived(internal::eval_expr_given_size(id, derived.size())); + } +}; + +// SFINAE dummy types + +template +using EnableOverload = std::enable_if_t< + internal::valid_indexed_view_overload::value && internal::is_lvalue::value, bool>; + +template +using EnableConstOverload = + std::enable_if_t::value, bool>; + +template +using EnableVectorOverload = + std::enable_if_t::value && internal::is_lvalue::value, bool>; + +template +using EnableConstVectorOverload = std::enable_if_t::value, bool>; + public: -template -using IndexedViewType = IndexedView, IvcColType>; +// Public API for 2D matrices/arrays + +// non-const versions template -using ConstIndexedViewType = IndexedView, IvcColType>; +using IndexedViewType = typename IndexedViewSelector::ReturnType; -// This is the generic version - -template -std::enable_if_t::value && - internal::traits>::ReturnAsIndexedView, - IndexedViewType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { - return IndexedViewType(derived(), ivcRow(rowIndices), ivcCol(colIndices)); +template = true> +IndexedViewType operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { + return IndexedViewSelector::run(derived(), rowIndices, colIndices); } -template -std::enable_if_t::value && - internal::traits>::ReturnAsIndexedView, - ConstIndexedViewType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { - return ConstIndexedViewType(derived(), ivcRow(rowIndices), ivcCol(colIndices)); +template , + EnableOverload = true> +IndexedViewType operator()(const RowType (&rowIndices)[RowSize], const ColIndices& colIndices) { + return IndexedViewSelector::run(derived(), RowIndices{rowIndices}, colIndices); } -// The following overload returns a Block<> object - -template -std::enable_if_t::value && - internal::traits>::ReturnAsBlock, - typename internal::traits>::BlockType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { - typedef typename internal::traits>::BlockType BlockType; - IvcRowType actualRowIndices = ivcRow(rowIndices); - IvcColType actualColIndices = ivcCol(colIndices); - return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices), - internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); +template , + EnableOverload = true> +IndexedViewType operator()(const RowIndices& rowIndices, const ColType (&colIndices)[ColSize]) { + return IndexedViewSelector::run(derived(), rowIndices, ColIndices{colIndices}); } -template -std::enable_if_t::value && - internal::traits>::ReturnAsBlock, - typename internal::traits>::BlockType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { - typedef typename internal::traits>::BlockType BlockType; - IvcRowType actualRowIndices = ivcRow(rowIndices); - IvcColType actualColIndices = ivcCol(colIndices); - return BlockType(derived(), internal::first(actualRowIndices), internal::first(actualColIndices), - internal::index_list_size(actualRowIndices), internal::index_list_size(actualColIndices)); +template , typename ColIndices = Array, + EnableOverload = true> +IndexedViewType operator()(const RowType (&rowIndices)[RowSize], + const ColType (&colIndices)[ColSize]) { + return IndexedViewSelector::run(derived(), RowIndices{rowIndices}, ColIndices{colIndices}); } -// The following overload returns a Scalar +// const versions template -std::enable_if_t::value && - internal::traits>::ReturnAsScalar && internal::is_lvalue::value, - Scalar&> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) { - return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()), - internal::eval_expr_given_size(colIndices, cols())); +using ConstIndexedViewType = typename IndexedViewSelector::ConstReturnType; + +template = true> +ConstIndexedViewType operator()(const RowIndices& rowIndices, + const ColIndices& colIndices) const { + return IndexedViewSelector::run(derived(), rowIndices, colIndices); } -template -std::enable_if_t::value && - internal::traits>::ReturnAsScalar, - CoeffReturnType> -operator()(const RowIndices& rowIndices, const ColIndices& colIndices) const { - return Base::operator()(internal::eval_expr_given_size(rowIndices, rows()), - internal::eval_expr_given_size(colIndices, cols())); +template , + EnableConstOverload = true> +ConstIndexedViewType operator()(const RowType (&rowIndices)[RowSize], + const ColIndices& colIndices) const { + return IndexedViewSelector::run(derived(), RowIndices{rowIndices}, colIndices); } -// Overloads for 1D vectors/arrays +template , + EnableConstOverload = true> +ConstIndexedViewType operator()(const RowIndices& rowIndices, + const ColType (&colIndices)[ColSize]) const { + return IndexedViewSelector::run(derived(), rowIndices, ColIndices{colIndices}); +} + +template , typename ColIndices = Array, + EnableConstOverload = true> +ConstIndexedViewType operator()(const RowType (&rowIndices)[RowSize], + const ColType (&colIndices)[ColSize]) const { + return IndexedViewSelector::run(derived(), RowIndices{rowIndices}, ColIndices{colIndices}); +} + +// Public API for 1D vectors/arrays + +// non-const versions template -std::enable_if_t>::value == 1 || - internal::is_valid_index_type::value)), - IndexedView>> -operator()(const Indices& indices) { +using VectorIndexedViewType = typename VectorIndexedViewSelector::ReturnType; + +template = true> +VectorIndexedViewType operator()(const Indices& indices) { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView>(derived(), IvcIndex(0), ivcCol(indices)); + return VectorIndexedViewSelector::run(derived(), indices); } +template , + EnableVectorOverload = true> +VectorIndexedViewType operator()(const IndexType (&indices)[Size]) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return VectorIndexedViewSelector::run(derived(), Indices{indices}); +} + +// const versions + template -std::enable_if_t>::value == 1 || - internal::is_valid_index_type::value)), - IndexedView>> -operator()(const Indices& indices) const { +using ConstVectorIndexedViewType = typename VectorIndexedViewSelector::ConstReturnType; + +template = true> +ConstVectorIndexedViewType operator()(const Indices& indices) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView>(derived(), IvcIndex(0), ivcCol(indices)); + return VectorIndexedViewSelector::run(derived(), indices); } -template -std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr>::value == 1 || - internal::is_valid_index_type::value)), - IndexedView, IvcIndex>> -operator()(const Indices& indices) { +template , + EnableConstVectorOverload = true> +ConstVectorIndexedViewType operator()(const IndexType (&indices)[Size]) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0)); -} - -template -std::enable_if_t<(!IsRowMajor) && (!(internal::get_compile_time_incr>::value == 1 || - internal::is_valid_index_type::value)), - IndexedView, IvcIndex>> -operator()(const Indices& indices) const { - EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - return IndexedView, IvcIndex>(derived(), ivcRow(indices), IvcIndex(0)); -} - -template -std::enable_if_t<(internal::get_compile_time_incr>::value == 1) && - (!internal::is_valid_index_type::value) && (!symbolic::is_symbolic::value), - VectorBlock::value>> -operator()(const Indices& indices) { - EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - IvcType actualIndices = ivcSize(indices); - return VectorBlock::value>(derived(), internal::first(actualIndices), - internal::index_list_size(actualIndices)); -} - -template -std::enable_if_t<(internal::get_compile_time_incr>::value == 1) && - (!internal::is_valid_index_type::value) && (!symbolic::is_symbolic::value), - VectorBlock::value>> -operator()(const Indices& indices) const { - EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - IvcType actualIndices = ivcSize(indices); - return VectorBlock::value>(derived(), internal::first(actualIndices), - internal::index_list_size(actualIndices)); -} - -template -std::enable_if_t::value && internal::is_lvalue::value, Scalar&> operator()(const IndexType& id) { - return Base::operator()(internal::eval_expr_given_size(id, size())); -} - -template -std::enable_if_t::value, CoeffReturnType> operator()(const IndexType& id) const { - return Base::operator()(internal::eval_expr_given_size(id, size())); + return VectorIndexedViewSelector::run(derived(), Indices{indices}); } #else // EIGEN_PARSED_BY_DOXYGEN diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index 84a47679d..41ba52153 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -295,6 +295,69 @@ void check_indexed_view() VERIFY_IS_EQUAL( a(std::array{1,3,5}).SizeAtCompileTime, 3 ); VERIFY_IS_EQUAL( b(std::array{1,3,5}).SizeAtCompileTime, 3 ); + // check different index types (C-style array, STL container, Eigen type) + { + Index size = 10; + ArrayXd r = ArrayXd::Random(size); + ArrayXi idx = ArrayXi::EqualSpaced(size, 0, 1); + std::shuffle(idx.begin(), idx.end(), std::random_device()); + + int c_array[3] = { idx[0], idx[1], idx[2] }; + std::vector std_vector{ idx[0], idx[1], idx[2] }; + Matrix eigen_matrix{ idx[0], idx[1], idx[2] }; + + // non-const access + VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(c_array)); + VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(std_vector)); + VERIFY_IS_CWISE_EQUAL(r({ idx[0], idx[1], idx[2] }), r(eigen_matrix)); + VERIFY_IS_CWISE_EQUAL(r(std_vector), r(c_array)); + VERIFY_IS_CWISE_EQUAL(r(std_vector), r(eigen_matrix)); + VERIFY_IS_CWISE_EQUAL(r(eigen_matrix), r(c_array)); + + const ArrayXd& r_ref = r; + // const access + VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(c_array)); + VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(std_vector)); + VERIFY_IS_CWISE_EQUAL(r_ref({ idx[0], idx[1], idx[2] }), r_ref(eigen_matrix)); + VERIFY_IS_CWISE_EQUAL(r_ref(std_vector), r_ref(c_array)); + VERIFY_IS_CWISE_EQUAL(r_ref(std_vector), r_ref(eigen_matrix)); + VERIFY_IS_CWISE_EQUAL(r_ref(eigen_matrix), r_ref(c_array)); + } + + { + Index rows = 8; + Index cols = 11; + ArrayXXd R = ArrayXXd::Random(rows, cols); + ArrayXi r_idx = ArrayXi::EqualSpaced(rows, 0, 1); + ArrayXi c_idx = ArrayXi::EqualSpaced(cols, 0, 1); + std::shuffle(r_idx.begin(), r_idx.end(), std::random_device()); + std::shuffle(c_idx.begin(), c_idx.end(), std::random_device()); + + int c_array_rows[3] = { r_idx[0], r_idx[1], r_idx[2] }; + int c_array_cols[4] = { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }; + std::vector std_vector_rows{ r_idx[0], r_idx[1], r_idx[2] }; + std::vector std_vector_cols{ c_idx[0], c_idx[1], c_idx[2], c_idx[3] }; + Matrix eigen_matrix_rows{ r_idx[0], r_idx[1], r_idx[2] }; + Matrix eigen_matrix_cols{ c_idx[0], c_idx[1], c_idx[2], c_idx[3] }; + + // non-const access + VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(c_array_rows, c_array_cols)); + VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(std_vector_rows, std_vector_cols)); + VERIFY_IS_CWISE_EQUAL(R({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R(eigen_matrix_rows, eigen_matrix_cols)); + VERIFY_IS_CWISE_EQUAL(R(std_vector_rows, std_vector_cols), R(c_array_rows, c_array_cols)); + VERIFY_IS_CWISE_EQUAL(R(std_vector_rows, std_vector_cols), R(eigen_matrix_rows, eigen_matrix_cols)); + VERIFY_IS_CWISE_EQUAL(R(eigen_matrix_rows, eigen_matrix_cols), R(c_array_rows, c_array_cols)); + + const ArrayXXd& R_ref = R; + // const access + VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(c_array_rows, c_array_cols)); + VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(std_vector_rows, std_vector_cols)); + VERIFY_IS_CWISE_EQUAL(R_ref({ r_idx[0], r_idx[1], r_idx[2] }, { c_idx[0], c_idx[1], c_idx[2], c_idx[3] }), R_ref(eigen_matrix_rows, eigen_matrix_cols)); + VERIFY_IS_CWISE_EQUAL(R_ref(std_vector_rows, std_vector_cols), R_ref(c_array_rows, c_array_cols)); + VERIFY_IS_CWISE_EQUAL(R_ref(std_vector_rows, std_vector_cols), R_ref(eigen_matrix_rows, eigen_matrix_cols)); + VERIFY_IS_CWISE_EQUAL(R_ref(eigen_matrix_rows, eigen_matrix_cols), R_ref(c_array_rows, c_array_cols)); + } + // check mat(i,j) with weird types for i and j { VERIFY_IS_APPROX( A(B.RowsAtCompileTime-1, 1), A(3,1) ); @@ -357,8 +420,33 @@ void check_indexed_view() A(XX,Y) = 1; A(X,YY) = 1; // check symbolic indices - a(last) = 1; + a(last) = 1.0; A(last, last) = 1; + // check weird non-const, non-lvalue scenarios + { + // in these scenarios, the objects are not declared 'const', and the compiler will atttempt to use the non-const + // overloads without intervention + + // non-const map to a const object + Map a_map(a.data(), a.size()); + Map A_map(A.data(), A.rows(), A.cols()); + + VERIFY_IS_EQUAL(a_map(last), a.coeff(a.size() - 1)); + VERIFY_IS_EQUAL(A_map(last, last), A.coeff(A.rows() - 1, A.cols() - 1)); + + // non-const expressions that have no modifiable data + using Op = internal::scalar_constant_op; + using VectorXpr = CwiseNullaryOp; + using MatrixXpr = CwiseNullaryOp; + double constant_val = internal::random(); + Op op(constant_val); + VectorXpr vectorXpr(10, 1, op); + MatrixXpr matrixXpr(8, 11, op); + + VERIFY_IS_EQUAL(vectorXpr.coeff(vectorXpr.size() - 1), vectorXpr(last)); + VERIFY_IS_EQUAL(matrixXpr.coeff(matrixXpr.rows() - 1, matrixXpr.cols() - 1), matrixXpr(last, last)); + } + // Check compilation of varying integer types as index types: Index i = n/2;