mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 11:19:02 +08:00
Added support for reverse iterators for Vectorwise operations.
This commit is contained in:
parent
fa8fd4b4d5
commit
d640276d31
@ -93,6 +93,85 @@ protected:
|
||||
Index m_index;
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
class indexed_based_stl_reverse_iterator_base
|
||||
{
|
||||
protected:
|
||||
typedef indexed_based_stl_iterator_traits<Derived> traits;
|
||||
typedef typename traits::XprType XprType;
|
||||
typedef indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator> non_const_iterator;
|
||||
typedef indexed_based_stl_reverse_iterator_base<typename traits::const_iterator> const_iterator;
|
||||
typedef typename internal::conditional<internal::is_const<XprType>::value,non_const_iterator,const_iterator>::type other_iterator;
|
||||
// NOTE: in C++03 we cannot declare friend classes through typedefs because we need to write friend class:
|
||||
friend class indexed_based_stl_reverse_iterator_base<typename traits::const_iterator>;
|
||||
friend class indexed_based_stl_reverse_iterator_base<typename traits::non_const_iterator>;
|
||||
public:
|
||||
typedef Index difference_type;
|
||||
typedef std::random_access_iterator_tag iterator_category;
|
||||
|
||||
indexed_based_stl_reverse_iterator_base() : mp_xpr(0), m_index(0) {}
|
||||
indexed_based_stl_reverse_iterator_base(XprType& xpr, Index index) : mp_xpr(&xpr), m_index(index) {}
|
||||
|
||||
indexed_based_stl_reverse_iterator_base(const non_const_iterator& other)
|
||||
: mp_xpr(other.mp_xpr), m_index(other.m_index)
|
||||
{}
|
||||
|
||||
indexed_based_stl_reverse_iterator_base& operator=(const non_const_iterator& other)
|
||||
{
|
||||
mp_xpr = other.mp_xpr;
|
||||
m_index = other.m_index;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Derived& operator++() { --m_index; return derived(); }
|
||||
Derived& operator--() { ++m_index; return derived(); }
|
||||
|
||||
Derived operator++(int) { Derived prev(derived()); operator++(); return prev;}
|
||||
Derived operator--(int) { Derived prev(derived()); operator--(); return prev;}
|
||||
|
||||
friend Derived operator+(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret += b; return ret; }
|
||||
friend Derived operator-(const indexed_based_stl_reverse_iterator_base& a, Index b) { Derived ret(a.derived()); ret -= b; return ret; }
|
||||
friend Derived operator+(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret += a; return ret; }
|
||||
friend Derived operator-(Index a, const indexed_based_stl_reverse_iterator_base& b) { Derived ret(b.derived()); ret -= a; return ret; }
|
||||
|
||||
Derived& operator+=(Index b) { m_index -= b; return derived(); }
|
||||
Derived& operator-=(Index b) { m_index += b; return derived(); }
|
||||
|
||||
difference_type operator-(const indexed_based_stl_reverse_iterator_base& other) const
|
||||
{
|
||||
eigen_assert(mp_xpr == other.mp_xpr);
|
||||
return other.m_index - m_index;
|
||||
}
|
||||
|
||||
difference_type operator-(const other_iterator& other) const
|
||||
{
|
||||
eigen_assert(mp_xpr == other.mp_xpr);
|
||||
return other.m_index - m_index;
|
||||
}
|
||||
|
||||
bool operator==(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
|
||||
bool operator!=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
|
||||
bool operator< (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
|
||||
bool operator<=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
|
||||
bool operator> (const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
|
||||
bool operator>=(const indexed_based_stl_reverse_iterator_base& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
|
||||
|
||||
bool operator==(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index == other.m_index; }
|
||||
bool operator!=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index != other.m_index; }
|
||||
bool operator< (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index > other.m_index; }
|
||||
bool operator<=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index >= other.m_index; }
|
||||
bool operator> (const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index < other.m_index; }
|
||||
bool operator>=(const other_iterator& other) const { eigen_assert(mp_xpr == other.mp_xpr); return m_index <= other.m_index; }
|
||||
|
||||
protected:
|
||||
|
||||
Derived& derived() { return static_cast<Derived&>(*this); }
|
||||
const Derived& derived() const { return static_cast<const Derived&>(*this); }
|
||||
|
||||
XprType *mp_xpr;
|
||||
Index m_index;
|
||||
};
|
||||
|
||||
template<typename XprType>
|
||||
class pointer_based_stl_iterator
|
||||
{
|
||||
@ -267,6 +346,54 @@ public:
|
||||
pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
|
||||
};
|
||||
|
||||
template<typename _XprType, DirectionType Direction>
|
||||
struct indexed_based_stl_iterator_traits<subvector_stl_reverse_iterator<_XprType,Direction> >
|
||||
{
|
||||
typedef _XprType XprType;
|
||||
typedef subvector_stl_reverse_iterator<typename internal::remove_const<XprType>::type, Direction> non_const_iterator;
|
||||
typedef subvector_stl_reverse_iterator<typename internal::add_const<XprType>::type, Direction> const_iterator;
|
||||
};
|
||||
|
||||
template<typename XprType, DirectionType Direction>
|
||||
class subvector_stl_reverse_iterator : public indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator<XprType,Direction> >
|
||||
{
|
||||
protected:
|
||||
|
||||
enum { is_lvalue = internal::is_lvalue<XprType>::value };
|
||||
|
||||
typedef indexed_based_stl_reverse_iterator_base<subvector_stl_reverse_iterator> Base;
|
||||
using Base::m_index;
|
||||
using Base::mp_xpr;
|
||||
|
||||
typedef typename internal::conditional<Direction==Vertical,typename XprType::ColXpr,typename XprType::RowXpr>::type SubVectorType;
|
||||
typedef typename internal::conditional<Direction==Vertical,typename XprType::ConstColXpr,typename XprType::ConstRowXpr>::type ConstSubVectorType;
|
||||
|
||||
|
||||
public:
|
||||
typedef typename internal::conditional<bool(is_lvalue), SubVectorType, ConstSubVectorType>::type reference;
|
||||
typedef typename reference::PlainObject value_type;
|
||||
|
||||
private:
|
||||
class subvector_stl_reverse_iterator_ptr
|
||||
{
|
||||
public:
|
||||
subvector_stl_reverse_iterator_ptr(const reference &subvector) : m_subvector(subvector) {}
|
||||
reference* operator->() { return &m_subvector; }
|
||||
private:
|
||||
reference m_subvector;
|
||||
};
|
||||
public:
|
||||
|
||||
typedef subvector_stl_reverse_iterator_ptr pointer;
|
||||
|
||||
subvector_stl_reverse_iterator() : Base() {}
|
||||
subvector_stl_reverse_iterator(XprType& xpr, Index index) : Base(xpr,index) {}
|
||||
|
||||
reference operator*() const { return (*mp_xpr).template subVector<Direction>(m_index); }
|
||||
reference operator[](Index i) const { return (*mp_xpr).template subVector<Direction>(m_index+i); }
|
||||
pointer operator->() const { return (*mp_xpr).template subVector<Direction>(m_index); }
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
|
||||
@ -328,4 +455,4 @@ inline typename DenseBase<Derived>::const_iterator DenseBase<Derived>::cend() co
|
||||
return const_iterator(derived(), size());
|
||||
}
|
||||
|
||||
} // namespace Eigen
|
||||
} // namespace Eigen
|
@ -279,27 +279,47 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
|
||||
/** This is the const version of iterator (aka read-only) */
|
||||
random_access_iterator_type const_iterator;
|
||||
#else
|
||||
typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator;
|
||||
typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator;
|
||||
typedef internal::subvector_stl_iterator<ExpressionType, DirectionType(Direction)> iterator;
|
||||
typedef internal::subvector_stl_iterator<const ExpressionType, DirectionType(Direction)> const_iterator;
|
||||
typedef internal::subvector_stl_reverse_iterator<ExpressionType, DirectionType(Direction)> reverse_iterator;
|
||||
typedef internal::subvector_stl_reverse_iterator<const ExpressionType, DirectionType(Direction)> const_reverse_iterator;
|
||||
#endif
|
||||
|
||||
/** returns an iterator to the first row (rowwise) or column (colwise) of the nested expression.
|
||||
* \sa end(), cbegin()
|
||||
*/
|
||||
iterator begin() { return iterator (m_matrix, 0); }
|
||||
iterator begin() { return iterator (m_matrix, 0); }
|
||||
/** const version of begin() */
|
||||
const_iterator begin() const { return const_iterator(m_matrix, 0); }
|
||||
const_iterator begin() const { return const_iterator(m_matrix, 0); }
|
||||
/** const version of begin() */
|
||||
const_iterator cbegin() const { return const_iterator(m_matrix, 0); }
|
||||
const_iterator cbegin() const { return const_iterator(m_matrix, 0); }
|
||||
|
||||
/** returns a reverse iterator to the last row (rowwise) or column (colwise) of the nested expression.
|
||||
* \sa rend(), crbegin()
|
||||
*/
|
||||
reverse_iterator rbegin() { return reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
|
||||
/** const version of rbegin() */
|
||||
const_reverse_iterator rbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
|
||||
/** const version of rbegin() */
|
||||
const_reverse_iterator crbegin() const { return const_reverse_iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()-1); }
|
||||
|
||||
/** returns an iterator to the row (resp. column) following the last row (resp. column) of the nested expression
|
||||
* \sa begin(), cend()
|
||||
*/
|
||||
iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
iterator end() { return iterator (m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
/** const version of end() */
|
||||
const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
const_iterator end() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
/** const version of end() */
|
||||
const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
const_iterator cend() const { return const_iterator(m_matrix, m_matrix.template subVectors<DirectionType(Direction)>()); }
|
||||
|
||||
/** returns a reverse iterator to the row (resp. column) before the first row (resp. column) of the nested expression
|
||||
* \sa begin(), cend()
|
||||
*/
|
||||
reverse_iterator rend() { return reverse_iterator (m_matrix, -1); }
|
||||
/** const version of rend() */
|
||||
const_reverse_iterator rend() const { return const_reverse_iterator (m_matrix, -1); }
|
||||
/** const version of rend() */
|
||||
const_reverse_iterator crend() const { return const_reverse_iterator (m_matrix, -1); }
|
||||
|
||||
/** \returns a row or column vector expression of \c *this reduxed by \a func
|
||||
*
|
||||
|
@ -134,6 +134,7 @@ namespace internal {
|
||||
template<typename XprType> class generic_randaccess_stl_iterator;
|
||||
template<typename XprType> class pointer_based_stl_iterator;
|
||||
template<typename XprType, DirectionType Direction> class subvector_stl_iterator;
|
||||
template<typename XprType, DirectionType Direction> class subvector_stl_reverse_iterator;
|
||||
template<typename DecompositionType> struct kernel_retval_base;
|
||||
template<typename DecompositionType> struct kernel_retval;
|
||||
template<typename DecompositionType> struct image_retval_base;
|
||||
|
@ -431,22 +431,27 @@ void test_stl_iterators(int rows=Rows, int cols=Cols)
|
||||
{
|
||||
RowVectorType row = RowVectorType::Random(cols);
|
||||
A.rowwise() = row;
|
||||
VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.rowwise().begin(), A.rowwise().end(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.rowwise().rbegin(), A.rowwise().rend(), [&row](typename ColMatrixType::RowXpr x) { return internal::isApprox(x.squaredNorm(),row.squaredNorm()); }) );
|
||||
|
||||
VectorType col = VectorType::Random(rows);
|
||||
A.colwise() = col;
|
||||
VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.colwise().begin(), A.colwise().end(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.colwise().rbegin(), A.colwise().rend(), [&col](typename ColMatrixType::ColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.colwise().cbegin(), A.colwise().cend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
VERIFY( std::all_of(A.colwise().crbegin(), A.colwise().crend(), [&col](typename ColMatrixType::ConstColXpr x) { return internal::isApprox(x.squaredNorm(),col.squaredNorm()); }) );
|
||||
|
||||
i = internal::random<Index>(0,A.rows()-1);
|
||||
A.setRandom();
|
||||
A.row(i).setZero();
|
||||
VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i );
|
||||
VERIFY_IS_EQUAL( std::find_if(A.rowwise().begin(), A.rowwise().end(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().begin(), i );
|
||||
VERIFY_IS_EQUAL( std::find_if(A.rowwise().rbegin(), A.rowwise().rend(), [](typename ColMatrixType::RowXpr x) { return x.squaredNorm() == Scalar(0); })-A.rowwise().rbegin(), (A.rows()-1) - i );
|
||||
|
||||
j = internal::random<Index>(0,A.cols()-1);
|
||||
A.setRandom();
|
||||
A.col(j).setZero();
|
||||
VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j );
|
||||
VERIFY_IS_EQUAL( std::find_if(A.colwise().begin(), A.colwise().end(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().begin(), j );
|
||||
VERIFY_IS_EQUAL( std::find_if(A.colwise().rbegin(), A.colwise().rend(), [](typename ColMatrixType::ColXpr x) { return x.squaredNorm() == Scalar(0); })-A.colwise().rbegin(), (A.cols()-1) - j );
|
||||
}
|
||||
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user