From 874f5947f429d1da729a5fb54c5c9a673d1a0148 Mon Sep 17 00:00:00 2001 From: Pedro Gonnet Date: Mon, 1 May 2023 16:09:31 +0000 Subject: [PATCH] Add half-`Packet` operations to `StridedLinearBufferCopy`. --- .../Eigen/CXX11/src/Tensor/TensorBlock.h | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h index 2e663404e..ee01d9932 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h @@ -994,9 +994,12 @@ class TensorTernaryExprBlock { template class StridedLinearBufferCopy { typedef typename packet_traits::type Packet; + typedef typename unpacket_traits::half HalfPacket; enum { Vectorizable = packet_traits::Vectorizable, - PacketSize = packet_traits::size + PacketSize = packet_traits::size, + HasHalfPacket = unpacket_traits::size < PacketSize, + HalfPacketSize = unpacket_traits::size, }; public: @@ -1070,6 +1073,14 @@ class StridedLinearBufferCopy { Packet p = ploadu(src + i); pstoreu(dst + i, p); } + if (HasHalfPacket) { + const IndexType vectorized_half_size = count - HalfPacketSize; + if (i <= vectorized_half_size) { + HalfPacket p = ploadu(src + i); + pstoreu(dst + i, p); + i += HalfPacketSize; + } + } for (; i < count; ++i) { dst[i] = src[i]; } @@ -1081,6 +1092,14 @@ class StridedLinearBufferCopy { Packet p = ploadu(src + i); pscatter(dst + i * dst_stride, p, dst_stride); } + if (HasHalfPacket) { + const IndexType vectorized_half_size = count - HalfPacketSize; + if (i <= vectorized_half_size) { + HalfPacket p = ploadu(src + i); + pscatter(dst + i * dst_stride, p, dst_stride); + i += HalfPacketSize; + } + } for (; i < count; ++i) { dst[i * dst_stride] = src[i]; } @@ -1089,7 +1108,8 @@ class StridedLinearBufferCopy { // Fill `dst` with value at `*src`. eigen_assert(src_stride == 0 && dst_stride == 1); const IndexType unrolled_size = count - 4 * PacketSize; - Packet p = pload1(src); + Scalar s = *src; + Packet p = pset1(s); for (; i <= unrolled_size; i += 4 * PacketSize) { for (int j = 0; j < 4; ++j) { pstoreu(dst + i + j * PacketSize, p); @@ -1098,19 +1118,36 @@ class StridedLinearBufferCopy { for (; i <= vectorized_size; i += PacketSize) { pstoreu(dst + i, p); } + if (HasHalfPacket) { + const IndexType vectorized_half_size = count - HalfPacketSize; + if (i <= vectorized_half_size) { + HalfPacket hp = pset1(s); + pstoreu(dst + i, hp); + i += HalfPacketSize; + } + } for (; i < count; ++i) { - dst[i] = *src; + dst[i] = s; } // ******************************************************************** // } else if (kind == StridedLinearBufferCopy::Kind::FillScatter) { // Scatter `*src` into `dst`. eigen_assert(src_stride == 0 && dst_stride != 1); - Packet p = pload1(src); + Scalar s = *src; + Packet p = pset1(s); for (; i <= vectorized_size; i += PacketSize) { pscatter(dst + i * dst_stride, p, dst_stride); } + if (HasHalfPacket) { + const IndexType vectorized_half_size = count - HalfPacketSize; + if (i <= vectorized_half_size) { + HalfPacket hp = pset1(s); + pscatter(dst + i * dst_stride, hp, dst_stride); + i += HalfPacketSize; + } + } for (; i < count; ++i) { - dst[i * dst_stride] = *src; + dst[i * dst_stride] = s; } // ******************************************************************** // } else if (kind == StridedLinearBufferCopy::Kind::Gather) { @@ -1120,6 +1157,15 @@ class StridedLinearBufferCopy { Packet p = pgather(src + i * src_stride, src_stride); pstoreu(dst + i, p); } + if (HasHalfPacket) { + const IndexType vectorized_half_size = count - HalfPacketSize; + if (i <= vectorized_half_size) { + HalfPacket p = + pgather(src + i * src_stride, src_stride); + pstoreu(dst + i, p); + i += HalfPacketSize; + } + } for (; i < count; ++i) { dst[i] = src[i * src_stride]; }