mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-13 01:43:13 +08:00
Enable direct access for IndexedView.
This commit is contained in:
parent
90087b990a
commit
b56e30841c
@ -226,6 +226,11 @@ struct get_compile_time_incr<ArithmeticSequence<FirstType, SizeType, IncrType> >
|
||||
enum { value = get_fixed_value<IncrType, DynamicIndex>::value };
|
||||
};
|
||||
|
||||
template <typename FirstType, typename SizeType, typename IncrType>
|
||||
constexpr Index get_runtime_incr(const ArithmeticSequence<FirstType, SizeType, IncrType>& x) EIGEN_NOEXCEPT {
|
||||
return static_cast<Index>(x.incrObject());
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/** \namespace Eigen::indexing
|
||||
|
@ -75,11 +75,11 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices>> : traits<XprType> {
|
||||
typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
|
||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
|
||||
class IndexedViewImpl;
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/** \class IndexedView
|
||||
* \ingroup Core_Module
|
||||
*
|
||||
@ -120,19 +120,36 @@ class IndexedViewImpl;
|
||||
*/
|
||||
template <typename XprType, typename RowIndices, typename ColIndices>
|
||||
class IndexedView
|
||||
: public IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind> {
|
||||
: public internal::IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
|
||||
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags &
|
||||
DirectAccessBit) != 0> {
|
||||
public:
|
||||
typedef
|
||||
typename IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind>::Base
|
||||
Base;
|
||||
typedef typename internal::IndexedViewImpl<
|
||||
XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
|
||||
(internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags & DirectAccessBit) != 0>
|
||||
Base;
|
||||
EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
|
||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)
|
||||
|
||||
template <typename T0, typename T1>
|
||||
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Generic API dispatcher
|
||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
|
||||
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
|
||||
public:
|
||||
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
|
||||
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
|
||||
typedef internal::remove_all_t<XprType> NestedExpression;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
|
||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
|
||||
|
||||
template <typename T0, typename T1>
|
||||
IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices)
|
||||
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices)
|
||||
: m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}
|
||||
|
||||
/** \returns number of rows */
|
||||
@ -153,20 +170,76 @@ class IndexedView
|
||||
/** \returns a const reference to the object storing/generating the column indices */
|
||||
const ColIndices& colIndices() const { return m_colIndices; }
|
||||
|
||||
constexpr Scalar& coeffRef(Index rowId, Index colId) {
|
||||
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
|
||||
}
|
||||
|
||||
constexpr const Scalar& coeffRef(Index rowId, Index colId) const {
|
||||
return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
|
||||
}
|
||||
|
||||
protected:
|
||||
MatrixTypeNested m_xpr;
|
||||
RowIndices m_rowIndices;
|
||||
ColIndices m_colIndices;
|
||||
};
|
||||
|
||||
// Generic API dispatcher
|
||||
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
|
||||
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
|
||||
class IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, true>
|
||||
: public IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, false> {
|
||||
public:
|
||||
typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
|
||||
};
|
||||
using Base = internal::IndexedViewImpl<XprType, RowIndices, ColIndices,
|
||||
typename internal::traits<XprType>::StorageKind, false>;
|
||||
using Derived = IndexedView<XprType, RowIndices, ColIndices>;
|
||||
|
||||
namespace internal {
|
||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)
|
||||
|
||||
template <typename T0, typename T1>
|
||||
IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
|
||||
|
||||
Index rowIncrement() const {
|
||||
if (traits<Derived>::RowIncr != DynamicIndex && traits<Derived>::RowIncr != UndefinedIncr) {
|
||||
return traits<Derived>::RowIncr;
|
||||
}
|
||||
return get_runtime_incr(this->rowIndices());
|
||||
}
|
||||
Index colIncrement() const {
|
||||
if (traits<Derived>::ColIncr != DynamicIndex && traits<Derived>::ColIncr != UndefinedIncr) {
|
||||
return traits<Derived>::ColIncr;
|
||||
}
|
||||
return get_runtime_incr(this->colIndices());
|
||||
}
|
||||
|
||||
Index innerIncrement() const { return traits<Derived>::IsRowMajor ? colIncrement() : rowIncrement(); }
|
||||
|
||||
Index outerIncrement() const { return traits<Derived>::IsRowMajor ? rowIncrement() : colIncrement(); }
|
||||
|
||||
std::decay_t<typename XprType::Scalar>* data() {
|
||||
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
|
||||
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
|
||||
return this->nestedExpression().data() + row_offset + col_offset;
|
||||
}
|
||||
|
||||
const std::decay_t<typename XprType::Scalar>* data() const {
|
||||
Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
|
||||
Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
|
||||
return this->nestedExpression().data() + row_offset + col_offset;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const EIGEN_NOEXCEPT {
|
||||
if (traits<Derived>::InnerStrideAtCompileTime != Dynamic) {
|
||||
return traits<Derived>::InnerStrideAtCompileTime;
|
||||
}
|
||||
return innerIncrement() * this->nestedExpression().innerStride();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const EIGEN_NOEXCEPT {
|
||||
if (traits<Derived>::OuterStrideAtCompileTime != Dynamic) {
|
||||
return traits<Derived>::OuterStrideAtCompileTime;
|
||||
}
|
||||
return outerIncrement() * this->nestedExpression().outerStride();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ArgType, typename RowIndices, typename ColIndices>
|
||||
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
|
||||
|
@ -67,6 +67,11 @@ struct get_compile_time_incr {
|
||||
enum { value = UndefinedIncr };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr Index get_runtime_incr(const T&) EIGEN_NOEXCEPT {
|
||||
return Index(1);
|
||||
};
|
||||
|
||||
// Analogue of std::get<0>(x), but tailored for our needs.
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR Index first(const T& x) EIGEN_NOEXCEPT {
|
||||
|
@ -498,12 +498,39 @@ void check_indexed_view() {
|
||||
// A(1, seq(0,2,1)).cwiseAbs().colwise().replicate(2).eval();
|
||||
STATIC_CHECK(((internal::evaluator<decltype(A(1, seq(0, 2, 1)))>::Flags & RowMajorBit) == RowMajorBit));
|
||||
}
|
||||
|
||||
// Direct access.
|
||||
{
|
||||
int rows = 3;
|
||||
int row_start = internal::random<int>(0, rows - 1);
|
||||
int row_inc = internal::random<int>(1, rows - row_start);
|
||||
int row_size = internal::random<int>(1, (rows - row_start) / row_inc);
|
||||
auto row_seq = seqN(row_start, row_size, row_inc);
|
||||
|
||||
int cols = 3;
|
||||
int col_start = internal::random<int>(0, cols - 1);
|
||||
int col_inc = internal::random<int>(1, cols - col_start);
|
||||
int col_size = internal::random<int>(1, (cols - col_start) / col_inc);
|
||||
auto col_seq = seqN(col_start, col_size, col_inc);
|
||||
|
||||
MatrixXd m1 = MatrixXd::Random(rows, cols);
|
||||
MatrixXd m2 = MatrixXd::Random(cols, rows);
|
||||
VERIFY_IS_APPROX(m1(row_seq, indexing::all) * m2, m1(row_seq, indexing::all).eval() * m2);
|
||||
VERIFY_IS_APPROX(m1 * m2(indexing::all, col_seq), m1 * m2(indexing::all, col_seq).eval());
|
||||
VERIFY_IS_APPROX(m1(row_seq, col_seq) * m2(col_seq, row_seq),
|
||||
m1(row_seq, col_seq).eval() * m2(col_seq, row_seq).eval());
|
||||
|
||||
VectorXd v1 = VectorXd::Random(cols);
|
||||
VERIFY_IS_APPROX(m1(row_seq, col_seq) * v1(col_seq), m1(row_seq, col_seq).eval() * v1(col_seq).eval());
|
||||
VERIFY_IS_APPROX(v1(col_seq).transpose() * m2(col_seq, row_seq),
|
||||
v1(col_seq).transpose().eval() * m2(col_seq, row_seq).eval());
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(indexed_view) {
|
||||
// for(int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1(check_indexed_view());
|
||||
// }
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
CALL_SUBTEST_1(check_indexed_view());
|
||||
}
|
||||
|
||||
// static checks of some internals:
|
||||
STATIC_CHECK((internal::is_valid_index_type<int>::value));
|
||||
|
Loading…
x
Reference in New Issue
Block a user