More clever evaluation of arguments: now it occurs in earlier, in operator*,

before the Product<> type is constructed. This resets template depth on each
intermediate evaluation, and gives simpler code. Introducing
ei_eval_if_expensive<Derived, n> which evaluates Derived if it's worth it
given that each of its coeffs will be accessed n times. Operator*
uses this with adequate values of n to evaluate args exactly when needed.
This commit is contained in:
Benoit Jacob 2008-04-03 14:17:56 +00:00
parent 4448f2620d
commit b8900d0b80
2 changed files with 21 additions and 7 deletions

View File

@ -27,6 +27,7 @@
template<typename T> struct ei_traits; template<typename T> struct ei_traits;
template<typename Lhs, typename Rhs> struct ei_product_eval_mode; template<typename Lhs, typename Rhs> struct ei_product_eval_mode;
template<typename T> struct NumTraits;
template<typename _Scalar, int _Rows, int _Cols, unsigned int _Flags, int _MaxRows, int _MaxCols> class Matrix; template<typename _Scalar, int _Rows, int _Cols, unsigned int _Flags, int _MaxRows, int _MaxCols> class Matrix;
template<typename ExpressionType> class Lazy; template<typename ExpressionType> class Lazy;
@ -89,6 +90,13 @@ template<typename T> struct ei_eval
ei_traits<T>::MaxColsAtCompileTime> type; ei_traits<T>::MaxColsAtCompileTime> type;
}; };
template<typename T, int n> struct ei_eval_if_expensive
{
enum { eval = n * NumTraits<typename T::Scalar>::ReadCost < (n-1) * T::CoeffReadCost };
typedef typename ei_meta_if<eval, typename T::Eval, T>::ret type;
typedef typename ei_meta_if<eval, typename T::Eval, T&>::ret reftype;
};
template<typename T> struct ei_eval_unless_lazy template<typename T> struct ei_eval_unless_lazy
{ {
typedef typename ei_meta_if<ei_traits<T>::Flags & LazyBit, typedef typename ei_meta_if<ei_traits<T>::Flags & LazyBit,

View File

@ -78,6 +78,7 @@ template<typename Lhs, typename Rhs, int EvalMode>
struct ei_traits<Product<Lhs, Rhs, EvalMode> > struct ei_traits<Product<Lhs, Rhs, EvalMode> >
{ {
typedef typename Lhs::Scalar Scalar; typedef typename Lhs::Scalar Scalar;
#if 0
typedef typename ei_meta_if< typedef typename ei_meta_if<
(int)NumTraits<Scalar>::ReadCost < (int)Lhs::CoeffReadCost, (int)NumTraits<Scalar>::ReadCost < (int)Lhs::CoeffReadCost,
typename Lhs::Eval, typename Lhs::Eval,
@ -95,6 +96,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
(int)NumTraits<Scalar>::ReadCost < (int)Rhs::CoeffReadCost, (int)NumTraits<Scalar>::ReadCost < (int)Rhs::CoeffReadCost,
typename Rhs::Eval, typename Rhs::Eval,
typename Rhs::XprCopy>::ret ActualRhsXprCopy; typename Rhs::XprCopy>::ret ActualRhsXprCopy;
#endif
enum { enum {
RowsAtCompileTime = Lhs::RowsAtCompileTime, RowsAtCompileTime = Lhs::RowsAtCompileTime,
ColsAtCompileTime = Rhs::ColsAtCompileTime, ColsAtCompileTime = Rhs::ColsAtCompileTime,
@ -107,7 +109,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
= Lhs::ColsAtCompileTime == Dynamic = Lhs::ColsAtCompileTime == Dynamic
? Dynamic ? Dynamic
: Lhs::ColsAtCompileTime : Lhs::ColsAtCompileTime
* (NumTraits<Scalar>::MulCost + ActualLhs::CoeffReadCost + ActualRhs::CoeffReadCost) * (NumTraits<Scalar>::MulCost + Lhs::CoeffReadCost + Rhs::CoeffReadCost)
+ (Lhs::ColsAtCompileTime - 1) * NumTraits<Scalar>::AddCost + (Lhs::ColsAtCompileTime - 1) * NumTraits<Scalar>::AddCost
}; };
}; };
@ -115,7 +117,7 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
template<typename Lhs, typename Rhs> struct ei_product_eval_mode template<typename Lhs, typename Rhs> struct ei_product_eval_mode
{ {
enum{ value = Lhs::MaxRowsAtCompileTime == Dynamic || Rhs::MaxColsAtCompileTime == Dynamic enum{ value = Lhs::MaxRowsAtCompileTime == Dynamic || Rhs::MaxColsAtCompileTime == Dynamic
? CacheOptimal : UnrolledDotProduct }; ? CacheOptimal : UnrolledDotProduct };
}; };
template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignment_operator, template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignment_operator,
@ -124,11 +126,12 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
public: public:
EIGEN_GENERIC_PUBLIC_INTERFACE(Product) EIGEN_GENERIC_PUBLIC_INTERFACE(Product)
#if 0
typedef typename ei_traits<Product>::ActualLhs ActualLhs; typedef typename ei_traits<Product>::ActualLhs ActualLhs;
typedef typename ei_traits<Product>::ActualRhs ActualRhs; typedef typename ei_traits<Product>::ActualRhs ActualRhs;
typedef typename ei_traits<Product>::ActualLhsXprCopy ActualLhsXprCopy; typedef typename ei_traits<Product>::ActualLhsXprCopy ActualLhsXprCopy;
typedef typename ei_traits<Product>::ActualRhsXprCopy ActualRhsXprCopy; typedef typename ei_traits<Product>::ActualRhsXprCopy ActualRhsXprCopy;
#endif
Product(const Lhs& lhs, const Rhs& rhs) Product(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs) : m_lhs(lhs), m_rhs(rhs)
{ {
@ -153,7 +156,7 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
ei_product_unroller<Lhs::ColsAtCompileTime-1, ei_product_unroller<Lhs::ColsAtCompileTime-1,
Lhs::ColsAtCompileTime <= EIGEN_UNROLLING_LIMIT Lhs::ColsAtCompileTime <= EIGEN_UNROLLING_LIMIT
? Lhs::ColsAtCompileTime : Dynamic, ? Lhs::ColsAtCompileTime : Dynamic,
ActualLhs, ActualRhs> Lhs, Rhs>
::run(row, col, m_lhs, m_rhs, res); ::run(row, col, m_lhs, m_rhs, res);
else else
{ {
@ -165,8 +168,8 @@ template<typename Lhs, typename Rhs, int EvalMode> class Product : ei_no_assignm
} }
protected: protected:
const ActualLhsXprCopy m_lhs; const typename Lhs::XprCopy m_lhs;
const ActualRhsXprCopy m_rhs; const typename Rhs::XprCopy m_rhs;
}; };
/** \returns the matrix product of \c *this and \a other. /** \returns the matrix product of \c *this and \a other.
@ -181,7 +184,10 @@ template<typename OtherDerived>
const typename ei_eval_unless_lazy<Product<Derived, OtherDerived> >::type const typename ei_eval_unless_lazy<Product<Derived, OtherDerived> >::type
MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
{ {
return Product<Derived, OtherDerived>(derived(), other.derived()).eval(); typedef ei_eval_if_expensive<Derived, OtherDerived::ColsAtCompileTime> Lhs;
typedef ei_eval_if_expensive<OtherDerived, Derived::RowsAtCompileTime> Rhs;
return Product<typename Lhs::type, typename Rhs::type>
(typename Lhs::reftype(derived()), typename Rhs::reftype(other.derived())).eval();
} }
/** replaces \c *this by \c *this * \a other. /** replaces \c *this by \c *this * \a other.