fix symmetric permuatation for mixed storage orders

This commit is contained in:
Gael Guennebaud 2012-02-27 13:21:41 +01:00
parent 128ff9cf07
commit bc8188f6a1
2 changed files with 43 additions and 18 deletions

View File

@ -309,12 +309,14 @@ void permute_symm_to_fullsymm(const MatrixType& mat, SparseMatrix<typename Matri
for(typename MatrixType::InnerIterator it(mat,j); it; ++it) for(typename MatrixType::InnerIterator it(mat,j); it; ++it)
{ {
Index i = it.index(); Index i = it.index();
Index r = it.row();
Index c = it.col();
Index ip = perm ? perm[i] : i; Index ip = perm ? perm[i] : i;
if(UpLo==(Upper|Lower)) if(UpLo==(Upper|Lower))
count[StorageOrderMatch ? jp : ip]++; count[StorageOrderMatch ? jp : ip]++;
else if(i==j) else if(r==c)
count[ip]++; count[ip]++;
else if(( UpLo==Lower && i>j) || ( UpLo==Upper && i<j)) else if(( UpLo==Lower && r>c) || ( UpLo==Upper && r<c))
{ {
count[ip]++; count[ip]++;
count[jp]++; count[jp]++;
@ -334,25 +336,31 @@ void permute_symm_to_fullsymm(const MatrixType& mat, SparseMatrix<typename Matri
// copy data // copy data
for(Index j = 0; j<size; ++j) for(Index j = 0; j<size; ++j)
{ {
Index jp = perm ? perm[j] : j;
for(typename MatrixType::InnerIterator it(mat,j); it; ++it) for(typename MatrixType::InnerIterator it(mat,j); it; ++it)
{ {
Index i = it.index(); Index i = it.index();
Index r = it.row();
Index c = it.col();
Index jp = perm ? perm[j] : j;
Index ip = perm ? perm[i] : i; Index ip = perm ? perm[i] : i;
if(UpLo==(Upper|Lower)) if(UpLo==(Upper|Lower))
{ {
Index k = count[StorageOrderMatch ? jp : ip]++; Index k = count[StorageOrderMatch ? jp : ip]++;
dest.innerIndexPtr()[k] = StorageOrderMatch ? ip : jp; dest.innerIndexPtr()[k] = StorageOrderMatch ? ip : jp;
dest.valuePtr()[k] = it.value(); dest.valuePtr()[k] = it.value();
} }
else if(i==j) else if(r==c)
{ {
Index k = count[ip]++; Index k = count[ip]++;
dest.innerIndexPtr()[k] = ip; dest.innerIndexPtr()[k] = ip;
dest.valuePtr()[k] = it.value(); dest.valuePtr()[k] = it.value();
} }
else if(( (UpLo&Lower)==Lower && i>j) || ( (UpLo&Upper)==Upper && i<j)) else if(( (UpLo&Lower)==Lower && r>c) || ( (UpLo&Upper)==Upper && r<c))
{ {
if(!StorageOrderMatch)
std::swap(ip,jp);
Index k = count[jp]++; Index k = count[jp]++;
dest.innerIndexPtr()[k] = ip; dest.innerIndexPtr()[k] = ip;
dest.valuePtr()[k] = it.value(); dest.valuePtr()[k] = it.value();
@ -364,15 +372,19 @@ void permute_symm_to_fullsymm(const MatrixType& mat, SparseMatrix<typename Matri
} }
} }
template<int SrcUpLo,int DstUpLo,typename MatrixType,int DestOrder> template<int _SrcUpLo,int _DstUpLo,typename MatrixType,int DstOrder>
void permute_symm_to_symm(const MatrixType& mat, SparseMatrix<typename MatrixType::Scalar,DestOrder,typename MatrixType::Index>& _dest, const typename MatrixType::Index* perm) void permute_symm_to_symm(const MatrixType& mat, SparseMatrix<typename MatrixType::Scalar,DstOrder,typename MatrixType::Index>& _dest, const typename MatrixType::Index* perm)
{ {
typedef typename MatrixType::Index Index; typedef typename MatrixType::Index Index;
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef SparseMatrix<Scalar,DestOrder,Index> Dest; SparseMatrix<Scalar,DstOrder,Index>& dest(_dest.derived());
Dest& dest(_dest.derived());
typedef Matrix<Index,Dynamic,1> VectorI; typedef Matrix<Index,Dynamic,1> VectorI;
//internal::conj_if<SrcUpLo!=DstUpLo> cj; enum {
SrcOrder = MatrixType::IsRowMajor ? RowMajor : ColMajor,
StorageOrderMatch = int(SrcOrder) == int(DstOrder),
DstUpLo = DstOrder==RowMajor ? (_DstUpLo==Upper ? Lower : Upper) : _DstUpLo,
SrcUpLo = SrcOrder==RowMajor ? (_SrcUpLo==Upper ? Lower : Upper) : _SrcUpLo
};
Index size = mat.rows(); Index size = mat.rows();
VectorI count(size); VectorI count(size);
@ -400,18 +412,21 @@ void permute_symm_to_symm(const MatrixType& mat, SparseMatrix<typename MatrixTyp
for(Index j = 0; j<size; ++j) for(Index j = 0; j<size; ++j)
{ {
Index jp = perm ? perm[j] : j;
for(typename MatrixType::InnerIterator it(mat,j); it; ++it) for(typename MatrixType::InnerIterator it(mat,j); it; ++it)
{ {
Index i = it.index(); Index i = it.index();
if((SrcUpLo==Lower && i<j) || (SrcUpLo==Upper && i>j)) if((SrcUpLo==Lower && i<j) || (SrcUpLo==Upper && i>j))
continue; continue;
Index jp = perm ? perm[j] : j;
Index ip = perm? perm[i] : i; Index ip = perm? perm[i] : i;
Index k = count[DstUpLo==Lower ? (std::min)(ip,jp) : (std::max)(ip,jp)]++; Index k = count[DstUpLo==Lower ? (std::min)(ip,jp) : (std::max)(ip,jp)]++;
dest.innerIndexPtr()[k] = DstUpLo==Lower ? (std::max)(ip,jp) : (std::min)(ip,jp); dest.innerIndexPtr()[k] = DstUpLo==Lower ? (std::max)(ip,jp) : (std::min)(ip,jp);
if((DstUpLo==Lower && ip<jp) || (DstUpLo==Upper && ip>jp)) if(!StorageOrderMatch) std::swap(ip,jp);
if( ((DstUpLo==Lower && ip<jp) || (DstUpLo==Upper && ip>jp)))
dest.valuePtr()[k] = conj(it.value()); dest.valuePtr()[k] = conj(it.value());
else else
dest.valuePtr()[k] = it.value(); dest.valuePtr()[k] = it.value();

View File

@ -24,19 +24,22 @@
#include "sparse.h" #include "sparse.h"
template<typename SparseMatrixType> void sparse_permutations(const SparseMatrixType& ref) template<int OtherStorage, typename SparseMatrixType> void sparse_permutations(const SparseMatrixType& ref)
{ {
typedef typename SparseMatrixType::Index Index; typedef typename SparseMatrixType::Index Index;
const Index rows = ref.rows(); const Index rows = ref.rows();
const Index cols = ref.cols(); const Index cols = ref.cols();
typedef typename SparseMatrixType::Scalar Scalar; typedef typename SparseMatrixType::Scalar Scalar;
typedef typename SparseMatrixType::Index Index;
typedef SparseMatrix<Scalar, OtherStorage, Index> OtherSparseMatrixType;
typedef Matrix<Scalar,Dynamic,Dynamic> DenseMatrix; typedef Matrix<Scalar,Dynamic,Dynamic> DenseMatrix;
typedef Matrix<int,Dynamic,1> VectorI; typedef Matrix<Index,Dynamic,1> VectorI;
double density = (std::max)(8./(rows*cols), 0.01); double density = (std::max)(8./(rows*cols), 0.01);
SparseMatrixType mat(rows, cols), up(rows,cols), lo(rows,cols), res; SparseMatrixType mat(rows, cols), up(rows,cols), lo(rows,cols);
OtherSparseMatrixType res;
DenseMatrix mat_d = DenseMatrix::Zero(rows, cols), up_sym_d, lo_sym_d, res_d; DenseMatrix mat_d = DenseMatrix::Zero(rows, cols), up_sym_d, lo_sym_d, res_d;
initSparse<Scalar>(density, mat_d, mat, 0); initSparse<Scalar>(density, mat_d, mat, 0);
@ -126,12 +129,19 @@ template<typename SparseMatrixType> void sparse_permutations(const SparseMatrixT
VERIFY(res.isApprox(res_d) && "lower selfadjoint twisted to full"); VERIFY(res.isApprox(res_d) && "lower selfadjoint twisted to full");
} }
template<typename Scalar> void sparse_permutations_all(int size)
{
CALL_SUBTEST(( sparse_permutations<ColMajor>(SparseMatrix<Scalar, ColMajor>(size,size)) ));
CALL_SUBTEST(( sparse_permutations<ColMajor>(SparseMatrix<Scalar, RowMajor>(size,size)) ));
CALL_SUBTEST(( sparse_permutations<RowMajor>(SparseMatrix<Scalar, ColMajor>(size,size)) ));
CALL_SUBTEST(( sparse_permutations<RowMajor>(SparseMatrix<Scalar, RowMajor>(size,size)) ));
}
void test_sparse_permutations() void test_sparse_permutations()
{ {
for(int i = 0; i < g_repeat; i++) { for(int i = 0; i < g_repeat; i++) {
int s = Eigen::internal::random<int>(1,50); int s = Eigen::internal::random<int>(1,50);
CALL_SUBTEST_1(( sparse_permutations(SparseMatrix<double>(8, 8)) )); CALL_SUBTEST_1(( sparse_permutations_all<double>(s) ));
CALL_SUBTEST_2(( sparse_permutations(SparseMatrix<std::complex<double> >(s, s)) )); CALL_SUBTEST_2(( sparse_permutations_all<std::complex<double> >(s) ));
CALL_SUBTEST_1(( sparse_permutations(SparseMatrix<double>(s, s)) ));
} }
} }