mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 04:09:10 +08:00
WIP
This commit is contained in:
parent
421891e1db
commit
5bffe09624
@ -560,7 +560,7 @@ EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA
|
||||
#endif
|
||||
|
||||
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES_COUNT = 2;
|
||||
constexpr int SHAPES_COUNT = 4;
|
||||
|
||||
constexpr int SHAPES_DIMENSION = 4;
|
||||
constexpr int SHAPES_LHS_DIMENSION = 0;
|
||||
@ -578,7 +578,7 @@ constexpr int PACK_SHAPES_END = -1;
|
||||
|
||||
// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map
|
||||
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,4,0}};
|
||||
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1}};
|
||||
|
||||
// d1progress x d2progress
|
||||
template<int Architecture, int CPU, typename Scalar, bool isLhs>
|
||||
@ -694,8 +694,7 @@ template<int Architecture, int CPU, typename Index, typename Scalar, typename Da
|
||||
struct PackMapCalculator
|
||||
{
|
||||
PackMapCalculator<Architecture, CPU, Index, Scalar, DataMapper, isLhs, PACK_SHAPES<Architecture, CPU, Scalar, isLhs>[IDX][PACK_SHAPES_POINTER]> pmc;
|
||||
|
||||
inline Index getPosition(Index pos, Index d2Size)
|
||||
EIGEN_STRONG_INLINE Index getPosition(Index pos, Index d2Size)
|
||||
{
|
||||
constexpr auto d1Progress = PACK_SHAPES<Architecture, CPU, Scalar, isLhs>[IDX][0];
|
||||
Index v = (pos / d1Progress) * d1Progress;
|
||||
@ -706,7 +705,7 @@ struct PackMapCalculator
|
||||
template<int Architecture, int CPU, typename Index, typename Scalar, typename DataMapper, bool isLhs>
|
||||
struct PackMapCalculator<Architecture, CPU, Index, Scalar, DataMapper, isLhs, -1>
|
||||
{
|
||||
inline Index getPosition(Index, Index) { return Index(0); }
|
||||
EIGEN_STRONG_INLINE Index getPosition(Index, Index) { return Index(0); }
|
||||
};
|
||||
|
||||
template<int Architecture, int CPU, typename Index, typename Scalar, typename DataMapper, bool isLhs>
|
||||
@ -719,41 +718,87 @@ struct PackMap
|
||||
|
||||
PackMap(const Scalar *base, Index d2Size) : pBase(base), pCur(base), d2Size(d2Size) {}
|
||||
|
||||
inline void resetCur() { pCur = pBase; }
|
||||
inline void moveTo(Index pos)
|
||||
{
|
||||
Index inc = pmc.getPosition(pos, d2Size);
|
||||
std::cout << isLhs << " MOVE_TO " << pos << " " << inc << std::endl;
|
||||
pCur = pBase + inc;
|
||||
}
|
||||
inline void advance(int progress) { pCur += progress; }
|
||||
EIGEN_STRONG_INLINE void resetCur() { pCur = pBase; }
|
||||
EIGEN_STRONG_INLINE void moveTo(Index p1) { pCur = pBase + pmc.getPosition(p1, d2Size); }
|
||||
EIGEN_STRONG_INLINE void advance(int progress) { pCur += progress; }
|
||||
};
|
||||
|
||||
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename RhsScalar, typename AccScalar, typename ResScalar, typename DataMapper, int M, int K, int N>
|
||||
template<int Architecture, int CPU, typename Scalar, typename ResScalar, typename DataMapper, int M, int N>
|
||||
struct Accumulator
|
||||
{
|
||||
Scalar dt[M][N];
|
||||
|
||||
EIGEN_STRONG_INLINE void zero()
|
||||
{
|
||||
for(auto i = 0; i < M; i++)
|
||||
{
|
||||
for(auto j = 0; j < N; j++)
|
||||
{
|
||||
dt[i][j] = Scalar(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void scale(ResScalar alpha)
|
||||
{
|
||||
for(auto i = 0; i < M; i++)
|
||||
{
|
||||
for(auto j = 0; j < N; j++)
|
||||
{
|
||||
dt[i][j] *= alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col)
|
||||
{
|
||||
for(auto i = 0; i < M; i++)
|
||||
{
|
||||
for(auto j = 0; j < N; j++)
|
||||
{
|
||||
dest(row + i, col + j) = dt[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename RhsScalar, typename AccScalar, typename ResScalar, typename DataMapper, int SHAPE_IDX, int M, int K, int N>
|
||||
struct MicroKernel
|
||||
{
|
||||
EIGEN_STRONG_INLINE void operator()(PackMap<Architecture, CPU, Index, LhsScalar, DataMapper, true>& lhsPackMap, PackMap<Architecture, CPU, Index, RhsScalar, DataMapper, false>& rhsPackMap, Index rowIdx, Index colIdx, Index depthIdx)
|
||||
EIGEN_STRONG_INLINE void operator()(PackMap<Architecture, CPU, Index, LhsScalar, DataMapper, true>& lhsPackMap,
|
||||
PackMap<Architecture, CPU, Index, RhsScalar, DataMapper, false>& rhsPackMap,
|
||||
Index rowIdx, Index colIdx, Index depthIdx,
|
||||
Accumulator<Architecture, CPU, AccScalar, ResScalar, DataMapper, M, N>& acc)
|
||||
{
|
||||
std::cout << "Kernel " << M << " x " << K << " x " << N << " @ " << rowIdx << ", " << depthIdx << ", " << colIdx << std::endl;
|
||||
std::cout << "LHS ";
|
||||
for(auto i = rowIdx; i < M + rowIdx; i++)
|
||||
for(auto i = 0; i < M; i++)
|
||||
{
|
||||
for(auto j = depthIdx; j < K + depthIdx; j++)
|
||||
for(auto j = 0; j < K; j++)
|
||||
{
|
||||
std::cout << *lhsPackMap.pCur << " ";
|
||||
lhsPackMap.advance(1);
|
||||
std::cout << lhsPackMap.pCur[i*K + j] << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl << "RHS ";
|
||||
for(auto i = depthIdx; i < K + depthIdx; i++)
|
||||
for(auto i = 0; i < K; i++)
|
||||
{
|
||||
for(auto j = colIdx; j < N + colIdx; j++)
|
||||
for(auto j = 0; j < N; j++)
|
||||
{
|
||||
std::cout << *rhsPackMap.pCur << " ";
|
||||
rhsPackMap.advance(1);
|
||||
std::cout << rhsPackMap.pCur[i*N + j] << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
const RhsScalar *pRhs = rhsPackMap.pCur;
|
||||
for(auto i = 0; i < N; i++)
|
||||
{
|
||||
const LhsScalar *pLhs = lhsPackMap.pCur;
|
||||
for(auto j = 0; j < M; j++)
|
||||
{
|
||||
acc.dt[j][i] += pRhs[i]*pLhs[j];
|
||||
}
|
||||
}
|
||||
lhsPackMap.advance(M*K);
|
||||
rhsPackMap.advance(K*N);
|
||||
};
|
||||
};
|
||||
|
||||
@ -770,11 +815,15 @@ struct DepthLoopStruct
|
||||
|
||||
if(rhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_DIMENSION])
|
||||
{
|
||||
MicroKernel<Architecture, CPU, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, lhsProgress, depthProgress, rhsProgress> mkt;
|
||||
MicroKernel<Architecture, CPU, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, IDX, lhsProgress, depthProgress, rhsProgress> mkt;
|
||||
Accumulator<Architecture, CPU, AccScalar, ResScalar, DataMapper, lhsProgress, rhsProgress> acc;
|
||||
acc.zero();
|
||||
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
|
||||
{
|
||||
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx);
|
||||
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
|
||||
}
|
||||
acc.scale(alpha);
|
||||
acc.store(res, rowIdx, colIdx);
|
||||
}
|
||||
depthLS(rowIdx, colIdx, depthIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap);
|
||||
}
|
||||
@ -825,10 +874,8 @@ struct RhsLoopStruct
|
||||
{
|
||||
constexpr auto rhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_DIMENSION];
|
||||
|
||||
std::cout << __PRETTY_FUNCTION__ << std::endl;
|
||||
for(;colIdx + rhsProgress <= cols; colIdx+=rhsProgress)
|
||||
{
|
||||
//rhsPackMap.moveTo(colIdx);
|
||||
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, IDX, IDX> lhsLS;
|
||||
lhsLS(0, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap);
|
||||
}
|
||||
@ -843,7 +890,7 @@ struct RhsLoopStruct<Architecture, CPU, Index, LhsScalar, RhsScalar, AccScalar,
|
||||
Index, Index, Index, ResScalar, Index, Index, Index, Index, PackMap<Architecture, CPU, Index, LhsScalar, DataMapper, true>&, PackMap<Architecture, CPU, Index, RhsScalar, DataMapper, false>&) {}
|
||||
};
|
||||
|
||||
template<typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
|
||||
template<int Architecture, int CPU, typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
|
||||
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
|
||||
Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
@ -864,9 +911,9 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, co
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
RhsLoopStruct<0, 0, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, SHAPES_COUNT<0, 0, LhsScalar, RhsScalar>-1> rhsLS;
|
||||
PackMap<0, 0, Index, LhsScalar, DataMapper, true> lhsPackMap(blockA, depth);
|
||||
PackMap<0, 0, Index, RhsScalar, DataMapper, false> rhsPackMap(blockB, depth);
|
||||
RhsLoopStruct<Architecture, CPU, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, SHAPES_COUNT<0, 0, LhsScalar, RhsScalar>-1> rhsLS;
|
||||
PackMap<Architecture, CPU, Index, LhsScalar, DataMapper, true> lhsPackMap(blockA, depth);
|
||||
PackMap<Architecture, CPU, Index, RhsScalar, DataMapper, false> rhsPackMap(blockB, depth);
|
||||
rhsLS(0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap);
|
||||
}
|
||||
|
||||
@ -940,7 +987,7 @@ void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, Conjugat
|
||||
Index rows, Index depth, Index cols, float alpha,
|
||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
gemm<float, float, float, float, Index, DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||
gemm<0, 0, float, float, float, float, Index, DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||
}
|
||||
} // end namespace internal
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user