From f2c9c2d2f7f16f6cf5f503d0cc383221f3d17b64 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 20 Oct 2021 16:58:01 +0000 Subject: [PATCH] Vectorize Visitor.h. --- Eigen/src/Core/Visitor.h | 309 +++++++++++++++++++++++---------------- 1 file changed, 182 insertions(+), 127 deletions(-) diff --git a/Eigen/src/Core/Visitor.h b/Eigen/src/Core/Visitor.h index cf4e06a9b..b9d7b612b 100644 --- a/Eigen/src/Core/Visitor.h +++ b/Eigen/src/Core/Visitor.h @@ -16,8 +16,11 @@ namespace Eigen { namespace internal { +template::PacketAccess)> +struct visitor_impl; + template -struct visitor_impl +struct visitor_impl { enum { col = (UnrollCount-1) / Derived::RowsAtCompileTime, @@ -33,7 +36,7 @@ struct visitor_impl }; template -struct visitor_impl +struct visitor_impl { EIGEN_DEVICE_FUNC static inline void run(const Derived &mat, Visitor& visitor) @@ -44,14 +47,14 @@ struct visitor_impl // This specialization enables visitors on empty matrices at compile-time template -struct visitor_impl { +struct visitor_impl { EIGEN_DEVICE_FUNC static inline void run(const Derived &/*mat*/, Visitor& /*visitor*/) {} }; template -struct visitor_impl +struct visitor_impl { EIGEN_DEVICE_FUNC static inline void run(const Derived& mat, Visitor& visitor) @@ -65,21 +68,62 @@ struct visitor_impl } }; +template +struct visitor_impl +{ + typedef typename Derived::Scalar Scalar; + typedef typename packet_traits::type Packet; + + EIGEN_DEVICE_FUNC + static inline void run(const Derived& mat, Visitor& visitor) + { + const Index PacketSize = packet_traits::size; + visitor.init(mat.coeff(0,0), 0, 0); + if (Derived::IsRowMajor) { + for(Index i = 0; i < mat.rows(); ++i) { + Index j = i == 0 ? 1 : 0; + for(; j+PacketSize-1 < mat.cols(); j += PacketSize) { + Packet p = mat.packet(i, j); + visitor.packet(p, i, j); + } + for(; j < mat.cols(); ++j) + visitor(mat.coeff(i, j), i, j); + } + } else { + for(Index j = 0; j < mat.cols(); ++j) { + Index i = j == 0 ? 1 : 0; + for(; i+PacketSize-1 < mat.rows(); i += PacketSize) { + Packet p = mat.packet(i, j); + visitor.packet(p, i, j); + } + for(; i < mat.rows(); ++i) + visitor(mat.coeff(i, j), i, j); + } + } + } +}; + // evaluator adaptor template class visitor_evaluator { public: + typedef internal::evaluator Evaluator; + + enum { + PacketAccess = Evaluator::Flags & PacketAccessBit, + IsRowMajor = XprType::IsRowMajor, + RowsAtCompileTime = XprType::RowsAtCompileTime, + CoeffReadCost = Evaluator::CoeffReadCost + }; + + EIGEN_DEVICE_FUNC - explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {} + explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { } typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; - - enum { - RowsAtCompileTime = XprType::RowsAtCompileTime, - CoeffReadCost = internal::evaluator::CoeffReadCost - }; + typedef typename XprType::PacketReturnType PacketReturnType; EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); } EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); } @@ -87,11 +131,14 @@ public: EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const { return m_evaluator.coeff(row, col); } + EIGEN_DEVICE_FUNC PacketReturnType packet(Index row, Index col) const + { return m_evaluator.template packet(row, col); } protected: - internal::evaluator m_evaluator; + Evaluator m_evaluator; const XprType &m_xpr; }; + } // end namespace internal /** Applies the visitor \a visitor to the whole coefficients of the matrix or vector. @@ -154,123 +201,131 @@ struct coeff_visitor } }; -/** \internal - * \brief Visitor computing the min coefficient with its value and coordinates - * - * \sa DenseBase::minCoeff(Index*, Index*) - */ -template -struct min_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if(value < this->res) - { - this->res = value; - this->row = i; - this->col = j; - } - } -}; -template -struct min_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value < this->res)) - { - this->res = value; - this->row = i; - this->col = j; - } - } -}; - -template -struct min_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if((numext::isnan)(value) || value < this->res) - { - this->res = value; - this->row = i; - this->col = j; - } - } +template +struct minmax_compare { + typedef typename packet_traits::type Packet; + static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a < b; } + static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_min(p);} }; template - struct functor_traits > { +struct minmax_compare { + typedef typename packet_traits::type Packet; + static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a > b; } + static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max(p);} +}; + +template +struct minmax_coeff_visitor : coeff_visitor +{ + using Scalar = typename Derived::Scalar; + using Packet = typename packet_traits::type; + using Comparator = minmax_compare; + + EIGEN_DEVICE_FUNC inline + void operator() (const Scalar& value, Index i, Index j) + { + if(Comparator::compare(value, this->res)) { + this->res = value; + this->row = i; + this->col = j; + } + } + + EIGEN_DEVICE_FUNC inline + void packet(const Packet& p, Index i, Index j) { + const Index PacketSize = packet_traits::size; + Scalar value = Comparator::predux(p); + if (Comparator::compare(value, this->res)) { + const Packet range = preverse(plset(Scalar(1))); + Packet mask = pcmp_eq(pset1(value), p); + Index max_idx = PacketSize - static_cast(predux_max(pand(range, mask))); + this->res = value; + this->row = Derived::IsRowMajor ? i : i + max_idx;; + this->col = Derived::IsRowMajor ? j + max_idx : j; + } + } +}; + +// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN, in which case, +// the row=0, col=0 is returned for the location. +template +struct minmax_coeff_visitor : coeff_visitor +{ + typedef typename Derived::Scalar Scalar; + using Packet = typename packet_traits::type; + using Comparator = minmax_compare; + + EIGEN_DEVICE_FUNC inline + void operator() (const Scalar& value, Index i, Index j) + { + if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) { + this->res = value; + this->row = i; + this->col = j; + } + } + + EIGEN_DEVICE_FUNC inline + void packet(const Packet& p, Index i, Index j) { + const Index PacketSize = packet_traits::size; + Scalar value = Comparator::predux(p); + if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) { + const Packet range = preverse(plset(Scalar(1))); + /* mask will be zero for NaNs, so they will be ignored. */ + Packet mask = pcmp_eq(pset1(value), p); + Index max_idx = PacketSize - static_cast(predux_max(pand(range, mask))); + this->res = value; + this->row = Derived::IsRowMajor ? i : i + max_idx;; + this->col = Derived::IsRowMajor ? j + max_idx : j; + } + } + +}; + +// Propagate NaN. If the matrix contains NaN, the location of the first NaN will be returned in +// row and col. +template +struct minmax_coeff_visitor : coeff_visitor +{ + typedef typename Derived::Scalar Scalar; + using Packet = typename packet_traits::type; + using Comparator = minmax_compare; + + EIGEN_DEVICE_FUNC inline + void operator() (const Scalar& value, Index i, Index j) + { + const bool value_is_nan = (numext::isnan)(value); + if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) { + this->res = value; + this->row = i; + this->col = j; + } + } + + EIGEN_DEVICE_FUNC inline + void packet(const Packet& p, Index i, Index j) { + const Index PacketSize = packet_traits::size; + Scalar value = Comparator::predux(p); + const bool value_is_nan = (numext::isnan)(value); + if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) { + const Packet range = preverse(plset(Scalar(1))); + // If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value. + Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1(value), p); + Index max_idx = PacketSize - static_cast(predux_max(pand(range, mask))); + this->res = value; + this->row = Derived::IsRowMajor ? i : i + max_idx;; + this->col = Derived::IsRowMajor ? j + max_idx : j; + } + } +}; + +template +struct functor_traits > { enum { - Cost = NumTraits::AddCost - }; -}; - -/** \internal - * \brief Visitor computing the max coefficient with its value and coordinates - * - * \sa DenseBase::maxCoeff(Index*, Index*) - */ -template -struct max_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if(value > this->res) - { - this->res = value; - this->row = i; - this->col = j; - } - } -}; - -template -struct max_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if((numext::isnan)(this->res) || (!(numext::isnan)(value) && value > this->res)) - { - this->res = value; - this->row = i; - this->col = j; - } - } -}; - -template -struct max_coeff_visitor : coeff_visitor -{ - typedef typename Derived::Scalar Scalar; - EIGEN_DEVICE_FUNC - void operator() (const Scalar& value, Index i, Index j) - { - if((numext::isnan)(value) || value > this->res) - { - this->res = value; - this->row = i; - this->col = j; - } - } -}; - -template -struct functor_traits > { - enum { - Cost = NumTraits::AddCost + Cost = NumTraits::AddCost, + PacketAccess = true }; }; @@ -295,7 +350,7 @@ DenseBase::minCoeff(IndexType* rowId, IndexType* colId) const { eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix"); - internal::min_coeff_visitor minVisitor; + internal::minmax_coeff_visitor minVisitor; this->visit(minVisitor); *rowId = minVisitor.row; if (colId) *colId = minVisitor.col; @@ -321,7 +376,7 @@ DenseBase::minCoeff(IndexType* index) const eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix"); EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - internal::min_coeff_visitor minVisitor; + internal::minmax_coeff_visitor minVisitor; this->visit(minVisitor); *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row); return minVisitor.res; @@ -346,7 +401,7 @@ DenseBase::maxCoeff(IndexType* rowPtr, IndexType* colPtr) const { eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix"); - internal::max_coeff_visitor maxVisitor; + internal::minmax_coeff_visitor maxVisitor; this->visit(maxVisitor); *rowPtr = maxVisitor.row; if (colPtr) *colPtr = maxVisitor.col; @@ -372,7 +427,7 @@ DenseBase::maxCoeff(IndexType* index) const eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix"); EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - internal::max_coeff_visitor maxVisitor; + internal::minmax_coeff_visitor maxVisitor; this->visit(maxVisitor); *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row; return maxVisitor.res;