fix reshape flag and test case

This commit is contained in:
yoco 2014-02-10 22:49:13 +08:00
parent b64a09acc1
commit 15f273b63c
2 changed files with 21 additions and 28 deletions

View File

@ -53,20 +53,16 @@ struct traits<Reshape<XprType, ReshapeRows, ReshapeCols> > : traits<XprType>
typedef typename traits<XprType>::Scalar Scalar; typedef typename traits<XprType>::Scalar Scalar;
typedef typename traits<XprType>::StorageKind StorageKind; typedef typename traits<XprType>::StorageKind StorageKind;
typedef typename traits<XprType>::XprKind XprKind; typedef typename traits<XprType>::XprKind XprKind;
typedef typename nested<XprType>::type XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
enum{ enum{
MatrixRows = traits<XprType>::RowsAtCompileTime, MatrixRows = traits<XprType>::RowsAtCompileTime,
MatrixCols = traits<XprType>::ColsAtCompileTime, MatrixCols = traits<XprType>::ColsAtCompileTime,
RowsAtCompileTime = MatrixRows == 0 ? 0 : ReshapeRows, RowsAtCompileTime = ReshapeRows,
ColsAtCompileTime = MatrixCols == 0 ? 0 : ReshapeCols, ColsAtCompileTime = ReshapeCols,
MaxRowsAtCompileTime = ReshapeRows==0 ? 0 MaxRowsAtCompileTime = ReshapeRows,
: int(RowsAtCompileTime), MaxColsAtCompileTime = ReshapeCols,
MaxColsAtCompileTime = ReshapeCols==0 ? 0 XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
: int(ColsAtCompileTime), IsRowMajor = (RowsAtCompileTime == 1 && ColsAtCompileTime != 1) ? 1
XprTypeIsRowMajor = (int(traits<XprType>::Flags)&RowMajorBit) != 0, : (ColsAtCompileTime == 1 && RowsAtCompileTime != 1) ? 0
IsRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
: (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
: XprTypeIsRowMajor, : XprTypeIsRowMajor,
HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor), HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
InnerSize = IsRowMajor ? int(ColsAtCompileTime) : int(RowsAtCompileTime), InnerSize = IsRowMajor ? int(ColsAtCompileTime) : int(RowsAtCompileTime),
@ -83,14 +79,11 @@ struct traits<Reshape<XprType, ReshapeRows, ReshapeCols> > : traits<XprType>
FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0, FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
IsSameShapeAtCompileTime = RowsAtCompileTime == ReshapeRows
&& ColsAtCompileTime == ReshapeCols
&& RowsAtCompileTime != Dynamic
&& ColsAtCompileTime != Dynamic,
Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) | Flags0 = traits<XprType>::Flags & ( (HereditaryBits & ~RowMajorBit) |
(traits<XprType>::Flags & ~DirectAccessBit) |
MaskPacketAccessBit | MaskPacketAccessBit |
MaskAlignedBit), MaskAlignedBit)
& ~DirectAccessBit,
Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit) Flags = (Flags0 | FlagsLinearAccessBit | FlagsLvalueBit | FlagsRowMajorBit)
}; };
}; };

View File

@ -18,18 +18,18 @@ template <typename MatType>
void reshape_all_size(MatType m) { void reshape_all_size(MatType m) {
typedef Eigen::Map<MatrixXi> MapMat; typedef Eigen::Map<MatrixXi> MapMat;
// dynamic // dynamic
VERIFY_IS_EQUAL((m.template reshape( 1, 16)), MapMat(m.eval().data(), 1, 16)); VERIFY_IS_EQUAL((m.template reshape( 1, 16)), MapMat(m.data(), 1, 16));
VERIFY_IS_EQUAL((m.template reshape( 2, 8)), MapMat(m.eval().data(), 2, 8)); VERIFY_IS_EQUAL((m.template reshape( 2, 8)), MapMat(m.data(), 2, 8));
VERIFY_IS_EQUAL((m.template reshape( 4, 4)), MapMat(m.eval().data(), 4, 4)); VERIFY_IS_EQUAL((m.template reshape( 4, 4)), MapMat(m.data(), 4, 4));
VERIFY_IS_EQUAL((m.template reshape( 8, 2)), MapMat(m.eval().data(), 8, 2)); VERIFY_IS_EQUAL((m.template reshape( 8, 2)), MapMat(m.data(), 8, 2));
VERIFY_IS_EQUAL((m.template reshape(16, 1)), MapMat(m.eval().data(), 16, 1)); VERIFY_IS_EQUAL((m.template reshape(16, 1)), MapMat(m.data(), 16, 1));
// static // static
VERIFY_IS_EQUAL((m.template reshape< 1, 16>()), MapMat(m.eval().data(), 1, 16)); VERIFY_IS_EQUAL((m.template reshape< 1, 16>()), MapMat(m.data(), 1, 16));
VERIFY_IS_EQUAL((m.template reshape< 2, 8>()), MapMat(m.eval().data(), 2, 8)); VERIFY_IS_EQUAL((m.template reshape< 2, 8>()), MapMat(m.data(), 2, 8));
VERIFY_IS_EQUAL((m.template reshape< 4, 4>()), MapMat(m.eval().data(), 4, 4)); VERIFY_IS_EQUAL((m.template reshape< 4, 4>()), MapMat(m.data(), 4, 4));
VERIFY_IS_EQUAL((m.template reshape< 8, 2>()), MapMat(m.eval().data(), 8, 2)); VERIFY_IS_EQUAL((m.template reshape< 8, 2>()), MapMat(m.data(), 8, 2));
VERIFY_IS_EQUAL((m.template reshape<16, 1>()), MapMat(m.eval().data(), 16, 1)); VERIFY_IS_EQUAL((m.template reshape<16, 1>()), MapMat(m.data(), 16, 1));
// reshape chain // reshape chain
VERIFY_IS_EQUAL( VERIFY_IS_EQUAL(
@ -45,7 +45,7 @@ void reshape_all_size(MatType m) {
.template reshape( 8, 2) .template reshape( 8, 2)
.template reshape< 4, 4>() .template reshape< 4, 4>()
), ),
MapMat(m.eval().data(), 4, 4) MapMat(m.data(), 4, 4)
); );
} }