Rationalize assignment to sparse vectors

This commit is contained in:
Gael Guennebaud 2013-07-13 19:45:05 +02:00
parent 9a16519d62
commit 4bb0fff151

View File

@ -45,6 +45,21 @@ struct traits<SparseVector<_Scalar, _Options, _Index> >
SupportedAccessPatterns = InnerRandomAccessPattern SupportedAccessPatterns = InnerRandomAccessPattern
}; };
}; };
// Sparse-Vector-Assignment kinds:
enum {
SVA_RuntimeSwitch,
SVA_Inner,
SVA_Outer
};
template< typename Dest, typename Src,
int AssignmentKind = !bool(Src::IsVectorAtCompileTime) ? SVA_RuntimeSwitch
: (((Src::Flags&RowMajorBit)==RowMajorBit) && (Src::RowsAtCompileTime==1))
|| (((Src::Flags&RowMajorBit)==0) && (Src::ColsAtCompileTime==1)) ? SVA_Inner
: SVA_Outer>
struct sparse_vector_assign_selector;
} }
template<typename _Scalar, int _Options, typename _Index> template<typename _Scalar, int _Options, typename _Index>
@ -241,11 +256,11 @@ class SparseVector
template<typename OtherDerived> template<typename OtherDerived>
inline SparseVector& operator=(const SparseMatrixBase<OtherDerived>& other) inline SparseVector& operator=(const SparseMatrixBase<OtherDerived>& other)
{ {
if ( (bool(OtherDerived::IsVectorAtCompileTime) && int(RowsAtCompileTime)!=int(OtherDerived::RowsAtCompileTime)) SparseVector tmp(other.size());
|| ((!bool(OtherDerived::IsVectorAtCompileTime)) && ( bool(IsColVector) ? other.cols()>1 : other.rows()>1 ))) tmp.reserve(other.nonZeros());
return assign(other.transpose(), typename internal::conditional<((Flags & RowMajorBit) == (OtherDerived::Flags & RowMajorBit)),internal::true_type,internal::false_type>::type()); internal::sparse_vector_assign_selector<SparseVector,OtherDerived>::run(tmp,other.derived());
else this->swap(tmp);
return assign(other, typename internal::conditional<((Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit)),internal::true_type,internal::false_type>::type()); return *this;
} }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
@ -327,12 +342,6 @@ protected:
EIGEN_STATIC_ASSERT((_Options&(ColMajor|RowMajor))==Options,INVALID_MATRIX_TEMPLATE_PARAMETERS); EIGEN_STATIC_ASSERT((_Options&(ColMajor|RowMajor))==Options,INVALID_MATRIX_TEMPLATE_PARAMETERS);
} }
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector& assign(const SparseMatrixBase<OtherDerived>& _other, internal::true_type);
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector& assign(const SparseMatrixBase<OtherDerived>& _other, internal::false_type);
Storage m_data; Storage m_data;
Index m_size; Index m_size;
}; };
@ -401,32 +410,38 @@ class SparseVector<Scalar,_Options,_Index>::ReverseInnerIterator
const Index m_start; const Index m_start;
}; };
template<typename Scalar, int _Options, typename _Index> namespace internal {
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector<Scalar,_Options,_Index>& SparseVector<Scalar,_Options,_Index>::assign(const SparseMatrixBase<OtherDerived>& _other, internal::true_type)
{
const OtherDerived& other(_other.derived());
Index size = other.size(); template< typename Dest, typename Src>
Index nnz = other.nonZeros(); struct sparse_vector_assign_selector<Dest,Src,SVA_Inner> {
resize(size); static void run(Dest& dst, const Src& src) {
reserve(nnz); eigen_internal_assert(src.innerSize()==src.size());
for(Index i=0; i<size; ++i) for(typename Src::InnerIterator it(src, 0); it; ++it)
dst.insert(it.index()) = it.value();
}
};
template< typename Dest, typename Src>
struct sparse_vector_assign_selector<Dest,Src,SVA_Outer> {
static void run(Dest& dst, const Src& src) {
eigen_internal_assert(src.outerSize()==src.size());
for(typename Dest::Index i=0; i<src.size(); ++i)
{ {
typename OtherDerived::InnerIterator it(other, i); typename Src::InnerIterator it(src, i);
if(it) if(it)
insert(i) = it.value(); dst.insert(i) = it.value();
} }
return *this;
} }
};
template< typename Dest, typename Src>
struct sparse_vector_assign_selector<Dest,Src,SVA_RuntimeSwitch> {
static void run(Dest& dst, const Src& src) {
if(src.outerSize()==1) sparse_vector_assign_selector<Dest,Src,SVA_Inner>::run(dst, src);
else sparse_vector_assign_selector<Dest,Src,SVA_Outer>::run(dst, src);
}
};
template<typename Scalar, int _Options, typename _Index>
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector<Scalar,_Options,_Index>& SparseVector<Scalar,_Options,_Index>::assign(const SparseMatrixBase<OtherDerived>& _other, internal::false_type)
{
const OtherDerived& other(_other.derived());
// there is no special optimization
return Base::operator=(other);
} }
} // end namespace Eigen } // end namespace Eigen