mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Added some syntactic sugar to make it simpler to compare a tensor to a scalar.
This commit is contained in:
parent
5ca2e25967
commit
b178cc3479
@ -279,6 +279,38 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
|
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// comparisons and tests for Scalars
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator<(Scalar threshold) const {
|
||||||
|
return operator<(constant(threshold));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator<=(Scalar threshold) const {
|
||||||
|
return operator<=(constant(threshold));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator>(Scalar threshold) const {
|
||||||
|
return operator>(constant(threshold));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator>=(Scalar threshold) const {
|
||||||
|
return operator>=(constant(threshold));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator==(Scalar threshold) const {
|
||||||
|
return operator==(constant(threshold));
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
|
operator!=(Scalar threshold) const {
|
||||||
|
return operator!=(constant(threshold));
|
||||||
|
}
|
||||||
|
|
||||||
// Coefficient-wise ternary operators.
|
// Coefficient-wise ternary operators.
|
||||||
template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
|
const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
|
||||||
|
@ -143,6 +143,7 @@ if(EIGEN_TEST_CXX11)
|
|||||||
ei_add_test(cxx11_tensor_generator "-std=c++0x")
|
ei_add_test(cxx11_tensor_generator "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
|
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_custom_index "-std=c++0x")
|
ei_add_test(cxx11_tensor_custom_index "-std=c++0x")
|
||||||
|
ei_add_test(cxx11_tensor_sugar "-std=c++0x")
|
||||||
|
|
||||||
# These tests needs nvcc
|
# These tests needs nvcc
|
||||||
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
||||||
|
38
unsupported/test/cxx11_tensor_sugar.cpp
Normal file
38
unsupported/test/cxx11_tensor_sugar.cpp
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#include "main.h"
|
||||||
|
|
||||||
|
#include <Eigen/CXX11/Tensor>
|
||||||
|
|
||||||
|
using Eigen::Tensor;
|
||||||
|
using Eigen::RowMajor;
|
||||||
|
|
||||||
|
static void test_comparison_sugar() {
|
||||||
|
// we already trust comparisons between tensors, we're simply checking that
|
||||||
|
// the sugared versions are doing the same thing
|
||||||
|
Tensor<int, 3> t(6, 7, 5);
|
||||||
|
|
||||||
|
t.setRandom();
|
||||||
|
// make sure we have at least one value == 0
|
||||||
|
t(0,0,0) = 0;
|
||||||
|
|
||||||
|
Tensor<bool,1> b;
|
||||||
|
|
||||||
|
#define TEST_TENSOR_EQUAL(e1, e2) \
|
||||||
|
b = ((e1) == (e2)).all(); \
|
||||||
|
VERIFY(b(0))
|
||||||
|
|
||||||
|
#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))
|
||||||
|
|
||||||
|
TEST_OP(==);
|
||||||
|
TEST_OP(!=);
|
||||||
|
TEST_OP(<=);
|
||||||
|
TEST_OP(>=);
|
||||||
|
TEST_OP(<);
|
||||||
|
TEST_OP(>);
|
||||||
|
#undef TEST_OP
|
||||||
|
#undef TEST_TENSOR_EQUAL
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_cxx11_tensor_sugar()
|
||||||
|
{
|
||||||
|
CALL_SUBTEST(test_comparison_sugar());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user