optimize new dot product

This commit is contained in:
Charles Schlosser 2024-09-11 21:40:43 +00:00 committed by Rasmus Munk Larsen
parent fb477b8be1
commit 84282c42fc

View File

@ -132,24 +132,21 @@ struct inner_product_impl<Evaluator, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
const UnsignedIndex numPackets = size / PacketSize;
const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
Packet presult0 = eval.template packet<Packet>(0 * PacketSize);
Packet presult1 = pzero(Packet());
Packet presult2 = pzero(Packet());
Packet presult3 = pzero(Packet());
Packet presult0, presult1, presult2, presult3;
presult0 = eval.template packet<Packet>(0 * PacketSize);
if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
if (numPackets >= 4) {
presult3 = eval.template packet<Packet>(3 * PacketSize);
const UnsignedIndex numRemPackets = (numPackets - 4) % 4;
const UnsignedIndex quadStart = 4 * PacketSize;
const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize;
for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) {
for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
presult0 = eval.packet(presult0, k + 0 * PacketSize);
presult1 = eval.packet(presult1, k + 1 * PacketSize);
presult2 = eval.packet(presult2, k + 2 * PacketSize);
@ -159,16 +156,16 @@ struct inner_product_impl<Evaluator, true> {
if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
presult2 = padd(presult2, presult3);
}
Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3)));
if (numPackets >= 3) presult1 = padd(presult1, presult2);
if (numPackets >= 2) presult0 = padd(presult0, presult1);
if (size > packetEnd) {
Scalar scalarAccum = eval.coeff(packetEnd);
for (UnsignedIndex k = packetEnd + 1; k < size; k++) {
scalarAccum = eval.coeff(scalarAccum, k);
}
result += scalarAccum;
Scalar result = predux(presult0);
for (UnsignedIndex k = packetEnd; k < size; k++) {
result = eval.coeff(result, k);
}
return result;