eigen/Eigen/src/Sparse/SparseMatrixBase.h
Gael Guennebaud 709e903335 Sparse module:
* extend unit tests
* add support for generic sum reduction and dot product
* optimize the cwise()* : this is a special case of CwiseBinaryOp where
  we only have to process the coeffs which are not null for *both* matrices.
  Perhaps there exist some other binary operations like that ?
2009-01-07 17:01:57 +00:00

219 lines
7.4 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra. Eigen itself is part of the KDE project.
//
// Copyright (C) 2008 Gael Guennebaud <g.gael@free.fr>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#ifndef EIGEN_SPARSEMATRIXBASE_H
#define EIGEN_SPARSEMATRIXBASE_H
template<typename Derived>
class SparseMatrixBase : public MatrixBase<Derived>
{
public:
typedef MatrixBase<Derived> Base;
typedef typename Base::Scalar Scalar;
typedef typename Base::RealScalar RealScalar;
enum {
Flags = Base::Flags,
RowMajor = ei_traits<Derived>::Flags&RowMajorBit ? 1 : 0
};
inline const Derived& derived() const { return *static_cast<const Derived*>(this); }
inline Derived& derived() { return *static_cast<Derived*>(this); }
inline Derived& const_cast_derived() const
{ return *static_cast<Derived*>(const_cast<SparseMatrixBase*>(this)); }
SparseMatrixBase()
: m_isRValue(false)
{}
bool isRValue() const { return m_isRValue; }
Derived& markAsRValue() { m_isRValue = true; return derived(); }
inline Derived& operator=(const Derived& other)
{
// std::cout << "Derived& operator=(const Derived& other)\n";
if (other.isRValue())
derived().swap(other.const_cast_derived());
else
this->operator=<Derived>(other);
return derived();
}
template<typename OtherDerived>
inline Derived& operator=(const MatrixBase<OtherDerived>& other)
{
// std::cout << "Derived& operator=(const MatrixBase<OtherDerived>& other)\n";
//const bool transpose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
ei_assert((!((Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit))) && "the transpose operation is supposed to be handled in SparseMatrix::operator=");
const int outerSize = other.outerSize();
//typedef typename ei_meta_if<transpose, LinkedVectorMatrix<Scalar,Flags&RowMajorBit>, Derived>::ret TempType;
// thanks to shallow copies, we always eval to a tempary
Derived temp(other.rows(), other.cols());
temp.startFill(std::max(this->rows(),this->cols())*2);
for (int j=0; j<outerSize; ++j)
{
for (typename OtherDerived::InnerIterator it(other.derived(), j); it; ++it)
{
Scalar v = it.value();
if (v!=Scalar(0))
{
if (OtherDerived::Flags & RowMajorBit) temp.fill(j,it.index()) = v;
else temp.fill(it.index(),j) = v;
}
}
}
temp.endFill();
derived() = temp.markAsRValue();
return derived();
}
template<typename OtherDerived>
inline Derived& operator=(const SparseMatrixBase<OtherDerived>& other)
{
// std::cout << typeid(OtherDerived).name() << "\n";
// std::cout << Flags << " " << OtherDerived::Flags << "\n";
const bool transpose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
// std::cout << "eval transpose = " << transpose << "\n";
const int outerSize = (int(OtherDerived::Flags) & RowMajorBit) ? other.rows() : other.cols();
if ((!transpose) && other.isRValue())
{
// eval without temporary
derived().resize(other.rows(), other.cols());
derived().startFill(std::max(this->rows(),this->cols())*2);
for (int j=0; j<outerSize; ++j)
{
for (typename OtherDerived::InnerIterator it(other.derived(), j); it; ++it)
{
Scalar v = it.value();
if (v!=Scalar(0))
{
if (RowMajor) derived().fill(j,it.index()) = v;
else derived().fill(it.index(),j) = v;
}
}
}
derived().endFill();
}
else
this->operator=<OtherDerived>(static_cast<const MatrixBase<OtherDerived>&>(other));
return derived();
}
template<typename Lhs, typename Rhs>
inline Derived& operator=(const Product<Lhs,Rhs,SparseProduct>& product);
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
{
if (Flags&RowMajorBit)
{
for (int row=0; row<m.outerSize(); ++row)
{
int col = 0;
for (typename Derived::InnerIterator it(m.derived(), row); it; ++it)
{
for ( ; col<it.index(); ++col)
s << "0 ";
s << it.value() << " ";
++col;
}
for ( ; col<m.cols(); ++col)
s << "0 ";
s << std::endl;
}
}
else
{
if (m.cols() == 1) {
int row = 0;
for (typename Derived::InnerIterator it(m.derived(), 0); it; ++it)
{
for ( ; row<it.index(); ++row)
s << "0" << std::endl;
s << it.value() << std::endl;
++row;
}
for ( ; row<m.rows(); ++row)
s << "0" << std::endl;
}
else
{
SparseMatrix<Scalar, RowMajorBit> trans = m.derived();
s << trans;
}
}
return s;
}
// template<typename OtherDerived>
// Scalar dot(const MatrixBase<OtherDerived>& other) const
// {
// EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
// EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
// EIGEN_STATIC_ASSERT((ei_is_same_type<Scalar, typename OtherDerived::Scalar>::ret),
// YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
//
// ei_assert(derived().size() == other.size());
// // short version, but the assembly looks more complicated because
// // of the CwiseBinaryOp iterator complexity
// // return res = (derived().cwise() * other.derived().conjugate()).sum();
//
// // optimized, generic version
// typename Derived::InnerIterator i(derived(),0);
// typename OtherDerived::InnerIterator j(other.derived(),0);
// Scalar res = 0;
// while (i && j)
// {
// if (i.index()==j.index())
// {
// // std::cerr << i.value() << " * " << j.value() << "\n";
// res += i.value() * ei_conj(j.value());
// ++i; ++j;
// }
// else if (i.index()<j.index())
// ++i;
// else
// ++j;
// }
// return res;
// }
//
// Scalar sum() const
// {
// Scalar res = 0;
// for (typename Derived::InnerIterator iter(*this,0); iter; ++iter)
// {
// res += iter.value();
// }
// return res;
// }
protected:
bool m_isRValue;
};
#endif // EIGEN_SPARSEMATRIXBASE_H