Fix tensor stridedlinearbuffercopy

This commit is contained in:
Charles Schlosser 2023-08-03 20:36:42 +00:00 committed by Rasmus Munk Larsen
parent 8d9f467036
commit a798d07659

View File

@ -1054,28 +1054,28 @@ class StridedLinearBufferCopy {
} }
return; return;
} }
const IndexType vectorized_size = count - PacketSize; const IndexType vectorized_size = PacketSize * (count / PacketSize);
IndexType i = 0; IndexType i = 0;
if (kind == StridedLinearBufferCopy::Kind::Linear) { if (kind == StridedLinearBufferCopy::Kind::Linear) {
// ******************************************************************** // // ******************************************************************** //
// Linear copy from `src` to `dst`. // Linear copy from `src` to `dst`.
const IndexType unrolled_size = count - 4 * PacketSize; const IndexType unrolled_size = (4 * PacketSize) * (count / (4 * PacketSize));
eigen_assert(src_stride == 1 && dst_stride == 1); eigen_assert(src_stride == 1 && dst_stride == 1);
for (; i <= unrolled_size; i += 4 * PacketSize) { for (; i < unrolled_size; i += 4 * PacketSize) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
Packet p = ploadu<Packet>(src + i + j * PacketSize); Packet p = ploadu<Packet>(src + i + j * PacketSize);
pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p); pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p);
} }
} }
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
Packet p = ploadu<Packet>(src + i); Packet p = ploadu<Packet>(src + i);
pstoreu<Scalar, Packet>(dst + i, p); pstoreu<Scalar, Packet>(dst + i, p);
} }
if (HasHalfPacket) { if (HasHalfPacket) {
const IndexType vectorized_half_size = count - HalfPacketSize; const IndexType vectorized_half_size = HalfPacketSize * (count / HalfPacketSize);
if (i <= vectorized_half_size) { if (i < vectorized_half_size) {
HalfPacket p = ploadu<HalfPacket>(src + i); HalfPacket p = ploadu<HalfPacket>(src + i);
pstoreu<Scalar, HalfPacket>(dst + i, p); pstoreu<Scalar, HalfPacket>(dst + i, p);
i += HalfPacketSize; i += HalfPacketSize;
@ -1088,13 +1088,13 @@ class StridedLinearBufferCopy {
} else if (kind == StridedLinearBufferCopy::Kind::Scatter) { } else if (kind == StridedLinearBufferCopy::Kind::Scatter) {
// Scatter from `src` to `dst`. // Scatter from `src` to `dst`.
eigen_assert(src_stride == 1 && dst_stride != 1); eigen_assert(src_stride == 1 && dst_stride != 1);
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
Packet p = ploadu<Packet>(src + i); Packet p = ploadu<Packet>(src + i);
pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride); pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
} }
if (HasHalfPacket) { if (HasHalfPacket) {
const IndexType vectorized_half_size = count - HalfPacketSize; const IndexType vectorized_half_size = HalfPacketSize * (count / HalfPacketSize);
if (i <= vectorized_half_size) { if (i < vectorized_half_size) {
HalfPacket p = ploadu<HalfPacket>(src + i); HalfPacket p = ploadu<HalfPacket>(src + i);
pscatter<Scalar, HalfPacket>(dst + i * dst_stride, p, dst_stride); pscatter<Scalar, HalfPacket>(dst + i * dst_stride, p, dst_stride);
i += HalfPacketSize; i += HalfPacketSize;
@ -1107,20 +1107,21 @@ class StridedLinearBufferCopy {
} else if (kind == StridedLinearBufferCopy::Kind::FillLinear) { } else if (kind == StridedLinearBufferCopy::Kind::FillLinear) {
// Fill `dst` with value at `*src`. // Fill `dst` with value at `*src`.
eigen_assert(src_stride == 0 && dst_stride == 1); eigen_assert(src_stride == 0 && dst_stride == 1);
const IndexType unrolled_size = count - 4 * PacketSize;
const IndexType unrolled_size = (4 * PacketSize) * (count / (4 * PacketSize));
Scalar s = *src; Scalar s = *src;
Packet p = pset1<Packet>(s); Packet p = pset1<Packet>(s);
for (; i <= unrolled_size; i += 4 * PacketSize) { for (; i < unrolled_size; i += 4 * PacketSize) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p); pstoreu<Scalar, Packet>(dst + i + j * PacketSize, p);
} }
} }
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
pstoreu<Scalar, Packet>(dst + i, p); pstoreu<Scalar, Packet>(dst + i, p);
} }
if (HasHalfPacket) { if (HasHalfPacket) {
const IndexType vectorized_half_size = count - HalfPacketSize; const IndexType vectorized_half_size = HalfPacketSize * (count / HalfPacketSize);
if (i <= vectorized_half_size) { if (i < vectorized_half_size) {
HalfPacket hp = pset1<HalfPacket>(s); HalfPacket hp = pset1<HalfPacket>(s);
pstoreu<Scalar, HalfPacket>(dst + i, hp); pstoreu<Scalar, HalfPacket>(dst + i, hp);
i += HalfPacketSize; i += HalfPacketSize;
@ -1135,12 +1136,12 @@ class StridedLinearBufferCopy {
eigen_assert(src_stride == 0 && dst_stride != 1); eigen_assert(src_stride == 0 && dst_stride != 1);
Scalar s = *src; Scalar s = *src;
Packet p = pset1<Packet>(s); Packet p = pset1<Packet>(s);
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride); pscatter<Scalar, Packet>(dst + i * dst_stride, p, dst_stride);
} }
if (HasHalfPacket) { if (HasHalfPacket) {
const IndexType vectorized_half_size = count - HalfPacketSize; const IndexType vectorized_half_size = HalfPacketSize * (count / HalfPacketSize);
if (i <= vectorized_half_size) { if (i < vectorized_half_size) {
HalfPacket hp = pset1<HalfPacket>(s); HalfPacket hp = pset1<HalfPacket>(s);
pscatter<Scalar, HalfPacket>(dst + i * dst_stride, hp, dst_stride); pscatter<Scalar, HalfPacket>(dst + i * dst_stride, hp, dst_stride);
i += HalfPacketSize; i += HalfPacketSize;
@ -1153,13 +1154,13 @@ class StridedLinearBufferCopy {
} else if (kind == StridedLinearBufferCopy::Kind::Gather) { } else if (kind == StridedLinearBufferCopy::Kind::Gather) {
// Gather from `src` into `dst`. // Gather from `src` into `dst`.
eigen_assert(dst_stride == 1); eigen_assert(dst_stride == 1);
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
Packet p = pgather<Scalar, Packet>(src + i * src_stride, src_stride); Packet p = pgather<Scalar, Packet>(src + i * src_stride, src_stride);
pstoreu<Scalar, Packet>(dst + i, p); pstoreu<Scalar, Packet>(dst + i, p);
} }
if (HasHalfPacket) { if (HasHalfPacket) {
const IndexType vectorized_half_size = count - HalfPacketSize; const IndexType vectorized_half_size = HalfPacketSize * (count / HalfPacketSize);
if (i <= vectorized_half_size) { if (i < vectorized_half_size) {
HalfPacket p = HalfPacket p =
pgather<Scalar, HalfPacket>(src + i * src_stride, src_stride); pgather<Scalar, HalfPacket>(src + i * src_stride, src_stride);
pstoreu<Scalar, HalfPacket>(dst + i, p); pstoreu<Scalar, HalfPacket>(dst + i, p);
@ -1456,11 +1457,11 @@ class TensorBlockAssignment {
IndexType eval_offset) { IndexType eval_offset) {
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
const IndexType unrolled_size = count - 4 * PacketSize; const IndexType unrolled_size = (4 * PacketSize) * (count / (4 * PacketSize));
const IndexType vectorized_size = count - PacketSize; const IndexType vectorized_size = PacketSize * (count / PacketSize);
IndexType i = 0; IndexType i = 0;
for (; i <= unrolled_size; i += 4 * PacketSize) { for (; i < unrolled_size; i += 4 * PacketSize) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
const IndexType idx = eval_offset + i + j * PacketSize; const IndexType idx = eval_offset + i + j * PacketSize;
Packet p = eval.template packet<Unaligned>(idx); Packet p = eval.template packet<Unaligned>(idx);
@ -1468,7 +1469,7 @@ class TensorBlockAssignment {
} }
} }
for (; i <= vectorized_size; i += PacketSize) { for (; i < vectorized_size; i += PacketSize) {
Packet p = eval.template packet<Unaligned>(eval_offset + i); Packet p = eval.template packet<Unaligned>(eval_offset + i);
pstoreu<Scalar>(target + i, p); pstoreu<Scalar>(target + i, p);
} }