diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h index 172c8ffb6..00aa45d34 100644 --- a/Eigen/src/Core/Ref.h +++ b/Eigen/src/Core/Ref.h @@ -93,29 +93,127 @@ protected: typedef Stride StrideBase; - template - EIGEN_DEVICE_FUNC void construct(Expression& expr) - { - EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(PlainObjectType,Expression); + // Resolves inner stride if default 0. + static Index resolveInnerStride(Index inner) { + if (inner == 0) { + return 1; + } + return inner; + } + + // Resolves outer stride if default 0. + static Index resolveOuterStride(Index inner, Index outer, Index rows, Index cols, bool isVectorAtCompileTime, bool isRowMajor) { + if (outer == 0) { + if (isVectorAtCompileTime) { + outer = inner * rows * cols; + } else if (isRowMajor) { + outer = inner * cols; + } else { + outer = inner * rows; + } + } + return outer; + } + // Returns true if construction is valid, false if there is a stride mismatch, + // and fails if there is a size mismatch. + template + EIGEN_DEVICE_FUNC bool construct(Expression& expr) + { + // Check matrix sizes. If this is a compile-time vector, we do allow + // implicitly transposing. + EIGEN_STATIC_ASSERT( + EIGEN_PREDICATE_SAME_MATRIX_SIZE(PlainObjectType, Expression) + // If it is a vector, the transpose sizes might match. + || ( PlainObjectType::IsVectorAtCompileTime + && ((int(PlainObjectType::RowsAtCompileTime)==Eigen::Dynamic + || int(Expression::ColsAtCompileTime)==Eigen::Dynamic + || int(PlainObjectType::RowsAtCompileTime)==int(Expression::ColsAtCompileTime)) + && (int(PlainObjectType::ColsAtCompileTime)==Eigen::Dynamic + || int(Expression::RowsAtCompileTime)==Eigen::Dynamic + || int(PlainObjectType::ColsAtCompileTime)==int(Expression::RowsAtCompileTime)))), + YOU_MIXED_MATRICES_OF_DIFFERENT_SIZES + ) + + // Determine runtime rows and columns. + Index rows = expr.rows(); + Index cols = expr.cols(); if(PlainObjectType::RowsAtCompileTime==1) { eigen_assert(expr.rows()==1 || expr.cols()==1); - ::new (static_cast(this)) Base(expr.data(), 1, expr.size()); + rows = 1; + cols = expr.size(); } else if(PlainObjectType::ColsAtCompileTime==1) { eigen_assert(expr.rows()==1 || expr.cols()==1); - ::new (static_cast(this)) Base(expr.data(), expr.size(), 1); + rows = expr.size(); + cols = 1; } - else - ::new (static_cast(this)) Base(expr.data(), expr.rows(), expr.cols()); + // Verify that the sizes are valid. + eigen_assert( + (PlainObjectType::RowsAtCompileTime == Dynamic) || (PlainObjectType::RowsAtCompileTime == rows)); + eigen_assert( + (PlainObjectType::ColsAtCompileTime == Dynamic) || (PlainObjectType::ColsAtCompileTime == cols)); + + + // If this is a vector, we might be transposing, which means that stride should swap. + const bool transpose = PlainObjectType::IsVectorAtCompileTime && (rows != expr.rows()); + // If the storage format differs, we also need to swap the stride. + const bool row_major = ((PlainObjectType::Flags)&RowMajorBit) != 0; + const bool expr_row_major = (Expression::Flags&RowMajorBit) != 0; + const bool storage_differs = (row_major != expr_row_major); + + const bool swap_stride = (transpose != storage_differs); - if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit))) - ::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1); - else - ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(), - StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride()); + // Determine expr's actual strides, resolving any defaults if zero. + const Index expr_inner_actual = resolveInnerStride(expr.innerStride()); + const Index expr_outer_actual = resolveOuterStride(expr_inner_actual, + expr.outerStride(), + expr.rows(), + expr.cols(), + Expression::IsVectorAtCompileTime != 0, + expr_row_major); + + // If this is a column-major row vector or row-major column vector, the inner-stride + // is arbitrary, so set it to either the compile-time inner stride or 1. + const bool row_vector = (rows == 1); + const bool col_vector = (cols == 1); + const Index inner_stride = + ( (!row_major && row_vector) || (row_major && col_vector) ) ? + ( StrideType::InnerStrideAtCompileTime > 0 ? Index(StrideType::InnerStrideAtCompileTime) : 1) + : swap_stride ? expr_outer_actual : expr_inner_actual; + + // If this is a column-major column vector or row-major row vector, the outer-stride + // is arbitrary, so set it to either the compile-time outer stride or vector size. + const Index outer_stride = + ( (!row_major && col_vector) || (row_major && row_vector) ) ? + ( StrideType::OuterStrideAtCompileTime > 0 ? Index(StrideType::OuterStrideAtCompileTime) : rows * cols * inner_stride) + : swap_stride ? expr_inner_actual : expr_outer_actual; + + // Check if given inner/outer strides are compatible with compile-time strides. + const bool inner_valid = (StrideType::InnerStrideAtCompileTime == Dynamic) + || (resolveInnerStride(Index(StrideType::InnerStrideAtCompileTime)) == inner_stride); + if (!inner_valid) { + return false; + } + + const bool outer_valid = (StrideType::OuterStrideAtCompileTime == Dynamic) + || (resolveOuterStride( + inner_stride, + Index(StrideType::OuterStrideAtCompileTime), + rows, cols, PlainObjectType::IsVectorAtCompileTime != 0, + row_major) + == outer_stride); + if (!outer_valid) { + return false; + } + + ::new (static_cast(this)) Base(expr.data(), rows, cols); + ::new (&m_stride) StrideBase( + (StrideType::OuterStrideAtCompileTime == 0) ? 0 : outer_stride, + (StrideType::InnerStrideAtCompileTime == 0) ? 0 : inner_stride ); + return true; } StrideBase m_stride; @@ -212,7 +310,8 @@ template class Ref typename internal::enable_if::MatchAtCompileTime),Derived>::type* = 0) { EIGEN_STATIC_ASSERT(bool(Traits::template match::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); - Base::construct(expr.derived()); + // Construction must pass since we will not create temprary storage in the non-const case. + eigen_assert(Base::construct(expr.derived())); } template EIGEN_DEVICE_FUNC inline Ref(const DenseBase& expr, @@ -226,7 +325,8 @@ template class Ref EIGEN_STATIC_ASSERT(bool(internal::is_lvalue::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY); EIGEN_STATIC_ASSERT(bool(Traits::template match::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); EIGEN_STATIC_ASSERT(!Derived::IsPlainObjectBase,THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY); - Base::construct(expr.const_cast_derived()); + // Construction must pass since we will not create temporary storage in the non-const case. + eigen_assert(Base::construct(expr.const_cast_derived())); } EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Ref) @@ -267,7 +367,10 @@ template class Ref< template EIGEN_DEVICE_FUNC void construct(const Expression& expr,internal::true_type) { - Base::construct(expr); + // Check if we can use the underlying expr's storage directly, otherwise call the copy version. + if (!Base::construct(expr)) { + construct(expr, internal::false_type()); + } } template diff --git a/test/ref.cpp b/test/ref.cpp index c0b6ffdcf..ebfc70d3d 100644 --- a/test/ref.cpp +++ b/test/ref.cpp @@ -141,6 +141,69 @@ template void ref_vector(const VectorType& m) VERIFY_IS_APPROX(mat1, mat2); } +template +void ref_vector_fixed_sizes() +{ + typedef Matrix RowMajorMatrixType; + typedef Matrix ColMajorMatrixType; + typedef Matrix RowVectorType; + typedef Matrix ColVectorType; + typedef Matrix RowVectorTransposeType; + typedef Matrix ColVectorTransposeType; + typedef Stride DynamicStride; + + RowMajorMatrixType mr = RowMajorMatrixType::Random(); + ColMajorMatrixType mc = ColMajorMatrixType::Random(); + + Index i = internal::random(0,Rows-1); + Index j = internal::random(0,Cols-1); + + // Reference ith row. + Ref mr_ri = mr.row(i); + VERIFY_IS_EQUAL(mr_ri, mr.row(i)); + Ref mc_ri = mc.row(i); + VERIFY_IS_EQUAL(mc_ri, mc.row(i)); + + // Reference jth col. + Ref mr_cj = mr.col(j); + VERIFY_IS_EQUAL(mr_cj, mr.col(j)); + Ref mc_cj = mc.col(j); + VERIFY_IS_EQUAL(mc_cj, mc.col(j)); + + // Reference the transpose of row i. + Ref mr_rit = mr.row(i); + VERIFY_IS_EQUAL(mr_rit, mr.row(i).transpose()); + Ref mc_rit = mc.row(i); + VERIFY_IS_EQUAL(mc_rit, mc.row(i).transpose()); + + // Reference the transpose of col j. + Ref mr_cjt = mr.col(j); + VERIFY_IS_EQUAL(mr_cjt, mr.col(j).transpose()); + Ref mc_cjt = mc.col(j); + VERIFY_IS_EQUAL(mc_cjt, mc.col(j).transpose()); + + // Const references without strides. + Ref cmr_ri = mr.row(i); + VERIFY_IS_EQUAL(cmr_ri, mr.row(i)); + Ref cmc_ri = mc.row(i); + VERIFY_IS_EQUAL(cmc_ri, mc.row(i)); + + Ref cmr_cj = mr.col(j); + VERIFY_IS_EQUAL(cmr_cj, mr.col(j)); + Ref cmc_cj = mc.col(j); + VERIFY_IS_EQUAL(cmc_cj, mc.col(j)); + + Ref cmr_rit = mr.row(i); + VERIFY_IS_EQUAL(cmr_rit, mr.row(i).transpose()); + Ref cmc_rit = mc.row(i); + VERIFY_IS_EQUAL(cmc_rit, mc.row(i).transpose()); + + Ref cmr_cjt = mr.col(j); + VERIFY_IS_EQUAL(cmr_cjt, mr.col(j).transpose()); + Ref cmc_cjt = mc.col(j); + VERIFY_IS_EQUAL(cmc_cjt, mc.col(j).transpose()); +} + template void check_const_correctness(const PlainObjectType&) { // verify that ref-to-const don't have LvalueBit @@ -287,6 +350,9 @@ EIGEN_DECLARE_TEST(ref) CALL_SUBTEST_4( ref_matrix(Matrix,10,15>()) ); CALL_SUBTEST_5( ref_matrix(MatrixXi(internal::random(1,10),internal::random(1,10))) ); CALL_SUBTEST_6( call_ref() ); + + CALL_SUBTEST_8( (ref_vector_fixed_sizes()) ); + CALL_SUBTEST_8( (ref_vector_fixed_sizes()) ); } CALL_SUBTEST_7( test_ref_overloads() );