trsv: simplifications/cleaning

This commit is contained in:
Gael Guennebaud 2010-11-05 12:54:32 +01:00
parent 0e6c1170ab
commit 3fdea699b8
2 changed files with 26 additions and 27 deletions

View File

@ -29,7 +29,7 @@ namespace internal {
// Forward declarations: // Forward declarations:
// The following two routines are implemented in the products/TriangularSolver*.h files // The following two routines are implemented in the products/TriangularSolver*.h files
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder> template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector; struct triangular_solve_vector;
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
@ -55,13 +55,12 @@ template<typename Lhs, typename Rhs,
int Side, // can be OnTheLeft/OnTheRight int Side, // can be OnTheLeft/OnTheRight
int Mode, // can be Upper/Lower | UnitDiag int Mode, // can be Upper/Lower | UnitDiag
int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling, int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
> >
struct triangular_solver_selector; struct triangular_solver_selector;
template<typename Lhs, typename Rhs, int Mode, int StorageOrder> template<typename Lhs, typename Rhs, int Side, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1> struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
{ {
typedef typename Lhs::Scalar LhsScalar; typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar; typedef typename Rhs::Scalar RhsScalar;
@ -86,8 +85,8 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
MappedRhs(actualRhs,rhs.size()) = rhs; MappedRhs(actualRhs,rhs.size()) = rhs;
} }
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder> (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
if(!useRhsDirectly) if(!useRhsDirectly)
@ -98,22 +97,9 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
} }
}; };
// transpose OnTheRight cases for vectors
template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder>
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1>
{
static void run(const Lhs& lhs, Rhs& rhs)
{
Transpose<Rhs> rhsTr(rhs);
Transpose<Lhs> lhsTr(lhs);
triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr);
}
};
// the rhs is a matrix // the rhs is a matrix
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder> template<typename Lhs, typename Rhs, int Side, int Mode>
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic> struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
{ {
typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Scalar Scalar;
typedef typename Rhs::Index Index; typedef typename Rhs::Index Index;
@ -122,7 +108,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dyn
static void run(const Lhs& lhs, Rhs& rhs) static void run(const Lhs& lhs, Rhs& rhs)
{ {
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs); const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder, triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride()); ::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
} }
@ -146,7 +132,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
static void run(const Lhs& lhs, Rhs& rhs) static void run(const Lhs& lhs, Rhs& rhs)
{ {
if (Index>0) if (Index>0)
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose().cwiseProduct(rhs.template segment<Index>(S)).sum(); rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose()
.cwiseProduct(rhs.template segment<Index>(S)).sum();
if(!(Mode & UnitDiag)) if(!(Mode & UnitDiag))
rhs.coeffRef(I) /= lhs.coeff(I,I); rhs.coeffRef(I) /= lhs.coeff(I,I);
@ -160,8 +147,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
static void run(const Lhs&, Rhs&) {} static void run(const Lhs&, Rhs&) {}
}; };
template<typename Lhs, typename Rhs, int Mode, int StorageOrder> template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> { struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
static void run(const Lhs& lhs, Rhs& rhs) static void run(const Lhs& lhs, Rhs& rhs)
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
}; };

View File

@ -27,9 +27,21 @@
namespace internal { namespace internal {
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
{
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
>::run(size, _lhs, lhsStride, rhs);
}
};
// forward and backward substitution, row-major, rhs is a vector // forward and backward substitution, row-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate> template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor> struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
{ {
enum { enum {
IsLower = ((Mode&Lower)==Lower) IsLower = ((Mode&Lower)==Lower)
@ -83,7 +95,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, Row
// forward and backward substitution, column-major, rhs is a vector // forward and backward substitution, column-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate> template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor> struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
{ {
enum { enum {
IsLower = ((Mode&Lower)==Lower) IsLower = ((Mode&Lower)==Lower)