* change Flagged to take into account NestByValue only

* bugfix in Assign and cache friendly product (weird that worked before)
* improved argument evaluation in Product
This commit is contained in:
Gael Guennebaud 2008-05-28 22:11:47 +00:00
parent 73084dc754
commit 8711e26c8a
4 changed files with 105 additions and 32 deletions

View File

@ -208,12 +208,12 @@ struct ei_assignment_impl<Derived, OtherDerived, true>
for ( ; index<alignedSize ; index+=ei_packet_traits<typename Derived::Scalar>::size) for ( ; index<alignedSize ; index+=ei_packet_traits<typename Derived::Scalar>::size)
{ {
// FIXME the following is not really efficient // FIXME the following is not really efficient
int i = index/dst.rows(); int i = index/dst.cols();
int j = index%dst.rows(); int j = index%dst.cols();
dst.template writePacketCoeff<Aligned>(i, j, src.template packetCoeff<Aligned>(i, j)); dst.template writePacketCoeff<Aligned>(i, j, src.template packetCoeff<Aligned>(i, j));
} }
for(int i = alignedSize/dst.rows(); i < dst.rows(); i++) for(int i = alignedSize/dst.cols(); i < dst.rows(); i++)
for(int j = alignedSize%dst.rows(); j < dst.cols(); j++) for(int j = alignedSize%dst.cols(); j < dst.cols(); j++)
dst.coeffRef(i, j) = src.coeff(i, j); dst.coeffRef(i, j) = src.coeff(i, j);
} }
else else

View File

@ -45,7 +45,7 @@ static void ei_cache_friendly_product(
rhsStride = _lhsStride; rhsStride = _lhsStride;
cols = _rows; cols = _rows;
rows = _cols; rows = _cols;
lhsRowMajor = _rhsRowMajor; lhsRowMajor = !_rhsRowMajor;
ei_assert(_lhsRowMajor); ei_assert(_lhsRowMajor);
} }
else else

View File

@ -43,6 +43,7 @@ template<typename ExpressionType, unsigned int Added, unsigned int Removed>
struct ei_traits<Flagged<ExpressionType, Added, Removed> > struct ei_traits<Flagged<ExpressionType, Added, Removed> >
{ {
typedef typename ExpressionType::Scalar Scalar; typedef typename ExpressionType::Scalar Scalar;
enum { enum {
RowsAtCompileTime = ExpressionType::RowsAtCompileTime, RowsAtCompileTime = ExpressionType::RowsAtCompileTime,
ColsAtCompileTime = ExpressionType::ColsAtCompileTime, ColsAtCompileTime = ExpressionType::ColsAtCompileTime,
@ -59,11 +60,13 @@ template<typename ExpressionType, unsigned int Added, unsigned int Removed> clas
public: public:
EIGEN_GENERIC_PUBLIC_INTERFACE(Flagged) EIGEN_GENERIC_PUBLIC_INTERFACE(Flagged)
typedef typename ei_meta_if<ei_must_nest_by_value<ExpressionType>::ret,
ExpressionType, const ExpressionType&>::ret ExpressionTypeNested;
inline Flagged(const ExpressionType& matrix) : m_matrix(matrix) {} inline Flagged(const ExpressionType& matrix) : m_matrix(matrix) {}
/** \internal */ /** \internal */
inline ExpressionType _expression() const { return m_matrix; } inline const ExpressionType& _expression() const { return m_matrix; }
private: private:
@ -94,7 +97,7 @@ template<typename ExpressionType, unsigned int Added, unsigned int Removed> clas
} }
protected: protected:
typename ExpressionType::Nested m_matrix; ExpressionTypeNested m_matrix;
}; };
/** \returns an expression of *this with added flags /** \returns an expression of *this with added flags

View File

@ -165,12 +165,10 @@ template<typename T> class ei_product_eval_to_column_major
template<typename T, int n=1> struct ei_product_nested_rhs template<typename T, int n=1> struct ei_product_nested_rhs
{ {
typedef typename ei_meta_if< typedef typename ei_meta_if<
ei_must_nest_by_value<T>::ret && (!(ei_traits<T>::Flags & RowMajorBit)) && (int(ei_traits<T>::Flags) & DirectAccessBit), ei_must_nest_by_value<T>::ret,
T, T,
typename ei_meta_if< typename ei_meta_if<
((ei_traits<T>::Flags & EvalBeforeNestingBit) ((ei_traits<T>::Flags & EvalBeforeNestingBit)
|| (ei_traits<T>::Flags & RowMajorBit)
|| (!(ei_traits<T>::Flags & DirectAccessBit))
|| (n+1) * (NumTraits<typename ei_traits<T>::Scalar>::ReadCost) < (n-1) * T::CoeffReadCost), || (n+1) * (NumTraits<typename ei_traits<T>::Scalar>::ReadCost) < (n-1) * T::CoeffReadCost),
typename ei_product_eval_to_column_major<T>::type, typename ei_product_eval_to_column_major<T>::type,
const T& const T&
@ -178,18 +176,37 @@ template<typename T, int n=1> struct ei_product_nested_rhs
>::ret type; >::ret type;
}; };
template<typename T, int n=1> struct ei_product_nested_lhs // template<typename T, int n=1> struct ei_product_nested_lhs
// {
// typedef typename ei_meta_if<
// ei_must_nest_by_value<T>::ret && (int(ei_traits<T>::Flags) & DirectAccessBit),
// T,
// typename ei_meta_if<
// int(ei_traits<T>::Flags) & EvalBeforeNestingBit
// || (!(int(ei_traits<T>::Flags) & DirectAccessBit))
// || (n+1) * int(NumTraits<typename ei_traits<T>::Scalar>::ReadCost) < (n-1) * int(T::CoeffReadCost),
// typename ei_eval<T>::type,
// const T&
// >::ret
// >::ret type;
// };
template<typename T> struct ei_product_copy_rhs
{ {
typedef typename ei_meta_if< typedef typename ei_meta_if<
ei_must_nest_by_value<T>::ret && (int(ei_traits<T>::Flags) & DirectAccessBit), (ei_traits<T>::Flags & RowMajorBit)
T, || (!(ei_traits<T>::Flags & DirectAccessBit)),
typename ei_meta_if< typename ei_product_eval_to_column_major<T>::type,
int(ei_traits<T>::Flags) & EvalBeforeNestingBit const T&
|| (!(int(ei_traits<T>::Flags) & DirectAccessBit)) >::ret type;
|| (n+1) * int(NumTraits<typename ei_traits<T>::Scalar>::ReadCost) < (n-1) * int(T::CoeffReadCost), };
template<typename T> struct ei_product_copy_lhs
{
typedef typename ei_meta_if<
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_eval<T>::type, typename ei_eval<T>::type,
const T& const T&
>::ret
>::ret type; >::ret type;
}; };
@ -199,9 +216,9 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
typedef typename Lhs::Scalar Scalar; typedef typename Lhs::Scalar Scalar;
// the cache friendly product evals lhs once only // the cache friendly product evals lhs once only
// FIXME what to do if we chose to dynamically call the normal product from the cache friendly one for small matrices ? // FIXME what to do if we chose to dynamically call the normal product from the cache friendly one for small matrices ?
typedef typename ei_meta_if<EvalMode==CacheFriendlyProduct, typedef /*typename ei_meta_if<EvalMode==CacheFriendlyProduct,*/
typename ei_product_nested_lhs<Lhs,1>::type, // typename ei_product_nested_lhs<Lhs,1>::type,
typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type>::ret LhsNested; typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type/*>::ret*/ LhsNested;
// NOTE that rhs must be ColumnMajor, so we might need a special nested type calculation // NOTE that rhs must be ColumnMajor, so we might need a special nested type calculation
typedef typename ei_meta_if<EvalMode==CacheFriendlyProduct, typedef typename ei_meta_if<EvalMode==CacheFriendlyProduct,
@ -225,10 +242,9 @@ struct ei_traits<Product<Lhs, Rhs, EvalMode> >
_Vectorizable = (_LhsVectorizable || _RhsVectorizable) ? 1 : 0, _Vectorizable = (_LhsVectorizable || _RhsVectorizable) ? 1 : 0,
_RowMajor = (RhsFlags & RowMajorBit) _RowMajor = (RhsFlags & RowMajorBit)
&& (EvalMode==(int)CacheFriendlyProduct ? (int)LhsFlags & RowMajorBit : (!_LhsVectorizable)), && (EvalMode==(int)CacheFriendlyProduct ? (int)LhsFlags & RowMajorBit : (!_LhsVectorizable)),
_LostBits = HereditaryBits & ~( _LostBits = ~((_RowMajor ? 0 : RowMajorBit)
(_RowMajor ? 0 : RowMajorBit)
| ((RowsAtCompileTime == Dynamic || ColsAtCompileTime == Dynamic) ? 0 : LargeBit)), | ((RowsAtCompileTime == Dynamic || ColsAtCompileTime == Dynamic) ? 0 : LargeBit)),
Flags = ((unsigned int)(LhsFlags | RhsFlags) & _LostBits & ~NestedByValue) Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & _LostBits)
| EvalBeforeAssigningBit | EvalBeforeAssigningBit
| EvalBeforeNestingBit | EvalBeforeNestingBit
| (_Vectorizable ? VectorizableBit : 0), | (_Vectorizable ? VectorizableBit : 0),
@ -369,6 +385,7 @@ template<typename Lhs,typename Rhs>
inline Derived& inline Derived&
MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other) MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProduct>, 0, EvalBeforeNestingBit | EvalBeforeAssigningBit>& other)
{ {
std::cout << "_cacheFriendlyEvalAndAdd\n";
other._expression()._cacheFriendlyEvalAndAdd(const_cast_derived()); other._expression()._cacheFriendlyEvalAndAdd(const_cast_derived());
return derived(); return derived();
} }
@ -396,6 +413,7 @@ struct ei_cache_friendly_selector
) )
{ {
res.setZero(); res.setZero();
// typename ei_product_copy_lhs<>::type
ei_cache_friendly_product<Scalar>( ei_cache_friendly_product<Scalar>(
product._rows(), product._cols(), product.m_lhs.cols(), product._rows(), product._cols(), product.m_lhs.cols(),
_LhsNested::Flags&RowMajorBit, &(product.m_lhs.const_cast_derived().coeffRef(0,0)), product.m_lhs.stride(), _LhsNested::Flags&RowMajorBit, &(product.m_lhs.const_cast_derived().coeffRef(0,0)), product.m_lhs.stride(),
@ -452,18 +470,70 @@ template<typename Lhs, typename Rhs, int EvalMode>
template<typename DestDerived> template<typename DestDerived>
inline void Product<Lhs,Rhs,EvalMode>::_cacheFriendlyEval(DestDerived& res) const inline void Product<Lhs,Rhs,EvalMode>::_cacheFriendlyEval(DestDerived& res) const
{ {
ei_cache_friendly_selector<Lhs,Rhs,EvalMode,DestDerived, // ei_cache_friendly_selector<Lhs,Rhs,EvalMode,DestDerived,
_LhsNested::Flags&_RhsNested::Flags&DirectAccessBit> // _LhsNested::Flags&_RhsNested::Flags&DirectAccessBit>
::eval(*this, res); // ::eval(*this, res);
if ( _rows()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
&& _cols()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
&& m_lhs.cols()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
)
{
res.setZero();
// typedef typename ei_eval<_LhsNested>::type LhsCopy;
// typedef typename ei_product_eval_to_column_major<_RhsNested>::type RhsCopy;
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
LhsCopy lhs(m_lhs);
RhsCopy rhs(m_rhs);
ei_cache_friendly_product<Scalar>(
_rows(), _cols(), lhs.cols(),
_LhsCopy::Flags&RowMajorBit, &(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
_RhsCopy::Flags&RowMajorBit, &(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
Flags&RowMajorBit, &(res.coeffRef(0,0)), res.stride()
);
}
else
{
res = Product<_LhsNested,_RhsNested,NormalProduct>(m_lhs, m_rhs).lazy();
}
} }
template<typename Lhs, typename Rhs, int EvalMode> template<typename Lhs, typename Rhs, int EvalMode>
template<typename DestDerived> template<typename DestDerived>
inline void Product<Lhs,Rhs,EvalMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const inline void Product<Lhs,Rhs,EvalMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
{ {
ei_cache_friendly_selector<Lhs,Rhs,EvalMode,DestDerived, std::cout << "_cacheFriendlyEvalAndAdd\n";
_LhsNested::Flags&_RhsNested::Flags&DirectAccessBit> // ei_cache_friendly_selector<Lhs,Rhs,EvalMode,DestDerived,
::eval_and_add(*this, res); // _LhsNested::Flags&_RhsNested::Flags&DirectAccessBit>
// ::eval_and_add(*this, res);
if ( _rows()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
&& _cols()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
&& m_lhs.cols()>=EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
)
{
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
LhsCopy lhs(m_lhs);
RhsCopy rhs(m_rhs);
ei_cache_friendly_product<Scalar>(
_rows(), _cols(), lhs.cols(),
_LhsCopy::Flags&RowMajorBit, &(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
_RhsCopy::Flags&RowMajorBit, &(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
Flags&RowMajorBit, &(res.coeffRef(0,0)), res.stride()
);
}
else
{
res += Product<_LhsNested,_RhsNested,NormalProduct>(m_lhs, m_rhs).lazy();
}
} }
#endif // EIGEN_PRODUCT_H #endif // EIGEN_PRODUCT_H