From 3fdea699b80c429738ac0af8c9b7479594b90583 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 5 Nov 2010 12:54:32 +0100 Subject: [PATCH] trsv: simplifications/cleaning --- Eigen/src/Core/SolveTriangular.h | 37 ++++++------------- .../Core/products/TriangularSolverVector.h | 16 +++++++- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index b950d2c31..d85f967cb 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -29,7 +29,7 @@ namespace internal { // Forward declarations: // The following two routines are implemented in the products/TriangularSolver*.h files -template +template struct triangular_solve_vector; template @@ -55,13 +55,12 @@ template::Unrolling, - int StorageOrder = (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, int RhsVectors = trsolve_traits::RhsVectors > struct triangular_solver_selector; -template -struct triangular_solver_selector +template +struct triangular_solver_selector { typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; @@ -86,8 +85,8 @@ struct triangular_solver_selector + triangular_solve_vector ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); if(!useRhsDirectly) @@ -98,22 +97,9 @@ struct triangular_solver_selector -struct triangular_solver_selector -{ - static void run(const Lhs& lhs, Rhs& rhs) - { - Transpose rhsTr(rhs); - Transpose lhsTr(lhs); - triangular_solver_selector,Transpose,OnTheLeft,TriangularView::TransposeMode>::run(lhsTr,rhsTr); - } -}; - - // the rhs is a matrix -template -struct triangular_solver_selector +template +struct triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Index Index; @@ -122,7 +108,7 @@ struct triangular_solver_selector ::run(lhs.rows(), Side==OnTheLeft? rhs.cols() : rhs.rows(), &actualLhs.coeff(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride()); } @@ -146,7 +132,8 @@ struct triangular_solver_unroller { static void run(const Lhs& lhs, Rhs& rhs) { if (Index>0) - rhs.coeffRef(I) -= lhs.row(I).template segment(S).transpose().cwiseProduct(rhs.template segment(S)).sum(); + rhs.coeffRef(I) -= lhs.row(I).template segment(S).transpose() + .cwiseProduct(rhs.template segment(S)).sum(); if(!(Mode & UnitDiag)) rhs.coeffRef(I) /= lhs.coeff(I,I); @@ -160,8 +147,8 @@ struct triangular_solver_unroller { static void run(const Lhs&, Rhs&) {} }; -template -struct triangular_solver_selector { +template +struct triangular_solver_selector { static void run(const Lhs& lhs, Rhs& rhs) { triangular_solver_unroller::run(lhs,rhs); } }; diff --git a/Eigen/src/Core/products/TriangularSolverVector.h b/Eigen/src/Core/products/TriangularSolverVector.h index fcf8bcae0..25e739178 100644 --- a/Eigen/src/Core/products/TriangularSolverVector.h +++ b/Eigen/src/Core/products/TriangularSolverVector.h @@ -27,9 +27,21 @@ namespace internal { +template +struct triangular_solve_vector +{ + static void run(int size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs) + { + triangular_solve_vector::run(size, _lhs, lhsStride, rhs); + } +}; + // forward and backward substitution, row-major, rhs is a vector template -struct triangular_solve_vector +struct triangular_solve_vector { enum { IsLower = ((Mode&Lower)==Lower) @@ -83,7 +95,7 @@ struct triangular_solve_vector -struct triangular_solve_vector +struct triangular_solve_vector { enum { IsLower = ((Mode&Lower)==Lower)