mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 04:35:57 +08:00
trsv: simplifications/cleaning
This commit is contained in:
parent
0e6c1170ab
commit
3fdea699b8
@ -29,7 +29,7 @@ namespace internal {
|
|||||||
|
|
||||||
// Forward declarations:
|
// Forward declarations:
|
||||||
// The following two routines are implemented in the products/TriangularSolver*.h files
|
// The following two routines are implemented in the products/TriangularSolver*.h files
|
||||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
|
||||||
struct triangular_solve_vector;
|
struct triangular_solve_vector;
|
||||||
|
|
||||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
||||||
@ -55,13 +55,12 @@ template<typename Lhs, typename Rhs,
|
|||||||
int Side, // can be OnTheLeft/OnTheRight
|
int Side, // can be OnTheLeft/OnTheRight
|
||||||
int Mode, // can be Upper/Lower | UnitDiag
|
int Mode, // can be Upper/Lower | UnitDiag
|
||||||
int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
|
int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
|
||||||
int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
|
||||||
int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
|
int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
|
||||||
>
|
>
|
||||||
struct triangular_solver_selector;
|
struct triangular_solver_selector;
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
template<typename Lhs, typename Rhs, int Side, int Mode>
|
||||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
|
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
|
||||||
{
|
{
|
||||||
typedef typename Lhs::Scalar LhsScalar;
|
typedef typename Lhs::Scalar LhsScalar;
|
||||||
typedef typename Rhs::Scalar RhsScalar;
|
typedef typename Rhs::Scalar RhsScalar;
|
||||||
@ -86,8 +85,8 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
|
|||||||
MappedRhs(actualRhs,rhs.size()) = rhs;
|
MappedRhs(actualRhs,rhs.size()) = rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
|
||||||
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
|
(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
|
||||||
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
|
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
|
||||||
|
|
||||||
if(!useRhsDirectly)
|
if(!useRhsDirectly)
|
||||||
@ -98,22 +97,9 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// transpose OnTheRight cases for vectors
|
|
||||||
template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder>
|
|
||||||
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1>
|
|
||||||
{
|
|
||||||
static void run(const Lhs& lhs, Rhs& rhs)
|
|
||||||
{
|
|
||||||
Transpose<Rhs> rhsTr(rhs);
|
|
||||||
Transpose<Lhs> lhsTr(lhs);
|
|
||||||
triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// the rhs is a matrix
|
// the rhs is a matrix
|
||||||
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
|
template<typename Lhs, typename Rhs, int Side, int Mode>
|
||||||
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic>
|
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
|
||||||
{
|
{
|
||||||
typedef typename Rhs::Scalar Scalar;
|
typedef typename Rhs::Scalar Scalar;
|
||||||
typedef typename Rhs::Index Index;
|
typedef typename Rhs::Index Index;
|
||||||
@ -122,7 +108,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dyn
|
|||||||
static void run(const Lhs& lhs, Rhs& rhs)
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
{
|
{
|
||||||
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
|
||||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder,
|
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||||
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
|
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
|
||||||
}
|
}
|
||||||
@ -146,7 +132,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
|
|||||||
static void run(const Lhs& lhs, Rhs& rhs)
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
{
|
{
|
||||||
if (Index>0)
|
if (Index>0)
|
||||||
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose().cwiseProduct(rhs.template segment<Index>(S)).sum();
|
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose()
|
||||||
|
.cwiseProduct(rhs.template segment<Index>(S)).sum();
|
||||||
|
|
||||||
if(!(Mode & UnitDiag))
|
if(!(Mode & UnitDiag))
|
||||||
rhs.coeffRef(I) /= lhs.coeff(I,I);
|
rhs.coeffRef(I) /= lhs.coeff(I,I);
|
||||||
@ -160,8 +147,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
|||||||
static void run(const Lhs&, Rhs&) {}
|
static void run(const Lhs&, Rhs&) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
template<typename Lhs, typename Rhs, int Mode>
|
||||||
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> {
|
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
|
||||||
static void run(const Lhs& lhs, Rhs& rhs)
|
static void run(const Lhs& lhs, Rhs& rhs)
|
||||||
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||||
};
|
};
|
||||||
|
@ -27,9 +27,21 @@
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
|
||||||
|
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
|
||||||
|
{
|
||||||
|
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
|
||||||
|
{
|
||||||
|
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
|
||||||
|
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
|
||||||
|
Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
|
||||||
|
>::run(size, _lhs, lhsStride, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// forward and backward substitution, row-major, rhs is a vector
|
// forward and backward substitution, row-major, rhs is a vector
|
||||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
|
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
|
||||||
{
|
{
|
||||||
enum {
|
enum {
|
||||||
IsLower = ((Mode&Lower)==Lower)
|
IsLower = ((Mode&Lower)==Lower)
|
||||||
@ -83,7 +95,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, Row
|
|||||||
|
|
||||||
// forward and backward substitution, column-major, rhs is a vector
|
// forward and backward substitution, column-major, rhs is a vector
|
||||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
|
||||||
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
|
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
|
||||||
{
|
{
|
||||||
enum {
|
enum {
|
||||||
IsLower = ((Mode&Lower)==Lower)
|
IsLower = ((Mode&Lower)==Lower)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user