From bb4b67cf39b2a0aae2c912e1fbad50c1cf3b1ab6 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 13 Mar 2014 18:04:19 +0100 Subject: [PATCH] Relax Ref such that Ref accepts a RowVectorXf which can be seen as a degenerate MatrixXf(1,N) --- Eigen/src/Core/Ref.h | 10 +++-- test/ref.cpp | 104 ++++++++++++++++++++++++++----------------- 2 files changed, 69 insertions(+), 45 deletions(-) diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h index 00d9e6d2b..cd6d949c4 100644 --- a/Eigen/src/Core/Ref.h +++ b/Eigen/src/Core/Ref.h @@ -101,7 +101,7 @@ struct traits > template struct match { enum { HasDirectAccess = internal::has_direct_access::ret, - StorageOrderMatch = PlainObjectType::IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)), + StorageOrderMatch = PlainObjectType::IsVectorAtCompileTime || Derived::IsVectorAtCompileTime || ((PlainObjectType::Flags&RowMajorBit)==(Derived::Flags&RowMajorBit)), InnerStrideMatch = int(StrideType::InnerStrideAtCompileTime)==int(Dynamic) || int(StrideType::InnerStrideAtCompileTime)==int(Derived::InnerStrideAtCompileTime) || (int(StrideType::InnerStrideAtCompileTime)==0 && int(Derived::InnerStrideAtCompileTime)==1), @@ -172,8 +172,12 @@ protected: } else ::new (static_cast(this)) Base(expr.data(), expr.rows(), expr.cols()); - ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(), - StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride()); + + if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit))) + ::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1); + else + ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(), + StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride()); } StrideBase m_stride; diff --git a/test/ref.cpp b/test/ref.cpp index f639d900b..19e81549c 100644 --- a/test/ref.cpp +++ b/test/ref.cpp @@ -154,59 +154,79 @@ template void check_const_correctness(const PlainObjec VERIFY( !(Ref::Flags & LvalueBit) ); } -EIGEN_DONT_INLINE void call_ref_1(Ref ) { } -EIGEN_DONT_INLINE void call_ref_2(const Ref& ) { } -EIGEN_DONT_INLINE void call_ref_3(Ref > ) { } -EIGEN_DONT_INLINE void call_ref_4(const Ref >& ) { } -EIGEN_DONT_INLINE void call_ref_5(Ref > ) { } -EIGEN_DONT_INLINE void call_ref_6(const Ref >& ) { } +template +EIGEN_DONT_INLINE void call_ref_1(Ref a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_2(const Ref& a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_3(Ref > a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_4(const Ref >& a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_5(Ref > a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_6(const Ref >& a, const B &b) { VERIFY_IS_EQUAL(a,b); } +template +EIGEN_DONT_INLINE void call_ref_7(Ref > a, const B &b) { VERIFY_IS_EQUAL(a,b); } void call_ref() { - VectorXcf ca(10); - VectorXf a(10); + VectorXcf ca = VectorXcf::Random(10); + VectorXf a = VectorXf::Random(10); + RowVectorXf b = RowVectorXf::Random(10); + MatrixXf A = MatrixXf::Random(10,10); + RowVector3f c = RowVector3f::Random(); const VectorXf& ac(a); VectorBlock ab(a,0,3); - MatrixXf A(10,10); const VectorBlock abc(a,0,3); + - VERIFY_EVALUATION_COUNT( call_ref_1(a), 0); - //call_ref_1(ac); // does not compile because ac is const - VERIFY_EVALUATION_COUNT( call_ref_1(ab), 0); - VERIFY_EVALUATION_COUNT( call_ref_1(a.head(4)), 0); - VERIFY_EVALUATION_COUNT( call_ref_1(abc), 0); - VERIFY_EVALUATION_COUNT( call_ref_1(A.col(3)), 0); - // call_ref_1(A.row(3)); // does not compile because innerstride!=1 - VERIFY_EVALUATION_COUNT( call_ref_3(A.row(3)), 0); - VERIFY_EVALUATION_COUNT( call_ref_4(A.row(3)), 0); - //call_ref_1(a+a); // does not compile for obvious reason + VERIFY_EVALUATION_COUNT( call_ref_1(a,a), 0); + VERIFY_EVALUATION_COUNT( call_ref_1(b,b.transpose()), 0); +// call_ref_1(ac); // does not compile because ac is const + VERIFY_EVALUATION_COUNT( call_ref_1(ab,ab), 0); + VERIFY_EVALUATION_COUNT( call_ref_1(a.head(4),a.head(4)), 0); + VERIFY_EVALUATION_COUNT( call_ref_1(abc,abc), 0); + VERIFY_EVALUATION_COUNT( call_ref_1(A.col(3),A.col(3)), 0); +// call_ref_1(A.row(3)); // does not compile because innerstride!=1 + VERIFY_EVALUATION_COUNT( call_ref_3(A.row(3),A.row(3).transpose()), 0); + VERIFY_EVALUATION_COUNT( call_ref_4(A.row(3),A.row(3).transpose()), 0); +// call_ref_1(a+a); // does not compile for obvious reason - VERIFY_EVALUATION_COUNT( call_ref_2(A*A.col(1)), 1); // evaluated into a temp - VERIFY_EVALUATION_COUNT( call_ref_2(ac.head(5)), 0); - VERIFY_EVALUATION_COUNT( call_ref_2(ac), 0); - VERIFY_EVALUATION_COUNT( call_ref_2(a), 0); - VERIFY_EVALUATION_COUNT( call_ref_2(ab), 0); - VERIFY_EVALUATION_COUNT( call_ref_2(a.head(4)), 0); - VERIFY_EVALUATION_COUNT( call_ref_2(a+a), 1); // evaluated into a temp - VERIFY_EVALUATION_COUNT( call_ref_2(ca.imag()), 1); // evaluated into a temp + MatrixXf tmp = A*A.col(1); + VERIFY_EVALUATION_COUNT( call_ref_2(A*A.col(1), tmp), 1); // evaluated into a temp + VERIFY_EVALUATION_COUNT( call_ref_2(ac.head(5),ac.head(5)), 0); + VERIFY_EVALUATION_COUNT( call_ref_2(ac,ac), 0); + VERIFY_EVALUATION_COUNT( call_ref_2(a,a), 0); + VERIFY_EVALUATION_COUNT( call_ref_2(ab,ab), 0); + VERIFY_EVALUATION_COUNT( call_ref_2(a.head(4),a.head(4)), 0); + tmp = a+a; + VERIFY_EVALUATION_COUNT( call_ref_2(a+a,tmp), 1); // evaluated into a temp + VERIFY_EVALUATION_COUNT( call_ref_2(ca.imag(),ca.imag()), 1); // evaluated into a temp - VERIFY_EVALUATION_COUNT( call_ref_4(ac.head(5)), 0); - VERIFY_EVALUATION_COUNT( call_ref_4(a+a), 1); // evaluated into a temp - VERIFY_EVALUATION_COUNT( call_ref_4(ca.imag()), 0); + VERIFY_EVALUATION_COUNT( call_ref_4(ac.head(5),ac.head(5)), 0); + tmp = a+a; + VERIFY_EVALUATION_COUNT( call_ref_4(a+a,tmp), 1); // evaluated into a temp + VERIFY_EVALUATION_COUNT( call_ref_4(ca.imag(),ca.imag()), 0); - VERIFY_EVALUATION_COUNT( call_ref_5(a), 0); - VERIFY_EVALUATION_COUNT( call_ref_5(a.head(3)), 0); - VERIFY_EVALUATION_COUNT( call_ref_5(A), 0); - // call_ref_5(A.transpose()); // does not compile - VERIFY_EVALUATION_COUNT( call_ref_5(A.block(1,1,2,2)), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(a,a), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(a.head(3),a.head(3)), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(A,A), 0); +// call_ref_5(A.transpose()); // does not compile + VERIFY_EVALUATION_COUNT( call_ref_5(A.block(1,1,2,2),A.block(1,1,2,2)), 0); + VERIFY_EVALUATION_COUNT( call_ref_5(b,b), 0); // storage order do not match, but this is a degenerate case that should work + VERIFY_EVALUATION_COUNT( call_ref_5(a.row(3),a.row(3)), 0); - VERIFY_EVALUATION_COUNT( call_ref_6(a), 0); - VERIFY_EVALUATION_COUNT( call_ref_6(a.head(3)), 0); - VERIFY_EVALUATION_COUNT( call_ref_6(A.row(3)), 1); // evaluated into a temp thouth it could be avoided by viewing it as a 1xn matrix - VERIFY_EVALUATION_COUNT( call_ref_6(A+A), 1); // evaluated into a temp - VERIFY_EVALUATION_COUNT( call_ref_6(A), 0); - VERIFY_EVALUATION_COUNT( call_ref_6(A.transpose()), 1); // evaluated into a temp because the storage orders do not match - VERIFY_EVALUATION_COUNT( call_ref_6(A.block(1,1,2,2)), 0); + VERIFY_EVALUATION_COUNT( call_ref_6(a,a), 0); + VERIFY_EVALUATION_COUNT( call_ref_6(a.head(3),a.head(3)), 0); + VERIFY_EVALUATION_COUNT( call_ref_6(A.row(3),A.row(3)), 1); // evaluated into a temp thouth it could be avoided by viewing it as a 1xn matrix + tmp = A+A; + VERIFY_EVALUATION_COUNT( call_ref_6(A+A,tmp), 1); // evaluated into a temp + VERIFY_EVALUATION_COUNT( call_ref_6(A,A), 0); + VERIFY_EVALUATION_COUNT( call_ref_6(A.transpose(),A.transpose()), 1); // evaluated into a temp because the storage orders do not match + VERIFY_EVALUATION_COUNT( call_ref_6(A.block(1,1,2,2),A.block(1,1,2,2)), 0); + + VERIFY_EVALUATION_COUNT( call_ref_7(c,c), 0); } void test_ref()