mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Remove reinterpret_cast from AVX512 complex implementation
The reinterpret_casts used in ptranspose(PacketBlock<Packet8cf,4>&) ptranspose(PacketBlock<Packet8cf,8>&) don't appear to be working correctly. They're used to convert the kernel parameters to PacketBlock<Packet8d,T>& so that the complex number versions of ptranspose can be written using the existing double implementations. Unfortunately, they don't seem to work and are responsible for 9 unit test failures in the AVX512 build of tensorflow master. This commit fixes the issue by manually initialising PacketBlock<Packet8d,T> variables with the contents of the kernel parameter before calling the double version of ptranspose, and then copying the resulting values back into the kernel parameter before returning.
This commit is contained in:
parent
0522460a0d
commit
3c9add6598
@ -390,12 +390,40 @@ template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip<Packet4cd>(const Packet4cd& x
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void
|
||||
ptranspose(PacketBlock<Packet8cf,4>& kernel) {
|
||||
ptranspose(reinterpret_cast<PacketBlock<Packet8d,4>&>(kernel));
|
||||
PacketBlock<Packet8d,4> pb;
|
||||
|
||||
pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
|
||||
pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
|
||||
pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
|
||||
pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
|
||||
ptranspose(pb);
|
||||
kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
|
||||
kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
|
||||
kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
|
||||
kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void
|
||||
ptranspose(PacketBlock<Packet8cf,8>& kernel) {
|
||||
ptranspose(reinterpret_cast<PacketBlock<Packet8d,8>&>(kernel));
|
||||
PacketBlock<Packet8d,8> pb;
|
||||
|
||||
pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v);
|
||||
pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v);
|
||||
pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v);
|
||||
pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v);
|
||||
pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v);
|
||||
pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v);
|
||||
pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v);
|
||||
pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v);
|
||||
ptranspose(pb);
|
||||
kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]);
|
||||
kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]);
|
||||
kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]);
|
||||
kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]);
|
||||
kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]);
|
||||
kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]);
|
||||
kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]);
|
||||
kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void
|
||||
|
Loading…
x
Reference in New Issue
Block a user