This commit is contained in:
Gael Guennebaud 2010-12-31 17:26:48 +01:00
commit 902af035d3
2 changed files with 23 additions and 43 deletions

View File

@ -169,32 +169,15 @@ class SparseTimeDenseProduct
enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit }; enum { LhsIsRowMajor = (_Lhs::Flags&RowMajorBit)==RowMajorBit };
for(Index j=0; j<m_lhs.outerSize(); ++j) for(Index j=0; j<m_lhs.outerSize(); ++j)
{ {
if(LhsIsRowMajor) typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
{ typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
// Block<Dest,1,Dest::ColsAtCompileTime> dest_j(dest.row(LhsIsRowMajor ? j : 0)); // this does not work in all cases. Why?
Block<Dest,1,Dest::ColsAtCompileTime> dest_j(dest, LhsIsRowMajor ? j : 0);
for(LhsInnerIterator it(m_lhs,j); it ;++it) for(LhsInnerIterator it(m_lhs,j); it ;++it)
{ {
dest_j += (alpha*it.value()) * m_rhs.row(it.index()); if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
else if(Rhs::ColsAtCompileTime==1) dest.coeffRef(it.index()) += it.value() * rhs_j;
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
} }
} }
else if(Rhs::ColsAtCompileTime==1)
{
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(j,0);
for(LhsInnerIterator it(m_lhs,j); it ;++it)
{
dest.coeffRef(it.index()) += it.value() * rhs_j;
}
}
else
{
for(LhsInnerIterator it(m_lhs,j); it ;++it)
{
dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
}
}
}
} }
private: private:

View File

@ -141,23 +141,20 @@ template<typename SparseMatrixType> void sparse_product(const SparseMatrixType&
template<typename SparseMatrixType, typename DenseMatrixType> void sparse_product_regression_test() template<typename SparseMatrixType, typename DenseMatrixType> void sparse_product_regression_test()
{ {
// This code does not compile with afflicted versions of the bug // This code does not compile with afflicted versions of the bug
/* SparseMatrixType sm1(3,2); SparseMatrixType sm1(3,2);
DenseMatrixType m2(2,2); DenseMatrixType m2(2,2);
sm1.setZero(); sm1.setZero();
m2.setZero(); m2.setZero();
DenseMatrixType m3 = sm1*m2; DenseMatrixType m3 = sm1*m2;
*/
// This code produces a segfault with afflicted versions of another SparseTimeDenseProduct // This code produces a segfault with afflicted versions of another SparseTimeDenseProduct
// bug // bug
SparseMatrixType sm2(20000,2); SparseMatrixType sm2(20000,2);
DenseMatrixType m3(2,2);
sm2.setZero(); sm2.setZero();
m3.setZero(); DenseMatrixType m4(sm2*m2);
DenseMatrixType m4(sm2*m3);
VERIFY_IS_APPROX( m4(0,0), 0.0 ); VERIFY_IS_APPROX( m4(0,0), 0.0 );
} }