Add support for sparse * dense and dense * sparse matrix/vector products

This commit is contained in:
Gael Guennebaud 2009-01-14 17:41:55 +00:00
parent c4c70669d1
commit 0b606dcccd
8 changed files with 140 additions and 42 deletions

View File

@ -250,10 +250,6 @@ template<typename Derived> class MatrixBase
Derived& lazyAssign(const Flagged<OtherDerived, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) Derived& lazyAssign(const Flagged<OtherDerived, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{ return lazyAssign(other._expression()); } { return lazyAssign(other._expression()); }
/** Overloaded for sparse product evaluation */
/*template<typename Derived1, typename Derived2>
Derived& lazyAssign(const Product<Derived1,Derived2,SparseProduct>& product);*/
CommaInitializer<Derived> operator<< (const Scalar& s); CommaInitializer<Derived> operator<< (const Scalar& s);
template<typename OtherDerived> template<typename OtherDerived>
@ -615,6 +611,15 @@ template<typename Derived> class MatrixBase
PlainMatrixType unitOrthogonal(void) const; PlainMatrixType unitOrthogonal(void) const;
Matrix<Scalar,3,1> eulerAngles(int a0, int a1, int a2) const; Matrix<Scalar,3,1> eulerAngles(int a0, int a1, int a2) const;
/////////// Sparse module ///////////
// dense = spasre * dense
template<typename Derived1, typename Derived2>
Derived& lazyAssign(const SparseProduct<Derived1,Derived2,SparseTimeDenseProduct>& product);
// dense = dense * spasre
template<typename Derived1, typename Derived2>
Derived& lazyAssign(const SparseProduct<Derived1,Derived2,DenseTimeSparseProduct>& product);
#ifdef EIGEN_MATRIXBASE_PLUGIN #ifdef EIGEN_MATRIXBASE_PLUGIN
#include EIGEN_MATRIXBASE_PLUGIN #include EIGEN_MATRIXBASE_PLUGIN
#endif #endif

View File

@ -201,7 +201,7 @@ enum { ForceAligned, AsRequested };
enum { ConditionalJumpCost = 5 }; enum { ConditionalJumpCost = 5 };
enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight }; enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight };
enum DirectionType { Vertical, Horizontal }; enum DirectionType { Vertical, Horizontal };
enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct }; enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct, SparseTimeSparseProduct, SparseTimeDenseProduct, DenseTimeSparseProduct };
enum { enum {
/** \internal Equivalent to a slice vectorization for fixed-size matrices having good alignment /** \internal Equivalent to a slice vectorization for fixed-size matrices having good alignment

View File

@ -122,4 +122,7 @@ template <typename _Scalar, int _AmbientDim> class Hyperplane;
template<typename Scalar,int Dim> class Translation; template<typename Scalar,int Dim> class Translation;
template<typename Scalar,int Dim> class Scaling; template<typename Scalar,int Dim> class Scaling;
// Sparse module:
template<typename Lhs, typename Rhs, int ProductMode> class SparseProduct;
#endif // EIGEN_FORWARDDECLARATIONS_H #endif // EIGEN_FORWARDDECLARATIONS_H

View File

@ -314,9 +314,10 @@ class SparseMatrix
// 1 - compute the number of coeffs per dest inner vector // 1 - compute the number of coeffs per dest inner vector
// 2 - do the actual copy/eval // 2 - do the actual copy/eval
// Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed // Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed
typedef typename ei_nested<OtherDerived,2>::type OtherCopy; //typedef typename ei_nested<OtherDerived,2>::type OtherCopy;
OtherCopy otherCopy(other.derived()); typedef typename ei_eval<OtherDerived>::type OtherCopy;
typedef typename ei_cleantype<OtherCopy>::type _OtherCopy; typedef typename ei_cleantype<OtherCopy>::type _OtherCopy;
OtherCopy otherCopy(other.derived());
resize(other.rows(), other.cols()); resize(other.rows(), other.cols());
Eigen::Map<VectorXi>(m_outerIndex,outerSize()).setZero(); Eigen::Map<VectorXi>(m_outerIndex,outerSize()).setZero();

View File

@ -213,7 +213,7 @@ template<typename Derived> class SparseMatrixBase
} }
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
inline Derived& operator=(const SparseProduct<Lhs,Rhs>& product); inline Derived& operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product);
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
{ {
@ -291,6 +291,16 @@ template<typename Derived> class SparseMatrixBase
template<typename OtherDerived> template<typename OtherDerived>
const typename SparseProductReturnType<Derived,OtherDerived>::Type const typename SparseProductReturnType<Derived,OtherDerived>::Type
operator*(const SparseMatrixBase<OtherDerived> &other) const; operator*(const SparseMatrixBase<OtherDerived> &other) const;
// dense * sparse (return a dense object)
template<typename OtherDerived> friend
const typename SparseProductReturnType<OtherDerived,Derived>::Type
operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
{ return typename SparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
template<typename OtherDerived>
const typename SparseProductReturnType<Derived,OtherDerived>::Type
operator*(const MatrixBase<OtherDerived> &other) const;
template<typename OtherDerived> template<typename OtherDerived>
Derived& operator*=(const SparseMatrixBase<OtherDerived>& other); Derived& operator*=(const SparseMatrixBase<OtherDerived>& other);

View File

@ -25,9 +25,29 @@
#ifndef EIGEN_SPARSEPRODUCT_H #ifndef EIGEN_SPARSEPRODUCT_H
#define EIGEN_SPARSEPRODUCT_H #define EIGEN_SPARSEPRODUCT_H
template<typename Lhs, typename Rhs> struct ei_sparse_product_mode
{
enum {
value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit
? SparseTimeSparseProduct
: (Lhs::Flags&SparseBit)==SparseBit
? SparseTimeDenseProduct
: DenseTimeSparseProduct };
};
template<typename Lhs, typename Rhs, int ProductMode>
struct SparseProductReturnType
{
typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type;
};
// sparse product return type specialization // sparse product return type specialization
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
struct SparseProductReturnType struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct>
{ {
typedef typename ei_traits<Lhs>::Scalar Scalar; typedef typename ei_traits<Lhs>::Scalar Scalar;
enum { enum {
@ -47,11 +67,11 @@ struct SparseProductReturnType
SparseMatrix<Scalar,0>, SparseMatrix<Scalar,0>,
const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested; const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested;
typedef SparseProduct<LhsNested, RhsNested> Type; typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type;
}; };
template<typename LhsNested, typename RhsNested> template<typename LhsNested, typename RhsNested, int ProductMode>
struct ei_traits<SparseProduct<LhsNested, RhsNested> > struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >
{ {
// clean the nested types: // clean the nested types:
typedef typename ei_cleantype<LhsNested>::type _LhsNested; typedef typename ei_cleantype<LhsNested>::type _LhsNested;
@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
LhsRowMajor = LhsFlags & RowMajorBit, // LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit,
RhsRowMajor = RhsFlags & RowMajorBit, // RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit,
EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
ResultIsSparse = ProductMode==SparseTimeSparseProduct,
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ),
Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
| EvalBeforeAssigningBit | EvalBeforeAssigningBit
@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
CoeffReadCost = Dynamic CoeffReadCost = Dynamic
}; };
typedef typename ei_meta_if<ResultIsSparse,
SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >,
MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base;
}; };
template<typename LhsNested, typename RhsNested> template<typename LhsNested, typename RhsNested, int ProductMode>
class SparseProduct : ei_no_assignment_operator, class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base
public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> >
{ {
public: public:
@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator,
public: public:
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
inline SparseProduct(const Lhs& lhs, const Rhs& rhs) EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs) : m_lhs(lhs), m_rhs(rhs)
{ {
ei_assert(lhs.cols() == rhs.rows()); ei_assert(lhs.cols() == rhs.rows());
enum {
ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic
|| _RhsNested::RowsAtCompileTime==Dynamic
|| int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime),
AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested)
};
// note to the lost user:
// * for a dot product use: v1.dot(v2)
// * for a coeff-wise product use: v1.cwise()*v2
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
} }
inline int rows() const { return m_lhs.rows(); } EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); } EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); }
const _LhsNested& lhs() const { return m_lhs; } EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
const _LhsNested& rhs() const { return m_rhs; } EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
protected: protected:
LhsNested m_lhs; LhsNested m_lhs;
@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
// return derived(); // return derived();
// } // }
// sparse = sparse * sparse
template<typename Derived> template<typename Derived>
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product) inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product)
{ {
// std::cout << "sparse product to sparse\n"; // std::cout << "sparse product to sparse\n";
ei_sparse_product_selector< ei_sparse_product_selector<
@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs
return derived(); return derived();
} }
// dense = sparse * dense
template<typename Derived>
template<typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
{
typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename _Lhs::InnerIterator LhsInnerIterator;
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
derived().setZero();
for (int j=0; j<product.lhs().outerSize(); ++j)
for (LhsInnerIterator i(product.lhs(),j); i; ++i)
derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j);
return derived();
}
// dense = dense * sparse
template<typename Derived>
template<typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product)
{
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename _Rhs::InnerIterator RhsInnerIterator;
enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit };
derived().setZero();
for (int j=0; j<product.rhs().outerSize(); ++j)
for (RhsInnerIterator i(product.rhs(),j); i; ++i)
derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index());
return derived();
}
// sparse * sparse
template<typename Derived> template<typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
inline const typename SparseProductReturnType<Derived,OtherDerived>::Type EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
{ {
enum { return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
ProductIsValid = Derived::ColsAtCompileTime==Dynamic }
|| OtherDerived::RowsAtCompileTime==Dynamic
|| int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime), // sparse * dense
AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime, template<typename Derived>
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived) template<typename OtherDerived>
}; EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
// note to the lost user: SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
// * for a dot product use: v1.dot(v2) {
// * for a coeff-wise product use: v1.cwise()*v2
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
} }

View File

@ -109,10 +109,10 @@ template<typename MatrixType> class SparseInnerVector;
template<typename Derived> class SparseCwise; template<typename Derived> class SparseCwise;
template<typename UnaryOp, typename MatrixType> class SparseCwiseUnaryOp; template<typename UnaryOp, typename MatrixType> class SparseCwiseUnaryOp;
template<typename BinaryOp, typename Lhs, typename Rhs> class SparseCwiseBinaryOp; template<typename BinaryOp, typename Lhs, typename Rhs> class SparseCwiseBinaryOp;
template<typename Lhs, typename Rhs> class SparseProduct;
template<typename ExpressionType, unsigned int Added, unsigned int Removed> class SparseFlagged; template<typename ExpressionType, unsigned int Added, unsigned int Removed> class SparseFlagged;
template<typename Lhs, typename Rhs> struct SparseProductReturnType; template<typename Lhs, typename Rhs> struct ei_sparse_product_mode;
template<typename Lhs, typename Rhs, int ProductMode = ei_sparse_product_mode<Lhs,Rhs>::value> struct SparseProductReturnType;
const int AccessPatternNotSupported = 0x0; const int AccessPatternNotSupported = 0x0;
const int AccessPatternSupported = 0x1; const int AccessPatternSupported = 0x1;

View File

@ -216,6 +216,7 @@ template<typename Scalar> void sparse_basic(int rows, int cols)
DenseMatrix refMat2 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat2 = DenseMatrix::Zero(rows, rows);
DenseMatrix refMat3 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat3 = DenseMatrix::Zero(rows, rows);
DenseMatrix refMat4 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat4 = DenseMatrix::Zero(rows, rows);
DenseMatrix dm4 = DenseMatrix::Zero(rows, rows);
SparseMatrix<Scalar> m2(rows, rows); SparseMatrix<Scalar> m2(rows, rows);
SparseMatrix<Scalar> m3(rows, rows); SparseMatrix<Scalar> m3(rows, rows);
SparseMatrix<Scalar> m4(rows, rows); SparseMatrix<Scalar> m4(rows, rows);
@ -226,6 +227,18 @@ template<typename Scalar> void sparse_basic(int rows, int cols)
VERIFY_IS_APPROX(m4=m2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); VERIFY_IS_APPROX(m4=m2.transpose()*m3, refMat4=refMat2.transpose()*refMat3);
VERIFY_IS_APPROX(m4=m2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); VERIFY_IS_APPROX(m4=m2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
VERIFY_IS_APPROX(m4=m2*m3.transpose(), refMat4=refMat2*refMat3.transpose()); VERIFY_IS_APPROX(m4=m2*m3.transpose(), refMat4=refMat2*refMat3.transpose());
// sparse * dense
VERIFY_IS_APPROX(dm4=m2*refMat3, refMat4=refMat2*refMat3);
VERIFY_IS_APPROX(dm4=m2*refMat3.transpose(), refMat4=refMat2*refMat3.transpose());
VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3, refMat4=refMat2.transpose()*refMat3);
VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
// dense * sparse
VERIFY_IS_APPROX(dm4=refMat2*m3, refMat4=refMat2*refMat3);
VERIFY_IS_APPROX(dm4=refMat2*m3.transpose(), refMat4=refMat2*refMat3.transpose());
VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3, refMat4=refMat2.transpose()*refMat3);
VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
} }
} }