mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-14 09:01:47 +08:00
Improved support for rvalues in tensor expressions.
This commit is contained in:
parent
36a2b2e9dc
commit
a669052f12
@ -22,7 +22,7 @@ namespace Eigen {
|
||||
*/
|
||||
|
||||
template<typename Derived>
|
||||
class TensorBase
|
||||
class TensorBase<Derived, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename internal::traits<Derived>::Scalar Scalar;
|
||||
@ -30,19 +30,6 @@ class TensorBase
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setZero() {
|
||||
return setConstant(Scalar(0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
|
||||
return derived() = constant(val);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setRandom() {
|
||||
return derived() = random();
|
||||
}
|
||||
|
||||
// Nullary operators
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
|
||||
@ -224,14 +211,53 @@ class TensorBase
|
||||
return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions);
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
|
||||
};
|
||||
|
||||
|
||||
template<typename Derived>
|
||||
class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyAccessors> {
|
||||
public:
|
||||
typedef typename internal::traits<Derived>::Scalar Scalar;
|
||||
typedef typename internal::traits<Derived>::Index Index;
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
|
||||
|
||||
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setZero() {
|
||||
return setConstant(Scalar(0));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
|
||||
return derived() = this->constant(val);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& setRandom() {
|
||||
return derived() = this->random();
|
||||
}
|
||||
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Derived& operator+=(const OtherDerived& other) {
|
||||
return derived() = TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||
}
|
||||
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Derived& operator-=(const OtherDerived& other) {
|
||||
return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||
}
|
||||
|
||||
// Select the device on which to evaluate the expression.
|
||||
template <typename DeviceType>
|
||||
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
|
||||
return TensorDevice<Derived, DeviceType>(device, derived());
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename OtherDerived> friend class TensorBase;
|
||||
protected:
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); }
|
||||
EIGEN_DEVICE_FUNC
|
||||
|
@ -35,6 +35,10 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
||||
typedef typename RhsXprType::Nested RhsNested;
|
||||
typedef typename remove_reference<LhsNested>::type _LhsNested;
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
};
|
||||
};
|
||||
|
||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
||||
|
@ -35,6 +35,10 @@ struct traits<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> >
|
||||
typedef typename KernelXprType::Nested RhsNested;
|
||||
typedef typename remove_reference<LhsNested>::type _LhsNested;
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
};
|
||||
};
|
||||
|
||||
template<typename Dimensions, typename InputXprType, typename KernelXprType>
|
||||
|
@ -36,6 +36,10 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::Nested XprTypeNested;
|
||||
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
};
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
@ -153,6 +157,10 @@ struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
|
||||
typedef typename RhsXprType::Nested RhsNested;
|
||||
typedef typename remove_reference<LhsNested>::type _LhsNested;
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
};
|
||||
};
|
||||
|
||||
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
|
||||
|
@ -15,7 +15,7 @@ namespace Eigen {
|
||||
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
|
||||
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
|
||||
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
|
||||
template<typename Derived> class TensorBase;
|
||||
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
|
||||
|
||||
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
|
||||
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
|
||||
@ -29,6 +29,10 @@ template<typename ExpressionType, typename DeviceType> class TensorDevice;
|
||||
// Move to internal?
|
||||
template<typename Derived> struct TensorEvaluator;
|
||||
|
||||
namespace internal {
|
||||
template<typename Derived, typename OtherDerived, bool Vectorizable> struct TensorAssign;
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H
|
||||
|
@ -21,7 +21,7 @@ namespace Eigen {
|
||||
*/
|
||||
namespace internal {
|
||||
template<typename XprType, typename NewDimensions>
|
||||
struct traits<TensorReshapingOp<XprType, NewDimensions> >
|
||||
struct traits<TensorReshapingOp<XprType, NewDimensions> > : public traits<XprType>
|
||||
{
|
||||
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
@ -81,6 +81,7 @@ template<typename ArgType, typename NewDimensions>
|
||||
struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
|
||||
{
|
||||
typedef TensorReshapingOp<ArgType, NewDimensions> XprType;
|
||||
typedef NewDimensions Dimensions;
|
||||
|
||||
enum {
|
||||
IsAligned = TensorEvaluator<ArgType>::IsAligned,
|
||||
@ -95,7 +96,7 @@ struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
const NewDimensions& dimensions() const { return m_dimensions; }
|
||||
const Dimensions& dimensions() const { return m_dimensions; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
|
@ -52,7 +52,7 @@ struct traits<Tensor<Scalar_, NumIndices_, Options_> >
|
||||
typedef DenseIndex Index;
|
||||
enum {
|
||||
Options = Options_,
|
||||
Flags = compute_tensor_flags<Scalar_, Options_>::ret,
|
||||
Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
|
||||
};
|
||||
};
|
||||
|
||||
@ -63,6 +63,10 @@ struct traits<TensorFixedSize<Scalar_, Dimensions, Options_> >
|
||||
typedef Scalar_ Scalar;
|
||||
typedef Dense StorageKind;
|
||||
typedef DenseIndex Index;
|
||||
enum {
|
||||
Options = Options_,
|
||||
Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user