Eliminate unnecessary copying for sparse Kronecker product.

This commit is contained in:
Chen-Pang He 2013-07-15 09:10:17 +08:00
parent 9be658f701
commit 4b780553e0
2 changed files with 75 additions and 52 deletions

View File

@ -14,7 +14,42 @@
namespace Eigen {
template<typename Scalar, int Options, typename Index> class SparseMatrix;
template<typename Derived>
class KroneckerProductBase : public ReturnByValue<Derived>
{
private:
typedef typename internal::traits<Derived> Traits;
typedef typename Traits::Lhs Lhs;
typedef typename Traits::Rhs Rhs;
typedef typename Traits::Scalar Scalar;
protected:
typedef typename Traits::Index Index;
public:
KroneckerProductBase(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
{}
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
Scalar coeff(Index row, Index col) const
{
return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
m_B.coeff(row % m_B.rows(), col % m_B.cols());
}
Scalar coeff(Index i) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
}
protected:
typename Lhs::Nested m_A;
typename Rhs::Nested m_B;
};
/*!
* \brief Kronecker tensor product helper class for dense matrices
@ -27,40 +62,21 @@ template<typename Scalar, int Options, typename Index> class SparseMatrix;
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
*/
template<typename Lhs, typename Rhs>
class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
{
private:
typedef ReturnByValue<KroneckerProduct> Base;
typedef typename Base::Scalar Scalar;
typedef typename Base::Index Index;
typedef KroneckerProductBase<KroneckerProduct> Base;
using Base::m_A;
using Base::m_B;
public:
/*! \brief Constructor. */
KroneckerProduct(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
: Base(A, B)
{}
/*! \brief Evaluate the Kronecker tensor product. */
template<typename Dest> void evalTo(Dest& dst) const;
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
Scalar coeff(Index row, Index col) const
{
return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
m_B.coeff(row % m_B.rows(), col % m_B.cols());
}
Scalar coeff(Index i) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
}
private:
typename Lhs::Nested m_A;
typename Rhs::Nested m_B;
};
/*!
@ -77,40 +93,28 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
*/
template<typename Lhs, typename Rhs>
class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
{
private:
typedef typename internal::traits<KroneckerProductSparse>::Index Index;
typedef KroneckerProductBase<KroneckerProductSparse> Base;
using Base::m_A;
using Base::m_B;
public:
/*! \brief Constructor. */
KroneckerProductSparse(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
: Base(A, B)
{}
/*! \brief Evaluate the Kronecker tensor product. */
template<typename Dest> void evalTo(Dest& dst) const;
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
template<typename Scalar, int Options, typename Index>
operator SparseMatrix<Scalar, Options, Index>()
{
SparseMatrix<Scalar, Options, Index> result;
evalTo(result.derived());
return result;
}
private:
typename Lhs::Nested m_A;
typename Rhs::Nested m_B;
};
template<typename Lhs, typename Rhs>
template<typename Dest>
void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
{
typedef typename Base::Index Index;
const int BlockRows = Rhs::RowsAtCompileTime,
BlockCols = Rhs::ColsAtCompileTime;
const Index Br = m_B.rows(),
@ -124,9 +128,10 @@ template<typename Lhs, typename Rhs>
template<typename Dest>
void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
{
typedef typename Base::Index Index;
const Index Br = m_B.rows(),
Bc = m_B.cols();
dst.resize(rows(),cols());
dst.resize(this->rows(), this->cols());
dst.resizeNonZeros(0);
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
@ -155,6 +160,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
typedef typename remove_all<_Lhs>::type Lhs;
typedef typename remove_all<_Rhs>::type Rhs;
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
enum {
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
@ -193,6 +199,8 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
CoeffReadCost = Dynamic
};
typedef SparseMatrix<Scalar> ReturnType;
};
} // end namespace internal
@ -228,6 +236,16 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
* Computes Kronecker tensor product of two matrices, at least one of
* which is sparse
*
* \warning If you want to replace a matrix by its Kronecker product
* with some matrix, do \b NOT do this:
* \code
* A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
* \endcode
* instead, use eval() to work around this:
* \code
* A = kroneckerProduct(A,B).eval();
* \endcode
*
* \param a Dense/sparse matrix a
* \param b Dense/sparse matrix b
* \return Kronecker tensor product of a and b, stored in a sparse

View File

@ -107,31 +107,34 @@ void test_kronecker_product()
SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
// test kroneckerProduct(DM_block,DM,DM_fixedSize)
// test DM_fixedSize = kroneckerProduct(DM_block,DM)
Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
for(int i=0;i<DM_fix_ab.rows();++i)
for(int j=0;j<DM_fix_ab.cols();++j)
VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
// test kroneckerProduct(DM,DM,DM_block)
// test DM_block = kroneckerProduct(DM,DM)
MatrixXd DM_block_ab(10,15);
DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
// test kroneckerProduct(DM,DM,DM)
// test DM = kroneckerProduct(DM,DM)
MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(DM_ab));
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
// test kroneckerProduct(SM,DM,SM)
// test SM = kroneckerProduct(SM,DM)
SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab));
SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
// test kroneckerProduct(DM,SM,SM)
// test SM = kroneckerProduct(DM,SM)
SM_ab.setZero();
SM_ab.insert(0,0)=37.0;
SM_ab = kroneckerProduct(DM_a,SM_b);
@ -140,8 +143,9 @@ void test_kronecker_product()
SM_ab2.insert(0,0)=37.0;
SM_ab2 = kroneckerProduct(DM_a,SM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
// test kroneckerProduct(SM,SM,SM)
// test SM = kroneckerProduct(SM,SM)
SM_ab.resize(2,33);
SM_ab.insert(0,0)=37.0;
SM_ab = kroneckerProduct(SM_a,SM_b);
@ -150,8 +154,9 @@ void test_kronecker_product()
SM_ab2.insert(0,0)=37.0;
SM_ab2 = kroneckerProduct(SM_a,SM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
// test kroneckerProduct(SM,SM,SM) with sparse pattern
// test SM = kroneckerProduct(SM,SM) with sparse pattern
SM_a.resize(4,5);
SM_b.resize(3,2);
SM_a.resizeNonZeros(0);
@ -169,7 +174,7 @@ void test_kronecker_product()
SM_ab = kroneckerProduct(SM_a,SM_b);
CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
// test dimension of result of kroneckerProduct(DM,DM,DM)
// test dimension of result of DM = kroneckerProduct(DM,DM)
MatrixXd DM_a2(2,1);
MatrixXd DM_b2(5,4);
MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);