Support for in place evaluation of expressions containing slicing and reshaping operations

This commit is contained in:
Benoit Steiner 2014-08-13 08:27:58 -07:00
parent b1892ab14d
commit 1aa2bf8274

View File

@ -103,13 +103,14 @@ struct TensorEvaluator<const TensorReshapingOp<NewDimensions, ArgType>, Device>
{ }
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_impl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
return m_impl.evalSubExprsIfNeeded(data);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
@ -126,6 +127,8 @@ struct TensorEvaluator<const TensorReshapingOp<NewDimensions, ArgType>, Device>
return m_impl.template packet<LoadMode>(index);
}
Scalar* data() const { return NULL; }
protected:
NewDimensions m_dimensions;
TensorEvaluator<ArgType, Device> m_impl;
@ -150,13 +153,14 @@ struct TensorEvaluator<TensorReshapingOp<NewDimensions, ArgType>, Device>
{ }
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_impl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
return m_impl.evalSubExprsIfNeeded(data);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
@ -182,6 +186,8 @@ struct TensorEvaluator<TensorReshapingOp<NewDimensions, ArgType>, Device>
return m_impl.template packet<LoadMode>(index);
}
Scalar* data() const { return NULL; }
private:
NewDimensions m_dimensions;
TensorEvaluator<ArgType, Device> m_impl;
@ -306,14 +312,16 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef Sizes Dimensions;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_impl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
m_impl.evalSubExprsIfNeeded(NULL);
return true;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
@ -366,6 +374,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
}
}
Scalar* data() const { return NULL; }
private:
Dimensions m_dimensions;
array<Index, NumDims> m_outputStrides;
@ -415,14 +425,16 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef Sizes Dimensions;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_impl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
m_impl.evalSubExprsIfNeeded(NULL);
return true;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
@ -515,6 +527,8 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
}
}
Scalar* data() const { return NULL; }
private:
Dimensions m_dimensions;
array<Index, NumDims> m_outputStrides;