mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 01:29:35 +08:00
Fixed the return types of unary and binary expressions to properly handle the case where it is different from the input type (e.g. abs(complex<float>))
This commit is contained in:
parent
d853adffdb
commit
94e47798f4
@ -155,8 +155,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
|
|||||||
|
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
|
||||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
typedef typename internal::traits<XprType>::Packet PacketReturnType;
|
||||||
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
|
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
|
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
|
||||||
@ -203,8 +203,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
|
|||||||
|
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
|
||||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
typedef typename internal::traits<XprType>::Packet PacketReturnType;
|
||||||
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
|
typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
|
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
|
||||||
@ -257,8 +257,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
|
|||||||
|
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
|
||||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
typedef typename internal::traits<XprType>::Packet PacketReturnType;
|
||||||
typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
|
typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
|
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
|
||||||
@ -317,8 +317,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
|||||||
|
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
|
||||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
typedef typename internal::traits<XprType>::Packet PacketReturnType;
|
||||||
typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
|
typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
|
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
|
||||||
|
@ -84,9 +84,7 @@ struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
|
|||||||
typedef typename result_of<
|
typedef typename result_of<
|
||||||
UnaryOp(typename XprType::Scalar)
|
UnaryOp(typename XprType::Scalar)
|
||||||
>::type Scalar;
|
>::type Scalar;
|
||||||
typedef typename result_of<
|
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||||
UnaryOp(typename XprType::Packet)
|
|
||||||
>::type Packet;
|
|
||||||
typedef typename XprType::Nested XprTypeNested;
|
typedef typename XprType::Nested XprTypeNested;
|
||||||
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
|
||||||
};
|
};
|
||||||
@ -188,8 +186,7 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
|
|||||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||||
typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
|
typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
|
||||||
typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
|
typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
|
||||||
typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
|
typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
|
||||||
typename RhsXprType::PacketReturnType>::ret PacketReturnType;
|
|
||||||
typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
|
typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
|
||||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
|
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
|
||||||
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
|
typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
|
||||||
|
@ -32,6 +32,22 @@ static void test_additions()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_abs()
|
||||||
|
{
|
||||||
|
Tensor<std::complex<float>, 1> data1(3);
|
||||||
|
Tensor<std::complex<double>, 1> data2(3);
|
||||||
|
data1.setRandom();
|
||||||
|
data2.setRandom();
|
||||||
|
|
||||||
|
Tensor<float, 1> abs1 = data1.abs();
|
||||||
|
Tensor<double, 1> abs2 = data2.abs();
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
VERIFY_IS_APPROX(abs1(i), std::abs(data1(i)));
|
||||||
|
VERIFY_IS_APPROX(abs2(i), std::abs(data2(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_contractions()
|
static void test_contractions()
|
||||||
{
|
{
|
||||||
Tensor<std::complex<float>, 4> t_left(30, 50, 8, 31);
|
Tensor<std::complex<float>, 4> t_left(30, 50, 8, 31);
|
||||||
@ -60,5 +76,6 @@ static void test_contractions()
|
|||||||
void test_cxx11_tensor_of_complex()
|
void test_cxx11_tensor_of_complex()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_additions());
|
CALL_SUBTEST(test_additions());
|
||||||
|
CALL_SUBTEST(test_abs());
|
||||||
CALL_SUBTEST(test_contractions());
|
CALL_SUBTEST(test_contractions());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user