mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 03:09:01 +08:00
trsv: simplifications/cleaning
This commit is contained in:
parent
0e6c1170ab
commit
3fdea699b8
@ -29,7 +29,7 @@ namespace internal {
|
||||
|
||||
// Forward declarations:
|
||||
// The following two routines are implemented in the products/TriangularSolver*.h files
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
|
||||
struct triangular_solve_vector;
|
||||
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
||||
@ -55,13 +55,12 @@ template<typename Lhs, typename Rhs,
|
||||
int Side, // can be OnTheLeft/OnTheRight
|
||||
int Mode, // can be Upper/Lower | UnitDiag
|
||||
int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
|
||||
int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||
int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
|
||||
>
|
||||
struct triangular_solver_selector;
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
|
||||
template<typename Lhs, typename Rhs, int Side, int Mode>
|
||||
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
|
||||
{
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
@ -86,8 +85,8 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
|
||||
MappedRhs(actualRhs,rhs.size()) = rhs;
|
||||
}
|
||||
|
||||
|
||||
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
|
||||
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
|
||||
(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
|
||||
|
||||
if(!useRhsDirectly)
|
||||
@ -98,22 +97,9 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
|
||||
}
|
||||
};
|
||||
|
||||
// transpose OnTheRight cases for vectors
|
||||
template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder>
|
||||
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1>
|
||||
{
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
Transpose<Rhs> rhsTr(rhs);
|
||||
Transpose<Lhs> lhsTr(lhs);
|
||||
triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// the rhs is a matrix
|
||||
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
|
||||
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic>
|
||||
template<typename Lhs, typename Rhs, int Side, int Mode>
|
||||
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef typename Rhs::Index Index;
|
||||
@ -122,7 +108,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dyn
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder,
|
||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
|
||||
}
|
||||
@ -146,7 +132,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
if (Index>0)
|
||||
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose().cwiseProduct(rhs.template segment<Index>(S)).sum();
|
||||
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose()
|
||||
.cwiseProduct(rhs.template segment<Index>(S)).sum();
|
||||
|
||||
if(!(Mode & UnitDiag))
|
||||
rhs.coeffRef(I) /= lhs.coeff(I,I);
|
||||
@ -160,8 +147,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
||||
static void run(const Lhs&, Rhs&) {}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> {
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||
};
|
||||
|
@ -27,9 +27,21 @@
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
|
||||
{
|
||||
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||
{
|
||||
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
|
||||
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
|
||||
Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
|
||||
>::run(size, _lhs, lhsStride, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// forward and backward substitution, row-major, rhs is a vector
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
|
||||
{
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower)
|
||||
@ -83,7 +95,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, Row
|
||||
|
||||
// forward and backward substitution, column-major, rhs is a vector
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
|
||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
|
||||
{
|
||||
enum {
|
||||
IsLower = ((Mode&Lower)==Lower)
|
||||
|
Loading…
x
Reference in New Issue
Block a user