mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 04:09:10 +08:00
Fixed the return type of coefficient wise operations. For example, the abs function returns a floating point value when called on a complex input.
This commit is contained in:
parent
378bdfb7f0
commit
1ac8600126
@ -34,9 +34,15 @@ struct TensorEvaluator
|
||||
typedef typename Derived::Packet PacketReturnType;
|
||||
typedef typename Derived::Dimensions Dimensions;
|
||||
|
||||
// NumDimensions is -1 for variable dim tensors
|
||||
static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
|
||||
internal::traits<Derived>::NumDimensions : 0;
|
||||
|
||||
enum {
|
||||
IsAligned = Derived::IsAligned,
|
||||
PacketAccess = Derived::PacketAccess,
|
||||
Layout = Derived::Layout,
|
||||
CoordAccess = NumCoords > 0,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
|
||||
@ -77,6 +83,24 @@ struct TensorEvaluator
|
||||
return internal::pstoret<Scalar, Packet, StoreMode>(m_data + index, x);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
|
||||
eigen_assert(m_data);
|
||||
if (Layout == ColMajor) {
|
||||
return m_data[m_dims.IndexOfColMajor(coords)];
|
||||
} else {
|
||||
return m_data[m_dims.IndexOfRowMajor(coords)];
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
|
||||
eigen_assert(m_data);
|
||||
if (Layout == ColMajor) {
|
||||
return m_data[m_dims.IndexOfColMajor(coords)];
|
||||
} else {
|
||||
return m_data[m_dims.IndexOfRowMajor(coords)];
|
||||
}
|
||||
}
|
||||
|
||||
Scalar* data() const { return m_data; }
|
||||
|
||||
protected:
|
||||
@ -97,9 +121,15 @@ struct TensorEvaluator<const Derived, Device>
|
||||
typedef typename Derived::Packet PacketReturnType;
|
||||
typedef typename Derived::Dimensions Dimensions;
|
||||
|
||||
// NumDimensions is -1 for variable dim tensors
|
||||
static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
|
||||
internal::traits<Derived>::NumDimensions : 0;
|
||||
|
||||
enum {
|
||||
IsAligned = Derived::IsAligned,
|
||||
PacketAccess = Derived::PacketAccess,
|
||||
Layout = Derived::Layout,
|
||||
CoordAccess = NumCoords > 0,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&)
|
||||
@ -126,6 +156,17 @@ struct TensorEvaluator<const Derived, Device>
|
||||
return internal::ploadt_ro<Packet, LoadMode>(m_data + index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
|
||||
eigen_assert(m_data);
|
||||
const Index index = (Layout == ColMajor) ? m_dims.IndexOfColMajor(coords)
|
||||
: m_dims.IndexOfRowMajor(coords);
|
||||
#ifdef __CUDA_ARCH__
|
||||
return __ldg(m_data+index);
|
||||
#else
|
||||
return m_data[index];
|
||||
#endif
|
||||
}
|
||||
|
||||
const Scalar* data() const { return m_data; }
|
||||
|
||||
protected:
|
||||
@ -146,6 +187,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
|
||||
enum {
|
||||
IsAligned = true,
|
||||
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
|
||||
Layout = TensorEvaluator<ArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -194,6 +237,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
|
||||
enum {
|
||||
IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
|
||||
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
|
||||
Layout = TensorEvaluator<ArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
|
||||
@ -247,6 +292,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
|
||||
IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
|
||||
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
|
||||
internal::functor_traits<BinaryOp>::PacketAccess,
|
||||
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
|
||||
@ -254,7 +301,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
|
||||
m_leftImpl(op.lhsExpression(), device),
|
||||
m_rightImpl(op.rhsExpression(), device)
|
||||
{
|
||||
eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
|
||||
EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == TensorEvaluator<RightArgType, Device>::Layout || internal::traits<XprType>::NumDimensions == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
|
||||
}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
@ -309,6 +357,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
||||
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
|
||||
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* &
|
||||
TensorEvaluator<IfArgType>::PacketAccess*/,
|
||||
Layout = TensorEvaluator<IfArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
|
||||
@ -316,8 +366,10 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
||||
m_thenImpl(op.thenExpression(), device),
|
||||
m_elseImpl(op.elseExpression(), device)
|
||||
{
|
||||
eigen_assert(internal::dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
|
||||
eigen_assert(internal::dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
|
||||
EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ThenArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ElseArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
|
||||
eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
|
||||
}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
|
@ -17,14 +17,14 @@ namespace Eigen {
|
||||
*
|
||||
* \brief Tensor expression classes.
|
||||
*
|
||||
* The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
|
||||
* is typically used to generate constants.
|
||||
* The TensorCwiseNullaryOp class applies a nullary operators to an expression.
|
||||
* This is typically used to generate constants.
|
||||
*
|
||||
* The TensorCwiseUnaryOp class represents an expression where a unary operator
|
||||
* (e.g. cwiseSqrt) is applied to an expression.
|
||||
*
|
||||
* The TensorCwiseBinaryOp class represents an expression where a binary operator
|
||||
* (e.g. addition) is applied to a lhs and a rhs expression.
|
||||
* The TensorCwiseBinaryOp class represents an expression where a binary
|
||||
* operator (e.g. addition) is applied to a lhs and a rhs expression.
|
||||
*
|
||||
*/
|
||||
namespace internal {
|
||||
@ -33,9 +33,12 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
|
||||
: traits<XprType>
|
||||
{
|
||||
typedef typename XprType::Packet Packet;
|
||||
typedef traits<XprType> XprTraits;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::Nested XprTypeNested;
|
||||
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
||||
static const int NumDimensions = XprTraits::NumDimensions;
|
||||
static const int Layout = XprTraits::Layout;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
@ -47,7 +50,7 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
|
||||
|
||||
|
||||
template<typename NullaryOp, typename XprType>
|
||||
class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> >
|
||||
class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
|
||||
@ -81,12 +84,15 @@ template<typename UnaryOp, typename XprType>
|
||||
struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
|
||||
: traits<XprType>
|
||||
{
|
||||
typedef typename result_of<
|
||||
UnaryOp(typename XprType::Scalar)
|
||||
>::type Scalar;
|
||||
// TODO(phli): Add InputScalar, InputPacket. Check references to
|
||||
// current Scalar/Packet to see if the intent is Input or Output.
|
||||
typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
|
||||
typedef traits<XprType> XprTraits;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename XprType::Nested XprTypeNested;
|
||||
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
||||
static const int NumDimensions = XprTraits::NumDimensions;
|
||||
static const int Layout = XprTraits::Layout;
|
||||
};
|
||||
|
||||
template<typename UnaryOp, typename XprType>
|
||||
@ -106,14 +112,16 @@ struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwise
|
||||
|
||||
|
||||
template<typename UnaryOp, typename XprType>
|
||||
class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType> >
|
||||
class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
// TODO(phli): Add InputScalar, InputPacket. Check references to
|
||||
// current Scalar/Packet to see if the intent is Input or Output.
|
||||
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
|
||||
@ -139,22 +147,27 @@ namespace internal {
|
||||
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
|
||||
struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
|
||||
{
|
||||
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
||||
// Type promotion to handle the case where the types of the lhs and the rhs
|
||||
// are different.
|
||||
// TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
|
||||
// current Scalar/Packet to see if the intent is Inputs or Output.
|
||||
typedef typename result_of<
|
||||
BinaryOp(
|
||||
typename LhsXprType::Scalar,
|
||||
typename RhsXprType::Scalar
|
||||
)
|
||||
>::type Scalar;
|
||||
BinaryOp(typename LhsXprType::Scalar,
|
||||
typename RhsXprType::Scalar)>::type Scalar;
|
||||
typedef traits<LhsXprType> XprTraits;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
|
||||
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename traits<LhsXprType>::Index,
|
||||
typename traits<RhsXprType>::Index>::type Index;
|
||||
typedef typename promote_storage_type<
|
||||
typename traits<LhsXprType>::StorageKind,
|
||||
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<
|
||||
typename traits<LhsXprType>::Index,
|
||||
typename traits<RhsXprType>::Index>::type Index;
|
||||
typedef typename LhsXprType::Nested LhsNested;
|
||||
typedef typename RhsXprType::Nested RhsNested;
|
||||
typedef typename remove_reference<LhsNested>::type _LhsNested;
|
||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||
static const int NumDimensions = XprTraits::NumDimensions;
|
||||
static const int Layout = XprTraits::Layout;
|
||||
|
||||
enum {
|
||||
Flags = 0,
|
||||
@ -178,21 +191,22 @@ struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename
|
||||
|
||||
|
||||
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
|
||||
class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
|
||||
class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
|
||||
{
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
|
||||
typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
|
||||
typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
|
||||
// TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
|
||||
// current Scalar/Packet to see if the intent is Inputs or Output.
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef Scalar CoeffReturnType;
|
||||
typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
|
||||
typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
|
||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
|
||||
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
|
||||
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
const BinaryOp& functor() const { return m_functor; }
|
||||
@ -219,7 +233,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
|
||||
: traits<ThenXprType>
|
||||
{
|
||||
typedef typename traits<ThenXprType>::Scalar Scalar;
|
||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||
typedef traits<ThenXprType> XprTraits;
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
|
||||
typename traits<ElseXprType>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename traits<ElseXprType>::Index,
|
||||
@ -227,6 +242,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
|
||||
typedef typename IfXprType::Nested IfNested;
|
||||
typedef typename ThenXprType::Nested ThenNested;
|
||||
typedef typename ElseXprType::Nested ElseNested;
|
||||
static const int NumDimensions = XprTraits::NumDimensions;
|
||||
static const int Layout = XprTraits::Layout;
|
||||
};
|
||||
|
||||
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
|
||||
|
Loading…
x
Reference in New Issue
Block a user