Add arg() to tensor

This commit is contained in:
Tobias Wood 2022-05-17 12:01:39 +01:00 committed by Antonio Sánchez
parent 028ab12586
commit a9868bd5be
3 changed files with 38 additions and 0 deletions

View File

@ -886,6 +886,23 @@ containing the natural logarithms of the original tensor.
Returns a tensor of the same type and dimensions as the original tensor
containing the absolute values of the original tensor.
### <Operation> arg()
Returns a tensor with the same dimensions as the original tensor
containing the complex argument (phase angle) of the values of the
original tensor.
### <Operation> real()
Returns a tensor with the same dimensions as the original tensor
containing the real part of the complex values of the original tensor.
### <Operation> imag()
Returns a tensor with the same dimensions as the orginal tensor
containing the imaginary part of the complex values of the original
tensor.
### <Operation> pow(Scalar exponent)
Returns a tensor of the same type and dimensions as the original tensor

View File

@ -311,6 +311,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_abs_op<Scalar>());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived>
arg() const {
return unaryExpr(internal::scalar_arg_op<Scalar>());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clamp_op<Scalar>, const Derived>
clip(Scalar min, Scalar max) const {

View File

@ -47,6 +47,20 @@ static void test_abs()
}
}
static void test_arg()
{
Tensor<std::complex<float>, 1> data1(3);
Tensor<std::complex<double>, 1> data2(3);
data1.setRandom();
data2.setRandom();
Tensor<float, 1> arg1 = data1.arg();
Tensor<double, 1> arg2 = data2.arg();
for (int i = 0; i < 3; ++i) {
VERIFY_IS_APPROX(arg1(i), std::arg(data1(i)));
VERIFY_IS_APPROX(arg2(i), std::arg(data2(i)));
}
}
static void test_conjugate()
{
@ -98,6 +112,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_of_complex)
{
CALL_SUBTEST(test_additions());
CALL_SUBTEST(test_abs());
CALL_SUBTEST(test_arg());
CALL_SUBTEST(test_conjugate());
CALL_SUBTEST(test_contractions());
}