Make vectorized compute_inverse_size4 compile with AVX.

(cherry picked from commit 85a76a16ea835fcfa7d4c185a338ae2aef9a272a)
This commit is contained in:
Rasmus Munk Larsen 2021-04-22 15:21:01 +00:00
parent 34d0be9ec1
commit 54425a39b2
2 changed files with 25 additions and 24 deletions

View File

@ -38,9 +38,7 @@
#include "src/LU/Determinant.h" #include "src/LU/Determinant.h"
#include "src/LU/InverseImpl.h" #include "src/LU/InverseImpl.h"
// Use the SSE optimized version whenever possible. At the moment the #if defined EIGEN_VECTORIZE_SSE || defined EIGEN_VECTORIZE_NEON
// SSE version doesn't compile when AVX is enabled
#if (defined EIGEN_VECTORIZE_SSE && !defined EIGEN_VECTORIZE_AVX) || defined EIGEN_VECTORIZE_NEON
#include "src/LU/arch/InverseSize4.h" #include "src/LU/arch/InverseSize4.h"
#endif #endif

View File

@ -54,10 +54,12 @@ struct compute_inverse_size4<Architecture::Target, float, MatrixType, ResultType
{ {
ActualMatrixType matrix(mat); ActualMatrixType matrix(mat);
Packet4f _L1 = matrix.template packet<MatrixAlignment>(0); const float* data = matrix.data();
Packet4f _L2 = matrix.template packet<MatrixAlignment>(4); const Index stride = matrix.innerStride();
Packet4f _L3 = matrix.template packet<MatrixAlignment>(8); Packet4f _L1 = ploadt<Packet4f,MatrixAlignment>(data);
Packet4f _L4 = matrix.template packet<MatrixAlignment>(12); Packet4f _L2 = ploadt<Packet4f,MatrixAlignment>(data + stride*4);
Packet4f _L3 = ploadt<Packet4f,MatrixAlignment>(data + stride*8);
Packet4f _L4 = ploadt<Packet4f,MatrixAlignment>(data + stride*12);
// Four 2x2 sub-matrices of the input matrix // Four 2x2 sub-matrices of the input matrix
// input = [[A, B], // input = [[A, B],
@ -189,25 +191,26 @@ struct compute_inverse_size4<Architecture::Target, double, MatrixType, ResultTyp
Packet2d A1, A2, B1, B2, C1, C2, D1, D2; Packet2d A1, A2, B1, B2, C1, C2, D1, D2;
const double* data = matrix.data();
const Index stride = matrix.innerStride();
if (StorageOrdersMatch) if (StorageOrdersMatch)
{ {
A1 = matrix.template packet<MatrixAlignment>(0); A1 = ploadt<Packet2d,MatrixAlignment>(data + stride*0);
B1 = matrix.template packet<MatrixAlignment>(2); B1 = ploadt<Packet2d,MatrixAlignment>(data + stride*2);
A2 = matrix.template packet<MatrixAlignment>(4); A2 = ploadt<Packet2d,MatrixAlignment>(data + stride*4);
B2 = matrix.template packet<MatrixAlignment>(6); B2 = ploadt<Packet2d,MatrixAlignment>(data + stride*6);
C1 = matrix.template packet<MatrixAlignment>(8); C1 = ploadt<Packet2d,MatrixAlignment>(data + stride*8);
D1 = matrix.template packet<MatrixAlignment>(10); D1 = ploadt<Packet2d,MatrixAlignment>(data + stride*10);
C2 = matrix.template packet<MatrixAlignment>(12); C2 = ploadt<Packet2d,MatrixAlignment>(data + stride*12);
D2 = matrix.template packet<MatrixAlignment>(14); D2 = ploadt<Packet2d,MatrixAlignment>(data + stride*14);
} }
else else
{ {
Packet2d temp; Packet2d temp;
A1 = matrix.template packet<MatrixAlignment>(0); A1 = ploadt<Packet2d,MatrixAlignment>(data + stride*0);
C1 = matrix.template packet<MatrixAlignment>(2); C1 = ploadt<Packet2d,MatrixAlignment>(data + stride*2);
A2 = matrix.template packet<MatrixAlignment>(4); A2 = ploadt<Packet2d,MatrixAlignment>(data + stride*4);
C2 = matrix.template packet<MatrixAlignment>(6); C2 = ploadt<Packet2d,MatrixAlignment>(data + stride*6);
temp = A1; temp = A1;
A1 = vec2d_unpacklo(A1, A2); A1 = vec2d_unpacklo(A1, A2);
A2 = vec2d_unpackhi(temp, A2); A2 = vec2d_unpackhi(temp, A2);
@ -216,10 +219,10 @@ struct compute_inverse_size4<Architecture::Target, double, MatrixType, ResultTyp
C1 = vec2d_unpacklo(C1, C2); C1 = vec2d_unpacklo(C1, C2);
C2 = vec2d_unpackhi(temp, C2); C2 = vec2d_unpackhi(temp, C2);
B1 = matrix.template packet<MatrixAlignment>(8); B1 = ploadt<Packet2d,MatrixAlignment>(data + stride*8);
D1 = matrix.template packet<MatrixAlignment>(10); D1 = ploadt<Packet2d,MatrixAlignment>(data + stride*10);
B2 = matrix.template packet<MatrixAlignment>(12); B2 = ploadt<Packet2d,MatrixAlignment>(data + stride*12);
D2 = matrix.template packet<MatrixAlignment>(14); D2 = ploadt<Packet2d,MatrixAlignment>(data + stride*14);
temp = B1; temp = B1;
B1 = vec2d_unpacklo(B1, B2); B1 = vec2d_unpacklo(B1, B2);