mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 09:23:12 +08:00
avoid dynamic allocation for fixed size triangular solving
This commit is contained in:
parent
bc580bbffb
commit
924c7a9300
@ -100,12 +100,22 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
|
|||||||
typedef typename Rhs::Index Index;
|
typedef typename Rhs::Index Index;
|
||||||
typedef blas_traits<Lhs> LhsProductTraits;
|
typedef blas_traits<Lhs> LhsProductTraits;
|
||||||
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
|
typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType;
|
||||||
|
|
||||||
static void run(const Lhs& lhs, Rhs& rhs)
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
{
|
{
|
||||||
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
|
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs);
|
||||||
|
|
||||||
|
const Index size = lhs.rows();
|
||||||
|
const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows();
|
||||||
|
|
||||||
|
typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
||||||
|
Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType;
|
||||||
|
|
||||||
|
BlockingType blocking(rhs.rows(), rhs.cols(), size);
|
||||||
|
|
||||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||||
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
|
::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ static void run(Index rows, Index cols, Index depth,
|
|||||||
|
|
||||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||||
|
|
||||||
Index kc = blocking.kc(); // cache block size along the K direction
|
Index kc = blocking.kc(); // cache block size along the K direction
|
||||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||||
//Index nc = blocking.nc(); // cache block size along the N direction
|
//Index nc = blocking.nc(); // cache block size along the N direction
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ struct gemm_functor
|
|||||||
BlockingType& m_blocking;
|
BlockingType& m_blocking;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth,
|
template<int StorageOrder, typename LhsScalar, typename RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor=1,
|
||||||
bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> class gemm_blocking_space;
|
bool FiniteAtCompileTime = MaxRows!=Dynamic && MaxCols!=Dynamic && MaxDepth != Dynamic> class gemm_blocking_space;
|
||||||
|
|
||||||
template<typename _LhsScalar, typename _RhsScalar>
|
template<typename _LhsScalar, typename _RhsScalar>
|
||||||
@ -282,8 +282,8 @@ class level3_blocking
|
|||||||
inline RhsScalar* blockW() { return m_blockW; }
|
inline RhsScalar* blockW() { return m_blockW; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
|
||||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, true>
|
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, true>
|
||||||
: public level3_blocking<
|
: public level3_blocking<
|
||||||
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
||||||
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
||||||
@ -324,8 +324,8 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
|||||||
inline void allocateAll() {}
|
inline void allocateAll() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth>
|
template<int StorageOrder, typename _LhsScalar, typename _RhsScalar, int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
|
||||||
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, false>
|
class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, false>
|
||||||
: public level3_blocking<
|
: public level3_blocking<
|
||||||
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
|
||||||
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
|
||||||
@ -349,7 +349,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
|||||||
this->m_nc = Transpose ? rows : cols;
|
this->m_nc = Transpose ? rows : cols;
|
||||||
this->m_kc = depth;
|
this->m_kc = depth;
|
||||||
|
|
||||||
computeProductBlockingSizes<LhsScalar,RhsScalar>(this->m_kc, this->m_mc, this->m_nc);
|
computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc);
|
||||||
m_sizeA = this->m_mc * this->m_kc;
|
m_sizeA = this->m_mc * this->m_kc;
|
||||||
m_sizeB = this->m_kc * this->m_nc;
|
m_sizeB = this->m_kc * this->m_nc;
|
||||||
m_sizeW = this->m_kc*Traits::WorkSpaceFactor;
|
m_sizeW = this->m_kc*Traits::WorkSpaceFactor;
|
||||||
|
@ -36,14 +36,15 @@ struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,
|
|||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index size, Index cols,
|
Index size, Index cols,
|
||||||
const Scalar* tri, Index triStride,
|
const Scalar* tri, Index triStride,
|
||||||
Scalar* _other, Index otherStride)
|
Scalar* _other, Index otherStride,
|
||||||
|
level3_blocking<Scalar,Scalar>& blocking)
|
||||||
{
|
{
|
||||||
triangular_solve_matrix<
|
triangular_solve_matrix<
|
||||||
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||||
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
|
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
|
||||||
NumTraits<Scalar>::IsComplex && Conjugate,
|
NumTraits<Scalar>::IsComplex && Conjugate,
|
||||||
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
||||||
::run(size, cols, tri, triStride, _other, otherStride);
|
::run(size, cols, tri, triStride, _other, otherStride, blocking);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -55,7 +56,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
|||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize,
|
||||||
const Scalar* _tri, Index triStride,
|
const Scalar* _tri, Index triStride,
|
||||||
Scalar* _other, Index otherStride)
|
Scalar* _other, Index otherStride,
|
||||||
|
level3_blocking<Scalar,Scalar>& blocking)
|
||||||
{
|
{
|
||||||
Index cols = otherSize;
|
Index cols = otherSize;
|
||||||
const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride);
|
const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride);
|
||||||
@ -67,17 +69,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
|||||||
IsLower = (Mode&Lower) == Lower
|
IsLower = (Mode&Lower) == Lower
|
||||||
};
|
};
|
||||||
|
|
||||||
Index kc = size; // cache block size along the K direction
|
Index kc = blocking.kc(); // cache block size along the K direction
|
||||||
Index mc = size; // cache block size along the M direction
|
Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction
|
||||||
Index nc = cols; // cache block size along the N direction
|
|
||||||
computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
|
|
||||||
|
|
||||||
|
std::size_t sizeA = kc*mc;
|
||||||
|
std::size_t sizeB = kc*cols;
|
||||||
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
||||||
std::size_t sizeB = sizeW + kc*cols;
|
|
||||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||||
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||||
Scalar* blockB = allocatedBlockB + sizeW;
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
||||||
Scalar* blockW = allocatedBlockB;
|
|
||||||
|
|
||||||
conj_if<Conjugate> conj;
|
conj_if<Conjugate> conj;
|
||||||
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
|
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
|
||||||
@ -181,7 +182,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
|||||||
{
|
{
|
||||||
pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc);
|
pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc);
|
||||||
|
|
||||||
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1));
|
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0, blockW);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -197,7 +198,8 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
|||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize,
|
||||||
const Scalar* _tri, Index triStride,
|
const Scalar* _tri, Index triStride,
|
||||||
Scalar* _other, Index otherStride)
|
Scalar* _other, Index otherStride,
|
||||||
|
level3_blocking<Scalar,Scalar>& blocking)
|
||||||
{
|
{
|
||||||
Index rows = otherSize;
|
Index rows = otherSize;
|
||||||
const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride);
|
const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride);
|
||||||
@ -210,19 +212,16 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
|||||||
IsLower = (Mode&Lower) == Lower
|
IsLower = (Mode&Lower) == Lower
|
||||||
};
|
};
|
||||||
|
|
||||||
// Index kc = std::min<Index>(Traits::Max_kc/4,size); // cache block size along the K direction
|
Index kc = blocking.kc(); // cache block size along the K direction
|
||||||
// Index mc = std::min<Index>(Traits::Max_mc,size); // cache block size along the M direction
|
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||||
// check that !!!!
|
|
||||||
Index kc = size; // cache block size along the K direction
|
|
||||||
Index mc = size; // cache block size along the M direction
|
|
||||||
Index nc = rows; // cache block size along the N direction
|
|
||||||
computeProductBlockingSizes<Scalar,Scalar,4>(kc, mc, nc);
|
|
||||||
|
|
||||||
|
std::size_t sizeA = kc*mc;
|
||||||
|
std::size_t sizeB = kc*size;
|
||||||
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
||||||
std::size_t sizeB = sizeW + kc*size;
|
|
||||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||||
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||||
Scalar* blockB = allocatedBlockB + sizeW;
|
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
||||||
|
|
||||||
conj_if<Conjugate> conj;
|
conj_if<Conjugate> conj;
|
||||||
gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
|
gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
|
||||||
@ -289,7 +288,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
|||||||
Scalar(-1),
|
Scalar(-1),
|
||||||
actual_kc, actual_kc, // strides
|
actual_kc, actual_kc, // strides
|
||||||
panelOffset, panelOffset, // offsets
|
panelOffset, panelOffset, // offsets
|
||||||
allocatedBlockB); // workspace
|
blockW); // workspace
|
||||||
}
|
}
|
||||||
|
|
||||||
// unblocked triangular solve
|
// unblocked triangular solve
|
||||||
@ -320,7 +319,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
|||||||
if (rs>0)
|
if (rs>0)
|
||||||
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
|
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
|
||||||
actual_mc, actual_kc, rs, Scalar(-1),
|
actual_mc, actual_kc, rs, Scalar(-1),
|
||||||
-1, -1, 0, 0, allocatedBlockB);
|
-1, -1, 0, 0, blockW);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user