Added support for additional tensor operations:

* comparison (<, <=, ==, !=, ...)
  * selection
  * nullary ops such as random or constant generation
  * misc unary ops such as log(), exp(), or a user defined unaryExpr()
Cleaned up the code a little.
This commit is contained in:
Benoit Steiner 2014-05-22 16:22:35 -07:00
parent 7402fea0a8
commit 736267cf6b
5 changed files with 339 additions and 31 deletions

View File

@ -33,21 +33,25 @@ class TensorBase
Derived& setZero() { Derived& setZero() {
return setConstant(Scalar(0)); return setConstant(Scalar(0));
} }
Derived& setConstant(const Scalar& val) { Derived& setConstant(const Scalar& val) {
Scalar* data = derived().data(); return derived() = constant(val);
for (int i = 0; i < derived().size(); ++i) {
data[i] = val;
} }
return derived(); Derived& setRandom() {
return derived() = random();
} }
Derived& setRandom() { // Nullary operators
Scalar* data = derived().data(); EIGEN_DEVICE_FUNC
for (int i = 0; i < derived().size(); ++i) { EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
data[i] = internal::random_default_impl<Scalar, false, false>::run(); constant(const Scalar& value) const {
return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
(internal::scalar_constant_op<Scalar>(value));
} }
return derived();
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>
random() const {
return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>();
} }
// Coefficient-wise unary operators // Coefficient-wise unary operators
@ -57,15 +61,31 @@ class TensorBase
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived>
cwiseSqrt() const { return derived(); } sqrt() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived>
square() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived>
inverse() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived>
exp() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived>
log() const { return derived(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived>
cwiseAbs() const { return derived(); } abs() const { return derived(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
cwisePow(Scalar exponent) const { pow(Scalar exponent) const {
return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
(derived(), internal::scalar_pow_op<Scalar>(exponent)); (derived(), internal::scalar_pow_op<Scalar>(exponent));
} }
@ -77,6 +97,30 @@ class TensorBase
(derived(), internal::scalar_multiple_op<Scalar>(scale)); (derived(), internal::scalar_multiple_op<Scalar>(scale));
} }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMax(Scalar threshold) const {
return cwiseMax(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMin(Scalar threshold) const {
return cwiseMin(constant(threshold));
}
template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived>
unaryExpr(const CustomUnaryOp& func) const {
return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func);
}
template <typename NewType> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived>
cast() const {
return derived();
}
// Coefficient-wise binary operators. // Coefficient-wise binary operators.
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived> const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>
@ -90,6 +134,71 @@ class TensorBase
return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
} }
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>
operator*(const OtherDerived& other) const {
return TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>
operator/(const OtherDerived& other) const {
return TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
cwiseMax(const OtherDerived& other) const {
return TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
cwiseMin(const OtherDerived& other) const {
return TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
// Comparisons and tests.
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>
operator<(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>
operator<=(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>
operator>(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>
operator>=(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
operator==(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
operator!=(const OtherDerived& other) const {
return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
// Coefficient-wise ternary operators.
template<typename ThenDerived,typename ElseDerived>
inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{
return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
}
// Select the device on which to evaluate the expression.
template <typename DeviceType> template <typename DeviceType>
TensorDevice<Derived, DeviceType> device(const DeviceType& device) { TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
return TensorDevice<Derived, DeviceType>(device, derived()); return TensorDevice<Derived, DeviceType>(device, derived());

View File

@ -68,6 +68,42 @@ struct TensorEvaluator
// -------------------- CwiseNullaryOp --------------------
template<typename NullaryOp, typename PlainObjectType>
struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
{
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType;
enum {
IsAligned = true,
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
};
TensorEvaluator(const XprType& op)
: m_functor(op.functor())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
return m_functor(index);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
{
return m_functor.packetOp(index);
}
private:
const NullaryOp m_functor;
};
// -------------------- CwiseUnaryOp -------------------- // -------------------- CwiseUnaryOp --------------------
@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType> m_rightImpl; TensorEvaluator<RightArgType> m_rightImpl;
}; };
// -------------------- SelectOp --------------------
template<typename IfArgType, typename ThenArgType, typename ElseArgType>
struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
{
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
enum {
IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
TensorEvaluator<IfArgType>::PacketAccess*/,
};
TensorEvaluator(const XprType& op)
: m_condImpl(op.ifExpression()),
m_thenImpl(op.thenExpression()),
m_elseImpl(op.elseExpression())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
}
template<int LoadMode>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
{
static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
internal::Selector<PacketSize> select;
for (Index i = 0; i < PacketSize; ++i) {
select.select[i] = m_condImpl.coeff(index+i);
}
return internal::pblend(select,
m_thenImpl.template packet<LoadMode>(index),
m_elseImpl.template packet<LoadMode>(index));
}
private:
TensorEvaluator<IfArgType> m_condImpl;
TensorEvaluator<ThenArgType> m_thenImpl;
TensorEvaluator<ElseArgType> m_elseImpl;
};
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H

View File

@ -17,6 +17,9 @@ namespace Eigen {
* *
* \brief Tensor expression classes. * \brief Tensor expression classes.
* *
* The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
* is typically used to generate constants.
*
* The TensorCwiseUnaryOp class represents an expression where a unary operator * The TensorCwiseUnaryOp class represents an expression where a unary operator
* (e.g. cwiseSqrt) is applied to an expression. * (e.g. cwiseSqrt) is applied to an expression.
* *
@ -24,6 +27,46 @@ namespace Eigen {
* (e.g. addition) is applied to a lhs and a rhs expression. * (e.g. addition) is applied to a lhs and a rhs expression.
* *
*/ */
namespace internal {
template<typename NullaryOp, typename PlainObjectType>
struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
: traits<PlainObjectType>
{
typedef typename PlainObjectType::Packet Packet;
typedef typename PlainObjectType::Scalar Scalar;
typedef typename PlainObjectType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
};
} // end namespace internal
template<typename NullaryOp, typename PlainObjectType>
class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
{
public:
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp())
: m_functor(func) {}
EIGEN_DEVICE_FUNC
const NullaryOp& functor() const { return m_functor; }
protected:
// todo: add tensor dimension to be able to do some sanity checks
const NullaryOp m_functor;
};
namespace internal { namespace internal {
template<typename UnaryOp, typename XprType> template<typename UnaryOp, typename XprType>
@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
const BinaryOp m_functor; const BinaryOp m_functor;
}; };
namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
{
typedef typename traits<ThenXprType>::Scalar Scalar;
typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
typename traits<ElseXprType>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<ElseXprType>::Index,
typename traits<ThenXprType>::Index>::type Index;
typedef typename IfXprType::Nested IfNested;
typedef typename ThenXprType::Nested ThenNested;
typedef typename ElseXprType::Nested ElseNested;
};
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
{
typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
};
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
{
typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
};
} // end namespace internal
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
{
public:
typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType,
typename ElseXprType::PacketReturnType>::ret PacketReturnType;
typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
TensorSelectOp(const IfXprType& a_condition,
const ThenXprType& a_then,
const ElseXprType& a_else)
: m_condition(a_condition), m_then(a_then), m_else(a_else)
{ }
const IfXprType& ifExpression() const { return m_condition; }
const ThenXprType& thenExpression() const { return m_then; }
const ElseXprType& elseExpression() const { return m_else; }
protected:
typename IfXprType::Nested m_condition;
typename ThenXprType::Nested m_then;
typename ElseXprType::Nested m_else;
};
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H

View File

@ -17,8 +17,10 @@ template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFi
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap; template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
template<typename Derived> class TensorBase; template<typename Derived> class TensorBase;
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp; template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp; template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice; template<typename ExpressionType, typename DeviceType> class TensorDevice;

View File

@ -45,33 +45,37 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
static const int Options = Options_; static const int Options = Options_;
static const std::size_t NumIndices = PlainObjectType::NumIndices;
typedef typename PlainObjectType::Dimensions Dimensions;
enum { enum {
IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned), IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned),
PacketAccess = true, PacketAccess = true,
}; };
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) { EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>(firstDimension)) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor. // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
} }
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) { EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>({{firstDimension, otherDimensions...}})) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor. // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
} }
#endif #endif
inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions) inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
: m_data(dataPtr), m_dimensions(dimensions) : m_data(dataPtr), m_dimensions(dimensions)
{ } { }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; } EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -80,7 +84,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; } EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) const EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
{ {
// eigen_assert(checkIndexRange(indices)); // eigen_assert(checkIndexRange(indices));
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
@ -96,12 +100,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
template<typename... IndexTypes> EIGEN_DEVICE_FUNC template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
{ {
static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index]; return m_data[index];
} else { } else {
const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index]; return m_data[index];
} }
} }
@ -159,7 +163,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
#endif #endif
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
{ {
// eigen_assert(checkIndexRange(indices)); // eigen_assert(checkIndexRange(indices));
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
@ -175,12 +179,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
template<typename... IndexTypes> EIGEN_DEVICE_FUNC template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
{ {
static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index]; return m_data[index];
} else { } else {
const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index]; return m_data[index];
} }
} }
@ -247,8 +251,8 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
} }
private: private:
typename PlainObjectType::Scalar* m_data; Scalar* m_data;
typename PlainObjectType::Dimensions m_dimensions; Dimensions m_dimensions;
}; };
} // end namespace Eigen } // end namespace Eigen