mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-01 16:24:28 +08:00
Add load vector_pairs for RHS of GEMM MMA. Improved predux GEMV.
This commit is contained in:
parent
9e026e5e28
commit
c2f15edc43
@ -129,7 +129,7 @@ const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
|
|||||||
* reason why packing for complex is broken down into several different parts, also the reason why we endup having a
|
* reason why packing for complex is broken down into several different parts, also the reason why we endup having a
|
||||||
* float32/64 and complex float32/64 version.
|
* float32/64 and complex float32/64 version.
|
||||||
**/
|
**/
|
||||||
template<typename Scalar, typename Index, int StorageOrder>
|
template<typename Scalar, int StorageOrder>
|
||||||
EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
|
EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
|
||||||
{
|
{
|
||||||
std::complex<Scalar> v;
|
std::complex<Scalar> v;
|
||||||
@ -148,7 +148,7 @@ EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_b
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int StorageOrder, int N>
|
template<typename Scalar, int StorageOrder, int N>
|
||||||
EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
const Index depth = k2 + rows;
|
const Index depth = k2 + rows;
|
||||||
@ -166,7 +166,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
|
|||||||
{
|
{
|
||||||
for(Index k = 0; k < vectorSize; k++)
|
for(Index k = 0; k < vectorSize; k++)
|
||||||
{
|
{
|
||||||
std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
|
std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j + k, rhs);
|
||||||
|
|
||||||
blockBf[rir + k] = v.real();
|
blockBf[rir + k] = v.real();
|
||||||
blockBf[rii + k] = v.imag();
|
blockBf[rii + k] = v.imag();
|
||||||
@ -184,7 +184,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
|
|||||||
|
|
||||||
for(Index i = k2; i < depth; i++)
|
for(Index i = k2; i < depth; i++)
|
||||||
{
|
{
|
||||||
std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
|
std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j, rhs);
|
||||||
|
|
||||||
blockBf[rir] = v.real();
|
blockBf[rir] = v.real();
|
||||||
blockBf[rii] = v.imag();
|
blockBf[rii] = v.imag();
|
||||||
@ -197,7 +197,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int StorageOrder>
|
template<typename Scalar, int StorageOrder>
|
||||||
EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
|
EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
const Index depth = cols;
|
const Index depth = cols;
|
||||||
@ -215,7 +215,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* bloc
|
|||||||
{
|
{
|
||||||
for(Index k = 0; k < vectorSize; k++)
|
for(Index k = 0; k < vectorSize; k++)
|
||||||
{
|
{
|
||||||
std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
|
std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(j+k, i, lhs);
|
||||||
|
|
||||||
blockAf[rir + k] = v.real();
|
blockAf[rir + k] = v.real();
|
||||||
blockAf[rii + k] = v.imag();
|
blockAf[rii + k] = v.imag();
|
||||||
@ -236,7 +236,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* bloc
|
|||||||
Index k = j;
|
Index k = j;
|
||||||
for(; k < rows; k++)
|
for(; k < rows; k++)
|
||||||
{
|
{
|
||||||
std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
|
std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(k, i, lhs);
|
||||||
|
|
||||||
blockAf[rir] = v.real();
|
blockAf[rir] = v.real();
|
||||||
blockAf[rii] = v.imag();
|
blockAf[rii] = v.imag();
|
||||||
@ -248,7 +248,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* bloc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int StorageOrder, int N>
|
template<typename Scalar, int StorageOrder, int N>
|
||||||
EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
const Index depth = k2 + rows;
|
const Index depth = k2 + rows;
|
||||||
@ -285,7 +285,7 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int StorageOrder>
|
template<typename Scalar, int StorageOrder>
|
||||||
EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
|
EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
const Index depth = cols;
|
const Index depth = cols;
|
||||||
@ -332,7 +332,7 @@ struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
|
symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -341,7 +341,7 @@ struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrde
|
|||||||
{
|
{
|
||||||
void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
|
void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -352,7 +352,7 @@ struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
|
symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrd
|
|||||||
{
|
{
|
||||||
void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
|
void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -371,7 +371,7 @@ struct symm_pack_rhs<float, Index, nr, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
|
symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
|
void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -390,7 +390,7 @@ struct symm_pack_rhs<double, Index, nr, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
|
||||||
{
|
{
|
||||||
symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
|
symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -399,7 +399,7 @@ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
|
|||||||
{
|
{
|
||||||
void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
|
void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
|
||||||
{
|
{
|
||||||
symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -414,7 +414,7 @@ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
|
|||||||
* and offset and behaves accordingly.
|
* and offset and behaves accordingly.
|
||||||
**/
|
**/
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Index, int N>
|
template<typename Scalar, typename Packet, int N>
|
||||||
EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
|
EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
|
||||||
{
|
{
|
||||||
const Index size = 16 / sizeof(Scalar);
|
const Index size = 16 / sizeof(Scalar);
|
||||||
@ -429,7 +429,7 @@ EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// General template for lhs & rhs complex packing.
|
// General template for lhs & rhs complex packing.
|
||||||
template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
|
template<typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
|
||||||
struct dhs_cpack {
|
struct dhs_cpack {
|
||||||
EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -451,9 +451,9 @@ struct dhs_cpack {
|
|||||||
PacketBlock<PacketC,8> cblock;
|
PacketBlock<PacketC,8> cblock;
|
||||||
|
|
||||||
if (UseLhs) {
|
if (UseLhs) {
|
||||||
bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
|
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
|
||||||
} else {
|
} else {
|
||||||
bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
|
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
|
||||||
}
|
}
|
||||||
|
|
||||||
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
|
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
|
||||||
@ -480,8 +480,8 @@ struct dhs_cpack {
|
|||||||
ptranspose(blocki);
|
ptranspose(blocki);
|
||||||
}
|
}
|
||||||
|
|
||||||
storeBlock<Scalar, Packet, Index, 4>(blockAt + rir, blockr);
|
storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
|
||||||
storeBlock<Scalar, Packet, Index, 4>(blockAt + rii, blocki);
|
storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
|
||||||
|
|
||||||
rir += 4*vectorSize;
|
rir += 4*vectorSize;
|
||||||
rii += 4*vectorSize;
|
rii += 4*vectorSize;
|
||||||
@ -579,7 +579,7 @@ struct dhs_cpack {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// General template for lhs & rhs packing.
|
// General template for lhs & rhs packing.
|
||||||
template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
|
template<typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
|
||||||
struct dhs_pack{
|
struct dhs_pack{
|
||||||
EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -597,16 +597,16 @@ struct dhs_pack{
|
|||||||
PacketBlock<Packet,4> block;
|
PacketBlock<Packet,4> block;
|
||||||
|
|
||||||
if (UseLhs) {
|
if (UseLhs) {
|
||||||
bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, j, i);
|
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs, j, i);
|
||||||
} else {
|
} else {
|
||||||
bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, i, j);
|
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs, i, j);
|
||||||
}
|
}
|
||||||
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
||||||
{
|
{
|
||||||
ptranspose(block);
|
ptranspose(block);
|
||||||
}
|
}
|
||||||
|
|
||||||
storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
|
storeBlock<Scalar, Packet, 4>(blockA + ri, block);
|
||||||
|
|
||||||
ri += 4*vectorSize;
|
ri += 4*vectorSize;
|
||||||
}
|
}
|
||||||
@ -675,8 +675,8 @@ struct dhs_pack{
|
|||||||
};
|
};
|
||||||
|
|
||||||
// General template for lhs packing, float64 specialization.
|
// General template for lhs packing, float64 specialization.
|
||||||
template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
|
template<typename DataMapper, int StorageOrder, bool PanelMode>
|
||||||
struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true>
|
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
|
||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -703,7 +703,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, tr
|
|||||||
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
|
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
|
storeBlock<double, Packet2d, 2>(blockA + ri, block);
|
||||||
|
|
||||||
ri += 2*vectorSize;
|
ri += 2*vectorSize;
|
||||||
}
|
}
|
||||||
@ -742,8 +742,8 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, tr
|
|||||||
};
|
};
|
||||||
|
|
||||||
// General template for rhs packing, float64 specialization.
|
// General template for rhs packing, float64 specialization.
|
||||||
template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
|
template<typename DataMapper, int StorageOrder, bool PanelMode>
|
||||||
struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false>
|
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
|
||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -780,7 +780,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
|
|||||||
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
|
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
|
||||||
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
|
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
|
||||||
|
|
||||||
storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
|
storeBlock<double, Packet2d, 4>(blockB + ri, block);
|
||||||
}
|
}
|
||||||
|
|
||||||
ri += 4*vectorSize;
|
ri += 4*vectorSize;
|
||||||
@ -827,8 +827,8 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
|
|||||||
};
|
};
|
||||||
|
|
||||||
// General template for lhs complex packing, float64 specialization.
|
// General template for lhs complex packing, float64 specialization.
|
||||||
template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||||
struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
|
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
|
||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -882,8 +882,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
|||||||
blocki.packet[1] = -blocki.packet[1];
|
blocki.packet[1] = -blocki.packet[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
storeBlock<double, Packet, Index, 2>(blockAt + rir, blockr);
|
storeBlock<double, Packet, 2>(blockAt + rir, blockr);
|
||||||
storeBlock<double, Packet, Index, 2>(blockAt + rii, blocki);
|
storeBlock<double, Packet, 2>(blockAt + rii, blocki);
|
||||||
|
|
||||||
rir += 2*vectorSize;
|
rir += 2*vectorSize;
|
||||||
rii += 2*vectorSize;
|
rii += 2*vectorSize;
|
||||||
@ -940,8 +940,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
|||||||
};
|
};
|
||||||
|
|
||||||
// General template for rhs complex packing, float64 specialization.
|
// General template for rhs complex packing, float64 specialization.
|
||||||
template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||||
struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
|
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
|
||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
@ -962,7 +962,7 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
|||||||
PacketBlock<PacketC,4> cblock;
|
PacketBlock<PacketC,4> cblock;
|
||||||
PacketBlock<Packet,2> blockr, blocki;
|
PacketBlock<Packet,2> blockr, blocki;
|
||||||
|
|
||||||
bload<DataMapper, PacketC, Index, 2, ColMajor, false, 4>(cblock, rhs, i, j);
|
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs, i, j);
|
||||||
|
|
||||||
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
|
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
|
||||||
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
|
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
|
||||||
@ -976,8 +976,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
|||||||
blocki.packet[1] = -blocki.packet[1];
|
blocki.packet[1] = -blocki.packet[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
storeBlock<double, Packet, Index, 2>(blockBt + rir, blockr);
|
storeBlock<double, Packet, 2>(blockBt + rir, blockr);
|
||||||
storeBlock<double, Packet, Index, 2>(blockBt + rii, blocki);
|
storeBlock<double, Packet, 2>(blockBt + rii, blocki);
|
||||||
|
|
||||||
rir += 2*vectorSize;
|
rir += 2*vectorSize;
|
||||||
rii += 2*vectorSize;
|
rii += 2*vectorSize;
|
||||||
@ -1123,7 +1123,7 @@ EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packe
|
|||||||
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
|
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
|
||||||
//
|
//
|
||||||
// full = operate (load) on the entire PacketBlock or only half
|
// full = operate (load) on the entire PacketBlock or only half
|
||||||
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
|
template<typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
|
||||||
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
|
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
|
||||||
{
|
{
|
||||||
if (StorageOrder == RowMajor) {
|
if (StorageOrder == RowMajor) {
|
||||||
@ -1147,7 +1147,7 @@ EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const D
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename DataMapper, typename Packet, typename Index, int N>
|
template<typename DataMapper, typename Packet, int N>
|
||||||
EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row)
|
EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row)
|
||||||
{
|
{
|
||||||
for (int M = 0; M < N; M++) {
|
for (int M = 0; M < N; M++) {
|
||||||
@ -1165,7 +1165,7 @@ EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& re
|
|||||||
const static Packet4i mask4[4] = { { 0, 0, 0, 0 }, { -1, 0, 0, 0 }, { -1, -1, 0, 0 }, { -1, -1, -1, 0 } };
|
const static Packet4i mask4[4] = { { 0, 0, 0, 0 }, { -1, 0, 0, 0 }, { -1, -1, 0, 0 }, { -1, -1, -1, 0 } };
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<typename Packet, typename Index>
|
template<typename Packet>
|
||||||
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
|
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
|
||||||
{
|
{
|
||||||
#if USE_P10_AND_PVIPR2_0
|
#if USE_P10_AND_PVIPR2_0
|
||||||
@ -1180,7 +1180,7 @@ EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d,Index>(const Index remaining_rows)
|
EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const Index remaining_rows)
|
||||||
{
|
{
|
||||||
#if USE_P10_AND_PVIPR2_0
|
#if USE_P10_AND_PVIPR2_0
|
||||||
Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
|
Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
|
||||||
@ -1406,7 +1406,7 @@ EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Pa
|
|||||||
MICRO_PREFETCHN1(ptr_imag, N); \
|
MICRO_PREFETCHN1(ptr_imag, N); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
|
template<typename Scalar, typename Packet, const Index accRows, const Index remaining_rows>
|
||||||
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
|
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
|
||||||
const Scalar* &lhs_ptr,
|
const Scalar* &lhs_ptr,
|
||||||
const Scalar* &rhs_ptr0,
|
const Scalar* &rhs_ptr0,
|
||||||
@ -1419,7 +1419,7 @@ EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
|
|||||||
lhs_ptr += remaining_rows;
|
lhs_ptr += remaining_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols, const Index remaining_rows>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -1454,14 +1454,14 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
|
|||||||
}
|
}
|
||||||
for(; k < depth; k++)
|
for(; k < depth; k++)
|
||||||
{
|
{
|
||||||
MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
|
MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
|
bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row, 0);
|
||||||
if ((accRows == 1) || (rows >= accCols))
|
if ((accRows == 1) || (rows >= accCols))
|
||||||
{
|
{
|
||||||
bscale<Packet,accRows,true>(acc, accZero0, pAlpha, pMask);
|
bscale<Packet,accRows,true>(acc, accZero0, pAlpha, pMask);
|
||||||
bstore<DataMapper, Packet, Index, accRows>(acc, res, row);
|
bstore<DataMapper, Packet, accRows>(acc, res, row);
|
||||||
} else {
|
} else {
|
||||||
bscale<Packet,accRows,false>(acc, accZero0, pAlpha, pMask);
|
bscale<Packet,accRows,false>(acc, accZero0, pAlpha, pMask);
|
||||||
for(Index j = 0; j < accRows; j++) {
|
for(Index j = 0; j < accRows; j++) {
|
||||||
@ -1490,9 +1490,9 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_EXTRA_ROWS(N) \
|
#define MICRO_EXTRA_ROWS(N) \
|
||||||
gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
|
gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -1563,14 +1563,14 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
|||||||
|
|
||||||
#define MICRO_STORE_ONE(iter) \
|
#define MICRO_STORE_ONE(iter) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
|
bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
|
||||||
bscale<Packet,accRows,!(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
|
bscale<Packet,accRows,!(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
|
||||||
bstore<DataMapper, Packet, Index, accRows>(acc, res, row + iter*accCols); \
|
bstore<DataMapper, Packet, accRows>(acc, res, row + iter*accCols); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
|
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
|
||||||
|
|
||||||
template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2>
|
template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -1609,10 +1609,10 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_UNROLL_ITER2(N, M) \
|
#define MICRO_UNROLL_ITER2(N, M) \
|
||||||
gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, Index, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
|
gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
|
||||||
if (M) return;
|
if (M) return;
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_cols(
|
EIGEN_ALWAYS_INLINE void gemm_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -1681,14 +1681,14 @@ EIGEN_ALWAYS_INLINE void gemm_cols(
|
|||||||
|
|
||||||
if(remaining_rows > 0)
|
if(remaining_rows > 0)
|
||||||
{
|
{
|
||||||
gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_EXTRA_COLS(N) \
|
#define MICRO_EXTRA_COLS(N) \
|
||||||
gemm_cols<Scalar, Packet, DataMapper, Index, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols>
|
||||||
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -1711,7 +1711,7 @@ EIGEN_STRONG_INLINE void gemm_extra_cols(
|
|||||||
/****************
|
/****************
|
||||||
* GEMM kernels *
|
* GEMM kernels *
|
||||||
* **************/
|
* **************/
|
||||||
template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
const Index remaining_rows = rows % accCols;
|
const Index remaining_rows = rows % accCols;
|
||||||
@ -1725,12 +1725,12 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const
|
|||||||
Index col = 0;
|
Index col = 0;
|
||||||
for(; col + accRows <= cols; col += accRows)
|
for(; col + accRows <= cols; col += accRows)
|
||||||
{
|
{
|
||||||
gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (col != cols)
|
if (col != cols)
|
||||||
{
|
{
|
||||||
gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1828,7 +1828,7 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const
|
|||||||
MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
|
MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
|
||||||
MICRO_COMPLEX_ADD_PEEL(1, 0)
|
MICRO_COMPLEX_ADD_PEEL(1, 0)
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
|
template<typename Scalar, typename Packet, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
|
||||||
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
|
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
|
||||||
const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
|
const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
|
||||||
const Scalar* &rhs_ptr_real0, const Scalar* &rhs_ptr_real1, const Scalar* &rhs_ptr_real2,
|
const Scalar* &rhs_ptr_real0, const Scalar* &rhs_ptr_real1, const Scalar* &rhs_ptr_real2,
|
||||||
@ -1840,7 +1840,7 @@ EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
|
|||||||
MICRO_COMPLEX_ADD_COLS(1)
|
MICRO_COMPLEX_ADD_COLS(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -1888,18 +1888,18 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
|
|||||||
}
|
}
|
||||||
for(; k < depth; k++)
|
for(; k < depth; k++)
|
||||||
{
|
{
|
||||||
MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1, rhs_ptr_imag2, accReal0, accImag0);
|
MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1, rhs_ptr_imag2, accReal0, accImag0);
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool full = (remaining_rows > accColsC);
|
constexpr bool full = (remaining_rows > accColsC);
|
||||||
bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
|
bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
|
||||||
if ((accRows == 1) || (rows >= accCols))
|
if ((accRows == 1) || (rows >= accCols))
|
||||||
{
|
{
|
||||||
bscalec<Packet,accRows,true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
|
bscalec<Packet,accRows,true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
|
||||||
bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
|
bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
|
||||||
bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + 0);
|
bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
|
||||||
if (full) {
|
if (full) {
|
||||||
bstore<DataMapper, Packetc, Index, accRows>(acc1, res, row + accColsC);
|
bstore<DataMapper, Packetc, accRows>(acc1, res, row + accColsC);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bscalec<Packet,accRows,false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
|
bscalec<Packet,accRows,false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
|
||||||
@ -1911,7 +1911,7 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
|
|||||||
res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
|
res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + 0);
|
bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
|
||||||
if (full) {
|
if (full) {
|
||||||
for(Index j = 0; j < accRows; j++) {
|
for(Index j = 0; j < accRows; j++) {
|
||||||
res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
|
res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
|
||||||
@ -1922,9 +1922,9 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_EXTRA_ROWS(N) \
|
#define MICRO_COMPLEX_EXTRA_ROWS(N) \
|
||||||
gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, N>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -1998,19 +1998,19 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
|||||||
|
|
||||||
#define MICRO_COMPLEX_STORE_ONE(iter) \
|
#define MICRO_COMPLEX_STORE_ONE(iter) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
const bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
|
constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
|
||||||
bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter*accCols, 0); \
|
bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter*accCols, 0); \
|
||||||
bscalec<Packet,accRows,!(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); \
|
bscalec<Packet,accRows,!(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); \
|
||||||
bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
|
bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
|
||||||
bstore<DataMapper, Packetc, Index, accRows>(acc0, res, row + iter*accCols + 0); \
|
bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter*accCols + 0); \
|
||||||
if (full) { \
|
if (full) { \
|
||||||
bstore<DataMapper, Packetc, Index, accRows>(acc1, res, row + iter*accCols + accColsC); \
|
bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter*accCols + accColsC); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
|
#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
|
||||||
|
|
||||||
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -2057,10 +2057,10 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
|
#define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
|
||||||
gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
||||||
if (M) return;
|
if (M) return;
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_complex_cols(
|
EIGEN_ALWAYS_INLINE void gemm_complex_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -2115,14 +2115,14 @@ EIGEN_ALWAYS_INLINE void gemm_complex_cols(
|
|||||||
|
|
||||||
if(remaining_rows > 0)
|
if(remaining_rows > 0)
|
||||||
{
|
{
|
||||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_EXTRA_COLS(N) \
|
#define MICRO_COMPLEX_EXTRA_COLS(N) \
|
||||||
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -2143,7 +2143,7 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
|||||||
MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols-col, true)
|
MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols-col, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
const Index remaining_rows = rows % accCols;
|
const Index remaining_rows = rows % accCols;
|
||||||
@ -2161,12 +2161,12 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl
|
|||||||
Index col = 0;
|
Index col = 0;
|
||||||
for(; col + accRows <= cols; col += accRows)
|
for(; col + accRows <= cols; col += accRows)
|
||||||
{
|
{
|
||||||
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (col != cols)
|
if (col != cols)
|
||||||
{
|
{
|
||||||
gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2189,7 +2189,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
|
dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2203,7 +2203,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
|
dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2218,7 +2218,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
|
dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2232,7 +2232,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
|
dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -2247,7 +2247,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
|
dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2261,7 +2261,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
|
dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2275,7 +2275,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
|
dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2289,7 +2289,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
|
dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2304,7 +2304,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
|
dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2318,7 +2318,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
|
dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -2333,7 +2333,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
|
dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2347,7 +2347,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
|
dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2361,7 +2361,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
|
dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2375,7 +2375,7 @@ template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Pac
|
|||||||
void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
|
dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
|
||||||
pack(blockA, lhs, depth, rows, stride, offset);
|
pack(blockA, lhs, depth, rows, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2389,7 +2389,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
|
dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2403,7 +2403,7 @@ template<typename Index, typename DataMapper, int nr, bool Conjugate, bool Panel
|
|||||||
void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
|
||||||
::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||||
{
|
{
|
||||||
dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
|
dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
|
||||||
pack(blockB, rhs, depth, cols, stride, offset);
|
pack(blockB, rhs, depth, cols, stride, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2431,16 +2431,16 @@ void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, Conjugat
|
|||||||
|
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2470,16 +2470,16 @@ void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr
|
|||||||
|
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2508,16 +2508,16 @@ void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, Conjugat
|
|||||||
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
|
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2546,16 +2546,16 @@ void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, Conjugat
|
|||||||
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
|
Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2583,16 +2583,16 @@ void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, Conjug
|
|||||||
|
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
gemm_function = &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2621,16 +2621,16 @@ void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper,
|
|||||||
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2659,16 +2659,16 @@ void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, Conjug
|
|||||||
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
@ -2697,16 +2697,16 @@ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, Conjug
|
|||||||
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
|
||||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||||
//generate with MMA only
|
//generate with MMA only
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||||
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
|
||||||
gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
|
||||||
#endif
|
#endif
|
||||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ namespace Eigen {
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -26,7 +26,7 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
|||||||
const Packet& pAlpha,
|
const Packet& pAlpha,
|
||||||
const Packet& pMask);
|
const Packet& pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -43,10 +43,10 @@ EIGEN_STRONG_INLINE void gemm_extra_cols(
|
|||||||
const Packet& pAlpha,
|
const Packet& pAlpha,
|
||||||
const Packet& pMask);
|
const Packet& pMask);
|
||||||
|
|
||||||
template<typename Packet, typename Index>
|
template<typename Packet>
|
||||||
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);
|
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -62,7 +62,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
|||||||
const Packet& pAlphaImag,
|
const Packet& pAlphaImag,
|
||||||
const Packet& pMask);
|
const Packet& pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -83,10 +83,10 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
|||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
|
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
|
||||||
|
|
||||||
template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N, bool full = true>
|
template<typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N, bool full = true>
|
||||||
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col);
|
EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col);
|
||||||
|
|
||||||
template<typename DataMapper, typename Packet, typename Index, int N>
|
template<typename DataMapper, typename Packet, int N>
|
||||||
EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row);
|
EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row);
|
||||||
|
|
||||||
template<typename Packet, int N, bool mask>
|
template<typename Packet, int N, bool mask>
|
||||||
|
@ -39,30 +39,30 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
|
|||||||
__builtin_mma_xxsetaccz(acc);
|
__builtin_mma_xxsetaccz(acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename DataMapper, typename Index, typename Packet, const Index accCols, const Index accCols2>
|
template<typename DataMapper, typename Packet, const Index accCols, const Index accCols2>
|
||||||
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc)
|
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc)
|
||||||
{
|
{
|
||||||
PacketBlock<Packet, 4> result;
|
PacketBlock<Packet, 4> result;
|
||||||
__builtin_mma_disassemble_acc(&result.packet, acc);
|
__builtin_mma_disassemble_acc(&result.packet, acc);
|
||||||
|
|
||||||
PacketBlock<Packet, 4> tRes;
|
PacketBlock<Packet, 4> tRes;
|
||||||
bload<DataMapper, Packet, Index, 0, ColMajor, false, 4>(tRes, data, i, 0);
|
bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
|
||||||
|
|
||||||
bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
|
bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
|
||||||
|
|
||||||
bstore<DataMapper, Packet, Index, 4>(tRes, data, i);
|
bstore<DataMapper, Packet, 4>(tRes, data, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
|
template<typename DataMapper, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
|
||||||
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal, __vector_quad* accImag)
|
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal, __vector_quad* accImag)
|
||||||
{
|
{
|
||||||
const bool full = (accCols2 > accColsC);
|
constexpr bool full = (accCols2 > accColsC);
|
||||||
PacketBlock<Packet, 4> resultReal, resultImag;
|
PacketBlock<Packet, 4> resultReal, resultImag;
|
||||||
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
|
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
|
||||||
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
|
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
|
||||||
|
|
||||||
PacketBlock<Packetc, 8> tRes;
|
PacketBlock<Packetc, 8> tRes;
|
||||||
bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
|
bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
|
||||||
|
|
||||||
PacketBlock<Packet, 4> taccReal, taccImag;
|
PacketBlock<Packet, 4> taccReal, taccImag;
|
||||||
bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
|
bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
|
||||||
@ -70,9 +70,9 @@ EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data
|
|||||||
PacketBlock<Packetc, 4> acc1, acc2;
|
PacketBlock<Packetc, 4> acc1, acc2;
|
||||||
bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
|
bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
|
||||||
|
|
||||||
bstore<DataMapper, Packetc, Index, 4>(acc1, data, i);
|
bstore<DataMapper, Packetc, 4>(acc1, data, i);
|
||||||
if (full) {
|
if (full) {
|
||||||
bstore<DataMapper, Packetc, Index, 4>(acc2, data, i + accColsC);
|
bstore<DataMapper, Packetc, 4>(acc2, data, i + accColsC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,13 +163,13 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
|||||||
|
|
||||||
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
|
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV##iter); \
|
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV##iter); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef VECTOR_PAIR_LOADS_LHS
|
#ifdef VECTOR_PAIR_LOADS_LHS
|
||||||
#define MICRO_MMA_WORK_TWO(iter, type, peel) \
|
#define MICRO_MMA_WORK_TWO(iter, type, peel) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##peel, lhsV2##iter.packet[peel & 1]); \
|
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV2##iter.packet[peel & 1]); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
||||||
@ -195,16 +195,14 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
|||||||
#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
|
#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
|
||||||
if (PEEL_MMA > peel) { \
|
if (PEEL_MMA > peel) { \
|
||||||
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
|
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
|
||||||
ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \
|
ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV[peel]); \
|
||||||
MICRO_MMA_UNROLL(funcl) \
|
MICRO_MMA_UNROLL(funcl) \
|
||||||
MICRO_MMA_WORK(funcw, type, peel) \
|
MICRO_MMA_WORK(funcw, type, peel) \
|
||||||
} else { \
|
|
||||||
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||||
#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
||||||
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
|
type rhsV[8]; \
|
||||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
||||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \
|
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \
|
||||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \
|
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \
|
||||||
@ -214,17 +212,25 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
|||||||
if (PEEL_MMA > peel2) { \
|
if (PEEL_MMA > peel2) { \
|
||||||
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
|
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
|
||||||
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
|
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
|
||||||
ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV##peel1); \
|
if (sizeof(type) == 16) { \
|
||||||
ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV##peel2); \
|
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr + (accRows * peel1)), prhsV##peel1); \
|
||||||
|
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
|
||||||
|
} else { \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||||
|
ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV[peel1]); \
|
||||||
|
ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV[peel2]); \
|
||||||
|
} \
|
||||||
MICRO_MMA_UNROLL(funcl2) \
|
MICRO_MMA_UNROLL(funcl2) \
|
||||||
MICRO_MMA_WORK(funcw2, type, peel1) \
|
MICRO_MMA_WORK(funcw2, type, peel1) \
|
||||||
MICRO_MMA_WORK(funcw2, type, peel2) \
|
MICRO_MMA_WORK(funcw2, type, peel2) \
|
||||||
} else { \
|
} else { \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||||
MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
|
MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
||||||
type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
|
type rhsV[8]; \
|
||||||
|
__vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
|
||||||
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
||||||
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
|
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
|
||||||
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,4,5) \
|
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,4,5) \
|
||||||
@ -232,7 +238,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
||||||
type rhsV0; \
|
type rhsV[1]; \
|
||||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
||||||
|
|
||||||
#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
|
#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
|
||||||
@ -266,12 +272,12 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
|||||||
|
|
||||||
#define MICRO_MMA_STORE_ONE(iter) \
|
#define MICRO_MMA_STORE_ONE(iter) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
storeAccumulator<DataMapper, Index, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
|
storeAccumulator<DataMapper, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
|
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
|
||||||
|
|
||||||
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2>
|
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -307,10 +313,10 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_MMA_UNROLL_ITER2(N, M) \
|
#define MICRO_MMA_UNROLL_ITER2(N, M) \
|
||||||
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
|
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
|
||||||
if (M) return;
|
if (M) return;
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -379,11 +385,11 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
|||||||
|
|
||||||
if(remaining_rows > 0)
|
if(remaining_rows > 0)
|
||||||
{
|
{
|
||||||
gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
const Index remaining_rows = rows % accCols;
|
const Index remaining_rows = rows % accCols;
|
||||||
@ -399,12 +405,12 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
Index col = 0;
|
Index col = 0;
|
||||||
for(; col + accRows <= cols; col += accRows)
|
for(; col + accRows <= cols; col += accRows)
|
||||||
{
|
{
|
||||||
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (col != cols)
|
if (col != cols)
|
||||||
{
|
{
|
||||||
gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,13 +428,13 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
|
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
|
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV[peel], rhsVi[peel]); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef VECTOR_PAIR_LOADS_LHS
|
#ifdef VECTOR_PAIR_LOADS_LHS
|
||||||
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
|
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV##peel, rhsVi##peel); \
|
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV[peel], rhsVi[peel]); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
||||||
@ -454,23 +460,17 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
if (PEEL_COMPLEX_MMA > peel) { \
|
if (PEEL_COMPLEX_MMA > peel) { \
|
||||||
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
|
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
|
||||||
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
|
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
|
||||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \
|
ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV[peel]); \
|
||||||
if(!RhsIsReal) { \
|
if(!RhsIsReal) { \
|
||||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
|
ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi[peel]); \
|
||||||
} else { \
|
|
||||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
|
|
||||||
} \
|
} \
|
||||||
MICRO_COMPLEX_MMA_UNROLL(funcl) \
|
MICRO_COMPLEX_MMA_UNROLL(funcl) \
|
||||||
MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
|
MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
|
||||||
} else { \
|
|
||||||
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
|
||||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
||||||
type rhsV0, rhsV1, rhsV2, rhsV3; \
|
type rhsV[4], rhsVi[4]; \
|
||||||
type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
|
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3)
|
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3)
|
||||||
#else
|
#else
|
||||||
@ -480,31 +480,44 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
PacketBlock<Packet,2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
|
PacketBlock<Packet,2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
|
||||||
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
|
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
|
||||||
__vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
|
__vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
|
||||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV##peel1); \
|
if (sizeof(type) == 16) { \
|
||||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV##peel2); \
|
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real + (accRows * peel1)), prhsV##peel1); \
|
||||||
|
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
|
||||||
if(!RhsIsReal) { \
|
if(!RhsIsReal) { \
|
||||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi##peel1); \
|
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag + (accRows * peel1)), prhsVi##peel1); \
|
||||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi##peel2); \
|
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi[peel1]), &prhsVi##peel1); \
|
||||||
} else { \
|
} else { \
|
||||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel1); \
|
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||||
EIGEN_UNUSED_VARIABLE(rhsVi##peel2); \
|
} \
|
||||||
|
} else { \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||||
|
ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV[peel1]); \
|
||||||
|
ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV[peel2]); \
|
||||||
|
if(!RhsIsReal) { \
|
||||||
|
ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi[peel1]); \
|
||||||
|
ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi[peel2]); \
|
||||||
|
} \
|
||||||
} \
|
} \
|
||||||
MICRO_COMPLEX_MMA_UNROLL(funcl2) \
|
MICRO_COMPLEX_MMA_UNROLL(funcl2) \
|
||||||
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
|
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
|
||||||
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
|
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
|
||||||
} else { \
|
} else { \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||||
|
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
|
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
||||||
type rhsV0, rhsV1, rhsV2, rhsV3; \
|
type rhsV[4], rhsVi[4]; \
|
||||||
type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
|
__vector_pair prhsV0, prhsV2; \
|
||||||
|
__vector_pair prhsVi0, prhsVi2; \
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3)
|
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
||||||
type rhsV0, rhsVi0; \
|
type rhsV[1], rhsVi[1]; \
|
||||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
|
#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
|
||||||
@ -542,12 +555,12 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
|
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
|
||||||
if (unroll_factor > iter) { \
|
if (unroll_factor > iter) { \
|
||||||
storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
|
storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
|
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
|
||||||
|
|
||||||
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* lhs_base,
|
const Scalar* lhs_base,
|
||||||
@ -597,10 +610,10 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
|
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
|
||||||
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
||||||
if (M) return;
|
if (M) return;
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
@ -655,11 +668,11 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
|||||||
|
|
||||||
if(remaining_rows > 0)
|
if(remaining_rows > 0)
|
||||||
{
|
{
|
||||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
const Index remaining_rows = rows % accCols;
|
const Index remaining_rows = rows % accCols;
|
||||||
@ -679,12 +692,12 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS
|
|||||||
Index col = 0;
|
Index col = 0;
|
||||||
for(; col + accRows <= cols; col += accRows)
|
for(; col + accRows <= cols; col += accRows)
|
||||||
{
|
{
|
||||||
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (col != cols)
|
if (col != cols)
|
||||||
{
|
{
|
||||||
gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc, __vector_pair& a, c
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<typename Index, typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
|
template<typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
|
||||||
EIGEN_STRONG_INLINE void gemv_col(
|
EIGEN_STRONG_INLINE void gemv_col(
|
||||||
Index rows, Index cols,
|
Index rows, Index cols,
|
||||||
const LhsMapper& alhs,
|
const LhsMapper& alhs,
|
||||||
@ -927,7 +927,7 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, AlphaData& b0, Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index, typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar, typename AlphaData, Index ResPacketSize, Index iter2>
|
template<typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar, typename AlphaData, Index ResPacketSize, Index iter2>
|
||||||
EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, AlphaData& b0, ResScalar* res)
|
EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, AlphaData& b0, ResScalar* res)
|
||||||
{
|
{
|
||||||
PResPacket c2 = pcplxflipconj(c0);
|
PResPacket c2 = pcplxflipconj(c0);
|
||||||
@ -953,7 +953,7 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, A
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** \internal load lhs packet */
|
/** \internal load lhs packet */
|
||||||
template<typename Scalar, typename LhsScalar, typename LhsMapper, typename LhsPacket, typename Index>
|
template<typename Scalar, typename LhsScalar, typename LhsMapper, typename LhsPacket>
|
||||||
EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper& lhs, Index i, Index j)
|
EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper& lhs, Index i, Index j)
|
||||||
{
|
{
|
||||||
if (sizeof(Scalar) == sizeof(LhsScalar)) {
|
if (sizeof(Scalar) == sizeof(LhsScalar)) {
|
||||||
@ -1337,14 +1337,14 @@ EIGEN_ALWAYS_INLINE void disassembleResults2(__vector_quad* c0, PacketBlock<Scal
|
|||||||
result0.packet[0] = tmp0;
|
result0.packet[0] = tmp0;
|
||||||
|
|
||||||
if (ConjugateLhs) {
|
if (ConjugateLhs) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
result0.packet[2] = convertReal(pconj2(convertComplex(result0.packet[2])));
|
result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
|
||||||
} else if (ConjugateRhs) {
|
} else if (ConjugateRhs) {
|
||||||
result0.packet[1] = convertReal(pconj2(convertComplex(result0.packet[1])));
|
result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
|
||||||
result0.packet[3] = convertReal(pconj2(convertComplex(result0.packet[3])));
|
result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
|
||||||
} else {
|
} else {
|
||||||
result0.packet[1] = convertReal(pconjinv(convertComplex(result0.packet[1])));
|
result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
|
||||||
result0.packet[3] = convertReal(pconjinv(convertComplex(result0.packet[3])));
|
result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
|
||||||
}
|
}
|
||||||
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
||||||
result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
|
result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
|
||||||
@ -1361,19 +1361,19 @@ EIGEN_ALWAYS_INLINE void disassembleResults4(__vector_quad* c0, PacketBlock<Scal
|
|||||||
__builtin_mma_disassemble_acc(&result0.packet, c0);
|
__builtin_mma_disassemble_acc(&result0.packet, c0);
|
||||||
if (GEMV_IS_COMPLEX_COMPLEX) {
|
if (GEMV_IS_COMPLEX_COMPLEX) {
|
||||||
if (ConjugateLhs) {
|
if (ConjugateLhs) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
result0.packet[1] = convertReal(pcplxflip2(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
|
||||||
} else {
|
} else {
|
||||||
if (ConjugateRhs) {
|
if (ConjugateRhs) {
|
||||||
result0.packet[1] = convertReal(pcplxconjflip(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
|
||||||
} else {
|
} else {
|
||||||
result0.packet[1] = convertReal(pcplxflipconj(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
||||||
} else if (sizeof(LhsPacket) == sizeof(std::complex<float>)) {
|
} else if (sizeof(LhsPacket) == sizeof(std::complex<float>)) {
|
||||||
if (ConjugateLhs) {
|
if (ConjugateLhs) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
|
result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
|
||||||
@ -1394,7 +1394,7 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
|
|||||||
#define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
|
#define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
|
||||||
|
|
||||||
#define GEMV_LOADPACKET_COL_COMPLEX(iter) \
|
#define GEMV_LOADPACKET_COL_COMPLEX(iter) \
|
||||||
loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket, Index>(lhs, i + ((iter) * ResPacketSize), j)
|
loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
|
||||||
|
|
||||||
#define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) \
|
#define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) \
|
||||||
convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
|
convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
|
||||||
@ -1444,7 +1444,7 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
|
#define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
|
||||||
GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
|
GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
|
||||||
|
|
||||||
#define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
|
#define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
|
||||||
if (GEMV_GETN_COMPLEX(N) > iter1) { \
|
if (GEMV_GETN_COMPLEX(N) > iter1) { \
|
||||||
@ -1498,7 +1498,7 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
|
#define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
|
||||||
disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter, result0##iter); \
|
disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter, result0##iter);
|
||||||
|
|
||||||
#define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
|
#define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
|
||||||
if (GEMV_GETN_COMPLEX(N) > iter) { \
|
if (GEMV_GETN_COMPLEX(N) > iter) { \
|
||||||
@ -1520,13 +1520,13 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
|
|||||||
c0##iter2 = PResPacket(result0##iter2.packet[0]); \
|
c0##iter2 = PResPacket(result0##iter2.packet[0]); \
|
||||||
if (GEMV_IS_COMPLEX_FLOAT) { \
|
if (GEMV_IS_COMPLEX_FLOAT) { \
|
||||||
c0##iter3 = PResPacket(result0##iter3.packet[0]); \
|
c0##iter3 = PResPacket(result0##iter3.packet[0]); \
|
||||||
pstoreu_pmadd_complex<Index, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
||||||
} else { \
|
} else { \
|
||||||
c0##iter3 = PResPacket(result0##iter2.packet[2]); \
|
c0##iter3 = PResPacket(result0##iter2.packet[2]); \
|
||||||
pstoreu_pmadd_complex<Index, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
||||||
c0##iter2 = PResPacket(result0##iter3.packet[0]); \
|
c0##iter2 = PResPacket(result0##iter3.packet[0]); \
|
||||||
c0##iter3 = PResPacket(result0##iter3.packet[2]); \
|
c0##iter3 = PResPacket(result0##iter3.packet[2]); \
|
||||||
pstoreu_pmadd_complex<Index, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>(c0##iter2, c0##iter3, alpha_data, res + i); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1607,7 +1607,7 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<typename Index, typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
|
template<typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
|
||||||
EIGEN_STRONG_INLINE void gemv_complex_col(
|
EIGEN_STRONG_INLINE void gemv_complex_col(
|
||||||
Index rows, Index cols,
|
Index rows, Index cols,
|
||||||
const LhsMapper& alhs,
|
const LhsMapper& alhs,
|
||||||
@ -1725,10 +1725,6 @@ static Packet16uc p16uc_ELEMENT_3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f,
|
|||||||
template<typename ResScalar, typename ResPacket>
|
template<typename ResScalar, typename ResPacket>
|
||||||
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1)
|
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1)
|
||||||
{
|
{
|
||||||
union {
|
|
||||||
ScalarBlock<ResScalar, 2> cs;
|
|
||||||
double cd;
|
|
||||||
} cc0;
|
|
||||||
PacketBlock<ResPacket, 4> result0, result1;
|
PacketBlock<ResPacket, 4> result0, result1;
|
||||||
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
||||||
__builtin_mma_disassemble_acc(&result1.packet, acc1);
|
__builtin_mma_disassemble_acc(&result1.packet, acc1);
|
||||||
@ -1737,20 +1733,17 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, _
|
|||||||
result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
|
result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
|
||||||
result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
|
result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
|
||||||
result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
|
result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
|
||||||
cc0.cd = pfirst(reinterpret_cast<Packet2d>(result0.packet[0]));
|
return *reinterpret_cast<ScalarBlock<ResScalar, 2> *>(&result0.packet[0]);
|
||||||
return cc0.cs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
EIGEN_ALWAYS_INLINE ScalarBlock<double, 2> predux_real<double, Packet2d>(__vector_quad* acc0, __vector_quad* acc1)
|
EIGEN_ALWAYS_INLINE ScalarBlock<double, 2> predux_real<double, Packet2d>(__vector_quad* acc0, __vector_quad* acc1)
|
||||||
{
|
{
|
||||||
ScalarBlock<double, 2> cc0;
|
|
||||||
PacketBlock<Packet2d, 4> result0, result1;
|
PacketBlock<Packet2d, 4> result0, result1;
|
||||||
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
||||||
__builtin_mma_disassemble_acc(&result1.packet, acc1);
|
__builtin_mma_disassemble_acc(&result1.packet, acc1);
|
||||||
cc0.scalar[0] = result0.packet[0][0] + result0.packet[1][1];
|
result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
|
||||||
cc0.scalar[1] = result1.packet[0][0] + result1.packet[1][1];
|
return *reinterpret_cast<ScalarBlock<double, 2> *>(&result0.packet[0]);
|
||||||
return cc0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal add complex results together */
|
/** \internal add complex results together */
|
||||||
@ -1766,17 +1759,17 @@ EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<float>, 2> addComplexResults(Packet
|
|||||||
result0.packet[3] = reinterpret_cast<Packet4f>(vec_mergel(reinterpret_cast<Packet2d>(result0.packet[3]), reinterpret_cast<Packet2d>(result1.packet[3])));
|
result0.packet[3] = reinterpret_cast<Packet4f>(vec_mergel(reinterpret_cast<Packet2d>(result0.packet[3]), reinterpret_cast<Packet2d>(result1.packet[3])));
|
||||||
result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
|
result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
|
||||||
if (ConjugateLhs) {
|
if (ConjugateLhs) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
result0.packet[1] = convertReal(pcplxflip2(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
|
||||||
} else if (ConjugateRhs) {
|
} else if (ConjugateRhs) {
|
||||||
result0.packet[1] = convertReal(pcplxconjflip(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
|
||||||
} else {
|
} else {
|
||||||
result0.packet[1] = convertReal(pcplxflipconj(convertComplex(result0.packet[1])));
|
result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
|
||||||
}
|
}
|
||||||
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
|
||||||
} else {
|
} else {
|
||||||
if (ConjugateLhs && (sizeof(LhsPacket) == sizeof(std::complex<float>))) {
|
if (ConjugateLhs && (sizeof(LhsPacket) == sizeof(std::complex<float>))) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cc0.scalar[0].real(result0.packet[0][0]);
|
cc0.scalar[0].real(result0.packet[0][0]);
|
||||||
@ -1807,12 +1800,10 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0
|
|||||||
template<typename ResScalar, typename ResPacket>
|
template<typename ResScalar, typename ResPacket>
|
||||||
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0)
|
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0)
|
||||||
{
|
{
|
||||||
ScalarBlock<ResScalar, 2> cc0;
|
|
||||||
PacketBlock<ResPacket, 4> result0;
|
PacketBlock<ResPacket, 4> result0;
|
||||||
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
||||||
cc0.scalar[0] = result0.packet[0][0] + result0.packet[1][1];
|
result0.packet[0] = vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
|
||||||
cc0.scalar[1] = result0.packet[2][0] + result0.packet[3][1];
|
return *reinterpret_cast<ScalarBlock<ResScalar, 2> *>(&result0.packet[0]);
|
||||||
return cc0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
|
template<typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
|
||||||
@ -1823,25 +1814,25 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0
|
|||||||
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
__builtin_mma_disassemble_acc(&result0.packet, acc0);
|
||||||
if (GEMV_IS_COMPLEX_COMPLEX) {
|
if (GEMV_IS_COMPLEX_COMPLEX) {
|
||||||
if (ConjugateLhs) {
|
if (ConjugateLhs) {
|
||||||
result0.packet[1] = convertReal(pconjinv(convertComplex(result0.packet[1])));
|
result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
|
||||||
result0.packet[3] = convertReal(pconjinv(convertComplex(result0.packet[3])));
|
result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
|
||||||
} else if (ConjugateRhs) {
|
} else if (ConjugateRhs) {
|
||||||
result0.packet[0] = convertReal(pconj2(convertComplex(result0.packet[0])));
|
result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
|
||||||
result0.packet[2] = convertReal(pconj2(convertComplex(result0.packet[2])));
|
result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
|
||||||
} else {
|
} else {
|
||||||
result0.packet[1] = convertReal(pconj2(convertComplex(result0.packet[1])));
|
result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
|
||||||
result0.packet[3] = convertReal(pconj2(convertComplex(result0.packet[3])));
|
result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
|
||||||
}
|
}
|
||||||
cc0.scalar[0].real(result0.packet[0][0] + result0.packet[1][1]);
|
result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
|
||||||
cc0.scalar[0].imag(result0.packet[0][1] + result0.packet[1][0]);
|
result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
|
||||||
cc0.scalar[1].real(result0.packet[2][0] + result0.packet[3][1]);
|
|
||||||
cc0.scalar[1].imag(result0.packet[2][1] + result0.packet[3][0]);
|
|
||||||
} else {
|
} else {
|
||||||
|
result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
|
||||||
|
result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
|
||||||
|
}
|
||||||
cc0.scalar[0].real(result0.packet[0][0]);
|
cc0.scalar[0].real(result0.packet[0][0]);
|
||||||
cc0.scalar[0].imag(result0.packet[1][1]);
|
cc0.scalar[0].imag(result0.packet[0][1]);
|
||||||
cc0.scalar[1].real(result0.packet[2][0]);
|
cc0.scalar[1].real(result0.packet[2][0]);
|
||||||
cc0.scalar[1].imag(result0.packet[3][1]);
|
cc0.scalar[1].imag(result0.packet[2][1]);
|
||||||
}
|
|
||||||
return cc0;
|
return cc0;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -1957,7 +1948,7 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(ResPacket& a, ResPa
|
|||||||
GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
|
GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index, typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
|
template<typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
|
||||||
EIGEN_STRONG_INLINE void gemv_row(
|
EIGEN_STRONG_INLINE void gemv_row(
|
||||||
Index rows, Index cols,
|
Index rows, Index cols,
|
||||||
const LhsMapper& alhs,
|
const LhsMapper& alhs,
|
||||||
@ -2040,7 +2031,7 @@ struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, Conjuga
|
|||||||
const RhsMapper& rhs, \
|
const RhsMapper& rhs, \
|
||||||
ResScalar* res, Index resIncr, \
|
ResScalar* res, Index resIncr, \
|
||||||
ResScalar alpha) { \
|
ResScalar alpha) { \
|
||||||
gemv_col<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
} \
|
} \
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2056,7 +2047,7 @@ struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, Conjuga
|
|||||||
const RhsMapper& rhs, \
|
const RhsMapper& rhs, \
|
||||||
ResScalar* res, Index resIncr, \
|
ResScalar* res, Index resIncr, \
|
||||||
ResScalar alpha) { \
|
ResScalar alpha) { \
|
||||||
gemv_row<Index, Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
} \
|
} \
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2076,7 +2067,7 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define GEMV_LOADPACKET_ROW_COMPLEX(iter) \
|
#define GEMV_LOADPACKET_ROW_COMPLEX(iter) \
|
||||||
loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket, Index>(lhs, i + (iter), j)
|
loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
|
||||||
|
|
||||||
#define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) \
|
#define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) \
|
||||||
convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
|
convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
|
||||||
@ -2276,7 +2267,7 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PRe
|
|||||||
GEMV_PROCESS_ROW_COMPLEX_ONE(N)
|
GEMV_PROCESS_ROW_COMPLEX_ONE(N)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template<typename Index, typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
|
template<typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
|
||||||
EIGEN_STRONG_INLINE void gemv_complex_row(
|
EIGEN_STRONG_INLINE void gemv_complex_row(
|
||||||
Index rows, Index cols,
|
Index rows, Index cols,
|
||||||
const LhsMapper& alhs,
|
const LhsMapper& alhs,
|
||||||
@ -2367,7 +2358,7 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, Conj
|
|||||||
const RhsMapper& rhs, \
|
const RhsMapper& rhs, \
|
||||||
ResScalar* res, Index resIncr, \
|
ResScalar* res, Index resIncr, \
|
||||||
ResScalar alpha) { \
|
ResScalar alpha) { \
|
||||||
gemv_complex_col<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
} \
|
} \
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2383,7 +2374,7 @@ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, Conj
|
|||||||
const RhsMapper& rhs, \
|
const RhsMapper& rhs, \
|
||||||
ResScalar* res, Index resIncr, \
|
ResScalar* res, Index resIncr, \
|
||||||
ResScalar alpha) { \
|
ResScalar alpha) { \
|
||||||
gemv_complex_row<Index, Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
} \
|
} \
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user