mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-09 22:39:05 +08:00
Added support for evaluation of tensor shuffling operations as lvalues
This commit is contained in:
parent
f50548e86a
commit
d43f737b4a
@ -222,19 +222,19 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes);
|
return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes);
|
||||||
}
|
}
|
||||||
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
TensorPaddingOp<const PaddingDimensions, Derived>
|
const TensorPaddingOp<const PaddingDimensions, const Derived>
|
||||||
pad(const PaddingDimensions& padding) const {
|
pad(const PaddingDimensions& padding) const {
|
||||||
return TensorPaddingOp<const PaddingDimensions, Derived>(derived(), padding);
|
return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding);
|
||||||
}
|
}
|
||||||
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
TensorShufflingOp<const Shuffle, Derived>
|
const TensorShufflingOp<const Shuffle, const Derived>
|
||||||
shuffle(const Shuffle& shuffle) const {
|
shuffle(const Shuffle& shuffle) const {
|
||||||
return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
|
return TensorShufflingOp<const Shuffle, const Derived>(derived(), shuffle);
|
||||||
}
|
}
|
||||||
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
TensorStridingOp<const Strides, Derived>
|
const TensorStridingOp<const Strides, const Derived>
|
||||||
stride(const Strides& strides) const {
|
stride(const Strides& strides) const {
|
||||||
return TensorStridingOp<const Strides, Derived>(derived(), strides);
|
return TensorStridingOp<const Strides, const Derived>(derived(), strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force the evaluation of the expression.
|
// Force the evaluation of the expression.
|
||||||
@ -244,6 +244,7 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor;
|
||||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
|
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
|
||||||
@ -258,6 +259,7 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
|||||||
typedef Scalar CoeffReturnType;
|
typedef Scalar CoeffReturnType;
|
||||||
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
|
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
|
||||||
|
|
||||||
|
template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor;
|
||||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -293,6 +295,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
|||||||
slice(const StartIndices& startIndices, const Sizes& sizes) const {
|
slice(const StartIndices& startIndices, const Sizes& sizes) const {
|
||||||
return TensorSlicingOp<const StartIndices, const Sizes, Derived>(derived(), startIndices, sizes);
|
return TensorSlicingOp<const StartIndices, const Sizes, Derived>(derived(), startIndices, sizes);
|
||||||
}
|
}
|
||||||
|
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
TensorShufflingOp<const Shuffle, Derived>
|
||||||
|
shuffle(const Shuffle& shuffle) const {
|
||||||
|
return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle);
|
||||||
|
}
|
||||||
|
|
||||||
// Select the device on which to evaluate the expression.
|
// Select the device on which to evaluate the expression.
|
||||||
template <typename DeviceType>
|
template <typename DeviceType>
|
||||||
|
@ -48,7 +48,7 @@ struct nested<TensorShufflingOp<Shuffle, XprType>, 1, typename eval<TensorShuffl
|
|||||||
|
|
||||||
|
|
||||||
template<typename Shuffle, typename XprType>
|
template<typename Shuffle, typename XprType>
|
||||||
class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType>, WriteAccessors>
|
class TensorShufflingOp : public TensorBase<TensorShufflingOp<Shuffle, XprType> >
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename Eigen::internal::traits<TensorShufflingOp>::Scalar Scalar;
|
typedef typename Eigen::internal::traits<TensorShufflingOp>::Scalar Scalar;
|
||||||
@ -94,33 +94,38 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
|||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||||
typedef DSizes<Index, NumDims> Dimensions;
|
typedef DSizes<Index, NumDims> Dimensions;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false,
|
IsAligned = true,
|
||||||
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false,
|
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||||
: m_impl(op.expression(), device), m_shuffle(op.shuffle())
|
: m_impl(op.expression(), device)
|
||||||
{
|
{
|
||||||
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||||
|
const Shuffle& shuffle = op.shuffle();
|
||||||
for (int i = 0; i < NumDims; ++i) {
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
m_dimensions[i] = input_dims[m_shuffle[i]];
|
m_dimensions[i] = input_dims[shuffle[i]];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array<Index, NumDims> inputStrides;
|
||||||
|
|
||||||
for (int i = 0; i < NumDims; ++i) {
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
|
inputStrides[i] = inputStrides[i-1] * input_dims[i-1];
|
||||||
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
|
||||||
} else {
|
} else {
|
||||||
m_inputStrides[0] = 1;
|
inputStrides[0] = 1;
|
||||||
m_outputStrides[0] = 1;
|
m_outputStrides[0] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
|
m_inputStrides[i] = inputStrides[shuffle[i]];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// typedef typename XprType::Index Index;
|
|
||||||
typedef typename XprType::Scalar Scalar;
|
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
@ -136,33 +141,90 @@ struct TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||||
{
|
{
|
||||||
Index inputIndex = 0;
|
return m_impl.coeff(srcCoeff(index));
|
||||||
for (int i = NumDims - 1; i > 0; --i) {
|
|
||||||
const Index idx = index / m_outputStrides[i];
|
|
||||||
inputIndex += idx * m_inputStrides[m_shuffle[i]];
|
|
||||||
index -= idx * m_outputStrides[i];
|
|
||||||
}
|
|
||||||
inputIndex += index * m_inputStrides[m_shuffle[0]];
|
|
||||||
return m_impl.coeff(inputIndex);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* template<int LoadMode>
|
template<int LoadMode>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||||
{
|
{
|
||||||
return m_impl.template packet<LoadMode>(index);
|
const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
}*/
|
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
|
eigen_assert(index+packetSize-1 < dimensions().TotalSize());
|
||||||
|
|
||||||
|
EIGEN_ALIGN_DEFAULT typename internal::remove_const<CoeffReturnType>::type values[packetSize];
|
||||||
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
|
values[i] = coeff(index+i);
|
||||||
|
}
|
||||||
|
PacketReturnType rslt = internal::pload<PacketReturnType>(values);
|
||||||
|
return rslt;
|
||||||
|
}
|
||||||
|
|
||||||
Scalar* data() const { return NULL; }
|
Scalar* data() const { return NULL; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index srcCoeff(Index index) const
|
||||||
|
{
|
||||||
|
Index inputIndex = 0;
|
||||||
|
for (int i = NumDims - 1; i > 0; --i) {
|
||||||
|
const Index idx = index / m_outputStrides[i];
|
||||||
|
inputIndex += idx * m_inputStrides[i];
|
||||||
|
index -= idx * m_outputStrides[i];
|
||||||
|
}
|
||||||
|
return inputIndex + index * m_inputStrides[0];
|
||||||
|
}
|
||||||
|
|
||||||
Dimensions m_dimensions;
|
Dimensions m_dimensions;
|
||||||
Shuffle m_shuffle;
|
|
||||||
array<Index, NumDims> m_outputStrides;
|
array<Index, NumDims> m_outputStrides;
|
||||||
array<Index, NumDims> m_inputStrides;
|
array<Index, NumDims> m_inputStrides;
|
||||||
TensorEvaluator<ArgType, Device> m_impl;
|
TensorEvaluator<ArgType, Device> m_impl;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Eval as lvalue
|
||||||
|
template<typename Shuffle, typename ArgType, typename Device>
|
||||||
|
struct TensorEvaluator<TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||||
|
: public TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device>
|
||||||
|
{
|
||||||
|
typedef TensorEvaluator<const TensorShufflingOp<Shuffle, ArgType>, Device> Base;
|
||||||
|
|
||||||
|
typedef TensorShufflingOp<Shuffle, ArgType> XprType;
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||||
|
typedef DSizes<Index, NumDims> Dimensions;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = true,
|
||||||
|
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||||
|
: Base(op, device)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
|
||||||
|
{
|
||||||
|
return this->m_impl.coeffRef(this->srcCoeff(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int StoreMode> EIGEN_STRONG_INLINE
|
||||||
|
void writePacket(Index index, const PacketReturnType& x)
|
||||||
|
{
|
||||||
|
static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
|
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
|
|
||||||
|
EIGEN_ALIGN_DEFAULT typename internal::remove_const<CoeffReturnType>::type values[packetSize];
|
||||||
|
internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
|
||||||
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
|
this->coeffRef(index+i) = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_SHUFFLING_H
|
||||||
|
Loading…
x
Reference in New Issue
Block a user