optimize new dot product

This commit is contained in:
Charles Schlosser 2024-09-11 21:40:43 +00:00 committed by Rasmus Munk Larsen
parent fb477b8be1
commit 84282c42fc

View File

@ -132,24 +132,21 @@ struct inner_product_impl<Evaluator, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size()); const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval); if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
const UnsignedIndex packetEnd = numext::round_down(size, PacketSize); const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
const UnsignedIndex numPackets = size / PacketSize; const UnsignedIndex numPackets = size / PacketSize;
const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
Packet presult0 = eval.template packet<Packet>(0 * PacketSize); Packet presult0, presult1, presult2, presult3;
Packet presult1 = pzero(Packet());
Packet presult2 = pzero(Packet());
Packet presult3 = pzero(Packet());
presult0 = eval.template packet<Packet>(0 * PacketSize);
if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize); if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize); if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
if (numPackets >= 4) { if (numPackets >= 4) {
presult3 = eval.template packet<Packet>(3 * PacketSize); presult3 = eval.template packet<Packet>(3 * PacketSize);
const UnsignedIndex numRemPackets = (numPackets - 4) % 4; for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
const UnsignedIndex quadStart = 4 * PacketSize;
const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize;
for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) {
presult0 = eval.packet(presult0, k + 0 * PacketSize); presult0 = eval.packet(presult0, k + 0 * PacketSize);
presult1 = eval.packet(presult1, k + 1 * PacketSize); presult1 = eval.packet(presult1, k + 1 * PacketSize);
presult2 = eval.packet(presult2, k + 2 * PacketSize); presult2 = eval.packet(presult2, k + 2 * PacketSize);
@ -159,16 +156,16 @@ struct inner_product_impl<Evaluator, true> {
if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize); if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize); if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize); if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
presult2 = padd(presult2, presult3);
} }
Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3))); if (numPackets >= 3) presult1 = padd(presult1, presult2);
if (numPackets >= 2) presult0 = padd(presult0, presult1);
if (size > packetEnd) { Scalar result = predux(presult0);
Scalar scalarAccum = eval.coeff(packetEnd); for (UnsignedIndex k = packetEnd; k < size; k++) {
for (UnsignedIndex k = packetEnd + 1; k < size; k++) { result = eval.coeff(result, k);
scalarAccum = eval.coeff(scalarAccum, k);
}
result += scalarAccum;
} }
return result; return result;