From ccdcebcf03a529b429feda831ff4b44f8433e045 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 15 Jan 2009 18:52:14 +0000 Subject: [PATCH] Sparse module: add support for sparse selfadjoint * dense --- Eigen/src/Sparse/SparseMatrixBase.h | 4 +-- Eigen/src/Sparse/SparseProduct.h | 49 +++++++++++++++++++++++++++-- test/sparse.h | 7 ++++- test/sparse_basic.cpp | 33 +++++++++++++++++++ 4 files changed, 87 insertions(+), 6 deletions(-) diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index 953dd30a7..dd4eeff16 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h @@ -88,7 +88,7 @@ template class SparseMatrixBase /** \internal the return type of MatrixBase::imag() */ typedef CwiseUnaryOp, Derived> ImagReturnType; /** \internal the return type of MatrixBase::adjoint() */ - typedef Eigen::Transpose::type> > + typedef SparseTranspose::type> /*>*/ AdjointReturnType; #ifndef EIGEN_PARSED_BY_DOXYGEN @@ -322,7 +322,7 @@ template class SparseMatrixBase SparseTranspose transpose() { return derived(); } const SparseTranspose transpose() const { return derived(); } // void transposeInPlace(); - // const AdjointReturnType adjoint() const; + const AdjointReturnType adjoint() const { return conjugate()/*.nestByValue()*/; } SparseInnerVector innerVector(int outer); const SparseInnerVector innerVector(int outer) const; diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h index 29f5208fa..5a2c294a2 100644 --- a/Eigen/src/Sparse/SparseProduct.h +++ b/Eigen/src/Sparse/SparseProduct.h @@ -294,17 +294,60 @@ inline Derived& SparseMatrixBase::operator=(const SparseProduct +// template +// Derived& MatrixBase::lazyAssign(const SparseProduct& product) +// { +// typedef typename ei_cleantype::type _Lhs; +// typedef typename _Lhs::InnerIterator LhsInnerIterator; +// enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; +// derived().setZero(); +// for (int j=0; j template Derived& MatrixBase::lazyAssign(const SparseProduct& product) { typedef typename ei_cleantype::type _Lhs; typedef typename _Lhs::InnerIterator LhsInnerIterator; - enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; + enum { + LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit, + LhsIsSelfAdjoint = (_Lhs::Flags&SelfAdjointBit)==SelfAdjointBit, + ProcessFirstHalf = LhsIsSelfAdjoint + && ( ((_Lhs::Flags&(UpperTriangularBit|LowerTriangularBit))==0) + || ( (_Lhs::Flags&UpperTriangularBit) && !LhsIsRowMajor) + || ( (_Lhs::Flags&LowerTriangularBit) && LhsIsRowMajor) ), + ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf) + }; derived().setZero(); for (int j=0; j void sparse_basic(int rows, int cols) VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); } + + // test self adjoint products + { + DenseMatrix b = DenseMatrix::Random(rows, rows); + DenseMatrix x = DenseMatrix::Random(rows, rows); + DenseMatrix refX = DenseMatrix::Random(rows, rows); + DenseMatrix refUp = DenseMatrix::Zero(rows, rows); + DenseMatrix refLo = DenseMatrix::Zero(rows, rows); + DenseMatrix refS = DenseMatrix::Zero(rows, rows); + SparseMatrix mUp(rows, rows); + SparseMatrix mLo(rows, rows); + SparseMatrix mS(rows, rows); + do { + initSparse(density, refUp, mUp, ForceRealDiag|/*ForceNonZeroDiag|*/MakeUpperTriangular); + } while (refUp.isZero()); + refLo = refUp.transpose().conjugate(); + mLo = mUp.transpose().conjugate(); + refS = refUp + refLo; + refS.diagonal() *= 0.5; + mS = mUp + mLo; + for (int k=0; k::InnerIterator it(mS,k); it; ++it) + if (it.index() == k) + it.valueRef() *= 0.5; + + VERIFY_IS_APPROX(refS.adjoint(), refS); + VERIFY_IS_APPROX(mS.transpose().conjugate(), mS); + VERIFY_IS_APPROX(mS, refS); + VERIFY_IS_APPROX(x=mS*b, refX=refS*b); + VERIFY_IS_APPROX(x=mUp.template marked()*b, refX=refS*b); + VERIFY_IS_APPROX(x=mLo.template marked()*b, refX=refS*b); + VERIFY_IS_APPROX(x=mS.template marked()*b, refX=refS*b); + } } void test_sparse_basic()