evaluate 1D sparse expressions into SparseVector and make the sparse operator<< and dot honor nested types

This commit is contained in:
Gael Guennebaud 2011-12-22 14:01:06 +01:00
parent 7f04845023
commit 2c03e6fccc
5 changed files with 56 additions and 16 deletions

View File

@ -62,8 +62,16 @@ SparseMatrixBase<Derived>::dot(const SparseMatrixBase<OtherDerived>& other) cons
eigen_assert(size() == other.size()); eigen_assert(size() == other.size());
typename Derived::InnerIterator i(derived(),0); typedef typename Derived::Nested Nested;
typename OtherDerived::InnerIterator j(other.derived(),0); typedef typename OtherDerived::Nested OtherNested;
typedef typename internal::remove_all<Nested>::type NestedCleaned;
typedef typename internal::remove_all<OtherNested>::type OtherNestedCleaned;
const Nested nthis(derived());
const OtherNested nother(other.derived());
typename NestedCleaned::InnerIterator i(nthis,0);
typename OtherNestedCleaned::InnerIterator j(nother,0);
Scalar res(0); Scalar res(0);
while (i && j) while (i && j)
{ {

View File

@ -274,12 +274,16 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
{ {
typedef typename Derived::Nested Nested;
typedef typename internal::remove_all<Nested>::type NestedCleaned;
if (Flags&RowMajorBit) if (Flags&RowMajorBit)
{ {
for (Index row=0; row<m.outerSize(); ++row) const Nested nm(m.derived());
for (Index row=0; row<nm.outerSize(); ++row)
{ {
Index col = 0; Index col = 0;
for (typename Derived::InnerIterator it(m.derived(), row); it; ++it) for (typename NestedCleaned::InnerIterator it(nm.derived(), row); it; ++it)
{ {
for ( ; col<it.index(); ++col) for ( ; col<it.index(); ++col)
s << "0 "; s << "0 ";
@ -293,9 +297,10 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
} }
else else
{ {
const Nested nm(m.derived());
if (m.cols() == 1) { if (m.cols() == 1) {
Index row = 0; Index row = 0;
for (typename Derived::InnerIterator it(m.derived(), 0); it; ++it) for (typename NestedCleaned::InnerIterator it(nm.derived(), 0); it; ++it)
{ {
for ( ; row<it.index(); ++row) for ( ; row<it.index(); ++row)
s << "0" << std::endl; s << "0" << std::endl;
@ -307,7 +312,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
} }
else else
{ {
SparseMatrix<Scalar, RowMajorBit> trans = m.derived(); SparseMatrix<Scalar, RowMajorBit> trans = m;
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans); s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans);
} }
} }

View File

@ -103,17 +103,39 @@ template<typename Lhs, typename Rhs, int InnerSize = internal::traits<Lhs>::Cols
namespace internal { namespace internal {
template<typename T> struct eval<T,Sparse> template<typename T,int Rows,int Cols> struct sparse_eval;
{
typedef typename traits<T>::Scalar _Scalar;
enum {
_Flags = traits<T>::Flags
};
template<typename T> struct eval<T,Sparse>
: public sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime>
{};
template<typename T,int Cols> struct sparse_eval<T,1,Cols> {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags| RowMajorBit };
public:
typedef SparseVector<_Scalar, _Flags> type;
};
template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags & (~RowMajorBit) };
public:
typedef SparseVector<_Scalar, _Flags> type;
};
template<typename T,int Rows,int Cols> struct sparse_eval {
typedef typename traits<T>::Scalar _Scalar;
enum { _Flags = traits<T>::Flags };
public: public:
typedef SparseMatrix<_Scalar, _Flags> type; typedef SparseMatrix<_Scalar, _Flags> type;
}; };
template<typename T> struct sparse_eval<T,1,1> {
typedef typename traits<T>::Scalar _Scalar;
public:
typedef Matrix<_Scalar, 1, 1> type;
};
template<typename T> struct plain_matrix_type<T,Sparse> template<typename T> struct plain_matrix_type<T,Sparse>
{ {
typedef typename traits<T>::Scalar _Scalar; typedef typename traits<T>::Scalar _Scalar;

View File

@ -47,7 +47,7 @@ struct traits<SparseVector<_Scalar, _Options, _Index> >
typedef Sparse StorageKind; typedef Sparse StorageKind;
typedef MatrixXpr XprKind; typedef MatrixXpr XprKind;
enum { enum {
IsColVector = _Options & RowMajorBit ? 0 : 1, IsColVector = (_Options & RowMajorBit) ? 0 : 1,
RowsAtCompileTime = IsColVector ? Dynamic : 1, RowsAtCompileTime = IsColVector ? Dynamic : 1,
ColsAtCompileTime = IsColVector ? 1 : Dynamic, ColsAtCompileTime = IsColVector ? 1 : Dynamic,
@ -320,7 +320,7 @@ protected:
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit); const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
if(needToTranspose) if(needToTranspose)
{ {
Index size = other.innerSize(); Index size = other.size();
Index nnz = other.nonZeros(); Index nnz = other.nonZeros();
resize(size); resize(size);
reserve(nnz); reserve(nnz);

View File

@ -34,9 +34,9 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
typedef SparseMatrix<Scalar> SparseMatrixType; typedef SparseMatrix<Scalar> SparseMatrixType;
Scalar eps = 1e-6; Scalar eps = 1e-6;
SparseMatrixType m1(rows,cols); SparseMatrixType m1(rows,rows);
SparseVectorType v1(rows), v2(rows), v3(rows); SparseVectorType v1(rows), v2(rows), v3(rows);
DenseMatrix refM1 = DenseMatrix::Zero(rows, cols); DenseMatrix refM1 = DenseMatrix::Zero(rows, rows);
DenseVector refV1 = DenseVector::Random(rows), DenseVector refV1 = DenseVector::Random(rows),
refV2 = DenseVector::Random(rows), refV2 = DenseVector::Random(rows),
refV3 = DenseVector::Random(rows); refV3 = DenseVector::Random(rows);
@ -86,6 +86,11 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2)); VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2));
VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2)); VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2));
VERIFY_IS_APPROX(v1.dot(m1*v2), refV1.dot(refM1*refV2));
int i = internal::random<int>(0,rows-1);
VERIFY_IS_APPROX(v1.dot(m1.col(i)), refV1.dot(refM1.col(i)));
VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm()); VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm());
} }