mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
evaluate 1D sparse expressions into SparseVector and make the sparse operator<< and dot honor nested types
This commit is contained in:
parent
7f04845023
commit
2c03e6fccc
@ -62,8 +62,16 @@ SparseMatrixBase<Derived>::dot(const SparseMatrixBase<OtherDerived>& other) cons
|
|||||||
|
|
||||||
eigen_assert(size() == other.size());
|
eigen_assert(size() == other.size());
|
||||||
|
|
||||||
typename Derived::InnerIterator i(derived(),0);
|
typedef typename Derived::Nested Nested;
|
||||||
typename OtherDerived::InnerIterator j(other.derived(),0);
|
typedef typename OtherDerived::Nested OtherNested;
|
||||||
|
typedef typename internal::remove_all<Nested>::type NestedCleaned;
|
||||||
|
typedef typename internal::remove_all<OtherNested>::type OtherNestedCleaned;
|
||||||
|
|
||||||
|
const Nested nthis(derived());
|
||||||
|
const OtherNested nother(other.derived());
|
||||||
|
|
||||||
|
typename NestedCleaned::InnerIterator i(nthis,0);
|
||||||
|
typename OtherNestedCleaned::InnerIterator j(nother,0);
|
||||||
Scalar res(0);
|
Scalar res(0);
|
||||||
while (i && j)
|
while (i && j)
|
||||||
{
|
{
|
||||||
|
@ -274,12 +274,16 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
|
|
||||||
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
||||||
{
|
{
|
||||||
|
typedef typename Derived::Nested Nested;
|
||||||
|
typedef typename internal::remove_all<Nested>::type NestedCleaned;
|
||||||
|
|
||||||
if (Flags&RowMajorBit)
|
if (Flags&RowMajorBit)
|
||||||
{
|
{
|
||||||
for (Index row=0; row<m.outerSize(); ++row)
|
const Nested nm(m.derived());
|
||||||
|
for (Index row=0; row<nm.outerSize(); ++row)
|
||||||
{
|
{
|
||||||
Index col = 0;
|
Index col = 0;
|
||||||
for (typename Derived::InnerIterator it(m.derived(), row); it; ++it)
|
for (typename NestedCleaned::InnerIterator it(nm.derived(), row); it; ++it)
|
||||||
{
|
{
|
||||||
for ( ; col<it.index(); ++col)
|
for ( ; col<it.index(); ++col)
|
||||||
s << "0 ";
|
s << "0 ";
|
||||||
@ -293,9 +297,10 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
const Nested nm(m.derived());
|
||||||
if (m.cols() == 1) {
|
if (m.cols() == 1) {
|
||||||
Index row = 0;
|
Index row = 0;
|
||||||
for (typename Derived::InnerIterator it(m.derived(), 0); it; ++it)
|
for (typename NestedCleaned::InnerIterator it(nm.derived(), 0); it; ++it)
|
||||||
{
|
{
|
||||||
for ( ; row<it.index(); ++row)
|
for ( ; row<it.index(); ++row)
|
||||||
s << "0" << std::endl;
|
s << "0" << std::endl;
|
||||||
@ -307,7 +312,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
SparseMatrix<Scalar, RowMajorBit> trans = m.derived();
|
SparseMatrix<Scalar, RowMajorBit> trans = m;
|
||||||
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans);
|
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,17 +103,39 @@ template<typename Lhs, typename Rhs, int InnerSize = internal::traits<Lhs>::Cols
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename T> struct eval<T,Sparse>
|
template<typename T,int Rows,int Cols> struct sparse_eval;
|
||||||
{
|
|
||||||
typedef typename traits<T>::Scalar _Scalar;
|
|
||||||
enum {
|
|
||||||
_Flags = traits<T>::Flags
|
|
||||||
};
|
|
||||||
|
|
||||||
|
template<typename T> struct eval<T,Sparse>
|
||||||
|
: public sparse_eval<T, traits<T>::RowsAtCompileTime,traits<T>::ColsAtCompileTime>
|
||||||
|
{};
|
||||||
|
|
||||||
|
template<typename T,int Cols> struct sparse_eval<T,1,Cols> {
|
||||||
|
typedef typename traits<T>::Scalar _Scalar;
|
||||||
|
enum { _Flags = traits<T>::Flags| RowMajorBit };
|
||||||
|
public:
|
||||||
|
typedef SparseVector<_Scalar, _Flags> type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T,int Rows> struct sparse_eval<T,Rows,1> {
|
||||||
|
typedef typename traits<T>::Scalar _Scalar;
|
||||||
|
enum { _Flags = traits<T>::Flags & (~RowMajorBit) };
|
||||||
|
public:
|
||||||
|
typedef SparseVector<_Scalar, _Flags> type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T,int Rows,int Cols> struct sparse_eval {
|
||||||
|
typedef typename traits<T>::Scalar _Scalar;
|
||||||
|
enum { _Flags = traits<T>::Flags };
|
||||||
public:
|
public:
|
||||||
typedef SparseMatrix<_Scalar, _Flags> type;
|
typedef SparseMatrix<_Scalar, _Flags> type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename T> struct sparse_eval<T,1,1> {
|
||||||
|
typedef typename traits<T>::Scalar _Scalar;
|
||||||
|
public:
|
||||||
|
typedef Matrix<_Scalar, 1, 1> type;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T> struct plain_matrix_type<T,Sparse>
|
template<typename T> struct plain_matrix_type<T,Sparse>
|
||||||
{
|
{
|
||||||
typedef typename traits<T>::Scalar _Scalar;
|
typedef typename traits<T>::Scalar _Scalar;
|
||||||
|
@ -47,7 +47,7 @@ struct traits<SparseVector<_Scalar, _Options, _Index> >
|
|||||||
typedef Sparse StorageKind;
|
typedef Sparse StorageKind;
|
||||||
typedef MatrixXpr XprKind;
|
typedef MatrixXpr XprKind;
|
||||||
enum {
|
enum {
|
||||||
IsColVector = _Options & RowMajorBit ? 0 : 1,
|
IsColVector = (_Options & RowMajorBit) ? 0 : 1,
|
||||||
|
|
||||||
RowsAtCompileTime = IsColVector ? Dynamic : 1,
|
RowsAtCompileTime = IsColVector ? Dynamic : 1,
|
||||||
ColsAtCompileTime = IsColVector ? 1 : Dynamic,
|
ColsAtCompileTime = IsColVector ? 1 : Dynamic,
|
||||||
@ -320,7 +320,7 @@ protected:
|
|||||||
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
|
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
|
||||||
if(needToTranspose)
|
if(needToTranspose)
|
||||||
{
|
{
|
||||||
Index size = other.innerSize();
|
Index size = other.size();
|
||||||
Index nnz = other.nonZeros();
|
Index nnz = other.nonZeros();
|
||||||
resize(size);
|
resize(size);
|
||||||
reserve(nnz);
|
reserve(nnz);
|
||||||
|
@ -34,9 +34,9 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
|
|||||||
typedef SparseMatrix<Scalar> SparseMatrixType;
|
typedef SparseMatrix<Scalar> SparseMatrixType;
|
||||||
Scalar eps = 1e-6;
|
Scalar eps = 1e-6;
|
||||||
|
|
||||||
SparseMatrixType m1(rows,cols);
|
SparseMatrixType m1(rows,rows);
|
||||||
SparseVectorType v1(rows), v2(rows), v3(rows);
|
SparseVectorType v1(rows), v2(rows), v3(rows);
|
||||||
DenseMatrix refM1 = DenseMatrix::Zero(rows, cols);
|
DenseMatrix refM1 = DenseMatrix::Zero(rows, rows);
|
||||||
DenseVector refV1 = DenseVector::Random(rows),
|
DenseVector refV1 = DenseVector::Random(rows),
|
||||||
refV2 = DenseVector::Random(rows),
|
refV2 = DenseVector::Random(rows),
|
||||||
refV3 = DenseVector::Random(rows);
|
refV3 = DenseVector::Random(rows);
|
||||||
@ -86,6 +86,11 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
|
|||||||
VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2));
|
VERIFY_IS_APPROX(v1.dot(v2), refV1.dot(refV2));
|
||||||
VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2));
|
VERIFY_IS_APPROX(v1.dot(refV2), refV1.dot(refV2));
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX(v1.dot(m1*v2), refV1.dot(refM1*refV2));
|
||||||
|
int i = internal::random<int>(0,rows-1);
|
||||||
|
VERIFY_IS_APPROX(v1.dot(m1.col(i)), refV1.dot(refM1.col(i)));
|
||||||
|
|
||||||
|
|
||||||
VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm());
|
VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user