diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index a40d4cbb0..4cc0a94ff 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -104,12 +104,6 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; -const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; - -//[a,ai],[b,bi] = [ai,bi] -const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; /********************************************* * Single precision real and complex packing * @@ -441,6 +435,7 @@ struct dhs_cpack { for(; j + vectorSize <= rows; j+=vectorSize) { + const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j); Index i = 0; rii = rir + vectorDelta; @@ -451,9 +446,9 @@ struct dhs_cpack { PacketBlock cblock; if (UseLhs) { - bload(cblock, lhs, j, i); + bload(cblock, lhs2, 0, i); } else { - bload(cblock, lhs, i, j); + bload(cblock, lhs2, i, 0); } blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); @@ -494,19 +489,19 @@ struct dhs_cpack { if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) { if (UseLhs) { - cblock.packet[0] = lhs.template loadPacket(j + 0, i); - cblock.packet[1] = lhs.template loadPacket(j + 2, i); + cblock.packet[0] = lhs2.template loadPacket(0, i); + cblock.packet[1] = lhs2.template loadPacket(2, i); } else { - cblock.packet[0] = lhs.template loadPacket(i, j + 0); - cblock.packet[1] = lhs.template loadPacket(i, j + 2); + cblock.packet[0] = lhs2.template loadPacket(i, 0); + cblock.packet[1] = lhs2.template loadPacket(i, 2); } } else { if (UseLhs) { - cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i)); - cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i)); + cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i)); + cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i)); } else { - cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1)); - cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3)); + cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1)); + cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3)); } } @@ -534,16 +529,17 @@ struct dhs_cpack { for(; j < rows; j++) { + const DataMapper lhs2 = lhs.getSubMapper(0, j); rii = rir + ((PanelMode) ? stride : depth); for(Index i = 0; i < depth; i++) { - blockAt[rir] = lhs(i, j).real(); + blockAt[rir] = lhs2(i, 0).real(); if(Conjugate) - blockAt[rii] = -lhs(i, j).imag(); + blockAt[rii] = -lhs2(i, 0).imag(); else - blockAt[rii] = lhs(i, j).imag(); + blockAt[rii] = lhs2(i, 0).imag(); rir += 1; rii += 1; @@ -588,6 +584,7 @@ struct dhs_pack{ for(; j + vectorSize <= rows; j+=vectorSize) { + const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j); Index i = 0; if(PanelMode) ri += vectorSize*offset; @@ -597,9 +594,9 @@ struct dhs_pack{ PacketBlock block; if (UseLhs) { - bload(block, lhs, j, i); + bload(block, lhs2, 0, i); } else { - bload(block, lhs, i, j); + bload(block, lhs2, i, 0); } if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) { @@ -615,22 +612,22 @@ struct dhs_pack{ if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) { if (UseLhs) { - blockA[ri+0] = lhs(j+0, i); - blockA[ri+1] = lhs(j+1, i); - blockA[ri+2] = lhs(j+2, i); - blockA[ri+3] = lhs(j+3, i); + blockA[ri+0] = lhs2(0, i); + blockA[ri+1] = lhs2(1, i); + blockA[ri+2] = lhs2(2, i); + blockA[ri+3] = lhs2(3, i); } else { - blockA[ri+0] = lhs(i, j+0); - blockA[ri+1] = lhs(i, j+1); - blockA[ri+2] = lhs(i, j+2); - blockA[ri+3] = lhs(i, j+3); + blockA[ri+0] = lhs2(i, 0); + blockA[ri+1] = lhs2(i, 1); + blockA[ri+2] = lhs2(i, 2); + blockA[ri+3] = lhs2(i, 3); } } else { Packet lhsV; if (UseLhs) { - lhsV = lhs.template loadPacket(j, i); + lhsV = lhs2.template loadPacket(0, i); } else { - lhsV = lhs.template loadPacket(i, j); + lhsV = lhs2.template loadPacket(i, 0); } pstore(blockA + ri, lhsV); } @@ -647,9 +644,10 @@ struct dhs_pack{ for(; j < rows; j++) { + const DataMapper lhs2 = lhs.getSubMapper(0, j); for(Index i = 0; i < depth; i++) { - blockA[ri] = lhs(i, j); + blockA[ri] = lhs2(i, 0); ri += 1; } @@ -685,6 +683,7 @@ struct dhs_pack for(; j + vectorSize <= rows; j+=vectorSize) { + const DataMapper lhs2 = lhs.getSubMapper(j, 0); Index i = 0; if(PanelMode) ri += vectorSize*offset; @@ -694,13 +693,13 @@ struct dhs_pack PacketBlock block; if(StorageOrder == RowMajor) { - block.packet[0] = lhs.template loadPacket(j + 0, i); - block.packet[1] = lhs.template loadPacket(j + 1, i); + block.packet[0] = lhs2.template loadPacket(0, i); + block.packet[1] = lhs2.template loadPacket(1, i); ptranspose(block); } else { - block.packet[0] = lhs.template loadPacket(j, i + 0); - block.packet[1] = lhs.template loadPacket(j, i + 1); + block.packet[0] = lhs2.template loadPacket(0, i + 0); + block.packet[1] = lhs2.template loadPacket(0, i + 1); } storeBlock(blockA + ri, block); @@ -711,10 +710,10 @@ struct dhs_pack { if(StorageOrder == RowMajor) { - blockA[ri+0] = lhs(j+0, i); - blockA[ri+1] = lhs(j+1, i); + blockA[ri+0] = lhs2(0, i); + blockA[ri+1] = lhs2(1, i); } else { - Packet2d lhsV = lhs.template loadPacket(j, i); + Packet2d lhsV = lhs2.template loadPacket(0, i); pstore(blockA + ri, lhsV); } @@ -752,6 +751,7 @@ struct dhs_pack for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { + const DataMapper rhs2 = rhs.getSubMapper(0, j); Index i = 0; if(PanelMode) ri += offset*(2*vectorSize); @@ -762,10 +762,10 @@ struct dhs_pack if(StorageOrder == ColMajor) { PacketBlock block1, block2; - block1.packet[0] = rhs.template loadPacket(i, j + 0); - block1.packet[1] = rhs.template loadPacket(i, j + 1); - block2.packet[0] = rhs.template loadPacket(i, j + 2); - block2.packet[1] = rhs.template loadPacket(i, j + 3); + block1.packet[0] = rhs2.template loadPacket(i, 0); + block1.packet[1] = rhs2.template loadPacket(i, 1); + block2.packet[0] = rhs2.template loadPacket(i, 2); + block2.packet[1] = rhs2.template loadPacket(i, 3); ptranspose(block1); ptranspose(block2); @@ -775,10 +775,10 @@ struct dhs_pack pstore(blockB + ri + 4, block1.packet[1]); pstore(blockB + ri + 6, block2.packet[1]); } else { - block.packet[0] = rhs.template loadPacket(i + 0, j + 0); //[a1 a2] - block.packet[1] = rhs.template loadPacket(i + 0, j + 2); //[a3 a4] - block.packet[2] = rhs.template loadPacket(i + 1, j + 0); //[b1 b2] - block.packet[3] = rhs.template loadPacket(i + 1, j + 2); //[b3 b4] + block.packet[0] = rhs2.template loadPacket(i + 0, 0); //[a1 a2] + block.packet[1] = rhs2.template loadPacket(i + 0, 2); //[a3 a4] + block.packet[2] = rhs2.template loadPacket(i + 1, 0); //[b1 b2] + block.packet[3] = rhs2.template loadPacket(i + 1, 2); //[b3 b4] storeBlock(blockB + ri, block); } @@ -789,20 +789,20 @@ struct dhs_pack { if(StorageOrder == ColMajor) { - blockB[ri+0] = rhs(i, j+0); - blockB[ri+1] = rhs(i, j+1); + blockB[ri+0] = rhs2(i, 0); + blockB[ri+1] = rhs2(i, 1); ri += vectorSize; - blockB[ri+0] = rhs(i, j+2); - blockB[ri+1] = rhs(i, j+3); + blockB[ri+0] = rhs2(i, 2); + blockB[ri+1] = rhs2(i, 3); } else { - Packet2d rhsV = rhs.template loadPacket(i, j); + Packet2d rhsV = rhs2.template loadPacket(i, 0); pstore(blockB + ri, rhsV); ri += vectorSize; - rhsV = rhs.template loadPacket(i, j + 2); + rhsV = rhs2.template loadPacket(i, 2); pstore(blockB + ri, rhsV); } ri += vectorSize; @@ -815,9 +815,10 @@ struct dhs_pack for(; j < cols; j++) { + const DataMapper rhs2 = rhs.getSubMapper(0, j); for(Index i = 0; i < depth; i++) { - blockB[ri] = rhs(i, j); + blockB[ri] = rhs2(i, 0); ri += 1; } @@ -840,6 +841,7 @@ struct dhs_cpack(j, i + 0); //[a1 a1i] - cblock.packet[1] = lhs.template loadPacket(j, i + 1); //[b1 b1i] + cblock.packet[0] = lhs2.template loadPacket(0, i + 0); //[a1 a1i] + cblock.packet[1] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] - cblock.packet[2] = lhs.template loadPacket(j + 1, i + 0); //[a2 a2i] - cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] + cblock.packet[2] = lhs2.template loadPacket(1, i + 0); //[a2 a2i] + cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i] - blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] - blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2] + blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2] - blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); - blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v); + blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v); } else { - cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] - cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] + cblock.packet[0] = lhs2.template loadPacket(0, i); //[a1 a1i] + cblock.packet[1] = lhs2.template loadPacket(1, i); //[a2 a2i] - cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); //[b1 b1i] - cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i + cblock.packet[2] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] + cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i - blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] - blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2] + blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2] - blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); - blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); + blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); } if(Conjugate) @@ -893,11 +895,11 @@ struct dhs_cpack blockr, blocki; PacketBlock cblock; - cblock.packet[0] = lhs.template loadPacket(j + 0, i); - cblock.packet[1] = lhs.template loadPacket(j + 1, i); + cblock.packet[0] = lhs2.template loadPacket(0, i); + cblock.packet[1] = lhs2.template loadPacket(1, i); - blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); - blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); if(Conjugate) { @@ -953,6 +955,7 @@ struct dhs_cpack cblock; PacketBlock blockr, blocki; - bload(cblock, rhs, i, j); + bload(cblock, rhs2, i, 0); - blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); - blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); + blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); - blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); - blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); + blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); if(Conjugate) { @@ -990,16 +993,17 @@ struct dhs_cpack predux_complex(ResPacket& a, ResPa GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \ Index j = 0; \ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \ - RhsPacket a0 = rhs2.template load(j, 0); \ + RhsPacket a0 = rhs2.template load(j); \ GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \ } \ GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \ for (; j < cols; ++j) { \ - RhsScalar a0 = rhs2(j, 0); \ + RhsScalar a0 = rhs2(j); \ GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \ } \ GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \ @@ -1965,7 +1965,7 @@ EIGEN_STRONG_INLINE void gemv_row( // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate proper code. LhsMapper lhs(alhs); - RhsMapper rhs2(rhs); + typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0); eigen_internal_assert(rhs.stride() == 1); conj_helper cj; @@ -2006,14 +2006,14 @@ EIGEN_STRONG_INLINE void gemv_row( Index j = 0; for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { - RhsPacket b0 = rhs2.template load(j, 0); + RhsPacket b0 = rhs2.template load(j); d0 = pcj.pmadd(lhs.template load(i + 0, j), b0, d0); } ResScalar dd0 = predux(d0); for (; j < cols; ++j) { - dd0 += cj.pmul(lhs(i, j), rhs2(j, 0)); + dd0 += cj.pmul(lhs(i, j), rhs2(j)); } res[i * resIncr] += alpha * dd0; } @@ -2075,14 +2075,14 @@ EIGEN_ALWAYS_INLINE ScalarBlock predux_complex(PResPacket& a0, PRe #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \ j = 0; \ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \ - const RhsScalar& b1 = rhs2(j, 0); \ + const RhsScalar& b1 = rhs2(j); \ RhsScalar* b = const_cast(&b1); \ GEMV_UNROLL_ROW(which, N) \ } #define GEMV_PROCESS_END_ROW_COMPLEX(N) \ for (; j < cols; ++j) { \ - RhsScalar b0 = rhs2(j, 0); \ + RhsScalar b0 = rhs2(j); \ GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \ } \ GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1)) @@ -2216,7 +2216,7 @@ EIGEN_ALWAYS_INLINE ScalarBlock predux_complex(PResPacket& a0, PRe GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \ j = 0; \ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \ - RhsPacket b0 = rhs2.template load(j, 0); \ + RhsPacket b0 = rhs2.template load(j); \ GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \ } @@ -2289,7 +2289,7 @@ EIGEN_STRONG_INLINE void gemv_complex_row( // The following copy tells the compiler that lhs's attributes are not modified outside this function // This helps GCC to generate proper code. LhsMapper lhs(alhs); - RhsMapper rhs2(rhs); + typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0); eigen_internal_assert(rhs.stride() == 1); conj_helper cj; @@ -2340,7 +2340,7 @@ EIGEN_STRONG_INLINE void gemv_complex_row( GEMV_PROCESS_ROW_COMPLEX_PREDUX(0) for (; j < cols; ++j) { - dd0 += cj.pmul(lhs(i, j), rhs2(j, 0)); + dd0 += cj.pmul(lhs(i, j), rhs2(j)); } res[i * resIncr] += alpha * dd0; } diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 227d4f30f..92cbaf6ff 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -416,6 +416,7 @@ class TensorContractionSubMapper { typedef BaseTensorContractionMapper ParentMapper; typedef TensorContractionSubMapper Self; typedef Self LinearMapper; + typedef Self SubMapper; enum { // We can use direct offsets iff the parent mapper supports then and we can compute the strides. @@ -485,6 +486,13 @@ class TensorContractionSubMapper { return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const { + if (UseDirectOffsets) { + return SubMapper(m_base_mapper, i, j); + } + return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); + } + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -531,6 +539,7 @@ class TensorContractionInputMapper typedef BaseTensorContractionMapper Base; typedef TensorContractionSubMapper SubMapper; typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides, @@ -544,6 +553,10 @@ class TensorContractionInputMapper return SubMapper(*this, i, j); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { return VectorMapper(*this, i, j); }