mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 18:59:01 +08:00
optimize new dot product
This commit is contained in:
parent
fb477b8be1
commit
84282c42fc
@ -132,24 +132,21 @@ struct inner_product_impl<Evaluator, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
|
||||
const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
|
||||
if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
|
||||
|
||||
const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
|
||||
const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
|
||||
const UnsignedIndex numPackets = size / PacketSize;
|
||||
const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
|
||||
|
||||
Packet presult0 = eval.template packet<Packet>(0 * PacketSize);
|
||||
Packet presult1 = pzero(Packet());
|
||||
Packet presult2 = pzero(Packet());
|
||||
Packet presult3 = pzero(Packet());
|
||||
Packet presult0, presult1, presult2, presult3;
|
||||
|
||||
presult0 = eval.template packet<Packet>(0 * PacketSize);
|
||||
if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
|
||||
if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
|
||||
if (numPackets >= 4) {
|
||||
presult3 = eval.template packet<Packet>(3 * PacketSize);
|
||||
|
||||
const UnsignedIndex numRemPackets = (numPackets - 4) % 4;
|
||||
const UnsignedIndex quadStart = 4 * PacketSize;
|
||||
const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize;
|
||||
|
||||
for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) {
|
||||
for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
|
||||
presult0 = eval.packet(presult0, k + 0 * PacketSize);
|
||||
presult1 = eval.packet(presult1, k + 1 * PacketSize);
|
||||
presult2 = eval.packet(presult2, k + 2 * PacketSize);
|
||||
@ -159,16 +156,16 @@ struct inner_product_impl<Evaluator, true> {
|
||||
if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
|
||||
if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
|
||||
if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
|
||||
|
||||
presult2 = padd(presult2, presult3);
|
||||
}
|
||||
|
||||
Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3)));
|
||||
if (numPackets >= 3) presult1 = padd(presult1, presult2);
|
||||
if (numPackets >= 2) presult0 = padd(presult0, presult1);
|
||||
|
||||
if (size > packetEnd) {
|
||||
Scalar scalarAccum = eval.coeff(packetEnd);
|
||||
for (UnsignedIndex k = packetEnd + 1; k < size; k++) {
|
||||
scalarAccum = eval.coeff(scalarAccum, k);
|
||||
}
|
||||
result += scalarAccum;
|
||||
Scalar result = predux(presult0);
|
||||
for (UnsignedIndex k = packetEnd; k < size; k++) {
|
||||
result = eval.coeff(result, k);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
Loading…
x
Reference in New Issue
Block a user