diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h index 4fd923e84..ba5a3fddd 100644 --- a/Eigen/src/Core/arch/AltiVec/Complex.h +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -114,11 +114,19 @@ template<> struct unpacket_traits { typedef std::complex type; template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) { Packet2cf res; +#ifdef __VSX__ + // Load a single std::complex from memory and duplicate + // + // Using pload would read past the end of the reference in this case + // Using vec_xl_len + vec_splat, generates poor assembly + __asm__ ("lxvdsx %x0,%y1" : "=wa" (res.v) : "Z" (from)); +#else if((std::ptrdiff_t(&from) % 16) == 0) res.v = pload((const float *)&from); else res.v = ploadu((const float *)&from); res.v = vec_perm(res.v, res.v, p16uc_PSET64_HI); +#endif return res; } @@ -133,6 +141,7 @@ EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex& from0, const std { Packet4f res0, res1; #ifdef __VSX__ + // Load two std::complex from memory and combine __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0)); __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1)); #ifdef _BIG_ENDIAN @@ -186,7 +195,7 @@ template<> EIGEN_STRONG_INLINE std::complex pfirst(const Pack template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) { Packet4f rev_a; - rev_a = vec_perm(a.v, a.v, p16uc_COMPLEX32_REV2); + rev_a = vec_sld(a.v, a.v, 8); return Packet2cf(rev_a); } @@ -222,8 +231,8 @@ template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& x EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); - kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + Packet4f tmp = reinterpret_cast(vec_mergeh(reinterpret_cast(kernel.packet[0].v), reinterpret_cast(kernel.packet[1].v))); + kernel.packet[1].v = reinterpret_cast(vec_mergel(reinterpret_cast(kernel.packet[0].v), reinterpret_cast(kernel.packet[1].v))); kernel.packet[0].v = tmp; } @@ -358,7 +367,7 @@ template<> EIGEN_STRONG_INLINE void prefetch >(const std::c template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) { - EIGEN_ALIGN16 std::complex res[2]; + EIGEN_ALIGN16 std::complex res[1]; pstore >(res, a); return res[0]; @@ -384,8 +393,8 @@ EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { - Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); - kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + Packet2d tmp = vec_mergeh(kernel.packet[0].v, kernel.packet[1].v); + kernel.packet[1].v = vec_mergel(kernel.packet[0].v, kernel.packet[1].v); kernel.packet[0].v = tmp; } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 312970771..e24b5d5b3 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -1020,26 +1020,12 @@ EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& l { if(NegativeAccumulate) { - acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); - if (N > 1) { - acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); - } - if (N > 2) { - acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); - } - if (N > 3) { - acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + for (int M = 0; M < N; M++) { + acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]); } } else { - acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); - if (N > 1) { - acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); - } - if (N > 2) { - acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); - } - if (N > 3) { - acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + for (int M = 0; M < N; M++) { + acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]); } } } @@ -1052,31 +1038,9 @@ EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, con pger_common(acc, lhsV, rhsV); } -template -EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV) -{ -#ifdef _ARCH_PWR9 - lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); -#else - Index i = 0; - do { - lhsV[i] = lhs[i]; - } while (++i < remaining_rows); -#endif -} - -template -EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) -{ - Packet lhsV; - loadPacketRemaining(lhs, lhsV); - - pger_common(acc, lhsV, rhsV); -} - // 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. template -EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock* accReal, PacketBlock* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) +EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock* accReal, PacketBlock* accImag, const Packet &lhsV, Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) { pger_common(accReal, lhsV, rhsV); if(LhsIsReal) @@ -1097,97 +1061,56 @@ EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock* accReal, PacketBloc template EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) { - Packet lhsV = ploadLhs(lhs_ptr); + Packet lhsV = ploadLhs(lhs_ptr); Packet lhsVi; - if(!LhsIsReal) lhsVi = ploadLhs(lhs_ptr_imag); + if(!LhsIsReal) lhsVi = ploadLhs(lhs_ptr_imag); else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); } -template -EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi) -{ -#ifdef _ARCH_PWR9 - lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar)); - if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar)); - else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); -#else - Index i = 0; - do { - lhsV[i] = lhs_ptr[i]; - if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i]; - } while (++i < remaining_rows); - if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); -#endif -} - -template -EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) -{ - Packet lhsV, lhsVi; - loadPacketRemaining(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi); - - pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); -} - -template -EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs) +template +EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs) { return ploadu(lhs); } // Zero the accumulator on PacketBlock. -template +template EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) { - acc.packet[0] = pset1((Scalar)0); - if (N > 1) { - acc.packet[1] = pset1((Scalar)0); - } - if (N > 2) { - acc.packet[2] = pset1((Scalar)0); - } - if (N > 3) { - acc.packet[3] = pset1((Scalar)0); - } -} - -// Scale the PacketBlock vectors by alpha. -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) -{ - acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); - if (N > 1) { - acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); - } - if (N > 2) { - acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); - } - if (N > 3) { - acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); + for (int M = 0; M < N; M++) { + acc.packet[M] = pset1((__UNPACK_TYPE__(Packet))0); } } template EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { - acc.packet[0] = pmul(accZ.packet[0], pAlpha); - if (N > 1) { - acc.packet[1] = pmul(accZ.packet[1], pAlpha); + for (int M = 0; M < N; M++) { + acc.packet[M] = vec_mul(accZ.packet[M], pAlpha); } - if (N > 2) { - acc.packet[2] = pmul(accZ.packet[2], pAlpha); - } - if (N > 3) { - acc.packet[3] = pmul(accZ.packet[3], pAlpha); +} + +template +EIGEN_ALWAYS_INLINE void band(PacketBlock& acc, const Packet& pMask) +{ + for (int M = 0; M < N; M++) { + acc.packet[M] = pand(acc.packet[M], pMask); } } // Complex version of PacketBlock scaling. -template -EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) { + if (mask && (sizeof(__UNPACK_TYPE__(Packet)) == sizeof(float))) { + band(aReal, pMask); + band(aImag, pMask); + } else { + EIGEN_UNUSED_VARIABLE(pMask); + } + bscalec_common(cReal, aReal, bReal); bscalec_common(cImag, aImag, bReal); @@ -1197,213 +1120,253 @@ EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock(&cImag, bImag, aReal.packet); } -template -EIGEN_ALWAYS_INLINE void band(PacketBlock& acc, const Packet& pMask) -{ - acc.packet[0] = pand(acc.packet[0], pMask); - if (N > 1) { - acc.packet[1] = pand(acc.packet[1], pMask); - } - if (N > 2) { - acc.packet[2] = pand(acc.packet[2], pMask); - } - if (N > 3) { - acc.packet[3] = pand(acc.packet[3], pMask); - } -} - -template -EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) -{ - band(aReal, pMask); - band(aImag, pMask); - - bscalec(aReal, aImag, bReal, bImag, cReal, cImag); -} - // Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. -template +// +// full = operate (load) on the entire PacketBlock or only half +template EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) { if (StorageOrder == RowMajor) { - acc.packet[0] = res.template loadPacket(row + 0, col); - if (N > 1) { - acc.packet[1] = res.template loadPacket(row + 1, col); - } - if (N > 2) { - acc.packet[2] = res.template loadPacket(row + 2, col); - } - if (N > 3) { - acc.packet[3] = res.template loadPacket(row + 3, col); + for (int M = 0; M < N; M++) { + acc.packet[M] = res.template loadPacket(row + M, col); } if (Complex) { - acc.packet[0+N] = res.template loadPacket(row + 0, col + accCols); - if (N > 1) { - acc.packet[1+N] = res.template loadPacket(row + 1, col + accCols); - } - if (N > 2) { - acc.packet[2+N] = res.template loadPacket(row + 2, col + accCols); - } - if (N > 3) { - acc.packet[3+N] = res.template loadPacket(row + 3, col + accCols); + for (int M = 0; M < N; M++) { + acc.packet[M+N] = res.template loadPacket(row + M, col + accCols); } } } else { - acc.packet[0] = res.template loadPacket(row, col + 0); - if (N > 1) { - acc.packet[1] = res.template loadPacket(row, col + 1); + for (int M = 0; M < N; M++) { + acc.packet[M] = res.template loadPacket(row, col + M); } - if (N > 2) { - acc.packet[2] = res.template loadPacket(row, col + 2); - } - if (N > 3) { - acc.packet[3] = res.template loadPacket(row, col + 3); - } - if (Complex) { - acc.packet[0+N] = res.template loadPacket(row + accCols, col + 0); - if (N > 1) { - acc.packet[1+N] = res.template loadPacket(row + accCols, col + 1); - } - if (N > 2) { - acc.packet[2+N] = res.template loadPacket(row + accCols, col + 2); - } - if (N > 3) { - acc.packet[3+N] = res.template loadPacket(row + accCols, col + 3); + if (Complex && full) { + for (int M = 0; M < N; M++) { + acc.packet[M+N] = res.template loadPacket(row + accCols, col + M); } } } } -const static Packet4i mask41 = { -1, 0, 0, 0 }; -const static Packet4i mask42 = { -1, -1, 0, 0 }; -const static Packet4i mask43 = { -1, -1, -1, 0 }; +template +EIGEN_ALWAYS_INLINE void bstore(PacketBlock& acc, const DataMapper& res, Index row) +{ + for (int M = 0; M < N; M++) { + res.template storePacket(row, M, acc.packet[M]); + } +} -const static Packet2l mask21 = { -1, 0 }; +#ifdef _ARCH_PWR10 +#define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11)) +#else +#define USE_P10_AND_PVIPR2_0 0 +#endif + +#if !USE_P10_AND_PVIPR2_0 +const static Packet4i mask4[4] = { { 0, 0, 0, 0 }, { -1, 0, 0, 0 }, { -1, -1, 0, 0 }, { -1, -1, -1, 0 } }; +#endif template EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows) { - if (remaining_rows == 0) { - return pset1(float(0.0)); // Not used - } else { - switch (remaining_rows) { - case 1: return Packet(mask41); - case 2: return Packet(mask42); - default: return Packet(mask43); - } - } +#if USE_P10_AND_PVIPR2_0 +#ifdef _BIG_ENDIAN + return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1))); +#else + return Packet(vec_genwm((1 << remaining_rows) - 1)); +#endif +#else + return Packet(mask4[remaining_rows]); +#endif } template<> EIGEN_ALWAYS_INLINE Packet2d bmask(const Index remaining_rows) { - if (remaining_rows == 0) { - return pset1(double(0.0)); // Not used - } else { - return Packet2d(mask21); - } +#if USE_P10_AND_PVIPR2_0 + Packet2d mask2 = Packet2d(vec_gendm(remaining_rows)); +#ifdef _BIG_ENDIAN + return preverse(mask2); +#else + return mask2; +#endif +#else + Packet2l ret = { -remaining_rows, 0 }; + return Packet2d(ret); +#endif } -template +// Scale the PacketBlock vectors by alpha. +template EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) { - band(accZ, pMask); + if (mask) { + band(accZ, pMask); + } else { + EIGEN_UNUSED_VARIABLE(pMask); + } - bscale(acc, accZ, pAlpha); + for (int M = 0; M < N; M++) { + acc.packet[M] = pmadd(pAlpha, accZ.packet[M], acc.packet[M]); + } } -template EIGEN_ALWAYS_INLINE void -pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a, - Packet& a0, Packet& a1, Packet& a2, Packet& a3) +template +EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) *ap0, + const __UNPACK_TYPE__(Packet) *ap1, const __UNPACK_TYPE__(Packet) *ap2, + Packet& a0, Packet& a1, Packet& a2, Packet& a3) { - a0 = pset1(a[0]); - if (N > 1) { - a1 = pset1(a[1]); + a0 = pset1(ap0[0]); + if (N == 4) { + a1 = pset1(ap0[1]); + a2 = pset1(ap0[2]); + a3 = pset1(ap0[3]); + EIGEN_UNUSED_VARIABLE(ap1); + EIGEN_UNUSED_VARIABLE(ap2); } else { - EIGEN_UNUSED_VARIABLE(a1); - } - if (N > 2) { - a2 = pset1(a[2]); - } else { - EIGEN_UNUSED_VARIABLE(a2); - } - if (N > 3) { - a3 = pset1(a[3]); - } else { - EIGEN_UNUSED_VARIABLE(a3); + if (N > 1) { + a1 = pset1(ap1[0]); + } else { + EIGEN_UNUSED_VARIABLE(a1); + EIGEN_UNUSED_VARIABLE(ap1); + } + if (N > 2) { + a2 = pset1(ap2[0]); + } else { + EIGEN_UNUSED_VARIABLE(a2); + EIGEN_UNUSED_VARIABLE(ap2); + } } } +template<> EIGEN_ALWAYS_INLINE void +pbroadcastN(const float *ap0, const float *, const float *, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + pbroadcast4(ap0, a0, a1, a2, a3); +} + +template<> EIGEN_ALWAYS_INLINE void +pbroadcastN(const float *ap0, const float *ap1, const float *ap2, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + pbroadcastN(ap0, ap1, ap2, a0, a1, a2, a3); +} + template<> -EIGEN_ALWAYS_INLINE void pbroadcastN_old(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +EIGEN_ALWAYS_INLINE void pbroadcastN(const double* ap0, const double *, + const double *, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) { - pbroadcast4(a, a0, a1, a2, a3); -} - -template<> -EIGEN_ALWAYS_INLINE void pbroadcastN_old(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) -{ - a1 = pload(a); - a3 = pload(a + 2); + a1 = pload(ap0); + a3 = pload(ap0 + 2); a0 = vec_splat(a1, 0); a1 = vec_splat(a1, 1); a2 = vec_splat(a3, 0); a3 = vec_splat(a3, 1); } -template EIGEN_ALWAYS_INLINE void -pbroadcastN(const __UNPACK_TYPE__(Packet) *a, - Packet& a0, Packet& a1, Packet& a2, Packet& a3) +// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. +template +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) { - a0 = pset1(a[0]); - if (N > 1) { - a1 = pset1(a[1]); - } else { - EIGEN_UNUSED_VARIABLE(a1); + for (int M = 0; M < N; M++) { + acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]); } - if (N > 2) { - a2 = pset1(a[2]); - } else { - EIGEN_UNUSED_VARIABLE(a2); - } - if (N > 3) { - a3 = pset1(a[3]); - } else { - EIGEN_UNUSED_VARIABLE(a3); + + if (full) { + for (int M = 0; M < N; M++) { + acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]); + } } } -template<> EIGEN_ALWAYS_INLINE void -pbroadcastN(const float *a, - Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +template +EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) { - a3 = pload(a); - a0 = vec_splat(a3, 0); - a1 = vec_splat(a3, 1); - a2 = vec_splat(a3, 2); - a3 = vec_splat(a3, 3); + bcouple_common(taccReal, taccImag, acc1, acc2); + + for (int M = 0; M < N; M++) { + acc1.packet[M] = padd(tRes.packet[M], acc1.packet[M]); + } + + if (full) { + for (int M = 0; M < N; M++) { + acc2.packet[M] = padd(tRes.packet[M+N], acc2.packet[M]); + } + } } // PEEL loop factor. #define PEEL 7 #define PEEL_ROW 7 -#define MICRO_UNROLL_PEEL(func) \ +#define MICRO_UNROLL(func) \ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) +#define MICRO_NORMAL_ROWS \ + accRows == quad_traits::rows || accRows == 1 + +#define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1) + +#define MICRO_RHS(ptr, N) rhs_##ptr##N + #define MICRO_ZERO_PEEL(peel) \ if ((PEEL_ROW > peel) && (peel != 0)) { \ - bsetzero(accZero##peel); \ + bsetzero(accZero##peel); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##peel); \ } -#define MICRO_ZERO_PEEL_ROW \ - MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL); +#define MICRO_ADD(ptr, N) \ + if (MICRO_NORMAL_ROWS) { \ + MICRO_RHS(ptr,0) += (accRows * N); \ + } else { \ + MICRO_RHS(ptr,0) += N; \ + MICRO_RHS(ptr,1) += N; \ + if (accRows == 3) { \ + MICRO_RHS(ptr,2) += N; \ + } \ + } + +#define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N) + +#define MICRO_BROADCAST1(peel, ptr, rhsV, real) \ + if (MICRO_NORMAL_ROWS) { \ + pbroadcastN(MICRO_RHS(ptr,0) + (accRows * peel), MICRO_RHS(ptr,0), MICRO_RHS(ptr,0), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + } else { \ + pbroadcastN(MICRO_RHS(ptr,0) + peel, MICRO_RHS(ptr,1) + peel, MICRO_RHS(ptr,2) + peel, rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + } + +#define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true) + +#define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \ + pbroadcastN(MICRO_RHS(ptr,0), MICRO_RHS(ptr,1), MICRO_RHS(ptr,2), rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + +#define MICRO_BROADCAST_EXTRA \ + Packet rhsV[4]; \ + MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \ + MICRO_ADD_ROWS(1) + +#define MICRO_SRC2(ptr, N, M) \ + if (MICRO_NORMAL_ROWS) { \ + EIGEN_UNUSED_VARIABLE(strideB); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,1)); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \ + } else { \ + MICRO_RHS(ptr,1) = rhs_base + N + M; \ + if (accRows == 3) { \ + MICRO_RHS(ptr,2) = rhs_base + N*2 + M; \ + } else { \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr,2)); \ + } \ + } + +#define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0) + +#define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL) #define MICRO_WORK_PEEL(peel) \ if (PEEL_ROW > peel) { \ - pbroadcastN(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + MICRO_BROADCAST(peel) \ pger(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ @@ -1411,9 +1374,9 @@ pbroadcastN(const float *a, #define MICRO_WORK_PEEL_ROW \ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \ - MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \ + MICRO_UNROLL(MICRO_WORK_PEEL) \ lhs_ptr += (remaining_rows * PEEL_ROW); \ - rhs_ptr += (accRows * PEEL_ROW); + MICRO_ADD_ROWS(PEEL_ROW) #define MICRO_ADD_PEEL(peel, sum) \ if (PEEL_ROW > peel) { \ @@ -1426,17 +1389,34 @@ pbroadcastN(const float *a, MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \ MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0) +#define MICRO_PREFETCHN1(ptr, N) \ + EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,0)); \ + if (N == 2 || N == 3) { \ + EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,1)); \ + if (N == 3) { \ + EIGEN_POWER_PREFETCH(MICRO_RHS(ptr,2)); \ + } \ + } + +#define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N) + +#define MICRO_COMPLEX_PREFETCHN(N) \ + MICRO_PREFETCHN1(ptr_real, N); \ + if(!RhsIsReal) { \ + MICRO_PREFETCHN1(ptr_imag, N); \ + } + template EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW( const Scalar* &lhs_ptr, - const Scalar* &rhs_ptr, + const Scalar* &rhs_ptr0, + const Scalar* &rhs_ptr1, + const Scalar* &rhs_ptr2, PacketBlock &accZero) { - Packet rhsV[4]; - pbroadcastN(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + MICRO_BROADCAST_EXTRA pger(&accZero, lhs_ptr, rhsV); lhs_ptr += remaining_rows; - rhs_ptr += accRows; } template @@ -1447,60 +1427,71 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration( Index depth, Index strideA, Index offsetA, + Index strideB, Index row, - Index col, Index rows, - Index cols, const Packet& pAlpha, const Packet& pMask) { - const Scalar* rhs_ptr = rhs_base; + const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL; const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc; - bsetzero(accZero0); + MICRO_SRC2_PTR + bsetzero(accZero0); - Index remaining_depth = (col + quad_traits::rows < cols) ? depth : (depth & -quad_traits::rows); + Index remaining_depth = depth & -quad_traits::rows; Index k = 0; if (remaining_depth >= PEEL_ROW) { MICRO_ZERO_PEEL_ROW do { - EIGEN_POWER_PREFETCH(rhs_ptr); + MICRO_PREFETCHN(accRows) EIGEN_POWER_PREFETCH(lhs_ptr); MICRO_WORK_PEEL_ROW } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth); MICRO_ADD_PEEL_ROW } - for(; k < remaining_depth; k++) + for(; k < depth; k++) { - MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero0); + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0); } - if ((remaining_depth == depth) && (rows >= accCols)) + bload(acc, res, row, 0); + if ((accRows == 1) || (rows >= accCols)) { - bload(acc, res, row, 0); - bscale(acc, accZero0, pAlpha, pMask); - res.template storePacketBlock(row, 0, acc); + bscale(acc, accZero0, pAlpha, pMask); + bstore(acc, res, row); } else { - for(; k < depth; k++) - { - Packet rhsV[4]; - pbroadcastN(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - pger(&accZero0, lhs_ptr, rhsV); - lhs_ptr += remaining_rows; - rhs_ptr += accRows; - } - + bscale(acc, accZero0, pAlpha, pMask); for(Index j = 0; j < accRows; j++) { - accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]); for(Index i = 0; i < remaining_rows; i++) { - res(row + i, j) += accZero0.packet[j][i]; + res(row + i, j) = acc.packet[j][i]; } } } } +#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \ + switch(value) { \ + default: \ + MICRO_EXTRA_UNROLL(1) \ + break; \ + case 2: \ + if (is_col || (sizeof(Scalar) == sizeof(float))) { \ + MICRO_EXTRA_UNROLL(2) \ + } \ + break; \ + case 3: \ + if (is_col || (sizeof(Scalar) == sizeof(float))) { \ + MICRO_EXTRA_UNROLL(3) \ + } \ + break; \ + } + +#define MICRO_EXTRA_ROWS(N) \ + gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask); + template EIGEN_ALWAYS_INLINE void gemm_extra_row( const DataMapper& res, @@ -1509,46 +1500,20 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row( Index depth, Index strideA, Index offsetA, + Index strideB, Index row, - Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlpha, const Packet& pMask) { - switch(remaining_rows) { - case 1: - gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); - break; - case 2: - if (sizeof(Scalar) == sizeof(float)) { - gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); - } - break; - default: - if (sizeof(Scalar) == sizeof(float)) { - gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); - } - break; - } + MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows, false) } -#define MICRO_UNROLL(func) \ - func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) - #define MICRO_UNROLL_WORK(func, func2, peel) \ - MICRO_UNROLL(func2); \ - func(0,peel) func(1,peel) func(2,peel) func(3,peel) \ - func(4,peel) func(5,peel) func(6,peel) func(7,peel) - -#define MICRO_LOAD_ONE(iter) \ - if (unroll_factor > iter) { \ - lhsV##iter = ploadLhs(lhs_ptr##iter); \ - lhs_ptr##iter += accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsV##iter); \ - } + MICRO_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) \ + func(4,peel) func(5,peel) func(6,peel) func(7,peel) #define MICRO_WORK_ONE(iter, peel) \ if (unroll_factor > iter) { \ @@ -1558,7 +1523,7 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row( #define MICRO_TYPE_PEEL4(func, func2, peel) \ if (PEEL > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ - pbroadcastN(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + MICRO_BROADCAST(peel) \ MICRO_UNROLL_WORK(func, func2, peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ @@ -1566,79 +1531,71 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row( #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \ - func(func1,func2,0); func(func1,func2,1); \ - func(func1,func2,2); func(func1,func2,3); \ - func(func1,func2,4); func(func1,func2,5); \ - func(func1,func2,6); func(func1,func2,7); + func(func1,func2,0) func(func1,func2,1) \ + func(func1,func2,2) func(func1,func2,3) \ + func(func1,func2,4) func(func1,func2,5) \ + func(func1,func2,6) func(func1,func2,7) #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \ Packet rhsV0[M]; \ - func(func1,func2,0); + func(func1,func2,0) -#define MICRO_ONE_PEEL4 \ - MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ - rhs_ptr += (accRows * PEEL); +#define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \ + MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \ + MICRO_ADD_ROWS(size) -#define MICRO_ONE4 \ - MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ - rhs_ptr += accRows; +#define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL) + +#define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1) #define MICRO_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzero(accZero##iter); \ + bsetzero(accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ } #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE) -#define MICRO_SRC_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ - lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ - } - #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) -#define MICRO_PREFETCH_ONE(iter) \ - if (unroll_factor > iter) { \ - EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ - } - #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) #define MICRO_STORE_ONE(iter) \ if (unroll_factor > iter) { \ bload(acc, res, row + iter*accCols, 0); \ - bscale(acc, accZero##iter, pAlpha); \ - res.template storePacketBlock(row + iter*accCols, 0, acc); \ + bscale(acc, accZero##iter, pAlpha, pMask); \ + bstore(acc, res, row + iter*accCols); \ } #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) -template -EIGEN_STRONG_INLINE void gemm_unrolled_iteration( +template +EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, + Index offsetA, + Index strideB, Index& row, - const Packet& pAlpha) + const Packet& pAlpha, + const Packet& pMask) { - const Scalar* rhs_ptr = rhs_base; - const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; + const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL; + const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; PacketBlock acc; + MICRO_SRC2_PTR MICRO_SRC_PTR MICRO_DST_PTR Index k = 0; for(; k + PEEL <= depth; k+= PEEL) { - EIGEN_POWER_PREFETCH(rhs_ptr); + MICRO_PREFETCHN(accRows) MICRO_PREFETCH MICRO_ONE_PEEL4 } @@ -1648,9 +1605,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration( } MICRO_STORE - row += unroll_factor*accCols; + MICRO_UPDATE } +#define MICRO_UNROLL_ITER2(N, M) \ + gemm_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \ + if (M) return; + template EIGEN_ALWAYS_INLINE void gemm_cols( const DataMapper& res, @@ -1663,55 +1624,54 @@ EIGEN_ALWAYS_INLINE void gemm_cols( Index offsetB, Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlpha, const Packet& pMask) { const DataMapper res3 = res.getSubMapper(0, col); - const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* rhs_base = blockB + col*strideB + MICRO_NEW_ROWS*offsetB; const Scalar* lhs_base = blockA + accCols*offsetA; Index row = 0; -#define MAX_UNROLL 6 +#define MAX_UNROLL 7 while(row + MAX_UNROLL*accCols <= rows) { - gemm_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER2(MAX_UNROLL, 0); } switch( (rows-row)/accCols ) { #if MAX_UNROLL > 7 case 7: - gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 7) break; #endif #if MAX_UNROLL > 6 case 6: - gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 6) break; #endif #if MAX_UNROLL > 5 case 5: - gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 5) break; #endif #if MAX_UNROLL > 4 case 4: - gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 4) break; #endif #if MAX_UNROLL > 3 case 3: - gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 3) break; #endif #if MAX_UNROLL > 2 case 2: - gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 2) break; #endif #if MAX_UNROLL > 1 case 1: - gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 1) break; #endif default: @@ -1721,10 +1681,13 @@ EIGEN_ALWAYS_INLINE void gemm_cols( if(remaining_rows > 0) { - gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask); } } +#define MICRO_EXTRA_COLS(N) \ + gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); + template EIGEN_STRONG_INLINE void gemm_extra_cols( const DataMapper& res, @@ -1742,9 +1705,7 @@ EIGEN_STRONG_INLINE void gemm_extra_cols( const Packet& pAlpha, const Packet& pMask) { - for (; col < cols; col++) { - gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); - } + MICRO_EXTRA(MICRO_EXTRA_COLS, cols-col, true) } /**************** @@ -1764,10 +1725,13 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Index col = 0; for(; col + accRows <= cols; col += accRows) { - gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); } - gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + if (col != cols) + { + gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + } } #define accColsC (accCols / 2) @@ -1778,41 +1742,79 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const #define PEEL_COMPLEX 3 #define PEEL_COMPLEX_ROW 3 -#define MICRO_COMPLEX_UNROLL_PEEL(func) \ +#define MICRO_COMPLEX_UNROLL(func) \ func(0) func(1) func(2) func(3) #define MICRO_COMPLEX_ZERO_PEEL(peel) \ if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \ - bsetzero(accReal##peel); \ - bsetzero(accImag##peel); \ + bsetzero(accReal##peel); \ + bsetzero(accImag##peel); \ } else { \ EIGEN_UNUSED_VARIABLE(accReal##peel); \ EIGEN_UNUSED_VARIABLE(accImag##peel); \ } -#define MICRO_COMPLEX_ZERO_PEEL_ROW \ - MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL); +#define MICRO_COMPLEX_ADD_ROWS(N, used) \ + MICRO_ADD(ptr_real, N) \ + if (!RhsIsReal) { \ + MICRO_ADD(ptr_imag, N) \ + } else if (used) { \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \ + } + +#define MICRO_COMPLEX_BROADCAST(peel) \ + MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \ + if (!RhsIsReal) { \ + MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_BROADCAST_EXTRA \ + Packet rhsV[4], rhsVi[4]; \ + MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \ + if(!RhsIsReal) { \ + MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi); \ + } \ + MICRO_COMPLEX_ADD_ROWS(1, true) + +#define MICRO_COMPLEX_SRC2_PTR \ + MICRO_SRC2(ptr_real, strideB*advanceCols, 0) \ + if (!RhsIsReal) { \ + MICRO_RHS(ptr_imag,0) = rhs_base + MICRO_NEW_ROWS*strideB; \ + MICRO_SRC2(ptr_imag, strideB*advanceCols, strideB) \ + } else { \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,0)); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,1)); \ + EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag,2)); \ + } + +#define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL) #define MICRO_COMPLEX_WORK_PEEL(peel) \ if (PEEL_COMPLEX_ROW > peel) { \ - pbroadcastN_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ - if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ + MICRO_COMPLEX_BROADCAST(peel) \ pgerc(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } +#define MICRO_COMPLEX_ADD_COLS(size) \ + lhs_ptr_real += (remaining_rows * size); \ + if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * size); \ + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + #define MICRO_COMPLEX_WORK_PEEL_ROW \ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \ Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \ - MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \ - lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \ - if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \ - else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \ - rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \ - if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \ - else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \ + MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \ + MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false) #define MICRO_COMPLEX_ADD_PEEL(peel, sum) \ if (PEEL_COMPLEX_ROW > peel) { \ @@ -1829,19 +1831,13 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const template EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW( const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, - const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, + const Scalar* &rhs_ptr_real0, const Scalar* &rhs_ptr_real1, const Scalar* &rhs_ptr_real2, + const Scalar* &rhs_ptr_imag0, const Scalar* &rhs_ptr_imag1, const Scalar* &rhs_ptr_imag2, PacketBlock &accReal, PacketBlock &accImag) { - Packet rhsV[4], rhsVi[4]; - pbroadcastN_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + MICRO_COMPLEX_BROADCAST_EXTRA pgerc(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); - lhs_ptr_real += remaining_rows; - if(!LhsIsReal) lhs_ptr_imag += remaining_rows; - else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); - rhs_ptr_real += accRows; - if(!RhsIsReal) rhs_ptr_imag += accRows; - else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + MICRO_COMPLEX_ADD_COLS(1) } template @@ -1854,17 +1850,13 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration( Index offsetA, Index strideB, Index row, - Index col, Index rows, - Index cols, const Packet& pAlphaReal, const Packet& pAlphaImag, const Packet& pMask) { - const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag = NULL; - if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB; - else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL; + const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL; const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; const Scalar* lhs_ptr_imag = NULL; if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; @@ -1874,19 +1866,18 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration( PacketBlock acc0, acc1; PacketBlock tRes; - bsetzero(accReal0); - bsetzero(accImag0); + MICRO_COMPLEX_SRC2_PTR - Index remaining_depth = (col + quad_traits::rows < cols) ? depth : (depth & -quad_traits::rows); + bsetzero(accReal0); + bsetzero(accImag0); + + Index remaining_depth = depth & -quad_traits::rows; Index k = 0; if (remaining_depth >= PEEL_COMPLEX_ROW) { MICRO_COMPLEX_ZERO_PEEL_ROW do { - EIGEN_POWER_PREFETCH(rhs_ptr_real); - if(!RhsIsReal) { - EIGEN_POWER_PREFETCH(rhs_ptr_imag); - } + MICRO_COMPLEX_PREFETCHN(accRows) EIGEN_POWER_PREFETCH(lhs_ptr_real); if(!LhsIsReal) { EIGEN_POWER_PREFETCH(lhs_ptr_imag); @@ -1895,52 +1886,44 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration( } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth); MICRO_COMPLEX_ADD_PEEL_ROW } - for(; k < remaining_depth; k++) + for(; k < depth; k++) { - MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0); + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1, rhs_ptr_imag2, accReal0, accImag0); } - if ((remaining_depth == depth) && (rows >= accCols)) + const bool full = (remaining_rows > accColsC); + bload(tRes, res, row, 0); + if ((accRows == 1) || (rows >= accCols)) { - bload(tRes, res, row, 0); - bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); - bcouple(taccReal, taccImag, tRes, acc0, acc1); - res.template storePacketBlock(row + 0, 0, acc0); - res.template storePacketBlock(row + accColsC, 0, acc1); - } else { - for(; k < depth; k++) - { - Packet rhsV[4], rhsVi[4]; - pbroadcastN_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); - pgerc(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); - lhs_ptr_real += remaining_rows; - if(!LhsIsReal) lhs_ptr_imag += remaining_rows; - rhs_ptr_real += accRows; - if(!RhsIsReal) rhs_ptr_imag += accRows; + bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); + bcouple(taccReal, taccImag, tRes, acc0, acc1); + bstore(acc0, res, row + 0); + if (full) { + bstore(acc1, res, row + accColsC); } - - bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag); - bcouple_common(taccReal, taccImag, acc0, acc1); + } else { + bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); + bcouple(taccReal, taccImag, tRes, acc0, acc1); if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) { for(Index j = 0; j < accRows; j++) { - res(row + 0, j) += pfirst(acc0.packet[j]); + res(row + 0, j) = pfirst(acc0.packet[j]); } } else { - for(Index j = 0; j < accRows; j++) { - PacketBlock acc2; - acc2.packet[0] = res.template loadPacket(row + 0, j) + acc0.packet[j]; - res.template storePacketBlock(row + 0, j, acc2); - if(remaining_rows > accColsC) { - res(row + accColsC, j) += pfirst(acc1.packet[j]); + bstore(acc0, res, row + 0); + if (full) { + for(Index j = 0; j < accRows; j++) { + res(row + accColsC, j) = pfirst(acc1.packet[j]); } } } } } +#define MICRO_COMPLEX_EXTRA_ROWS(N) \ + gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask); + template EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( const DataMapper& res, @@ -1951,51 +1934,18 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( Index offsetA, Index strideB, Index row, - Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlphaReal, const Packet& pAlphaImag, const Packet& pMask) { - switch(remaining_rows) { - case 1: - gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); - break; - case 2: - if (sizeof(Scalar) == sizeof(float)) { - gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); - } - break; - default: - if (sizeof(Scalar) == sizeof(float)) { - gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); - } - break; - } + MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows, false) } -#define MICRO_COMPLEX_UNROLL(func) \ - func(0) func(1) func(2) func(3) - #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ - MICRO_COMPLEX_UNROLL(func2); \ - func(0,peel) func(1,peel) func(2,peel) func(3,peel) - -#define MICRO_COMPLEX_LOAD_ONE(iter) \ - if (unroll_factor > iter) { \ - lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ - if(!LhsIsReal) { \ - lhsVi##iter = ploadLhs(lhs_ptr_real##iter + imag_delta); \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ - } \ - lhs_ptr_real##iter += accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsV##iter); \ - EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ - } + MICRO_COMPLEX_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \ if (unroll_factor > iter) { \ @@ -2006,12 +1956,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( if (PEEL_COMPLEX > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3; \ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \ - pbroadcastN_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ - if(!RhsIsReal) { \ - pbroadcastN_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ - } else { \ - EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ - } \ + MICRO_COMPLEX_BROADCAST(peel) \ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ @@ -2021,27 +1966,25 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \ Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \ - func(func1,func2,0); func(func1,func2,1); \ - func(func1,func2,2); func(func1,func2,3); + func(func1,func2,0) func(func1,func2,1) \ + func(func1,func2,2) func(func1,func2,3) #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \ Packet rhsV0[M], rhsVi0[M];\ - func(func1,func2,0); + func(func1,func2,0) -#define MICRO_COMPLEX_ONE_PEEL4 \ - MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ - rhs_ptr_real += (accRows * PEEL_COMPLEX); \ - if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX); +#define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \ + MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \ + MICRO_COMPLEX_ADD_ROWS(size, false) -#define MICRO_COMPLEX_ONE4 \ - MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ - rhs_ptr_real += accRows; \ - if(!RhsIsReal) rhs_ptr_imag += accRows; +#define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX) + +#define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1) #define MICRO_COMPLEX_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzero(accReal##iter); \ - bsetzero(accImag##iter); \ + bsetzero(accReal##iter); \ + bsetzero(accImag##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accReal##iter); \ EIGEN_UNUSED_VARIABLE(accImag##iter); \ @@ -2049,53 +1992,42 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE) -#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ - lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ - } - #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) -#define MICRO_COMPLEX_PREFETCH_ONE(iter) \ - if (unroll_factor > iter) { \ - EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ - } - #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) #define MICRO_COMPLEX_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - bload(tRes, res, row + iter*accCols, 0); \ - bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ - bcouple(taccReal, taccImag, tRes, acc0, acc1); \ - res.template storePacketBlock(row + iter*accCols + 0, 0, acc0); \ - res.template storePacketBlock(row + iter*accCols + accColsC, 0, acc1); \ + const bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \ + bload(tRes, res, row + iter*accCols, 0); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + bstore(acc0, res, row + iter*accCols + 0); \ + if (full) { \ + bstore(acc1, res, row + iter*accCols + accColsC); \ + } \ } #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE) -template -EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( +template +EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, + Index offsetA, Index strideB, Index& row, const Packet& pAlphaReal, - const Packet& pAlphaImag) + const Packet& pAlphaImag, + const Packet& pMask) { - const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag = NULL; + const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL; + const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL; const Index imag_delta = accCols*strideA; - if(!RhsIsReal) { - rhs_ptr_imag = rhs_base + accRows*strideB; - } else { - EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); - } + const Index imag_delta2 = accCols2*strideA; const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL; const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL; PacketBlock accReal0, accImag0, accReal1, accImag1; @@ -2104,16 +2036,14 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( PacketBlock acc0, acc1; PacketBlock tRes; + MICRO_COMPLEX_SRC2_PTR MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_DST_PTR Index k = 0; for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) { - EIGEN_POWER_PREFETCH(rhs_ptr_real); - if(!RhsIsReal) { - EIGEN_POWER_PREFETCH(rhs_ptr_imag); - } + MICRO_COMPLEX_PREFETCHN(accRows) MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_ONE_PEEL4 } @@ -2123,9 +2053,13 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( } MICRO_COMPLEX_STORE - row += unroll_factor*accCols; + MICRO_COMPLEX_UPDATE } +#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \ + gemm_complex_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \ + if (M) return; + template EIGEN_ALWAYS_INLINE void gemm_complex_cols( const DataMapper& res, @@ -2138,7 +2072,6 @@ EIGEN_ALWAYS_INLINE void gemm_complex_cols( Index offsetB, Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlphaReal, const Packet& pAlphaImag, @@ -2146,33 +2079,33 @@ EIGEN_ALWAYS_INLINE void gemm_complex_cols( { const DataMapper res3 = res.getSubMapper(0, col); - const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* rhs_base = blockB + advanceCols*col*strideB + MICRO_NEW_ROWS*offsetB; const Scalar* lhs_base = blockA + accCols*offsetA; Index row = 0; -#define MAX_COMPLEX_UNROLL 3 +#define MAX_COMPLEX_UNROLL 4 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { - gemm_complex_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_COMPLEX_UNROLL_ITER2(MAX_COMPLEX_UNROLL, 0); } switch( (rows-row)/accCols ) { #if MAX_COMPLEX_UNROLL > 4 case 4: - gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 4) break; #endif #if MAX_COMPLEX_UNROLL > 3 case 3: - gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 3) break; #endif #if MAX_COMPLEX_UNROLL > 2 case 2: - gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 2) break; #endif #if MAX_COMPLEX_UNROLL > 1 case 1: - gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 1) break; #endif default: @@ -2182,10 +2115,13 @@ EIGEN_ALWAYS_INLINE void gemm_complex_cols( if(remaining_rows > 0) { - gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); } } +#define MICRO_COMPLEX_EXTRA_COLS(N) \ + gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); + template EIGEN_STRONG_INLINE void gemm_complex_extra_cols( const DataMapper& res, @@ -2204,9 +2140,7 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_cols( const Packet& pAlphaImag, const Packet& pMask) { - for (; col < cols; col++) { - gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); - } + MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols-col, true) } template @@ -2227,10 +2161,13 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl Index col = 0; for(; col + accRows <= cols; col += accRows) { - gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); } - gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + if (col != cols) + { + gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } } #undef accColsC diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index d92b67815..e68c595c7 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -19,10 +19,9 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row( Index depth, Index strideA, Index offsetA, + Index strideB, Index row, - Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlpha, const Packet& pMask); @@ -57,9 +56,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( Index offsetA, Index strideB, Index row, - Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlphaReal, const Packet& pAlphaImag, @@ -83,79 +80,100 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_cols( const Packet& pAlphaImag, const Packet& pMask); -template -EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs); +template +EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs); -template +template EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha); +template +EIGEN_ALWAYS_INLINE void bstore(PacketBlock& acc, const DataMapper& res, Index row); -template -EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag); +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask); -// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. -template -EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_mergeh(taccReal.packet[0], taccImag.packet[0]); - if (N > 1) { - acc1.packet[1].v = vec_mergeh(taccReal.packet[1], taccImag.packet[1]); - } - if (N > 2) { - acc1.packet[2].v = vec_mergeh(taccReal.packet[2], taccImag.packet[2]); - } - if (N > 3) { - acc1.packet[3].v = vec_mergeh(taccReal.packet[3], taccImag.packet[3]); +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask); + +template +EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2); + +#define MICRO_NORMAL(iter) \ + (accCols == accCols2) || (unroll_factor != (iter + 1)) + +#define MICRO_UNROLL_ITER(func, N) \ + switch (remaining_rows) { \ + default: \ + func(N, 0) \ + break; \ + case 1: \ + func(N, 1) \ + break; \ + case 2: \ + if (sizeof(Scalar) == sizeof(float)) { \ + func(N, 2) \ + } \ + break; \ + case 3: \ + if (sizeof(Scalar) == sizeof(float)) { \ + func(N, 3) \ + } \ + break; \ } - acc2.packet[0].v = vec_mergel(taccReal.packet[0], taccImag.packet[0]); - if (N > 1) { - acc2.packet[1].v = vec_mergel(taccReal.packet[1], taccImag.packet[1]); - } - if (N > 2) { - acc2.packet[2].v = vec_mergel(taccReal.packet[2], taccImag.packet[2]); - } - if (N > 3) { - acc2.packet[3].v = vec_mergel(taccReal.packet[3], taccImag.packet[3]); - } -} +#define MICRO_NORMAL_COLS(iter, a, b) ((MICRO_NORMAL(iter)) ? a : b) -template -EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - bcouple_common(taccReal, taccImag, acc1, acc2); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - if (N > 1) { - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - } - if (N > 2) { - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - } - if (N > 3) { - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); +#define MICRO_LOAD1(lhs_ptr, iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr##iter); \ + lhs_ptr##iter += MICRO_NORMAL_COLS(iter, accCols, accCols2); \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ } - acc2.packet[0] = padd(tRes.packet[0+N], acc2.packet[0]); - if (N > 1) { - acc2.packet[1] = padd(tRes.packet[1+N], acc2.packet[1]); - } - if (N > 2) { - acc2.packet[2] = padd(tRes.packet[2+N], acc2.packet[2]); - } - if (N > 3) { - acc2.packet[3] = padd(tRes.packet[3+N], acc2.packet[3]); - } -} +#define MICRO_LOAD_ONE(iter) MICRO_LOAD1(lhs_ptr, iter) + +#define MICRO_COMPLEX_LOAD_ONE(iter) \ + if (!LhsIsReal && (unroll_factor > iter)) { \ + lhsVi##iter = ploadLhs(lhs_ptr_real##iter + MICRO_NORMAL_COLS(iter, imag_delta, imag_delta2)); \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } \ + MICRO_LOAD1(lhs_ptr_real, iter) \ + +#define MICRO_SRC_PTR1(lhs_ptr, advRows, iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + (row+(iter*accCols))*strideA*advRows - MICRO_NORMAL_COLS(iter, 0, (accCols-accCols2)*offsetA); \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr, 1, iter) + +#define MICRO_COMPLEX_SRC_PTR_ONE(iter) MICRO_SRC_PTR1(lhs_ptr_real, advanceRows, iter) + +#define MICRO_PREFETCH1(lhs_ptr, iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ + } + +#define MICRO_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr, iter) + +#define MICRO_COMPLEX_PREFETCH_ONE(iter) MICRO_PREFETCH1(lhs_ptr_real, iter) + +#define MICRO_UPDATE \ + if (accCols == accCols2) { \ + EIGEN_UNUSED_VARIABLE(pMask); \ + EIGEN_UNUSED_VARIABLE(offsetA); \ + row += unroll_factor*accCols; \ + } + +#define MICRO_COMPLEX_UPDATE \ + MICRO_UPDATE \ + if(LhsIsReal || (accCols == accCols2)) { \ + EIGEN_UNUSED_VARIABLE(imag_delta2); \ + } -// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. -template -EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs) -{ - return ploadu(rhs); -} } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index 5a0834758..1cb82eed5 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -21,6 +21,9 @@ #if !__has_builtin(__builtin_vsx_assemble_pair) #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair #endif +#if !__has_builtin(__builtin_vsx_disassemble_pair) +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair +#endif #endif #include "../../InternalHeaderCheck.h" @@ -29,44 +32,48 @@ namespace Eigen { namespace internal { -template +#define accColsC (accCols / 2) + EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) { __builtin_mma_xxsetaccz(acc); } -template -EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc) +template +EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc) { PacketBlock result; __builtin_mma_disassemble_acc(&result.packet, acc); PacketBlock tRes; - bload(tRes, data, i, 0); + bload(tRes, data, i, 0); - bscale(tRes, result, alpha); + bscale(tRes, result, alpha, pMask); - data.template storePacketBlock(i, 0, tRes); + bstore(tRes, data, i); } -template -EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag) +template +EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal, __vector_quad* accImag) { + const bool full = (accCols2 > accColsC); PacketBlock resultReal, resultImag; __builtin_mma_disassemble_acc(&resultReal.packet, accReal); __builtin_mma_disassemble_acc(&resultImag.packet, accImag); PacketBlock tRes; - bload(tRes, data, i, 0); + bload(tRes, data, i, 0); - PacketBlock taccReal, taccImag; - bscalec(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag); + PacketBlock taccReal, taccImag; + bscalec(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask); PacketBlock acc1, acc2; - bcouple(taccReal, taccImag, tRes, acc1, acc2); + bcouple(taccReal, taccImag, tRes, acc1, acc2); - data.template storePacketBlock(i, 0, acc1); - data.template storePacketBlock(i + accColsC, 0, acc2); + bstore(acc1, data, i); + if (full) { + bstore(acc2, data, i + accColsC); + } } // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments @@ -81,18 +88,6 @@ EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const L } } -template -EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock& a, const Packet2d& b) -{ - __vector_pair* a0 = reinterpret_cast<__vector_pair *>(const_cast(&a.packet[0])); - if(NegativeAccumulate) - { - __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b); - } else { - __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b); - } -} - template EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b) { @@ -104,18 +99,13 @@ EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, con } } -template -EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&) -{ - // Just for compilation -} - -template -EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi) +template +EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, Packet& lhsVi, const RhsPacket& rhsV, RhsPacket& rhsVi) { pgerMMA(accReal, rhsV, lhsV); if(LhsIsReal) { pgerMMA(accImag, rhsVi, lhsV); + EIGEN_UNUSED_VARIABLE(lhsVi); } else { if(!RhsIsReal) { pgerMMA(accReal, rhsVi, lhsVi); @@ -128,35 +118,33 @@ EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag } // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet)* rhs) +{ + return ploadu(rhs); +} + template EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) { - rhsV = ploadRhs(rhs); + rhsV = ploadRhs(rhs); } template<> -EIGEN_ALWAYS_INLINE void ploadRhsMMA >(const double* rhs, PacketBlock& rhsV) -{ - rhsV.packet[0] = ploadRhs(rhs); - rhsV.packet[1] = ploadRhs(rhs + (sizeof(Packet2d) / sizeof(double))); -} - -template<> -EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) +EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) { #if EIGEN_COMP_LLVM __builtin_vsx_assemble_pair(&rhsV, - reinterpret_cast<__vector unsigned char>(ploadRhs(rhs + (sizeof(Packet2d) / sizeof(double)))), - reinterpret_cast<__vector unsigned char>(ploadRhs(rhs))); + reinterpret_cast<__vector unsigned char>(ploadRhs(rhs + (sizeof(Packet2d) / sizeof(double)))), + reinterpret_cast<__vector unsigned char>(ploadRhs(rhs))); #else __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs)); #endif } -template<> -EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) +EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) { - // Just for compilation + ploadRhsMMA(lhs, lhsV); } // PEEL_MMA loop factor. @@ -165,98 +153,116 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) #define MICRO_MMA_UNROLL(func) \ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) -#define MICRO_MMA_LOAD_ONE(iter) \ - if (unroll_factor > iter) { \ - lhsV##iter = ploadLhs(lhs_ptr##iter); \ - lhs_ptr##iter += accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsV##iter); \ - } +#define MICRO_MMA_WORK(func, type, peel) \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ + func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) #define MICRO_MMA_WORK_ONE(iter, type, peel) \ if (unroll_factor > iter) { \ pgerMMA(&accZero##iter, rhsV##peel, lhsV##iter); \ } -#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \ +#define MICRO_MMA_WORK_TWO(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgerMMA(&accZero##iter, rhsV##peel, lhsV2##iter.packet[peel & 1]); \ + } + +#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \ + if (unroll_factor > iter) { \ + if (MICRO_NORMAL(iter)) { \ + ploadLhsMMA(reinterpret_cast(lhs_ptr##iter), plhsV##iter); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsV2##iter.packet), &plhsV##iter); \ + lhs_ptr##iter += accCols*2; \ + } else { \ + lhsV2##iter.packet[0] = ploadLhs(lhs_ptr##iter); \ + lhsV2##iter.packet[1] = ploadLhs(lhs_ptr##iter + accCols2); \ + lhs_ptr##iter += accCols2*2; \ + EIGEN_UNUSED_VARIABLE(plhsV##iter) \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV2##iter); \ + EIGEN_UNUSED_VARIABLE(plhsV##iter) \ + } + +#define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) + +#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \ if (PEEL_MMA > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ - ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \ - MICRO_MMA_UNROLL(func2); \ - func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ - func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \ + ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \ + MICRO_MMA_UNROLL(funcl) \ + MICRO_MMA_WORK(funcw, type, peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ } -#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ +#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \ + if (PEEL_MMA > peel2) { \ + PacketBlock lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \ + __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \ + ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV##peel1); \ + ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV##peel2); \ + MICRO_MMA_UNROLL(funcl2) \ + MICRO_MMA_WORK(funcw2, type, peel1) \ + MICRO_MMA_WORK(funcw2, type, peel2) \ + } else { \ + MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \ + } + +#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \ - MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \ - MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \ - MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \ - MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); + MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \ + MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \ + MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,4,5) \ + MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,6,7) -#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \ +#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \ type rhsV0; \ - MICRO_MMA_TYPE_PEEL(func,func2,type,0); + MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) -#define MICRO_MMA_ONE_PEEL \ - if (sizeof(Scalar) == sizeof(float)) { \ - MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ - } else { \ - MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ - } \ - rhs_ptr += (accRows * PEEL_MMA); +#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \ + MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \ + rhs_ptr += (accRows * size); -#define MICRO_MMA_ONE \ - if (sizeof(Scalar) == sizeof(float)) { \ - MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ - } else { \ - MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ - } \ - rhs_ptr += accRows; +#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \ + MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \ + rhs_ptr += (accRows * size); + +#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA) + +#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1) #define MICRO_MMA_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzeroMMA(&accZero##iter); \ + bsetzeroMMA(&accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ } #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE) -#define MICRO_MMA_SRC_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ - lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ - } +#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE) -#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE) - -#define MICRO_MMA_PREFETCH_ONE(iter) \ - if (unroll_factor > iter) { \ - EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ - } - -#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE) +#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE) #define MICRO_MMA_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - storeAccumulator(row + iter*accCols, res, pAlpha, &accZero##iter); \ + storeAccumulator(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \ } #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) -template +template EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, + Index offsetA, Index& row, - const Packet& pAlpha) + const Packet& pAlpha, + const Packet& pMask) { const Scalar* rhs_ptr = rhs_base; const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; @@ -265,8 +271,8 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( MICRO_MMA_SRC_PTR MICRO_MMA_DST_PTR - Index k = 0; - for(; k + PEEL_MMA <= depth; k+= PEEL_MMA) + Index k = 0, depth2 = depth - PEEL_MMA; + for(; k <= depth2; k += PEEL_MMA) { EIGEN_POWER_PREFETCH(rhs_ptr); MICRO_MMA_PREFETCH @@ -278,9 +284,13 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( } MICRO_MMA_STORE - row += unroll_factor*accCols; + MICRO_UPDATE } +#define MICRO_MMA_UNROLL_ITER2(N, M) \ + gemm_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \ + if (M) return; + template EIGEN_ALWAYS_INLINE void gemmMMA_cols( const DataMapper& res, @@ -293,7 +303,6 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( Index offsetB, Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlpha, const Packet& pMask) @@ -306,42 +315,42 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( #define MAX_MMA_UNROLL 7 while(row + MAX_MMA_UNROLL*accCols <= rows) { - gemm_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_MMA_UNROLL_ITER2(MAX_MMA_UNROLL, 0); } switch( (rows-row)/accCols ) { #if MAX_MMA_UNROLL > 7 case 7: - gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7) break; #endif #if MAX_MMA_UNROLL > 6 case 6: - gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6) break; #endif #if MAX_MMA_UNROLL > 5 case 5: - gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5) break; #endif #if MAX_MMA_UNROLL > 4 case 4: - gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4) break; #endif #if MAX_MMA_UNROLL > 3 case 3: - gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3) break; #endif #if MAX_MMA_UNROLL > 2 case 2: - gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2) break; #endif #if MAX_MMA_UNROLL > 1 case 1: - gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 1) break; #endif default: @@ -351,7 +360,7 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( if(remaining_rows > 0) { - gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask); } } @@ -366,16 +375,20 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, const Packet pAlpha = pset1(alpha); const Packet pMask = bmask(remaining_rows); + typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2; + Index col = 0; for(; col + accRows <= cols; col += accRows) { - gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); } - gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + if (col != cols) + { + gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + } } -#define accColsC (accCols / 2) #define advanceRows ((LhsIsReal) ? 1 : 2) #define advanceCols ((RhsIsReal) ? 1 : 2) @@ -385,74 +398,104 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_UNROLL(func) \ func(0) func(1) func(2) func(3) -#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \ - if (unroll_factor > iter) { \ - lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ - if(!LhsIsReal) { \ - lhsVi##iter = ploadLhs(lhs_ptr_real##iter + imag_delta); \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ - } \ - lhs_ptr_real##iter += accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhsV##iter); \ - EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ - } +#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \ if (unroll_factor > iter) { \ - pgercMMA(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ } -#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \ +#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \ + if (!LhsIsReal && (unroll_factor > iter)) { \ + if (MICRO_NORMAL(iter)) { \ + ploadLhsMMA(reinterpret_cast(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsVi2##iter.packet), &plhsVi##iter); \ + } else { \ + lhsVi2##iter.packet[0] = ploadLhs(lhs_ptr_real##iter + imag_delta2); \ + lhsVi2##iter.packet[1] = ploadLhs(lhs_ptr_real##iter + imag_delta2 + accCols2); \ + EIGEN_UNUSED_VARIABLE(plhsVi##iter) \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \ + EIGEN_UNUSED_VARIABLE(plhsVi##iter) \ + } \ + MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter) + +#define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) + +#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \ if (PEEL_COMPLEX_MMA > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3; \ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \ - ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \ + ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \ if(!RhsIsReal) { \ - ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \ + ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \ } else { \ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } \ - MICRO_COMPLEX_MMA_UNROLL(func2); \ - func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ + MICRO_COMPLEX_MMA_UNROLL(funcl) \ + MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } -#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ +#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \ + if (PEEL_COMPLEX_MMA > peel2) { \ + PacketBlock lhsV20, lhsV21, lhsV22, lhsV23; \ + PacketBlock lhsVi20, lhsVi21, lhsVi22, lhsVi23; \ + __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \ + __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \ + ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV##peel1); \ + ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV##peel2); \ + if(!RhsIsReal) { \ + ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi##peel1); \ + ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi##peel2); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel1); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel2); \ + } \ + MICRO_COMPLEX_MMA_UNROLL(funcl2) \ + MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \ + MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \ + } else { \ + MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \ + } + +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \ type rhsV0, rhsV1, rhsV2, rhsV3; \ type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); + MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \ + MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) -#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \ +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \ type rhsV0, rhsVi0; \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); + MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) -#define MICRO_COMPLEX_MMA_ONE_PEEL \ - if (sizeof(Scalar) == sizeof(float)) { \ - MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ - } else { \ - MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ - } \ - rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \ - if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA); +#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \ + MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \ + rhs_ptr_real += (accRows * size); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * size); -#define MICRO_COMPLEX_MMA_ONE \ - if (sizeof(Scalar) == sizeof(float)) { \ - MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ - } else { \ - MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ - } \ - rhs_ptr_real += accRows; \ - if(!RhsIsReal) rhs_ptr_imag += accRows; +#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \ + MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \ + rhs_ptr_real += (accRows * size); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * size); + +#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA) + +#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1) #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzeroMMA(&accReal##iter); \ - bsetzeroMMA(&accImag##iter); \ + bsetzeroMMA(&accReal##iter); \ + bsetzeroMMA(&accImag##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accReal##iter); \ EIGEN_UNUSED_VARIABLE(accImag##iter); \ @@ -460,44 +503,35 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE) -#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ - lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ - } +#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) -#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE) - -#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \ - if (unroll_factor > iter) { \ - EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ - } - -#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE) +#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - storeComplexAccumulator(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \ + storeComplexAccumulator(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \ } #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) -template +template EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, + Index offsetA, Index strideB, Index& row, const Packet& pAlphaReal, - const Packet& pAlphaImag) + const Packet& pAlphaImag, + const Packet& pMask) { const Scalar* rhs_ptr_real = rhs_base; const Scalar* rhs_ptr_imag = NULL; const Index imag_delta = accCols*strideA; + const Index imag_delta2 = accCols2*strideA; if(!RhsIsReal) { rhs_ptr_imag = rhs_base + accRows*strideB; } else { @@ -510,8 +544,8 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_DST_PTR - Index k = 0; - for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA) + Index k = 0, depth2 = depth - PEEL_COMPLEX_MMA; + for(; k <= depth2; k += PEEL_COMPLEX_MMA) { EIGEN_POWER_PREFETCH(rhs_ptr_real); if(!RhsIsReal) { @@ -526,9 +560,13 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( } MICRO_COMPLEX_MMA_STORE - row += unroll_factor*accCols; + MICRO_COMPLEX_UPDATE } +#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \ + gemm_complex_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \ + if (M) return; + template EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( const DataMapper& res, @@ -541,7 +579,6 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( Index offsetB, Index col, Index rows, - Index cols, Index remaining_rows, const Packet& pAlphaReal, const Packet& pAlphaImag, @@ -555,27 +592,27 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( #define MAX_COMPLEX_MMA_UNROLL 4 while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { - gemm_complex_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_COMPLEX_MMA_UNROLL_ITER2(MAX_COMPLEX_MMA_UNROLL, 0); } switch( (rows-row)/accCols ) { #if MAX_COMPLEX_MMA_UNROLL > 4 case 4: - gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 4) break; #endif #if MAX_COMPLEX_MMA_UNROLL > 3 case 3: - gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3) break; #endif #if MAX_COMPLEX_MMA_UNROLL > 2 case 2: - gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2) break; #endif #if MAX_COMPLEX_MMA_UNROLL > 1 case 1: - gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + MICRO_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1) break; #endif default: @@ -585,7 +622,7 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( if(remaining_rows > 0) { - gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); } } @@ -604,13 +641,18 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS const Scalar* blockA = (Scalar *) blockAc; const Scalar* blockB = (Scalar *) blockBc; + typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2; + Index col = 0; for(; col + accRows <= cols; col += accRows) { - gemmMMA_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + gemmMMA_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); } - gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + if (col != cols) + { + gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } } #undef accColsC diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index b9b97313a..72596ebea 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -503,7 +503,11 @@ EIGEN_ALWAYS_INLINE Packet1cd pconj2(const Packet1cd& a) { /** \internal packet conjugate with real & imaginary operation inverted */ EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf& a) { +#ifdef __POWER8_VECTOR__ + return Packet2cf(Packet4f(vec_neg(Packet2d(a.v)))); +#else return Packet2cf(pxor(a.v, reinterpret_cast(p16uc_COMPLEX32_CONJ_XOR2))); +#endif } EIGEN_ALWAYS_INLINE Packet1cd pconjinv(const Packet1cd& a) { @@ -555,12 +559,20 @@ EIGEN_ALWAYS_INLINE Packet1cd pcplxconjflip(Packet1cd a) /** \internal packet negate */ EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a) { +#ifdef __POWER8_VECTOR__ + return Packet2cf(vec_neg(a.v)); +#else return Packet2cf(pxor(a.v, reinterpret_cast(p16uc_COMPLEX32_NEGATE))); +#endif } EIGEN_ALWAYS_INLINE Packet1cd pnegate2(Packet1cd a) { +#ifdef __POWER8_VECTOR__ + return Packet1cd(vec_neg(a.v)); +#else return Packet1cd(pxor(a.v, reinterpret_cast(p16uc_COMPLEX64_NEGATE))); +#endif } /** \internal flip the real & imaginary results and negate */ @@ -637,13 +649,24 @@ EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar* src, Packet2d& r, Packet2d& i #endif } +#ifndef __POWER8_VECTOR__ +const Packet16uc p16uc_MERGEE = { 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B }; + +const Packet16uc p16uc_MERGEO = { 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F }; +#endif + /** \internal load two vectors from the interleaved real & imaginary values of src */ template EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet4f& r, Packet4f& i) { Packet4f t = ploadu(reinterpret_cast(src)); +#ifdef __POWER8_VECTOR__ r = vec_mergee(t, t); i = vec_mergeo(t, t); +#else + r = vec_perm(t, t, p16uc_MERGEE); + i = vec_perm(t, t, p16uc_MERGEO); +#endif } template @@ -909,7 +932,7 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, A { PResPacket c2 = pcplxflipconj(c0); PResPacket c3 = pcplxflipconj(c1); -#if EIGEN_COMP_LLVM +#if EIGEN_COMP_LLVM || !defined(_ARCH_PWR10) ScalarPacket c4 = pload_complex(res + (iter2 * ResPacketSize)); ScalarPacket c5 = pload_complex(res + ((iter2 + 1) * ResPacketSize)); PResPacket c6 = PResPacket(pmadd_complex(c0.v, c2.v, c4, b0)); diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 3117f5659..563008c94 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -83,8 +83,10 @@ static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16} static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1} static EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u); static EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu); +#ifndef __POWER8_VECTOR__ static EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1} static EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1); +#endif static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000} #ifndef __VSX__ static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0} @@ -102,11 +104,13 @@ static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 }; static Packet16uc p16uc_REVERSE16 = { 14,15, 12,13, 10,11, 8,9, 6,7, 4,5, 2,3, 0,1 }; +#ifndef _ARCH_PWR9 static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 }; +#endif +#ifdef _BIG_ENDIAN static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 }; -static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 }; -static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 }; +#endif static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 }; static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 }; @@ -116,15 +120,11 @@ static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 // Define global static constants: #ifdef _BIG_ENDIAN static Packet16uc p16uc_FORWARD = vec_lvsl(0, (float*)0); -#ifdef __VSX__ -static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; -#endif static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 }; static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 }; static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16}; #else static Packet16uc p16uc_FORWARD = p16uc_REVERSE32; -static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 1), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 }; static Packet16uc p16uc_PSET32_WEVEN = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 }; static Packet16uc p16uc_HALF64_0_16 = vec_sld(vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 0), (Packet16uc)p4i_ZERO, 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16}; @@ -137,12 +137,6 @@ static Packet16uc p16uc_TRANSPOSE64_LO = p16uc_PSET64_LO + p16uc_HALF64_0_16; static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 }; -#ifdef _BIG_ENDIAN -static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; -#else -static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_PSET64_HI, p16uc_PSET64_LO, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; -#endif // _BIG_ENDIAN - #if EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC #define EIGEN_PPC_PREFETCH(ADDR) __builtin_prefetch(ADDR); #else @@ -788,8 +782,22 @@ template<> EIGEN_STRONG_INLINE Packet8us psub (const Packet8us& a, template<> EIGEN_STRONG_INLINE Packet16c psub (const Packet16c& a, const Packet16c& b) { return a - b; } template<> EIGEN_STRONG_INLINE Packet16uc psub(const Packet16uc& a, const Packet16uc& b) { return a - b; } -template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return p4f_ZERO - a; } -template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_ZERO - a; } +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) +{ +#ifdef __POWER8_VECTOR__ + return vec_neg(a); +#else + return p4f_ZERO - a; +#endif +} +template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) +{ +#ifdef __POWER8_VECTOR__ + return vec_neg(a); +#else + return p4i_ZERO - a; +#endif +} template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; } template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; } @@ -953,7 +961,10 @@ template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) template EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPACK_TYPE__(Packet)* from) { EIGEN_DEBUG_ALIGNED_LOAD -#ifdef _BIG_ENDIAN +#if defined(__VSX__) || !defined(_BIG_ENDIAN) + EIGEN_DEBUG_UNALIGNED_LOAD + return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from)); +#else Packet16uc MSQ, LSQ; Packet16uc mask; MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword @@ -961,9 +972,6 @@ template EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPAC mask = vec_lvsl(0, from); // create the permute mask //TODO: Add static_cast here return (Packet) vec_perm(MSQ, LSQ, mask); // align the data -#else - EIGEN_DEBUG_UNALIGNED_LOAD - return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from)); #endif } @@ -1001,7 +1009,7 @@ template EIGEN_STRONG_INLINE Packet ploaddup_common(const __UNP Packet p; if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); else p = ploadu(from); - return vec_perm(p, p, p16uc_DUPLICATE32_HI); + return vec_mergeh(p, p); } template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) { @@ -1017,7 +1025,7 @@ template<> EIGEN_STRONG_INLINE Packet8s ploaddup(const short int* Packet8s p; if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); else p = ploadu(from); - return vec_perm(p, p, p16uc_DUPLICATE16_HI); + return vec_mergeh(p, p); } template<> EIGEN_STRONG_INLINE Packet8us ploaddup(const unsigned short int* from) @@ -1025,7 +1033,7 @@ template<> EIGEN_STRONG_INLINE Packet8us ploaddup(const unsigned shor Packet8us p; if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); else p = ploadu(from); - return vec_perm(p, p, p16uc_DUPLICATE16_HI); + return vec_mergeh(p, p); } template<> EIGEN_STRONG_INLINE Packet8s ploadquad(const short int* from) @@ -1054,7 +1062,7 @@ template<> EIGEN_STRONG_INLINE Packet16c ploaddup(const signed char* Packet16c p; if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); else p = ploadu(from); - return vec_perm(p, p, p16uc_DUPLICATE8_HI); + return vec_mergeh(p, p); } template<> EIGEN_STRONG_INLINE Packet16uc ploaddup(const unsigned char* from) @@ -1062,13 +1070,15 @@ template<> EIGEN_STRONG_INLINE Packet16uc ploaddup(const unsigned ch Packet16uc p; if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); else p = ploadu(from); - return vec_perm(p, p, p16uc_DUPLICATE8_HI); + return vec_mergeh(p, p); } template EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE__(Packet)* to, const Packet& from) { EIGEN_DEBUG_UNALIGNED_STORE -#ifdef _BIG_ENDIAN +#if defined(__VSX__) || !defined(_BIG_ENDIAN) + vec_xst(from, 0, to); +#else // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html // Warning: not thread safe! Packet16uc MSQ, LSQ, edges; @@ -1083,8 +1093,6 @@ template EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE_ LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ) vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second -#else - vec_xst(from, 0, to); #endif } template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) @@ -1164,11 +1172,19 @@ template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a) } template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a) { +#ifdef _ARCH_PWR9 + return vec_revb(a); +#else return vec_perm(a, a, p16uc_REVERSE8); +#endif } template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a) { +#ifdef _ARCH_PWR9 + return vec_revb(a); +#else return vec_perm(a, a, p16uc_REVERSE8); +#endif } template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) { @@ -2102,7 +2118,11 @@ ptranspose(PacketBlock& kernel) { template EIGEN_STRONG_INLINE Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) { Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] }; +#ifdef __POWER8_VECTOR__ + Packet4ui mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); +#else Packet4ui mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), reinterpret_cast(p4i_ONE))); +#endif return vec_sel(elsePacket, thenPacket, mask); } @@ -2117,7 +2137,11 @@ template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, cons template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) { Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; +#ifdef __POWER8_VECTOR__ + Packet8us mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); +#else Packet8us mask = reinterpret_cast(vec_cmpeq(select, p8us_ONE)); +#endif Packet8s result = vec_sel(elsePacket, thenPacket, mask); return result; } @@ -2125,7 +2149,11 @@ template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, cons template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) { Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; +#ifdef __POWER8_VECTOR__ + Packet8us mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); +#else Packet8us mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p8us_ONE)); +#endif return vec_sel(elsePacket, thenPacket, mask); } @@ -2139,7 +2167,11 @@ template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, co ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; +#ifdef __POWER8_VECTOR__ + Packet16uc mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); +#else Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); +#endif return vec_sel(elsePacket, thenPacket, mask); } @@ -2149,7 +2181,11 @@ template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, c ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; +#ifdef __POWER8_VECTOR__ + Packet16uc mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); +#else Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); +#endif return vec_sel(elsePacket, thenPacket, mask); } @@ -2395,7 +2431,14 @@ template<> EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return a - b; } -template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return p2d_ZERO - a; } +template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) +{ +#ifdef __POWER8_VECTOR__ + return vec_neg(a); +#else + return p2d_ZERO - a; +#endif +} template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; } @@ -2487,7 +2530,7 @@ template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { EIG template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { - return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE64)); + return vec_sld(a, a, 8); } template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); } @@ -2692,8 +2735,8 @@ template<> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { Packet2d t0, t1; - t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI); - t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO); + t0 = vec_mergeh(kernel.packet[0], kernel.packet[1]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[1]); kernel.packet[0] = t0; kernel.packet[1] = t1; } diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index ea6b87699..f45665e3b 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -213,6 +213,11 @@ public: return ploadt(&operator()(i, j)); } + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType &p) const { + pstoret(&operator()(i, j), p); + } + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { pscatter(&operator()(i, j), p, m_stride); @@ -311,6 +316,11 @@ public: return pgather(&operator()(i, j),m_incr.value()); } + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Index j, const PacketType &p) const { + pscatter(&operator()(i, j), p, m_incr.value()); + } + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { pscatter(&operator()(i, j), p, m_stride);