mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Vectorize isfinite and isinf.
This commit is contained in:
parent
5a9f66fb35
commit
9148c47d67
@ -579,12 +579,6 @@ EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) {
|
||||
return pand(a, pnot(b));
|
||||
}
|
||||
|
||||
/** \internal \returns isnan(a) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) {
|
||||
return pandnot(ptrue(a), pcmp_eq(a, a));
|
||||
}
|
||||
|
||||
// In the general case, use bitwise select.
|
||||
template <typename Packet, typename EnableIf = void>
|
||||
struct pselect_impl {
|
||||
@ -1002,6 +996,20 @@ EIGEN_DEVICE_FUNC inline Packet pcplxflip(const Packet& a) {
|
||||
* Special math functions
|
||||
***************************/
|
||||
|
||||
/** \internal \returns isnan(a) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) {
|
||||
return pandnot(ptrue(a), pcmp_eq(a, a));
|
||||
}
|
||||
|
||||
/** \internal \returns isinf(a) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pisinf(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
constexpr Scalar inf = NumTraits<Scalar>::infinity();
|
||||
return pcmp_eq(pabs(a), pset1<Packet>(inf));
|
||||
}
|
||||
|
||||
/** \internal \returns the sine of \a a (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin(const Packet& a) {
|
||||
|
@ -989,10 +989,9 @@ struct functor_traits<scalar_isnan_op<Scalar, UseTypedPredicate>> {
|
||||
* \brief Template functor to check whether a scalar is +/-inf
|
||||
* \sa class CwiseUnaryOp, ArrayBase::isinf()
|
||||
*/
|
||||
template <typename Scalar>
|
||||
template <typename Scalar, bool UseTypedPredicate = false>
|
||||
struct scalar_isinf_op {
|
||||
typedef bool result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return numext::isinf(a);
|
||||
#else
|
||||
@ -1000,19 +999,33 @@ struct scalar_isinf_op {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct functor_traits<scalar_isinf_op<Scalar>> {
|
||||
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = false };
|
||||
struct scalar_isinf_op<Scalar, true> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return (numext::isinf(a) ? ptrue(a) : pzero(a));
|
||||
#else
|
||||
return (numext::isinf EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a));
|
||||
#endif
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
|
||||
return pisinf(a);
|
||||
}
|
||||
};
|
||||
template <typename Scalar, bool UseTypedPredicate>
|
||||
struct functor_traits<scalar_isinf_op<Scalar, UseTypedPredicate>> {
|
||||
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate };
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template functor to check whether a scalar has a finite value
|
||||
* \sa class CwiseUnaryOp, ArrayBase::isfinite()
|
||||
*/
|
||||
template <typename Scalar>
|
||||
template <typename Scalar, bool UseTypedPredicate = false>
|
||||
struct scalar_isfinite_op {
|
||||
typedef bool result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return numext::isfinite(a);
|
||||
#else
|
||||
@ -1020,9 +1033,25 @@ struct scalar_isfinite_op {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct functor_traits<scalar_isfinite_op<Scalar>> {
|
||||
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = false };
|
||||
struct scalar_isfinite_op<Scalar, true> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return (numext::isfinite(a) ? ptrue(a) : pzero(a));
|
||||
#else
|
||||
return (numext::isfinite EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a));
|
||||
#endif
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
|
||||
constexpr Scalar inf = NumTraits<Scalar>::infinity();
|
||||
return pcmp_lt(pabs(a), pset1<Packet>(inf));
|
||||
}
|
||||
};
|
||||
template <typename Scalar, bool UseTypedPredicate>
|
||||
struct functor_traits<scalar_isfinite_op<Scalar, UseTypedPredicate>> {
|
||||
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate };
|
||||
};
|
||||
|
||||
/** \internal
|
||||
|
@ -618,16 +618,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
(isnan)() const {
|
||||
return unaryExpr(internal::scalar_isnan_op<Scalar, true>()).template cast<bool>();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived>
|
||||
EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar, true>, const Derived>>
|
||||
(isinf)() const {
|
||||
return unaryExpr(internal::scalar_isinf_op<Scalar>());
|
||||
return unaryExpr(internal::scalar_isinf_op<Scalar, true>()).template cast<bool>();
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isfinite_op<Scalar>, const Derived>
|
||||
EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isfinite_op<Scalar, true>, const Derived>>
|
||||
(isfinite)() const {
|
||||
return unaryExpr(internal::scalar_isfinite_op<Scalar>());
|
||||
return unaryExpr(internal::scalar_isfinite_op<Scalar, true>()).template cast<bool>();
|
||||
}
|
||||
|
||||
// Coefficient-wise ternary operators.
|
||||
|
@ -132,8 +132,61 @@ static void test_isnan() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_isinf() {
|
||||
Tensor<Scalar, 3> mat(2, 3, 7);
|
||||
|
||||
mat.setRandom();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
if (internal::random<bool>()) {
|
||||
mat(i, j, k) = std::numeric_limits<Scalar>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor<bool, 3> inf(2, 3, 7);
|
||||
inf = (mat.isinf)();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(inf(i, j, k), (std::isinf)(mat(i, j, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void test_isfinite() {
|
||||
Tensor<Scalar, 3> mat(2, 3, 7);
|
||||
|
||||
mat.setRandom();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
if (internal::random<bool>()) {
|
||||
mat(i, j, k) = std::numeric_limits<Scalar>::infinity();
|
||||
}
|
||||
if (internal::random<bool>()) {
|
||||
mat(i, j, k) = std::numeric_limits<Scalar>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor<bool, 3> inf(2, 3, 7);
|
||||
inf = (mat.isfinite)();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(inf(i, j, k), (std::isfinite)(mat(i, j, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) {
|
||||
CALL_SUBTEST(test_orderings());
|
||||
CALL_SUBTEST(test_equality());
|
||||
CALL_SUBTEST(test_isnan());
|
||||
CALL_SUBTEST(test_isinf());
|
||||
CALL_SUBTEST(test_isfinite());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user