Fix ColMajor BF16 GEMV for when vector is RowMajor

This commit is contained in:
Chip Kerchner 2023-05-03 20:12:50 +00:00 committed by Rasmus Munk Larsen
parent fdc749de2a
commit fda1373a15
4 changed files with 104 additions and 65 deletions

View File

@ -102,6 +102,9 @@ EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], flo
template<Index num_acc, Index size = 4> template<Index num_acc, Index size = 4>
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha); EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha);
template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool>>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j);
template<typename Packet> template<typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs); EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);

View File

@ -383,12 +383,12 @@ EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (
} }
} }
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero> template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc]) EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc])
{ {
Packet8bf a0[num_acc]; Packet8bf a0[num_acc];
Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0)); Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0));
Packet8bf b0 = rhs.template loadPacket<Packet8bf>(j + 0); Packet8bf b0 = loadColData<RhsMapper, linear>(rhs, j);
if (zero) { if (zero) {
b0 = vec_mergeh(b0.m_val, b1.m_val); b0 = vec_mergeh(b0.m_val, b1.m_val);
@ -405,7 +405,7 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v
#define MAX_BFLOAT16_VEC_ACC 8 #define MAX_BFLOAT16_VEC_ACC 8
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows> template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
constexpr Index step = (num_acc * 4); constexpr Index step = (num_acc * 4);
@ -420,10 +420,10 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
LhsMapper lhs2 = lhs.getSubMapper(row, 0); LhsMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) { for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc); vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
} }
if (cend & 1) { if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, quad_acc); vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
} }
disassembleAccumulators<num_acc>(quad_acc, acc); disassembleAccumulators<num_acc>(quad_acc, acc);
@ -434,59 +434,59 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
} while(multiIters && (step <= rows - (row += step))); } while(multiIters && (step <= rows - (row += step)));
} }
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows> template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
if (MAX_BFLOAT16_VEC_ACC > num_acc) { if (MAX_BFLOAT16_VEC_ACC > num_acc) {
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
} }
template<typename LhsMapper, typename RhsMapper, bool extraRows> template<typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
switch ((rows - row) >> 2) { switch ((rows - row) >> 2) {
case 7: case 7:
colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 6: case 6:
colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 5: case 5:
colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 4: case 4:
colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 3: case 3:
colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 2: case 2:
colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 1: case 1:
colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
default: default:
if (extraRows) { if (extraRows) {
colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
break; break;
} }
} }
template<typename LhsMapper, typename RhsMapper> template<typename LhsMapper, typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
Index row = 0; Index row = 0;
if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) { if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
result += row; result += row;
} }
if (rows & 3) { if (rows & 3) {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} else { } else {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result); colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
} }
@ -524,8 +524,13 @@ void gemvMMA_bfloat16_col(
Index jend = numext::mini(j2 + block_cols, cols); Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2); LhsMapper lhs2 = lhs.getSubMapper(0, j2);
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); if (rhs.stride() == 1) {
calcVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result); LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
} }
convertArrayPointerF32toBF16(result, rows, res); convertArrayPointerF32toBF16(result, rows, res);

View File

@ -521,11 +521,23 @@ EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[
} }
} }
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero> template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool> = true>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j)
{
return rhs.template loadPacket<Packet8bf>(j + 0);
}
template<typename RhsMapper, bool linear, std::enable_if_t<!linear, bool> = true>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j)
{
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride());
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2]) EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2])
{ {
Packet4f a0[num_acc][2], b0[2]; Packet4f a0[num_acc][2], b0[2];
Packet8bf b2 = rhs.template loadPacket<Packet8bf>(j + 0); Packet8bf b2 = loadColData<RhsMapper, linear>(rhs, j);
b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1); b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1);
if (!zero) { if (!zero) {
@ -551,7 +563,7 @@ EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2])
// Uses 2X the accumulators or 4X the number of VSX registers // Uses 2X the accumulators or 4X the number of VSX registers
#define MAX_BFLOAT16_VEC_ACC_VSX 8 #define MAX_BFLOAT16_VEC_ACC_VSX 8
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows> template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
constexpr Index step = (num_acc * 4); constexpr Index step = (num_acc * 4);
@ -565,10 +577,10 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
LhsMapper lhs2 = lhs.getSubMapper(row, 0); LhsMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) { for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, acc); vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
} }
if (cend & 1) { if (cend & 1) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, acc); vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
} }
addResultsVSX<num_acc>(acc); addResultsVSX<num_acc>(acc);
@ -579,59 +591,59 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
} while(multiIters && (step <= rows - (row += step))); } while(multiIters && (step <= rows - (row += step)));
} }
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows> template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) { if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
} }
template<typename LhsMapper, typename RhsMapper, bool extraRows> template<typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
switch ((rows - row) >> 2) { switch ((rows - row) >> 2) {
case 7: case 7:
colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 6: case 6:
colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 5: case 5:
colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 4: case 4:
colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 3: case 3:
colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 2: case 2:
colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
case 1: case 1:
colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break; break;
default: default:
if (extraRows) { if (extraRows) {
colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
break; break;
} }
} }
template<typename LhsMapper, typename RhsMapper> template<typename LhsMapper, typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
Index row = 0; Index row = 0;
if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) { if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) {
colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
result += row; result += row;
} }
if (rows & 3) { if (rows & 3) {
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} else { } else {
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result); colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} }
} }
@ -724,8 +736,13 @@ void gemv_bfloat16_col(
Index jend = numext::mini(j2 + block_cols, cols); Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2); LhsMapper lhs2 = lhs.getSubMapper(0, j2);
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); if (rhs.stride() == 1) {
calcVSXVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result); LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
} }
convertArrayPointerF32toBF16VSX(result, rows, res); convertArrayPointerF32toBF16VSX(result, rows, res);

View File

@ -796,12 +796,20 @@ template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pgather_c
{ {
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size]; EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size];
eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will gather past end of packet"); eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will gather past end of packet");
LOAD_STORE_UNROLL_16 if (stride == 1) {
for (Index i = 0; i < n; i++) { if (n == unpacket_traits<Packet>::size) {
a[i] = from[i*stride]; return ploadu<Packet>(from);
} else {
return ploadu_partial<Packet>(from, n);
}
} else {
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
a[i] = from[i*stride];
}
// Leave rest of the array uninitialized
return pload_ignore<Packet>(a);
} }
// Leave rest of the array uninitialized
return pload_ignore<Packet>(a);
} }
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4f pgather<float, Packet4f>(const float* from, Index stride) template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4f pgather<float, Packet4f>(const float* from, Index stride)
@ -878,10 +886,18 @@ template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pscatter_co
{ {
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size]; EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size];
eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will scatter past end of packet"); eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will scatter past end of packet");
pstore<__UNPACK_TYPE__(Packet)>(a, from); if (stride == 1) {
LOAD_STORE_UNROLL_16 if (n == unpacket_traits<Packet>::size) {
for (Index i = 0; i < n; i++) { return pstoreu(to, from);
to[i*stride] = a[i]; } else {
return pstoreu_partial(to, from, n);
}
} else {
pstore<__UNPACK_TYPE__(Packet)>(a, from);
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
to[i*stride] = a[i];
}
} }
} }
@ -1256,15 +1272,14 @@ template<typename Packet> EIGEN_ALWAYS_INLINE Packet ploadu_partial_common(const
return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size); return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size);
#else #else
if (n) { if (n) {
Index n2 = n * size;
if (16 <= n2) {
return ploadu<Packet>(from);
}
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
unsigned char* load2 = reinterpret_cast<unsigned char *>(load); unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from)); unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
Index n2 = n * size; memcpy((void *)load2, (void *)from2, n2);
if (16 <= n2) {
pstore(load2, ploadu<Packet16uc>(from2));
} else {
memcpy((void *)load2, (void *)from2, n2);
}
return pload_ignore<Packet>(load); return pload_ignore<Packet>(load);
} else { } else {
return Packet(pset1<Packet16uc>(0)); return Packet(pset1<Packet16uc>(0));
@ -1432,16 +1447,15 @@ template<typename Packet> EIGEN_ALWAYS_INLINE void pstoreu_partial_common(__UNPA
vec_xst_len(from, to, n * size); vec_xst_len(from, to, n * size);
#else #else
if (n) { if (n) {
Index n2 = n * size;
if (16 <= n2) {
pstoreu(to, from);
}
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
pstore(store, from); pstore(store, from);
unsigned char* store2 = reinterpret_cast<unsigned char *>(store); unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
unsigned char* to2 = reinterpret_cast<unsigned char *>(to); unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
Index n2 = n * size; memcpy((void *)to2, (void *)store2, n2);
if (16 <= n2) {
pstoreu(to2, pload<Packet16uc>(store2));
} else {
memcpy((void *)to2, (void *)store2, n2);
}
} }
#endif #endif
} }