From 736267cf6b17832a571acf7e34ca07c7f55907ee Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 22 May 2014 16:22:35 -0700 Subject: [PATCH] Added support for additional tensor operations: * comparison (<, <=, ==, !=, ...) * selection * nullary ops such as random or constant generation * misc unary ops such as log(), exp(), or a user defined unaryExpr() Cleaned up the code a little. --- .../Eigen/CXX11/src/Tensor/TensorBase.h | 139 ++++++++++++++++-- .../Eigen/CXX11/src/Tensor/TensorEvaluator.h | 84 +++++++++++ .../Eigen/CXX11/src/Tensor/TensorExpr.h | 109 ++++++++++++++ .../src/Tensor/TensorForwardDeclarations.h | 2 + .../Eigen/CXX11/src/Tensor/TensorMap.h | 36 +++-- 5 files changed, 339 insertions(+), 31 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index fa1bd3498..8a88ba806 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -33,21 +33,25 @@ class TensorBase Derived& setZero() { return setConstant(Scalar(0)); } - Derived& setConstant(const Scalar& val) { - Scalar* data = derived().data(); - for (int i = 0; i < derived().size(); ++i) { - data[i] = val; - } - return derived(); + return derived() = constant(val); + } + Derived& setRandom() { + return derived() = random(); } - Derived& setRandom() { - Scalar* data = derived().data(); - for (int i = 0; i < derived().size(); ++i) { - data[i] = internal::random_default_impl::run(); - } - return derived(); + // Nullary operators + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> + constant(const Scalar& value) const { + return TensorCwiseNullaryOp, const Derived> + (internal::scalar_constant_op(value)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> + random() const { + return TensorCwiseNullaryOp, const Derived>(); } // Coefficient-wise unary operators @@ -57,15 +61,31 @@ class TensorBase EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> - cwiseSqrt() const { return derived(); } + sqrt() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + square() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + inverse() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + exp() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + log() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> - cwiseAbs() const { return derived(); } + abs() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> - cwisePow(Scalar exponent) const { + pow(Scalar exponent) const { return TensorCwiseUnaryOp, const Derived> (derived(), internal::scalar_pow_op(exponent)); } @@ -77,6 +97,30 @@ class TensorBase (derived(), internal::scalar_multiple_op(scale)); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > + cwiseMax(Scalar threshold) const { + return cwiseMax(constant(threshold)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > + cwiseMin(Scalar threshold) const { + return cwiseMin(constant(threshold)); + } + + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp + unaryExpr(const CustomUnaryOp& func) const { + return TensorCwiseUnaryOp(derived(), func); + } + + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + cast() const { + return derived(); + } + // Coefficient-wise binary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> @@ -90,6 +134,71 @@ class TensorBase return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator*(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator/(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + cwiseMax(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + cwiseMin(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + // Comparisons and tests. + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator<(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator<=(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator>(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator>=(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator==(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp, const Derived, const OtherDerived> + operator!=(const OtherDerived& other) const { + return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + // Coefficient-wise ternary operators. + template + inline const TensorSelectOp + select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{ + return TensorSelectOp(derived(), thenTensor.derived(), elseTensor.derived()); + } + + // Select the device on which to evaluate the expression. template TensorDevice device(const DeviceType& device) { return TensorDevice(device, derived()); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 3ce924dc3..e0c0863b7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -68,6 +68,42 @@ struct TensorEvaluator +// -------------------- CwiseNullaryOp -------------------- + +template +struct TensorEvaluator > +{ + typedef TensorCwiseNullaryOp XprType; + + enum { + IsAligned = true, + PacketAccess = internal::functor_traits::PacketAccess, + }; + + TensorEvaluator(const XprType& op) + : m_functor(op.functor()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_functor(index); + } + + template + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + return m_functor.packetOp(index); + } + + private: + const NullaryOp m_functor; +}; + + // -------------------- CwiseUnaryOp -------------------- @@ -146,6 +182,54 @@ struct TensorEvaluator m_rightImpl; }; + +// -------------------- SelectOp -------------------- + +template +struct TensorEvaluator > +{ + typedef TensorSelectOp XprType; + + enum { + IsAligned = TensorEvaluator::IsAligned & TensorEvaluator::IsAligned, + PacketAccess = TensorEvaluator::PacketAccess & TensorEvaluator::PacketAccess/* & + TensorEvaluator::PacketAccess*/, + }; + + TensorEvaluator(const XprType& op) + : m_condImpl(op.ifExpression()), + m_thenImpl(op.thenExpression()), + m_elseImpl(op.elseExpression()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); + } + template + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + static const int PacketSize = internal::unpacket_traits::size; + internal::Selector select; + for (Index i = 0; i < PacketSize; ++i) { + select.select[i] = m_condImpl.coeff(index+i); + } + return internal::pblend(select, + m_thenImpl.template packet(index), + m_elseImpl.template packet(index)); + } + + private: + TensorEvaluator m_condImpl; + TensorEvaluator m_thenImpl; + TensorEvaluator m_elseImpl; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index e32077f6e..94cfae05c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -17,6 +17,9 @@ namespace Eigen { * * \brief Tensor expression classes. * + * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This + * is typically used to generate constants. + * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * @@ -24,6 +27,46 @@ namespace Eigen { * (e.g. addition) is applied to a lhs and a rhs expression. * */ +namespace internal { +template +struct traits > + : traits +{ + typedef typename PlainObjectType::Packet Packet; + typedef typename PlainObjectType::Scalar Scalar; + typedef typename PlainObjectType::Nested XprTypeNested; + typedef typename remove_reference::type _XprTypeNested; +}; + +} // end namespace internal + + + +template +class TensorCwiseNullaryOp : public TensorBase > +{ + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; + typedef typename PlainObjectType::PacketReturnType PacketReturnType; + typedef TensorCwiseNullaryOp Nested; + typedef typename Eigen::internal::traits::StorageKind StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) + : m_functor(func) {} + + EIGEN_DEVICE_FUNC + const NullaryOp& functor() const { return m_functor; } + + protected: + // todo: add tensor dimension to be able to do some sanity checks + const NullaryOp m_functor; +}; + + namespace internal { template @@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase +struct traits > + : traits +{ + typedef typename traits::Scalar Scalar; + typedef typename internal::packet_traits::type Packet; + typedef typename promote_storage_type::StorageKind, + typename traits::StorageKind>::ret StorageKind; + typedef typename promote_index_type::Index, + typename traits::Index>::type Index; + typedef typename IfXprType::Nested IfNested; + typedef typename ThenXprType::Nested ThenNested; + typedef typename ElseXprType::Nested ElseNested; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorSelectOp& type; +}; + +template +struct nested, 1, typename eval >::type> +{ + typedef TensorSelectOp type; +}; + +} // end namespace internal + + +template +class TensorSelectOp : public TensorBase > +{ + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename internal::promote_storage_type::ret CoeffReturnType; + typedef typename internal::promote_storage_type::ret PacketReturnType; + typedef typename Eigen::internal::nested::type Nested; + typedef typename Eigen::internal::traits::StorageKind StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + TensorSelectOp(const IfXprType& a_condition, + const ThenXprType& a_then, + const ElseXprType& a_else) + : m_condition(a_condition), m_then(a_then), m_else(a_else) + { } + + const IfXprType& ifExpression() const { return m_condition; } + + const ThenXprType& thenExpression() const { return m_then; } + + const ElseXprType& elseExpression() const { return m_else; } + + protected: + typename IfXprType::Nested m_condition; + typename ThenXprType::Nested m_then; + typename ElseXprType::Nested m_else; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 09b0fe66d..03ac8d516 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -17,8 +17,10 @@ template class TensorFi template class TensorMap; template class TensorBase; +template class TensorCwiseNullaryOp; template class TensorCwiseUnaryOp; template class TensorCwiseBinaryOp; +template class TensorSelectOp; template class TensorDevice; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h index 3fc9c5335..3a2ff5b30 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h @@ -45,33 +45,37 @@ template class TensorMap : public Tensor static const int Options = Options_; + static const std::size_t NumIndices = PlainObjectType::NumIndices; + typedef typename PlainObjectType::Dimensions Dimensions; + + enum { IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned), PacketAccess = true, }; EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array({{firstDimension}})) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array(firstDimension)) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. - EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) + EIGEN_STATIC_ASSERT(1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #ifdef EIGEN_HAS_VARIADIC_TEMPLATES template EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array({{firstDimension, otherDimensions...}})) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array({{firstDimension, otherDimensions...}})) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. - EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) + EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #endif - inline TensorMap(PointerArgType dataPtr, const array& dimensions) + inline TensorMap(PointerArgType dataPtr, const array& dimensions) : m_data(dataPtr), m_dimensions(dimensions) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; } + EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } EIGEN_DEVICE_FUNC @@ -80,7 +84,7 @@ template class TensorMap : public Tensor EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const Scalar& operator()(const array& indices) const + EIGEN_STRONG_INLINE const Scalar& operator()(const array& indices) const { // eigen_assert(checkIndexRange(indices)); if (PlainObjectType::Options&RowMajor) { @@ -96,12 +100,12 @@ template class TensorMap : public Tensor template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const { - static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); if (PlainObjectType::Options&RowMajor) { - const Index index = m_dimensions.IndexOfRowMajor(array{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfRowMajor(array{{firstIndex, otherIndices...}}); return m_data[index]; } else { - const Index index = m_dimensions.IndexOfColMajor(array{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfColMajor(array{{firstIndex, otherIndices...}}); return m_data[index]; } } @@ -159,7 +163,7 @@ template class TensorMap : public Tensor #endif EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar& operator()(const array& indices) + EIGEN_STRONG_INLINE Scalar& operator()(const array& indices) { // eigen_assert(checkIndexRange(indices)); if (PlainObjectType::Options&RowMajor) { @@ -175,12 +179,12 @@ template class TensorMap : public Tensor template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) { - static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); if (PlainObjectType::Options&RowMajor) { - const Index index = m_dimensions.IndexOfRowMajor(array{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfRowMajor(array{{firstIndex, otherIndices...}}); return m_data[index]; } else { - const Index index = m_dimensions.IndexOfColMajor(array{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfColMajor(array{{firstIndex, otherIndices...}}); return m_data[index]; } } @@ -247,8 +251,8 @@ template class TensorMap : public Tensor } private: - typename PlainObjectType::Scalar* m_data; - typename PlainObjectType::Dimensions m_dimensions; + Scalar* m_data; + Dimensions m_dimensions; }; } // end namespace Eigen