mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 04:09:10 +08:00
Update KroneckerProduct wrt evaluator changes
This commit is contained in:
parent
62bce6e5e6
commit
2ae20d558b
@ -154,16 +154,41 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
|||||||
dst.resize(this->rows(), this->cols());
|
dst.resize(this->rows(), this->cols());
|
||||||
dst.resizeNonZeros(0);
|
dst.resizeNonZeros(0);
|
||||||
|
|
||||||
|
// 1 - evaluate the operands if needed:
|
||||||
|
typedef typename internal::nested_eval<Lhs,10>::type Lhs1;
|
||||||
|
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
|
||||||
|
const Lhs1 lhs1(m_A);
|
||||||
|
typedef typename internal::nested_eval<Rhs,10>::type Rhs1;
|
||||||
|
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
|
||||||
|
const Rhs1 rhs1(m_B);
|
||||||
|
|
||||||
|
// 2 - construct a SparseView for dense operands
|
||||||
|
typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2;
|
||||||
|
typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned;
|
||||||
|
const Lhs2 lhs2(lhs1);
|
||||||
|
typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2;
|
||||||
|
typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned;
|
||||||
|
const Rhs2 rhs2(rhs1);
|
||||||
|
|
||||||
|
// 3 - construct respective evaluators
|
||||||
|
typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval;
|
||||||
|
LhsEval lhsEval(lhs2);
|
||||||
|
typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval;
|
||||||
|
RhsEval rhsEval(rhs2);
|
||||||
|
|
||||||
|
typedef typename LhsEval::InnerIterator LhsInnerIterator;
|
||||||
|
typedef typename RhsEval::InnerIterator RhsInnerIterator;
|
||||||
|
|
||||||
// compute number of non-zeros per innervectors of dst
|
// compute number of non-zeros per innervectors of dst
|
||||||
{
|
{
|
||||||
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
|
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
|
||||||
for (Index kA=0; kA < m_A.outerSize(); ++kA)
|
for (Index kA=0; kA < m_A.outerSize(); ++kA)
|
||||||
for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
|
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
|
||||||
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
|
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
|
||||||
|
|
||||||
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
|
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
|
||||||
for (Index kB=0; kB < m_B.outerSize(); ++kB)
|
for (Index kB=0; kB < m_B.outerSize(); ++kB)
|
||||||
for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
|
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
|
||||||
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
|
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
|
||||||
|
|
||||||
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
|
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
|
||||||
@ -174,9 +199,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
|||||||
{
|
{
|
||||||
for (Index kB=0; kB < m_B.outerSize(); ++kB)
|
for (Index kB=0; kB < m_B.outerSize(); ++kB)
|
||||||
{
|
{
|
||||||
for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
|
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
|
||||||
{
|
{
|
||||||
for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
|
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
|
||||||
{
|
{
|
||||||
const Index i = itA.row() * Br + itB.row(),
|
const Index i = itA.row() * Br + itB.row(),
|
||||||
j = itA.col() * Bc + itB.col();
|
j = itA.col() * Bc + itB.col();
|
||||||
@ -201,8 +226,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
|||||||
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
||||||
Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
||||||
MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
||||||
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
|
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret
|
||||||
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
|
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef Matrix<Scalar,Rows,Cols> ReturnType;
|
typedef Matrix<Scalar,Rows,Cols> ReturnType;
|
||||||
|
@ -93,7 +93,7 @@ ei_add_test(gmres)
|
|||||||
ei_add_test(minres)
|
ei_add_test(minres)
|
||||||
ei_add_test(levenberg_marquardt)
|
ei_add_test(levenberg_marquardt)
|
||||||
ei_add_test(bdcsvd)
|
ei_add_test(bdcsvd)
|
||||||
# TODO ei_add_test(kronecker_product)
|
ei_add_test(kronecker_product)
|
||||||
|
|
||||||
option(EIGEN_TEST_CXX11 "Enable testing of C++11 features (e.g. Tensor module)." OFF)
|
option(EIGEN_TEST_CXX11 "Enable testing of C++11 features (e.g. Tensor module)." OFF)
|
||||||
if(EIGEN_TEST_CXX11)
|
if(EIGEN_TEST_CXX11)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user