mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-24 22:04:28 +08:00
Tensor Roll / Circular Shift / Rotate
This commit is contained in:
parent
bb73be8a2e
commit
d49021212b
@ -109,6 +109,7 @@
|
||||
#include "src/Tensor/TensorMorphing.h"
|
||||
#include "src/Tensor/TensorPadding.h"
|
||||
#include "src/Tensor/TensorReverse.h"
|
||||
#include "src/Tensor/TensorRoll.h"
|
||||
#include "src/Tensor/TensorShuffling.h"
|
||||
#include "src/Tensor/TensorStriding.h"
|
||||
#include "src/Tensor/TensorCustomOp.h"
|
||||
|
@ -946,6 +946,11 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
reverse(const ReverseDimensions& rev) const {
|
||||
return TensorReverseOp<const ReverseDimensions, const Derived>(derived(), rev);
|
||||
}
|
||||
template <typename Rolls> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorRollOp<const Rolls, const Derived>
|
||||
roll(const Rolls& rolls) const {
|
||||
return TensorRollOp<const Rolls, const Derived>(derived(), rolls);
|
||||
}
|
||||
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorPaddingOp<const PaddingDimensions, const Derived>
|
||||
pad(const PaddingDimensions& padding) const {
|
||||
@ -1166,6 +1171,17 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
|
||||
return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev);
|
||||
}
|
||||
|
||||
template <typename Rolls> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorRollOp<const Rolls, const Derived>
|
||||
roll(const Rolls& roll) const {
|
||||
return TensorRollOp<const Rolls, const Derived>(derived(), roll);
|
||||
}
|
||||
template <typename Rolls> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorRollOp<const Rolls, Derived>
|
||||
roll(const Rolls& roll) {
|
||||
return TensorRollOp<const Rolls, Derived>(derived(), roll);
|
||||
}
|
||||
|
||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorShufflingOp<const Shuffle, const Derived>
|
||||
shuffle(const Shuffle& shfl) const {
|
||||
|
@ -111,6 +111,8 @@ template <typename StartIndices, typename Sizes, typename XprType>
|
||||
class TensorSlicingOp;
|
||||
template <typename ReverseDimensions, typename XprType>
|
||||
class TensorReverseOp;
|
||||
template <typename Rolls, typename XprType>
|
||||
class TensorRollOp;
|
||||
template <typename PaddingDimensions, typename XprType>
|
||||
class TensorPaddingOp;
|
||||
template <typename Shuffle, typename XprType>
|
||||
|
361
unsupported/Eigen/CXX11/src/Tensor/TensorRoll.h
Normal file
361
unsupported/Eigen/CXX11/src/Tensor/TensorRoll.h
Normal file
@ -0,0 +1,361 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2024 Tobias Wood tobias@spinicist.org.uk
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_ROLL_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_ROLL_H
|
||||
// IWYU pragma: private
|
||||
#include "./InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \class TensorRoll
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor roll (circular shift) elements class.
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template <typename RollDimensions, typename XprType>
|
||||
struct traits<TensorRollOp<RollDimensions, XprType> > : public traits<XprType> {
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef traits<XprType> XprTraits;
|
||||
typedef typename XprTraits::StorageKind StorageKind;
|
||||
typedef typename XprTraits::Index Index;
|
||||
typedef typename XprType::Nested Nested;
|
||||
typedef std::remove_reference_t<Nested> Nested_;
|
||||
static constexpr int NumDimensions = XprTraits::NumDimensions;
|
||||
static constexpr int Layout = XprTraits::Layout;
|
||||
typedef typename XprTraits::PointerType PointerType;
|
||||
};
|
||||
|
||||
template <typename RollDimensions, typename XprType>
|
||||
struct eval<TensorRollOp<RollDimensions, XprType>, Eigen::Dense> {
|
||||
typedef const TensorRollOp<RollDimensions, XprType>& type;
|
||||
};
|
||||
|
||||
template <typename RollDimensions, typename XprType>
|
||||
struct nested<TensorRollOp<RollDimensions, XprType>, 1, typename eval<TensorRollOp<RollDimensions, XprType> >::type> {
|
||||
typedef TensorRollOp<RollDimensions, XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template <typename RollDimensions, typename XprType>
|
||||
class TensorRollOp : public TensorBase<TensorRollOp<RollDimensions, XprType>, WriteAccessors> {
|
||||
public:
|
||||
typedef TensorBase<TensorRollOp<RollDimensions, XprType>, WriteAccessors> Base;
|
||||
typedef typename Eigen::internal::traits<TensorRollOp>::Scalar Scalar;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorRollOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorRollOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorRollOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorRollOp(const XprType& expr, const RollDimensions& roll_dims)
|
||||
: m_xpr(expr), m_roll_dims(roll_dims) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC const RollDimensions& roll() const { return m_roll_dims; }
|
||||
|
||||
EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_xpr; }
|
||||
|
||||
EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorRollOp)
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_xpr;
|
||||
const RollDimensions m_roll_dims;
|
||||
};
|
||||
|
||||
// Eval as rvalue
|
||||
template <typename RollDimensions, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<const TensorRollOp<RollDimensions, ArgType>, Device> {
|
||||
typedef TensorRollOp<RollDimensions, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static constexpr int NumDims = internal::array_size<RollDimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
|
||||
static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
|
||||
typedef StorageMemory<CoeffReturnType, Device> Storage;
|
||||
typedef typename Storage::Type EvaluatorPointerType;
|
||||
|
||||
static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
|
||||
BlockAccess = NumDims > 0,
|
||||
PreferBlockAccess = true,
|
||||
CoordAccess = false, // to be implemented
|
||||
RawAccess = false
|
||||
};
|
||||
|
||||
typedef internal::TensorIntDivisor<Index> IndexDivisor;
|
||||
|
||||
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
|
||||
using TensorBlockDesc = internal::TensorBlockDescriptor<NumDims, Index>;
|
||||
using TensorBlockScratch = internal::TensorBlockScratchAllocator<Device>;
|
||||
using ArgTensorBlock = typename TensorEvaluator<const ArgType, Device>::TensorBlock;
|
||||
using TensorBlock = typename internal::TensorMaterializedBlock<CoeffReturnType, NumDims, Layout, Index>;
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device), m_rolls(op.roll()), m_device(device) {
|
||||
EIGEN_STATIC_ASSERT((NumDims > 0), Must_Have_At_Least_One_Dimension_To_Roll);
|
||||
|
||||
// Compute strides
|
||||
m_dimensions = m_impl.dimensions();
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
m_strides[0] = 1;
|
||||
for (int i = 1; i < NumDims; ++i) {
|
||||
m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
|
||||
if (m_strides[i] > 0) m_fast_strides[i] = IndexDivisor(m_strides[i]);
|
||||
}
|
||||
} else {
|
||||
m_strides[NumDims - 1] = 1;
|
||||
for (int i = NumDims - 2; i >= 0; --i) {
|
||||
m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
|
||||
if (m_strides[i] > 0) m_fast_strides[i] = IndexDivisor(m_strides[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
|
||||
m_impl.evalSubExprsIfNeeded(nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef EIGEN_USE_THREADS
|
||||
template <typename EvalSubExprsCallback>
|
||||
EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
|
||||
m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); });
|
||||
}
|
||||
#endif // EIGEN_USE_THREADS
|
||||
|
||||
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index roll(Index const i, Index const r, Index const n) const {
|
||||
auto const tmp = (i + r) % n;
|
||||
if (tmp < 0) {
|
||||
return tmp + n;
|
||||
} else {
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE array<Index, NumDims> rollCoords(array<Index, NumDims> const& coords) const {
|
||||
array<Index, NumDims> rolledCoords;
|
||||
for (int id = 0; id < NumDims; id++) {
|
||||
eigen_assert(coords[id] < m_dimensions[id]);
|
||||
rolledCoords[id] = roll(coords[id], m_rolls[id], m_dimensions[id]);
|
||||
}
|
||||
return rolledCoords;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rollIndex(Index index) const {
|
||||
eigen_assert(index < dimensions().TotalSize());
|
||||
Index rolledIndex = 0;
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
Index idx = index / m_fast_strides[i];
|
||||
index -= idx * m_strides[i];
|
||||
rolledIndex += roll(idx, m_rolls[i], m_dimensions[i]) * m_strides[i];
|
||||
}
|
||||
rolledIndex += roll(index, m_rolls[0], m_dimensions[0]);
|
||||
} else {
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < NumDims - 1; ++i) {
|
||||
Index idx = index / m_fast_strides[i];
|
||||
index -= idx * m_strides[i];
|
||||
rolledIndex += roll(idx, m_rolls[i], m_dimensions[i]) * m_strides[i];
|
||||
}
|
||||
rolledIndex += roll(index, m_rolls[NumDims - 1], m_dimensions[NumDims - 1]);
|
||||
}
|
||||
return rolledIndex;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
return m_impl.coeff(rollIndex(index));
|
||||
}
|
||||
|
||||
template <int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
|
||||
eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
|
||||
EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
values[i] = coeff(index + i);
|
||||
}
|
||||
PacketReturnType rslt = internal::pload<PacketReturnType>(values);
|
||||
return rslt;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements() const {
|
||||
const size_t target_size = m_device.lastLevelCacheSize();
|
||||
return internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size).addCostPerCoeff({0, 0, 24});
|
||||
}
|
||||
|
||||
struct BlockIteratorState {
|
||||
Index stride;
|
||||
Index span;
|
||||
Index size;
|
||||
Index count;
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
|
||||
bool /*root_of_expr_ast*/ = false) const {
|
||||
static const bool is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
|
||||
|
||||
// Compute spatial coordinates for the first block element.
|
||||
array<Index, NumDims> coords;
|
||||
extract_coordinates(desc.offset(), coords);
|
||||
array<Index, NumDims> initial_coords = coords;
|
||||
Index offset = 0; // Offset in the output block buffer.
|
||||
|
||||
// Initialize output block iterator state. Dimension in this array are
|
||||
// always in inner_most -> outer_most order (col major layout).
|
||||
array<BlockIteratorState, NumDims> it;
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
const int dim = is_col_major ? i : NumDims - 1 - i;
|
||||
it[i].size = desc.dimension(dim);
|
||||
it[i].stride = i == 0 ? 1 : (it[i - 1].size * it[i - 1].stride);
|
||||
it[i].span = it[i].stride * (it[i].size - 1);
|
||||
it[i].count = 0;
|
||||
}
|
||||
eigen_assert(it[0].stride == 1);
|
||||
|
||||
// Prepare storage for the materialized generator result.
|
||||
const typename TensorBlock::Storage block_storage = TensorBlock::prepareStorage(desc, scratch);
|
||||
CoeffReturnType* block_buffer = block_storage.data();
|
||||
|
||||
static const int inner_dim = is_col_major ? 0 : NumDims - 1;
|
||||
const Index inner_dim_size = it[0].size;
|
||||
|
||||
while (it[NumDims - 1].count < it[NumDims - 1].size) {
|
||||
Index i = 0;
|
||||
for (; i < inner_dim_size; ++i) {
|
||||
auto const rolled = rollCoords(coords);
|
||||
auto const index = is_col_major ? m_dimensions.IndexOfColMajor(rolled) : m_dimensions.IndexOfRowMajor(rolled);
|
||||
*(block_buffer + offset + i) = m_impl.coeff(index);
|
||||
coords[inner_dim]++;
|
||||
}
|
||||
coords[inner_dim] = initial_coords[inner_dim];
|
||||
|
||||
if (NumDims == 1) break; // For the 1d tensor we need to generate only one inner-most dimension.
|
||||
|
||||
// Update offset.
|
||||
for (i = 1; i < NumDims; ++i) {
|
||||
if (++it[i].count < it[i].size) {
|
||||
offset += it[i].stride;
|
||||
coords[is_col_major ? i : NumDims - 1 - i]++;
|
||||
break;
|
||||
}
|
||||
if (i != NumDims - 1) it[i].count = 0;
|
||||
coords[is_col_major ? i : NumDims - 1 - i] = initial_coords[is_col_major ? i : NumDims - 1 - i];
|
||||
offset -= it[i].span;
|
||||
}
|
||||
}
|
||||
|
||||
return block_storage.AsTensorMaterializedBlock();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
|
||||
double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
|
||||
TensorOpCost::DivCost<Index>());
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
compute_cost += 2 * TensorOpCost::AddCost<Index>();
|
||||
}
|
||||
return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost, false /* vectorized */, PacketSize);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC typename Storage::Type data() const { return nullptr; }
|
||||
|
||||
protected:
|
||||
Dimensions m_dimensions;
|
||||
array<Index, NumDims> m_strides;
|
||||
array<IndexDivisor, NumDims> m_fast_strides;
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
RollDimensions m_rolls;
|
||||
const Device EIGEN_DEVICE_REF m_device;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_coordinates(Index index, array<Index, NumDims>& coords) const {
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
const Index idx = index / m_fast_strides[i];
|
||||
index -= idx * m_strides[i];
|
||||
coords[i] = idx;
|
||||
}
|
||||
coords[0] = index;
|
||||
} else {
|
||||
for (int i = 0; i < NumDims - 1; ++i) {
|
||||
const Index idx = index / m_fast_strides[i];
|
||||
index -= idx * m_strides[i];
|
||||
coords[i] = idx;
|
||||
}
|
||||
coords[NumDims - 1] = index;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
// Eval as lvalue
|
||||
|
||||
template <typename RollDimensions, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<TensorRollOp<RollDimensions, ArgType>, Device>
|
||||
: public TensorEvaluator<const TensorRollOp<RollDimensions, ArgType>, Device> {
|
||||
typedef TensorEvaluator<const TensorRollOp<RollDimensions, ArgType>, Device> Base;
|
||||
typedef TensorRollOp<RollDimensions, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static constexpr int NumDims = internal::array_size<RollDimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
|
||||
BlockAccess = false,
|
||||
PreferBlockAccess = false,
|
||||
CoordAccess = false,
|
||||
RawAccess = false
|
||||
};
|
||||
EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {}
|
||||
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
|
||||
static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
|
||||
|
||||
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
|
||||
typedef internal::TensorBlockNotImplemented TensorBlock;
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return this->m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) const {
|
||||
return this->m_impl.coeffRef(this->rollIndex(index));
|
||||
}
|
||||
|
||||
template <int StoreMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index index, const PacketReturnType& x) const {
|
||||
eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
|
||||
EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
|
||||
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
this->coeffRef(index + i) = values[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_ROLL_H
|
@ -198,8 +198,10 @@ ei_add_test(cxx11_tensor_of_strings)
|
||||
ei_add_test(cxx11_tensor_padding)
|
||||
ei_add_test(cxx11_tensor_patch)
|
||||
ei_add_test(cxx11_tensor_random)
|
||||
ei_add_test(cxx11_tensor_reverse)
|
||||
ei_add_test(cxx11_tensor_reduction)
|
||||
ei_add_test(cxx11_tensor_ref)
|
||||
ei_add_test(cxx11_tensor_roll)
|
||||
ei_add_test(cxx11_tensor_roundings)
|
||||
ei_add_test(cxx11_tensor_scan)
|
||||
ei_add_test(cxx11_tensor_shuffling)
|
||||
|
156
unsupported/test/cxx11_tensor_roll.cpp
Normal file
156
unsupported/test/cxx11_tensor_roll.cpp
Normal file
@ -0,0 +1,156 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2024 Tobias Wood tobias@spinicist.org.uk
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
using Eigen::array;
|
||||
using Eigen::Tensor;
|
||||
|
||||
template <int DataLayout>
|
||||
static void test_simple_roll() {
|
||||
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||
tensor.setRandom();
|
||||
|
||||
array<Index, 4> dim_roll;
|
||||
dim_roll[0] = 0;
|
||||
dim_roll[1] = 1;
|
||||
dim_roll[2] = 4;
|
||||
dim_roll[3] = 8;
|
||||
|
||||
Tensor<float, 4, DataLayout> rolled_tensor;
|
||||
rolled_tensor = tensor.roll(dim_roll);
|
||||
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(2), 5);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(3), 7);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 5; ++k) {
|
||||
for (int l = 0; l < 7; ++l) {
|
||||
VERIFY_IS_EQUAL(tensor(i, (j + 1) % 3, (k + 4) % 5, (l + 8) % 7), rolled_tensor(i, j, k, l));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dim_roll[0] = -3;
|
||||
dim_roll[1] = -2;
|
||||
dim_roll[2] = -1;
|
||||
dim_roll[3] = 0;
|
||||
|
||||
rolled_tensor = tensor.roll(dim_roll);
|
||||
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(2), 5);
|
||||
VERIFY_IS_EQUAL(rolled_tensor.dimension(3), 7);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 5; ++k) {
|
||||
for (int l = 0; l < 7; ++l) {
|
||||
VERIFY_IS_EQUAL(tensor((i + 1) % 2, (j + 1) % 3, (k + 4) % 5, l), rolled_tensor(i, j, k, l));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int DataLayout>
|
||||
static void test_expr_roll(bool LValue) {
|
||||
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||
tensor.setRandom();
|
||||
|
||||
array<bool, 4> dim_roll;
|
||||
dim_roll[0] = 2;
|
||||
dim_roll[1] = 1;
|
||||
dim_roll[2] = 0;
|
||||
dim_roll[3] = 3;
|
||||
|
||||
Tensor<float, 4, DataLayout> expected(tensor.dimensions());
|
||||
if (LValue) {
|
||||
expected.roll(dim_roll) = tensor;
|
||||
} else {
|
||||
expected = tensor.roll(dim_roll);
|
||||
}
|
||||
|
||||
Tensor<float, 4, DataLayout> result(tensor.dimensions());
|
||||
|
||||
array<ptrdiff_t, 4> src_slice_dim;
|
||||
src_slice_dim[0] = tensor.dimension(0);
|
||||
src_slice_dim[1] = tensor.dimension(1);
|
||||
src_slice_dim[2] = 1;
|
||||
src_slice_dim[3] = tensor.dimension(3);
|
||||
array<ptrdiff_t, 4> src_slice_start;
|
||||
src_slice_start[0] = 0;
|
||||
src_slice_start[1] = 0;
|
||||
src_slice_start[2] = 0;
|
||||
src_slice_start[3] = 0;
|
||||
array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim;
|
||||
array<ptrdiff_t, 4> dst_slice_start = src_slice_start;
|
||||
|
||||
for (int i = 0; i < tensor.dimension(2); ++i) {
|
||||
if (LValue) {
|
||||
result.slice(dst_slice_start, dst_slice_dim).roll(dim_roll) = tensor.slice(src_slice_start, src_slice_dim);
|
||||
} else {
|
||||
result.slice(dst_slice_start, dst_slice_dim) = tensor.slice(src_slice_start, src_slice_dim).roll(dim_roll);
|
||||
}
|
||||
src_slice_start[2] += 1;
|
||||
dst_slice_start[2] += 1;
|
||||
}
|
||||
|
||||
VERIFY_IS_EQUAL(result.dimension(0), tensor.dimension(0));
|
||||
VERIFY_IS_EQUAL(result.dimension(1), tensor.dimension(1));
|
||||
VERIFY_IS_EQUAL(result.dimension(2), tensor.dimension(2));
|
||||
VERIFY_IS_EQUAL(result.dimension(3), tensor.dimension(3));
|
||||
|
||||
for (int i = 0; i < expected.dimension(0); ++i) {
|
||||
for (int j = 0; j < expected.dimension(1); ++j) {
|
||||
for (int k = 0; k < expected.dimension(2); ++k) {
|
||||
for (int l = 0; l < expected.dimension(3); ++l) {
|
||||
VERIFY_IS_EQUAL(result(i, j, k, l), expected(i, j, k, l));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst_slice_start[2] = 0;
|
||||
result.setRandom();
|
||||
for (int i = 0; i < tensor.dimension(2); ++i) {
|
||||
if (LValue) {
|
||||
result.slice(dst_slice_start, dst_slice_dim).roll(dim_roll) = tensor.slice(dst_slice_start, dst_slice_dim);
|
||||
} else {
|
||||
result.slice(dst_slice_start, dst_slice_dim) = tensor.roll(dim_roll).slice(dst_slice_start, dst_slice_dim);
|
||||
}
|
||||
dst_slice_start[2] += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < expected.dimension(0); ++i) {
|
||||
for (int j = 0; j < expected.dimension(1); ++j) {
|
||||
for (int k = 0; k < expected.dimension(2); ++k) {
|
||||
for (int l = 0; l < expected.dimension(3); ++l) {
|
||||
VERIFY_IS_EQUAL(result(i, j, k, l), expected(i, j, k, l));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_roll) {
|
||||
CALL_SUBTEST(test_simple_roll<ColMajor>());
|
||||
CALL_SUBTEST(test_simple_roll<RowMajor>());
|
||||
CALL_SUBTEST(test_expr_roll<ColMajor>(true));
|
||||
CALL_SUBTEST(test_expr_roll<RowMajor>(true));
|
||||
CALL_SUBTEST(test_expr_roll<ColMajor>(false));
|
||||
CALL_SUBTEST(test_expr_roll<RowMajor>(false));
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user