Enable half-packet in reduxions.

This commit is contained in:
Gael Guennebaud 2016-04-13 13:02:34 +02:00
parent e9b12cc1f7
commit bbb8854bf7

View File

@ -27,8 +27,9 @@ template<typename Func, typename Derived>
struct redux_traits struct redux_traits
{ {
public: public:
typedef typename find_best_packet<typename Derived::Scalar,Derived::SizeAtCompileTime>::type PacketType;
enum { enum {
PacketSize = packet_traits<typename Derived::Scalar>::size, PacketSize = unpacket_traits<PacketType>::size,
InnerMaxSize = int(Derived::IsRowMajor) InnerMaxSize = int(Derived::IsRowMajor)
? Derived::MaxColsAtCompileTime ? Derived::MaxColsAtCompileTime
: Derived::MaxRowsAtCompileTime : Derived::MaxRowsAtCompileTime
@ -137,12 +138,12 @@ template<typename Func, typename Derived, int Start, int Length>
struct redux_vec_unroller struct redux_vec_unroller
{ {
enum { enum {
PacketSize = packet_traits<typename Derived::Scalar>::size, PacketSize = redux_traits<Func, Derived>::PacketSize,
HalfLength = Length/2 HalfLength = Length/2
}; };
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename packet_traits<Scalar>::type PacketScalar; typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func& func) static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func& func)
{ {
@ -156,14 +157,14 @@ template<typename Func, typename Derived, int Start>
struct redux_vec_unroller<Func, Derived, Start, 1> struct redux_vec_unroller<Func, Derived, Start, 1>
{ {
enum { enum {
index = Start * packet_traits<typename Derived::Scalar>::size, index = Start * redux_traits<Func, Derived>::PacketSize,
outer = index / int(Derived::InnerSizeAtCompileTime), outer = index / int(Derived::InnerSizeAtCompileTime),
inner = index % int(Derived::InnerSizeAtCompileTime), inner = index % int(Derived::InnerSizeAtCompileTime),
alignment = Derived::Alignment alignment = Derived::Alignment
}; };
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename packet_traits<Scalar>::type PacketScalar; typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func&) static EIGEN_STRONG_INLINE PacketScalar run(const Derived &mat, const Func&)
{ {
@ -209,13 +210,13 @@ template<typename Func, typename Derived>
struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling> struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling>
{ {
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename packet_traits<Scalar>::type PacketScalar; typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
static Scalar run(const Derived &mat, const Func& func) static Scalar run(const Derived &mat, const Func& func)
{ {
const Index size = mat.size(); const Index size = mat.size();
const Index packetSize = packet_traits<Scalar>::size; const Index packetSize = redux_traits<Func, Derived>::PacketSize;
const int packetAlignment = unpacket_traits<PacketScalar>::alignment; const int packetAlignment = unpacket_traits<PacketScalar>::alignment;
enum { enum {
alignment0 = (bool(Derived::Flags & DirectAccessBit) && bool(packet_traits<Scalar>::AlignedOnScalar)) ? int(packetAlignment) : int(Unaligned), alignment0 = (bool(Derived::Flags & DirectAccessBit) && bool(packet_traits<Scalar>::AlignedOnScalar)) ? int(packetAlignment) : int(Unaligned),
@ -268,7 +269,7 @@ template<typename Func, typename Derived, int Unrolling>
struct redux_impl<Func, Derived, SliceVectorizedTraversal, Unrolling> struct redux_impl<Func, Derived, SliceVectorizedTraversal, Unrolling>
{ {
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename packet_traits<Scalar>::type PacketType; typedef typename redux_traits<Func, Derived>::PacketType PacketType;
EIGEN_DEVICE_FUNC static Scalar run(const Derived &mat, const Func& func) EIGEN_DEVICE_FUNC static Scalar run(const Derived &mat, const Func& func)
{ {
@ -276,7 +277,7 @@ struct redux_impl<Func, Derived, SliceVectorizedTraversal, Unrolling>
const Index innerSize = mat.innerSize(); const Index innerSize = mat.innerSize();
const Index outerSize = mat.outerSize(); const Index outerSize = mat.outerSize();
enum { enum {
packetSize = packet_traits<Scalar>::size packetSize = redux_traits<Func, Derived>::PacketSize
}; };
const Index packetedInnerSize = ((innerSize)/packetSize)*packetSize; const Index packetedInnerSize = ((innerSize)/packetSize)*packetSize;
Scalar res; Scalar res;
@ -306,9 +307,10 @@ template<typename Func, typename Derived>
struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling> struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
{ {
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename packet_traits<Scalar>::type PacketScalar;
typedef typename redux_traits<Func, Derived>::PacketType PacketScalar;
enum { enum {
PacketSize = packet_traits<Scalar>::size, PacketSize = redux_traits<Func, Derived>::PacketSize,
Size = Derived::SizeAtCompileTime, Size = Derived::SizeAtCompileTime,
VectorizedSize = (Size / PacketSize) * PacketSize VectorizedSize = (Size / PacketSize) * PacketSize
}; };
@ -367,11 +369,11 @@ public:
{ return m_evaluator.coeff(index); } { return m_evaluator.coeff(index); }
template<int LoadMode, typename PacketType> template<int LoadMode, typename PacketType>
PacketReturnType packet(Index row, Index col) const PacketType packet(Index row, Index col) const
{ return m_evaluator.template packet<LoadMode,PacketType>(row, col); } { return m_evaluator.template packet<LoadMode,PacketType>(row, col); }
template<int LoadMode, typename PacketType> template<int LoadMode, typename PacketType>
PacketReturnType packet(Index index) const PacketType packet(Index index) const
{ return m_evaluator.template packet<LoadMode,PacketType>(index); } { return m_evaluator.template packet<LoadMode,PacketType>(index); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -379,7 +381,7 @@ public:
{ return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); } { return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
template<int LoadMode, typename PacketType> template<int LoadMode, typename PacketType>
PacketReturnType packetByOuterInner(Index outer, Index inner) const PacketType packetByOuterInner(Index outer, Index inner) const
{ return m_evaluator.template packet<LoadMode,PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); } { return m_evaluator.template packet<LoadMode,PacketType>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
const XprType & nestedExpression() const { return m_xpr; } const XprType & nestedExpression() const { return m_xpr; }