mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Let KroneckerProduct exploits the recently introduced generic InnerIterator class.
This commit is contained in:
parent
abd3502e9e
commit
842e31cf5c
@ -157,40 +157,27 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
|||||||
dst.resizeNonZeros(0);
|
dst.resizeNonZeros(0);
|
||||||
|
|
||||||
// 1 - evaluate the operands if needed:
|
// 1 - evaluate the operands if needed:
|
||||||
typedef typename internal::nested_eval<Lhs,10>::type Lhs1;
|
typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
|
||||||
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
|
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
|
||||||
const Lhs1 lhs1(m_A);
|
const Lhs1 lhs1(m_A);
|
||||||
typedef typename internal::nested_eval<Rhs,10>::type Rhs1;
|
typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
|
||||||
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
|
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
|
||||||
const Rhs1 rhs1(m_B);
|
const Rhs1 rhs1(m_B);
|
||||||
|
|
||||||
// 2 - construct a SparseView for dense operands
|
// 2 - construct respective iterators
|
||||||
typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2;
|
typedef InnerIterator<Lhs1Cleaned> LhsInnerIterator;
|
||||||
typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned;
|
typedef InnerIterator<Rhs1Cleaned> RhsInnerIterator;
|
||||||
const Lhs2 lhs2(lhs1);
|
|
||||||
typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2;
|
|
||||||
typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned;
|
|
||||||
const Rhs2 rhs2(rhs1);
|
|
||||||
|
|
||||||
// 3 - construct respective evaluators
|
|
||||||
typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval;
|
|
||||||
LhsEval lhsEval(lhs2);
|
|
||||||
typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval;
|
|
||||||
RhsEval rhsEval(rhs2);
|
|
||||||
|
|
||||||
typedef typename LhsEval::InnerIterator LhsInnerIterator;
|
|
||||||
typedef typename RhsEval::InnerIterator RhsInnerIterator;
|
|
||||||
|
|
||||||
// compute number of non-zeros per innervectors of dst
|
// compute number of non-zeros per innervectors of dst
|
||||||
{
|
{
|
||||||
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
|
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
|
||||||
for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA)
|
for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA)
|
||||||
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
|
for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
|
||||||
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
|
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
|
||||||
|
|
||||||
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
|
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
|
||||||
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
|
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
|
||||||
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
|
for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
|
||||||
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
|
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
|
||||||
|
|
||||||
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
|
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
|
||||||
@ -201,9 +188,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
|||||||
{
|
{
|
||||||
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
|
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
|
||||||
{
|
{
|
||||||
for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
|
for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
|
||||||
{
|
{
|
||||||
for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
|
for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
|
||||||
{
|
{
|
||||||
const DestIndex
|
const DestIndex
|
||||||
i = DestIndex(itA.row() * Br + itB.row()),
|
i = DestIndex(itA.row() * Br + itB.row()),
|
||||||
|
@ -216,5 +216,17 @@ void test_kronecker_product()
|
|||||||
sC2 = kroneckerProduct(sA,sB);
|
sC2 = kroneckerProduct(sA,sB);
|
||||||
dC = kroneckerProduct(dA,dB);
|
dC = kroneckerProduct(dA,dB);
|
||||||
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
|
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
|
||||||
|
|
||||||
|
sC2 = kroneckerProduct(dA,sB);
|
||||||
|
dC = kroneckerProduct(dA,dB);
|
||||||
|
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
|
||||||
|
|
||||||
|
sC2 = kroneckerProduct(sA,dB);
|
||||||
|
dC = kroneckerProduct(dA,dB);
|
||||||
|
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
|
||||||
|
|
||||||
|
sC2 = kroneckerProduct(2*sA,sB);
|
||||||
|
dC = kroneckerProduct(2*dA,dB);
|
||||||
|
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user