optimize sparse-sparse Kronecker product

This commit is contained in:
Gael Guennebaud 2014-02-14 14:46:01 +01:00
parent 0d3f496233
commit 3283d98d13
2 changed files with 50 additions and 1 deletions

View File

@ -153,7 +153,22 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
Bc = m_B.cols();
dst.resize(this->rows(), this->cols());
dst.resizeNonZeros(0);
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
// compute number of non-zeros per innervectors of dst
{
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
for (Index kA=0; kA < m_A.outerSize(); ++kA)
for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
for (Index kB=0; kB < m_B.outerSize(); ++kB)
for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
}
for (Index kA=0; kA < m_A.outerSize(); ++kA)
{

View File

@ -183,4 +183,38 @@ void test_kronecker_product()
DM_b2.resize(4,8);
DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
CALL_SUBTEST(check_dimension(DM_ab2,10*4,9*8));
for(int i = 0; i < g_repeat; i++)
{
double density = Eigen::internal::random<double>(0.01,0.5);
int ra = Eigen::internal::random<int>(1,50);
int ca = Eigen::internal::random<int>(1,50);
int rb = Eigen::internal::random<int>(1,50);
int cb = Eigen::internal::random<int>(1,50);
SparseMatrix<float,ColMajor> sA(ra,ca), sB(rb,cb), sC;
SparseMatrix<float,RowMajor> sC2;
MatrixXf dA(ra,ca), dB(rb,cb), dC;
initSparse(density, dA, sA);
initSparse(density, dB, sB);
sC = kroneckerProduct(sA,sB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC),dC);
sC = kroneckerProduct(sA.transpose(),sB);
dC = kroneckerProduct(dA.transpose(),dB);
VERIFY_IS_APPROX(MatrixXf(sC),dC);
sC = kroneckerProduct(sA.transpose(),sB.transpose());
dC = kroneckerProduct(dA.transpose(),dB.transpose());
VERIFY_IS_APPROX(MatrixXf(sC),dC);
sC = kroneckerProduct(sA,sB.transpose());
dC = kroneckerProduct(dA,dB.transpose());
VERIFY_IS_APPROX(MatrixXf(sC),dC);
sC2 = kroneckerProduct(sA,sB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
}
}