diff --git a/CMakeLists.txt b/CMakeLists.txt index 0af36a53a..43530b463 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -360,11 +360,19 @@ else() endif() option(EIGEN_TEST_FMA "Enable/Disable FMA/AVX2 in tests/examples" OFF) - if(EIGEN_TEST_FMA AND NOT EIGEN_TEST_NEON) + option(EIGEN_TEST_AVX2 "Enable/Disable FMA/AVX2 in tests/examples" OFF) + if((EIGEN_TEST_FMA AND NOT EIGEN_TEST_NEON) OR EIGEN_TEST_AVX2) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2") message(STATUS "Enabling FMA/AVX2 in tests/examples") endif() + option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF) + option(EIGEN_TEST_AVX512DQ "Enable/Disable AVX512DQ in tests/examples" OFF) + if(EIGEN_TEST_AVX512 OR EIGEN_TEST_AVX512DQ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX512") + message(STATUS "Enabling AVX512 in tests/examples") + endif() + endif() option(EIGEN_TEST_NO_EXPLICIT_VECTORIZATION "Disable explicit vectorization in tests/examples" OFF) diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index f6916c853..d34e04873 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -784,7 +784,7 @@ EIGEN_STRONG_INLINE Packet8d pload(const double* from) { template <> EIGEN_STRONG_INLINE Packet16i pload(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( - reinterpret_cast(from)); + reinterpret_cast(from)); } template <> @@ -1440,38 +1440,30 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]); __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]); - kernel.packet[0] = reinterpret_cast<__m512>( - _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2))); - kernel.packet[1] = reinterpret_cast<__m512>( - _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2))); - kernel.packet[2] = reinterpret_cast<__m512>( - _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3))); - kernel.packet[3] = reinterpret_cast<__m512>( - _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3))); - kernel.packet[4] = reinterpret_cast<__m512>( - _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6))); - kernel.packet[5] = reinterpret_cast<__m512>( - _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6))); - kernel.packet[6] = reinterpret_cast<__m512>( - _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7))); - kernel.packet[7] = reinterpret_cast<__m512>( - _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7))); + kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); + kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); + kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); + kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); + kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); + kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); + kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); + kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); - T0 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[4]), 0x4E)); + T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); - T4 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[0]), 0x4E)); + T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); - T1 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[5]), 0x4E)); + T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); - T5 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[1]), 0x4E)); + T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); - T2 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[6]), 0x4E)); + T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); - T6 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[2]), 0x4E)); + T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); - T3 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[7]), 0x4E)); + T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); - T7 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[3]), 0x4E)); + T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); kernel.packet[0] = T0; kernel.packet[1] = T1;