diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h index db19b56c5..2a611e497 100644 --- a/Eigen/src/Core/arch/AVX/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -64,7 +64,6 @@ struct type_casting_traits { }; #endif // EIGEN_VECTORIZE_AVX512 - template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet8f& a) { return _mm256_cvttps_epi32(a); } @@ -77,6 +76,10 @@ template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet4d return _mm256_set_m128(_mm256_cvtpd_ps(b), _mm256_cvtpd_ps(a)); } +template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet4d& a, const Packet4d& b) { + return _mm256_set_m128i(_mm256_cvtpd_epi32(b), _mm256_cvtpd_epi32(a)); +} + template <> EIGEN_STRONG_INLINE Packet16b pcast(const Packet8f& a, const Packet8f& b) { diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 543f424b4..17bff79e7 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -544,6 +544,8 @@ EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, con template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); } template EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); } EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); } +EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) { + return _mm512_insertf32x8(_mm512_castsi256_si512(a),b,1); } #else // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { @@ -559,6 +561,9 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), _mm256_castps_si256(b),1)); } +EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) { + return _mm512_inserti64x4(_mm512_castsi256_si512(a), b, 1); +} #endif // Helper function for bit packing snippet of low precision comparison. diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 60f49a3d9..ca588e93c 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -55,6 +55,10 @@ template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b)); } +template<> EIGEN_STRONG_INLINE Packet16i pcast(const Packet8d& a, const Packet8d& b) { + return cat256i(_mm512_cvtpd_epi32(a), _mm512_cvtpd_epi32(b)); +} + template<> EIGEN_STRONG_INLINE Packet16i preinterpret(const Packet16f& a) { return _mm512_castps_si512(a); } diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h index 2ab09434f..dd146187f 100644 --- a/Eigen/src/Core/arch/SSE/TypeCasting.h +++ b/Eigen/src/Core/arch/SSE/TypeCasting.h @@ -27,13 +27,14 @@ struct type_casting_traits { }; template <> -struct type_casting_traits { +struct type_casting_traits { enum { VectorizedCast = 1, SrcCoeffRatio = 1, - TgtCoeffRatio = 1 + TgtCoeffRatio = 2 }; }; +#endif template <> struct type_casting_traits { @@ -45,14 +46,22 @@ struct type_casting_traits { }; template <> -struct type_casting_traits { +struct type_casting_traits { enum { VectorizedCast = 1, SrcCoeffRatio = 1, - TgtCoeffRatio = 2 + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 2, + TgtCoeffRatio = 1 }; }; -#endif template <> struct type_casting_traits { @@ -91,6 +100,12 @@ template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2d return _mm_shuffle_ps(_mm_cvtpd_ps(a), _mm_cvtpd_ps(b), (1 << 2) | (1 << 6)); } +template<> EIGEN_STRONG_INLINE Packet4i pcast(const Packet2d& a, const Packet2d& b) { + return _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_mm_cvtpd_epi32(a)), + _mm_castsi128_ps(_mm_cvtpd_epi32(b)), + (1 << 2) | (1 << 6))); +} + template<> EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f& a) { // Simply discard the second half of the input return _mm_cvtps_pd(a); diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp index 7b67738ea..81d81efdd 100644 --- a/unsupported/test/cxx11_tensor_casts.cpp +++ b/unsupported/test/cxx11_tensor_casts.cpp @@ -15,113 +15,23 @@ using Eigen::Tensor; using Eigen::array; -static void test_simple_cast() -{ - Tensor ftensor(20,30); - ftensor = ftensor.random() * 100.f; - Tensor chartensor(20,30); - chartensor.setRandom(); - Tensor, 2> cplextensor(20,30); - cplextensor.setRandom(); - - chartensor = ftensor.cast(); - cplextensor = ftensor.cast >(); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 30; ++j) { - VERIFY_IS_EQUAL(chartensor(i,j), static_cast(ftensor(i,j))); - VERIFY_IS_EQUAL(cplextensor(i,j), static_cast >(ftensor(i,j))); - } - } -} - - -static void test_vectorized_cast() -{ - Tensor itensor(20,30); - itensor = itensor.random() / 1000; - Tensor ftensor(20,30); - ftensor.setRandom(); - Tensor dtensor(20,30); - dtensor.setRandom(); - - ftensor = itensor.cast(); - dtensor = itensor.cast(); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 30; ++j) { - VERIFY_IS_EQUAL(itensor(i,j), static_cast(ftensor(i,j))); - VERIFY_IS_EQUAL(dtensor(i,j), static_cast(ftensor(i,j))); - } - } -} - - -static void test_float_to_int_cast() -{ - Tensor ftensor(20,30); - ftensor = ftensor.random() * 1000.0f; - Tensor dtensor(20,30); - dtensor = dtensor.random() * 1000.0; - - Tensor i1tensor = ftensor.cast(); - Tensor i2tensor = dtensor.cast(); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 30; ++j) { - VERIFY_IS_EQUAL(i1tensor(i,j), static_cast(ftensor(i,j))); - VERIFY_IS_EQUAL(i2tensor(i,j), static_cast(dtensor(i,j))); - } - } -} - - -static void test_big_to_small_type_cast() -{ - Tensor dtensor(20, 30); - dtensor.setRandom(); - Tensor ftensor(20, 30); - ftensor = dtensor.cast(); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 30; ++j) { - VERIFY_IS_APPROX(dtensor(i,j), static_cast(ftensor(i,j))); - } - } -} - - -static void test_small_to_big_type_cast() -{ - Tensor ftensor(20, 30); - ftensor.setRandom(); - Tensor dtensor(20, 30); - dtensor = ftensor.cast(); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 30; ++j) { - VERIFY_IS_APPROX(dtensor(i,j), static_cast(ftensor(i,j))); - } - } -} - template static void test_type_cast() { - Tensor ftensor(100, 200); + Tensor ftensor(101, 201); // Generate random values for a valid cast. - for (int i = 0; i < 100; ++i) { - for (int j = 0; j < 200; ++j) { + for (int i = 0; i < 101; ++i) { + for (int j = 0; j < 201; ++j) { ftensor(i, j) = internal::random_without_cast_overflow::value(); } } - Tensor ttensor(100, 200); + Tensor ttensor(101, 201); ttensor = ftensor.template cast(); - for (int i = 0; i < 100; ++i) { - for (int j = 0; j < 200; ++j) { - const ToType ref = internal::cast(ftensor(i, j)); - VERIFY_IS_APPROX(ttensor(i, j), ref); + for (int i = 0; i < 101; ++i) { + for (int j = 0; j < 201; ++j) { + const ToType ref = static_cast(ftensor(i, j)); + VERIFY_IS_EQUAL(ttensor(i, j), ref); } } } @@ -161,12 +71,6 @@ struct test_cast_runner::IsComplex>> EIGEN_DECLARE_TEST(cxx11_tensor_casts) { - CALL_SUBTEST(test_simple_cast()); - CALL_SUBTEST(test_vectorized_cast()); - CALL_SUBTEST(test_float_to_int_cast()); - CALL_SUBTEST(test_big_to_small_type_cast()); - CALL_SUBTEST(test_small_to_big_type_cast()); - CALL_SUBTEST(test_cast_runner::run()); CALL_SUBTEST(test_cast_runner::run()); CALL_SUBTEST(test_cast_runner::run());