mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Added the -= operator to the device classes
This commit is contained in:
parent
e134226a03
commit
a6a628ca6b
@ -21,8 +21,7 @@ namespace Eigen {
|
||||
* Example:
|
||||
* C.device(EIGEN_GPU) = A + B;
|
||||
*
|
||||
* Todo: thread pools.
|
||||
* Todo: operator +=, -=, *= and so on.
|
||||
* Todo: operator *= and /=.
|
||||
*/
|
||||
|
||||
template <typename ExpressionType, typename DeviceType> class TensorDevice {
|
||||
@ -50,6 +49,18 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
|
||||
typedef typename OtherDerived::Scalar Scalar;
|
||||
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
|
||||
Difference difference(m_expression, other);
|
||||
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
|
||||
Assign assign(m_expression, difference);
|
||||
static const bool Vectorize = TensorEvaluator<const Assign, DeviceType>::PacketAccess;
|
||||
internal::TensorExecutor<const Assign, DeviceType, Vectorize>::run(assign, m_device);
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
const DeviceType& m_device;
|
||||
ExpressionType& m_expression;
|
||||
@ -82,6 +93,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
|
||||
typedef typename OtherDerived::Scalar Scalar;
|
||||
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
|
||||
Difference difference(m_expression, other);
|
||||
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
|
||||
Assign assign(m_expression, difference);
|
||||
static const bool Vectorize = TensorEvaluator<const Assign, ThreadPoolDevice>::PacketAccess;
|
||||
internal::TensorExecutor<const Assign, ThreadPoolDevice, Vectorize>::run(assign, m_device);
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
const ThreadPoolDevice& m_device;
|
||||
ExpressionType& m_expression;
|
||||
@ -114,6 +137,18 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
|
||||
typedef typename OtherDerived::Scalar Scalar;
|
||||
typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
|
||||
Difference difference(m_expression, other);
|
||||
typedef TensorAssignOp<ExpressionType, const Difference> Assign;
|
||||
Assign assign(m_expression, difference);
|
||||
static const bool Vectorize = TensorEvaluator<const Assign, GpuDevice>::PacketAccess;
|
||||
internal::TensorExecutor<const Assign, GpuDevice, Vectorize>::run(assign, m_device);
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
const GpuDevice& m_device;
|
||||
ExpressionType m_expression;
|
||||
|
Loading…
x
Reference in New Issue
Block a user