mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Eliminate unnecessary copying for sparse Kronecker product.
This commit is contained in:
parent
9be658f701
commit
4b780553e0
@ -14,35 +14,23 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
template<typename Scalar, int Options, typename Index> class SparseMatrix;
|
template<typename Derived>
|
||||||
|
class KroneckerProductBase : public ReturnByValue<Derived>
|
||||||
/*!
|
|
||||||
* \brief Kronecker tensor product helper class for dense matrices
|
|
||||||
*
|
|
||||||
* This class is the return value of kroneckerProduct(MatrixBase,
|
|
||||||
* MatrixBase). Use the function rather than construct this class
|
|
||||||
* directly to avoid specifying template prarameters.
|
|
||||||
*
|
|
||||||
* \tparam Lhs Type of the left-hand side, a matrix expression.
|
|
||||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
|
||||||
*/
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
typedef ReturnByValue<KroneckerProduct> Base;
|
typedef typename internal::traits<Derived> Traits;
|
||||||
typedef typename Base::Scalar Scalar;
|
typedef typename Traits::Lhs Lhs;
|
||||||
typedef typename Base::Index Index;
|
typedef typename Traits::Rhs Rhs;
|
||||||
|
typedef typename Traits::Scalar Scalar;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typedef typename Traits::Index Index;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief Constructor. */
|
KroneckerProductBase(const Lhs& A, const Rhs& B)
|
||||||
KroneckerProduct(const Lhs& A, const Rhs& B)
|
|
||||||
: m_A(A), m_B(B)
|
: m_A(A), m_B(B)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/*! \brief Evaluate the Kronecker tensor product. */
|
|
||||||
template<typename Dest> void evalTo(Dest& dst) const;
|
|
||||||
|
|
||||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||||
|
|
||||||
@ -54,15 +42,43 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
|||||||
|
|
||||||
Scalar coeff(Index i) const
|
Scalar coeff(Index i) const
|
||||||
{
|
{
|
||||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
|
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||||
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
typename Lhs::Nested m_A;
|
typename Lhs::Nested m_A;
|
||||||
typename Rhs::Nested m_B;
|
typename Rhs::Nested m_B;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Kronecker tensor product helper class for dense matrices
|
||||||
|
*
|
||||||
|
* This class is the return value of kroneckerProduct(MatrixBase,
|
||||||
|
* MatrixBase). Use the function rather than construct this class
|
||||||
|
* directly to avoid specifying template prarameters.
|
||||||
|
*
|
||||||
|
* \tparam Lhs Type of the left-hand side, a matrix expression.
|
||||||
|
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||||
|
*/
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
typedef KroneckerProductBase<KroneckerProduct> Base;
|
||||||
|
using Base::m_A;
|
||||||
|
using Base::m_B;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/*! \brief Constructor. */
|
||||||
|
KroneckerProduct(const Lhs& A, const Rhs& B)
|
||||||
|
: Base(A, B)
|
||||||
|
{}
|
||||||
|
|
||||||
|
/*! \brief Evaluate the Kronecker tensor product. */
|
||||||
|
template<typename Dest> void evalTo(Dest& dst) const;
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Kronecker tensor product helper class for sparse matrices
|
* \brief Kronecker tensor product helper class for sparse matrices
|
||||||
*
|
*
|
||||||
@ -77,40 +93,28 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
|||||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||||
*/
|
*/
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
|
class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
typedef typename internal::traits<KroneckerProductSparse>::Index Index;
|
typedef KroneckerProductBase<KroneckerProductSparse> Base;
|
||||||
|
using Base::m_A;
|
||||||
|
using Base::m_B;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief Constructor. */
|
/*! \brief Constructor. */
|
||||||
KroneckerProductSparse(const Lhs& A, const Rhs& B)
|
KroneckerProductSparse(const Lhs& A, const Rhs& B)
|
||||||
: m_A(A), m_B(B)
|
: Base(A, B)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/*! \brief Evaluate the Kronecker tensor product. */
|
/*! \brief Evaluate the Kronecker tensor product. */
|
||||||
template<typename Dest> void evalTo(Dest& dst) const;
|
template<typename Dest> void evalTo(Dest& dst) const;
|
||||||
|
|
||||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
|
||||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
|
||||||
|
|
||||||
template<typename Scalar, int Options, typename Index>
|
|
||||||
operator SparseMatrix<Scalar, Options, Index>()
|
|
||||||
{
|
|
||||||
SparseMatrix<Scalar, Options, Index> result;
|
|
||||||
evalTo(result.derived());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
typename Lhs::Nested m_A;
|
|
||||||
typename Rhs::Nested m_B;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
template<typename Dest>
|
template<typename Dest>
|
||||||
void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
|
void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
|
||||||
{
|
{
|
||||||
|
typedef typename Base::Index Index;
|
||||||
const int BlockRows = Rhs::RowsAtCompileTime,
|
const int BlockRows = Rhs::RowsAtCompileTime,
|
||||||
BlockCols = Rhs::ColsAtCompileTime;
|
BlockCols = Rhs::ColsAtCompileTime;
|
||||||
const Index Br = m_B.rows(),
|
const Index Br = m_B.rows(),
|
||||||
@ -124,9 +128,10 @@ template<typename Lhs, typename Rhs>
|
|||||||
template<typename Dest>
|
template<typename Dest>
|
||||||
void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
||||||
{
|
{
|
||||||
|
typedef typename Base::Index Index;
|
||||||
const Index Br = m_B.rows(),
|
const Index Br = m_B.rows(),
|
||||||
Bc = m_B.cols();
|
Bc = m_B.cols();
|
||||||
dst.resize(rows(),cols());
|
dst.resize(this->rows(), this->cols());
|
||||||
dst.resizeNonZeros(0);
|
dst.resizeNonZeros(0);
|
||||||
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
|
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
|
||||||
|
|
||||||
@ -155,6 +160,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
|||||||
typedef typename remove_all<_Lhs>::type Lhs;
|
typedef typename remove_all<_Lhs>::type Lhs;
|
||||||
typedef typename remove_all<_Rhs>::type Rhs;
|
typedef typename remove_all<_Rhs>::type Rhs;
|
||||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||||
|
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
||||||
@ -193,6 +199,8 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
|||||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
||||||
CoeffReadCost = Dynamic
|
CoeffReadCost = Dynamic
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef SparseMatrix<Scalar> ReturnType;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
@ -228,6 +236,16 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
|
|||||||
* Computes Kronecker tensor product of two matrices, at least one of
|
* Computes Kronecker tensor product of two matrices, at least one of
|
||||||
* which is sparse
|
* which is sparse
|
||||||
*
|
*
|
||||||
|
* \warning If you want to replace a matrix by its Kronecker product
|
||||||
|
* with some matrix, do \b NOT do this:
|
||||||
|
* \code
|
||||||
|
* A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
|
||||||
|
* \endcode
|
||||||
|
* instead, use eval() to work around this:
|
||||||
|
* \code
|
||||||
|
* A = kroneckerProduct(A,B).eval();
|
||||||
|
* \endcode
|
||||||
|
*
|
||||||
* \param a Dense/sparse matrix a
|
* \param a Dense/sparse matrix a
|
||||||
* \param b Dense/sparse matrix b
|
* \param b Dense/sparse matrix b
|
||||||
* \return Kronecker tensor product of a and b, stored in a sparse
|
* \return Kronecker tensor product of a and b, stored in a sparse
|
||||||
|
@ -107,31 +107,34 @@ void test_kronecker_product()
|
|||||||
|
|
||||||
SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
|
SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
|
||||||
|
|
||||||
// test kroneckerProduct(DM_block,DM,DM_fixedSize)
|
// test DM_fixedSize = kroneckerProduct(DM_block,DM)
|
||||||
Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
|
Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
|
||||||
|
|
||||||
CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
|
CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
|
||||||
|
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
|
||||||
|
|
||||||
for(int i=0;i<DM_fix_ab.rows();++i)
|
for(int i=0;i<DM_fix_ab.rows();++i)
|
||||||
for(int j=0;j<DM_fix_ab.cols();++j)
|
for(int j=0;j<DM_fix_ab.cols();++j)
|
||||||
VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
|
VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
|
||||||
|
|
||||||
// test kroneckerProduct(DM,DM,DM_block)
|
// test DM_block = kroneckerProduct(DM,DM)
|
||||||
MatrixXd DM_block_ab(10,15);
|
MatrixXd DM_block_ab(10,15);
|
||||||
DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
|
DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
|
CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
|
||||||
|
|
||||||
// test kroneckerProduct(DM,DM,DM)
|
// test DM = kroneckerProduct(DM,DM)
|
||||||
MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
|
MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(DM_ab));
|
CALL_SUBTEST(check_kronecker_product(DM_ab));
|
||||||
|
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
|
||||||
|
|
||||||
// test kroneckerProduct(SM,DM,SM)
|
// test SM = kroneckerProduct(SM,DM)
|
||||||
SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
|
SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(SM_ab));
|
CALL_SUBTEST(check_kronecker_product(SM_ab));
|
||||||
SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
|
SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||||
|
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
|
||||||
|
|
||||||
// test kroneckerProduct(DM,SM,SM)
|
// test SM = kroneckerProduct(DM,SM)
|
||||||
SM_ab.setZero();
|
SM_ab.setZero();
|
||||||
SM_ab.insert(0,0)=37.0;
|
SM_ab.insert(0,0)=37.0;
|
||||||
SM_ab = kroneckerProduct(DM_a,SM_b);
|
SM_ab = kroneckerProduct(DM_a,SM_b);
|
||||||
@ -140,8 +143,9 @@ void test_kronecker_product()
|
|||||||
SM_ab2.insert(0,0)=37.0;
|
SM_ab2.insert(0,0)=37.0;
|
||||||
SM_ab2 = kroneckerProduct(DM_a,SM_b);
|
SM_ab2 = kroneckerProduct(DM_a,SM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||||
|
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
|
||||||
|
|
||||||
// test kroneckerProduct(SM,SM,SM)
|
// test SM = kroneckerProduct(SM,SM)
|
||||||
SM_ab.resize(2,33);
|
SM_ab.resize(2,33);
|
||||||
SM_ab.insert(0,0)=37.0;
|
SM_ab.insert(0,0)=37.0;
|
||||||
SM_ab = kroneckerProduct(SM_a,SM_b);
|
SM_ab = kroneckerProduct(SM_a,SM_b);
|
||||||
@ -150,8 +154,9 @@ void test_kronecker_product()
|
|||||||
SM_ab2.insert(0,0)=37.0;
|
SM_ab2.insert(0,0)=37.0;
|
||||||
SM_ab2 = kroneckerProduct(SM_a,SM_b);
|
SM_ab2 = kroneckerProduct(SM_a,SM_b);
|
||||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||||
|
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
|
||||||
|
|
||||||
// test kroneckerProduct(SM,SM,SM) with sparse pattern
|
// test SM = kroneckerProduct(SM,SM) with sparse pattern
|
||||||
SM_a.resize(4,5);
|
SM_a.resize(4,5);
|
||||||
SM_b.resize(3,2);
|
SM_b.resize(3,2);
|
||||||
SM_a.resizeNonZeros(0);
|
SM_a.resizeNonZeros(0);
|
||||||
@ -169,7 +174,7 @@ void test_kronecker_product()
|
|||||||
SM_ab = kroneckerProduct(SM_a,SM_b);
|
SM_ab = kroneckerProduct(SM_a,SM_b);
|
||||||
CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
|
CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
|
||||||
|
|
||||||
// test dimension of result of kroneckerProduct(DM,DM,DM)
|
// test dimension of result of DM = kroneckerProduct(DM,DM)
|
||||||
MatrixXd DM_a2(2,1);
|
MatrixXd DM_a2(2,1);
|
||||||
MatrixXd DM_b2(5,4);
|
MatrixXd DM_b2(5,4);
|
||||||
MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
|
MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user