mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Eliminate unnecessary copying for sparse Kronecker product.
This commit is contained in:
parent
9be658f701
commit
4b780553e0
@ -14,7 +14,42 @@
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template<typename Scalar, int Options, typename Index> class SparseMatrix;
|
||||
template<typename Derived>
|
||||
class KroneckerProductBase : public ReturnByValue<Derived>
|
||||
{
|
||||
private:
|
||||
typedef typename internal::traits<Derived> Traits;
|
||||
typedef typename Traits::Lhs Lhs;
|
||||
typedef typename Traits::Rhs Rhs;
|
||||
typedef typename Traits::Scalar Scalar;
|
||||
|
||||
protected:
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
public:
|
||||
KroneckerProductBase(const Lhs& A, const Rhs& B)
|
||||
: m_A(A), m_B(B)
|
||||
{}
|
||||
|
||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||
|
||||
Scalar coeff(Index row, Index col) const
|
||||
{
|
||||
return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
|
||||
m_B.coeff(row % m_B.rows(), col % m_B.cols());
|
||||
}
|
||||
|
||||
Scalar coeff(Index i) const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
|
||||
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
||||
}
|
||||
|
||||
protected:
|
||||
typename Lhs::Nested m_A;
|
||||
typename Rhs::Nested m_B;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Kronecker tensor product helper class for dense matrices
|
||||
@ -27,40 +62,21 @@ template<typename Scalar, int Options, typename Index> class SparseMatrix;
|
||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||
*/
|
||||
template<typename Lhs, typename Rhs>
|
||||
class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
||||
class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
|
||||
{
|
||||
private:
|
||||
typedef ReturnByValue<KroneckerProduct> Base;
|
||||
typedef typename Base::Scalar Scalar;
|
||||
typedef typename Base::Index Index;
|
||||
typedef KroneckerProductBase<KroneckerProduct> Base;
|
||||
using Base::m_A;
|
||||
using Base::m_B;
|
||||
|
||||
public:
|
||||
/*! \brief Constructor. */
|
||||
KroneckerProduct(const Lhs& A, const Rhs& B)
|
||||
: m_A(A), m_B(B)
|
||||
: Base(A, B)
|
||||
{}
|
||||
|
||||
/*! \brief Evaluate the Kronecker tensor product. */
|
||||
template<typename Dest> void evalTo(Dest& dst) const;
|
||||
|
||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||
|
||||
Scalar coeff(Index row, Index col) const
|
||||
{
|
||||
return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
|
||||
m_B.coeff(row % m_B.rows(), col % m_B.cols());
|
||||
}
|
||||
|
||||
Scalar coeff(Index i) const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
|
||||
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
||||
}
|
||||
|
||||
private:
|
||||
typename Lhs::Nested m_A;
|
||||
typename Rhs::Nested m_B;
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -77,40 +93,28 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||
*/
|
||||
template<typename Lhs, typename Rhs>
|
||||
class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
|
||||
class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
|
||||
{
|
||||
private:
|
||||
typedef typename internal::traits<KroneckerProductSparse>::Index Index;
|
||||
typedef KroneckerProductBase<KroneckerProductSparse> Base;
|
||||
using Base::m_A;
|
||||
using Base::m_B;
|
||||
|
||||
public:
|
||||
/*! \brief Constructor. */
|
||||
KroneckerProductSparse(const Lhs& A, const Rhs& B)
|
||||
: m_A(A), m_B(B)
|
||||
: Base(A, B)
|
||||
{}
|
||||
|
||||
/*! \brief Evaluate the Kronecker tensor product. */
|
||||
template<typename Dest> void evalTo(Dest& dst) const;
|
||||
|
||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||
|
||||
template<typename Scalar, int Options, typename Index>
|
||||
operator SparseMatrix<Scalar, Options, Index>()
|
||||
{
|
||||
SparseMatrix<Scalar, Options, Index> result;
|
||||
evalTo(result.derived());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
typename Lhs::Nested m_A;
|
||||
typename Rhs::Nested m_B;
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
template<typename Dest>
|
||||
void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
|
||||
{
|
||||
typedef typename Base::Index Index;
|
||||
const int BlockRows = Rhs::RowsAtCompileTime,
|
||||
BlockCols = Rhs::ColsAtCompileTime;
|
||||
const Index Br = m_B.rows(),
|
||||
@ -124,9 +128,10 @@ template<typename Lhs, typename Rhs>
|
||||
template<typename Dest>
|
||||
void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
||||
{
|
||||
typedef typename Base::Index Index;
|
||||
const Index Br = m_B.rows(),
|
||||
Bc = m_B.cols();
|
||||
dst.resize(rows(),cols());
|
||||
dst.resize(this->rows(), this->cols());
|
||||
dst.resizeNonZeros(0);
|
||||
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
|
||||
|
||||
@ -155,6 +160,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
||||
typedef typename remove_all<_Lhs>::type Lhs;
|
||||
typedef typename remove_all<_Rhs>::type Rhs;
|
||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
||||
|
||||
enum {
|
||||
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
||||
@ -193,6 +199,8 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
||||
CoeffReadCost = Dynamic
|
||||
};
|
||||
|
||||
typedef SparseMatrix<Scalar> ReturnType;
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
@ -228,6 +236,16 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
|
||||
* Computes Kronecker tensor product of two matrices, at least one of
|
||||
* which is sparse
|
||||
*
|
||||
* \warning If you want to replace a matrix by its Kronecker product
|
||||
* with some matrix, do \b NOT do this:
|
||||
* \code
|
||||
* A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
|
||||
* \endcode
|
||||
* instead, use eval() to work around this:
|
||||
* \code
|
||||
* A = kroneckerProduct(A,B).eval();
|
||||
* \endcode
|
||||
*
|
||||
* \param a Dense/sparse matrix a
|
||||
* \param b Dense/sparse matrix b
|
||||
* \return Kronecker tensor product of a and b, stored in a sparse
|
||||
|
@ -107,31 +107,34 @@ void test_kronecker_product()
|
||||
|
||||
SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
|
||||
|
||||
// test kroneckerProduct(DM_block,DM,DM_fixedSize)
|
||||
// test DM_fixedSize = kroneckerProduct(DM_block,DM)
|
||||
Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
|
||||
|
||||
CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
|
||||
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
|
||||
|
||||
for(int i=0;i<DM_fix_ab.rows();++i)
|
||||
for(int j=0;j<DM_fix_ab.cols();++j)
|
||||
VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
|
||||
|
||||
// test kroneckerProduct(DM,DM,DM_block)
|
||||
// test DM_block = kroneckerProduct(DM,DM)
|
||||
MatrixXd DM_block_ab(10,15);
|
||||
DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
|
||||
|
||||
// test kroneckerProduct(DM,DM,DM)
|
||||
// test DM = kroneckerProduct(DM,DM)
|
||||
MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(DM_ab));
|
||||
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
|
||||
|
||||
// test kroneckerProduct(SM,DM,SM)
|
||||
// test SM = kroneckerProduct(SM,DM)
|
||||
SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(SM_ab));
|
||||
SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
|
||||
|
||||
// test kroneckerProduct(DM,SM,SM)
|
||||
// test SM = kroneckerProduct(DM,SM)
|
||||
SM_ab.setZero();
|
||||
SM_ab.insert(0,0)=37.0;
|
||||
SM_ab = kroneckerProduct(DM_a,SM_b);
|
||||
@ -140,8 +143,9 @@ void test_kronecker_product()
|
||||
SM_ab2.insert(0,0)=37.0;
|
||||
SM_ab2 = kroneckerProduct(DM_a,SM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
|
||||
|
||||
// test kroneckerProduct(SM,SM,SM)
|
||||
// test SM = kroneckerProduct(SM,SM)
|
||||
SM_ab.resize(2,33);
|
||||
SM_ab.insert(0,0)=37.0;
|
||||
SM_ab = kroneckerProduct(SM_a,SM_b);
|
||||
@ -150,8 +154,9 @@ void test_kronecker_product()
|
||||
SM_ab2.insert(0,0)=37.0;
|
||||
SM_ab2 = kroneckerProduct(SM_a,SM_b);
|
||||
CALL_SUBTEST(check_kronecker_product(SM_ab2));
|
||||
CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
|
||||
|
||||
// test kroneckerProduct(SM,SM,SM) with sparse pattern
|
||||
// test SM = kroneckerProduct(SM,SM) with sparse pattern
|
||||
SM_a.resize(4,5);
|
||||
SM_b.resize(3,2);
|
||||
SM_a.resizeNonZeros(0);
|
||||
@ -169,7 +174,7 @@ void test_kronecker_product()
|
||||
SM_ab = kroneckerProduct(SM_a,SM_b);
|
||||
CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
|
||||
|
||||
// test dimension of result of kroneckerProduct(DM,DM,DM)
|
||||
// test dimension of result of DM = kroneckerProduct(DM,DM)
|
||||
MatrixXd DM_a2(2,1);
|
||||
MatrixXd DM_b2(5,4);
|
||||
MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
|
||||
|
Loading…
x
Reference in New Issue
Block a user