TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type.

Started work on partial evaluations.
This commit is contained in:
Benoit Steiner 2014-06-10 09:14:44 -07:00
parent a77458a8ff
commit 925fb6b937
9 changed files with 129 additions and 102 deletions

View File

@ -32,15 +32,15 @@ namespace Eigen {
namespace internal {
// Default strategy: the expressions are evaluated with a single cpu thread.
template<typename Derived1, typename Derived2, bool Vectorizable = TensorEvaluator<Derived1>::PacketAccess & TensorEvaluator<Derived2>::PacketAccess>
template<typename Derived1, typename Derived2, typename Device = DefaultDevice, bool Vectorizable = TensorEvaluator<Derived1, Device>::PacketAccess & TensorEvaluator<Derived2, Device>::PacketAccess>
struct TensorAssign
{
typedef typename Derived1::Index Index;
EIGEN_DEVICE_FUNC
static inline void run(Derived1& dst, const Derived2& src)
static inline void run(Derived1& dst, const Derived2& src, const Device& device = Device())
{
TensorEvaluator<Derived1> evalDst(dst);
TensorEvaluator<Derived2> evalSrc(src);
TensorEvaluator<Derived1, Device> evalDst(dst, device);
TensorEvaluator<Derived2, Device> evalSrc(src, device);
const Index size = dst.size();
for (Index i = 0; i < size; ++i) {
evalDst.coeffRef(i) = evalSrc.coeff(i);
@ -49,19 +49,19 @@ struct TensorAssign
};
template<typename Derived1, typename Derived2>
struct TensorAssign<Derived1, Derived2, true>
template<typename Derived1, typename Derived2, typename Device>
struct TensorAssign<Derived1, Derived2, Device, true>
{
typedef typename Derived1::Index Index;
static inline void run(Derived1& dst, const Derived2& src)
static inline void run(Derived1& dst, const Derived2& src, const Device& device = Device())
{
TensorEvaluator<Derived1> evalDst(dst);
TensorEvaluator<Derived2> evalSrc(src);
TensorEvaluator<Derived1, Device> evalDst(dst, device);
TensorEvaluator<Derived2, Device> evalSrc(src, device);
const Index size = dst.size();
static const int LhsStoreMode = TensorEvaluator<Derived1>::IsAligned ? Aligned : Unaligned;
static const int RhsLoadMode = TensorEvaluator<Derived2>::IsAligned ? Aligned : Unaligned;
static const int PacketSize = unpacket_traits<typename TensorEvaluator<Derived1>::PacketReturnType>::size;
static const int LhsStoreMode = TensorEvaluator<Derived1, Device>::IsAligned ? Aligned : Unaligned;
static const int RhsLoadMode = TensorEvaluator<Derived2, Device>::IsAligned ? Aligned : Unaligned;
static const int PacketSize = unpacket_traits<typename TensorEvaluator<Derived1, Device>::PacketReturnType>::size;
const int VectorizedSize = (size / PacketSize) * PacketSize;
for (Index i = 0; i < VectorizedSize; i += PacketSize) {
@ -116,12 +116,12 @@ struct TensorAssignMultiThreaded
typedef typename Derived1::Index Index;
static inline void run(Derived1& dst, const Derived2& src, const ThreadPoolDevice& device)
{
TensorEvaluator<Derived1> evalDst(dst);
TensorEvaluator<Derived2> evalSrc(src);
TensorEvaluator<Derived1, DefaultDevice> evalDst(dst, DefaultDevice());
TensorEvaluator<Derived2, DefaultDevice> evalSrc(src, Defaultevice());
const Index size = dst.size();
static const bool Vectorizable = TensorEvaluator<Derived1>::PacketAccess & TensorEvaluator<Derived2>::PacketAccess;
static const int PacketSize = Vectorizable ? unpacket_traits<typename TensorEvaluator<Derived1>::PacketReturnType>::size : 1;
static const bool Vectorizable = TensorEvaluator<Derived1, DefaultDevice>::PacketAccess & TensorEvaluator<Derived2, DefaultDevice>::PacketAccess;
static const int PacketSize = Vectorizable ? unpacket_traits<typename TensorEvaluator<Derived1, DefaultDevice>::PacketReturnType>::size : 1;
int blocksz = static_cast<int>(ceil(static_cast<float>(size)/device.numThreads()) + PacketSize - 1);
const Index blocksize = std::max<Index>(PacketSize, (blocksz - (blocksz % PacketSize)));
@ -131,7 +131,7 @@ struct TensorAssignMultiThreaded
vector<std::future<void> > results;
results.reserve(numblocks);
for (int i = 0; i < numblocks; ++i) {
results.push_back(std::async(std::launch::async, &EvalRange<TensorEvaluator<Derived1>, TensorEvaluator<Derived2>, Index>::run, evalDst, evalSrc, i*blocksize, (i+1)*blocksize));
results.push_back(std::async(std::launch::async, &EvalRange<TensorEvaluator<Derived1, DefaultDevice>, TensorEvaluator<Derived2, DefaultDevice>, Index>::run, evalDst, evalSrc, i*blocksize, (i+1)*blocksize));
}
for (int i = 0; i < numblocks; ++i) {
@ -167,19 +167,19 @@ struct TensorAssignGpu
typedef typename Derived1::Index Index;
static inline void run(Derived1& dst, const Derived2& src, const GpuDevice& device)
{
TensorEvaluator<Derived1> evalDst(dst);
TensorEvaluator<Derived2> evalSrc(src);
TensorEvaluator<Derived1, GpuDevice> evalDst(dst, device);
TensorEvaluator<Derived2, GpuDevice> evalSrc(src, device);
const Index size = dst.size();
const int block_size = std::min<int>(size, 32*32);
const int num_blocks = size / block_size;
EigenMetaKernelNoCheck<TensorEvaluator<Derived1>, TensorEvaluator<Derived2> > <<<num_blocks, block_size, 0, device.stream()>>>(evalDst, evalSrc);
EigenMetaKernelNoCheck<TensorEvaluator<Derived1, GpuDevice>, TensorEvaluator<Derived2, GpuDevice> > <<<num_blocks, block_size, 0, device.stream()>>>(evalDst, evalSrc);
const int remaining_items = size % block_size;
if (remaining_items > 0) {
const int peel_start_offset = num_blocks * block_size;
const int peel_block_size = std::min<int>(size, 32);
const int peel_num_blocks = (remaining_items + peel_block_size - 1) / peel_block_size;
EigenMetaKernelPeel<TensorEvaluator<Derived1>, TensorEvaluator<Derived2> > <<<peel_num_blocks, peel_block_size, 0, device.stream()>>>(evalDst, evalSrc, peel_start_offset, size);
EigenMetaKernelPeel<TensorEvaluator<Derived1, GpuDevice>, TensorEvaluator<Derived2, GpuDevice> > <<<peel_num_blocks, peel_block_size, 0, device.stream()>>>(evalDst, evalSrc, peel_start_offset, size);
}
}
};

View File

@ -198,19 +198,25 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
// Coefficient-wise ternary operators.
template<typename ThenDerived, typename ElseDerived>
inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const {
return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
}
// Morphing operators (slicing tbd).
template <typename NewDimensions>
inline const TensorReshapingOp<const Derived, const NewDimensions>
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorReshapingOp<const Derived, const NewDimensions>
reshape(const NewDimensions& newDimensions) const {
return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions);
}
// Force the evaluation of the expression.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorForcedEvalOp<const Derived> eval() const {
return TensorForcedEvalOp<const Derived>(derived());
}
protected:
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
EIGEN_DEVICE_FUNC

View File

@ -102,31 +102,31 @@ template <> struct max_n_1<0> {
};
template<typename Indices, typename LeftArgType, typename RightArgType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType> >
template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device>
{
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
static const int NumDims = max_n_1<TensorEvaluator<LeftArgType>::Dimensions::count + TensorEvaluator<RightArgType>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size;
static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size;
typedef typename XprType::Index Index;
typedef DSizes<Index, NumDims> Dimensions;
enum {
IsAligned = TensorEvaluator<LeftArgType>::IsAligned & TensorEvaluator<RightArgType>::IsAligned,
IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
PacketAccess = /*TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess */
false,
};
TensorEvaluator(const XprType& op)
: m_leftImpl(op.lhsExpression()), m_rightImpl(op.rhsExpression())
TensorEvaluator(const XprType& op, const Device& device)
: m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device)
{
Index index = 0;
Index stride = 1;
m_shiftright = 1;
int skipped = 0;
const typename TensorEvaluator<LeftArgType>::Dimensions& left_dims = m_leftImpl.dimensions();
for (int i = 0; i < TensorEvaluator<LeftArgType>::Dimensions::count; ++i) {
const typename TensorEvaluator<LeftArgType, Device>::Dimensions& left_dims = m_leftImpl.dimensions();
for (int i = 0; i < TensorEvaluator<LeftArgType, Device>::Dimensions::count; ++i) {
bool skip = false;
for (int j = 0; j < internal::array_size<Indices>::value; ++j) {
if (op.indices()[j].first == i) {
@ -148,8 +148,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
stride = 1;
skipped = 0;
const typename TensorEvaluator<RightArgType>::Dimensions& right_dims = m_rightImpl.dimensions();
for (int i = 0; i < TensorEvaluator<RightArgType>::Dimensions::count; ++i) {
const typename TensorEvaluator<RightArgType, Device>::Dimensions& right_dims = m_rightImpl.dimensions();
for (int i = 0; i < TensorEvaluator<RightArgType, Device>::Dimensions::count; ++i) {
bool skip = false;
for (int j = 0; j < internal::array_size<Indices>::value; ++j) {
if (op.indices()[j].second == i) {
@ -168,7 +168,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
// Scalar case
if (TensorEvaluator<LeftArgType>::Dimensions::count + TensorEvaluator<LeftArgType>::Dimensions::count == 2 * internal::array_size<Indices>::value) {
if (TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<LeftArgType, Device>::Dimensions::count == 2 * internal::array_size<Indices>::value) {
m_dimensions[0] = 1;
}
}
@ -223,8 +223,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
array<Index, internal::array_size<Indices>::value> m_stitchsize;
Index m_shiftright;
Dimensions m_dimensions;
TensorEvaluator<LeftArgType> m_leftImpl;
TensorEvaluator<RightArgType> m_rightImpl;
TensorEvaluator<LeftArgType, Device> m_leftImpl;
TensorEvaluator<RightArgType, Device> m_rightImpl;
};

View File

@ -94,27 +94,27 @@ class TensorConvolutionOp : public TensorBase<TensorConvolutionOp<Indices, Input
};
template<typename Indices, typename InputArgType, typename KernelArgType>
struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType> >
template<typename Indices, typename InputArgType, typename KernelArgType, typename Device>
struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelArgType>, Device>
{
typedef TensorConvolutionOp<Indices, InputArgType, KernelArgType> XprType;
static const int NumDims = TensorEvaluator<InputArgType>::Dimensions::count;
static const int NumDims = TensorEvaluator<InputArgType, Device>::Dimensions::count;
static const int KernelDims = Indices::size;
typedef typename XprType::Index Index;
typedef DSizes<Index, NumDims> Dimensions;
enum {
IsAligned = TensorEvaluator<InputArgType>::IsAligned & TensorEvaluator<KernelArgType>::IsAligned,
IsAligned = TensorEvaluator<InputArgType, Device>::IsAligned & TensorEvaluator<KernelArgType, Device>::IsAligned,
PacketAccess = /*TensorEvaluator<InputArgType>::PacketAccess & TensorEvaluator<KernelArgType>::PacketAccess */
false,
};
TensorEvaluator(const XprType& op)
: m_inputImpl(op.inputExpression()), m_kernelImpl(op.kernelExpression()), m_dimensions(op.inputExpression().dimensions())
TensorEvaluator(const XprType& op, const Device& device)
: m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_dimensions(op.inputExpression().dimensions())
{
const typename TensorEvaluator<InputArgType>::Dimensions& input_dims = m_inputImpl.dimensions();
const typename TensorEvaluator<KernelArgType>::Dimensions& kernel_dims = m_kernelImpl.dimensions();
const typename TensorEvaluator<InputArgType, Device>::Dimensions& input_dims = m_inputImpl.dimensions();
const typename TensorEvaluator<KernelArgType, Device>::Dimensions& kernel_dims = m_kernelImpl.dimensions();
for (int i = 0; i < NumDims; ++i) {
if (i > 0) {
@ -200,8 +200,8 @@ struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelAr
array<Index, KernelDims> m_indexStride;
array<Index, KernelDims> m_kernelStride;
Dimensions m_dimensions;
TensorEvaluator<InputArgType> m_inputImpl;
TensorEvaluator<KernelArgType> m_kernelImpl;
TensorEvaluator<InputArgType, Device> m_inputImpl;
TensorEvaluator<KernelArgType, Device> m_kernelImpl;
};

View File

@ -31,7 +31,7 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
template<typename OtherDerived>
EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
internal::TensorAssign<ExpressionType, const OtherDerived>::run(m_expression, other);
internal::TensorAssign<ExpressionType, const OtherDerived, DeviceType>::run(m_expression, other, m_device);
return *this;
}

View File

@ -15,6 +15,12 @@ namespace Eigen {
// Default device for the machine (typically a single cpu core)
struct DefaultDevice {
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
return internal::aligned_malloc(num_bytes);
}
EIGEN_STRONG_INLINE void deallocate(void* buffer) const {
internal::aligned_free(buffer);
}
};
@ -24,12 +30,17 @@ struct DefaultDevice {
struct ThreadPoolDevice {
ThreadPoolDevice(/*ThreadPool* pool, */size_t num_cores) : /*pool_(pool), */num_threads_(num_cores) { }
size_t numThreads() const { return num_threads_; }
/*ThreadPool* threadPool() const { return pool_; }*/
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
return internal::aligned_malloc(num_bytes);
}
EIGEN_STRONG_INLINE void deallocate(void* buffer) const {
internal::aligned_free(buffer);
}
private:
// todo: NUMA, ...
size_t num_threads_;
/*ThreadPool* pool_;*/
};
#endif
@ -40,7 +51,16 @@ struct GpuDevice {
// The cudastream is not owned: the caller is responsible for its initialization and eventual destruction.
GpuDevice(const cudaStream_t* stream) : stream_(stream) { eigen_assert(stream); }
const cudaStream_t& stream() const { return *stream_; }
EIGEN_STRONG_INLINE const cudaStream_t& stream() const { return *stream_; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
void* result;
cudaMalloc(&result, num_bytes);
return result;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void deallocate(void* buffer) const {
cudaFree(buffer);
}
private:
// TODO: multigpu.

View File

@ -23,7 +23,7 @@ namespace Eigen {
* leading to lvalues (slicing, reshaping, etc...)
*/
template<typename Derived>
template<typename Derived, typename Device>
struct TensorEvaluator
{
typedef typename Derived::Index Index;
@ -38,7 +38,7 @@ struct TensorEvaluator
PacketAccess = Derived::PacketAccess,
};
EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m)
EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m, const Device&)
: m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions())
{ }
@ -73,8 +73,8 @@ struct TensorEvaluator
// -------------------- CwiseNullaryOp --------------------
template<typename NullaryOp, typename ArgType>
struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
template<typename NullaryOp, typename ArgType, typename Device>
struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
{
typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
@ -84,14 +84,14 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
};
EIGEN_DEVICE_FUNC
TensorEvaluator(const XprType& op)
: m_functor(op.functor()), m_argImpl(op.nestedExpression())
TensorEvaluator(const XprType& op, const Device& device)
: m_functor(op.functor()), m_argImpl(op.nestedExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
@ -108,32 +108,32 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
private:
const NullaryOp m_functor;
TensorEvaluator<ArgType> m_argImpl;
TensorEvaluator<ArgType, Device> m_argImpl;
};
// -------------------- CwiseUnaryOp --------------------
template<typename UnaryOp, typename ArgType>
struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
template<typename UnaryOp, typename ArgType, typename Device>
struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
{
typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
enum {
IsAligned = TensorEvaluator<ArgType>::IsAligned,
PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
: m_functor(op.functor()),
m_argImpl(op.nestedExpression())
m_argImpl(op.nestedExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
@ -150,33 +150,33 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
private:
const UnaryOp m_functor;
TensorEvaluator<ArgType> m_argImpl;
TensorEvaluator<ArgType, Device> m_argImpl;
};
// -------------------- CwiseBinaryOp --------------------
template<typename BinaryOp, typename LeftArgType, typename RightArgType>
struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> >
template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
{
typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
enum {
IsAligned = TensorEvaluator<LeftArgType>::IsAligned & TensorEvaluator<RightArgType>::IsAligned,
PacketAccess = TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess &
IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
internal::functor_traits<BinaryOp>::PacketAccess,
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
: m_functor(op.functor()),
m_leftImpl(op.lhsExpression()),
m_rightImpl(op.rhsExpression())
m_leftImpl(op.lhsExpression(), device),
m_rightImpl(op.rhsExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<LeftArgType>::Dimensions Dimensions;
typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
@ -196,34 +196,34 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
private:
const BinaryOp m_functor;
TensorEvaluator<LeftArgType> m_leftImpl;
TensorEvaluator<RightArgType> m_rightImpl;
TensorEvaluator<LeftArgType, Device> m_leftImpl;
TensorEvaluator<RightArgType, Device> m_rightImpl;
};
// -------------------- SelectOp --------------------
template<typename IfArgType, typename ThenArgType, typename ElseArgType>
struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
{
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
enum {
IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* &
TensorEvaluator<IfArgType>::PacketAccess*/,
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_condImpl(op.ifExpression()),
m_thenImpl(op.thenExpression()),
m_elseImpl(op.elseExpression())
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
: m_condImpl(op.ifExpression(), device),
m_thenImpl(op.thenExpression(), device),
m_elseImpl(op.elseExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<IfArgType>::Dimensions Dimensions;
typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
@ -248,9 +248,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
}
private:
TensorEvaluator<IfArgType> m_condImpl;
TensorEvaluator<ThenArgType> m_thenImpl;
TensorEvaluator<ElseArgType> m_elseImpl;
TensorEvaluator<IfArgType, Device> m_condImpl;
TensorEvaluator<ThenArgType, Device> m_thenImpl;
TensorEvaluator<ElseArgType, Device> m_elseImpl;
};

View File

@ -21,16 +21,17 @@ template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryO
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename XprType> class TensorReductionOp;
template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp;
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
template<typename NewDimensions, typename XprType> class TensorReshapingOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice;
template<typename XprType> class TensorForcedEvalOp;
// Move to internal?
template<typename Derived> struct TensorEvaluator;
template<typename ExpressionType, typename DeviceType> class TensorDevice;
template<typename Derived, typename Device> struct TensorEvaluator;
namespace internal {
template<typename Derived, typename OtherDerived, bool Vectorizable> struct TensorAssign;
template<typename Derived, typename OtherDerived, typename Device, bool Vectorizable> struct TensorAssign;
} // end namespace internal
} // end namespace Eigen

View File

@ -77,19 +77,19 @@ class TensorReshapingOp : public TensorBase<TensorReshapingOp<XprType, NewDimens
};
template<typename ArgType, typename NewDimensions>
struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
template<typename ArgType, typename NewDimensions, typename Device>
struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions>, Device>
{
typedef TensorReshapingOp<ArgType, NewDimensions> XprType;
typedef NewDimensions Dimensions;
enum {
IsAligned = TensorEvaluator<ArgType>::IsAligned,
PacketAccess = TensorEvaluator<ArgType>::PacketAccess,
IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
};
TensorEvaluator(const XprType& op)
: m_impl(op.expression()), m_dimensions(op.dimensions())
TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_dimensions(op.dimensions())
{ }
typedef typename XprType::Index Index;
@ -111,7 +111,7 @@ struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
private:
NewDimensions m_dimensions;
TensorEvaluator<ArgType> m_impl;
TensorEvaluator<ArgType, Device> m_impl;
};