mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 18:19:34 +08:00
Eliminate unnecessary evaluations
This commit is contained in:
parent
7e64f78f65
commit
963794b04a
@ -162,6 +162,9 @@ template<typename Derived> class MatrixBase
|
||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other);
|
||||
|
||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
|
||||
#endif // not EIGEN_PARSED_BY_DOXYGEN
|
||||
|
||||
template<typename OtherDerived>
|
||||
|
@ -81,8 +81,8 @@ class NoAlias
|
||||
EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other)
|
||||
{ return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); }
|
||||
|
||||
template<typename Derived>
|
||||
EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived>& other)
|
||||
template<typename Derived, typename Lhs, typename Rhs>
|
||||
EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived,Lhs,Rhs>& other)
|
||||
{ other.derived().evalTo(m_expression); return m_expression; }
|
||||
#endif
|
||||
|
||||
|
@ -272,7 +272,7 @@ template<typename Derived> class MatrixFunctionReturnValue;
|
||||
template<typename Derived> class MatrixSquareRootReturnValue;
|
||||
template<typename Derived> class MatrixLogarithmReturnValue;
|
||||
template<typename Derived> class MatrixPowerReturnValue;
|
||||
template<typename Derived> class MatrixPowerProductBase;
|
||||
template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase;
|
||||
|
||||
namespace internal {
|
||||
template <typename Scalar>
|
||||
|
@ -55,14 +55,14 @@ template<typename MatrixType> class MatrixPower
|
||||
|
||||
RealScalar modfAndInit(RealScalar, RealScalar*);
|
||||
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void apply(const PlainObject&, ResultType&, bool&);
|
||||
template<typename Derived, typename ResultType>
|
||||
void apply(const Derived&, ResultType&, bool&);
|
||||
|
||||
template<typename ResultType>
|
||||
void computeIntPower(ResultType&, RealScalar);
|
||||
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void computeIntPower(const PlainObject&, ResultType&, RealScalar);
|
||||
template<typename Derived, typename ResultType>
|
||||
void computeIntPower(const Derived&, ResultType&, RealScalar);
|
||||
|
||||
template<typename ResultType>
|
||||
void computeFracPower(ResultType&, RealScalar);
|
||||
@ -101,8 +101,8 @@ template<typename MatrixType> class MatrixPower
|
||||
* \param[out] res \f$ A^p b \f$, where A is specified in the
|
||||
* constructor.
|
||||
*/
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void compute(const PlainObject& b, ResultType& res, RealScalar p);
|
||||
template<typename Derived, typename ResultType>
|
||||
void compute(const Derived& b, ResultType& res, RealScalar p);
|
||||
|
||||
Index rows() const { return m_A.rows(); }
|
||||
Index cols() const { return m_A.cols(); }
|
||||
@ -133,8 +133,8 @@ void MatrixPower<MatrixType>::compute(MatrixType& res, RealScalar p)
|
||||
}
|
||||
|
||||
template<typename MatrixType>
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void MatrixPower<MatrixType>::compute(const PlainObject& b, ResultType& res, RealScalar p)
|
||||
template<typename Derived, typename ResultType>
|
||||
void MatrixPower<MatrixType>::compute(const Derived& b, ResultType& res, RealScalar p)
|
||||
{
|
||||
switch (m_A.cols()) {
|
||||
case 0:
|
||||
@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower<MatrixType>::modfAndInit(RealScalar
|
||||
}
|
||||
|
||||
template<typename MatrixType>
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void MatrixPower<MatrixType>::apply(const PlainObject& b, ResultType& res, bool& init)
|
||||
template<typename Derived, typename ResultType>
|
||||
void MatrixPower<MatrixType>::apply(const Derived& b, ResultType& res, bool& init)
|
||||
{
|
||||
if (init)
|
||||
res = m_tmp1 * res;
|
||||
@ -206,8 +206,8 @@ void MatrixPower<MatrixType>::computeIntPower(ResultType& res, RealScalar p)
|
||||
}
|
||||
|
||||
template<typename MatrixType>
|
||||
template<typename PlainObject, typename ResultType>
|
||||
void MatrixPower<MatrixType>::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p)
|
||||
template<typename Derived, typename ResultType>
|
||||
void MatrixPower<MatrixType>::computeIntPower(const Derived& b, ResultType& res, RealScalar p)
|
||||
{
|
||||
if (b.cols() >= m_A.cols()) {
|
||||
m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols());
|
||||
@ -262,14 +262,13 @@ void MatrixPower<MatrixType>::computeFracPower(ResultType& res, RealScalar p)
|
||||
}
|
||||
}
|
||||
|
||||
template<typename MatrixType, typename PlainObject>
|
||||
class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
||||
template<typename Lhs, typename Rhs>
|
||||
class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs>
|
||||
{
|
||||
public:
|
||||
typedef MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > Base;
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
|
||||
EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
|
||||
|
||||
MatrixPowerMatrixProduct(MatrixPower<MatrixType>& pow, const PlainObject& b, RealScalar p)
|
||||
MatrixPowerMatrixProduct(MatrixPower<Lhs>& pow, const Rhs& b, RealScalar p)
|
||||
: m_pow(pow), m_b(b), m_p(p) { }
|
||||
|
||||
template<typename ResultType>
|
||||
@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrix
|
||||
Index cols() const { return m_b.cols(); }
|
||||
|
||||
private:
|
||||
MatrixPower<MatrixType>& m_pow;
|
||||
const PlainObject& m_b;
|
||||
MatrixPower<Lhs>& m_pow;
|
||||
const Rhs& m_b;
|
||||
const RealScalar m_p;
|
||||
MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&);
|
||||
};
|
||||
@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue<MatrixPowerReturnValue<Deriv
|
||||
*/
|
||||
template<typename ResultType>
|
||||
inline void evalTo(ResultType& res) const
|
||||
{ MatrixPower<typename Derived::PlainObject>(m_A).compute(res, m_p); }
|
||||
{ MatrixPower<typename Derived::PlainObject>(m_A.eval()).compute(res, m_p); }
|
||||
|
||||
Index rows() const { return m_A.rows(); }
|
||||
Index cols() const { return m_A.cols(); }
|
||||
@ -350,8 +349,8 @@ class MatrixPowerEvaluator
|
||||
{ m_pow.compute(res, m_p); }
|
||||
|
||||
template<typename Derived>
|
||||
const MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject> operator*(const MatrixBase<Derived>& b) const
|
||||
{ return MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject>(m_pow, b.derived(), m_p); }
|
||||
const MatrixPowerMatrixProduct<MatrixType, Derived> operator*(const MatrixBase<Derived>& b) const
|
||||
{ return MatrixPowerMatrixProduct<MatrixType, Derived>(m_pow, b.derived(), m_p); }
|
||||
|
||||
Index rows() const { return m_pow.rows(); }
|
||||
Index cols() const { return m_pow.cols(); }
|
||||
@ -363,9 +362,9 @@ class MatrixPowerEvaluator
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
template<typename MatrixType, typename PlainObject>
|
||||
struct nested<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
||||
{ typedef PlainObject const& type; };
|
||||
template<typename MatrixType, typename Derived>
|
||||
struct nested<MatrixPowerMatrixProduct<MatrixType,Derived> >
|
||||
{ typedef typename MatrixPowerMatrixProduct<MatrixType,Derived>::PlainObject const& type; };
|
||||
|
||||
template<typename Derived>
|
||||
struct traits<MatrixPowerReturnValue<Derived> >
|
||||
@ -375,28 +374,10 @@ template<typename MatrixType>
|
||||
struct traits<MatrixPowerEvaluator<MatrixType> >
|
||||
{ typedef MatrixType ReturnType; };
|
||||
|
||||
template<typename MatrixType, typename PlainObject>
|
||||
struct traits<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
||||
{
|
||||
typedef MatrixXpr XprKind;
|
||||
typedef typename scalar_product_traits<typename MatrixType::Scalar, typename PlainObject::Scalar>::ReturnType Scalar;
|
||||
typedef typename promote_storage_type<typename traits<MatrixType>::StorageKind,
|
||||
typename traits<PlainObject>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename traits<MatrixType>::Index,
|
||||
typename traits<PlainObject>::Index>::type Index;
|
||||
|
||||
enum {
|
||||
RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::RowsAtCompileTime,
|
||||
traits<PlainObject>::RowsAtCompileTime),
|
||||
ColsAtCompileTime = traits<PlainObject>::ColsAtCompileTime,
|
||||
MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::MaxRowsAtCompileTime,
|
||||
traits<PlainObject>::MaxRowsAtCompileTime),
|
||||
MaxColsAtCompileTime = traits<PlainObject>::MaxColsAtCompileTime,
|
||||
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
|
||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
|
||||
CoeffReadCost = 0
|
||||
};
|
||||
};
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct traits<MatrixPowerMatrixProduct<Lhs,Rhs> >
|
||||
: traits<MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> >
|
||||
{ };
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
|
@ -29,9 +29,29 @@ struct recompose_complex_schur<0>
|
||||
{ res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); }
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
struct traits<MatrixPowerProductBase<Derived> > : traits<Derived>
|
||||
{ };
|
||||
template<typename Derived, typename _Lhs, typename _Rhs>
|
||||
struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> >
|
||||
{
|
||||
typedef MatrixXpr XprKind;
|
||||
typedef typename remove_all<_Lhs>::type Lhs;
|
||||
typedef typename remove_all<_Rhs>::type Rhs;
|
||||
typedef typename remove_all<Derived>::type PlainObject;
|
||||
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||
typedef typename promote_storage_type<typename traits<Lhs>::StorageKind,
|
||||
typename traits<Rhs>::StorageKind>::ret StorageKind;
|
||||
typedef typename promote_index_type<typename traits<Lhs>::Index,
|
||||
typename traits<Rhs>::Index>::type Index;
|
||||
|
||||
enum {
|
||||
RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime,
|
||||
ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime,
|
||||
MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime,
|
||||
MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime,
|
||||
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
|
||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
|
||||
CoeffReadCost = 0
|
||||
};
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline int binary_powering_cost(T p, int* squarings)
|
||||
@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R
|
||||
compute2x2(res, p);
|
||||
}
|
||||
|
||||
template<typename Derived>
|
||||
#define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \
|
||||
typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||
|
||||
template<typename Derived, typename Lhs, typename Rhs>
|
||||
class MatrixPowerProductBase : public MatrixBase<Derived>
|
||||
{
|
||||
public:
|
||||
typedef MatrixBase<Derived> Base;
|
||||
typedef typename Base::PlainObject PlainObject;
|
||||
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase)
|
||||
|
||||
typedef typename Base::PlainObject PlainObject;
|
||||
|
||||
inline Index rows() const { return derived().rows(); }
|
||||
inline Index cols() const { return derived().cols(); }
|
||||
@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase<Derived>
|
||||
mutable PlainObject m_result;
|
||||
};
|
||||
|
||||
template<typename Derived>
|
||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||
Derived& MatrixBase<Derived>::lazyAssign(const MatrixPowerProductBase<ProductDerived,Lhs,Rhs>& other)
|
||||
{
|
||||
other.derived().evalTo(derived());
|
||||
return derived();
|
||||
}
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_MATRIX_POWER
|
||||
|
Loading…
x
Reference in New Issue
Block a user