mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-18 23:57:39 +08:00
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...) * selection * nullary ops such as random or constant generation * misc unary ops such as log(), exp(), or a user defined unaryExpr() Cleaned up the code a little.
This commit is contained in:
parent
7402fea0a8
commit
736267cf6b
@ -33,21 +33,25 @@ class TensorBase
|
|||||||
Derived& setZero() {
|
Derived& setZero() {
|
||||||
return setConstant(Scalar(0));
|
return setConstant(Scalar(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
Derived& setConstant(const Scalar& val) {
|
Derived& setConstant(const Scalar& val) {
|
||||||
Scalar* data = derived().data();
|
return derived() = constant(val);
|
||||||
for (int i = 0; i < derived().size(); ++i) {
|
}
|
||||||
data[i] = val;
|
Derived& setRandom() {
|
||||||
}
|
return derived() = random();
|
||||||
return derived();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Derived& setRandom() {
|
// Nullary operators
|
||||||
Scalar* data = derived().data();
|
EIGEN_DEVICE_FUNC
|
||||||
for (int i = 0; i < derived().size(); ++i) {
|
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
|
||||||
data[i] = internal::random_default_impl<Scalar, false, false>::run();
|
constant(const Scalar& value) const {
|
||||||
}
|
return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
|
||||||
return derived();
|
(internal::scalar_constant_op<Scalar>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>
|
||||||
|
random() const {
|
||||||
|
return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Coefficient-wise unary operators
|
// Coefficient-wise unary operators
|
||||||
@ -57,15 +61,31 @@ class TensorBase
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived>
|
||||||
cwiseSqrt() const { return derived(); }
|
sqrt() const { return derived(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived>
|
||||||
|
square() const { return derived(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived>
|
||||||
|
inverse() const { return derived(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived>
|
||||||
|
exp() const { return derived(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived>
|
||||||
|
log() const { return derived(); }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived>
|
||||||
cwiseAbs() const { return derived(); }
|
abs() const { return derived(); }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
|
||||||
cwisePow(Scalar exponent) const {
|
pow(Scalar exponent) const {
|
||||||
return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
|
return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
|
||||||
(derived(), internal::scalar_pow_op<Scalar>(exponent));
|
(derived(), internal::scalar_pow_op<Scalar>(exponent));
|
||||||
}
|
}
|
||||||
@ -77,6 +97,30 @@ class TensorBase
|
|||||||
(derived(), internal::scalar_multiple_op<Scalar>(scale));
|
(derived(), internal::scalar_multiple_op<Scalar>(scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
cwiseMax(Scalar threshold) const {
|
||||||
|
return cwiseMax(constant(threshold));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
cwiseMin(Scalar threshold) const {
|
||||||
|
return cwiseMin(constant(threshold));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived>
|
||||||
|
unaryExpr(const CustomUnaryOp& func) const {
|
||||||
|
return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NewType> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived>
|
||||||
|
cast() const {
|
||||||
|
return derived();
|
||||||
|
}
|
||||||
|
|
||||||
// Coefficient-wise binary operators.
|
// Coefficient-wise binary operators.
|
||||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>
|
const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>
|
||||||
@ -90,6 +134,71 @@ class TensorBase
|
|||||||
return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator*(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator/(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
|
||||||
|
cwiseMax(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
|
||||||
|
cwiseMin(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comparisons and tests.
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator<(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator<=(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator>(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator>=(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator==(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
|
||||||
|
operator!=(const OtherDerived& other) const {
|
||||||
|
return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Coefficient-wise ternary operators.
|
||||||
|
template<typename ThenDerived,typename ElseDerived>
|
||||||
|
inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
|
||||||
|
select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{
|
||||||
|
return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the device on which to evaluate the expression.
|
||||||
template <typename DeviceType>
|
template <typename DeviceType>
|
||||||
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
|
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
|
||||||
return TensorDevice<Derived, DeviceType>(device, derived());
|
return TensorDevice<Derived, DeviceType>(device, derived());
|
||||||
|
@ -68,6 +68,42 @@ struct TensorEvaluator
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// -------------------- CwiseNullaryOp --------------------
|
||||||
|
|
||||||
|
template<typename NullaryOp, typename PlainObjectType>
|
||||||
|
struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
|
||||||
|
{
|
||||||
|
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = true,
|
||||||
|
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
|
||||||
|
};
|
||||||
|
|
||||||
|
TensorEvaluator(const XprType& op)
|
||||||
|
: m_functor(op.functor())
|
||||||
|
{ }
|
||||||
|
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
||||||
|
{
|
||||||
|
return m_functor(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
|
||||||
|
{
|
||||||
|
return m_functor.packetOp(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const NullaryOp m_functor;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// -------------------- CwiseUnaryOp --------------------
|
// -------------------- CwiseUnaryOp --------------------
|
||||||
|
|
||||||
@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
|
|||||||
TensorEvaluator<RightArgType> m_rightImpl;
|
TensorEvaluator<RightArgType> m_rightImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// -------------------- SelectOp --------------------
|
||||||
|
|
||||||
|
template<typename IfArgType, typename ThenArgType, typename ElseArgType>
|
||||||
|
struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
|
||||||
|
{
|
||||||
|
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
|
||||||
|
PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
|
||||||
|
TensorEvaluator<IfArgType>::PacketAccess*/,
|
||||||
|
};
|
||||||
|
|
||||||
|
TensorEvaluator(const XprType& op)
|
||||||
|
: m_condImpl(op.ifExpression()),
|
||||||
|
m_thenImpl(op.thenExpression()),
|
||||||
|
m_elseImpl(op.elseExpression())
|
||||||
|
{ }
|
||||||
|
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
||||||
|
{
|
||||||
|
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
|
||||||
|
}
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
|
||||||
|
{
|
||||||
|
static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
|
internal::Selector<PacketSize> select;
|
||||||
|
for (Index i = 0; i < PacketSize; ++i) {
|
||||||
|
select.select[i] = m_condImpl.coeff(index+i);
|
||||||
|
}
|
||||||
|
return internal::pblend(select,
|
||||||
|
m_thenImpl.template packet<LoadMode>(index),
|
||||||
|
m_elseImpl.template packet<LoadMode>(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorEvaluator<IfArgType> m_condImpl;
|
||||||
|
TensorEvaluator<ThenArgType> m_thenImpl;
|
||||||
|
TensorEvaluator<ElseArgType> m_elseImpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
|
||||||
|
@ -17,6 +17,9 @@ namespace Eigen {
|
|||||||
*
|
*
|
||||||
* \brief Tensor expression classes.
|
* \brief Tensor expression classes.
|
||||||
*
|
*
|
||||||
|
* The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
|
||||||
|
* is typically used to generate constants.
|
||||||
|
*
|
||||||
* The TensorCwiseUnaryOp class represents an expression where a unary operator
|
* The TensorCwiseUnaryOp class represents an expression where a unary operator
|
||||||
* (e.g. cwiseSqrt) is applied to an expression.
|
* (e.g. cwiseSqrt) is applied to an expression.
|
||||||
*
|
*
|
||||||
@ -24,6 +27,46 @@ namespace Eigen {
|
|||||||
* (e.g. addition) is applied to a lhs and a rhs expression.
|
* (e.g. addition) is applied to a lhs and a rhs expression.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
namespace internal {
|
||||||
|
template<typename NullaryOp, typename PlainObjectType>
|
||||||
|
struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
|
||||||
|
: traits<PlainObjectType>
|
||||||
|
{
|
||||||
|
typedef typename PlainObjectType::Packet Packet;
|
||||||
|
typedef typename PlainObjectType::Scalar Scalar;
|
||||||
|
typedef typename PlainObjectType::Nested XprTypeNested;
|
||||||
|
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename NullaryOp, typename PlainObjectType>
|
||||||
|
class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
|
||||||
|
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
|
||||||
|
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||||
|
typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
|
||||||
|
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested;
|
||||||
|
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
|
||||||
|
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp())
|
||||||
|
: m_functor(func) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
const NullaryOp& functor() const { return m_functor; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// todo: add tensor dimension to be able to do some sanity checks
|
||||||
|
const NullaryOp m_functor;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename UnaryOp, typename XprType>
|
template<typename UnaryOp, typename XprType>
|
||||||
@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
|
|||||||
const BinaryOp m_functor;
|
const BinaryOp m_functor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||||
|
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
|
||||||
|
: traits<ThenXprType>
|
||||||
|
{
|
||||||
|
typedef typename traits<ThenXprType>::Scalar Scalar;
|
||||||
|
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||||
|
typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
|
||||||
|
typename traits<ElseXprType>::StorageKind>::ret StorageKind;
|
||||||
|
typedef typename promote_index_type<typename traits<ElseXprType>::Index,
|
||||||
|
typename traits<ThenXprType>::Index>::type Index;
|
||||||
|
typedef typename IfXprType::Nested IfNested;
|
||||||
|
typedef typename ThenXprType::Nested ThenNested;
|
||||||
|
typedef typename ElseXprType::Nested ElseNested;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||||
|
struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
|
||||||
|
{
|
||||||
|
typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||||
|
struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
|
||||||
|
{
|
||||||
|
typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||||
|
class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
|
||||||
|
typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet;
|
||||||
|
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||||
|
typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
|
||||||
|
typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
|
||||||
|
typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType,
|
||||||
|
typename ElseXprType::PacketReturnType>::ret PacketReturnType;
|
||||||
|
typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
|
||||||
|
typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
|
||||||
|
typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
|
||||||
|
|
||||||
|
TensorSelectOp(const IfXprType& a_condition,
|
||||||
|
const ThenXprType& a_then,
|
||||||
|
const ElseXprType& a_else)
|
||||||
|
: m_condition(a_condition), m_then(a_then), m_else(a_else)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
const IfXprType& ifExpression() const { return m_condition; }
|
||||||
|
|
||||||
|
const ThenXprType& thenExpression() const { return m_then; }
|
||||||
|
|
||||||
|
const ElseXprType& elseExpression() const { return m_else; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename IfXprType::Nested m_condition;
|
||||||
|
typename ThenXprType::Nested m_then;
|
||||||
|
typename ElseXprType::Nested m_else;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
|
||||||
|
@ -17,8 +17,10 @@ template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFi
|
|||||||
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
|
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
|
||||||
template<typename Derived> class TensorBase;
|
template<typename Derived> class TensorBase;
|
||||||
|
|
||||||
|
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
|
||||||
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
|
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
|
||||||
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
|
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
|
||||||
|
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
|
||||||
|
|
||||||
template<typename ExpressionType, typename DeviceType> class TensorDevice;
|
template<typename ExpressionType, typename DeviceType> class TensorDevice;
|
||||||
|
|
||||||
|
@ -45,33 +45,37 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
|
|
||||||
static const int Options = Options_;
|
static const int Options = Options_;
|
||||||
|
|
||||||
|
static const std::size_t NumIndices = PlainObjectType::NumIndices;
|
||||||
|
typedef typename PlainObjectType::Dimensions Dimensions;
|
||||||
|
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned),
|
IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned),
|
||||||
PacketAccess = true,
|
PacketAccess = true,
|
||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) {
|
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>(firstDimension)) {
|
||||||
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
|
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
|
||||||
EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
EIGEN_STATIC_ASSERT(1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
|
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
|
||||||
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) {
|
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>({{firstDimension, otherDimensions...}})) {
|
||||||
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
|
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
|
||||||
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions)
|
inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
|
||||||
: m_data(dataPtr), m_dimensions(dimensions)
|
: m_data(dataPtr), m_dimensions(dimensions)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
|
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; }
|
EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
|
EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -80,7 +84,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
|
EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) const
|
EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
|
||||||
{
|
{
|
||||||
// eigen_assert(checkIndexRange(indices));
|
// eigen_assert(checkIndexRange(indices));
|
||||||
if (PlainObjectType::Options&RowMajor) {
|
if (PlainObjectType::Options&RowMajor) {
|
||||||
@ -96,12 +100,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
|
EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
|
||||||
{
|
{
|
||||||
static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||||
if (PlainObjectType::Options&RowMajor) {
|
if (PlainObjectType::Options&RowMajor) {
|
||||||
const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
|
const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
||||||
return m_data[index];
|
return m_data[index];
|
||||||
} else {
|
} else {
|
||||||
const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
|
const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
||||||
return m_data[index];
|
return m_data[index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -159,7 +163,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices)
|
EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
|
||||||
{
|
{
|
||||||
// eigen_assert(checkIndexRange(indices));
|
// eigen_assert(checkIndexRange(indices));
|
||||||
if (PlainObjectType::Options&RowMajor) {
|
if (PlainObjectType::Options&RowMajor) {
|
||||||
@ -175,12 +179,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
|
EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
|
||||||
{
|
{
|
||||||
static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||||
if (PlainObjectType::Options&RowMajor) {
|
if (PlainObjectType::Options&RowMajor) {
|
||||||
const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
|
const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
||||||
return m_data[index];
|
return m_data[index];
|
||||||
} else {
|
} else {
|
||||||
const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
|
const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
||||||
return m_data[index];
|
return m_data[index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -247,8 +251,8 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
typename PlainObjectType::Scalar* m_data;
|
Scalar* m_data;
|
||||||
typename PlainObjectType::Dimensions m_dimensions;
|
Dimensions m_dimensions;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
Loading…
x
Reference in New Issue
Block a user