mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 18:19:34 +08:00
Let KroneckerProduct inherit ReturnByValue to eliminate temporary evaluation. It's uncommon to store the product back to one of the operands.
This commit is contained in:
parent
8284e7134b
commit
0508a0620b
@ -165,9 +165,6 @@ template<typename Derived> class MatrixBase
|
|||||||
|
|
||||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||||
Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
|
Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
Derived& lazyAssign(const KroneckerProduct<Lhs,Rhs>& other);
|
|
||||||
#endif // not EIGEN_PARSED_BY_DOXYGEN
|
#endif // not EIGEN_PARSED_BY_DOXYGEN
|
||||||
|
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
|
@ -283,7 +283,6 @@ struct stem_function
|
|||||||
}
|
}
|
||||||
|
|
||||||
// KroneckerProduct module
|
// KroneckerProduct module
|
||||||
template<typename Lhs, typename Rhs> class KroneckerProduct;
|
|
||||||
template<typename Lhs, typename Rhs> class KroneckerProductSparse;
|
template<typename Lhs, typename Rhs> class KroneckerProductSparse;
|
||||||
|
|
||||||
#ifdef EIGEN2_SUPPORT
|
#ifdef EIGEN2_SUPPORT
|
||||||
|
@ -18,59 +18,6 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
template<typename _Lhs, typename _Rhs>
|
|
||||||
struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
|
||||||
{
|
|
||||||
typedef MatrixXpr XprKind;
|
|
||||||
typedef typename remove_all<_Lhs>::type Lhs;
|
|
||||||
typedef typename remove_all<_Rhs>::type Rhs;
|
|
||||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
|
||||||
typedef Dense StorageKind;
|
|
||||||
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
|
||||||
|
|
||||||
enum {
|
|
||||||
RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
|
||||||
ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
|
||||||
MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
|
||||||
MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
|
||||||
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
|
|
||||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
|
|
||||||
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename _Lhs, typename _Rhs>
|
|
||||||
struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
|
||||||
{
|
|
||||||
typedef MatrixXpr XprKind;
|
|
||||||
typedef typename remove_all<_Lhs>::type Lhs;
|
|
||||||
typedef typename remove_all<_Rhs>::type Rhs;
|
|
||||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
|
||||||
typedef Sparse StorageKind;
|
|
||||||
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
|
||||||
|
|
||||||
enum {
|
|
||||||
LhsFlags = Lhs::Flags,
|
|
||||||
RhsFlags = Rhs::Flags,
|
|
||||||
|
|
||||||
RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
|
||||||
ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
|
||||||
MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
|
||||||
MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
|
||||||
|
|
||||||
EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
|
|
||||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
|
||||||
|
|
||||||
Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
|
||||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
|
||||||
CoeffReadCost = Dynamic
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end namespace internal
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Kronecker tensor product helper class for dense matrices
|
* \brief Kronecker tensor product helper class for dense matrices
|
||||||
*
|
*
|
||||||
@ -82,12 +29,14 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
|||||||
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
||||||
*/
|
*/
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> >
|
class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
||||||
{
|
{
|
||||||
public:
|
private:
|
||||||
typedef MatrixBase<KroneckerProduct> Base;
|
typedef ReturnByValue<KroneckerProduct> Base;
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(KroneckerProduct)
|
typedef typename Base::Scalar Scalar;
|
||||||
|
typedef typename Base::Index Index;
|
||||||
|
|
||||||
|
public:
|
||||||
/*! \brief Constructor. */
|
/*! \brief Constructor. */
|
||||||
KroneckerProduct(const Lhs& A, const Rhs& B)
|
KroneckerProduct(const Lhs& A, const Rhs& B)
|
||||||
: m_A(A), m_B(B)
|
: m_A(A), m_B(B)
|
||||||
@ -99,13 +48,13 @@ class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> >
|
|||||||
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
||||||
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
||||||
|
|
||||||
typename Base::CoeffReturnType coeff(Index row, Index col) const
|
Scalar coeff(Index row, Index col) const
|
||||||
{
|
{
|
||||||
return m_A.coeff(row / m_A.cols(), col / m_A.rows()) *
|
return m_A.coeff(row / m_A.cols(), col / m_A.rows()) *
|
||||||
m_B.coeff(row % m_A.cols(), col % m_A.rows());
|
m_B.coeff(row % m_A.cols(), col % m_A.rows());
|
||||||
}
|
}
|
||||||
|
|
||||||
typename Base::CoeffReturnType coeff(Index i) const
|
Scalar coeff(Index i) const
|
||||||
{
|
{
|
||||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
|
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
|
||||||
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
||||||
@ -198,9 +147,71 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename _Lhs, typename _Rhs>
|
||||||
|
struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
||||||
|
{
|
||||||
|
typedef typename remove_all<_Lhs>::type Lhs;
|
||||||
|
typedef typename remove_all<_Rhs>::type Rhs;
|
||||||
|
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
Rows = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
||||||
|
Cols = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
||||||
|
MaxRows = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
||||||
|
MaxCols = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
||||||
|
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Matrix<Scalar,Rows,Cols> ReturnType;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename _Lhs, typename _Rhs>
|
||||||
|
struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
||||||
|
{
|
||||||
|
typedef MatrixXpr XprKind;
|
||||||
|
typedef typename remove_all<_Lhs>::type Lhs;
|
||||||
|
typedef typename remove_all<_Rhs>::type Rhs;
|
||||||
|
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||||
|
typedef Sparse StorageKind;
|
||||||
|
typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
LhsFlags = Lhs::Flags,
|
||||||
|
RhsFlags = Rhs::Flags,
|
||||||
|
|
||||||
|
RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
|
||||||
|
ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
|
||||||
|
MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
|
||||||
|
MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
|
||||||
|
|
||||||
|
EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
|
||||||
|
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
||||||
|
|
||||||
|
Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
||||||
|
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
||||||
|
CoeffReadCost = Dynamic
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
* \ingroup KroneckerProduct_Module
|
||||||
|
*
|
||||||
* Computes Kronecker tensor product of two dense matrices
|
* Computes Kronecker tensor product of two dense matrices
|
||||||
*
|
*
|
||||||
|
* \warning If you want to replace a matrix by its Kronecker product
|
||||||
|
* with some matrix, do \b NOT do this:
|
||||||
|
* \code
|
||||||
|
* A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
|
||||||
|
* \endcode
|
||||||
|
* instead, use eval() to work around this:
|
||||||
|
* \code
|
||||||
|
* A = kroneckerProduct(A,B).eval();
|
||||||
|
* \endcode
|
||||||
|
*
|
||||||
* \param a Dense matrix a
|
* \param a Dense matrix a
|
||||||
* \param b Dense matrix b
|
* \param b Dense matrix b
|
||||||
* \return Kronecker tensor product of a and b
|
* \return Kronecker tensor product of a and b
|
||||||
@ -212,8 +223,10 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
* \ingroup KroneckerProduct_Module
|
||||||
|
*
|
||||||
* Computes Kronecker tensor product of two matrices, at least one of
|
* Computes Kronecker tensor product of two matrices, at least one of
|
||||||
* which is sparse.
|
* which is sparse
|
||||||
*
|
*
|
||||||
* \param a Dense/sparse matrix a
|
* \param a Dense/sparse matrix a
|
||||||
* \param b Dense/sparse matrix b
|
* \param b Dense/sparse matrix b
|
||||||
@ -226,14 +239,6 @@ KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenB
|
|||||||
return KroneckerProductSparse<A,B>(a.derived(), b.derived());
|
return KroneckerProductSparse<A,B>(a.derived(), b.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Derived>
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
Derived& MatrixBase<Derived>::lazyAssign(const KroneckerProduct<Lhs,Rhs>& other)
|
|
||||||
{
|
|
||||||
other.evalTo(derived());
|
|
||||||
return derived();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product)
|
Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user