From c253cc3d5346a5297dd623a4d92453ed0bb1a77c Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 23 Sep 2010 09:51:08 -0400 Subject: [PATCH] SVD: * fix unit test for rectangular matrices. * enforce that rows >= cols since various places in the code assume that. --- Eigen/src/SVD/SVD.h | 9 ++++----- test/svd.cpp | 27 +++++++++++++++------------ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/Eigen/src/SVD/SVD.h b/Eigen/src/SVD/SVD.h index 7379b72f4..15e284dea 100644 --- a/Eigen/src/SVD/SVD.h +++ b/Eigen/src/SVD/SVD.h @@ -38,6 +38,8 @@ template struct ei_svd_solve_impl; * * This class performs a standard SVD decomposition of a real matrix A of size \c M x \c N. * + * Requires M >= N, in other words, at least as many rows as columns. + * * \sa MatrixBase::SVD() */ template class SVD @@ -440,11 +442,8 @@ SVD& SVD::compute(const MatrixType& matrix) } } } - m_matU.setZero(); - if (m>=n) - m_matU.block(0,0,m,n) = A; - else - m_matU = A.block(0,0,m,m); + m_matU.leftCols(n) = A; + m_matU.rightCols(m-n).setZero(); m_isInitialized = true; return *this; diff --git a/test/svd.cpp b/test/svd.cpp index 3003c9dff..8ba510d9b 100644 --- a/test/svd.cpp +++ b/test/svd.cpp @@ -37,30 +37,30 @@ template void svd(const MatrixType& m) typedef typename MatrixType::Scalar Scalar; typedef typename NumTraits::Real RealScalar; MatrixType a = MatrixType::Random(rows,cols); - Matrix b = - Matrix::Random(rows,1); Matrix x(cols,1), x2(cols,1); { SVD svd(a); MatrixType sigma = MatrixType::Zero(rows,cols); MatrixType matU = MatrixType::Zero(rows,rows); + MatrixType matV = MatrixType::Zero(cols,cols); + sigma.diagonal() = svd.singularValues(); matU = svd.matrixU(); - VERIFY_IS_APPROX(a, matU * sigma * svd.matrixV().transpose()); + //VERIFY_IS_UNITARY(matU); + matV = svd.matrixV(); + //VERIFY_IS_UNITARY(matV); + VERIFY_IS_APPROX(a, matU * sigma * matV.transpose()); } if (rows>=cols) { - if (ei_is_same_type::ret) - { - MatrixType a1 = MatrixType::Random(rows,cols); - a += a * a.adjoint() + a1 * a1.adjoint(); - } SVD svd(a); - x = svd.solve(b); - VERIFY_IS_APPROX(a * x,b); + Matrix x = Matrix::Random(cols,1); + Matrix b = a * x; + Matrix result = svd.solve(b); + VERIFY_IS_APPROX(a * result, b); } @@ -102,8 +102,11 @@ void test_svd() for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_1( svd(Matrix3f()) ); CALL_SUBTEST_2( svd(Matrix4d()) ); - CALL_SUBTEST_3( svd(MatrixXf(7,7)) ); - CALL_SUBTEST_4( svd(MatrixXd(14,7)) ); + int cols = ei_random(2,50); + int rows = cols + ei_random(0,50); + + CALL_SUBTEST_3( svd(MatrixXf(rows,cols)) ); + CALL_SUBTEST_4( svd(MatrixXd(rows,cols)) ); // complex are not implemented yet // CALL_SUBTEST(svd(MatrixXcd(6,6)) ); // CALL_SUBTEST(svd(MatrixXcf(3,3)) );