Added some syntactic sugar to make it simpler to compare a tensor to a scalar.

This commit is contained in:
Benoit Steiner 2015-10-21 11:28:28 -07:00
parent 5ca2e25967
commit b178cc3479
3 changed files with 71 additions and 0 deletions

View File

@ -279,6 +279,38 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
}
// comparisons and tests for Scalars
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator<(Scalar threshold) const {
return operator<(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator<=(Scalar threshold) const {
return operator<=(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator>(Scalar threshold) const {
return operator>(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator>=(Scalar threshold) const {
return operator>=(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator==(Scalar threshold) const {
return operator==(constant(threshold));
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
operator!=(Scalar threshold) const {
return operator!=(constant(threshold));
}
// Coefficient-wise ternary operators.
template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>

View File

@ -143,6 +143,7 @@ if(EIGEN_TEST_CXX11)
ei_add_test(cxx11_tensor_generator "-std=c++0x")
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
ei_add_test(cxx11_tensor_custom_index "-std=c++0x")
ei_add_test(cxx11_tensor_sugar "-std=c++0x")
# These tests needs nvcc
# ei_add_test(cxx11_tensor_device "-std=c++0x")

View File

@ -0,0 +1,38 @@
#include "main.h"
#include <Eigen/CXX11/Tensor>
using Eigen::Tensor;
using Eigen::RowMajor;
static void test_comparison_sugar() {
// we already trust comparisons between tensors, we're simply checking that
// the sugared versions are doing the same thing
Tensor<int, 3> t(6, 7, 5);
t.setRandom();
// make sure we have at least one value == 0
t(0,0,0) = 0;
Tensor<bool,1> b;
#define TEST_TENSOR_EQUAL(e1, e2) \
b = ((e1) == (e2)).all(); \
VERIFY(b(0))
#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))
TEST_OP(==);
TEST_OP(!=);
TEST_OP(<=);
TEST_OP(>=);
TEST_OP(<);
TEST_OP(>);
#undef TEST_OP
#undef TEST_TENSOR_EQUAL
}
void test_cxx11_tensor_sugar()
{
CALL_SUBTEST(test_comparison_sugar());
}