mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-06 02:34:05 +08:00
Added support for padding, stridding, and shuffling
This commit is contained in:
parent
16047c8d4a
commit
eeb43f9e2b
@ -42,6 +42,9 @@
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h"
|
||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
|
||||
|
@ -215,6 +215,21 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
slice(const StartIndices& startIndices, const Sizes& sizes) const {
|
||||
return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes);
|
||||
}
|
||||
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorPaddingOp<const PaddingDimensions, Derived>
|
||||
pad(const PaddingDimensions& padding) const {
|
||||
return TensorPaddingOp<const PaddingDimensions, Derived>(derived(), padding);
|
||||
}
|
||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorShufflingOp<const Shuffle, Derived>
|
||||
shuffle(const Shuffle& shuffle) const {
|
||||
return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
|
||||
}
|
||||
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
TensorStridingOp<const Strides, Derived>
|
||||
stride(const Strides& strides) const {
|
||||
return TensorStridingOp<const Strides, Derived>(derived(), strides);
|
||||
}
|
||||
|
||||
// Force the evaluation of the expression.
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
|
@ -26,6 +26,9 @@ template<typename Dimensions, typename LeftXprType, typename RightXprType> class
|
||||
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
|
||||
template<typename NewDimensions, typename XprType> class TensorReshapingOp;
|
||||
template<typename StartIndices, typename Sizes, typename XprType> class TensorSlicingOp;
|
||||
template<typename PaddingDimensions, typename XprType> class TensorPaddingOp;
|
||||
template<typename Shuffle, typename XprType> class TensorShufflingOp;
|
||||
template<typename Strides, typename XprType> class TensorStridingOp;
|
||||
template<typename LeftXprType, typename RightXprType> class TensorAssignOp;
|
||||
|
||||
template<typename XprType> class TensorEvalToOp;
|
||||
|
163
unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
Normal file
163
unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h
Normal file
@ -0,0 +1,163 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_PADDING_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_PADDING_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \class TensorPadding
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor padding class.
|
||||
* At the moment only 0-padding is supported.
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename PaddingDimensions, typename XprType>
|
||||
struct traits<TensorPaddingOp<PaddingDimensions, XprType> > : public traits<XprType>
|
||||
{
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename traits<XprType>::StorageKind StorageKind;
|
||||
typedef typename traits<XprType>::Index Index;
|
||||
typedef typename XprType::Nested Nested;
|
||||
typedef typename remove_reference<Nested>::type _Nested;
|
||||
};
|
||||
|
||||
template<typename PaddingDimensions, typename XprType>
|
||||
struct eval<TensorPaddingOp<PaddingDimensions, XprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorPaddingOp<PaddingDimensions, XprType>& type;
|
||||
};
|
||||
|
||||
template<typename PaddingDimensions, typename XprType>
|
||||
struct nested<TensorPaddingOp<PaddingDimensions, XprType>, 1, typename eval<TensorPaddingOp<PaddingDimensions, XprType> >::type>
|
||||
{
|
||||
typedef TensorPaddingOp<PaddingDimensions, XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename PaddingDimensions, typename XprType>
|
||||
class TensorPaddingOp : public TensorBase<TensorPaddingOp<PaddingDimensions, XprType> >
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorPaddingOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorPaddingOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorPaddingOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorPaddingOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorPaddingOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPaddingOp(const XprType& expr, const PaddingDimensions& padding_dims)
|
||||
: m_xpr(expr), m_padding_dims(padding_dims) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const PaddingDimensions& padding() const { return m_padding_dims; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||
expression() const { return m_xpr; }
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_xpr;
|
||||
const PaddingDimensions m_padding_dims;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename PaddingDimensions, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device>
|
||||
{
|
||||
typedef TensorPaddingOp<PaddingDimensions, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<PaddingDimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device), m_padding(op.padding())
|
||||
{
|
||||
// Compute dimensions
|
||||
m_dimensions = m_impl.dimensions();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_dimensions[i] += m_padding[i].first + m_padding[i].second;
|
||||
}
|
||||
|
||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
if (i > 0) {
|
||||
m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
|
||||
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
||||
} else {
|
||||
m_inputStrides[0] = 1;
|
||||
m_outputStrides[0] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
return true;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
m_impl.cleanup();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
Index inputIndex = 0;
|
||||
for (int i = NumDims - 1; i >= 0; --i) {
|
||||
const Index idx = index / m_outputStrides[i];
|
||||
if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
|
||||
return Scalar(0);
|
||||
}
|
||||
inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
|
||||
index -= idx * m_outputStrides[i];
|
||||
}
|
||||
return m_impl.coeff(inputIndex);
|
||||
}
|
||||
|
||||
/* template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_impl.template packet<LoadMode>(index);
|
||||
}*/
|
||||
|
||||
Scalar* data() const { return NULL; }
|
||||
|
||||
protected:
|
||||
PaddingDimensions m_padding;
|
||||
Dimensions m_dimensions;
|
||||
array<Index, NumDims> m_outputStrides;
|
||||
array<Index, NumDims> m_inputStrides;
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_PADDING_H
|
168
unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
Normal file
168
unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h
Normal file
@ -0,0 +1,168 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \class TensorShuffling
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor shuffling class.
|
||||
*
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename Shuffle, typename XprType>
|
||||
struct traits<TensorShufflingOp<Shuffle, XprType> > : public traits<XprType>
|
||||
{
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename traits<XprType>::StorageKind StorageKind;
|
||||
typedef typename traits<XprType>::Index Index;
|
||||
typedef typename XprType::Nested Nested;
|
||||
typedef typename remove_reference<Nested>::type _Nested;
|
||||
};
|
||||
|
||||
template<typename Shuffle, typename XprType>
|
||||
struct eval<TensorShufflingOp<Shuffle, XprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorShufflingOp<Shuffle, XprType>& type;
|
||||
};
|
||||
|
||||
template<typename Shuffle, typename XprType>
|
||||
struct nested<TensorShufflingOp<Shuffle, XprType>, 1, typename eval<TensorShufflingOp<Shuffle, XprType> >::type>
|
||||
{
|
||||
typedef TensorShufflingOp<Shuffle, XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename Shuffle, typename XprType>
|
||||
class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType>, WriteAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorShufflingOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp(const XprType& expr, const Shuffle& shuffle)
|
||||
: m_xpr(expr), m_shuffle(shuffle) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const Shuffle& shuffle() const { return m_shuffle; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||
expression() const { return m_xpr; }
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorShufflingOp& operator = (const OtherDerived& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorShufflingOp, const OtherDerived> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_xpr;
|
||||
const Shuffle m_shuffle;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename Shuffle, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||
{
|
||||
typedef TensorShufflingOp<Shuffle, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
enum {
|
||||
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
|
||||
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device), m_shuffle(op.shuffle())
|
||||
{
|
||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_dimensions[i] = input_dims[m_shuffle[i]];
|
||||
}
|
||||
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
if (i > 0) {
|
||||
m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
|
||||
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
||||
} else {
|
||||
m_inputStrides[0] = 1;
|
||||
m_outputStrides[0] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
return true;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
m_impl.cleanup();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
Index inputIndex = 0;
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
const Index idx = index / m_outputStrides[i];
|
||||
inputIndex += idx * m_inputStrides[m_shuffle[i]];
|
||||
index -= idx * m_outputStrides[i];
|
||||
}
|
||||
inputIndex += index * m_inputStrides[m_shuffle[0]];
|
||||
return m_impl.coeff(inputIndex);
|
||||
}
|
||||
|
||||
/* template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_impl.template packet<LoadMode>(index);
|
||||
}*/
|
||||
|
||||
Scalar* data() const { return NULL; }
|
||||
|
||||
protected:
|
||||
Dimensions m_dimensions;
|
||||
Shuffle m_shuffle;
|
||||
array<Index, NumDims> m_outputStrides;
|
||||
array<Index, NumDims> m_inputStrides;
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
172
unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h
Normal file
172
unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h
Normal file
@ -0,0 +1,172 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
|
||||
#define EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
/** \class TensorStriding
|
||||
* \ingroup CXX11_Tensor_Module
|
||||
*
|
||||
* \brief Tensor striding class.
|
||||
*
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename Strides, typename XprType>
|
||||
struct traits<TensorStridingOp<Strides, XprType> > : public traits<XprType>
|
||||
{
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename traits<XprType>::StorageKind StorageKind;
|
||||
typedef typename traits<XprType>::Index Index;
|
||||
typedef typename XprType::Nested Nested;
|
||||
typedef typename remove_reference<Nested>::type _Nested;
|
||||
};
|
||||
|
||||
template<typename Strides, typename XprType>
|
||||
struct eval<TensorStridingOp<Strides, XprType>, Eigen::Dense>
|
||||
{
|
||||
typedef const TensorStridingOp<Strides, XprType>& type;
|
||||
};
|
||||
|
||||
template<typename Strides, typename XprType>
|
||||
struct nested<TensorStridingOp<Strides, XprType>, 1, typename eval<TensorStridingOp<Strides, XprType> >::type>
|
||||
{
|
||||
typedef TensorStridingOp<Strides, XprType> type;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
||||
|
||||
template<typename Strides, typename XprType>
|
||||
class TensorStridingOp : public TensorBase<TensorStridingOp<Strides, XprType>, WriteAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorStridingOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorStridingOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorStridingOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorStridingOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorStridingOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorStridingOp(const XprType& expr, const Strides& dims)
|
||||
: m_xpr(expr), m_dims(dims) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const Strides& strides() const { return m_dims; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const typename internal::remove_all<typename XprType::Nested>::type&
|
||||
expression() const { return m_xpr; }
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorStridingOp& operator = (const OtherDerived& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorStridingOp, const OtherDerived> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
typename XprType::Nested m_xpr;
|
||||
const Strides m_dims;
|
||||
};
|
||||
|
||||
|
||||
// Eval as rvalue
|
||||
template<typename Strides, typename ArgType, typename Device>
|
||||
struct TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device>
|
||||
{
|
||||
typedef TensorStridingOp<Strides, ArgType> XprType;
|
||||
typedef typename XprType::Index Index;
|
||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
|
||||
enum {
|
||||
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
|
||||
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device)
|
||||
{
|
||||
m_dimensions = m_impl.dimensions();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_dimensions[i] = ceilf(static_cast<float>(m_dimensions[i]) / op.strides()[i]);
|
||||
}
|
||||
|
||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
if (i > 0) {
|
||||
m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
|
||||
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
||||
} else {
|
||||
m_inputStrides[0] = 1;
|
||||
m_outputStrides[0] = 1;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
m_inputStrides[i] *= op.strides()[i];
|
||||
}
|
||||
}
|
||||
|
||||
// typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
return true;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||
m_impl.cleanup();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
Index inputIndex = 0;
|
||||
for (int i = NumDims - 1; i > 0; --i) {
|
||||
const Index idx = index / m_outputStrides[i];
|
||||
inputIndex += idx * m_inputStrides[i];
|
||||
index -= idx * m_outputStrides[i];
|
||||
}
|
||||
inputIndex += index * m_inputStrides[0];
|
||||
return m_impl.coeff(inputIndex);
|
||||
}
|
||||
|
||||
/* template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||
{
|
||||
return m_impl.template packet<LoadMode>(index);
|
||||
}*/
|
||||
|
||||
Scalar* data() const { return NULL; }
|
||||
|
||||
protected:
|
||||
// Strides m_strides;
|
||||
Dimensions m_dimensions;
|
||||
array<Index, NumDims> m_outputStrides;
|
||||
array<Index, NumDims> m_inputStrides;
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
};
|
||||
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_STRIDING_H
|
Loading…
x
Reference in New Issue
Block a user