Make reductions compatible with evaluators

This commit is contained in:
Gael Guennebaud 2013-12-02 17:54:38 +01:00
parent 6f1a0479b3
commit f0b82c3ab9

View File

@ -203,7 +203,7 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling>
static Scalar run(const Derived &mat, const Func& func)
{
const Index size = mat.size();
eigen_assert(size && "you are using an empty matrix");
const Index packetSize = packet_traits<Scalar>::size;
const Index alignedStart = internal::first_aligned(mat);
enum {
@ -302,7 +302,6 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
};
static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
{
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func));
if (VectorizedSize != Size)
res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func));
@ -310,6 +309,64 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
}
};
#ifdef EIGEN_ENABLE_EVALUATORS
// evaluator adaptor
template<typename XprType>
class redux_evaluator
{
public:
redux_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketScalar PacketScalar;
typedef typename XprType::PacketReturnType PacketReturnType;
enum {
MaxRowsAtCompileTime = XprType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = XprType::MaxColsAtCompileTime,
// TODO we should not remove DirectAccessBit and rather find an elegant way to query the alignment offset at runtime from the evaluator
Flags = XprType::Flags & ~DirectAccessBit,
IsRowMajor = XprType::IsRowMajor,
SizeAtCompileTime = XprType::SizeAtCompileTime,
InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime,
CoeffReadCost = XprType::CoeffReadCost
};
Index rows() const { return m_xpr.rows(); }
Index cols() const { return m_xpr.cols(); }
Index size() const { return m_xpr.size(); }
Index innerSize() const { return m_xpr.innerSize(); }
Index outerSize() const { return m_xpr.outerSize(); }
CoeffReturnType coeff(Index row, Index col) const
{ return m_evaluator.coeff(row, col); }
CoeffReturnType coeff(Index index) const
{ return m_evaluator.coeff(index); }
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{ return m_evaluator.template packet<LoadMode>(row, col); }
template<int LoadMode>
PacketReturnType packet(Index index) const
{ return m_evaluator.template packet<LoadMode>(index); }
CoeffReturnType coeffByOuterInner(Index outer, Index inner) const
{ return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
template<int LoadMode>
PacketReturnType packetByOuterInner(Index outer, Index inner) const
{ return m_evaluator.template packet<LoadMode>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
protected:
typename internal::evaluator<XprType>::nestedType m_evaluator;
const XprType &m_xpr;
};
#endif
} // end namespace internal
/***************************************************************************
@ -320,7 +377,7 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
/** \returns the result of a full redux operation on the whole matrix or vector using \a func
*
* The template parameter \a BinaryOp is the type of the functor \a func which must be
* an associative operator. Both current STL and TR1 functor styles are handled.
* an associative operator. Both current C++98 and C++11 functor styles are handled.
*
* \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise()
*/
@ -329,9 +386,19 @@ template<typename Func>
EIGEN_STRONG_INLINE typename internal::result_of<Func(typename internal::traits<Derived>::Scalar)>::type
DenseBase<Derived>::redux(const Func& func) const
{
eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
#ifdef EIGEN_TEST_EVALUATORS
typedef typename internal::redux_evaluator<Derived> ThisEvaluator;
ThisEvaluator thisEval(derived());
return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func);
#else
typedef typename internal::remove_all<typename Derived::Nested>::type ThisNested;
return internal::redux_impl<Func, ThisNested>
::run(derived(), func);
#endif
}
/** \returns the minimum of all coefficients of \c *this.