mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-13 01:43:13 +08:00
Enable direct access for IndexedView.
This commit is contained in:
parent
90087b990a
commit
b56e30841c
@ -226,6 +226,11 @@ struct get_compile_time_incr<ArithmeticSequence<FirstType, SizeType, IncrType> >
|
|||||||
enum { value = get_fixed_value<IncrType, DynamicIndex>::value };
|
enum { value = get_fixed_value<IncrType, DynamicIndex>::value };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename SizeType, typename IncrType>
|
||||||
|
constexpr Index get_runtime_incr(const ArithmeticSequence<FirstType, SizeType, IncrType>& x) EIGEN_NOEXCEPT {
|
||||||
|
return static_cast<Index>(x.incrObject());
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/** \namespace Eigen::indexing
|
/** \namespace Eigen::indexing
|
||||||
|
@ -75,11 +75,11 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices>> : traits<XprType> {
|
|||||||
typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
|
typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace internal
|
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
|
||||||
|
|
||||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
|
|
||||||
class IndexedViewImpl;
|
class IndexedViewImpl;
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
/** \class IndexedView
|
/** \class IndexedView
|
||||||
* \ingroup Core_Module
|
* \ingroup Core_Module
|
||||||
*
|
*
|
||||||
@ -120,19 +120,36 @@ class IndexedViewImpl;
|
|||||||
*/
|
*/
|
||||||
template <typename XprType, typename RowIndices, typename ColIndices>
|
template <typename XprType, typename RowIndices, typename ColIndices>
|
||||||
class IndexedView
|
class IndexedView
|
||||||
: public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind> {
|
: public internal::IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
|
||||||
|
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags &
|
||||||
|
DirectAccessBit) != 0> {
|
||||||
public:
|
public:
|
||||||
typedef
|
typedef typename internal::IndexedViewImpl<
|
||||||
typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base
|
XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
|
||||||
|
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags & DirectAccessBit) != 0>
|
||||||
Base;
|
Base;
|
||||||
EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
|
EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
|
||||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
|
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
|
||||||
|
|
||||||
|
template <typename T0, typename T1>
|
||||||
|
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// Generic API dispatcher
|
||||||
|
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
|
||||||
|
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
|
||||||
|
public:
|
||||||
|
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
|
||||||
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
|
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
|
||||||
typedef internal::remove_all_t<XprType> NestedExpression;
|
typedef internal::remove_all_t<XprType> NestedExpression;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
|
||||||
|
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
|
||||||
|
|
||||||
template <typename T0, typename T1>
|
template <typename T0, typename T1>
|
||||||
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
|
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices)
|
||||||
: m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}
|
: m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}
|
||||||
|
|
||||||
/** \returns number of rows */
|
/** \returns number of rows */
|
||||||
@ -153,20 +170,76 @@ class IndexedView
|
|||||||
/** \returns a const reference to the object storing/generating the column indices */
|
/** \returns a const reference to the object storing/generating the column indices */
|
||||||
const ColIndices& colIndices() const { return m_colIndices; }
|
const ColIndices& colIndices() const { return m_colIndices; }
|
||||||
|
|
||||||
|
constexpr Scalar& coeffRef(Index rowId, Index colId) {
|
||||||
|
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const Scalar& coeffRef(Index rowId, Index colId) const {
|
||||||
|
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
MatrixTypeNested m_xpr;
|
MatrixTypeNested m_xpr;
|
||||||
RowIndices m_rowIndices;
|
RowIndices m_rowIndices;
|
||||||
ColIndices m_colIndices;
|
ColIndices m_colIndices;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generic API dispatcher
|
|
||||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
|
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
|
||||||
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
|
class IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, true>
|
||||||
|
: public IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, false> {
|
||||||
public:
|
public:
|
||||||
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
|
using Base = internal::IndexedViewImpl<XprType, RowIndices, ColIndices,
|
||||||
};
|
typename internal::traits<XprType>::StorageKind, false>;
|
||||||
|
using Derived = IndexedView<XprType, RowIndices, ColIndices>;
|
||||||
|
|
||||||
namespace internal {
|
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
|
||||||
|
|
||||||
|
template <typename T0, typename T1>
|
||||||
|
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
|
||||||
|
|
||||||
|
Index rowIncrement() const {
|
||||||
|
if (traits<Derived>::RowIncr != DynamicIndex && traits<Derived>::RowIncr != UndefinedIncr) {
|
||||||
|
return traits<Derived>::RowIncr;
|
||||||
|
}
|
||||||
|
return get_runtime_incr(this->rowIndices());
|
||||||
|
}
|
||||||
|
Index colIncrement() const {
|
||||||
|
if (traits<Derived>::ColIncr != DynamicIndex && traits<Derived>::ColIncr != UndefinedIncr) {
|
||||||
|
return traits<Derived>::ColIncr;
|
||||||
|
}
|
||||||
|
return get_runtime_incr(this->colIndices());
|
||||||
|
}
|
||||||
|
|
||||||
|
Index innerIncrement() const { return traits<Derived>::IsRowMajor ? colIncrement() : rowIncrement(); }
|
||||||
|
|
||||||
|
Index outerIncrement() const { return traits<Derived>::IsRowMajor ? rowIncrement() : colIncrement(); }
|
||||||
|
|
||||||
|
std::decay_t<typename XprType::Scalar>* data() {
|
||||||
|
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
|
||||||
|
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
|
||||||
|
return this->nestedExpression().data() + row_offset + col_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::decay_t<typename XprType::Scalar>* data() const {
|
||||||
|
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
|
||||||
|
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
|
||||||
|
return this->nestedExpression().data() + row_offset + col_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const EIGEN_NOEXCEPT {
|
||||||
|
if (traits<Derived>::InnerStrideAtCompileTime != Dynamic) {
|
||||||
|
return traits<Derived>::InnerStrideAtCompileTime;
|
||||||
|
}
|
||||||
|
return innerIncrement() * this->nestedExpression().innerStride();
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const EIGEN_NOEXCEPT {
|
||||||
|
if (traits<Derived>::OuterStrideAtCompileTime != Dynamic) {
|
||||||
|
return traits<Derived>::OuterStrideAtCompileTime;
|
||||||
|
}
|
||||||
|
return outerIncrement() * this->nestedExpression().outerStride();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ArgType, typename RowIndices, typename ColIndices>
|
template <typename ArgType, typename RowIndices, typename ColIndices>
|
||||||
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
|
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
|
||||||
|
@ -67,6 +67,11 @@ struct get_compile_time_incr {
|
|||||||
enum { value = UndefinedIncr };
|
enum { value = UndefinedIncr };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr Index get_runtime_incr(const T&) EIGEN_NOEXCEPT {
|
||||||
|
return Index(1);
|
||||||
|
};
|
||||||
|
|
||||||
// Analogue of std::get<0>(x), but tailored for our needs.
|
// Analogue of std::get<0>(x), but tailored for our needs.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT {
|
EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT {
|
||||||
|
@ -498,12 +498,39 @@ void check_indexed_view() {
|
|||||||
// A(1, seq(0,2,1)).cwiseAbs().colwise().replicate(2).eval();
|
// A(1, seq(0,2,1)).cwiseAbs().colwise().replicate(2).eval();
|
||||||
STATIC_CHECK(((internal::evaluator<decltype(A(1, seq(0, 2, 1)))>::Flags & RowMajorBit) == RowMajorBit));
|
STATIC_CHECK(((internal::evaluator<decltype(A(1, seq(0, 2, 1)))>::Flags & RowMajorBit) == RowMajorBit));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Direct access.
|
||||||
|
{
|
||||||
|
int rows = 3;
|
||||||
|
int row_start = internal::random<int>(0, rows - 1);
|
||||||
|
int row_inc = internal::random<int>(1, rows - row_start);
|
||||||
|
int row_size = internal::random<int>(1, (rows - row_start) / row_inc);
|
||||||
|
auto row_seq = seqN(row_start, row_size, row_inc);
|
||||||
|
|
||||||
|
int cols = 3;
|
||||||
|
int col_start = internal::random<int>(0, cols - 1);
|
||||||
|
int col_inc = internal::random<int>(1, cols - col_start);
|
||||||
|
int col_size = internal::random<int>(1, (cols - col_start) / col_inc);
|
||||||
|
auto col_seq = seqN(col_start, col_size, col_inc);
|
||||||
|
|
||||||
|
MatrixXd m1 = MatrixXd::Random(rows, cols);
|
||||||
|
MatrixXd m2 = MatrixXd::Random(cols, rows);
|
||||||
|
VERIFY_IS_APPROX(m1(row_seq, indexing::all) * m2, m1(row_seq, indexing::all).eval() * m2);
|
||||||
|
VERIFY_IS_APPROX(m1 * m2(indexing::all, col_seq), m1 * m2(indexing::all, col_seq).eval());
|
||||||
|
VERIFY_IS_APPROX(m1(row_seq, col_seq) * m2(col_seq, row_seq),
|
||||||
|
m1(row_seq, col_seq).eval() * m2(col_seq, row_seq).eval());
|
||||||
|
|
||||||
|
VectorXd v1 = VectorXd::Random(cols);
|
||||||
|
VERIFY_IS_APPROX(m1(row_seq, col_seq) * v1(col_seq), m1(row_seq, col_seq).eval() * v1(col_seq).eval());
|
||||||
|
VERIFY_IS_APPROX(v1(col_seq).transpose() * m2(col_seq, row_seq),
|
||||||
|
v1(col_seq).transpose().eval() * m2(col_seq, row_seq).eval());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DECLARE_TEST(indexed_view) {
|
EIGEN_DECLARE_TEST(indexed_view) {
|
||||||
// for(int i = 0; i < g_repeat; i++) {
|
for (int i = 0; i < g_repeat; i++) {
|
||||||
CALL_SUBTEST_1(check_indexed_view());
|
CALL_SUBTEST_1(check_indexed_view());
|
||||||
// }
|
}
|
||||||
|
|
||||||
// static checks of some internals:
|
// static checks of some internals:
|
||||||
STATIC_CHECK((internal::is_valid_index_type<int>::value));
|
STATIC_CHECK((internal::is_valid_index_type<int>::value));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user