diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index c25317989..960da31f3 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -53,7 +53,8 @@ struct ei_triangular_solver_selector; template struct ei_triangular_solver_selector { - typedef typename Rhs::Scalar Scalar; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; typedef ei_blas_traits LhsProductTraits; typedef typename LhsProductTraits::ExtractType ActualLhsType; typedef typename Lhs::Index Index; @@ -81,12 +82,12 @@ struct ei_triangular_solver_selector::run( + ei_general_matrix_vector_product::run( actualPanelWidth, r, &(actualLhs.const_cast_derived().coeffRef(startRow,startCol)), actualLhs.outerStride(), &(other.coeffRef(startCol)), other.innerStride(), &other.coeffRef(startRow), other.innerStride(), - Scalar(-1)); + RhsScalar(-1)); } for(Index k=0; k struct ei_triangular_solver_selector { - typedef typename Rhs::Scalar Scalar; - typedef typename ei_packet_traits::type Packet; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; typedef ei_blas_traits LhsProductTraits; typedef typename LhsProductTraits::ExtractType ActualLhsType; typedef typename Lhs::Index Index; enum { - PacketSize = ei_packet_traits::size, IsLower = ((Mode&Lower)==Lower) }; @@ -148,11 +148,11 @@ struct ei_triangular_solver_selector::run( + ei_general_matrix_vector_product::run( r, actualPanelWidth, &(actualLhs.const_cast_derived().coeffRef(endBlock,startBlock)), actualLhs.outerStride(), &other.coeff(startBlock), other.innerStride(), - &(other.coeffRef(endBlock, 0)), other.innerStride(), Scalar(-1)); + &(other.coeffRef(endBlock, 0)), other.innerStride(), RhsScalar(-1)); } } } diff --git a/test/cholesky.cpp b/test/cholesky.cpp index 136c69266..0edf9a793 100644 --- a/test/cholesky.cpp +++ b/test/cholesky.cpp @@ -170,6 +170,66 @@ template void cholesky(const MatrixType& m) } +template void cholesky_cplx(const MatrixType& m) +{ + // classic test + cholesky(m); + + // test mixing real/scalar types + + typedef typename MatrixType::Index Index; + + Index rows = m.rows(); + Index cols = m.cols(); + + typedef typename MatrixType::Scalar Scalar; + typedef typename NumTraits::Real RealScalar; + typedef Matrix RealMatrixType; + typedef Matrix VectorType; + + RealMatrixType a0 = RealMatrixType::Random(rows,cols); + VectorType vecB = VectorType::Random(rows), vecX(rows); + MatrixType matB = MatrixType::Random(rows,cols), matX(rows,cols); + RealMatrixType symm = a0 * a0.adjoint(); + // let's make sure the matrix is not singular or near singular + for (int k=0; k<3; ++k) + { + RealMatrixType a1 = RealMatrixType::Random(rows,cols); + symm += a1 * a1.adjoint(); + } + + { + RealMatrixType symmLo = symm.template triangularView(); + + LLT chollo(symmLo); + VERIFY_IS_APPROX(symm, chollo.reconstructedMatrix()); + vecX = chollo.solve(vecB); + VERIFY_IS_APPROX(symm * vecX, vecB); +// matX = chollo.solve(matB); +// VERIFY_IS_APPROX(symm * matX, matB); + } + + // LDLT + { + int sign = ei_random()%2 ? 1 : -1; + + if(sign == -1) + { + symm = -symm; // test a negative matrix + } + + RealMatrixType symmLo = symm.template triangularView(); + + LDLT ldltlo(symmLo); + VERIFY_IS_APPROX(symm, ldltlo.reconstructedMatrix()); + vecX = ldltlo.solve(vecB); + VERIFY_IS_APPROX(symm * vecX, vecB); +// matX = ldltlo.solve(matB); +// VERIFY_IS_APPROX(symm * matX, matB); + } + +} + template void cholesky_verify_assert() { MatrixType tmp; @@ -192,14 +252,16 @@ template void cholesky_verify_assert() void test_cholesky() { + int s; for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1( cholesky(Matrix()) ); - CALL_SUBTEST_2( cholesky(MatrixXd(1,1)) ); CALL_SUBTEST_3( cholesky(Matrix2d()) ); CALL_SUBTEST_4( cholesky(Matrix3f()) ); CALL_SUBTEST_5( cholesky(Matrix4d()) ); - CALL_SUBTEST_2( cholesky(MatrixXd(200,200)) ); - CALL_SUBTEST_6( cholesky(MatrixXcd(100,100)) ); + s = ei_random(1,200); + CALL_SUBTEST_2( cholesky(MatrixXd(s,s)) ); + s = ei_random(1,100); + CALL_SUBTEST_6( cholesky_cplx(MatrixXcd(s,s)) ); } CALL_SUBTEST_4( cholesky_verify_assert() );