Eliminate unnecessary evaluations

This commit is contained in:
Chen-Pang He 2012-09-23 00:20:19 +08:00
parent 7e64f78f65
commit 963794b04a
5 changed files with 72 additions and 55 deletions

View File

@ -162,6 +162,9 @@ template<typename Derived> class MatrixBase
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename ProductDerived, typename Lhs, typename Rhs> template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other); Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other);
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
#endif // not EIGEN_PARSED_BY_DOXYGEN #endif // not EIGEN_PARSED_BY_DOXYGEN
template<typename OtherDerived> template<typename OtherDerived>

View File

@ -81,8 +81,8 @@ class NoAlias
EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other) EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other)
{ return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); } { return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); }
template<typename Derived> template<typename Derived, typename Lhs, typename Rhs>
EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived>& other) EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived,Lhs,Rhs>& other)
{ other.derived().evalTo(m_expression); return m_expression; } { other.derived().evalTo(m_expression); return m_expression; }
#endif #endif

View File

@ -272,7 +272,7 @@ template<typename Derived> class MatrixFunctionReturnValue;
template<typename Derived> class MatrixSquareRootReturnValue; template<typename Derived> class MatrixSquareRootReturnValue;
template<typename Derived> class MatrixLogarithmReturnValue; template<typename Derived> class MatrixLogarithmReturnValue;
template<typename Derived> class MatrixPowerReturnValue; template<typename Derived> class MatrixPowerReturnValue;
template<typename Derived> class MatrixPowerProductBase; template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase;
namespace internal { namespace internal {
template <typename Scalar> template <typename Scalar>

View File

@ -55,14 +55,14 @@ template<typename MatrixType> class MatrixPower
RealScalar modfAndInit(RealScalar, RealScalar*); RealScalar modfAndInit(RealScalar, RealScalar*);
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void apply(const PlainObject&, ResultType&, bool&); void apply(const Derived&, ResultType&, bool&);
template<typename ResultType> template<typename ResultType>
void computeIntPower(ResultType&, RealScalar); void computeIntPower(ResultType&, RealScalar);
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void computeIntPower(const PlainObject&, ResultType&, RealScalar); void computeIntPower(const Derived&, ResultType&, RealScalar);
template<typename ResultType> template<typename ResultType>
void computeFracPower(ResultType&, RealScalar); void computeFracPower(ResultType&, RealScalar);
@ -101,8 +101,8 @@ template<typename MatrixType> class MatrixPower
* \param[out] res \f$ A^p b \f$, where A is specified in the * \param[out] res \f$ A^p b \f$, where A is specified in the
* constructor. * constructor.
*/ */
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void compute(const PlainObject& b, ResultType& res, RealScalar p); void compute(const Derived& b, ResultType& res, RealScalar p);
Index rows() const { return m_A.rows(); } Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); } Index cols() const { return m_A.cols(); }
@ -133,8 +133,8 @@ void MatrixPower<MatrixType>::compute(MatrixType& res, RealScalar p)
} }
template<typename MatrixType> template<typename MatrixType>
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void MatrixPower<MatrixType>::compute(const PlainObject& b, ResultType& res, RealScalar p) void MatrixPower<MatrixType>::compute(const Derived& b, ResultType& res, RealScalar p)
{ {
switch (m_A.cols()) { switch (m_A.cols()) {
case 0: case 0:
@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower<MatrixType>::modfAndInit(RealScalar
} }
template<typename MatrixType> template<typename MatrixType>
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void MatrixPower<MatrixType>::apply(const PlainObject& b, ResultType& res, bool& init) void MatrixPower<MatrixType>::apply(const Derived& b, ResultType& res, bool& init)
{ {
if (init) if (init)
res = m_tmp1 * res; res = m_tmp1 * res;
@ -206,8 +206,8 @@ void MatrixPower<MatrixType>::computeIntPower(ResultType& res, RealScalar p)
} }
template<typename MatrixType> template<typename MatrixType>
template<typename PlainObject, typename ResultType> template<typename Derived, typename ResultType>
void MatrixPower<MatrixType>::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p) void MatrixPower<MatrixType>::computeIntPower(const Derived& b, ResultType& res, RealScalar p)
{ {
if (b.cols() >= m_A.cols()) { if (b.cols() >= m_A.cols()) {
m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols()); m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols());
@ -262,14 +262,13 @@ void MatrixPower<MatrixType>::computeFracPower(ResultType& res, RealScalar p)
} }
} }
template<typename MatrixType, typename PlainObject> template<typename Lhs, typename Rhs>
class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs>
{ {
public: public:
typedef MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > Base; EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
MatrixPowerMatrixProduct(MatrixPower<MatrixType>& pow, const PlainObject& b, RealScalar p) MatrixPowerMatrixProduct(MatrixPower<Lhs>& pow, const Rhs& b, RealScalar p)
: m_pow(pow), m_b(b), m_p(p) { } : m_pow(pow), m_b(b), m_p(p) { }
template<typename ResultType> template<typename ResultType>
@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrix
Index cols() const { return m_b.cols(); } Index cols() const { return m_b.cols(); }
private: private:
MatrixPower<MatrixType>& m_pow; MatrixPower<Lhs>& m_pow;
const PlainObject& m_b; const Rhs& m_b;
const RealScalar m_p; const RealScalar m_p;
MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&); MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&);
}; };
@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue<MatrixPowerReturnValue<Deriv
*/ */
template<typename ResultType> template<typename ResultType>
inline void evalTo(ResultType& res) const inline void evalTo(ResultType& res) const
{ MatrixPower<typename Derived::PlainObject>(m_A).compute(res, m_p); } { MatrixPower<typename Derived::PlainObject>(m_A.eval()).compute(res, m_p); }
Index rows() const { return m_A.rows(); } Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); } Index cols() const { return m_A.cols(); }
@ -350,8 +349,8 @@ class MatrixPowerEvaluator
{ m_pow.compute(res, m_p); } { m_pow.compute(res, m_p); }
template<typename Derived> template<typename Derived>
const MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject> operator*(const MatrixBase<Derived>& b) const const MatrixPowerMatrixProduct<MatrixType, Derived> operator*(const MatrixBase<Derived>& b) const
{ return MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject>(m_pow, b.derived(), m_p); } { return MatrixPowerMatrixProduct<MatrixType, Derived>(m_pow, b.derived(), m_p); }
Index rows() const { return m_pow.rows(); } Index rows() const { return m_pow.rows(); }
Index cols() const { return m_pow.cols(); } Index cols() const { return m_pow.cols(); }
@ -363,9 +362,9 @@ class MatrixPowerEvaluator
}; };
namespace internal { namespace internal {
template<typename MatrixType, typename PlainObject> template<typename MatrixType, typename Derived>
struct nested<MatrixPowerMatrixProduct<MatrixType,PlainObject> > struct nested<MatrixPowerMatrixProduct<MatrixType,Derived> >
{ typedef PlainObject const& type; }; { typedef typename MatrixPowerMatrixProduct<MatrixType,Derived>::PlainObject const& type; };
template<typename Derived> template<typename Derived>
struct traits<MatrixPowerReturnValue<Derived> > struct traits<MatrixPowerReturnValue<Derived> >
@ -375,28 +374,10 @@ template<typename MatrixType>
struct traits<MatrixPowerEvaluator<MatrixType> > struct traits<MatrixPowerEvaluator<MatrixType> >
{ typedef MatrixType ReturnType; }; { typedef MatrixType ReturnType; };
template<typename MatrixType, typename PlainObject> template<typename Lhs, typename Rhs>
struct traits<MatrixPowerMatrixProduct<MatrixType,PlainObject> > struct traits<MatrixPowerMatrixProduct<Lhs,Rhs> >
{ : traits<MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> >
typedef MatrixXpr XprKind; { };
typedef typename scalar_product_traits<typename MatrixType::Scalar, typename PlainObject::Scalar>::ReturnType Scalar;
typedef typename promote_storage_type<typename traits<MatrixType>::StorageKind,
typename traits<PlainObject>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<MatrixType>::Index,
typename traits<PlainObject>::Index>::type Index;
enum {
RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::RowsAtCompileTime,
traits<PlainObject>::RowsAtCompileTime),
ColsAtCompileTime = traits<PlainObject>::ColsAtCompileTime,
MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::MaxRowsAtCompileTime,
traits<PlainObject>::MaxRowsAtCompileTime),
MaxColsAtCompileTime = traits<PlainObject>::MaxColsAtCompileTime,
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
CoeffReadCost = 0
};
};
} }
template<typename Derived> template<typename Derived>

View File

@ -29,9 +29,29 @@ struct recompose_complex_schur<0>
{ res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); } { res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); }
}; };
template<typename Derived> template<typename Derived, typename _Lhs, typename _Rhs>
struct traits<MatrixPowerProductBase<Derived> > : traits<Derived> struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> >
{ }; {
typedef MatrixXpr XprKind;
typedef typename remove_all<_Lhs>::type Lhs;
typedef typename remove_all<_Rhs>::type Rhs;
typedef typename remove_all<Derived>::type PlainObject;
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
typedef typename promote_storage_type<typename traits<Lhs>::StorageKind,
typename traits<Rhs>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<Lhs>::Index,
typename traits<Rhs>::Index>::type Index;
enum {
RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime,
ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime,
MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime,
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
CoeffReadCost = 0
};
};
template<typename T> template<typename T>
inline int binary_powering_cost(T p, int* squarings) inline int binary_powering_cost(T p, int* squarings)
@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R
compute2x2(res, p); compute2x2(res, p);
} }
template<typename Derived> #define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \
typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
template<typename Derived, typename Lhs, typename Rhs>
class MatrixPowerProductBase : public MatrixBase<Derived> class MatrixPowerProductBase : public MatrixBase<Derived>
{ {
public: public:
typedef MatrixBase<Derived> Base; typedef MatrixBase<Derived> Base;
typedef typename Base::PlainObject PlainObject;
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase) EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase)
typedef typename Base::PlainObject PlainObject;
inline Index rows() const { return derived().rows(); } inline Index rows() const { return derived().rows(); }
inline Index cols() const { return derived().cols(); } inline Index cols() const { return derived().cols(); }
@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase<Derived>
mutable PlainObject m_result; mutable PlainObject m_result;
}; };
template<typename Derived>
template<typename ProductDerived, typename Lhs, typename Rhs>
Derived& MatrixBase<Derived>::lazyAssign(const MatrixPowerProductBase<ProductDerived,Lhs,Rhs>& other)
{
other.derived().evalTo(derived());
return derived();
}
} // namespace Eigen } // namespace Eigen
#endif // EIGEN_MATRIX_POWER #endif // EIGEN_MATRIX_POWER