mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-20 22:52:51 +08:00
Merged in rmlarsen/eigen2 (pull request PR-392)
Add vectorized clip functor for Eigen Tensors
This commit is contained in:
commit
ad355b3f05
@ -209,6 +209,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return unaryExpr(internal::scalar_abs_op<Scalar>());
|
return unaryExpr(internal::scalar_abs_op<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clip_op<Scalar>, const Derived>
|
||||||
|
clip(Scalar min, Scalar max) const {
|
||||||
|
return unaryExpr(internal::scalar_clip_op<Scalar>(min, max));
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
|
||||||
conjugate() const {
|
conjugate() const {
|
||||||
|
@ -487,6 +487,25 @@ struct functor_traits<GaussianGenerator<T, Index, NumDims> > {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct scalar_clip_op {
|
||||||
|
EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
|
||||||
|
operator()(const Scalar& x) const {
|
||||||
|
return numext::mini(numext::maxi(x, m_min), m_max);
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
|
||||||
|
packetOp(const Packet& x) const {
|
||||||
|
return internal::pmin(internal::pmax(x, pset1<Packet>(m_min)), pset1<Packet>(m_max));
|
||||||
|
}
|
||||||
|
const Scalar m_min;
|
||||||
|
const Scalar m_max;
|
||||||
|
};
|
||||||
|
template<typename Scalar>
|
||||||
|
struct functor_traits<scalar_clip_op<Scalar> >
|
||||||
|
{ enum { Cost = 2 * NumTraits<Scalar>::AddCost, PacketAccess = (packet_traits<Scalar>::HasMin && packet_traits<Scalar>::HasMax)}; };
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -340,6 +340,26 @@ void test_minmax_nan_propagation_templ() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_clip()
|
||||||
|
{
|
||||||
|
Tensor<float, 1> vec(6);
|
||||||
|
vec(0) = 4.0;
|
||||||
|
vec(1) = 8.0;
|
||||||
|
vec(2) = 15.0;
|
||||||
|
vec(3) = 16.0;
|
||||||
|
vec(4) = 23.0;
|
||||||
|
vec(5) = 42.0;
|
||||||
|
|
||||||
|
float kMin = 20;
|
||||||
|
float kMax = 30;
|
||||||
|
|
||||||
|
Tensor<float, 1> vec_clipped(6);
|
||||||
|
vec_clipped = vec.clip(kMin, kMax);
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
VERIFY_IS_EQUAL(vec_clipped(i), numext::mini(numext::maxi(vec(i), kMin), kMax));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void test_minmax_nan_propagation()
|
static void test_minmax_nan_propagation()
|
||||||
{
|
{
|
||||||
test_minmax_nan_propagation_templ<float>();
|
test_minmax_nan_propagation_templ<float>();
|
||||||
@ -356,5 +376,6 @@ void test_cxx11_tensor_expr()
|
|||||||
CALL_SUBTEST(test_functors());
|
CALL_SUBTEST(test_functors());
|
||||||
CALL_SUBTEST(test_type_casting());
|
CALL_SUBTEST(test_type_casting());
|
||||||
CALL_SUBTEST(test_select());
|
CALL_SUBTEST(test_select());
|
||||||
|
CALL_SUBTEST(test_clip());
|
||||||
CALL_SUBTEST(test_minmax_nan_propagation());
|
CALL_SUBTEST(test_minmax_nan_propagation());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user