mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 06:43:13 +08:00
Use select ternary op in tensor select evaulator
This commit is contained in:
parent
2b954be663
commit
e2bbf496f6
@ -818,13 +818,20 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
||||
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
|
||||
using TernarySelectOp = internal::scalar_boolean_select_op<typename internal::traits<ThenArgType>::Scalar,
|
||||
typename internal::traits<ElseArgType>::Scalar,
|
||||
typename internal::traits<IfArgType>::Scalar>;
|
||||
static constexpr bool TernaryPacketAccess =
|
||||
TensorEvaluator<ThenArgType, Device>::PacketAccess && TensorEvaluator<ElseArgType, Device>::PacketAccess &&
|
||||
TensorEvaluator<IfArgType, Device>::PacketAccess && internal::functor_traits<TernarySelectOp>::PacketAccess;
|
||||
|
||||
static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
|
||||
enum {
|
||||
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
|
||||
TensorEvaluator<ElseArgType, Device>::IsAligned,
|
||||
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess &
|
||||
TensorEvaluator<ElseArgType, Device>::PacketAccess &
|
||||
PacketType<Scalar, Device>::HasBlend,
|
||||
PacketAccess = (TensorEvaluator<ThenArgType, Device>::PacketAccess &&
|
||||
TensorEvaluator<ElseArgType, Device>::PacketAccess &&
|
||||
PacketType<Scalar, Device>::HasBlend) || TernaryPacketAccess,
|
||||
BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
|
||||
TensorEvaluator<ThenArgType, Device>::BlockAccess &&
|
||||
TensorEvaluator<ElseArgType, Device>::BlockAccess,
|
||||
@ -922,7 +929,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
||||
{
|
||||
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
|
||||
}
|
||||
template<int LoadMode>
|
||||
|
||||
template<int LoadMode, bool UseTernary = TernaryPacketAccess,
|
||||
std::enable_if_t<!UseTernary, bool> = true>
|
||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
|
||||
{
|
||||
internal::Selector<PacketSize> select;
|
||||
@ -936,6 +945,14 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
||||
|
||||
}
|
||||
|
||||
template <int LoadMode, bool UseTernary = TernaryPacketAccess,
|
||||
std::enable_if_t<UseTernary, bool> = true>
|
||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
|
||||
return TernarySelectOp().template packetOp<PacketReturnType>(m_thenImpl.template packet<LoadMode>(index),
|
||||
m_elseImpl.template packet<LoadMode>(index),
|
||||
m_condImpl.template packet<LoadMode>(index));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
||||
costPerCoeff(bool vectorized) const {
|
||||
return m_condImpl.costPerCoeff(vectorized) +
|
||||
|
@ -16,23 +16,41 @@ using Eigen::RowMajor;
|
||||
|
||||
using Scalar = float;
|
||||
|
||||
using TypedLTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
|
||||
using TypedLEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
|
||||
using TypedGTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
|
||||
using TypedGEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
|
||||
using TypedEQOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
|
||||
using TypedNEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
|
||||
|
||||
static void test_orderings()
|
||||
{
|
||||
Tensor<Scalar, 3> mat1(2,3,7);
|
||||
Tensor<Scalar, 3> mat2(2,3,7);
|
||||
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
|
||||
Tensor<bool, 3> lt(2,3,7);
|
||||
Tensor<bool, 3> le(2,3,7);
|
||||
Tensor<bool, 3> gt(2,3,7);
|
||||
Tensor<bool, 3> ge(2,3,7);
|
||||
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
Tensor<Scalar, 3> typed_lt(2, 3, 7);
|
||||
Tensor<Scalar, 3> typed_le(2, 3, 7);
|
||||
Tensor<Scalar, 3> typed_gt(2, 3, 7);
|
||||
Tensor<Scalar, 3> typed_ge(2, 3, 7);
|
||||
|
||||
lt = mat1 < mat2;
|
||||
le = mat1 <= mat2;
|
||||
gt = mat1 > mat2;
|
||||
ge = mat1 >= mat2;
|
||||
|
||||
typed_lt = mat1.binaryExpr(mat2, TypedLTOp());
|
||||
typed_le = mat1.binaryExpr(mat2, TypedLEOp());
|
||||
typed_gt = mat1.binaryExpr(mat2, TypedGTOp());
|
||||
typed_ge = mat1.binaryExpr(mat2, TypedGEOp());
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
@ -40,6 +58,11 @@ static void test_orderings()
|
||||
VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k));
|
||||
VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k));
|
||||
VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k));
|
||||
|
||||
VERIFY_IS_EQUAL(lt(i, j, k), (bool)typed_lt(i, j, k));
|
||||
VERIFY_IS_EQUAL(le(i, j, k), (bool)typed_le(i, j, k));
|
||||
VERIFY_IS_EQUAL(gt(i, j, k), (bool)typed_gt(i, j, k));
|
||||
VERIFY_IS_EQUAL(ge(i, j, k), (bool)typed_ge(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -65,14 +88,24 @@ static void test_equality()
|
||||
|
||||
Tensor<bool, 3> eq(2,3,7);
|
||||
Tensor<bool, 3> ne(2,3,7);
|
||||
|
||||
Tensor<Scalar, 3> typed_eq(2, 3, 7);
|
||||
Tensor<Scalar, 3> typed_ne(2, 3, 7);
|
||||
|
||||
eq = (mat1 == mat2);
|
||||
ne = (mat1 != mat2);
|
||||
|
||||
typed_eq = mat1.binaryExpr(mat2, TypedEQOp());
|
||||
typed_ne = mat1.binaryExpr(mat2, TypedNEOp());
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k));
|
||||
VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k));
|
||||
|
||||
VERIFY_IS_EQUAL(eq(i, j, k), (bool)typed_eq(i,j,k));
|
||||
VERIFY_IS_EQUAL(ne(i, j, k), (bool)typed_ne(i,j,k));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -280,6 +280,8 @@ static void test_type_casting()
|
||||
|
||||
static void test_select()
|
||||
{
|
||||
using TypedGTOp = internal::scalar_cmp_op<float, float, internal::cmp_GT, true>;
|
||||
|
||||
Tensor<float, 3> selector(2,3,7);
|
||||
Tensor<float, 3> mat1(2,3,7);
|
||||
Tensor<float, 3> mat2(2,3,7);
|
||||
@ -288,6 +290,8 @@ static void test_select()
|
||||
selector.setRandom();
|
||||
mat1.setRandom();
|
||||
mat2.setRandom();
|
||||
|
||||
// test select with a boolean condition
|
||||
result = (selector > selector.constant(0.5f)).select(mat1, mat2);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -297,6 +301,18 @@ static void test_select()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test select with a typed condition
|
||||
result = selector.binaryExpr(selector.constant(0.5f), TypedGTOp()).select(mat1, mat2);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_APPROX(result(i, j, k), (selector(i, j, k) > 0.5f) ? mat1(i, j, k) : mat2(i, j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
|
Loading…
x
Reference in New Issue
Block a user