Enable packet segment in partial redux

This commit is contained in:
Charles Schlosser 2025-04-14 17:44:53 +00:00 committed by Rasmus Munk Larsen
parent 6266d430cc
commit 5330960900
6 changed files with 68 additions and 54 deletions

View File

@ -136,8 +136,7 @@ struct copy_using_evaluator_traits {
: Traversal == SliceVectorizedTraversal ? (MayUnrollInner ? InnerUnrolling : NoUnrolling) : Traversal == SliceVectorizedTraversal ? (MayUnrollInner ? InnerUnrolling : NoUnrolling)
#endif #endif
: NoUnrolling; : NoUnrolling;
static constexpr bool UsePacketSegment = static constexpr bool UsePacketSegment = has_packet_segment<PacketType>::value;
enable_packet_segment<Src>::value && enable_packet_segment<Dst>::value && has_packet_segment<PacketType>::value;
#ifdef EIGEN_DEBUG_ASSIGN #ifdef EIGEN_DEBUG_ASSIGN
static void debug() { static void debug() {

View File

@ -103,19 +103,36 @@ struct packetwise_redux_impl<Func, Evaluator, NoUnrolling> {
EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size) { EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size) {
if (size == 0) return packetwise_redux_empty_value<PacketType>(func); if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
const Index size4 = (size - 1) & (~3); const Index size4 = 1 + numext::round_down(size - 1, 4);
PacketType p = eval.template packetByOuterInner<Unaligned, PacketType>(0, 0); PacketType p = eval.template packetByOuterInner<Unaligned, PacketType>(0, 0);
Index i = 1;
// This loop is optimized for instruction pipelining: // This loop is optimized for instruction pipelining:
// - each iteration generates two independent instructions // - each iteration generates two independent instructions
// - thanks to branch prediction and out-of-order execution we have independent instructions across loops // - thanks to branch prediction and out-of-order execution we have independent instructions across loops
for (; i < size4; i += 4) for (Index i = 1; i < size4; i += 4)
p = func.packetOp( p = func.packetOp(
p, func.packetOp(func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 0, 0), p, func.packetOp(func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 0, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 1, 0)), eval.template packetByOuterInner<Unaligned, PacketType>(i + 1, 0)),
func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 2, 0), func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 2, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 3, 0)))); eval.template packetByOuterInner<Unaligned, PacketType>(i + 3, 0))));
for (; i < size; ++i) p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0)); for (Index i = size4; i < size; ++i)
p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
return p;
}
};
template <typename Func, typename Evaluator>
struct packetwise_segment_redux_impl {
typedef typename Evaluator::Scalar Scalar;
typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
template <typename PacketType>
EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size, Index begin,
Index count) {
if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
PacketType p = eval.template packetSegmentByOuterInner<Unaligned, PacketType>(0, 0, begin, count);
for (Index i = 1; i < size; ++i)
p = func.packetOp(p, eval.template packetSegmentByOuterInner<Unaligned, PacketType>(i, 0, begin, count));
return p; return p;
} }
}; };
@ -174,14 +191,13 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
template <int LoadMode, typename PacketType> template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packet(Index idx) const { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packet(Index idx) const {
enum { PacketSize = internal::unpacket_traits<PacketType>::size }; static constexpr int PacketSize = internal::unpacket_traits<PacketType>::size;
typedef Block<const ArgTypeNestedCleaned, Direction == Vertical ? int(ArgType::RowsAtCompileTime) : int(PacketSize), static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : PacketSize;
Direction == Vertical ? int(PacketSize) : int(ArgType::ColsAtCompileTime), true /* InnerPanel */> static constexpr int PanelCols = Direction == Vertical ? PacketSize : ArgType::ColsAtCompileTime;
PanelType; using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
PanelType panel(m_arg, Direction == Vertical ? 0 : idx, Direction == Vertical ? idx : 0, using BinaryOp = typename MemberOp::BinaryOp;
Direction == Vertical ? m_arg.rows() : Index(PacketSize), using Impl = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>;
Direction == Vertical ? Index(PacketSize) : m_arg.cols());
// FIXME // FIXME
// See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of // See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of
@ -189,11 +205,39 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
// by pass "vectorization" in this case: // by pass "vectorization" in this case:
if (PacketSize == 1) return internal::pset1<PacketType>(coeff(idx)); if (PacketSize == 1) return internal::pset1<PacketType>(coeff(idx));
typedef typename internal::redux_evaluator<PanelType> PanelEvaluator; Index startRow = Direction == Vertical ? 0 : idx;
Index startCol = Direction == Vertical ? idx : 0;
Index numRows = Direction == Vertical ? m_arg.rows() : PacketSize;
Index numCols = Direction == Vertical ? PacketSize : m_arg.cols();
PanelType panel(m_arg, startRow, startCol, numRows, numCols);
PanelEvaluator panel_eval(panel); PanelEvaluator panel_eval(panel);
typedef typename MemberOp::BinaryOp BinaryOp; PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
PacketType p = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>::template run<PacketType>( return p;
panel_eval, m_functor.binaryFunc(), m_arg.outerSize()); }
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index i, Index j, Index begin, Index count) const {
return packetSegment<LoadMode, PacketType>(Direction == Vertical ? j : i, begin, count);
}
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packetSegment(Index idx, Index begin, Index count) const {
static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : Dynamic;
static constexpr int PanelCols = Direction == Vertical ? Dynamic : ArgType::ColsAtCompileTime;
using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
using BinaryOp = typename MemberOp::BinaryOp;
using Impl = internal::packetwise_segment_redux_impl<BinaryOp, PanelEvaluator>;
Index startRow = Direction == Vertical ? 0 : idx;
Index startCol = Direction == Vertical ? idx : 0;
Index numRows = Direction == Vertical ? m_arg.rows() : begin + count;
Index numCols = Direction == Vertical ? begin + count : m_arg.cols();
PanelType panel(m_arg, startRow, startCol, numRows, numCols);
PanelEvaluator panel_eval(panel);
PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize(), begin, count);
return p; return p;
} }

View File

@ -414,6 +414,13 @@ class redux_evaluator : public internal::evaluator<XprType_> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetByOuterInner(Index outer, Index inner) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetByOuterInner(Index outer, Index inner) const {
return Base::template packet<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); return Base::template packet<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer);
} }
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegmentByOuterInner(Index outer, Index inner, Index begin,
Index count) const {
return Base::template packetSegment<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer,
begin, count);
}
}; };
} // end namespace internal } // end namespace internal

View File

@ -37,9 +37,6 @@ class PartialReduxExpr;
namespace internal { namespace internal {
template <typename ArgType, typename MemberOp, int Direction>
struct enable_packet_segment<PartialReduxExpr<ArgType, MemberOp, Direction>> : std::false_type {};
template <typename MatrixType, typename MemberOp, int Direction> template <typename MatrixType, typename MemberOp, int Direction>
struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> > : traits<MatrixType> { struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> > : traits<MatrixType> {
typedef typename MemberOp::result_type Scalar; typedef typename MemberOp::result_type Scalar;

View File

@ -517,9 +517,6 @@ struct eigen_zero_impl;
template <typename Packet> template <typename Packet>
struct has_packet_segment : std::false_type {}; struct has_packet_segment : std::false_type {};
template <typename Xpr>
struct enable_packet_segment : std::true_type {};
} // namespace internal } // namespace internal
} // end namespace Eigen } // end namespace Eigen

View File

@ -996,36 +996,6 @@ struct is_matrix_base_xpr : std::is_base_of<MatrixBase<remove_all_t<XprType>>, r
template <typename XprType> template <typename XprType>
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {}; struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
/*---------------- load/store segment support ----------------*/
// recursively traverse unary, binary, and ternary expressions to determine if packet segments are supported
template <typename Func, typename Xpr>
struct enable_packet_segment<CwiseNullaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Func, typename Xpr>
struct enable_packet_segment<CwiseUnaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Func, typename LhsXpr, typename RhsXpr>
struct enable_packet_segment<CwiseBinaryOp<Func, LhsXpr, RhsXpr>>
: bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
template <typename Func, typename LhsXpr, typename MidXpr, typename RhsXpr>
struct enable_packet_segment<CwiseTernaryOp<Func, LhsXpr, MidXpr, RhsXpr>>
: bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
enable_packet_segment<remove_all_t<MidXpr>>::value &&
enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
template <typename Xpr>
struct enable_packet_segment<ArrayWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Xpr>
struct enable_packet_segment<MatrixWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Xpr>
struct enable_packet_segment<DiagonalWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
} // end namespace internal } // end namespace internal
/** \class ScalarBinaryOpTraits /** \class ScalarBinaryOpTraits