From 382279eb7f0160b1b20a0e1b95df2397277ede08 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 26 Nov 2018 14:10:07 +0100 Subject: [PATCH] Extend unit test to recursively check half-packet types and non packet types --- test/packetmath.cpp | 169 +++++++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 57 deletions(-) diff --git a/test/packetmath.cpp b/test/packetmath.cpp index babb7c20e..43c33ba94 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -10,6 +10,7 @@ #include "main.h" #include "unsupported/Eigen/SpecialFunctions" +#include #if defined __GNUC__ && __GNUC__>=6 #pragma GCC diagnostic ignored "-Wignored-attributes" @@ -22,6 +23,8 @@ const bool g_vectorize_sse = true; const bool g_vectorize_sse = false; #endif +bool g_first_pass = true; + namespace Eigen { namespace internal { template T negate(const T& x) { return -x; } @@ -109,14 +112,18 @@ struct packet_helper #define REF_MUL(a,b) ((a)*(b)) #define REF_DIV(a,b) ((a)/(b)) -template void packetmath() +template void packetmath() { using std::abs; typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; typedef typename NumTraits::Real RealScalar; + if (g_first_pass) + std::cerr << "=== Testing packet of type '" << typeid(Packet).name() + << "' and scalar type '" << typeid(Scalar).name() + << "' and size '" << PacketSize << "' ===\n" ; + const int max_size = PacketSize > 4 ? PacketSize : 4; const int size = PacketSize*max_size; EIGEN_ALIGN_MAX Scalar data1[size]; @@ -254,7 +261,7 @@ template void packetmath() ref[0] += data1[i]; VERIFY(isApproxAbs(ref[0], internal::predux(internal::pload(data1)), refvalue) && "internal::predux"); - if(PacketSize==8 && internal::unpacket_traits::half>::size ==4) // so far, predux_half_dowto4 is only required in such a case + if(PacketSize==8 && internal::unpacket_traits::half>::size ==4) // so far, predux_half_downto4 is only required in such a case { int HalfPacketSize = PacketSize>4 ? PacketSize/2 : PacketSize; for (int i=0; i void packetmath() } } -template void packetmath_real() +template void packetmath_real() { using std::abs; typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; const int size = PacketSize*4; - EIGEN_ALIGN_MAX Scalar data1[PacketTraits::size*4]; - EIGEN_ALIGN_MAX Scalar data2[PacketTraits::size*4]; - EIGEN_ALIGN_MAX Scalar ref[PacketTraits::size*4]; + EIGEN_ALIGN_MAX Scalar data1[PacketSize*4]; + EIGEN_ALIGN_MAX Scalar data2[PacketSize*4]; + EIGEN_ALIGN_MAX Scalar ref[PacketSize*4]; for (int i=0; i void packetmath_real() data2[i] = internal::random(-1,1) * std::pow(Scalar(10), internal::random(-6,6)); } CHECK_CWISE1_IF(PacketTraits::HasTanh, std::tanh, internal::ptanh); - if(PacketTraits::HasExp && PacketTraits::size>=2) + if(PacketTraits::HasExp && PacketSize>=2) { data1[0] = std::numeric_limits::quiet_NaN(); data1[1] = std::numeric_limits::epsilon(); @@ -455,7 +461,7 @@ template void packetmath_real() CHECK_CWISE1_IF(internal::packet_traits::HasErfc, std::erfc, internal::perfc); #endif - if(PacketTraits::HasLog && PacketTraits::size>=2) + if(PacketTraits::HasLog && PacketSize>=2) { data1[0] = std::numeric_limits::quiet_NaN(); data1[1] = std::numeric_limits::epsilon(); @@ -497,18 +503,17 @@ template void packetmath_real() } } -template void packetmath_notcomplex() +template void packetmath_notcomplex() { using std::abs; typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; - EIGEN_ALIGN_MAX Scalar data1[PacketTraits::size*4]; - EIGEN_ALIGN_MAX Scalar data2[PacketTraits::size*4]; - EIGEN_ALIGN_MAX Scalar ref[PacketTraits::size*4]; + EIGEN_ALIGN_MAX Scalar data1[PacketSize*4]; + EIGEN_ALIGN_MAX Scalar data2[PacketSize*4]; + EIGEN_ALIGN_MAX Scalar ref[PacketSize*4]; - Array::Map(data1, PacketTraits::size*4).setRandom(); + Array::Map(data1, PacketSize*4).setRandom(); ref[0] = data1[0]; for (int i=0; i void packetmath_notcomplex() VERIFY(areApprox(ref, data2, PacketSize) && "internal::plset"); } -template void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) +template void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) { - typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; internal::conj_if cj0; internal::conj_if cj1; @@ -562,11 +565,9 @@ template void test_conj_helper(Scalar VERIFY(areApprox(ref, pval, PacketSize) && "conj_helper pmadd"); } -template void packetmath_complex() +template void packetmath_complex() { - typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; const int size = PacketSize*4; EIGEN_ALIGN_MAX Scalar data1[PacketSize*4]; @@ -580,10 +581,10 @@ template void packetmath_complex() data2[i] = internal::random() * Scalar(1e2); } - test_conj_helper (data1,data2,ref,pval); - test_conj_helper (data1,data2,ref,pval); - test_conj_helper (data1,data2,ref,pval); - test_conj_helper (data1,data2,ref,pval); + test_conj_helper (data1,data2,ref,pval); + test_conj_helper (data1,data2,ref,pval); + test_conj_helper (data1,data2,ref,pval); + test_conj_helper (data1,data2,ref,pval); { for(int i=0;i void packetmath_complex() } } -template void packetmath_scatter_gather() +template void packetmath_scatter_gather() { - typedef internal::packet_traits PacketTraits; - typedef typename PacketTraits::type Packet; typedef typename NumTraits::Real RealScalar; - const int PacketSize = PacketTraits::size; + const int PacketSize = internal::unpacket_traits::size; EIGEN_ALIGN_MAX Scalar data1[PacketSize]; RealScalar refvalue = 0; for (int i=0; i void packetmath_scatter_gather() } } + +template< + typename Scalar, + typename PacketType, + bool IsComplex = NumTraits::IsComplex, + bool IsInteger = NumTraits::IsInteger> +struct runall; + +template +struct runall { // i.e. float or double + static void run() { + packetmath(); + packetmath_scatter_gather(); + packetmath_notcomplex(); + packetmath_real(); + } +}; + +template +struct runall { // i.e. int + static void run() { + packetmath(); + packetmath_scatter_gather(); + packetmath_notcomplex(); + } +}; + +template +struct runall { // i.e. complex + static void run() { + packetmath(); + packetmath_scatter_gather(); + packetmath_complex(); + } +}; + +template< + typename Scalar, + typename PacketType = typename internal::packet_traits::type, + bool Vectorized = internal::packet_traits::Vectorizable, + bool HasHalf = !internal::is_same::half,PacketType>::value > +struct runner; + +template +struct runner +{ + static void run() { + runall::run(); + runner::half>::run(); + } +}; + +template +struct runner +{ + static void run() { + runall::run(); + runall::run(); + } +}; + +template +struct runner +{ + static void run() { + runall::run(); + } +}; + EIGEN_DECLARE_TEST(packetmath) { + g_first_pass = true; for(int i = 0; i < g_repeat; i++) { - CALL_SUBTEST_1( packetmath() ); - CALL_SUBTEST_2( packetmath() ); - CALL_SUBTEST_3( packetmath() ); - CALL_SUBTEST_4( packetmath >() ); - CALL_SUBTEST_5( packetmath >() ); - CALL_SUBTEST_6( packetmath() ); - - CALL_SUBTEST_1( packetmath_notcomplex() ); - CALL_SUBTEST_2( packetmath_notcomplex() ); - CALL_SUBTEST_3( packetmath_notcomplex() ); - - CALL_SUBTEST_1( packetmath_real() ); - CALL_SUBTEST_2( packetmath_real() ); - - CALL_SUBTEST_4( packetmath_complex >() ); - CALL_SUBTEST_5( packetmath_complex >() ); - - CALL_SUBTEST_1( packetmath_scatter_gather() ); - CALL_SUBTEST_2( packetmath_scatter_gather() ); - CALL_SUBTEST_3( packetmath_scatter_gather() ); - CALL_SUBTEST_4( packetmath_scatter_gather >() ); - CALL_SUBTEST_5( packetmath_scatter_gather >() ); + + CALL_SUBTEST_1( runner::run() ); + CALL_SUBTEST_2( runner::run() ); + CALL_SUBTEST_3( runner::run() ); + CALL_SUBTEST_4( runner >::run() ); + CALL_SUBTEST_5( runner >::run() ); + CALL_SUBTEST_6(( packetmath::type>() )); + g_first_pass = false; } }