clean up array_cwise test

This commit is contained in:
Charles Schlosser 2023-05-04 16:02:08 +00:00
parent fda1373a15
commit 2af03fb685

View File

@ -10,6 +10,20 @@
#include <vector>
#include "main.h"
// suppress annoying unsigned integer warnings
template <typename Scalar, bool IsSigned = NumTraits<Scalar>::IsSigned>
struct negative_or_zero_impl {
static Scalar run(const Scalar& a) { return -a; }
};
template <typename Scalar>
struct negative_or_zero_impl<Scalar, false> {
static Scalar run(const Scalar&) { return 0; }
};
template <typename Scalar>
Scalar negative_or_zero(const Scalar& a) {
return negative_or_zero_impl<Scalar>::run(a);
}
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
std::vector<Scalar> special_values() {
const Scalar zero = Scalar(0);
@ -249,7 +263,7 @@ template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Expone
struct ref_pow {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return pow(base, static_cast<Base>(exponent));
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
}
};
@ -257,7 +271,7 @@ template <typename Base, typename Exponent>
struct ref_pow<Base, Exponent, true> {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return pow(base, exponent);
return static_cast<Base>(pow(base, exponent));
}
};
@ -302,7 +316,7 @@ void test_exponent(Exponent exponent) {
template <typename Base, typename Exponent>
void unary_pow_test() {
Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
Exponent min_exponent = static_cast<Exponent>(NumTraits<Exponent>::IsSigned ? -max_exponent : 0);
Exponent min_exponent = negative_or_zero(max_exponent);
for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
test_exponent<Base, Exponent>(exponent);
@ -374,7 +388,7 @@ void signbit_test() {
std::vector<Scalar> special_vals = special_values<Scalar>();
for (size_t i = 0; i < special_vals.size(); i++) {
x(2 * i + 0) = special_vals[i];
x(2 * i + 1) = -special_vals[i];
x(2 * i + 1) = negative_or_zero(special_vals[i]);
}
y = x.unaryExpr(internal::test_signbit_op<Scalar>());
@ -1020,7 +1034,7 @@ template<int N>
struct shift_left {
template<typename Scalar>
Scalar operator()(const Scalar& v) const {
return v << N;
return (v << N);
}
};
@ -1028,29 +1042,10 @@ template<int N>
struct arithmetic_shift_right {
template<typename Scalar>
Scalar operator()(const Scalar& v) const {
return v >> N;
return (v >> N);
}
};
template<typename ArrayType> void array_integer(const ArrayType& m)
{
Index rows = m.rows();
Index cols = m.cols();
ArrayType m1 = ArrayType::Random(rows, cols),
m2(rows, cols);
m2 = m1.template shiftLeft<2>();
VERIFY( (m2 == m1.unaryExpr(shift_left<2>())).all() );
m2 = m1.template shiftLeft<9>();
VERIFY( (m2 == m1.unaryExpr(shift_left<9>())).all() );
m2 = m1.template shiftRight<2>();
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<2>())).all() );
m2 = m1.template shiftRight<9>();
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() );
}
template <typename ArrayType>
struct signed_shift_test_impl {
typedef typename ArrayType::Scalar Scalar;
@ -1064,13 +1059,15 @@ struct signed_shift_test_impl {
const Index rows = m.rows();
const Index cols = m.cols();
ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols);
ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols), m3(rows, cols);
m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; });
VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>())).all());
m2 = m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>());
m3 = m1.unaryExpr(arithmetic_shift_right<N>());
VERIFY_IS_CWISE_EQUAL(m2, m3);
m2 = m1.unaryExpr([](const Scalar& x) { return x << N; });
VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op<Scalar, N>())).all());
m2 = m1.unaryExpr(internal::scalar_shift_left_op<Scalar, N>());
m3 = m1.unaryExpr(shift_left<N>());
VERIFY_IS_CWISE_EQUAL(m2, m3);
run<N + 1>(m);
}
@ -1193,8 +1190,6 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_5( array(ArrayXXf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_7( signed_shift_test(Array<Index, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
CALL_SUBTEST_8( array(Array<uint32_t, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));