mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
Eliminate unnecessary evaluations
This commit is contained in:
parent
7e64f78f65
commit
963794b04a
@ -162,6 +162,9 @@ template<typename Derived> class MatrixBase
|
|||||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||||
template<typename ProductDerived, typename Lhs, typename Rhs>
|
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||||
Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other);
|
Derived& lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other);
|
||||||
|
|
||||||
|
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||||
|
Derived& lazyAssign(const MatrixPowerProductBase<ProductDerived, Lhs,Rhs>& other);
|
||||||
#endif // not EIGEN_PARSED_BY_DOXYGEN
|
#endif // not EIGEN_PARSED_BY_DOXYGEN
|
||||||
|
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
|
@ -81,8 +81,8 @@ class NoAlias
|
|||||||
EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other)
|
EIGEN_STRONG_INLINE ExpressionType& operator-=(const CoeffBasedProduct<Lhs,Rhs,NestingFlags>& other)
|
||||||
{ return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); }
|
{ return m_expression.derived() -= CoeffBasedProduct<Lhs,Rhs,NestByRefBit>(other.lhs(), other.rhs()); }
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived, typename Lhs, typename Rhs>
|
||||||
EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived>& other)
|
EIGEN_STRONG_INLINE ExpressionType& operator=(const MatrixPowerProductBase<Derived,Lhs,Rhs>& other)
|
||||||
{ other.derived().evalTo(m_expression); return m_expression; }
|
{ other.derived().evalTo(m_expression); return m_expression; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ template<typename Derived> class MatrixFunctionReturnValue;
|
|||||||
template<typename Derived> class MatrixSquareRootReturnValue;
|
template<typename Derived> class MatrixSquareRootReturnValue;
|
||||||
template<typename Derived> class MatrixLogarithmReturnValue;
|
template<typename Derived> class MatrixLogarithmReturnValue;
|
||||||
template<typename Derived> class MatrixPowerReturnValue;
|
template<typename Derived> class MatrixPowerReturnValue;
|
||||||
template<typename Derived> class MatrixPowerProductBase;
|
template<typename Derived, typename Lhs, typename Rhs> class MatrixPowerProductBase;
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
|
@ -55,14 +55,14 @@ template<typename MatrixType> class MatrixPower
|
|||||||
|
|
||||||
RealScalar modfAndInit(RealScalar, RealScalar*);
|
RealScalar modfAndInit(RealScalar, RealScalar*);
|
||||||
|
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void apply(const PlainObject&, ResultType&, bool&);
|
void apply(const Derived&, ResultType&, bool&);
|
||||||
|
|
||||||
template<typename ResultType>
|
template<typename ResultType>
|
||||||
void computeIntPower(ResultType&, RealScalar);
|
void computeIntPower(ResultType&, RealScalar);
|
||||||
|
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void computeIntPower(const PlainObject&, ResultType&, RealScalar);
|
void computeIntPower(const Derived&, ResultType&, RealScalar);
|
||||||
|
|
||||||
template<typename ResultType>
|
template<typename ResultType>
|
||||||
void computeFracPower(ResultType&, RealScalar);
|
void computeFracPower(ResultType&, RealScalar);
|
||||||
@ -101,8 +101,8 @@ template<typename MatrixType> class MatrixPower
|
|||||||
* \param[out] res \f$ A^p b \f$, where A is specified in the
|
* \param[out] res \f$ A^p b \f$, where A is specified in the
|
||||||
* constructor.
|
* constructor.
|
||||||
*/
|
*/
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void compute(const PlainObject& b, ResultType& res, RealScalar p);
|
void compute(const Derived& b, ResultType& res, RealScalar p);
|
||||||
|
|
||||||
Index rows() const { return m_A.rows(); }
|
Index rows() const { return m_A.rows(); }
|
||||||
Index cols() const { return m_A.cols(); }
|
Index cols() const { return m_A.cols(); }
|
||||||
@ -133,8 +133,8 @@ void MatrixPower<MatrixType>::compute(MatrixType& res, RealScalar p)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename MatrixType>
|
template<typename MatrixType>
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void MatrixPower<MatrixType>::compute(const PlainObject& b, ResultType& res, RealScalar p)
|
void MatrixPower<MatrixType>::compute(const Derived& b, ResultType& res, RealScalar p)
|
||||||
{
|
{
|
||||||
switch (m_A.cols()) {
|
switch (m_A.cols()) {
|
||||||
case 0:
|
case 0:
|
||||||
@ -177,8 +177,8 @@ typename MatrixType::RealScalar MatrixPower<MatrixType>::modfAndInit(RealScalar
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename MatrixType>
|
template<typename MatrixType>
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void MatrixPower<MatrixType>::apply(const PlainObject& b, ResultType& res, bool& init)
|
void MatrixPower<MatrixType>::apply(const Derived& b, ResultType& res, bool& init)
|
||||||
{
|
{
|
||||||
if (init)
|
if (init)
|
||||||
res = m_tmp1 * res;
|
res = m_tmp1 * res;
|
||||||
@ -206,8 +206,8 @@ void MatrixPower<MatrixType>::computeIntPower(ResultType& res, RealScalar p)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename MatrixType>
|
template<typename MatrixType>
|
||||||
template<typename PlainObject, typename ResultType>
|
template<typename Derived, typename ResultType>
|
||||||
void MatrixPower<MatrixType>::computeIntPower(const PlainObject& b, ResultType& res, RealScalar p)
|
void MatrixPower<MatrixType>::computeIntPower(const Derived& b, ResultType& res, RealScalar p)
|
||||||
{
|
{
|
||||||
if (b.cols() >= m_A.cols()) {
|
if (b.cols() >= m_A.cols()) {
|
||||||
m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols());
|
m_tmp2 = MatrixType::Identity(m_A.rows(),m_A.cols());
|
||||||
@ -262,14 +262,13 @@ void MatrixPower<MatrixType>::computeFracPower(ResultType& res, RealScalar p)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename MatrixType, typename PlainObject>
|
template<typename Lhs, typename Rhs>
|
||||||
class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef MatrixPowerProductBase<MatrixPowerMatrixProduct<MatrixType,PlainObject> > Base;
|
EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerMatrixProduct)
|
|
||||||
|
|
||||||
MatrixPowerMatrixProduct(MatrixPower<MatrixType>& pow, const PlainObject& b, RealScalar p)
|
MatrixPowerMatrixProduct(MatrixPower<Lhs>& pow, const Rhs& b, RealScalar p)
|
||||||
: m_pow(pow), m_b(b), m_p(p) { }
|
: m_pow(pow), m_b(b), m_p(p) { }
|
||||||
|
|
||||||
template<typename ResultType>
|
template<typename ResultType>
|
||||||
@ -280,8 +279,8 @@ class MatrixPowerMatrixProduct : public MatrixPowerProductBase<MatrixPowerMatrix
|
|||||||
Index cols() const { return m_b.cols(); }
|
Index cols() const { return m_b.cols(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MatrixPower<MatrixType>& m_pow;
|
MatrixPower<Lhs>& m_pow;
|
||||||
const PlainObject& m_b;
|
const Rhs& m_b;
|
||||||
const RealScalar m_p;
|
const RealScalar m_p;
|
||||||
MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&);
|
MatrixPowerMatrixProduct& operator=(const MatrixPowerMatrixProduct&);
|
||||||
};
|
};
|
||||||
@ -323,7 +322,7 @@ class MatrixPowerReturnValue : public ReturnByValue<MatrixPowerReturnValue<Deriv
|
|||||||
*/
|
*/
|
||||||
template<typename ResultType>
|
template<typename ResultType>
|
||||||
inline void evalTo(ResultType& res) const
|
inline void evalTo(ResultType& res) const
|
||||||
{ MatrixPower<typename Derived::PlainObject>(m_A).compute(res, m_p); }
|
{ MatrixPower<typename Derived::PlainObject>(m_A.eval()).compute(res, m_p); }
|
||||||
|
|
||||||
Index rows() const { return m_A.rows(); }
|
Index rows() const { return m_A.rows(); }
|
||||||
Index cols() const { return m_A.cols(); }
|
Index cols() const { return m_A.cols(); }
|
||||||
@ -350,8 +349,8 @@ class MatrixPowerEvaluator
|
|||||||
{ m_pow.compute(res, m_p); }
|
{ m_pow.compute(res, m_p); }
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
const MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject> operator*(const MatrixBase<Derived>& b) const
|
const MatrixPowerMatrixProduct<MatrixType, Derived> operator*(const MatrixBase<Derived>& b) const
|
||||||
{ return MatrixPowerMatrixProduct<MatrixType, typename Derived::PlainObject>(m_pow, b.derived(), m_p); }
|
{ return MatrixPowerMatrixProduct<MatrixType, Derived>(m_pow, b.derived(), m_p); }
|
||||||
|
|
||||||
Index rows() const { return m_pow.rows(); }
|
Index rows() const { return m_pow.rows(); }
|
||||||
Index cols() const { return m_pow.cols(); }
|
Index cols() const { return m_pow.cols(); }
|
||||||
@ -363,9 +362,9 @@ class MatrixPowerEvaluator
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename MatrixType, typename PlainObject>
|
template<typename MatrixType, typename Derived>
|
||||||
struct nested<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
struct nested<MatrixPowerMatrixProduct<MatrixType,Derived> >
|
||||||
{ typedef PlainObject const& type; };
|
{ typedef typename MatrixPowerMatrixProduct<MatrixType,Derived>::PlainObject const& type; };
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
struct traits<MatrixPowerReturnValue<Derived> >
|
struct traits<MatrixPowerReturnValue<Derived> >
|
||||||
@ -375,28 +374,10 @@ template<typename MatrixType>
|
|||||||
struct traits<MatrixPowerEvaluator<MatrixType> >
|
struct traits<MatrixPowerEvaluator<MatrixType> >
|
||||||
{ typedef MatrixType ReturnType; };
|
{ typedef MatrixType ReturnType; };
|
||||||
|
|
||||||
template<typename MatrixType, typename PlainObject>
|
template<typename Lhs, typename Rhs>
|
||||||
struct traits<MatrixPowerMatrixProduct<MatrixType,PlainObject> >
|
struct traits<MatrixPowerMatrixProduct<Lhs,Rhs> >
|
||||||
{
|
: traits<MatrixPowerProductBase<MatrixPowerMatrixProduct<Lhs,Rhs>,Lhs,Rhs> >
|
||||||
typedef MatrixXpr XprKind;
|
{ };
|
||||||
typedef typename scalar_product_traits<typename MatrixType::Scalar, typename PlainObject::Scalar>::ReturnType Scalar;
|
|
||||||
typedef typename promote_storage_type<typename traits<MatrixType>::StorageKind,
|
|
||||||
typename traits<PlainObject>::StorageKind>::ret StorageKind;
|
|
||||||
typedef typename promote_index_type<typename traits<MatrixType>::Index,
|
|
||||||
typename traits<PlainObject>::Index>::type Index;
|
|
||||||
|
|
||||||
enum {
|
|
||||||
RowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::RowsAtCompileTime,
|
|
||||||
traits<PlainObject>::RowsAtCompileTime),
|
|
||||||
ColsAtCompileTime = traits<PlainObject>::ColsAtCompileTime,
|
|
||||||
MaxRowsAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(traits<MatrixType>::MaxRowsAtCompileTime,
|
|
||||||
traits<PlainObject>::MaxRowsAtCompileTime),
|
|
||||||
MaxColsAtCompileTime = traits<PlainObject>::MaxColsAtCompileTime,
|
|
||||||
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
|
|
||||||
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
|
|
||||||
CoeffReadCost = 0
|
|
||||||
};
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
|
@ -29,9 +29,29 @@ struct recompose_complex_schur<0>
|
|||||||
{ res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); }
|
{ res = (U * (T.template triangularView<Upper>() * U.adjoint())).real(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived, typename _Lhs, typename _Rhs>
|
||||||
struct traits<MatrixPowerProductBase<Derived> > : traits<Derived>
|
struct traits<MatrixPowerProductBase<Derived,_Lhs,_Rhs> >
|
||||||
{ };
|
{
|
||||||
|
typedef MatrixXpr XprKind;
|
||||||
|
typedef typename remove_all<_Lhs>::type Lhs;
|
||||||
|
typedef typename remove_all<_Rhs>::type Rhs;
|
||||||
|
typedef typename remove_all<Derived>::type PlainObject;
|
||||||
|
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||||
|
typedef typename promote_storage_type<typename traits<Lhs>::StorageKind,
|
||||||
|
typename traits<Rhs>::StorageKind>::ret StorageKind;
|
||||||
|
typedef typename promote_index_type<typename traits<Lhs>::Index,
|
||||||
|
typename traits<Rhs>::Index>::type Index;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime,
|
||||||
|
ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime,
|
||||||
|
MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime,
|
||||||
|
MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime,
|
||||||
|
Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
|
||||||
|
| EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
|
||||||
|
CoeffReadCost = 0
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
inline int binary_powering_cost(T p, int* squarings)
|
inline int binary_powering_cost(T p, int* squarings)
|
||||||
@ -219,13 +239,18 @@ void MatrixPowerTriangularAtomic<MatrixType,UpLo>::computeBig(MatrixType& res, R
|
|||||||
compute2x2(res, p);
|
compute2x2(res, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Derived>
|
#define EIGEN_MATRIX_POWER_PRODUCT_PUBLIC_INTERFACE(Derived) \
|
||||||
|
typedef MatrixPowerProductBase<Derived, Lhs, Rhs > Base; \
|
||||||
|
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||||
|
|
||||||
|
template<typename Derived, typename Lhs, typename Rhs>
|
||||||
class MatrixPowerProductBase : public MatrixBase<Derived>
|
class MatrixPowerProductBase : public MatrixBase<Derived>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef MatrixBase<Derived> Base;
|
typedef MatrixBase<Derived> Base;
|
||||||
typedef typename Base::PlainObject PlainObject;
|
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase)
|
EIGEN_DENSE_PUBLIC_INTERFACE(MatrixPowerProductBase)
|
||||||
|
|
||||||
|
typedef typename Base::PlainObject PlainObject;
|
||||||
|
|
||||||
inline Index rows() const { return derived().rows(); }
|
inline Index rows() const { return derived().rows(); }
|
||||||
inline Index cols() const { return derived().cols(); }
|
inline Index cols() const { return derived().cols(); }
|
||||||
@ -247,6 +272,14 @@ class MatrixPowerProductBase : public MatrixBase<Derived>
|
|||||||
mutable PlainObject m_result;
|
mutable PlainObject m_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
template<typename ProductDerived, typename Lhs, typename Rhs>
|
||||||
|
Derived& MatrixBase<Derived>::lazyAssign(const MatrixPowerProductBase<ProductDerived,Lhs,Rhs>& other)
|
||||||
|
{
|
||||||
|
other.derived().evalTo(derived());
|
||||||
|
return derived();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_MATRIX_POWER
|
#endif // EIGEN_MATRIX_POWER
|
||||||
|
Loading…
x
Reference in New Issue
Block a user