Update KroneckerProduct wrt evaluator changes

This commit is contained in:
Gael Guennebaud 2014-09-18 22:08:49 +02:00
parent 62bce6e5e6
commit 2ae20d558b
2 changed files with 31 additions and 7 deletions

View File

@ -154,16 +154,41 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
dst.resize(this->rows(), this->cols());
dst.resizeNonZeros(0);
// 1 - evaluate the operands if needed:
typedef typename internal::nested_eval<Lhs,10>::type Lhs1;
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
const Lhs1 lhs1(m_A);
typedef typename internal::nested_eval<Rhs,10>::type Rhs1;
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
const Rhs1 rhs1(m_B);
// 2 - construct a SparseView for dense operands
typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2;
typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned;
const Lhs2 lhs2(lhs1);
typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2;
typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned;
const Rhs2 rhs2(rhs1);
// 3 - construct respective evaluators
typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval;
LhsEval lhsEval(lhs2);
typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval;
RhsEval rhsEval(rhs2);
typedef typename LhsEval::InnerIterator LhsInnerIterator;
typedef typename RhsEval::InnerIterator RhsInnerIterator;
// compute number of non-zeros per innervectors of dst
{
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
for (Index kA=0; kA < m_A.outerSize(); ++kA)
for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
for (Index kB=0; kB < m_B.outerSize(); ++kB)
for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
@ -174,9 +199,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
{
for (Index kB=0; kB < m_B.outerSize(); ++kB)
{
for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
{
for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
{
const Index i = itA.row() * Br + itB.row(),
j = itA.col() * Bc + itB.col();
@ -201,8 +226,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret
};
typedef Matrix<Scalar,Rows,Cols> ReturnType;

View File

@ -93,7 +93,7 @@ ei_add_test(gmres)
ei_add_test(minres)
ei_add_test(levenberg_marquardt)
ei_add_test(bdcsvd)
# TODO ei_add_test(kronecker_product)
ei_add_test(kronecker_product)
option(EIGEN_TEST_CXX11 "Enable testing of C++11 features (e.g. Tensor module)." OFF)
if(EIGEN_TEST_CXX11)