From e2bbf496f69352b05d4d2cf55d52c38b415cadf6 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Tue, 18 Apr 2023 20:52:16 +0000 Subject: [PATCH] Use select ternary op in tensor select evaulator --- .../Eigen/CXX11/src/Tensor/TensorEvaluator.h | 25 +++++++++++-- unsupported/test/cxx11_tensor_comparisons.cpp | 37 ++++++++++++++++++- unsupported/test/cxx11_tensor_expr.cpp | 16 ++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index f8e3f2981..b16e5a661 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -817,14 +817,21 @@ struct TensorEvaluator { typedef TensorSelectOp XprType; typedef typename XprType::Scalar Scalar; + + using TernarySelectOp = internal::scalar_boolean_select_op::Scalar, + typename internal::traits::Scalar, + typename internal::traits::Scalar>; + static constexpr bool TernaryPacketAccess = + TensorEvaluator::PacketAccess && TensorEvaluator::PacketAccess && + TensorEvaluator::PacketAccess && internal::functor_traits::PacketAccess; static constexpr int Layout = TensorEvaluator::Layout; enum { IsAligned = TensorEvaluator::IsAligned & TensorEvaluator::IsAligned, - PacketAccess = TensorEvaluator::PacketAccess & - TensorEvaluator::PacketAccess & - PacketType::HasBlend, + PacketAccess = (TensorEvaluator::PacketAccess && + TensorEvaluator::PacketAccess && + PacketType::HasBlend) || TernaryPacketAccess, BlockAccess = TensorEvaluator::BlockAccess && TensorEvaluator::BlockAccess && TensorEvaluator::BlockAccess, @@ -922,7 +929,9 @@ struct TensorEvaluator { return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); } - template + + template = true> EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { internal::Selector select; @@ -936,6 +945,14 @@ struct TensorEvaluator } + template = true> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { + return TernarySelectOp().template packetOp(m_thenImpl.template packet(index), + m_elseImpl.template packet(index), + m_condImpl.template packet(index)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { return m_condImpl.costPerCoeff(vectorized) + diff --git a/unsupported/test/cxx11_tensor_comparisons.cpp b/unsupported/test/cxx11_tensor_comparisons.cpp index 86c73355b..e0bd90de9 100644 --- a/unsupported/test/cxx11_tensor_comparisons.cpp +++ b/unsupported/test/cxx11_tensor_comparisons.cpp @@ -16,23 +16,41 @@ using Eigen::RowMajor; using Scalar = float; +using TypedLTOp = internal::scalar_cmp_op; +using TypedLEOp = internal::scalar_cmp_op; +using TypedGTOp = internal::scalar_cmp_op; +using TypedGEOp = internal::scalar_cmp_op; +using TypedEQOp = internal::scalar_cmp_op; +using TypedNEOp = internal::scalar_cmp_op; + static void test_orderings() { Tensor mat1(2,3,7); Tensor mat2(2,3,7); + + mat1.setRandom(); + mat2.setRandom(); + Tensor lt(2,3,7); Tensor le(2,3,7); Tensor gt(2,3,7); Tensor ge(2,3,7); - mat1.setRandom(); - mat2.setRandom(); + Tensor typed_lt(2, 3, 7); + Tensor typed_le(2, 3, 7); + Tensor typed_gt(2, 3, 7); + Tensor typed_ge(2, 3, 7); lt = mat1 < mat2; le = mat1 <= mat2; gt = mat1 > mat2; ge = mat1 >= mat2; + typed_lt = mat1.binaryExpr(mat2, TypedLTOp()); + typed_le = mat1.binaryExpr(mat2, TypedLEOp()); + typed_gt = mat1.binaryExpr(mat2, TypedGTOp()); + typed_ge = mat1.binaryExpr(mat2, TypedGEOp()); + for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { for (int k = 0; k < 7; ++k) { @@ -40,6 +58,11 @@ static void test_orderings() VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k)); VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k)); VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k)); + + VERIFY_IS_EQUAL(lt(i, j, k), (bool)typed_lt(i, j, k)); + VERIFY_IS_EQUAL(le(i, j, k), (bool)typed_le(i, j, k)); + VERIFY_IS_EQUAL(gt(i, j, k), (bool)typed_gt(i, j, k)); + VERIFY_IS_EQUAL(ge(i, j, k), (bool)typed_ge(i, j, k)); } } } @@ -65,14 +88,24 @@ static void test_equality() Tensor eq(2,3,7); Tensor ne(2,3,7); + + Tensor typed_eq(2, 3, 7); + Tensor typed_ne(2, 3, 7); + eq = (mat1 == mat2); ne = (mat1 != mat2); + typed_eq = mat1.binaryExpr(mat2, TypedEQOp()); + typed_ne = mat1.binaryExpr(mat2, TypedNEOp()); + for (int i = 0; i < 2; ++i) { for (int j = 0; j < 3; ++j) { for (int k = 0; k < 7; ++k) { VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k)); VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k)); + + VERIFY_IS_EQUAL(eq(i, j, k), (bool)typed_eq(i,j,k)); + VERIFY_IS_EQUAL(ne(i, j, k), (bool)typed_ne(i,j,k)); } } } diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index f99c80a22..c76fbc5e9 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -280,6 +280,8 @@ static void test_type_casting() static void test_select() { + using TypedGTOp = internal::scalar_cmp_op; + Tensor selector(2,3,7); Tensor mat1(2,3,7); Tensor mat2(2,3,7); @@ -288,6 +290,8 @@ static void test_select() selector.setRandom(); mat1.setRandom(); mat2.setRandom(); + + // test select with a boolean condition result = (selector > selector.constant(0.5f)).select(mat1, mat2); for (int i = 0; i < 2; ++i) { @@ -297,6 +301,18 @@ static void test_select() } } } + + // test select with a typed condition + result = selector.binaryExpr(selector.constant(0.5f), TypedGTOp()).select(mat1, mat2); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_APPROX(result(i, j, k), (selector(i, j, k) > 0.5f) ? mat1(i, j, k) : mat2(i, j, k)); + } + } + } + } template