Added the -= operator to the device classes

This commit is contained in:
Benoit Steiner 2015-03-19 23:22:19 -07:00
parent e134226a03
commit a6a628ca6b

View File

@ -21,8 +21,7 @@ namespace Eigen {
* Example:
* C.device(EIGEN_GPU) = A + B;
*
* Todo: thread pools.
* Todo: operator +=, -=, *= and so on.
* Todo: operator *= and /=.
*/
template <typename ExpressionType, typename DeviceType> class TensorDevice {
@ -50,6 +49,18 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
return *this;
}
template<typename OtherDerived>
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
typedef typename OtherDerived::Scalar Scalar;
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
Difference difference(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
Assign assign(m_expression, difference);
static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess;
internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device);
return *this;
}
protected:
const DeviceType& m_device;
ExpressionType& m_expression;
@ -82,6 +93,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool
return *this;
}
template<typename OtherDerived>
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
typedef typename OtherDerived::Scalar Scalar;
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
Difference difference(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
Assign assign(m_expression, difference);
static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess;
internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device);
return *this;
}
protected:
const ThreadPoolDevice& m_device;
ExpressionType& m_expression;
@ -114,6 +137,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
return *this;
}
template<typename OtherDerived>
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
typedef typename OtherDerived::Scalar Scalar;
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
Difference difference(m_expression, other);
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
Assign assign(m_expression, difference);
static const bool Vectorize = TensorEvaluator<const Assign, GpuDevice>::PacketAccess;
internal::TensorExecutor<const Assign, GpuDevice, Vectorize>::run(assign, m_device);
return *this;
}
protected:
const GpuDevice& m_device;
ExpressionType m_expression;