Improved support for rvalues in tensor expressions.

This commit is contained in:
Benoit Steiner 2014-06-09 09:45:30 -07:00
parent 36a2b2e9dc
commit a669052f12
7 changed files with 71 additions and 20 deletions

View File

@ -22,7 +22,7 @@ namespace Eigen {
*/
template<typename Derived>
class TensorBase
class TensorBase<Derived, ReadOnlyAccessors>
{
public:
typedef typename internal::traits<Derived>::Scalar Scalar;
@ -30,19 +30,6 @@ class TensorBase
typedef Scalar CoeffReturnType;
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setZero() {
return setConstant(Scalar(0));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
return derived() = constant(val);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setRandom() {
return derived() = random();
}
// Nullary operators
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
@ -224,14 +211,53 @@ class TensorBase
return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions);
}
protected:
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
template<typename Derived>
class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyAccessors> {
public:
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef typename internal::traits<Derived>::Index Index;
typedef Scalar CoeffReturnType;
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
template <typename OtherDerived, int AccessLevel> friend class TensorBase;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setZero() {
return setConstant(Scalar(0));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
return derived() = this->constant(val);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setRandom() {
return derived() = this->random();
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Derived& operator+=(const OtherDerived& other) {
return derived() = TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Derived& operator-=(const OtherDerived& other) {
return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
// Select the device on which to evaluate the expression.
template <typename DeviceType>
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
return TensorDevice<Derived, DeviceType>(device, derived());
}
protected:
template <typename OtherDerived> friend class TensorBase;
protected:
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); }
EIGEN_DEVICE_FUNC

View File

@ -35,6 +35,10 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
enum {
Flags = 0,
};
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType>

View File

@ -35,6 +35,10 @@ struct traits<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> >
typedef typename KernelXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
enum {
Flags = 0,
};
};
template<typename Dimensions, typename InputXprType, typename KernelXprType>

View File

@ -36,6 +36,10 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
enum {
Flags = 0,
};
};
} // end namespace internal
@ -153,6 +157,10 @@ struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
enum {
Flags = 0,
};
};
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>

View File

@ -15,7 +15,7 @@ namespace Eigen {
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
template<typename Derived> class TensorBase;
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
@ -29,6 +29,10 @@ template<typename ExpressionType, typename DeviceType> class TensorDevice;
// Move to internal?
template<typename Derived> struct TensorEvaluator;
namespace internal {
template<typename Derived, typename OtherDerived, bool Vectorizable> struct TensorAssign;
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H

View File

@ -21,7 +21,7 @@ namespace Eigen {
*/
namespace internal {
template<typename XprType, typename NewDimensions>
struct traits<TensorReshapingOp<XprType, NewDimensions> >
struct traits<TensorReshapingOp<XprType, NewDimensions> > : public traits<XprType>
{
// Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename XprType::Scalar Scalar;
@ -81,6 +81,7 @@ template<typename ArgType, typename NewDimensions>
struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
{
typedef TensorReshapingOp<ArgType, NewDimensions> XprType;
typedef NewDimensions Dimensions;
enum {
IsAligned = TensorEvaluator<ArgType>::IsAligned,
@ -95,7 +96,7 @@ struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
const NewDimensions& dimensions() const { return m_dimensions; }
const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
{

View File

@ -52,7 +52,7 @@ struct traits<Tensor<Scalar_, NumIndices_, Options_> >
typedef DenseIndex Index;
enum {
Options = Options_,
Flags = compute_tensor_flags<Scalar_, Options_>::ret,
Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
};
};
@ -63,6 +63,10 @@ struct traits<TensorFixedSize<Scalar_, Dimensions, Options_> >
typedef Scalar_ Scalar;
typedef Dense StorageKind;
typedef DenseIndex Index;
enum {
Options = Options_,
Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
};
};