mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 08:09:36 +08:00
clean up array_cwise test
This commit is contained in:
parent
fda1373a15
commit
2af03fb685
@ -10,6 +10,20 @@
|
||||
#include <vector>
|
||||
#include "main.h"
|
||||
|
||||
// suppress annoying unsigned integer warnings
|
||||
template <typename Scalar, bool IsSigned = NumTraits<Scalar>::IsSigned>
|
||||
struct negative_or_zero_impl {
|
||||
static Scalar run(const Scalar& a) { return -a; }
|
||||
};
|
||||
template <typename Scalar>
|
||||
struct negative_or_zero_impl<Scalar, false> {
|
||||
static Scalar run(const Scalar&) { return 0; }
|
||||
};
|
||||
template <typename Scalar>
|
||||
Scalar negative_or_zero(const Scalar& a) {
|
||||
return negative_or_zero_impl<Scalar>::run(a);
|
||||
}
|
||||
|
||||
template <typename Scalar, std::enable_if_t<NumTraits<Scalar>::IsInteger,int> = 0>
|
||||
std::vector<Scalar> special_values() {
|
||||
const Scalar zero = Scalar(0);
|
||||
@ -249,7 +263,7 @@ template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Expone
|
||||
struct ref_pow {
|
||||
static Base run(Base base, Exponent exponent) {
|
||||
EIGEN_USING_STD(pow);
|
||||
return pow(base, static_cast<Base>(exponent));
|
||||
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
|
||||
}
|
||||
};
|
||||
|
||||
@ -257,7 +271,7 @@ template <typename Base, typename Exponent>
|
||||
struct ref_pow<Base, Exponent, true> {
|
||||
static Base run(Base base, Exponent exponent) {
|
||||
EIGEN_USING_STD(pow);
|
||||
return pow(base, exponent);
|
||||
return static_cast<Base>(pow(base, exponent));
|
||||
}
|
||||
};
|
||||
|
||||
@ -302,7 +316,7 @@ void test_exponent(Exponent exponent) {
|
||||
template <typename Base, typename Exponent>
|
||||
void unary_pow_test() {
|
||||
Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
|
||||
Exponent min_exponent = static_cast<Exponent>(NumTraits<Exponent>::IsSigned ? -max_exponent : 0);
|
||||
Exponent min_exponent = negative_or_zero(max_exponent);
|
||||
|
||||
for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
|
||||
test_exponent<Base, Exponent>(exponent);
|
||||
@ -374,7 +388,7 @@ void signbit_test() {
|
||||
std::vector<Scalar> special_vals = special_values<Scalar>();
|
||||
for (size_t i = 0; i < special_vals.size(); i++) {
|
||||
x(2 * i + 0) = special_vals[i];
|
||||
x(2 * i + 1) = -special_vals[i];
|
||||
x(2 * i + 1) = negative_or_zero(special_vals[i]);
|
||||
}
|
||||
y = x.unaryExpr(internal::test_signbit_op<Scalar>());
|
||||
|
||||
@ -1020,7 +1034,7 @@ template<int N>
|
||||
struct shift_left {
|
||||
template<typename Scalar>
|
||||
Scalar operator()(const Scalar& v) const {
|
||||
return v << N;
|
||||
return (v << N);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1028,29 +1042,10 @@ template<int N>
|
||||
struct arithmetic_shift_right {
|
||||
template<typename Scalar>
|
||||
Scalar operator()(const Scalar& v) const {
|
||||
return v >> N;
|
||||
return (v >> N);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename ArrayType> void array_integer(const ArrayType& m)
|
||||
{
|
||||
Index rows = m.rows();
|
||||
Index cols = m.cols();
|
||||
|
||||
ArrayType m1 = ArrayType::Random(rows, cols),
|
||||
m2(rows, cols);
|
||||
|
||||
m2 = m1.template shiftLeft<2>();
|
||||
VERIFY( (m2 == m1.unaryExpr(shift_left<2>())).all() );
|
||||
m2 = m1.template shiftLeft<9>();
|
||||
VERIFY( (m2 == m1.unaryExpr(shift_left<9>())).all() );
|
||||
|
||||
m2 = m1.template shiftRight<2>();
|
||||
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<2>())).all() );
|
||||
m2 = m1.template shiftRight<9>();
|
||||
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() );
|
||||
}
|
||||
|
||||
template <typename ArrayType>
|
||||
struct signed_shift_test_impl {
|
||||
typedef typename ArrayType::Scalar Scalar;
|
||||
@ -1064,13 +1059,15 @@ struct signed_shift_test_impl {
|
||||
const Index rows = m.rows();
|
||||
const Index cols = m.cols();
|
||||
|
||||
ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols);
|
||||
ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols), m3(rows, cols);
|
||||
|
||||
m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; });
|
||||
VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>())).all());
|
||||
m2 = m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>());
|
||||
m3 = m1.unaryExpr(arithmetic_shift_right<N>());
|
||||
VERIFY_IS_CWISE_EQUAL(m2, m3);
|
||||
|
||||
m2 = m1.unaryExpr([](const Scalar& x) { return x << N; });
|
||||
VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op<Scalar, N>())).all());
|
||||
m2 = m1.unaryExpr(internal::scalar_shift_left_op<Scalar, N>());
|
||||
m3 = m1.unaryExpr(shift_left<N>());
|
||||
VERIFY_IS_CWISE_EQUAL(m2, m3);
|
||||
|
||||
run<N + 1>(m);
|
||||
}
|
||||
@ -1193,8 +1190,6 @@ EIGEN_DECLARE_TEST(array_cwise)
|
||||
CALL_SUBTEST_5( array(ArrayXXf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( array(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( array(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_6( array_integer(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||
CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
CALL_SUBTEST_7( signed_shift_test(Array<Index, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
CALL_SUBTEST_8( array(Array<uint32_t, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
|
||||
|
Loading…
x
Reference in New Issue
Block a user