Fix propagation of index type

This commit is contained in:
Gael Guennebaud 2014-02-13 23:58:28 +01:00
parent c0e08e9e4b
commit 0b1430ae10
4 changed files with 21 additions and 19 deletions

View File

@ -134,8 +134,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix;
ColMajorMatrix resCol(lhs.rows(),rhs.cols()); ColMajorMatrix resCol(lhs.rows(),rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
// sort the non zeros: // sort the non zeros:
@ -149,7 +149,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix;
RowMajorMatrix rhsRow = rhs; RowMajorMatrix rhsRow = rhs;
RowMajorMatrix resRow(lhs.rows(), rhs.cols()); RowMajorMatrix resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow); internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow);
@ -162,7 +162,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix;
RowMajorMatrix lhsRow = lhs; RowMajorMatrix lhsRow = lhs;
RowMajorMatrix resRow(lhs.rows(), rhs.cols()); RowMajorMatrix resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow); internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
@ -175,7 +175,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix;
RowMajorMatrix resRow(lhs.rows(), rhs.cols()); RowMajorMatrix resRow(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
res = resRow; res = resRow;
@ -190,7 +190,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix;
ColMajorMatrix resCol(lhs.rows(), rhs.cols()); ColMajorMatrix resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
res = resCol; res = resCol;
@ -202,7 +202,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,C
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix;
ColMajorMatrix lhsCol = lhs; ColMajorMatrix lhsCol = lhs;
ColMajorMatrix resCol(lhs.rows(), rhs.cols()); ColMajorMatrix resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol); internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
@ -215,7 +215,7 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,R
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix;
ColMajorMatrix rhsCol = rhs; ColMajorMatrix rhsCol = rhs;
ColMajorMatrix resCol(lhs.rows(), rhs.cols()); ColMajorMatrix resCol(lhs.rows(), rhs.cols());
internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol); internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
@ -228,8 +228,8 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
{ {
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
{ {
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::Index> RowMajorMatrix;
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> ColMajorMatrix;
RowMajorMatrix resRow(lhs.rows(),rhs.cols()); RowMajorMatrix resRow(lhs.rows(),rhs.cols());
internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
// sort the non zeros: // sort the non zeros:

View File

@ -302,8 +302,8 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
} }
else else
{ {
SparseMatrix<Scalar, RowMajorBit> trans = m; SparseMatrix<Scalar, RowMajorBit, Index> trans = m;
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit> >&>(trans); s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit, Index> >&>(trans);
} }
} }
return s; return s;

View File

@ -16,6 +16,7 @@ template<typename Lhs, typename Rhs>
struct SparseSparseProductReturnType struct SparseSparseProductReturnType
{ {
typedef typename internal::traits<Lhs>::Scalar Scalar; typedef typename internal::traits<Lhs>::Scalar Scalar;
typedef typename internal::traits<Lhs>::Index Index;
enum { enum {
LhsRowMajor = internal::traits<Lhs>::Flags & RowMajorBit, LhsRowMajor = internal::traits<Lhs>::Flags & RowMajorBit,
RhsRowMajor = internal::traits<Rhs>::Flags & RowMajorBit, RhsRowMajor = internal::traits<Rhs>::Flags & RowMajorBit,
@ -24,11 +25,11 @@ struct SparseSparseProductReturnType
}; };
typedef typename internal::conditional<TransposeLhs, typedef typename internal::conditional<TransposeLhs,
SparseMatrix<Scalar,0>, SparseMatrix<Scalar,0,Index>,
typename internal::nested<Lhs,Rhs::RowsAtCompileTime>::type>::type LhsNested; typename internal::nested<Lhs,Rhs::RowsAtCompileTime>::type>::type LhsNested;
typedef typename internal::conditional<TransposeRhs, typedef typename internal::conditional<TransposeRhs,
SparseMatrix<Scalar,0>, SparseMatrix<Scalar,0,Index>,
typename internal::nested<Rhs,Lhs::RowsAtCompileTime>::type>::type RhsNested; typename internal::nested<Rhs,Lhs::RowsAtCompileTime>::type>::type RhsNested;
typedef SparseSparseProduct<LhsNested, RhsNested> Type; typedef SparseSparseProduct<LhsNested, RhsNested> Type;

View File

@ -100,7 +100,7 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
{ {
// we need a col-major matrix to hold the result // we need a col-major matrix to hold the result
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
SparseTemporaryType _res(res.rows(), res.cols()); SparseTemporaryType _res(res.rows(), res.cols());
internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
res = _res; res = _res;
@ -126,10 +126,11 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
typedef typename ResultType::RealScalar RealScalar; typedef typename ResultType::RealScalar RealScalar;
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
{ {
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
ColMajorMatrix colLhs(lhs); typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
ColMajorMatrix colRhs(rhs); ColMajorMatrixLhs colLhs(lhs);
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance); ColMajorMatrixRhs colRhs(rhs);
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
// let's transpose the product to get a column x column product // let's transpose the product to get a column x column product
// typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;