Update reshaped API to use RowMajor/ColMajor directly as integral values instead of introducing RowOrder/ColOrder types.

The API changed from A.respahed(rows,cols,RowOrder) to A.template reshaped<RowOrder>(rows,cols).
This commit is contained in:
Gael Guennebaud 2018-09-19 11:49:26 +02:00
parent 5c68ba41a8
commit dfa8439e4d
5 changed files with 50 additions and 54 deletions

View File

@ -62,7 +62,7 @@ struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
: (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor : (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? ColMajor
: XpxStorageOrder, : XpxStorageOrder,
HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder), HasSameStorageOrderAsXprType = (ReshapedStorageOrder == XpxStorageOrder),
InnerSize = (ReshapedStorageOrder==RowMajor) ? int(ColsAtCompileTime) : int(RowsAtCompileTime), InnerSize = (ReshapedStorageOrder==int(RowMajor)) ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
InnerStrideAtCompileTime = HasSameStorageOrderAsXprType InnerStrideAtCompileTime = HasSameStorageOrderAsXprType
? int(inner_stride_at_compile_time<XprType>::ret) ? int(inner_stride_at_compile_time<XprType>::ret)
: Dynamic, : Dynamic,
@ -78,7 +78,7 @@ struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
//MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0, //MaskAlignedBit = ((OuterStrideAtCompileTime!=Dynamic) && (((OuterStrideAtCompileTime * int(sizeof(Scalar))) % 16) == 0)) ? AlignedBit : 0,
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0, FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = (ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0, FlagsRowMajorBit = (ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0, FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit), Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | MaskPacketAccessBit),
@ -284,7 +284,7 @@ struct evaluator<Reshaped<ArgType, Rows, Cols, Order> >
// OuterStrideAtCompileTime = Dynamic, // OuterStrideAtCompileTime = Dynamic,
FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0, FlagsLinearAccessBit = (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1 || HasDirectAccess) ? LinearAccessBit : 0,
FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==RowMajor) ? RowMajorBit : 0, FlagsRowMajorBit = (traits<XprType>::ReshapedStorageOrder==int(RowMajor)) ? RowMajorBit : 0,
FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0, FlagsDirectAccessBit = HasDirectAccess ? DirectAccessBit : 0,
Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit), Flags0 = evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit),
Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit, Flags = Flags0 | FlagsLinearAccessBit | FlagsRowMajorBit | FlagsDirectAccessBit,

View File

@ -265,11 +265,6 @@ static const auto fix(int val);
#endif // EIGEN_PARSED_BY_DOXYGEN #endif // EIGEN_PARSED_BY_DOXYGEN
const int AutoOrderValue = 2;
const internal::FixedInt<ColMajor> ColOrder;
const internal::FixedInt<RowMajor> RowOrder;
const internal::FixedInt<AutoOrderValue> AutoOrder;
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_INTEGRAL_CONSTANT_H #endif // EIGEN_INTEGRAL_CONSTANT_H

View File

@ -14,6 +14,7 @@
namespace Eigen { namespace Eigen {
enum AutoSize_t { AutoSize }; enum AutoSize_t { AutoSize };
const int AutoOrder = 2;
namespace internal { namespace internal {

View File

@ -27,16 +27,16 @@
/// \sa operator()(placeholders::all), class Reshaped, fix, fix<N>(int) /// \sa operator()(placeholders::all), class Reshaped, fix, fix<N>(int)
/// ///
#ifdef EIGEN_PARSED_BY_DOXYGEN #ifdef EIGEN_PARSED_BY_DOXYGEN
template<typename NRowsType, typename NColsType, typename OrderType = ColOrder> template<int Order = ColMajor, typename NRowsType, typename NColsType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Reshaped<Derived,...> inline Reshaped<Derived,...>
reshaped(NRowsType nRows, NColsType nCols, OrderType order = ColOrder); reshaped(NRowsType nRows, NColsType nCols);
/** This is the const version of reshaped(NRowsType,NColsType). */ /** This is the const version of reshaped(NRowsType,NColsType). */
template<typename NRowsType, typename NColsType, typename OrderType = ColOrder> template<int Order = ColMajor, typename NRowsType, typename NColsType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Reshaped<const Derived,...> inline const Reshaped<const Derived,...>
reshaped(NRowsType nRows, NColsType nCols, OrderType order = ColOrder) const; reshaped(NRowsType nRows, NColsType nCols) const;
/// \returns as expression of \c *this with columns stacked to a linear column vector /// \returns as expression of \c *this with columns stacked to a linear column vector
/// ///
@ -83,18 +83,18 @@ reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size())); internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));
} }
template<typename NRowsType, typename NColsType, typename OrderType> template<int Order, typename NRowsType, typename NColsType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, inline Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value, internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value, internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> Order==AutoOrder?Flags&RowMajorBit:Order>
reshaped(NRowsType nRows, NColsType nCols, OrderType) EIGEN_RESHAPED_METHOD_CONST reshaped(NRowsType nRows, NColsType nCols) EIGEN_RESHAPED_METHOD_CONST
{ {
return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived, return Reshaped<EIGEN_RESHAPED_METHOD_CONST Derived,
internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value, internal::get_compiletime_reshape_size<NRowsType,NColsType,SizeAtCompileTime>::value,
internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value, internal::get_compiletime_reshape_size<NColsType,NRowsType,SizeAtCompileTime>::value,
OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value> Order==AutoOrder?Flags&RowMajorBit:Order>
(derived(), (derived(),
internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()), internal::get_runtime_reshape_size(nRows,internal::get_runtime_value(nCols),size()),
internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size())); internal::get_runtime_reshape_size(nCols,internal::get_runtime_value(nRows),size()));

View File

@ -17,8 +17,8 @@ is_same_eq(const T1& a, const T2& b)
return (a.array() == b.array()).all(); return (a.array() == b.array()).all();
} }
template <typename MatType,typename OrderType> template <int Order,typename MatType>
void check_auto_reshape4x4(MatType m,OrderType order) void check_auto_reshape4x4(MatType m)
{ {
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1); internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 1> v1( 1);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2); internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 2> v2( 2);
@ -26,27 +26,27 @@ void check_auto_reshape4x4(MatType m,OrderType order)
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8); internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1: 8> v8( 8);
internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16); internal::VariableAndFixedInt<MatType::SizeAtCompileTime==Dynamic?-1:16> v16(16);
VERIFY(is_same_eq(m.reshaped( 1, AutoSize, order), m.reshaped( 1, 16, order))); VERIFY(is_same_eq(m.template reshaped<Order>( 1, AutoSize), m.template reshaped<Order>( 1, 16)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 16, order), m.reshaped( 1, 16, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, 16 ), m.template reshaped<Order>( 1, 16)));
VERIFY(is_same_eq(m.reshaped( 2, AutoSize, order), m.reshaped( 2, 8, order))); VERIFY(is_same_eq(m.template reshaped<Order>( 2, AutoSize), m.template reshaped<Order>( 2, 8)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 8, order), m.reshaped( 2, 8, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, 8 ), m.template reshaped<Order>( 2, 8)));
VERIFY(is_same_eq(m.reshaped( 4, AutoSize, order), m.reshaped( 4, 4, order))); VERIFY(is_same_eq(m.template reshaped<Order>( 4, AutoSize), m.template reshaped<Order>( 4, 4)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 4, order), m.reshaped( 4, 4, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, 4 ), m.template reshaped<Order>( 4, 4)));
VERIFY(is_same_eq(m.reshaped( 8, AutoSize, order), m.reshaped( 8, 2, order))); VERIFY(is_same_eq(m.template reshaped<Order>( 8, AutoSize), m.template reshaped<Order>( 8, 2)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 2, order), m.reshaped( 8, 2, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, 2 ), m.template reshaped<Order>( 8, 2)));
VERIFY(is_same_eq(m.reshaped(16, AutoSize, order), m.reshaped(16, 1, order))); VERIFY(is_same_eq(m.template reshaped<Order>(16, AutoSize), m.template reshaped<Order>(16, 1)));
VERIFY(is_same_eq(m.reshaped(AutoSize, 1, order), m.reshaped(16, 1, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, 1 ), m.template reshaped<Order>(16, 1)));
VERIFY(is_same_eq(m.reshaped(fix< 1>, AutoSize, order), m.reshaped(fix< 1>, v16, order))); VERIFY(is_same_eq(m.template reshaped<Order>(fix< 1>, AutoSize), m.template reshaped<Order>(fix< 1>, v16 )));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix<16>, order), m.reshaped( v1, fix<16>, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, fix<16> ), m.template reshaped<Order>( v1, fix<16>)));
VERIFY(is_same_eq(m.reshaped(fix< 2>, AutoSize, order), m.reshaped(fix< 2>, v8, order))); VERIFY(is_same_eq(m.template reshaped<Order>(fix< 2>, AutoSize), m.template reshaped<Order>(fix< 2>, v8 )));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 8>, order), m.reshaped( v2, fix< 8>, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, fix< 8> ), m.template reshaped<Order>( v2, fix< 8>)));
VERIFY(is_same_eq(m.reshaped(fix< 4>, AutoSize, order), m.reshaped(fix< 4>, v4, order))); VERIFY(is_same_eq(m.template reshaped<Order>(fix< 4>, AutoSize), m.template reshaped<Order>(fix< 4>, v4 )));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 4>, order), m.reshaped( v4, fix< 4>, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, fix< 4> ), m.template reshaped<Order>( v4, fix< 4>)));
VERIFY(is_same_eq(m.reshaped(fix< 8>, AutoSize, order), m.reshaped(fix< 8>, v2, order))); VERIFY(is_same_eq(m.template reshaped<Order>(fix< 8>, AutoSize), m.template reshaped<Order>(fix< 8>, v2 )));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 2>, order), m.reshaped( v8, fix< 2>, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, fix< 2> ), m.template reshaped<Order>( v8, fix< 2>)));
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize, order), m.reshaped(fix<16>, v1, order))); VERIFY(is_same_eq(m.template reshaped<Order>(fix<16>, AutoSize), m.template reshaped<Order>(fix<16>, v1 )));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>, order), m.reshaped(v16, fix< 1>, order))); VERIFY(is_same_eq(m.template reshaped<Order>(AutoSize, fix< 1> ), m.template reshaped<Order>(v16, fix< 1>)));
} }
// just test a 4x4 matrix, enumerate all combination manually // just test a 4x4 matrix, enumerate all combination manually
@ -117,12 +117,12 @@ void reshape4x4(MatType m)
VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1))); VERIFY(is_same_eq(m.reshaped(fix<16>, AutoSize), m.reshaped(fix<16>, v1)));
VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>))); VERIFY(is_same_eq(m.reshaped(AutoSize, fix< 1>), m.reshaped(v16, fix< 1>)));
check_auto_reshape4x4(m,ColOrder); check_auto_reshape4x4<ColMajor> (m);
check_auto_reshape4x4(m,RowOrder); check_auto_reshape4x4<RowMajor> (m);
check_auto_reshape4x4(m,AutoOrder); check_auto_reshape4x4<AutoOrder>(m);
check_auto_reshape4x4(m.transpose(),ColOrder); check_auto_reshape4x4<ColMajor> (m.transpose());
check_auto_reshape4x4(m.transpose(),RowOrder); check_auto_reshape4x4<ColMajor> (m.transpose());
check_auto_reshape4x4(m.transpose(),AutoOrder); check_auto_reshape4x4<AutoOrder>(m.transpose());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data()); VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1); VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
@ -133,20 +133,20 @@ void reshape4x4(MatType m)
if((MatType::Flags&RowMajorBit)==0) if((MatType::Flags&RowMajorBit)==0)
{ {
VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8)); VERIFY_IS_EQUAL(m.template reshaped<ColMajor>(2,8),m.reshaped(2,8));
VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8,AutoOrder)); VERIFY_IS_EQUAL(m.template reshaped<ColMajor>(2,8),m.template reshaped<AutoOrder>(2,8));
VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,RowOrder),m.transpose().reshaped(2,8,AutoOrder)); VERIFY_IS_EQUAL(m.transpose().template reshaped<RowMajor>(2,8),m.transpose().template reshaped<AutoOrder>(2,8));
} }
else else
{ {
VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8)); VERIFY_IS_EQUAL(m.template reshaped<ColMajor>(2,8),m.reshaped(2,8));
VERIFY_IS_EQUAL(m.reshaped(2,8,RowOrder),m.reshaped(2,8,AutoOrder)); VERIFY_IS_EQUAL(m.template reshaped<RowMajor>(2,8),m.template reshaped<AutoOrder>(2,8));
VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,ColOrder),m.transpose().reshaped(2,8,AutoOrder)); VERIFY_IS_EQUAL(m.transpose().template reshaped<ColMajor>(2,8),m.transpose().template reshaped<AutoOrder>(2,8));
VERIFY_IS_EQUAL(m.transpose().reshaped(2,8),m.transpose().reshaped(2,8,AutoOrder)); VERIFY_IS_EQUAL(m.transpose().reshaped(2,8),m.transpose().template reshaped<AutoOrder>(2,8));
} }
MatrixXi m28r1 = m.reshaped(2,8,RowOrder); MatrixXi m28r1 = m.template reshaped<RowMajor>(2,8);
MatrixXi m28r2 = m.transpose().reshaped(8,2,ColOrder).transpose(); MatrixXi m28r2 = m.transpose().template reshaped<ColMajor>(8,2).transpose();
VERIFY_IS_EQUAL( m28r1, m28r2); VERIFY_IS_EQUAL( m28r1, m28r2);
using placeholders::all; using placeholders::all;
@ -158,7 +158,7 @@ void reshape4x4(MatType m)
VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2)); VERIFY_IS_EQUAL(m(all).reshaped(8,2), m.reshaped(8,2));
VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all))); VERIFY(is_same_eq(m.reshaped(AutoSize,fix<1>), m(all)));
VERIFY_IS_EQUAL(m.reshaped(fix<1>,AutoSize,RowOrder), m.transpose()(all).transpose()); VERIFY_IS_EQUAL(m.template reshaped<RowMajor>(fix<1>,AutoSize), m.transpose()(all).transpose());
} }
void test_reshape() void test_reshape()