Enable direct access for IndexedView.

This commit is contained in:
Antonio Sánchez 2024-02-20 18:21:45 +00:00 committed by Rasmus Munk Larsen
parent 90087b990a
commit b56e30841c
4 changed files with 126 additions and 16 deletions

View File

@ -226,6 +226,11 @@ struct get_compile_time_incr<ArithmeticSequence<FirstType, SizeType, IncrType> >
enum { value = get_fixed_value<IncrType, DynamicIndex>::value };
};
template <typename FirstType, typename SizeType, typename IncrType>
constexpr Index get_runtime_incr(const ArithmeticSequence<FirstType, SizeType, IncrType>& x) EIGEN_NOEXCEPT {
return static_cast<Index>(x.incrObject());
};
} // end namespace internal
/** \namespace Eigen::indexing

View File

@ -75,11 +75,11 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices>> : traits<XprType> {
typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
};
} // namespace internal
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
class IndexedViewImpl;
} // namespace internal
/** \class IndexedView
* \ingroup Core_Module
*
@ -120,19 +120,36 @@ class IndexedViewImpl;
*/
template <typename XprType, typename RowIndices, typename ColIndices>
class IndexedView
: public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind> {
: public internal::IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags &
DirectAccessBit) != 0> {
public:
typedef
typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base
typedef typename internal::IndexedViewImpl<
XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags & DirectAccessBit) != 0>
Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
template <typename T0, typename T1>
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
};
namespace internal {
// Generic API dispatcher
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
public:
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
typedef internal::remove_all_t<XprType> NestedExpression;
typedef typename XprType::Scalar Scalar;
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
template <typename T0, typename T1>
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices)
: m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}
/** \returns number of rows */
@ -153,20 +170,76 @@ class IndexedView
/** \returns a const reference to the object storing/generating the column indices */
const ColIndices& colIndices() const { return m_colIndices; }
constexpr Scalar& coeffRef(Index rowId, Index colId) {
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
}
constexpr const Scalar& coeffRef(Index rowId, Index colId) const {
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
}
protected:
MatrixTypeNested m_xpr;
RowIndices m_rowIndices;
ColIndices m_colIndices;
};
// Generic API dispatcher
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
class IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, true>
: public IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, false> {
public:
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
};
using Base = internal::IndexedViewImpl<XprType, RowIndices, ColIndices,
typename internal::traits<XprType>::StorageKind, false>;
using Derived = IndexedView<XprType, RowIndices, ColIndices>;
namespace internal {
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
template <typename T0, typename T1>
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
Index rowIncrement() const {
if (traits<Derived>::RowIncr != DynamicIndex && traits<Derived>::RowIncr != UndefinedIncr) {
return traits<Derived>::RowIncr;
}
return get_runtime_incr(this->rowIndices());
}
Index colIncrement() const {
if (traits<Derived>::ColIncr != DynamicIndex && traits<Derived>::ColIncr != UndefinedIncr) {
return traits<Derived>::ColIncr;
}
return get_runtime_incr(this->colIndices());
}
Index innerIncrement() const { return traits<Derived>::IsRowMajor ? colIncrement() : rowIncrement(); }
Index outerIncrement() const { return traits<Derived>::IsRowMajor ? rowIncrement() : colIncrement(); }
std::decay_t<typename XprType::Scalar>* data() {
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
return this->nestedExpression().data() + row_offset + col_offset;
}
const std::decay_t<typename XprType::Scalar>* data() const {
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
return this->nestedExpression().data() + row_offset + col_offset;
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const EIGEN_NOEXCEPT {
if (traits<Derived>::InnerStrideAtCompileTime != Dynamic) {
return traits<Derived>::InnerStrideAtCompileTime;
}
return innerIncrement() * this->nestedExpression().innerStride();
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const EIGEN_NOEXCEPT {
if (traits<Derived>::OuterStrideAtCompileTime != Dynamic) {
return traits<Derived>::OuterStrideAtCompileTime;
}
return outerIncrement() * this->nestedExpression().outerStride();
}
};
template <typename ArgType, typename RowIndices, typename ColIndices>
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>

View File

@ -67,6 +67,11 @@ struct get_compile_time_incr {
enum { value = UndefinedIncr };
};
template <typename T>
constexpr Index get_runtime_incr(const T&) EIGEN_NOEXCEPT {
return Index(1);
};
// Analogue of std::get<0>(x), but tailored for our needs.
template <typename T>
EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT {

View File

@ -498,12 +498,39 @@ void check_indexed_view() {
// A(1, seq(0,2,1)).cwiseAbs().colwise().replicate(2).eval();
STATIC_CHECK(((internal::evaluator<decltype(A(1, seq(0, 2, 1)))>::Flags & RowMajorBit) == RowMajorBit));
}
// Direct access.
{
int rows = 3;
int row_start = internal::random<int>(0, rows - 1);
int row_inc = internal::random<int>(1, rows - row_start);
int row_size = internal::random<int>(1, (rows - row_start) / row_inc);
auto row_seq = seqN(row_start, row_size, row_inc);
int cols = 3;
int col_start = internal::random<int>(0, cols - 1);
int col_inc = internal::random<int>(1, cols - col_start);
int col_size = internal::random<int>(1, (cols - col_start) / col_inc);
auto col_seq = seqN(col_start, col_size, col_inc);
MatrixXd m1 = MatrixXd::Random(rows, cols);
MatrixXd m2 = MatrixXd::Random(cols, rows);
VERIFY_IS_APPROX(m1(row_seq, indexing::all) * m2, m1(row_seq, indexing::all).eval() * m2);
VERIFY_IS_APPROX(m1 * m2(indexing::all, col_seq), m1 * m2(indexing::all, col_seq).eval());
VERIFY_IS_APPROX(m1(row_seq, col_seq) * m2(col_seq, row_seq),
m1(row_seq, col_seq).eval() * m2(col_seq, row_seq).eval());
VectorXd v1 = VectorXd::Random(cols);
VERIFY_IS_APPROX(m1(row_seq, col_seq) * v1(col_seq), m1(row_seq, col_seq).eval() * v1(col_seq).eval());
VERIFY_IS_APPROX(v1(col_seq).transpose() * m2(col_seq, row_seq),
v1(col_seq).transpose().eval() * m2(col_seq, row_seq).eval());
}
}
EIGEN_DECLARE_TEST(indexed_view) {
// for(int i = 0; i < g_repeat; i++) {
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1(check_indexed_view());
// }
}
// static checks of some internals:
STATIC_CHECK((internal::is_valid_index_type<int>::value));