diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index eecd24c85..d342a8936 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -250,10 +250,6 @@ template class MatrixBase Derived& lazyAssign(const Flagged& other) { return lazyAssign(other._expression()); } - /** Overloaded for sparse product evaluation */ - /*template - Derived& lazyAssign(const Product& product);*/ - CommaInitializer operator<< (const Scalar& s); template @@ -615,6 +611,15 @@ template class MatrixBase PlainMatrixType unitOrthogonal(void) const; Matrix eulerAngles(int a0, int a1, int a2) const; +/////////// Sparse module /////////// + + // dense = spasre * dense + template + Derived& lazyAssign(const SparseProduct& product); + // dense = dense * spasre + template + Derived& lazyAssign(const SparseProduct& product); + #ifdef EIGEN_MATRIXBASE_PLUGIN #include EIGEN_MATRIXBASE_PLUGIN #endif diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index f2c76cc01..05df011cf 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -201,7 +201,7 @@ enum { ForceAligned, AsRequested }; enum { ConditionalJumpCost = 5 }; enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight }; enum DirectionType { Vertical, Horizontal }; -enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct }; +enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct, SparseTimeSparseProduct, SparseTimeDenseProduct, DenseTimeSparseProduct }; enum { /** \internal Equivalent to a slice vectorization for fixed-size matrices having good alignment diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c194882d1..a45210e0c 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -122,4 +122,7 @@ template class Hyperplane; template class Translation; template class Scaling; +// Sparse module: +template class SparseProduct; + #endif // EIGEN_FORWARDDECLARATIONS_H diff --git a/Eigen/src/Sparse/SparseMatrix.h b/Eigen/src/Sparse/SparseMatrix.h index a732bdc31..07fc0be8d 100644 --- a/Eigen/src/Sparse/SparseMatrix.h +++ b/Eigen/src/Sparse/SparseMatrix.h @@ -314,9 +314,10 @@ class SparseMatrix // 1 - compute the number of coeffs per dest inner vector // 2 - do the actual copy/eval // Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed - typedef typename ei_nested::type OtherCopy; - OtherCopy otherCopy(other.derived()); + //typedef typename ei_nested::type OtherCopy; + typedef typename ei_eval::type OtherCopy; typedef typename ei_cleantype::type _OtherCopy; + OtherCopy otherCopy(other.derived()); resize(other.rows(), other.cols()); Eigen::Map(m_outerIndex,outerSize()).setZero(); diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index d01fa1ec5..14ac4e1cf 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h @@ -213,7 +213,7 @@ template class SparseMatrixBase } template - inline Derived& operator=(const SparseProduct& product); + inline Derived& operator=(const SparseProduct& product); friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m) { @@ -291,6 +291,16 @@ template class SparseMatrixBase template const typename SparseProductReturnType::Type operator*(const SparseMatrixBase &other) const; + + // dense * sparse (return a dense object) + template friend + const typename SparseProductReturnType::Type + operator*(const MatrixBase& lhs, const Derived& rhs) + { return typename SparseProductReturnType::Type(lhs.derived(),rhs); } + + template + const typename SparseProductReturnType::Type + operator*(const MatrixBase &other) const; template Derived& operator*=(const SparseMatrixBase& other); diff --git a/Eigen/src/Sparse/SparseProduct.h b/Eigen/src/Sparse/SparseProduct.h index b4ba2ee6f..29f5208fa 100644 --- a/Eigen/src/Sparse/SparseProduct.h +++ b/Eigen/src/Sparse/SparseProduct.h @@ -25,9 +25,29 @@ #ifndef EIGEN_SPARSEPRODUCT_H #define EIGEN_SPARSEPRODUCT_H +template struct ei_sparse_product_mode +{ + enum { + + value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeSparseProduct + : (Lhs::Flags&SparseBit)==SparseBit + ? SparseTimeDenseProduct + : DenseTimeSparseProduct }; +}; + +template +struct SparseProductReturnType +{ + typedef const typename ei_nested::type LhsNested; + typedef const typename ei_nested::type RhsNested; + + typedef SparseProduct Type; +}; + // sparse product return type specialization template -struct SparseProductReturnType +struct SparseProductReturnType { typedef typename ei_traits::Scalar Scalar; enum { @@ -47,11 +67,11 @@ struct SparseProductReturnType SparseMatrix, const typename ei_nested::type>::ret RhsNested; - typedef SparseProduct Type; + typedef SparseProduct Type; }; -template -struct ei_traits > +template +struct ei_traits > { // clean the nested types: typedef typename ei_cleantype::type _LhsNested; @@ -71,12 +91,13 @@ struct ei_traits > MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, - LhsRowMajor = LhsFlags & RowMajorBit, - RhsRowMajor = RhsFlags & RowMajorBit, +// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit, +// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit, EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), + ResultIsSparse = ProductMode==SparseTimeSparseProduct, - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ), Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) | EvalBeforeAssigningBit @@ -84,11 +105,14 @@ struct ei_traits > CoeffReadCost = Dynamic }; + + typedef typename ei_meta_if >, + MatrixBase > >::ret Base; }; -template -class SparseProduct : ei_no_assignment_operator, - public SparseMatrixBase > +template +class SparseProduct : ei_no_assignment_operator, public ei_traits >::Base { public: @@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator, public: template - inline SparseProduct(const Lhs& lhs, const Rhs& rhs) + EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { ei_assert(lhs.cols() == rhs.rows()); + + enum { + ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic + || _RhsNested::RowsAtCompileTime==Dynamic + || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), + AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, + SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) + }; + // note to the lost user: + // * for a dot product use: v1.dot(v2) + // * for a coeff-wise product use: v1.cwise()*v2 + EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), + INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) + EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), + INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) + EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) } - inline int rows() const { return m_lhs.rows(); } - inline int cols() const { return m_rhs.cols(); } + EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); } + EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); } - const _LhsNested& lhs() const { return m_lhs; } - const _LhsNested& rhs() const { return m_rhs; } + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } protected: LhsNested m_lhs; @@ -240,9 +280,10 @@ struct ei_sparse_product_selector // return derived(); // } +// sparse = sparse * sparse template template -inline Derived& SparseMatrixBase::operator=(const SparseProduct& product) +inline Derived& SparseMatrixBase::operator=(const SparseProduct& product) { // std::cout << "sparse product to sparse\n"; ei_sparse_product_selector< @@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase::operator=(const SparseProduct +template +Derived& MatrixBase::lazyAssign(const SparseProduct& product) +{ + typedef typename ei_cleantype::type _Lhs; + typedef typename _Lhs::InnerIterator LhsInnerIterator; + enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j +template +Derived& MatrixBase::lazyAssign(const SparseProduct& product) +{ + typedef typename ei_cleantype::type _Rhs; + typedef typename _Rhs::InnerIterator RhsInnerIterator; + enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit }; + derived().setZero(); + for (int j=0; j template -inline const typename SparseProductReturnType::Type +EIGEN_STRONG_INLINE const typename SparseProductReturnType::Type SparseMatrixBase::operator*(const SparseMatrixBase &other) const { - enum { - ProductIsValid = Derived::ColsAtCompileTime==Dynamic - || OtherDerived::RowsAtCompileTime==Dynamic - || int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime), - AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime, - SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived) - }; - // note to the lost user: - // * for a dot product use: v1.dot(v2) - // * for a coeff-wise product use: v1.cwise()*v2 - EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), - INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) - EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), - INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) - EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) + return typename SparseProductReturnType::Type(derived(), other.derived()); +} + +// sparse * dense +template +template +EIGEN_STRONG_INLINE const typename SparseProductReturnType::Type +SparseMatrixBase::operator*(const MatrixBase &other) const +{ return typename SparseProductReturnType::Type(derived(), other.derived()); } diff --git a/Eigen/src/Sparse/SparseUtil.h b/Eigen/src/Sparse/SparseUtil.h index 724fb9efb..046523d8f 100644 --- a/Eigen/src/Sparse/SparseUtil.h +++ b/Eigen/src/Sparse/SparseUtil.h @@ -109,10 +109,10 @@ template class SparseInnerVector; template class SparseCwise; template class SparseCwiseUnaryOp; template class SparseCwiseBinaryOp; -template class SparseProduct; template class SparseFlagged; -template struct SparseProductReturnType; +template struct ei_sparse_product_mode; +template::value> struct SparseProductReturnType; const int AccessPatternNotSupported = 0x0; const int AccessPatternSupported = 0x1; diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp index 54272d871..07a38ddd8 100644 --- a/test/sparse_basic.cpp +++ b/test/sparse_basic.cpp @@ -216,6 +216,7 @@ template void sparse_basic(int rows, int cols) DenseMatrix refMat2 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat3 = DenseMatrix::Zero(rows, rows); DenseMatrix refMat4 = DenseMatrix::Zero(rows, rows); + DenseMatrix dm4 = DenseMatrix::Zero(rows, rows); SparseMatrix m2(rows, rows); SparseMatrix m3(rows, rows); SparseMatrix m4(rows, rows); @@ -226,6 +227,18 @@ template void sparse_basic(int rows, int cols) VERIFY_IS_APPROX(m4=m2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); VERIFY_IS_APPROX(m4=m2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); VERIFY_IS_APPROX(m4=m2*m3.transpose(), refMat4=refMat2*refMat3.transpose()); + + // sparse * dense + VERIFY_IS_APPROX(dm4=m2*refMat3, refMat4=refMat2*refMat3); + VERIFY_IS_APPROX(dm4=m2*refMat3.transpose(), refMat4=refMat2*refMat3.transpose()); + VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3, refMat4=refMat2.transpose()*refMat3); + VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); + + // dense * sparse + VERIFY_IS_APPROX(dm4=refMat2*m3, refMat4=refMat2*refMat3); + VERIFY_IS_APPROX(dm4=refMat2*m3.transpose(), refMat4=refMat2*refMat3.transpose()); + VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3, refMat4=refMat2.transpose()*refMat3); + VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose()); } }