diff --git a/Eigen/src/plugins/IndexedViewMethods.h b/Eigen/src/plugins/IndexedViewMethods.h index ea1aa6e2e..7d63f8d62 100644 --- a/Eigen/src/plugins/IndexedViewMethods.h +++ b/Eigen/src/plugins/IndexedViewMethods.h @@ -13,7 +13,7 @@ #ifndef EIGEN_INDEXED_VIEW_METHOD_2ND_PASS #define EIGEN_INDEXED_VIEW_METHOD_CONST const -#define EIGEN_INDEXED_VIEW_METHOD_TYPE ConstIndexedViewType +#define EIGEN_INDEXED_VIEW_METHOD_TYPE ConstIndexedViewType #else #define EIGEN_INDEXED_VIEW_METHOD_CONST #define EIGEN_INDEXED_VIEW_METHOD_TYPE IndexedViewType @@ -84,6 +84,62 @@ operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndicesT (&col (derived(), rowIndices, colIndices); } +// Overloads for 1D vectors/arrays + +template +typename internal::enable_if< + IsRowMajor && (!(internal::get_compile_time_incr::type>::value==1 || internal::is_integral::value)), + IndexedView::type,typename internal::MakeIndexing::type> >::type +operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView::type,typename internal::MakeIndexing::type> + (derived(), internal::make_indexing(0,derived().rows()), internal::make_indexing(indices,derived().cols())); +} + +template +typename internal::enable_if< + (!IsRowMajor) && (!(internal::get_compile_time_incr::type>::value==1 || internal::is_integral::value)), + IndexedView::type,typename internal::MakeIndexing::type> >::type +operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView::type,typename internal::MakeIndexing::type> + (derived(), internal::make_indexing(indices,derived().rows()), internal::make_indexing(Index(0),derived().cols())); +} + +template +typename internal::enable_if< + (internal::get_compile_time_incr::type>::value==1) && (!internal::is_integral::value), + VectorBlock::value> >::type +operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + typename internal::MakeIndexing::type actualIndices = internal::make_indexing(indices,derived().size()); + return VectorBlock::value> + (derived(), internal::first(actualIndices), internal::size(actualIndices)); +} + +template +typename internal::enable_if::type,const IndicesT (&)[IndicesN]> >::type +operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView::type,const IndicesT (&)[IndicesN]> + (derived(), internal::make_indexing(0,derived().rows()), indices); +} + +template +typename internal::enable_if::type> >::type +operator()(const IndicesT (&indices)[IndicesN]) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return IndexedView::type> + (derived(), indices, internal::make_indexing(0,derived().rows())); +} + #undef EIGEN_INDEXED_VIEW_METHOD_CONST #undef EIGEN_INDEXED_VIEW_METHOD_TYPE @@ -123,11 +179,21 @@ operator()(const RowIndicesT (&rowIndices)[RowIndicesN], const ColIndicesT (&col * Otherwise a more general IndexedView object will be returned, after conversion of the inputs * to more suitable types \c RowIndices' and \c ColIndices'. * - * \sa class Block, class IndexedView, DenseBase::block(Index,Index,Index,Index) + * For 1D vectors and arrays, you better use the operator()(const Indices&) overload, which behave the same way but taking a single parameter. + * + * \sa operator()(const Indices&), class Block, class IndexedView, DenseBase::block(Index,Index,Index,Index) */ template IndexedView_or_Block operator()(const RowIndices& rowIndices, const ColIndices& colIndices); +/** This is an overload of operator()(const RowIndices&, const ColIndices&) for 1D vectors or arrays + * + * \only_for_vectors + */ +template +IndexedView_or_VectorBlock +operator()(const Indices& indices); + #endif // EIGEN_PARSED_BY_DOXYGEN diff --git a/test/indexed_view.cpp b/test/indexed_view.cpp index 42d136847..c15a8306a 100644 --- a/test/indexed_view.cpp +++ b/test/indexed_view.cpp @@ -33,9 +33,10 @@ IndexPair decode(Index ij) { template bool match(const T& xpr, std::string ref, std::string str_xpr = "") { EIGEN_UNUSED_VARIABLE(str_xpr); - //std::cout << str_xpr << "\n" << xpr << "\n\n"; std::stringstream str; str << xpr; + if(!(str.str() == ref)) + std::cout << str_xpr << "\n" << xpr << "\n\n"; return str.str() == ref; } @@ -55,15 +56,16 @@ void check_indexed_view() Index n = 10; + ArrayXd a = ArrayXd::LinSpaced(n,0,n-1); + Array b = a.transpose(); + ArrayXXi A = ArrayXXi::NullaryExpr(n,n, std::ptr_fun(encode)); for(Index i=0; i vala(10); Map(&vala[0],10) = eia; std::valarray vali(4); Map(&vali[0],4) = eii; std::vector veci(4); Map(veci.data(),4) = eii; @@ -118,6 +120,19 @@ void check_indexed_view() "300 301 302 303 304 305 306 307 308 309") ); + VERIFY( MATCH( a(seqN(3,3),0), "3\n4\n5" ) ); + VERIFY( MATCH( a(seq(3,5)), "3\n4\n5" ) ); + VERIFY( MATCH( a(seqN(3,3,1)), "3\n4\n5" ) ); + VERIFY( MATCH( a(seqN(5,3,-1)), "5\n4\n3" ) ); + + VERIFY( MATCH( b(0,seqN(3,3)), "3 4 5" ) ); + VERIFY( MATCH( b(seq(3,5)), "3 4 5" ) ); + VERIFY( MATCH( b(seqN(3,3,1)), "3 4 5" ) ); + VERIFY( MATCH( b(seqN(5,3,-1)), "5 4 3" ) ); + + VERIFY( MATCH( b(all), "0 1 2 3 4 5 6 7 8 9" ) ); + VERIFY( MATCH( b(eii), "3 1 6 5" ) ); + Array44i B; B.setRandom(); VERIFY( (A(seqN(2,5), 5)).ColsAtCompileTime == 1); @@ -180,6 +195,11 @@ void check_indexed_view() VERIFY( is_same_type(cA.block(0,0,2,2), cA(seqN(0,2),seq(0,1))) ); VERIFY( is_same_type(cA.middleRows(2,4), cA(seqN(2,4),all)) ); VERIFY( is_same_type(cA.middleCols(2,4), cA(all,seqN(2,4))) ); + + VERIFY( is_same_type(a.head(4), a(seq(0,3))) ); + VERIFY( is_same_type(a.tail(4), a(seqN(last-3,4))) ); + VERIFY( is_same_type(a.tail(4), a(seq(end-4,last))) ); + VERIFY( is_same_type(a.segment<4>(3), a(seqN(3,fix<4>))) ); } ArrayXXi A1=A, A2 = ArrayXXi::Random(4,4); @@ -203,6 +223,12 @@ void check_indexed_view() VERIFY_IS_EQUAL( A({1,3,5},{3, 1, 6, 5}).RowsAtCompileTime, 3 ); VERIFY_IS_EQUAL( A({1,3,5},{3, 1, 6, 5}).ColsAtCompileTime, 4 ); + + VERIFY_IS_APPROX( a({3, 1, 6, 5}), a(std::array{{3, 1, 6, 5}}) ); + VERIFY_IS_EQUAL( a({1,3,5}).SizeAtCompileTime, 3 ); + + VERIFY_IS_APPROX( b({3, 1, 6, 5}), b(std::array{{3, 1, 6, 5}}) ); + VERIFY_IS_EQUAL( b({1,3,5}).SizeAtCompileTime, 3 ); #endif #endif