trsv: simplifications/cleaning

This commit is contained in:
Gael Guennebaud 2010-11-05 12:54:32 +01:00
parent 0e6c1170ab
commit 3fdea699b8
2 changed files with 26 additions and 27 deletions

View File

@ -29,7 +29,7 @@ namespace internal {
// Forward declarations:
// The following two routines are implemented in the products/TriangularSolver*.h files
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector;
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
@ -55,13 +55,12 @@ template<typename Lhs, typename Rhs,
int Side, // can be OnTheLeft/OnTheRight
int Mode, // can be Upper/Lower | UnitDiag
int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling,
int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors
>
struct triangular_solver_selector;
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrder,1>
template<typename Lhs, typename Rhs, int Side, int Mode>
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1>
{
typedef typename Lhs::Scalar LhsScalar;
typedef typename Rhs::Scalar RhsScalar;
@ -86,8 +85,8 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
MappedRhs(actualRhs,rhs.size()) = rhs;
}
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Mode, LhsProductTraits::NeedToConjugate, StorageOrder>
triangular_solve_vector<LhsScalar, RhsScalar, typename Lhs::Index, Side, Mode, LhsProductTraits::NeedToConjugate,
(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor>
::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs);
if(!useRhsDirectly)
@ -98,22 +97,9 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,StorageOrde
}
};
// transpose OnTheRight cases for vectors
template<typename Lhs, typename Rhs, int Mode, int Unrolling, int StorageOrder>
struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,Unrolling,StorageOrder,1>
{
static void run(const Lhs& lhs, Rhs& rhs)
{
Transpose<Rhs> rhsTr(rhs);
Transpose<Lhs> lhsTr(lhs);
triangular_solver_selector<Transpose<Lhs>,Transpose<Rhs>,OnTheLeft,TriangularView<Lhs,Mode>::TransposeMode>::run(lhsTr,rhsTr);
}
};
// the rhs is a matrix
template<typename Lhs, typename Rhs, int Side, int Mode, int StorageOrder>
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dynamic>
template<typename Lhs, typename Rhs, int Side, int Mode>
struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
{
typedef typename Rhs::Scalar Scalar;
typedef typename Rhs::Index Index;
@ -122,7 +108,7 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,StorageOrder,Dyn
static void run(const Lhs& lhs, Rhs& rhs)
{
const ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,StorageOrder,
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride());
}
@ -146,7 +132,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,false> {
static void run(const Lhs& lhs, Rhs& rhs)
{
if (Index>0)
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose().cwiseProduct(rhs.template segment<Index>(S)).sum();
rhs.coeffRef(I) -= lhs.row(I).template segment<Index>(S).transpose()
.cwiseProduct(rhs.template segment<Index>(S)).sum();
if(!(Mode & UnitDiag))
rhs.coeffRef(I) /= lhs.coeff(I,I);
@ -160,8 +147,8 @@ struct triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
static void run(const Lhs&, Rhs&) {}
};
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,StorageOrder,1> {
template<typename Lhs, typename Rhs, int Mode>
struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> {
static void run(const Lhs& lhs, Rhs& rhs)
{ triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
};

View File

@ -27,9 +27,21 @@
namespace internal {
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
{
static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
{
triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
>::run(size, _lhs, lhsStride, rhs);
}
};
// forward and backward substitution, row-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, RowMajor>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
{
enum {
IsLower = ((Mode&Lower)==Lower)
@ -83,7 +95,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, Row
// forward and backward substitution, column-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, Mode, Conjugate, ColMajor>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
{
enum {
IsLower = ((Mode&Lower)==Lower)