diff --git a/Eigen/src/Core/InnerProduct.h b/Eigen/src/Core/InnerProduct.h index c8b1c1d0d..38689daaf 100644 --- a/Eigen/src/Core/InnerProduct.h +++ b/Eigen/src/Core/InnerProduct.h @@ -132,24 +132,21 @@ struct inner_product_impl { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { const UnsignedIndex size = static_cast(eval.size()); if (size < PacketSize) return inner_product_impl::run(eval); + const UnsignedIndex packetEnd = numext::round_down(size, PacketSize); + const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize); const UnsignedIndex numPackets = size / PacketSize; + const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize; - Packet presult0 = eval.template packet(0 * PacketSize); - Packet presult1 = pzero(Packet()); - Packet presult2 = pzero(Packet()); - Packet presult3 = pzero(Packet()); + Packet presult0, presult1, presult2, presult3; + presult0 = eval.template packet(0 * PacketSize); if (numPackets >= 2) presult1 = eval.template packet(1 * PacketSize); if (numPackets >= 3) presult2 = eval.template packet(2 * PacketSize); if (numPackets >= 4) { presult3 = eval.template packet(3 * PacketSize); - const UnsignedIndex numRemPackets = (numPackets - 4) % 4; - const UnsignedIndex quadStart = 4 * PacketSize; - const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize; - - for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) { + for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) { presult0 = eval.packet(presult0, k + 0 * PacketSize); presult1 = eval.packet(presult1, k + 1 * PacketSize); presult2 = eval.packet(presult2, k + 2 * PacketSize); @@ -159,16 +156,16 @@ struct inner_product_impl { if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize); if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize); if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize); + + presult2 = padd(presult2, presult3); } - Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3))); + if (numPackets >= 3) presult1 = padd(presult1, presult2); + if (numPackets >= 2) presult0 = padd(presult0, presult1); - if (size > packetEnd) { - Scalar scalarAccum = eval.coeff(packetEnd); - for (UnsignedIndex k = packetEnd + 1; k < size; k++) { - scalarAccum = eval.coeff(scalarAccum, k); - } - result += scalarAccum; + Scalar result = predux(presult0); + for (UnsignedIndex k = packetEnd; k < size; k++) { + result = eval.coeff(result, k); } return result;