mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-25 07:43:14 +08:00
Add support for sparse * dense and dense * sparse matrix/vector products
This commit is contained in:
parent
c4c70669d1
commit
0b606dcccd
@ -250,10 +250,6 @@ template<typename Derived> class MatrixBase
|
||||
Derived& lazyAssign(const Flagged<OtherDerived, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
|
||||
{ return lazyAssign(other._expression()); }
|
||||
|
||||
/** Overloaded for sparse product evaluation */
|
||||
/*template<typename Derived1, typename Derived2>
|
||||
Derived& lazyAssign(const Product<Derived1,Derived2,SparseProduct>& product);*/
|
||||
|
||||
CommaInitializer<Derived> operator<< (const Scalar& s);
|
||||
|
||||
template<typename OtherDerived>
|
||||
@ -615,6 +611,15 @@ template<typename Derived> class MatrixBase
|
||||
PlainMatrixType unitOrthogonal(void) const;
|
||||
Matrix<Scalar,3,1> eulerAngles(int a0, int a1, int a2) const;
|
||||
|
||||
/////////// Sparse module ///////////
|
||||
|
||||
// dense = spasre * dense
|
||||
template<typename Derived1, typename Derived2>
|
||||
Derived& lazyAssign(const SparseProduct<Derived1,Derived2,SparseTimeDenseProduct>& product);
|
||||
// dense = dense * spasre
|
||||
template<typename Derived1, typename Derived2>
|
||||
Derived& lazyAssign(const SparseProduct<Derived1,Derived2,DenseTimeSparseProduct>& product);
|
||||
|
||||
#ifdef EIGEN_MATRIXBASE_PLUGIN
|
||||
#include EIGEN_MATRIXBASE_PLUGIN
|
||||
#endif
|
||||
|
@ -201,7 +201,7 @@ enum { ForceAligned, AsRequested };
|
||||
enum { ConditionalJumpCost = 5 };
|
||||
enum CornerType { TopLeft, TopRight, BottomLeft, BottomRight };
|
||||
enum DirectionType { Vertical, Horizontal };
|
||||
enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct };
|
||||
enum ProductEvaluationMode { NormalProduct, CacheFriendlyProduct, DiagonalProduct, SparseTimeSparseProduct, SparseTimeDenseProduct, DenseTimeSparseProduct };
|
||||
|
||||
enum {
|
||||
/** \internal Equivalent to a slice vectorization for fixed-size matrices having good alignment
|
||||
|
@ -122,4 +122,7 @@ template <typename _Scalar, int _AmbientDim> class Hyperplane;
|
||||
template<typename Scalar,int Dim> class Translation;
|
||||
template<typename Scalar,int Dim> class Scaling;
|
||||
|
||||
// Sparse module:
|
||||
template<typename Lhs, typename Rhs, int ProductMode> class SparseProduct;
|
||||
|
||||
#endif // EIGEN_FORWARDDECLARATIONS_H
|
||||
|
@ -314,9 +314,10 @@ class SparseMatrix
|
||||
// 1 - compute the number of coeffs per dest inner vector
|
||||
// 2 - do the actual copy/eval
|
||||
// Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed
|
||||
typedef typename ei_nested<OtherDerived,2>::type OtherCopy;
|
||||
OtherCopy otherCopy(other.derived());
|
||||
//typedef typename ei_nested<OtherDerived,2>::type OtherCopy;
|
||||
typedef typename ei_eval<OtherDerived>::type OtherCopy;
|
||||
typedef typename ei_cleantype<OtherCopy>::type _OtherCopy;
|
||||
OtherCopy otherCopy(other.derived());
|
||||
|
||||
resize(other.rows(), other.cols());
|
||||
Eigen::Map<VectorXi>(m_outerIndex,outerSize()).setZero();
|
||||
|
@ -213,7 +213,7 @@ template<typename Derived> class SparseMatrixBase
|
||||
}
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline Derived& operator=(const SparseProduct<Lhs,Rhs>& product);
|
||||
inline Derived& operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product);
|
||||
|
||||
friend std::ostream & operator << (std::ostream & s, const SparseMatrixBase& m)
|
||||
{
|
||||
@ -291,6 +291,16 @@ template<typename Derived> class SparseMatrixBase
|
||||
template<typename OtherDerived>
|
||||
const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||
|
||||
// dense * sparse (return a dense object)
|
||||
template<typename OtherDerived> friend
|
||||
const typename SparseProductReturnType<OtherDerived,Derived>::Type
|
||||
operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
|
||||
{ return typename SparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
|
||||
|
||||
template<typename OtherDerived>
|
||||
const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
operator*(const MatrixBase<OtherDerived> &other) const;
|
||||
|
||||
template<typename OtherDerived>
|
||||
Derived& operator*=(const SparseMatrixBase<OtherDerived>& other);
|
||||
|
@ -25,9 +25,29 @@
|
||||
#ifndef EIGEN_SPARSEPRODUCT_H
|
||||
#define EIGEN_SPARSEPRODUCT_H
|
||||
|
||||
template<typename Lhs, typename Rhs> struct ei_sparse_product_mode
|
||||
{
|
||||
enum {
|
||||
|
||||
value = (Rhs::Flags&Lhs::Flags&SparseBit)==SparseBit
|
||||
? SparseTimeSparseProduct
|
||||
: (Lhs::Flags&SparseBit)==SparseBit
|
||||
? SparseTimeDenseProduct
|
||||
: DenseTimeSparseProduct };
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int ProductMode>
|
||||
struct SparseProductReturnType
|
||||
{
|
||||
typedef const typename ei_nested<Lhs,Rhs::RowsAtCompileTime>::type LhsNested;
|
||||
typedef const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
|
||||
|
||||
typedef SparseProduct<LhsNested, RhsNested, ProductMode> Type;
|
||||
};
|
||||
|
||||
// sparse product return type specialization
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct SparseProductReturnType
|
||||
struct SparseProductReturnType<Lhs,Rhs,SparseTimeSparseProduct>
|
||||
{
|
||||
typedef typename ei_traits<Lhs>::Scalar Scalar;
|
||||
enum {
|
||||
@ -47,11 +67,11 @@ struct SparseProductReturnType
|
||||
SparseMatrix<Scalar,0>,
|
||||
const typename ei_nested<Rhs,Lhs::RowsAtCompileTime>::type>::ret RhsNested;
|
||||
|
||||
typedef SparseProduct<LhsNested, RhsNested> Type;
|
||||
typedef SparseProduct<LhsNested, RhsNested, SparseTimeSparseProduct> Type;
|
||||
};
|
||||
|
||||
template<typename LhsNested, typename RhsNested>
|
||||
struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
template<typename LhsNested, typename RhsNested, int ProductMode>
|
||||
struct ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >
|
||||
{
|
||||
// clean the nested types:
|
||||
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
|
||||
@ -71,12 +91,13 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
|
||||
MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
|
||||
|
||||
LhsRowMajor = LhsFlags & RowMajorBit,
|
||||
RhsRowMajor = RhsFlags & RowMajorBit,
|
||||
// LhsIsRowMajor = (LhsFlags & RowMajorBit)==RowMajorBit,
|
||||
// RhsIsRowMajor = (RhsFlags & RowMajorBit)==RowMajorBit,
|
||||
|
||||
EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
|
||||
ResultIsSparse = ProductMode==SparseTimeSparseProduct,
|
||||
|
||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
||||
RemovedBits = ~( (EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSparse ? 0 : SparseBit) ),
|
||||
|
||||
Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
||||
| EvalBeforeAssigningBit
|
||||
@ -84,11 +105,14 @@ struct ei_traits<SparseProduct<LhsNested, RhsNested> >
|
||||
|
||||
CoeffReadCost = Dynamic
|
||||
};
|
||||
|
||||
typedef typename ei_meta_if<ResultIsSparse,
|
||||
SparseMatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> >,
|
||||
MatrixBase<SparseProduct<LhsNested, RhsNested, ProductMode> > >::ret Base;
|
||||
};
|
||||
|
||||
template<typename LhsNested, typename RhsNested>
|
||||
class SparseProduct : ei_no_assignment_operator,
|
||||
public SparseMatrixBase<SparseProduct<LhsNested, RhsNested> >
|
||||
template<typename LhsNested, typename RhsNested, int ProductMode>
|
||||
class SparseProduct : ei_no_assignment_operator, public ei_traits<SparseProduct<LhsNested, RhsNested, ProductMode> >::Base
|
||||
{
|
||||
public:
|
||||
|
||||
@ -102,17 +126,33 @@ class SparseProduct : ei_no_assignment_operator,
|
||||
public:
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline SparseProduct(const Lhs& lhs, const Rhs& rhs)
|
||||
EIGEN_STRONG_INLINE SparseProduct(const Lhs& lhs, const Rhs& rhs)
|
||||
: m_lhs(lhs), m_rhs(rhs)
|
||||
{
|
||||
ei_assert(lhs.cols() == rhs.rows());
|
||||
|
||||
enum {
|
||||
ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic
|
||||
|| _RhsNested::RowsAtCompileTime==Dynamic
|
||||
|| int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime),
|
||||
AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
|
||||
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested)
|
||||
};
|
||||
// note to the lost user:
|
||||
// * for a dot product use: v1.dot(v2)
|
||||
// * for a coeff-wise product use: v1.cwise()*v2
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
|
||||
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
|
||||
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
|
||||
}
|
||||
|
||||
inline int rows() const { return m_lhs.rows(); }
|
||||
inline int cols() const { return m_rhs.cols(); }
|
||||
EIGEN_STRONG_INLINE int rows() const { return m_lhs.rows(); }
|
||||
EIGEN_STRONG_INLINE int cols() const { return m_rhs.cols(); }
|
||||
|
||||
const _LhsNested& lhs() const { return m_lhs; }
|
||||
const _LhsNested& rhs() const { return m_rhs; }
|
||||
EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
|
||||
EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
|
||||
|
||||
protected:
|
||||
LhsNested m_lhs;
|
||||
@ -240,9 +280,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
|
||||
// return derived();
|
||||
// }
|
||||
|
||||
// sparse = sparse * sparse
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs>& product)
|
||||
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs,SparseTimeSparseProduct>& product)
|
||||
{
|
||||
// std::cout << "sparse product to sparse\n";
|
||||
ei_sparse_product_selector<
|
||||
@ -252,26 +293,51 @@ inline Derived& SparseMatrixBase<Derived>::operator=(const SparseProduct<Lhs,Rhs
|
||||
return derived();
|
||||
}
|
||||
|
||||
// dense = sparse * dense
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,SparseTimeDenseProduct>& product)
|
||||
{
|
||||
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||
typedef typename _Lhs::InnerIterator LhsInnerIterator;
|
||||
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
|
||||
derived().setZero();
|
||||
for (int j=0; j<product.lhs().outerSize(); ++j)
|
||||
for (LhsInnerIterator i(product.lhs(),j); i; ++i)
|
||||
derived().row(LhsIsRowMajor ? j : i.index()) += i.value() * product.rhs().row(LhsIsRowMajor ? i.index() : j);
|
||||
return derived();
|
||||
}
|
||||
|
||||
// dense = dense * sparse
|
||||
template<typename Derived>
|
||||
template<typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const SparseProduct<Lhs,Rhs,DenseTimeSparseProduct>& product)
|
||||
{
|
||||
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||
typedef typename _Rhs::InnerIterator RhsInnerIterator;
|
||||
enum { RhsIsRowMajor = (_Rhs::Flags&RowMajorBit)==RowMajorBit };
|
||||
derived().setZero();
|
||||
for (int j=0; j<product.rhs().outerSize(); ++j)
|
||||
for (RhsInnerIterator i(product.rhs(),j); i; ++i)
|
||||
derived().col(RhsIsRowMajor ? i.index() : j) += i.value() * product.lhs().col(RhsIsRowMajor ? j : i.index());
|
||||
return derived();
|
||||
}
|
||||
|
||||
// sparse * sparse
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
inline const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
||||
{
|
||||
enum {
|
||||
ProductIsValid = Derived::ColsAtCompileTime==Dynamic
|
||||
|| OtherDerived::RowsAtCompileTime==Dynamic
|
||||
|| int(Derived::ColsAtCompileTime)==int(OtherDerived::RowsAtCompileTime),
|
||||
AreVectors = Derived::IsVectorAtCompileTime && OtherDerived::IsVectorAtCompileTime,
|
||||
SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(Derived,OtherDerived)
|
||||
};
|
||||
// note to the lost user:
|
||||
// * for a dot product use: v1.dot(v2)
|
||||
// * for a coeff-wise product use: v1.cwise()*v2
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
|
||||
INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
|
||||
INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
|
||||
EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
|
||||
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
||||
}
|
||||
|
||||
// sparse * dense
|
||||
template<typename Derived>
|
||||
template<typename OtherDerived>
|
||||
EIGEN_STRONG_INLINE const typename SparseProductReturnType<Derived,OtherDerived>::Type
|
||||
SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
|
||||
{
|
||||
return typename SparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
|
||||
}
|
||||
|
||||
|
@ -109,10 +109,10 @@ template<typename MatrixType> class SparseInnerVector;
|
||||
template<typename Derived> class SparseCwise;
|
||||
template<typename UnaryOp, typename MatrixType> class SparseCwiseUnaryOp;
|
||||
template<typename BinaryOp, typename Lhs, typename Rhs> class SparseCwiseBinaryOp;
|
||||
template<typename Lhs, typename Rhs> class SparseProduct;
|
||||
template<typename ExpressionType, unsigned int Added, unsigned int Removed> class SparseFlagged;
|
||||
|
||||
template<typename Lhs, typename Rhs> struct SparseProductReturnType;
|
||||
template<typename Lhs, typename Rhs> struct ei_sparse_product_mode;
|
||||
template<typename Lhs, typename Rhs, int ProductMode = ei_sparse_product_mode<Lhs,Rhs>::value> struct SparseProductReturnType;
|
||||
|
||||
const int AccessPatternNotSupported = 0x0;
|
||||
const int AccessPatternSupported = 0x1;
|
||||
|
@ -216,6 +216,7 @@ template<typename Scalar> void sparse_basic(int rows, int cols)
|
||||
DenseMatrix refMat2 = DenseMatrix::Zero(rows, rows);
|
||||
DenseMatrix refMat3 = DenseMatrix::Zero(rows, rows);
|
||||
DenseMatrix refMat4 = DenseMatrix::Zero(rows, rows);
|
||||
DenseMatrix dm4 = DenseMatrix::Zero(rows, rows);
|
||||
SparseMatrix<Scalar> m2(rows, rows);
|
||||
SparseMatrix<Scalar> m3(rows, rows);
|
||||
SparseMatrix<Scalar> m4(rows, rows);
|
||||
@ -226,6 +227,18 @@ template<typename Scalar> void sparse_basic(int rows, int cols)
|
||||
VERIFY_IS_APPROX(m4=m2.transpose()*m3, refMat4=refMat2.transpose()*refMat3);
|
||||
VERIFY_IS_APPROX(m4=m2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
|
||||
VERIFY_IS_APPROX(m4=m2*m3.transpose(), refMat4=refMat2*refMat3.transpose());
|
||||
|
||||
// sparse * dense
|
||||
VERIFY_IS_APPROX(dm4=m2*refMat3, refMat4=refMat2*refMat3);
|
||||
VERIFY_IS_APPROX(dm4=m2*refMat3.transpose(), refMat4=refMat2*refMat3.transpose());
|
||||
VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3, refMat4=refMat2.transpose()*refMat3);
|
||||
VERIFY_IS_APPROX(dm4=m2.transpose()*refMat3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
|
||||
|
||||
// dense * sparse
|
||||
VERIFY_IS_APPROX(dm4=refMat2*m3, refMat4=refMat2*refMat3);
|
||||
VERIFY_IS_APPROX(dm4=refMat2*m3.transpose(), refMat4=refMat2*refMat3.transpose());
|
||||
VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3, refMat4=refMat2.transpose()*refMat3);
|
||||
VERIFY_IS_APPROX(dm4=refMat2.transpose()*m3.transpose(), refMat4=refMat2.transpose()*refMat3.transpose());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user