implement a ProductBase class and, as a proof of concept, update TriangularProduct

and SelfAdjointMatrixProduct to take advantage of it => fewer LOC
This commit is contained in:
Gael Guennebaud 2009-08-04 16:54:17 +02:00
parent f3a6bc48c4
commit 7d607048a9
9 changed files with 219 additions and 160 deletions

View File

@ -178,6 +178,7 @@ namespace Eigen {
#include "src/Core/Swap.h"
#include "src/Core/CommaInitializer.h"
#include "src/Core/util/BlasUtil.h"
#include "src/Core/ProductBase.h"
#include "src/Core/Product.h"
#include "src/Core/products/GeneralMatrixMatrix.h"
#include "src/Core/products/GeneralMatrixVector.h"

View File

@ -339,6 +339,13 @@ class Matrix
return Base::operator=(func);
}
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_STRONG_INLINE Matrix& operator=(const ProductBase<ProductDerived,Lhs,Rhs>& other)
{
resize(other.rows(), other.cols());
return Base::operator=(other);
}
using Base::operator +=;
using Base::operator -=;
using Base::operator *=;
@ -444,6 +451,15 @@ class Matrix
resize(other.rows(), other.cols());
other.evalTo(*this);
}
template<typename ProductDerived, typename Lhs, typename Rhs>
EIGEN_STRONG_INLINE Matrix(const ProductBase<ProductDerived,Lhs,Rhs>& other)
{
_check_template_params();
resize(other.rows(), other.cols());
other.evalTo(*this);
}
/** Destructor */
inline ~Matrix() {}

View File

@ -318,6 +318,17 @@ template<typename Derived> class MatrixBase
Derived& operator-=(const AnyMatrixBase<OtherDerived> &other)
{ other.derived().subToDense(derived()); return derived(); }
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& operator=(const ProductBase<ProductDerived, Lhs, Rhs> &other);
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& operator+=(const ProductBase<ProductDerived, Lhs, Rhs> &other);
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& operator-=(const ProductBase<ProductDerived, Lhs, Rhs> &other);
template<typename OtherDerived,typename OtherEvalType>
Derived& operator=(const ReturnByValue<OtherDerived,OtherEvalType>& func);

View File

@ -0,0 +1,153 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <g.gael@free.fr>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#ifndef EIGEN_PRODUCTBASE_H
#define EIGEN_PRODUCTBASE_H
/** \class ProductBase
*
*/
template<typename Derived, typename _Lhs, typename _Rhs>
struct ei_traits<ProductBase<Derived,_Lhs,_Rhs> >
{
typedef typename ei_cleantype<_Lhs>::type Lhs;
typedef typename ei_cleantype<_Rhs>::type Rhs;
typedef typename ei_traits<Lhs>::Scalar Scalar;
enum {
RowsAtCompileTime = ei_traits<Lhs>::RowsAtCompileTime,
ColsAtCompileTime = ei_traits<Rhs>::ColsAtCompileTime,
MaxRowsAtCompileTime = ei_traits<Lhs>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = ei_traits<Rhs>::MaxColsAtCompileTime,
Flags = EvalBeforeNestingBit,
CoeffReadCost = 0 // FIXME why is it needed ?
};
};
*
// enforce evaluation before nesting
template<typename Derived, typename Lhs, typename Rhs,int N,typename EvalType>
struct ei_nested<ProductBase<Derived,Lhs,Rhs>, N, EvalType>
{
typedef EvalType type;
};
#define EIGEN_PRODUCT_PUBLIC_INTERFACE(Derived) \
typedef ProductBase<Derived, Lhs, Rhs > ProductBaseType; \
_EIGEN_GENERIC_PUBLIC_INTERFACE(Derived, ProductBaseType) \
typedef typename Base::LhsNested LhsNested; \
typedef typename Base::_LhsNested _LhsNested; \
typedef typename Base::LhsBlasTraits LhsBlasTraits; \
typedef typename Base::ActualLhsType ActualLhsType; \
typedef typename Base::_ActualLhsType _ActualLhsType; \
typedef typename Base::RhsNested RhsNested; \
typedef typename Base::_RhsNested _RhsNested; \
typedef typename Base::RhsBlasTraits RhsBlasTraits; \
typedef typename Base::ActualRhsType ActualRhsType; \
typedef typename Base::_ActualRhsType _ActualRhsType; \
using Base::m_lhs; \
using Base::m_rhs;
template<typename Derived, typename Lhs, typename Rhs>
class ProductBase : public MatrixBase<Derived>
{
public:
_EIGEN_GENERIC_PUBLIC_INTERFACE(ProductBase,MatrixBase<Derived>)
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
using Base::derived;
typedef typename Base::PlainMatrixType PlainMatrixType;
ProductBase(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); }
template<typename Dest>
inline void evalTo(Dest& dst) const { dst.setZero(); addTo(dst,1); }
template<typename Dest>
inline void addTo(Dest& dst) const { addTo(dst,1); }
template<typename Dest>
inline void subTo(Dest& dst) const { addTo(dst,-1); }
template<typename Dest>
inline void addTo(Dest& dst,Scalar alpha) const { derived().addTo(dst,alpha); }
PlainMatrixType eval() const
{
PlainMatrixType res(rows(), cols());
res.setZero();
evalTo(res);
return res;
}
protected:
const LhsNested m_lhs;
const RhsNested m_rhs;
private:
// discard coeff methods
void coeff(int,int) const;
void coeffRef(int,int);
void coeff(int) const;
void coeffRef(int);
};
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator=(const ProductBase<ProductDerived,Lhs,Rhs>& other)
{
other.evalTo(derived()); return derived();
}
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator+=(const ProductBase<ProductDerived,Lhs,Rhs>& other)
{
other.addTo(derived()); return derived();
}
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::operator-=(const ProductBase<ProductDerived,Lhs,Rhs>& other)
{
other.subTo(derived()); return derived();
}
#endif // EIGEN_PRODUCTBASE_H

View File

@ -202,50 +202,22 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
template<typename Lhs, int LhsMode, typename Rhs>
struct ei_traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> >
: ei_traits<Matrix<typename ei_traits<Rhs>::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
: ei_traits<ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>, Lhs, Rhs> >
{};
template<typename Lhs, int LhsMode, typename Rhs>
struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
: public AnyMatrixBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> >
: public ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>, Lhs, Rhs >
{
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix)
enum {
LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit)
};
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); }
template<typename Dest> inline void addToDense(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void subToDense(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalToDense(Dest& dst) const
{
dst.setZero();
evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
@ -265,9 +237,6 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
actualAlpha // scale factor
);
}
const LhsNested m_lhs;
const RhsNested m_rhs;
};
/***************************************************************************
@ -276,33 +245,16 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct ei_traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> >
: ei_traits<Matrix<typename ei_traits<Rhs>::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
: ei_traits<ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs> >
{};
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
: public AnyMatrixBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> >
: public ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs >
{
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix)
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); }
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
enum {
LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
@ -311,18 +263,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
};
template<typename Dest> inline void addToDense(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void subToDense(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalToDense(Dest& dst) const
{
dst.setZero();
evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
@ -348,9 +289,6 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
actualAlpha // alpha
);
}
const LhsNested m_lhs;
const RhsNested m_rhs;
};
/***************************************************************************

View File

@ -322,47 +322,18 @@ struct ei_product_triangular_matrix_matrix<Scalar,Mode,false,
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct ei_traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
: ei_traits<Matrix<typename ei_traits<Rhs>::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
: ei_traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
{};
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
: public AnyMatrixBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
: public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs >
{
TriangularProduct(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); }
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
template<typename Dest> inline void addToDense(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void subToDense(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalToDense(Dest& dst) const
{
dst.resize(m_lhs.rows(), m_rhs.cols());
dst.setZero();
evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
{
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
@ -383,9 +354,6 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
actualAlpha // alpha
);
}
const LhsNested m_lhs;
const RhsNested m_rhs;
};
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H

View File

@ -119,46 +119,18 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs
template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
struct ei_traits<TriangularProduct<Mode,true,Lhs,false,Rhs,true> >
: ei_traits<Matrix<typename ei_traits<Rhs>::Scalar,Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
: ei_traits<ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs> >
{};
template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
: public AnyMatrixBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true> >
: public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
{
typedef typename Lhs::Scalar Scalar;
EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
TriangularProduct(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs)
{}
inline int rows() const { return m_lhs.rows(); }
inline int cols() const { return m_rhs.cols(); }
template<typename Dest> inline void addToDense(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void subToDense(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalToDense(Dest& dst) const
{
dst.setZero();
evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
template<typename Dest> void addTo(Dest& dst, Scalar alpha) const
{
ei_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
@ -176,9 +148,6 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
ei_traits<Lhs>::Flags&RowMajorBit>
::run(lhs,rhs,dst,actualAlpha);
}
const LhsNested m_lhs;
const RhsNested m_rhs;
};
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H

View File

@ -48,6 +48,7 @@ template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
template<typename Derived, typename Lhs, typename Rhs> class ProductBase;
template<typename Lhs, typename Rhs, int ProductMode> class Product;
template<typename Derived> class DiagonalBase;

View File

@ -70,8 +70,10 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT( m3 = (m1 * m2.adjoint()), 1);
VERIFY_EVALUATION_COUNT( m3 = (m1 * m2.adjoint()).lazy(), 0);
// NOTE in this case the slow product is used:
VERIFY_EVALUATION_COUNT( m3 = s1 * (m1 * m2.transpose()).lazy(), 0);
VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * m2.adjoint()).lazy(), 0);
VERIFY_EVALUATION_COUNT( m3 = (s1 * m1 * s2 * (m1*s3+m2*s2).adjoint()).lazy(), 1);
VERIFY_EVALUATION_COUNT( m3 = ((s1 * m1).adjoint() * s2 * m2).lazy(), 0);
@ -80,6 +82,7 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) += (-m1.block(r0,c0,r1,c1) * (s2*m2.block(r0,c0,r1,c1)).adjoint()).lazy() ), 0);
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) -= (s1 * m1.block(r0,c0,r1,c1) * m2.block(c0,r0,c1,r1)).lazy() ), 0);
// NOTE this is because the Block expression is not handled yet by our expression analyser
VERIFY_EVALUATION_COUNT(( m3.block(r0,r0,r1,r1) = (s1 * m1.block(r0,c0,r1,c1) * (s1*m2).block(c0,r0,c1,r1)).lazy() ), 1);
@ -90,8 +93,7 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
VERIFY_EVALUATION_COUNT( rm3.col(c0) = (s1 * m1.adjoint()).template triangularView<UnitUpperTriangular>() * (s2*m2.row(c0)).adjoint(), 0);
VERIFY_EVALUATION_COUNT( m1.template triangularView<LowerTriangular>().solveInPlace(m3), 0);
// FIXME this is because the rhs/result must be column major:
VERIFY_EVALUATION_COUNT( m1.adjoint().template triangularView<LowerTriangular>().solveInPlace(m3.transpose()), 1);
VERIFY_EVALUATION_COUNT( m1.adjoint().template triangularView<LowerTriangular>().solveInPlace(m3.transpose()), 0);
VERIFY_EVALUATION_COUNT( m3 -= (s1 * m1).adjoint().template selfadjointView<LowerTriangular>() * (-m2*s3).adjoint(), 0);
VERIFY_EVALUATION_COUNT( m3 = s2 * m2.adjoint() * (s1 * m1.adjoint()).template selfadjointView<UpperTriangular>(), 0);