Use true compile-time branching in SparseVector::assign to handle automatic transposition.

This commit is contained in:
Gael Guennebaud 2013-07-05 09:14:32 +02:00
parent edba612f68
commit 7d8823c8b7
2 changed files with 30 additions and 26 deletions

View File

@ -86,7 +86,7 @@ template<typename MatrixType, unsigned int UpLo> class SparseSelfAdjointView
* Note that there is no algorithmic advantage of performing such a product compared to a general sparse-sparse matrix product. * Note that there is no algorithmic advantage of performing such a product compared to a general sparse-sparse matrix product.
* Indeed, the SparseSelfadjointView operand is first copied into a temporary SparseMatrix before computing the product. * Indeed, the SparseSelfadjointView operand is first copied into a temporary SparseMatrix before computing the product.
*/ */
template<typename OtherDerived> friend template<typename OtherDerived> friend
SparseSparseProduct<OtherDerived, SparseMatrix<Scalar, ((internal::traits<OtherDerived>::Flags&RowMajorBit) ? RowMajor : ColMajor),Index> > SparseSparseProduct<OtherDerived, SparseMatrix<Scalar, ((internal::traits<OtherDerived>::Flags&RowMajorBit) ? RowMajor : ColMajor),Index> >
operator*(const SparseMatrixBase<OtherDerived>& lhs, const SparseSelfAdjointView& rhs) operator*(const SparseMatrixBase<OtherDerived>& lhs, const SparseSelfAdjointView& rhs)
{ {

View File

@ -241,11 +241,11 @@ class SparseVector
template<typename OtherDerived> template<typename OtherDerived>
inline SparseVector& operator=(const SparseMatrixBase<OtherDerived>& other) inline SparseVector& operator=(const SparseMatrixBase<OtherDerived>& other)
{ {
if ( (bool(OtherDerived::IsVectorAtCompileTime) && int(RowsAtCompileTime)!=int(OtherDerived::RowsAtCompileTime)) if ( (bool(OtherDerived::IsVectorAtCompileTime) && int(RowsAtCompileTime)!=int(OtherDerived::RowsAtCompileTime))
|| ((!bool(OtherDerived::IsVectorAtCompileTime)) && ( bool(IsColVector) ? other.cols()>1 : other.rows()>1 ))) || ((!bool(OtherDerived::IsVectorAtCompileTime)) && ( bool(IsColVector) ? other.cols()>1 : other.rows()>1 )))
return assign(other.transpose()); return assign(other.transpose(), typename internal::conditional<((Flags & RowMajorBit) == (OtherDerived::Flags & RowMajorBit)),internal::true_type,internal::false_type>::type());
else else
return assign(other); return assign(other, typename internal::conditional<((Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit)),internal::true_type,internal::false_type>::type());
} }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
@ -328,7 +328,10 @@ protected:
} }
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector& assign(const SparseMatrixBase<OtherDerived>& _other); EIGEN_DONT_INLINE SparseVector& assign(const SparseMatrixBase<OtherDerived>& _other, internal::true_type);
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector& assign(const SparseMatrixBase<OtherDerived>& _other, internal::false_type);
Storage m_data; Storage m_data;
Index m_size; Index m_size;
@ -400,31 +403,32 @@ class SparseVector<Scalar,_Options,_Index>::ReverseInnerIterator
template<typename Scalar, int _Options, typename _Index> template<typename Scalar, int _Options, typename _Index>
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector<Scalar,_Options,_Index>& SparseVector<Scalar,_Options,_Index>::assign(const SparseMatrixBase<OtherDerived>& _other) EIGEN_DONT_INLINE SparseVector<Scalar,_Options,_Index>& SparseVector<Scalar,_Options,_Index>::assign(const SparseMatrixBase<OtherDerived>& _other, internal::true_type)
{ {
const OtherDerived& other(_other.derived()); const OtherDerived& other(_other.derived());
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
if(needToTranspose) Index size = other.size();
Index nnz = other.nonZeros();
resize(size);
reserve(nnz);
for(Index i=0; i<size; ++i)
{ {
Index size = other.size(); typename OtherDerived::InnerIterator it(other, i);
Index nnz = other.nonZeros(); if(it)
resize(size); insert(i) = it.value();
reserve(nnz);
for(Index i=0; i<size; ++i)
{
typename OtherDerived::InnerIterator it(other, i);
if(it)
insert(i) = it.value();
}
return *this;
}
else
{
// there is no special optimization
return Base::operator=(other);
} }
return *this;
} }
template<typename Scalar, int _Options, typename _Index>
template<typename OtherDerived>
EIGEN_DONT_INLINE SparseVector<Scalar,_Options,_Index>& SparseVector<Scalar,_Options,_Index>::assign(const SparseMatrixBase<OtherDerived>& _other, internal::false_type)
{
const OtherDerived& other(_other.derived());
// there is no special optimization
return Base::operator=(other);
}
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SPARSEVECTOR_H #endif // EIGEN_SPARSEVECTOR_H