mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 06:43:13 +08:00
Use select ternary op in tensor select evaulator
This commit is contained in:
parent
2b954be663
commit
e2bbf496f6
@ -817,14 +817,21 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
|||||||
{
|
{
|
||||||
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
|
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
|
||||||
|
using TernarySelectOp = internal::scalar_boolean_select_op<typename internal::traits<ThenArgType>::Scalar,
|
||||||
|
typename internal::traits<ElseArgType>::Scalar,
|
||||||
|
typename internal::traits<IfArgType>::Scalar>;
|
||||||
|
static constexpr bool TernaryPacketAccess =
|
||||||
|
TensorEvaluator<ThenArgType, Device>::PacketAccess && TensorEvaluator<ElseArgType, Device>::PacketAccess &&
|
||||||
|
TensorEvaluator<IfArgType, Device>::PacketAccess && internal::functor_traits<TernarySelectOp>::PacketAccess;
|
||||||
|
|
||||||
static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
|
static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
|
||||||
enum {
|
enum {
|
||||||
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
|
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned &
|
||||||
TensorEvaluator<ElseArgType, Device>::IsAligned,
|
TensorEvaluator<ElseArgType, Device>::IsAligned,
|
||||||
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess &
|
PacketAccess = (TensorEvaluator<ThenArgType, Device>::PacketAccess &&
|
||||||
TensorEvaluator<ElseArgType, Device>::PacketAccess &
|
TensorEvaluator<ElseArgType, Device>::PacketAccess &&
|
||||||
PacketType<Scalar, Device>::HasBlend,
|
PacketType<Scalar, Device>::HasBlend) || TernaryPacketAccess,
|
||||||
BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
|
BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
|
||||||
TensorEvaluator<ThenArgType, Device>::BlockAccess &&
|
TensorEvaluator<ThenArgType, Device>::BlockAccess &&
|
||||||
TensorEvaluator<ElseArgType, Device>::BlockAccess,
|
TensorEvaluator<ElseArgType, Device>::BlockAccess,
|
||||||
@ -922,7 +929,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
|||||||
{
|
{
|
||||||
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
|
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
|
||||||
}
|
}
|
||||||
template<int LoadMode>
|
|
||||||
|
template<int LoadMode, bool UseTernary = TernaryPacketAccess,
|
||||||
|
std::enable_if_t<!UseTernary, bool> = true>
|
||||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
|
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
|
||||||
{
|
{
|
||||||
internal::Selector<PacketSize> select;
|
internal::Selector<PacketSize> select;
|
||||||
@ -936,6 +945,14 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int LoadMode, bool UseTernary = TernaryPacketAccess,
|
||||||
|
std::enable_if_t<UseTernary, bool> = true>
|
||||||
|
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
|
||||||
|
return TernarySelectOp().template packetOp<PacketReturnType>(m_thenImpl.template packet<LoadMode>(index),
|
||||||
|
m_elseImpl.template packet<LoadMode>(index),
|
||||||
|
m_condImpl.template packet<LoadMode>(index));
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
||||||
costPerCoeff(bool vectorized) const {
|
costPerCoeff(bool vectorized) const {
|
||||||
return m_condImpl.costPerCoeff(vectorized) +
|
return m_condImpl.costPerCoeff(vectorized) +
|
||||||
|
@ -16,23 +16,41 @@ using Eigen::RowMajor;
|
|||||||
|
|
||||||
using Scalar = float;
|
using Scalar = float;
|
||||||
|
|
||||||
|
using TypedLTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
|
||||||
|
using TypedLEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
|
||||||
|
using TypedGTOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
|
||||||
|
using TypedGEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
|
||||||
|
using TypedEQOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
|
||||||
|
using TypedNEOp = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
|
||||||
|
|
||||||
static void test_orderings()
|
static void test_orderings()
|
||||||
{
|
{
|
||||||
Tensor<Scalar, 3> mat1(2,3,7);
|
Tensor<Scalar, 3> mat1(2,3,7);
|
||||||
Tensor<Scalar, 3> mat2(2,3,7);
|
Tensor<Scalar, 3> mat2(2,3,7);
|
||||||
|
|
||||||
|
mat1.setRandom();
|
||||||
|
mat2.setRandom();
|
||||||
|
|
||||||
Tensor<bool, 3> lt(2,3,7);
|
Tensor<bool, 3> lt(2,3,7);
|
||||||
Tensor<bool, 3> le(2,3,7);
|
Tensor<bool, 3> le(2,3,7);
|
||||||
Tensor<bool, 3> gt(2,3,7);
|
Tensor<bool, 3> gt(2,3,7);
|
||||||
Tensor<bool, 3> ge(2,3,7);
|
Tensor<bool, 3> ge(2,3,7);
|
||||||
|
|
||||||
mat1.setRandom();
|
Tensor<Scalar, 3> typed_lt(2, 3, 7);
|
||||||
mat2.setRandom();
|
Tensor<Scalar, 3> typed_le(2, 3, 7);
|
||||||
|
Tensor<Scalar, 3> typed_gt(2, 3, 7);
|
||||||
|
Tensor<Scalar, 3> typed_ge(2, 3, 7);
|
||||||
|
|
||||||
lt = mat1 < mat2;
|
lt = mat1 < mat2;
|
||||||
le = mat1 <= mat2;
|
le = mat1 <= mat2;
|
||||||
gt = mat1 > mat2;
|
gt = mat1 > mat2;
|
||||||
ge = mat1 >= mat2;
|
ge = mat1 >= mat2;
|
||||||
|
|
||||||
|
typed_lt = mat1.binaryExpr(mat2, TypedLTOp());
|
||||||
|
typed_le = mat1.binaryExpr(mat2, TypedLEOp());
|
||||||
|
typed_gt = mat1.binaryExpr(mat2, TypedGTOp());
|
||||||
|
typed_ge = mat1.binaryExpr(mat2, TypedGEOp());
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
for (int j = 0; j < 3; ++j) {
|
for (int j = 0; j < 3; ++j) {
|
||||||
for (int k = 0; k < 7; ++k) {
|
for (int k = 0; k < 7; ++k) {
|
||||||
@ -40,6 +58,11 @@ static void test_orderings()
|
|||||||
VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k));
|
VERIFY_IS_EQUAL(le(i,j,k), mat1(i,j,k) <= mat2(i,j,k));
|
||||||
VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k));
|
VERIFY_IS_EQUAL(gt(i,j,k), mat1(i,j,k) > mat2(i,j,k));
|
||||||
VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k));
|
VERIFY_IS_EQUAL(ge(i,j,k), mat1(i,j,k) >= mat2(i,j,k));
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(lt(i, j, k), (bool)typed_lt(i, j, k));
|
||||||
|
VERIFY_IS_EQUAL(le(i, j, k), (bool)typed_le(i, j, k));
|
||||||
|
VERIFY_IS_EQUAL(gt(i, j, k), (bool)typed_gt(i, j, k));
|
||||||
|
VERIFY_IS_EQUAL(ge(i, j, k), (bool)typed_ge(i, j, k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,14 +88,24 @@ static void test_equality()
|
|||||||
|
|
||||||
Tensor<bool, 3> eq(2,3,7);
|
Tensor<bool, 3> eq(2,3,7);
|
||||||
Tensor<bool, 3> ne(2,3,7);
|
Tensor<bool, 3> ne(2,3,7);
|
||||||
|
|
||||||
|
Tensor<Scalar, 3> typed_eq(2, 3, 7);
|
||||||
|
Tensor<Scalar, 3> typed_ne(2, 3, 7);
|
||||||
|
|
||||||
eq = (mat1 == mat2);
|
eq = (mat1 == mat2);
|
||||||
ne = (mat1 != mat2);
|
ne = (mat1 != mat2);
|
||||||
|
|
||||||
|
typed_eq = mat1.binaryExpr(mat2, TypedEQOp());
|
||||||
|
typed_ne = mat1.binaryExpr(mat2, TypedNEOp());
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
for (int j = 0; j < 3; ++j) {
|
for (int j = 0; j < 3; ++j) {
|
||||||
for (int k = 0; k < 7; ++k) {
|
for (int k = 0; k < 7; ++k) {
|
||||||
VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k));
|
VERIFY_IS_EQUAL(eq(i,j,k), mat1(i,j,k) == mat2(i,j,k));
|
||||||
VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k));
|
VERIFY_IS_EQUAL(ne(i,j,k), mat1(i,j,k) != mat2(i,j,k));
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(eq(i, j, k), (bool)typed_eq(i,j,k));
|
||||||
|
VERIFY_IS_EQUAL(ne(i, j, k), (bool)typed_ne(i,j,k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -280,6 +280,8 @@ static void test_type_casting()
|
|||||||
|
|
||||||
static void test_select()
|
static void test_select()
|
||||||
{
|
{
|
||||||
|
using TypedGTOp = internal::scalar_cmp_op<float, float, internal::cmp_GT, true>;
|
||||||
|
|
||||||
Tensor<float, 3> selector(2,3,7);
|
Tensor<float, 3> selector(2,3,7);
|
||||||
Tensor<float, 3> mat1(2,3,7);
|
Tensor<float, 3> mat1(2,3,7);
|
||||||
Tensor<float, 3> mat2(2,3,7);
|
Tensor<float, 3> mat2(2,3,7);
|
||||||
@ -288,6 +290,8 @@ static void test_select()
|
|||||||
selector.setRandom();
|
selector.setRandom();
|
||||||
mat1.setRandom();
|
mat1.setRandom();
|
||||||
mat2.setRandom();
|
mat2.setRandom();
|
||||||
|
|
||||||
|
// test select with a boolean condition
|
||||||
result = (selector > selector.constant(0.5f)).select(mat1, mat2);
|
result = (selector > selector.constant(0.5f)).select(mat1, mat2);
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
@ -297,6 +301,18 @@ static void test_select()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test select with a typed condition
|
||||||
|
result = selector.binaryExpr(selector.constant(0.5f), TypedGTOp()).select(mat1, mat2);
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_APPROX(result(i, j, k), (selector(i, j, k) > 0.5f) ? mat1(i, j, k) : mat2(i, j, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user