mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 02:39:03 +08:00
bug #231: initial implementation of STL iterators for dense expressions
This commit is contained in:
parent
2088c0897f
commit
b0c66adfb1
@ -310,6 +310,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/Replicate.h"
|
||||
#include "src/Core/Reverse.h"
|
||||
#include "src/Core/ArrayWrapper.h"
|
||||
#include "src/Core/StlIterators.h"
|
||||
|
||||
#ifdef EIGEN_USE_BLAS
|
||||
#include "src/Core/products/GeneralMatrixMatrix_BLAS.h"
|
||||
|
@ -572,6 +572,17 @@ template<typename Derived> class DenseBase
|
||||
}
|
||||
EIGEN_DEVICE_FUNC void reverseInPlace();
|
||||
|
||||
inline DenseStlIterator<Derived> begin();
|
||||
inline DenseStlIterator<const Derived> begin() const;
|
||||
inline DenseStlIterator<const Derived> cbegin() const;
|
||||
inline DenseStlIterator<Derived> end();
|
||||
inline DenseStlIterator<const Derived> end() const;
|
||||
inline DenseStlIterator<const Derived> cend() const;
|
||||
inline ColsProxy<Derived> allCols();
|
||||
inline ColsProxy<const Derived> allCols() const;
|
||||
inline RowsProxy<Derived> allRows();
|
||||
inline RowsProxy<const Derived> allRows() const;
|
||||
|
||||
#define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase
|
||||
#define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
|
||||
#define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND)
|
||||
|
235
Eigen/src/Core/StlIterators.h
Normal file
235
Eigen/src/Core/StlIterators.h
Normal file
@ -0,0 +1,235 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template<typename XprType,typename Derived>
|
||||
class DenseStlIteratorBase
|
||||
{
|
||||
public:
|
||||
typedef std::ptrdiff_t difference_type;
|
||||
typedef std::random_access_iterator_tag iterator_category;
|
||||
|
||||
DenseStlIteratorBase() : mp_xpr(0), m_index(0) {}
|
||||
DenseStlIteratorBase(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
|
||||
|
||||
void swap(DenseStlIteratorBase& other) {
|
||||
std::swap(mp_xpr,other.mp_xpr);
|
||||
std::swap(m_index,other.m_index);
|
||||
}
|
||||
|
||||
Derived& operator++() { ++m_index; return derived(); }
|
||||
Derived& operator--() { --m_index; return derived(); }
|
||||
|
||||
Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
|
||||
Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
|
||||
|
||||
friend Derived operator+(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret += b; return ret; }
|
||||
friend Derived operator-(const DenseStlIteratorBase& a, int b) { Derived ret(a.derived()); ret -= b; return ret; }
|
||||
friend Derived operator+(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret += a; return ret; }
|
||||
friend Derived operator-(int a, const DenseStlIteratorBase& b) { Derived ret(b.derived()); ret -= a; return ret; }
|
||||
|
||||
Derived& operator+=(int b) { m_index += b; return derived(); }
|
||||
Derived& operator-=(int b) { m_index -= b; return derived(); }
|
||||
|
||||
difference_type operator-(const DenseStlIteratorBase& other) const { eigen_assert(mp_xpr == other.mp_xpr);return m_index - other.m_index; }
|
||||
|
||||
bool operator==(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
|
||||
bool operator!=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
|
||||
bool operator< (const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
|
||||
bool operator<=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
|
||||
bool operator> (const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
|
||||
bool operator>=(const DenseStlIteratorBase& other) { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
|
||||
|
||||
protected:
|
||||
|
||||
Derived& derived() { return static_cast<Derived&>(*this); }
|
||||
const Derived& derived() const { return static_cast<const Derived&>(*this); }
|
||||
|
||||
XprType *mp_xpr;
|
||||
Index m_index;
|
||||
};
|
||||
|
||||
template<typename XprType>
|
||||
class DenseStlIterator : public DenseStlIteratorBase<XprType, DenseStlIterator<XprType> >
|
||||
{
|
||||
public:
|
||||
typedef typename XprType::Scalar value_type;
|
||||
|
||||
protected:
|
||||
|
||||
enum {
|
||||
has_direct_access = (internal::traits<XprType>::Flags & DirectAccessBit) ? 1 : 0,
|
||||
has_write_access = internal::is_lvalue<XprType>::value
|
||||
};
|
||||
|
||||
typedef DenseStlIteratorBase<XprType,DenseStlIterator> Base;
|
||||
using Base::m_index;
|
||||
using Base::mp_xpr;
|
||||
|
||||
typedef typename internal::conditional<bool(has_direct_access), const value_type&, const value_type>::type read_only_ref_t;
|
||||
|
||||
public:
|
||||
|
||||
typedef typename internal::conditional<bool(has_write_access), value_type *, const value_type *>::type pointer;
|
||||
typedef typename internal::conditional<bool(has_write_access), value_type&, read_only_ref_t>::type reference;
|
||||
|
||||
|
||||
DenseStlIterator() : Base() {}
|
||||
DenseStlIterator(XprType& xpr, Index index) : Base(xpr,index) {}
|
||||
|
||||
reference operator*() const { return (*mp_xpr)(m_index); }
|
||||
reference operator[](int i) const { return (*mp_xpr)(i); }
|
||||
|
||||
pointer operator->() const { return &((*mp_xpr)(m_index)); }
|
||||
};
|
||||
|
||||
template<typename XprType,typename Derived>
|
||||
void swap(DenseStlIteratorBase<XprType,Derived>& a, DenseStlIteratorBase<XprType,Derived>& b) {
|
||||
a.swap(b);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<Derived> DenseBase<Derived>::begin()
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<Derived>(derived(), 0);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::begin() const
|
||||
{
|
||||
return cbegin();
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::cbegin() const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<const Derived>(derived(), 0);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<Derived> DenseBase<Derived>::end()
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<Derived>(derived(), size());
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::end() const
|
||||
{
|
||||
return cend();
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
inline DenseStlIterator<const Derived> DenseBase<Derived>::cend() const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return DenseStlIterator<const Derived>(derived(), size());
|
||||
}
|
||||
|
||||
template<typename XprType>
|
||||
class DenseColStlIterator : public DenseStlIteratorBase<XprType, DenseColStlIterator<XprType> >
|
||||
{
|
||||
protected:
|
||||
|
||||
enum { is_lvalue = internal::is_lvalue<XprType>::value };
|
||||
|
||||
typedef DenseStlIteratorBase<XprType,DenseColStlIterator> Base;
|
||||
using Base::m_index;
|
||||
using Base::mp_xpr;
|
||||
|
||||
public:
|
||||
typedef typename internal::conditional<bool(is_lvalue), typename XprType::ColXpr, typename XprType::ConstColXpr>::type value_type;
|
||||
typedef value_type* pointer;
|
||||
typedef value_type reference;
|
||||
|
||||
DenseColStlIterator() : Base() {}
|
||||
DenseColStlIterator(XprType& xpr, Index index) : Base(xpr,index) {}
|
||||
|
||||
reference operator*() const { return (*mp_xpr).col(m_index); }
|
||||
reference operator[](int i) const { return (*mp_xpr).col(i); }
|
||||
|
||||
pointer operator->() const { return &((*mp_xpr).col(m_index)); }
|
||||
};
|
||||
|
||||
template<typename XprType>
|
||||
class DenseRowStlIterator : public DenseStlIteratorBase<XprType, DenseRowStlIterator<XprType> >
|
||||
{
|
||||
protected:
|
||||
|
||||
enum { is_lvalue = internal::is_lvalue<XprType>::value };
|
||||
|
||||
typedef DenseStlIteratorBase<XprType,DenseRowStlIterator> Base;
|
||||
using Base::m_index;
|
||||
using Base::mp_xpr;
|
||||
|
||||
public:
|
||||
typedef typename internal::conditional<bool(is_lvalue), typename XprType::RowXpr, typename XprType::ConstRowXpr>::type value_type;
|
||||
typedef value_type* pointer;
|
||||
typedef value_type reference;
|
||||
|
||||
DenseRowStlIterator() : Base() {}
|
||||
DenseRowStlIterator(XprType& xpr, Index index) : Base(xpr,index) {}
|
||||
|
||||
reference operator*() const { return (*mp_xpr).row(m_index); }
|
||||
reference operator[](int i) const { return (*mp_xpr).row(i); }
|
||||
|
||||
pointer operator->() const { return &((*mp_xpr).row(m_index)); }
|
||||
};
|
||||
|
||||
|
||||
template<typename Xpr>
|
||||
class ColsProxy
|
||||
{
|
||||
public:
|
||||
ColsProxy(Xpr& xpr) : m_xpr(xpr) {}
|
||||
DenseColStlIterator<Xpr> begin() const { return DenseColStlIterator<Xpr>(m_xpr, 0); }
|
||||
DenseColStlIterator<const Xpr> cbegin() const { return DenseColStlIterator<const Xpr>(m_xpr, 0); }
|
||||
|
||||
DenseColStlIterator<Xpr> end() const { return DenseColStlIterator<Xpr>(m_xpr, m_xpr.cols()); }
|
||||
DenseColStlIterator<const Xpr> cend() const { return DenseColStlIterator<const Xpr>(m_xpr, m_xpr.cols()); }
|
||||
|
||||
protected:
|
||||
Xpr& m_xpr;
|
||||
};
|
||||
|
||||
template<typename Xpr>
|
||||
class RowsProxy
|
||||
{
|
||||
public:
|
||||
RowsProxy(Xpr& xpr) : m_xpr(xpr) {}
|
||||
DenseRowStlIterator<Xpr> begin() const { return DenseRowStlIterator<Xpr>(m_xpr, 0); }
|
||||
DenseRowStlIterator<const Xpr> cbegin() const { return DenseRowStlIterator<const Xpr>(m_xpr, 0); }
|
||||
|
||||
DenseRowStlIterator<Xpr> end() const { return DenseRowStlIterator<Xpr>(m_xpr, m_xpr.rows()); }
|
||||
DenseRowStlIterator<const Xpr> cend() const { return DenseRowStlIterator<const Xpr>(m_xpr, m_xpr.rows()); }
|
||||
|
||||
protected:
|
||||
Xpr& m_xpr;
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
ColsProxy<Derived> DenseBase<Derived>::allCols()
|
||||
{ return ColsProxy<Derived>(derived()); }
|
||||
|
||||
template<typename Derived>
|
||||
ColsProxy<const Derived> DenseBase<Derived>::allCols() const
|
||||
{ return ColsProxy<const Derived>(derived()); }
|
||||
|
||||
template<typename Derived>
|
||||
RowsProxy<Derived> DenseBase<Derived>::allRows()
|
||||
{ return RowsProxy<Derived>(derived()); }
|
||||
|
||||
template<typename Derived>
|
||||
RowsProxy<const Derived> DenseBase<Derived>::allRows() const
|
||||
{ return RowsProxy<const Derived>(derived()); }
|
||||
|
||||
} // namespace Eigen
|
@ -133,6 +133,9 @@ template<typename ExpressionType> class ArrayWrapper;
|
||||
template<typename ExpressionType> class MatrixWrapper;
|
||||
template<typename Derived> class SolverBase;
|
||||
template<typename XprType> class InnerIterator;
|
||||
template<typename XprType> class DenseStlIterator;
|
||||
template<typename XprType> class ColsProxy;
|
||||
template<typename XprType> class RowsProxy;
|
||||
|
||||
namespace internal {
|
||||
template<typename DecompositionType> struct kernel_retval_base;
|
||||
|
@ -285,6 +285,7 @@ ei_add_test(inplace_decomposition)
|
||||
ei_add_test(half_float)
|
||||
ei_add_test(array_of_string)
|
||||
ei_add_test(num_dimensions)
|
||||
ei_add_test(stl_iterators)
|
||||
|
||||
add_executable(bug1213 bug1213.cpp bug1213_main.cpp)
|
||||
|
||||
|
128
test/stl_iterators.cpp
Normal file
128
test/stl_iterators.cpp
Normal file
@ -0,0 +1,128 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2018 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template< class Iterator >
|
||||
std::reverse_iterator<Iterator>
|
||||
make_reverse_iterator( Iterator i )
|
||||
{
|
||||
return std::reverse_iterator<Iterator>(i);
|
||||
}
|
||||
|
||||
template<typename Scalar, int Rows, int Cols>
|
||||
void test_range_for_loop(int rows=Rows, int cols=Cols)
|
||||
{
|
||||
using std::begin;
|
||||
using std::end;
|
||||
|
||||
typedef Matrix<Scalar,Rows,1> VectorType;
|
||||
typedef Matrix<Scalar,Rows,Cols,ColMajor> ColMatrixType;
|
||||
typedef Matrix<Scalar,Rows,Cols,RowMajor> RowMatrixType;
|
||||
VectorType v = VectorType::Random(rows);
|
||||
ColMatrixType A = ColMatrixType::Random(rows,cols);
|
||||
RowMatrixType B = RowMatrixType::Random(rows,cols);
|
||||
|
||||
Index i, j;
|
||||
|
||||
#if EIGEN_HAS_CXX11
|
||||
i = 0;
|
||||
for(auto x : v) { VERIFY_IS_EQUAL(x,v[i++]); }
|
||||
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
i = 0;
|
||||
for(auto x : A.col(j)) { VERIFY_IS_EQUAL(x,A(i++,j)); }
|
||||
|
||||
i = 0;
|
||||
for(auto x : (v+A.col(j))) { VERIFY_IS_APPROX(x,v(i)+A(i,j)); ++i; }
|
||||
|
||||
j = 0;
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
for(auto x : A.row(i)) { VERIFY_IS_EQUAL(x,A(i,j++)); }
|
||||
|
||||
i = 0;
|
||||
for(auto x : A.reshaped()) { VERIFY_IS_EQUAL(x,A(i++)); }
|
||||
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> Bc = B;
|
||||
i = 0;
|
||||
for(auto x : B.reshaped()) { VERIFY_IS_EQUAL(x,Bc(i++)); }
|
||||
|
||||
VectorType w(v.size());
|
||||
i = 0;
|
||||
for(auto& x : w) { x = v(i++); }
|
||||
VERIFY_IS_EQUAL(v,w);
|
||||
#endif
|
||||
|
||||
if(rows>=2)
|
||||
{
|
||||
v(1) = v(0)-Scalar(1);
|
||||
VERIFY(!std::is_sorted(begin(v),end(v)));
|
||||
}
|
||||
std::sort(begin(v),end(v));
|
||||
VERIFY(std::is_sorted(begin(v),end(v)));
|
||||
VERIFY(!std::is_sorted(make_reverse_iterator(end(v)),make_reverse_iterator(begin(v))));
|
||||
|
||||
{
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
// std::sort(begin(A.col(j)),end(A.col(j))); // does not compile because this returns const iterators
|
||||
typename ColMatrixType::ColXpr Acol = A.col(j);
|
||||
std::sort(begin(Acol),end(Acol));
|
||||
VERIFY(std::is_sorted(Acol.cbegin(),Acol.cend()));
|
||||
|
||||
// This raises an assert because this creates a pair of iterator referencing two different proxy objects:
|
||||
// std::sort(A.col(j).begin(),A.col(j).end());
|
||||
// VERIFY(std::is_sorted(A.col(j).cbegin(),A.col(j).cend())); // same issue
|
||||
}
|
||||
|
||||
{
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
typename ColMatrixType::ColXpr Acol = A.col(j);
|
||||
std::partial_sum(begin(Acol), end(Acol), begin(v));
|
||||
VERIFY_IS_APPROX(v(seq(1,last)), v(seq(0,last-1))+Acol(seq(1,last)));
|
||||
|
||||
// inplace
|
||||
std::partial_sum(begin(Acol), end(Acol), begin(Acol));
|
||||
VERIFY_IS_APPROX(v, Acol);
|
||||
}
|
||||
|
||||
#if EIGEN_HAS_CXX11
|
||||
j = 0;
|
||||
for(auto c : A.allCols()) { VERIFY_IS_APPROX(c.sum(), A.col(j).sum()); ++j; }
|
||||
j = 0;
|
||||
for(auto c : B.allCols()) { VERIFY_IS_APPROX(c.sum(), B.col(j).sum()); ++j; }
|
||||
|
||||
j = 0;
|
||||
for(auto c : B.allCols()) {
|
||||
i = 0;
|
||||
for(auto& x : c) {
|
||||
VERIFY_IS_EQUAL(x, B(i,j));
|
||||
x = A(i,j);
|
||||
++i;
|
||||
}
|
||||
++j;
|
||||
}
|
||||
VERIFY_IS_APPROX(A,B);
|
||||
B = Bc; // restore B
|
||||
|
||||
i = 0;
|
||||
for(auto r : A.allRows()) { VERIFY_IS_APPROX(r.sum(), A.row(i).sum()); ++i; }
|
||||
i = 0;
|
||||
for(auto r : B.allRows()) { VERIFY_IS_APPROX(r.sum(), B.row(i).sum()); ++i; }
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(stl_iterators)
|
||||
{
|
||||
for(int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1(( test_range_for_loop<double,2,3>() ));
|
||||
CALL_SUBTEST_1(( test_range_for_loop<float,7,5>() ));
|
||||
CALL_SUBTEST_1(( test_range_for_loop<int,Dynamic,Dynamic>(internal::random<int>(10,200), internal::random<int>(10,200)) ));
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user