Use select ternary op in tensor select evaulator

This commit is contained in:
Charles Schlosser 2023-04-18 20:52:16 +00:00 committed by Rasmus Munk Larsen
parent 2b954be663
commit e2bbf496f6
3 changed files with 72 additions and 6 deletions

View File

@ -817,14 +817,21 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
{ {
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType; typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
typedef typename XprType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
using TernarySelectOp = internal::scalar_boolean_select_op<typename internal::traits<ThenArgType>::Scalar,
typename internal::traits<ElseArgType>::Scalar,
typename internal::traits<IfArgType>::Scalar>;
static constexpr bool TernaryPacketAccess =
TensorEvaluator<ThenArgType, Device>::PacketAccess && TensorEvaluator<ElseArgType, Device>::PacketAccess &&
TensorEvaluator<IfArgType, Device>::PacketAccess && internal::functor_traits<TernarySelectOp>::PacketAccess;
static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout; static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
enum { enum {
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
TensorEvaluator<ElseArgType, Device>::IsAligned, TensorEvaluator<ElseArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & PacketAccess = (TensorEvaluator<ThenArgType, Device>::PacketAccess &&
TensorEvaluator<ElseArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &&
PacketType<Scalar, Device>::HasBlend, PacketType<Scalar, Device>::HasBlend) || TernaryPacketAccess,
BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess && BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
TensorEvaluator<ThenArgType, Device>::BlockAccess && TensorEvaluator<ThenArgType, Device>::BlockAccess &&
TensorEvaluator<ElseArgType, Device>::BlockAccess, TensorEvaluator<ElseArgType, Device>::BlockAccess,
@ -922,7 +929,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
{ {
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
} }
template<int LoadMode>
template<int LoadMode, bool UseTernary = TernaryPacketAccess,
std::enable_if_t<!UseTernary, bool> = true>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
{ {
internal::Selector<PacketSize> select; internal::Selector<PacketSize> select;
@ -936,6 +945,14 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
} }
template <int LoadMode, bool UseTernary = TernaryPacketAccess,
std::enable_if_t<UseTernary, bool> = true>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
return TernarySelectOp().template packetOp<PacketReturnType>(m_thenImpl.template packet<LoadMode>(index),
m_elseImpl.template packet<LoadMode>(index),
m_condImpl.template packet<LoadMode>(index));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeff(bool vectorized) const { costPerCoeff(bool vectorized) const {
return m_condImpl.costPerCoeff(vectorized) + return m_condImpl.costPerCoeff(vectorized) +

View File

@ -16,23 +16,41 @@ using Eigen::RowMajor;
using Scalar = float; using Scalar = float;
using TypedLTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
using TypedLEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
using TypedGTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
using TypedGEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
using TypedEQOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
using TypedNEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
static void test_orderings() static void test_orderings()
{ {
Tensor<Scalar, 3> mat1(2,3,7); Tensor<Scalar, 3> mat1(2,3,7);
Tensor<Scalar, 3> mat2(2,3,7); Tensor<Scalar, 3> mat2(2,3,7);
mat1.setRandom();
mat2.setRandom();
Tensor<bool, 3> lt(2,3,7); Tensor<bool, 3> lt(2,3,7);
Tensor<bool, 3> le(2,3,7); Tensor<bool, 3> le(2,3,7);
Tensor<bool, 3> gt(2,3,7); Tensor<bool, 3> gt(2,3,7);
Tensor<bool, 3> ge(2,3,7); Tensor<bool, 3> ge(2,3,7);
mat1.setRandom(); Tensor<Scalar, 3> typed_lt(2, 3, 7);
mat2.setRandom(); Tensor<Scalar, 3> typed_le(2, 3, 7);
Tensor<Scalar, 3> typed_gt(2, 3, 7);
Tensor<Scalar, 3> typed_ge(2, 3, 7);
lt = mat1 < mat2; lt = mat1 < mat2;
le = mat1 <= mat2; le = mat1 <= mat2;
gt = mat1 > mat2; gt = mat1 > mat2;
ge = mat1 >= mat2; ge = mat1 >= mat2;
typed_lt = mat1.binaryExpr(mat2, TypedLTOp());
typed_le = mat1.binaryExpr(mat2, TypedLEOp());
typed_gt = mat1.binaryExpr(mat2, TypedGTOp());
typed_ge = mat1.binaryExpr(mat2, TypedGEOp());
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) { for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) { for (int k = 0; k < 7; ++k) {
@ -40,6 +58,11 @@ static void test_orderings()
VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k)); VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k));
VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k)); VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k));
VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k)); VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k));
VERIFY_IS_EQUAL(lt(i, j, k), (bool)typed_lt(i, j, k));
VERIFY_IS_EQUAL(le(i, j, k), (bool)typed_le(i, j, k));
VERIFY_IS_EQUAL(gt(i, j, k), (bool)typed_gt(i, j, k));
VERIFY_IS_EQUAL(ge(i, j, k), (bool)typed_ge(i, j, k));
} }
} }
} }
@ -65,14 +88,24 @@ static void test_equality()
Tensor<bool, 3> eq(2,3,7); Tensor<bool, 3> eq(2,3,7);
Tensor<bool, 3> ne(2,3,7); Tensor<bool, 3> ne(2,3,7);
Tensor<Scalar, 3> typed_eq(2, 3, 7);
Tensor<Scalar, 3> typed_ne(2, 3, 7);
eq = (mat1 == mat2); eq = (mat1 == mat2);
ne = (mat1 != mat2); ne = (mat1 != mat2);
typed_eq = mat1.binaryExpr(mat2, TypedEQOp());
typed_ne = mat1.binaryExpr(mat2, TypedNEOp());
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) { for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) { for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k)); VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k));
VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k)); VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k));
VERIFY_IS_EQUAL(eq(i, j, k), (bool)typed_eq(i,j,k));
VERIFY_IS_EQUAL(ne(i, j, k), (bool)typed_ne(i,j,k));
} }
} }
} }

View File

@ -280,6 +280,8 @@ static void test_type_casting()
static void test_select() static void test_select()
{ {
using TypedGTOp = internal::scalar_cmp_op<float, float, internal::cmp_GT, true>;
Tensor<float, 3> selector(2,3,7); Tensor<float, 3> selector(2,3,7);
Tensor<float, 3> mat1(2,3,7); Tensor<float, 3> mat1(2,3,7);
Tensor<float, 3> mat2(2,3,7); Tensor<float, 3> mat2(2,3,7);
@ -288,6 +290,8 @@ static void test_select()
selector.setRandom(); selector.setRandom();
mat1.setRandom(); mat1.setRandom();
mat2.setRandom(); mat2.setRandom();
// test select with a boolean condition
result = (selector > selector.constant(0.5f)).select(mat1, mat2); result = (selector > selector.constant(0.5f)).select(mat1, mat2);
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
@ -297,6 +301,18 @@ static void test_select()
} }
} }
} }
// test select with a typed condition
result = selector.binaryExpr(selector.constant(0.5f), TypedGTOp()).select(mat1, mat2);
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_APPROX(result(i, j, k), (selector(i, j, k) > 0.5f) ? mat1(i, j, k) : mat2(i, j, k));
}
}
}
} }
template <typename Scalar> template <typename Scalar>