Fallback Reshaped to MapBase when possible (same storage order and linear access to the nested expression)

This commit is contained in:
Gael Guennebaud 2017-02-11 15:32:53 +01:00
parent 83d6a529c3
commit 4b22048cea
3 changed files with 156 additions and 120 deletions

View File

@ -57,34 +57,37 @@ struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
ColsAtCompileTime = Cols, ColsAtCompileTime = Cols,
MaxRowsAtCompileTime = Rows, MaxRowsAtCompileTime = Rows,
MaxColsAtCompileTime = Cols, MaxColsAtCompileTime = Cols,
XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0, XpxStorageOrder = ((int(traits<XprType>::Flags) & RowMajorBit) == RowMajorBit) ? RowMajor : ColMajor,
IsRowMajor = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? 1 ReshapedStorageOrder = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? RowMajor
: (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? 0 : (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor
: XprTypeIsRowMajor, : XpxStorageOrder,
HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor), HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder),
InnerSize = IsRowMajor ? int(ColsAtCompileTime) : int(RowsAtCompileTime), InnerSize = (ReshapedStorageOrder==RowMajor) ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
InnerStrideAtCompileTime = HasSameStorageOrderAsXprType InnerStrideAtCompileTime = HasSameStorageOrderAsXprType
? int(inner_stride_at_compile_time<XprType>::ret) ? int(inner_stride_at_compile_time<XprType>::ret)
: int(outer_stride_at_compile_time<XprType>::ret), : Dynamic,
OuterStrideAtCompileTime = HasSameStorageOrderAsXprType OuterStrideAtCompileTime = Dynamic,
? int(outer_stride_at_compile_time<XprType>::ret)
: int(inner_stride_at_compile_time<XprType>::ret), InOrder = Order,
HasDirectAccess = internal::has_direct_access<XprType>::ret
&& (Order==int(AutoOrderValue) || Order==int(XpxStorageOrder))
&& ((evaluator<XprType>::Flags&LinearAccessBit)==LinearAccessBit),
MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0) MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0)
&& (InnerStrideAtCompileTime == 1) && (InnerStrideAtCompileTime == 1)
? PacketAccessBit : 0, ? PacketAccessBit : 0,
//MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0, //MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0,
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0, FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, FlagsRowMajorBit = (ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit) FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
& ~DirectAccessBit, Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit),
Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit) Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit | FlagsDirectAccessBit)
}; };
}; };
template<typename XprType, int Rows=Dynamic, int Cols=Dynamic, int Order = 0, template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense;
bool HasDirectAccess = internal::has_direct_access<XprType>::ret> class ReshapedImpl_dense;
} // end namespace internal } // end namespace internal
@ -127,9 +130,9 @@ template<typename XprType, int Rows, int Cols, int Order> class Reshaped
// that must be specialized for direct and non-direct access... // that must be specialized for direct and non-direct access...
template<typename XprType, int Rows, int Cols, int Order> template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl<XprType, Rows, Cols, Order, Dense> class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
: public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order> : public internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess>
{ {
typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order> Impl; typedef internal::ReshapedImpl_dense<XprType, Rows, Cols, Order,internal::traits<Reshaped<XprType,Rows,Cols,Order> >::HasDirectAccess> Impl;
public: public:
typedef Impl Base; typedef Impl Base;
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl)
@ -140,8 +143,9 @@ class ReshapedImpl<XprType, Rows, Cols, Order, Dense>
namespace internal { namespace internal {
/** \internal Internal implementation of dense Reshapeds in the general case. */ /** \internal Internal implementation of dense Reshaped in the general case. */
template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess> class ReshapedImpl_dense template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType,Rows,Cols,Order,false>
: public internal::dense_xpr_base<Reshaped<XprType, Rows, Cols, Order> >::type : public internal::dense_xpr_base<Reshaped<XprType, Rows, Cols, Order> >::type
{ {
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType; typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
@ -166,8 +170,7 @@ template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess>
/** Dynamic-size constructor /** Dynamic-size constructor
*/ */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr, inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
Index nRows, Index nCols)
: m_xpr(xpr), m_rows(nRows), m_cols(nCols) : m_xpr(xpr), m_rows(nRows), m_cols(nCols)
{} {}
@ -199,8 +202,106 @@ template<typename XprType, int Rows, int Cols, int Order, bool HasDirectAccess>
}; };
/** \internal Internal implementation of dense Reshaped in the direct access case. */
template<typename XprType, int Rows, int Cols, int Order>
class ReshapedImpl_dense<XprType, Rows, Cols, Order, true>
: public MapBase<Reshaped<XprType, Rows, Cols, Order> >
{
typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType;
typedef typename internal::ref_selector<XprType>::non_const_type XprTypeNested;
public:
typedef MapBase<ReshapedType> Base;
EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
/** Fixed-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr)
: Base(xpr.data()), m_xpr(xpr)
{}
/** Dynamic-size constructor
*/
EIGEN_DEVICE_FUNC
inline ReshapedImpl_dense(XprType& xpr, Index nRows, Index nCols)
: Base(xpr.data(), nRows, nCols),
m_xpr(xpr)
{}
EIGEN_DEVICE_FUNC
const typename internal::remove_all<XprTypeNested>::type& nestedExpression() const
{
return m_xpr;
}
EIGEN_DEVICE_FUNC
XprType& nestedExpression() { return m_xpr; }
/** \sa MapBase::innerStride() */
EIGEN_DEVICE_FUNC
inline Index innerStride() const
{
return m_xpr.innerStride();
}
/** \sa MapBase::outerStride() */
EIGEN_DEVICE_FUNC
inline Index outerStride() const
{
return ((Flags&RowMajorBit)==RowMajorBit) ? this->cols() : this->rows();
}
protected:
XprTypeNested m_xpr;
};
// Evaluators
template<typename ArgType, int Rows, int Cols, int Order, bool HasDirectAccess> struct reshaped_evaluator;
template<typename ArgType, int Rows, int Cols, int Order> template<typename ArgType, int Rows, int Cols, int Order>
struct unary_evaluator<Reshaped<ArgType, Rows, Cols, Order>, IndexBased> struct evaluator<Reshaped<ArgType, Rows, Cols, Order> >
: reshaped_evaluator<ArgType, Rows, Cols, Order, traits<Reshaped<ArgType,Rows,Cols,Order> >::HasDirectAccess>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
// TODO: should check for smaller packet types
typedef typename packet_traits<Scalar>::type PacketScalar;
enum {
CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
HasDirectAccess = traits<XprType>::HasDirectAccess,
// RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
// ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
// MaxRowsAtCompileTime = traits<XprType>::MaxRowsAtCompileTime,
// MaxColsAtCompileTime = traits<XprType>::MaxColsAtCompileTime,
//
// InnerStrideAtCompileTime = traits<XprType>::HasSameStorageOrderAsXprType
// ? int(inner_stride_at_compile_time<ArgType>::ret)
// : Dynamic,
// OuterStrideAtCompileTime = Dynamic,
FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0,
FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit),
Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit,
PacketAlignment = unpacket_traits<PacketScalar>::alignment,
Alignment = evaluator<ArgType>::Alignment
};
typedef reshaped_evaluator<ArgType, Rows, Cols, Order, HasDirectAccess> reshaped_evaluator_type;
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : reshaped_evaluator_type(xpr)
{
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
};
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ false>
: evaluator_base<Reshaped<ArgType, Rows, Cols, Order> > : evaluator_base<Reshaped<ArgType, Rows, Cols, Order> >
{ {
typedef Reshaped<ArgType, Rows, Cols, Order> XprType; typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
@ -213,7 +314,7 @@ struct unary_evaluator<Reshaped<ArgType, Rows, Cols, Order>, IndexBased>
Alignment = 0 Alignment = 0
}; };
EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr) EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
{ {
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
} }
@ -321,103 +422,21 @@ protected:
}; };
template<typename ArgType, int Rows, int Cols, int Order>
struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ true>
: mapbase_evaluator<Reshaped<ArgType, Rows, Cols, Order>,
typename Reshaped<ArgType, Rows, Cols, Order>::PlainObject>
{
typedef Reshaped<ArgType, Rows, Cols, Order> XprType;
typedef typename XprType::Scalar Scalar;
///** \internal Internal implementation of dense Reshapeds in the direct access case.*/ EIGEN_DEVICE_FUNC explicit reshaped_evaluator(const XprType& xpr)
//template<typename XprType, int Rows, int Cols, int Order> : mapbase_evaluator<XprType, typename XprType::PlainObject>(xpr)
//class ReshapedImpl_dense<XprType,ReshapedRows,ReshapedCols, true> {
// : public MapBase<Reshaped<XprType, Rows, Cols, Order> > // TODO: for the 3.4 release, this should be turned to an internal assertion, but let's keep it as is for the beta lifetime
//{ eigen_assert(((internal::UIntPtr(xpr.data()) % EIGEN_PLAIN_ENUM_MAX(1,evaluator<XprType>::Alignment)) == 0) && "data is not aligned");
// typedef Reshaped<XprType, Rows, Cols, Order> ReshapedType; }
// public: };
//
// typedef MapBase<ReshapedType> Base;
// EIGEN_DENSE_PUBLIC_INTERFACE(ReshapedType)
// EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ReshapedImpl_dense)
//
// /** Column or Row constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr, Index i)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(
// (ReshapedRows==1) && (ReshapedCols==XprType::ColsAtCompileTime) ? i : 0,
// (ReshapedRows==XprType::RowsAtCompileTime) && (ReshapedCols==1) ? i : 0)),
// ReshapedRows==1 ? 1 : xpr.rows(),
// ReshapedCols==1 ? 1 : xpr.cols()),
// m_xpr(xpr)
// {
// init();
// }
//
// /** Fixed-size constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(0, 0))), m_xpr(xpr)
// {
// init();
// }
//
// /** Dynamic-size constructor
// */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr,
// Index reshapeRows, Index reshapeCols)
// : Base(internal::const_cast_ptr(&xpr.coeffRef(0, 0)), reshapeRows, reshapeCols),
// m_xpr(xpr)
// {
// init();
// }
//
// EIGEN_DEVICE_FUNC
// const typename internal::remove_all<typename XprType::Nested>::type& nestedExpression() const
// {
// return m_xpr;
// }
//
// EIGEN_DEVICE_FUNC
// /** \sa MapBase::innerStride() */
// inline Index innerStride() const
// {
// return internal::traits<ReshapedType>::HasSameStorageOrderAsXprType
// ? m_xpr.innerStride()
// : m_xpr.outerStride();
// }
//
// EIGEN_DEVICE_FUNC
// /** \sa MapBase::outerStride() */
// inline Index outerStride() const
// {
// return m_outerStride;
// }
//
// #ifndef __SUNPRO_CC
// // FIXME sunstudio is not friendly with the above friend...
// // META-FIXME there is no 'friend' keyword around here. Is this obsolete?
// protected:
// #endif
//
// #ifndef EIGEN_PARSED_BY_DOXYGEN
// /** \internal used by allowAligned() */
// EIGEN_DEVICE_FUNC
// inline ReshapedImpl_dense(XprType& xpr, const Scalar* data, Index reshapeRows, Index reshapeCols)
// : Base(data, reshapeRows, reshapeCols), m_xpr(xpr)
// {
// init();
// }
// #endif
//
// protected:
// EIGEN_DEVICE_FUNC
// void init()
// {
// m_outerStride = internal::traits<ReshapedType>::HasSameStorageOrderAsXprType
// ? m_xpr.outerStride()
// : m_xpr.innerStride();
// }
//
// typename XprType::Nested m_xpr;
// Index m_outerStride;
//};
} // end namespace internal } // end namespace internal

View File

@ -265,6 +265,11 @@ static const auto fix(int val);
#endif // EIGEN_PARSED_BY_DOXYGEN #endif // EIGEN_PARSED_BY_DOXYGEN
const int AutoOrderValue = 2;
const internal::FixedInt<ColMajor> ColOrder;
const internal::FixedInt<RowMajor> RowOrder;
const internal::FixedInt<AutoOrderValue> AutoOrder;
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_INTEGRAL_CONSTANT_H #endif // EIGEN_INTEGRAL_CONSTANT_H

View File

@ -48,6 +48,18 @@ void reshape_all_size(MatType m)
), ),
MapMat(m.data(), 4, 4) MapMat(m.data(), 4, 4)
); );
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
VERIFY_IS_EQUAL(m.reshaped( 2, 8).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 2, 8).innerStride(), 1);
VERIFY_IS_EQUAL(m.reshaped( 2, 8).outerStride(), 2);
m.reshaped(2,8,ColOrder);
MatrixXi m28r = m.reshaped(2,8,RowOrder);
std::cout << m28r << "\n";
} }
void test_reshape() void test_reshape()