mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 17:33:15 +08:00
Can now use the tensor 'reverse' operation as a lvalue
This commit is contained in:
parent
2fffe69b1b
commit
57154fdb32
@ -549,6 +549,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
|||||||
chip(const Index offset, const Index dim) const {
|
chip(const Index offset, const Index dim) const {
|
||||||
return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim);
|
return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim);
|
||||||
}
|
}
|
||||||
|
template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
TensorReverseOp<const ReverseDimensions, Derived>
|
||||||
|
reverse(const ReverseDimensions& rev) const {
|
||||||
|
return TensorReverseOp<const ReverseDimensions, Derived>(derived(), rev);
|
||||||
|
}
|
||||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
TensorShufflingOp<const Shuffle, Derived>
|
TensorShufflingOp<const Shuffle, Derived>
|
||||||
shuffle(const Shuffle& shuffle) const {
|
shuffle(const Shuffle& shuffle) const {
|
||||||
|
@ -49,12 +49,9 @@ struct nested<TensorReverseOp<ReverseDimensions, XprType>, 1,
|
|||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename ReverseDimensions, typename XprType>
|
template<typename ReverseDimensions, typename XprType>
|
||||||
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
||||||
XprType>, ReadOnlyAccessors>
|
XprType>, WriteAccessors>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
|
typedef typename Eigen::internal::traits<TensorReverseOp>::Scalar Scalar;
|
||||||
@ -67,8 +64,8 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
|||||||
StorageKind;
|
StorageKind;
|
||||||
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
|
typedef typename Eigen::internal::traits<TensorReverseOp>::Index Index;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(const XprType& expr,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp(
|
||||||
const ReverseDimensions& reverse_dims)
|
const XprType& expr, const ReverseDimensions& reverse_dims)
|
||||||
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
|
: m_xpr(expr), m_reverse_dims(reverse_dims) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -78,12 +75,30 @@ class TensorReverseOp : public TensorBase<TensorReverseOp<ReverseDimensions,
|
|||||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||||
expression() const { return m_xpr; }
|
expression() const { return m_xpr; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const TensorReverseOp& other)
|
||||||
|
{
|
||||||
|
typedef TensorAssignOp<TensorReverseOp, const TensorReverseOp> Assign;
|
||||||
|
Assign assign(*this, other);
|
||||||
|
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE TensorReverseOp& operator = (const OtherDerived& other)
|
||||||
|
{
|
||||||
|
typedef TensorAssignOp<TensorReverseOp, const OtherDerived> Assign;
|
||||||
|
Assign assign(*this, other);
|
||||||
|
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typename XprType::Nested m_xpr;
|
typename XprType::Nested m_xpr;
|
||||||
const ReverseDimensions m_reverse_dims;
|
const ReverseDimensions m_reverse_dims;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Eval as rvalue
|
// Eval as rvalue
|
||||||
template<typename ReverseDimensions, typename ArgType, typename Device>
|
template<typename ReverseDimensions, typename ArgType, typename Device>
|
||||||
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
|
struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device>
|
||||||
@ -134,8 +149,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
|||||||
m_impl.cleanup();
|
m_impl.cleanup();
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index reverseIndex(
|
||||||
{
|
Index index) const {
|
||||||
eigen_assert(index < dimensions().TotalSize());
|
eigen_assert(index < dimensions().TotalSize());
|
||||||
Index inputIndex = 0;
|
Index inputIndex = 0;
|
||||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
@ -152,7 +167,6 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
|||||||
} else {
|
} else {
|
||||||
inputIndex += index;
|
inputIndex += index;
|
||||||
}
|
}
|
||||||
return m_impl.coeff(inputIndex);
|
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < NumDims - 1; ++i) {
|
for (int i = 0; i < NumDims - 1; ++i) {
|
||||||
Index idx = index / m_strides[i];
|
Index idx = index / m_strides[i];
|
||||||
@ -167,8 +181,13 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
|||||||
} else {
|
} else {
|
||||||
inputIndex += index;
|
inputIndex += index;
|
||||||
}
|
}
|
||||||
return m_impl.coeff(inputIndex);
|
|
||||||
}
|
}
|
||||||
|
return inputIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
|
||||||
|
Index index) const {
|
||||||
|
return m_impl.coeff(reverseIndex(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int LoadMode>
|
template<int LoadMode>
|
||||||
@ -199,7 +218,55 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
|
|||||||
ReverseDimensions m_reverse;
|
ReverseDimensions m_reverse;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Eval as lvalue
|
||||||
|
|
||||||
|
template <typename ReverseDimensions, typename ArgType, typename Device>
|
||||||
|
struct TensorEvaluator<TensorReverseOp<ReverseDimensions, ArgType>, Device>
|
||||||
|
: public TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
|
||||||
|
Device> {
|
||||||
|
typedef TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>,
|
||||||
|
Device> Base;
|
||||||
|
typedef TensorReverseOp<ReverseDimensions, ArgType> XprType;
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
static const int NumDims = internal::array_size<ReverseDimensions>::value;
|
||||||
|
typedef DSizes<Index, NumDims> Dimensions;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
|
||||||
|
Layout = TensorEvaluator<ArgType, Device>::Layout,
|
||||||
|
CoordAccess = false, // to be implemented
|
||||||
|
};
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op,
|
||||||
|
const Device& device)
|
||||||
|
: Base(op, device) {}
|
||||||
|
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const Dimensions& dimensions() const { return this->m_dimensions; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
|
||||||
|
return this->m_impl.coeffRef(this->reverseIndex(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
void writePacket(Index index, const PacketReturnType& x) {
|
||||||
|
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
|
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
|
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
|
||||||
|
|
||||||
|
// This code is pilfered from TensorMorphing.h
|
||||||
|
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
|
||||||
|
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
|
||||||
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
|
this->coeffRef(index+i) = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -94,7 +94,7 @@ static void test_simple_reverse()
|
|||||||
|
|
||||||
|
|
||||||
template <int DataLayout>
|
template <int DataLayout>
|
||||||
static void test_expr_reverse()
|
static void test_expr_reverse(bool LValue)
|
||||||
{
|
{
|
||||||
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
|
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
|
||||||
tensor.setRandom();
|
tensor.setRandom();
|
||||||
@ -105,9 +105,12 @@ static void test_expr_reverse()
|
|||||||
dim_rev[2] = false;
|
dim_rev[2] = false;
|
||||||
dim_rev[3] = true;
|
dim_rev[3] = true;
|
||||||
|
|
||||||
|
Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
|
||||||
Tensor<float, 4, DataLayout> expected;
|
if (LValue) {
|
||||||
|
expected.reverse(dim_rev) = tensor;
|
||||||
|
} else {
|
||||||
expected = tensor.reverse(dim_rev);
|
expected = tensor.reverse(dim_rev);
|
||||||
|
}
|
||||||
|
|
||||||
Tensor<float, 4, DataLayout> result(2,3,5,7);
|
Tensor<float, 4, DataLayout> result(2,3,5,7);
|
||||||
|
|
||||||
@ -117,8 +120,13 @@ static void test_expr_reverse()
|
|||||||
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
|
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
|
||||||
|
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
if (LValue) {
|
||||||
|
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
|
||||||
|
tensor.slice(src_slice_start, src_slice_dim);
|
||||||
|
} else {
|
||||||
result.slice(dst_slice_start, dst_slice_dim) =
|
result.slice(dst_slice_start, dst_slice_dim) =
|
||||||
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
|
tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
|
||||||
|
}
|
||||||
src_slice_start[2] += 1;
|
src_slice_start[2] += 1;
|
||||||
dst_slice_start[2] += 1;
|
dst_slice_start[2] += 1;
|
||||||
}
|
}
|
||||||
@ -141,8 +149,13 @@ static void test_expr_reverse()
|
|||||||
dst_slice_start[2] = 0;
|
dst_slice_start[2] = 0;
|
||||||
result.setRandom();
|
result.setRandom();
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
if (LValue) {
|
||||||
|
result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
|
||||||
|
tensor.slice(dst_slice_start, dst_slice_dim);
|
||||||
|
} else {
|
||||||
result.slice(dst_slice_start, dst_slice_dim) =
|
result.slice(dst_slice_start, dst_slice_dim) =
|
||||||
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
|
tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
|
||||||
|
}
|
||||||
dst_slice_start[2] += 1;
|
dst_slice_start[2] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,6 +175,8 @@ void test_cxx11_tensor_reverse()
|
|||||||
{
|
{
|
||||||
CALL_SUBTEST(test_simple_reverse<ColMajor>());
|
CALL_SUBTEST(test_simple_reverse<ColMajor>());
|
||||||
CALL_SUBTEST(test_simple_reverse<RowMajor>());
|
CALL_SUBTEST(test_simple_reverse<RowMajor>());
|
||||||
CALL_SUBTEST(test_expr_reverse<ColMajor>());
|
CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
|
||||||
CALL_SUBTEST(test_expr_reverse<RowMajor>());
|
CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
|
||||||
|
CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
|
||||||
|
CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user