From f3fde74695eff236fe24b05ffb053d3890346420 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sun, 26 Jul 2009 13:01:37 +0200 Subject: [PATCH] finalize trsm: works in all situations, and it is now used by solve() and solveInPlace() --- Eigen/src/Core/SolveTriangular.h | 31 ++++- .../Core/products/TriangularSolverMatrix.h | 111 ++++++++++-------- test/CMakeLists.txt | 1 + test/product_trsm.cpp | 95 +++++++++++++++ 4 files changed, 185 insertions(+), 53 deletions(-) create mode 100644 test/product_trsm.cpp diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index cb162ca91..d0656eacb 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -29,13 +29,14 @@ template struct ei_triangular_solver_selector; -// forward and backward substitution, row-major +// forward and backward substitution, row-major, rhs is a vector template -struct ei_triangular_solver_selector +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef ei_blas_traits LhsProductTraits; @@ -89,9 +90,9 @@ struct ei_triangular_solver_selector } }; -// forward and backward substitution, column-major +// forward and backward substitution, column-major, rhs is a vector template -struct ei_triangular_solver_selector +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits::type Packet; @@ -150,6 +151,24 @@ struct ei_triangular_solver_selector } }; +template +struct ei_triangular_solve_matrix; + +// the rhs is a matrix +template +struct ei_triangular_solver_selector +{ + typedef typename Rhs::Scalar Scalar; + typedef ei_blas_traits LhsProductTraits; + typedef typename LhsProductTraits::ActualXprType ActualLhsType; + static void run(const Lhs& lhs, Rhs& rhs) + { + const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs); + ei_triangular_solve_matrix + ::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride()); + } +}; + /*************************************************************************** * meta-unrolling implementation ***************************************************************************/ @@ -184,7 +203,7 @@ struct ei_triangular_solver_unroller { }; template -struct ei_triangular_solver_selector { +struct ei_triangular_solver_selector { static void run(const Lhs& lhs, Rhs& rhs) { ei_triangular_solver_unroller::run(lhs,rhs); } }; diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index eeb445f0b..550076f68 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -26,63 +26,37 @@ #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H template -struct ei_gemm_pack_rhs_panel +struct ei_gemm_pack_rhs_panel; + +// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix +template +struct ei_triangular_solve_matrix { - enum { PacketSize = ei_packet_traits::size }; - void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + static EIGEN_DONT_INLINE void run( + int size, int cols, + const Scalar* lhs, int lhsStride, + Scalar* _rhs, int rhsStride) { - int packet_cols = (cols/nr) * nr; - int count = 0; - for(int j2=0; j2 > rhs(_rhs, rhsStride, cols); + Matrix aux = rhs.block(0,0,size,cols); + ei_triangular_solve_matrix + ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride()); + rhs.block(0,0,size,cols) = aux; } }; /* Optimized triangular solver with multiple right hand side (_TRSM) */ -template -struct ei_triangular_solve_matrix// +template +struct ei_triangular_solve_matrix { - static EIGEN_DONT_INLINE void run( int size, int cols, const Scalar* _lhs, int lhsStride, Scalar* _rhs, int rhsStride) { ei_const_blas_data_mapper lhs(_lhs,lhsStride); - ei_blas_data_mapper rhs(_rhs,rhsStride); + ei_blas_data_mapper rhs(_rhs,rhsStride); typedef ei_product_blocking_traits Blocking; enum { @@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix// Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); - ei_gebp_kernel > gebp_kernel; + ei_conj_if conj; + ei_gebp_kernel > gebp_kernel; ei_gemm_pack_lhs pack_lhs; for(int k2=IsLowerTriangular ? 0 : size; @@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix// int s = IsLowerTriangular ? k2+k1 : i+1; int rs = actualPanelWidth - k - 1; // remaining size - Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i); + Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i)); for (int j=0; j const Scalar* l = &lhs(i,s); Scalar* r = &rhs(s,j); for (int i3=0; i3 Scalar* r = &rhs(s,j); const Scalar* l = &lhs(s,i); for (int i3=0;i3 } }; +template +struct ei_gemm_pack_rhs_panel +{ + enum { PacketSize = ei_packet_traits::size }; + void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) + { + int packet_cols = (cols/nr) * nr; + int count = 0; + for(int j2=0; j2 +// +// Eigen is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 3 of the License, or (at your option) any later version. +// +// Alternatively, you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation; either version 2 of +// the License, or (at your option) any later version. +// +// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License and a copy of the GNU General Public License along with +// Eigen. If not, see . + +#include "main.h" + +template +void solve_ref(const Lhs& lhs, Rhs& rhs) +{ + for (int j=0; j void trsm(int size,int cols) +{ + typedef typename NumTraits::Real RealScalar; + + Matrix cmLhs(size,size); + Matrix rmLhs(size,size); + + Matrix cmRef(size,cols), cmRhs(size,cols); + Matrix rmRef(size,cols), rmRhs(size,cols); + + cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10; + rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10; + + cmRhs.setRandom(); cmRef = cmRhs; + cmLhs.conjugate().template triangularView().solveInPlace(cmRhs); + solve_ref(cmLhs.conjugate().template triangularView(),cmRef); + VERIFY_IS_APPROX(cmRhs, cmRef); + + cmRhs.setRandom(); cmRef = cmRhs; + cmLhs.conjugate().template triangularView().solveInPlace(cmRhs); + solve_ref(cmLhs.conjugate().template triangularView(),cmRef); + VERIFY_IS_APPROX(cmRhs, cmRef); + + rmRhs.setRandom(); rmRef = rmRhs; + cmLhs.template triangularView().solveInPlace(rmRhs); + solve_ref(cmLhs.template triangularView(),rmRef); + VERIFY_IS_APPROX(rmRhs, rmRef); + + rmRhs.setRandom(); rmRef = rmRhs; + cmLhs.template triangularView().solveInPlace(rmRhs); + solve_ref(cmLhs.template triangularView(),rmRef); + VERIFY_IS_APPROX(rmRhs, rmRef); + + + cmRhs.setRandom(); cmRef = cmRhs; + rmLhs.template triangularView().solveInPlace(cmRhs); + solve_ref(rmLhs.template triangularView(),cmRef); + VERIFY_IS_APPROX(cmRhs, cmRef); + + cmRhs.setRandom(); cmRef = cmRhs; + rmLhs.template triangularView().solveInPlace(cmRhs); + solve_ref(rmLhs.template triangularView(),cmRef); + VERIFY_IS_APPROX(cmRhs, cmRef); + + rmRhs.setRandom(); rmRef = rmRhs; + rmLhs.template triangularView().solveInPlace(rmRhs); + solve_ref(rmLhs.template triangularView(),rmRef); + VERIFY_IS_APPROX(rmRhs, rmRef); + + rmRhs.setRandom(); rmRef = rmRhs; + rmLhs.template triangularView().solveInPlace(rmRhs); + solve_ref(rmLhs.template triangularView(),rmRef); + VERIFY_IS_APPROX(rmRhs, rmRef); +} +void test_product_trsm() +{ + for(int i = 0; i < g_repeat ; i++) + { + trsm(ei_random(1,320),ei_random(1,320)); + trsm >(ei_random(1,320),ei_random(1,320)); + } +}