mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-01 16:24:28 +08:00
Make reductions compatible with evaluators
This commit is contained in:
parent
6f1a0479b3
commit
f0b82c3ab9
@ -174,7 +174,7 @@ struct redux_impl<Func, Derived, DefaultTraversal, NoUnrolling>
|
|||||||
typedef typename Derived::Scalar Scalar;
|
typedef typename Derived::Scalar Scalar;
|
||||||
typedef typename Derived::Index Index;
|
typedef typename Derived::Index Index;
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func)
|
static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
|
||||||
{
|
{
|
||||||
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
|
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
|
||||||
Scalar res;
|
Scalar res;
|
||||||
@ -200,10 +200,10 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling>
|
|||||||
typedef typename packet_traits<Scalar>::type PacketScalar;
|
typedef typename packet_traits<Scalar>::type PacketScalar;
|
||||||
typedef typename Derived::Index Index;
|
typedef typename Derived::Index Index;
|
||||||
|
|
||||||
static Scalar run(const Derived& mat, const Func& func)
|
static Scalar run(const Derived &mat, const Func& func)
|
||||||
{
|
{
|
||||||
const Index size = mat.size();
|
const Index size = mat.size();
|
||||||
eigen_assert(size && "you are using an empty matrix");
|
|
||||||
const Index packetSize = packet_traits<Scalar>::size;
|
const Index packetSize = packet_traits<Scalar>::size;
|
||||||
const Index alignedStart = internal::first_aligned(mat);
|
const Index alignedStart = internal::first_aligned(mat);
|
||||||
enum {
|
enum {
|
||||||
@ -258,7 +258,7 @@ struct redux_impl<Func, Derived, SliceVectorizedTraversal, NoUnrolling>
|
|||||||
typedef typename packet_traits<Scalar>::type PacketScalar;
|
typedef typename packet_traits<Scalar>::type PacketScalar;
|
||||||
typedef typename Derived::Index Index;
|
typedef typename Derived::Index Index;
|
||||||
|
|
||||||
static Scalar run(const Derived& mat, const Func& func)
|
static Scalar run(const Derived &mat, const Func& func)
|
||||||
{
|
{
|
||||||
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
|
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
|
||||||
const Index innerSize = mat.innerSize();
|
const Index innerSize = mat.innerSize();
|
||||||
@ -300,9 +300,8 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
|
|||||||
Size = Derived::SizeAtCompileTime,
|
Size = Derived::SizeAtCompileTime,
|
||||||
VectorizedSize = (Size / PacketSize) * PacketSize
|
VectorizedSize = (Size / PacketSize) * PacketSize
|
||||||
};
|
};
|
||||||
static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func)
|
static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
|
||||||
{
|
{
|
||||||
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
|
|
||||||
Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func));
|
Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func));
|
||||||
if (VectorizedSize != Size)
|
if (VectorizedSize != Size)
|
||||||
res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func));
|
res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func));
|
||||||
@ -310,6 +309,64 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||||
|
// evaluator adaptor
|
||||||
|
template<typename XprType>
|
||||||
|
class redux_evaluator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
redux_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
|
||||||
|
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketScalar PacketScalar;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
MaxRowsAtCompileTime = XprType::MaxRowsAtCompileTime,
|
||||||
|
MaxColsAtCompileTime = XprType::MaxColsAtCompileTime,
|
||||||
|
// TODO we should not remove DirectAccessBit and rather find an elegant way to query the alignment offset at runtime from the evaluator
|
||||||
|
Flags = XprType::Flags & ~DirectAccessBit,
|
||||||
|
IsRowMajor = XprType::IsRowMajor,
|
||||||
|
SizeAtCompileTime = XprType::SizeAtCompileTime,
|
||||||
|
InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime,
|
||||||
|
CoeffReadCost = XprType::CoeffReadCost
|
||||||
|
};
|
||||||
|
|
||||||
|
Index rows() const { return m_xpr.rows(); }
|
||||||
|
Index cols() const { return m_xpr.cols(); }
|
||||||
|
Index size() const { return m_xpr.size(); }
|
||||||
|
Index innerSize() const { return m_xpr.innerSize(); }
|
||||||
|
Index outerSize() const { return m_xpr.outerSize(); }
|
||||||
|
|
||||||
|
CoeffReturnType coeff(Index row, Index col) const
|
||||||
|
{ return m_evaluator.coeff(row, col); }
|
||||||
|
|
||||||
|
CoeffReturnType coeff(Index index) const
|
||||||
|
{ return m_evaluator.coeff(index); }
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
PacketReturnType packet(Index row, Index col) const
|
||||||
|
{ return m_evaluator.template packet<LoadMode>(row, col); }
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
PacketReturnType packet(Index index) const
|
||||||
|
{ return m_evaluator.template packet<LoadMode>(index); }
|
||||||
|
|
||||||
|
CoeffReturnType coeffByOuterInner(Index outer, Index inner) const
|
||||||
|
{ return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
PacketReturnType packetByOuterInner(Index outer, Index inner) const
|
||||||
|
{ return m_evaluator.template packet<LoadMode>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename internal::evaluator<XprType>::nestedType m_evaluator;
|
||||||
|
const XprType &m_xpr;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
@ -320,7 +377,7 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
|
|||||||
/** \returns the result of a full redux operation on the whole matrix or vector using \a func
|
/** \returns the result of a full redux operation on the whole matrix or vector using \a func
|
||||||
*
|
*
|
||||||
* The template parameter \a BinaryOp is the type of the functor \a func which must be
|
* The template parameter \a BinaryOp is the type of the functor \a func which must be
|
||||||
* an associative operator. Both current STL and TR1 functor styles are handled.
|
* an associative operator. Both current C++98 and C++11 functor styles are handled.
|
||||||
*
|
*
|
||||||
* \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise()
|
* \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise()
|
||||||
*/
|
*/
|
||||||
@ -329,9 +386,19 @@ template<typename Func>
|
|||||||
EIGEN_STRONG_INLINE typename internal::result_of<Func(typename internal::traits<Derived>::Scalar)>::type
|
EIGEN_STRONG_INLINE typename internal::result_of<Func(typename internal::traits<Derived>::Scalar)>::type
|
||||||
DenseBase<Derived>::redux(const Func& func) const
|
DenseBase<Derived>::redux(const Func& func) const
|
||||||
{
|
{
|
||||||
|
eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
typedef typename internal::redux_evaluator<Derived> ThisEvaluator;
|
||||||
|
ThisEvaluator thisEval(derived());
|
||||||
|
return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func);
|
||||||
|
|
||||||
|
#else
|
||||||
typedef typename internal::remove_all<typename Derived::Nested>::type ThisNested;
|
typedef typename internal::remove_all<typename Derived::Nested>::type ThisNested;
|
||||||
|
|
||||||
return internal::redux_impl<Func, ThisNested>
|
return internal::redux_impl<Func, ThisNested>
|
||||||
::run(derived(), func);
|
::run(derived(), func);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns the minimum of all coefficients of \c *this.
|
/** \returns the minimum of all coefficients of \c *this.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user