Enable packet segment in partial redux

This commit is contained in:
Charles Schlosser 2025-04-14 17:44:53 +00:00 committed by Rasmus Munk Larsen
parent 6266d430cc
commit 5330960900
6 changed files with 68 additions and 54 deletions

View File

@ -136,8 +136,7 @@ struct copy_using_evaluator_traits {
: Traversal == SliceVectorizedTraversal ? (MayUnrollInner ? InnerUnrolling : NoUnrolling)
#endif
: NoUnrolling;
static constexpr bool UsePacketSegment =
enable_packet_segment<Src>::value && enable_packet_segment<Dst>::value && has_packet_segment<PacketType>::value;
static constexpr bool UsePacketSegment = has_packet_segment<PacketType>::value;
#ifdef EIGEN_DEBUG_ASSIGN
static void debug() {

View File

@ -103,19 +103,36 @@ struct packetwise_redux_impl<Func, Evaluator, NoUnrolling> {
EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size) {
if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
const Index size4 = (size - 1) & (~3);
const Index size4 = 1 + numext::round_down(size - 1, 4);
PacketType p = eval.template packetByOuterInner<Unaligned, PacketType>(0, 0);
Index i = 1;
// This loop is optimized for instruction pipelining:
// - each iteration generates two independent instructions
// - thanks to branch prediction and out-of-order execution we have independent instructions across loops
for (; i < size4; i += 4)
for (Index i = 1; i < size4; i += 4)
p = func.packetOp(
p, func.packetOp(func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 0, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 1, 0)),
func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 2, 0),
eval.template packetByOuterInner<Unaligned, PacketType>(i + 3, 0))));
for (; i < size; ++i) p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
for (Index i = size4; i < size; ++i)
p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
return p;
}
};
template <typename Func, typename Evaluator>
struct packetwise_segment_redux_impl {
typedef typename Evaluator::Scalar Scalar;
typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
template <typename PacketType>
EIGEN_DEVICE_FUNC static PacketType run(const Evaluator& eval, const Func& func, Index size, Index begin,
Index count) {
if (size == 0) return packetwise_redux_empty_value<PacketType>(func);
PacketType p = eval.template packetSegmentByOuterInner<Unaligned, PacketType>(0, 0, begin, count);
for (Index i = 1; i < size; ++i)
p = func.packetOp(p, eval.template packetSegmentByOuterInner<Unaligned, PacketType>(i, 0, begin, count));
return p;
}
};
@ -174,14 +191,13 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packet(Index idx) const {
enum { PacketSize = internal::unpacket_traits<PacketType>::size };
typedef Block<const ArgTypeNestedCleaned, Direction == Vertical ? int(ArgType::RowsAtCompileTime) : int(PacketSize),
Direction == Vertical ? int(PacketSize) : int(ArgType::ColsAtCompileTime), true /* InnerPanel */>
PanelType;
PanelType panel(m_arg, Direction == Vertical ? 0 : idx, Direction == Vertical ? idx : 0,
Direction == Vertical ? m_arg.rows() : Index(PacketSize),
Direction == Vertical ? Index(PacketSize) : m_arg.cols());
static constexpr int PacketSize = internal::unpacket_traits<PacketType>::size;
static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : PacketSize;
static constexpr int PanelCols = Direction == Vertical ? PacketSize : ArgType::ColsAtCompileTime;
using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
using BinaryOp = typename MemberOp::BinaryOp;
using Impl = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>;
// FIXME
// See bug 1612, currently if PacketSize==1 (i.e. complex<double> with 128bits registers) then the storage-order of
@ -189,11 +205,39 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
// by pass "vectorization" in this case:
if (PacketSize == 1) return internal::pset1<PacketType>(coeff(idx));
typedef typename internal::redux_evaluator<PanelType> PanelEvaluator;
Index startRow = Direction == Vertical ? 0 : idx;
Index startCol = Direction == Vertical ? idx : 0;
Index numRows = Direction == Vertical ? m_arg.rows() : PacketSize;
Index numCols = Direction == Vertical ? PacketSize : m_arg.cols();
PanelType panel(m_arg, startRow, startCol, numRows, numCols);
PanelEvaluator panel_eval(panel);
typedef typename MemberOp::BinaryOp BinaryOp;
PacketType p = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>::template run<PacketType>(
panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
return p;
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index i, Index j, Index begin, Index count) const {
return packetSegment<LoadMode, PacketType>(Direction == Vertical ? j : i, begin, count);
}
template <int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packetSegment(Index idx, Index begin, Index count) const {
static constexpr int PanelRows = Direction == Vertical ? ArgType::RowsAtCompileTime : Dynamic;
static constexpr int PanelCols = Direction == Vertical ? Dynamic : ArgType::ColsAtCompileTime;
using PanelType = Block<const ArgTypeNestedCleaned, PanelRows, PanelCols, true /* InnerPanel */>;
using PanelEvaluator = typename internal::redux_evaluator<PanelType>;
using BinaryOp = typename MemberOp::BinaryOp;
using Impl = internal::packetwise_segment_redux_impl<BinaryOp, PanelEvaluator>;
Index startRow = Direction == Vertical ? 0 : idx;
Index startCol = Direction == Vertical ? idx : 0;
Index numRows = Direction == Vertical ? m_arg.rows() : begin + count;
Index numCols = Direction == Vertical ? begin + count : m_arg.cols();
PanelType panel(m_arg, startRow, startCol, numRows, numCols);
PanelEvaluator panel_eval(panel);
PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize(), begin, count);
return p;
}

View File

@ -414,6 +414,13 @@ class redux_evaluator : public internal::evaluator<XprType_> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetByOuterInner(Index outer, Index inner) const {
return Base::template packet<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer);
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegmentByOuterInner(Index outer, Index inner, Index begin,
Index count) const {
return Base::template packetSegment<LoadMode, PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer,
begin, count);
}
};
} // end namespace internal

View File

@ -37,9 +37,6 @@ class PartialReduxExpr;
namespace internal {
template <typename ArgType, typename MemberOp, int Direction>
struct enable_packet_segment<PartialReduxExpr<ArgType, MemberOp, Direction>> : std::false_type {};
template <typename MatrixType, typename MemberOp, int Direction>
struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> > : traits<MatrixType> {
typedef typename MemberOp::result_type Scalar;

View File

@ -517,9 +517,6 @@ struct eigen_zero_impl;
template <typename Packet>
struct has_packet_segment : std::false_type {};
template <typename Xpr>
struct enable_packet_segment : std::true_type {};
} // namespace internal
} // end namespace Eigen

View File

@ -996,36 +996,6 @@ struct is_matrix_base_xpr : std::is_base_of<MatrixBase<remove_all_t<XprType>>, r
template <typename XprType>
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
/*---------------- load/store segment support ----------------*/
// recursively traverse unary, binary, and ternary expressions to determine if packet segments are supported
template <typename Func, typename Xpr>
struct enable_packet_segment<CwiseNullaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Func, typename Xpr>
struct enable_packet_segment<CwiseUnaryOp<Func, Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Func, typename LhsXpr, typename RhsXpr>
struct enable_packet_segment<CwiseBinaryOp<Func, LhsXpr, RhsXpr>>
: bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
template <typename Func, typename LhsXpr, typename MidXpr, typename RhsXpr>
struct enable_packet_segment<CwiseTernaryOp<Func, LhsXpr, MidXpr, RhsXpr>>
: bool_constant<enable_packet_segment<remove_all_t<LhsXpr>>::value &&
enable_packet_segment<remove_all_t<MidXpr>>::value &&
enable_packet_segment<remove_all_t<RhsXpr>>::value> {};
template <typename Xpr>
struct enable_packet_segment<ArrayWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Xpr>
struct enable_packet_segment<MatrixWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
template <typename Xpr>
struct enable_packet_segment<DiagonalWrapper<Xpr>> : enable_packet_segment<remove_all_t<Xpr>> {};
} // end namespace internal
/** \class ScalarBinaryOpTraits