From bb1de9dbdede6669c2c86c028a9deff637e3d1f6 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 10 Dec 2020 14:05:38 -0800 Subject: [PATCH] Fix Ref Stride checks. The existing `Ref` class failed to consider cases where the Ref's `Stride` setting *could* match the underlying referred object's stride, but **didn't** at runtime. This led to trying to set invalid stride values, causing runtime failures in some cases, and garbage due to mismatched strides in others. Here we add the missing runtime checks. This involves computing the strides necessary to align with the referred object's storage, and verifying we can actually set those strides at runtime. In the `const` case, if it *may* be possible to refer to the original storage at compile-time but fails at runtime, then we defer to the `construct(...)` method that makes a copy. Added more tests to check these cases. Fixes #2093. --- Eigen/src/Core/Ref.h | 135 ++++++++++++++++++++++++++++++++++++++----- test/ref.cpp | 66 +++++++++++++++++++++ 2 files changed, 185 insertions(+), 16 deletions(-) diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h index 172c8ffb6..00aa45d34 100644 --- a/Eigen/src/Core/Ref.h +++ b/Eigen/src/Core/Ref.h @@ -93,29 +93,127 @@ protected: typedef Stride StrideBase; - template - EIGEN_DEVICE_FUNC void construct(Expression& expr) - { - EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(PlainObjectType,Expression); + // Resolves inner stride if default 0. + static Index resolveInnerStride(Index inner) { + if (inner == 0) { + return 1; + } + return inner; + } + + // Resolves outer stride if default 0. + static Index resolveOuterStride(Index inner, Index outer, Index rows, Index cols, bool isVectorAtCompileTime, bool isRowMajor) { + if (outer == 0) { + if (isVectorAtCompileTime) { + outer = inner * rows * cols; + } else if (isRowMajor) { + outer = inner * cols; + } else { + outer = inner * rows; + } + } + return outer; + } + // Returns true if construction is valid, false if there is a stride mismatch, + // and fails if there is a size mismatch. + template + EIGEN_DEVICE_FUNC bool construct(Expression& expr) + { + // Check matrix sizes. If this is a compile-time vector, we do allow + // implicitly transposing. + EIGEN_STATIC_ASSERT( + EIGEN_PREDICATE_SAME_MATRIX_SIZE(PlainObjectType, Expression) + // If it is a vector, the transpose sizes might match. + || ( PlainObjectType::IsVectorAtCompileTime + && ((int(PlainObjectType::RowsAtCompileTime)==Eigen::Dynamic + || int(Expression::ColsAtCompileTime)==Eigen::Dynamic + || int(PlainObjectType::RowsAtCompileTime)==int(Expression::ColsAtCompileTime)) + && (int(PlainObjectType::ColsAtCompileTime)==Eigen::Dynamic + || int(Expression::RowsAtCompileTime)==Eigen::Dynamic + || int(PlainObjectType::ColsAtCompileTime)==int(Expression::RowsAtCompileTime)))), + YOU_MIXED_MATRICES_OF_DIFFERENT_SIZES + ) + + // Determine runtime rows and columns. + Index rows = expr.rows(); + Index cols = expr.cols(); if(PlainObjectType::RowsAtCompileTime==1) { eigen_assert(expr.rows()==1 || expr.cols()==1); - ::new (static_cast(this)) Base(expr.data(), 1, expr.size()); + rows = 1; + cols = expr.size(); } else if(PlainObjectType::ColsAtCompileTime==1) { eigen_assert(expr.rows()==1 || expr.cols()==1); - ::new (static_cast(this)) Base(expr.data(), expr.size(), 1); + rows = expr.size(); + cols = 1; } - else - ::new (static_cast(this)) Base(expr.data(), expr.rows(), expr.cols()); + // Verify that the sizes are valid. + eigen_assert( + (PlainObjectType::RowsAtCompileTime == Dynamic) || (PlainObjectType::RowsAtCompileTime == rows)); + eigen_assert( + (PlainObjectType::ColsAtCompileTime == Dynamic) || (PlainObjectType::ColsAtCompileTime == cols)); + + + // If this is a vector, we might be transposing, which means that stride should swap. + const bool transpose = PlainObjectType::IsVectorAtCompileTime && (rows != expr.rows()); + // If the storage format differs, we also need to swap the stride. + const bool row_major = ((PlainObjectType::Flags)&RowMajorBit) != 0; + const bool expr_row_major = (Expression::Flags&RowMajorBit) != 0; + const bool storage_differs = (row_major != expr_row_major); + + const bool swap_stride = (transpose != storage_differs); - if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit))) - ::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1); - else - ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(), - StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride()); + // Determine expr's actual strides, resolving any defaults if zero. + const Index expr_inner_actual = resolveInnerStride(expr.innerStride()); + const Index expr_outer_actual = resolveOuterStride(expr_inner_actual, + expr.outerStride(), + expr.rows(), + expr.cols(), + Expression::IsVectorAtCompileTime != 0, + expr_row_major); + + // If this is a column-major row vector or row-major column vector, the inner-stride + // is arbitrary, so set it to either the compile-time inner stride or 1. + const bool row_vector = (rows == 1); + const bool col_vector = (cols == 1); + const Index inner_stride = + ( (!row_major && row_vector) || (row_major && col_vector) ) ? + ( StrideType::InnerStrideAtCompileTime > 0 ? Index(StrideType::InnerStrideAtCompileTime) : 1) + : swap_stride ? expr_outer_actual : expr_inner_actual; + + // If this is a column-major column vector or row-major row vector, the outer-stride + // is arbitrary, so set it to either the compile-time outer stride or vector size. + const Index outer_stride = + ( (!row_major && col_vector) || (row_major && row_vector) ) ? + ( StrideType::OuterStrideAtCompileTime > 0 ? Index(StrideType::OuterStrideAtCompileTime) : rows * cols * inner_stride) + : swap_stride ? expr_inner_actual : expr_outer_actual; + + // Check if given inner/outer strides are compatible with compile-time strides. + const bool inner_valid = (StrideType::InnerStrideAtCompileTime == Dynamic) + || (resolveInnerStride(Index(StrideType::InnerStrideAtCompileTime)) == inner_stride); + if (!inner_valid) { + return false; + } + + const bool outer_valid = (StrideType::OuterStrideAtCompileTime == Dynamic) + || (resolveOuterStride( + inner_stride, + Index(StrideType::OuterStrideAtCompileTime), + rows, cols, PlainObjectType::IsVectorAtCompileTime != 0, + row_major) + == outer_stride); + if (!outer_valid) { + return false; + } + + ::new (static_cast(this)) Base(expr.data(), rows, cols); + ::new (&m_stride) StrideBase( + (StrideType::OuterStrideAtCompileTime == 0) ? 0 : outer_stride, + (StrideType::InnerStrideAtCompileTime == 0) ? 0 : inner_stride ); + return true; } StrideBase m_stride; @@ -212,7 +310,8 @@ template class Ref typename internal::enable_if::MatchAtCompileTime),Derived>::type* = 0) { EIGEN_STATIC_ASSERT(bool(Traits::template match::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); - Base::construct(expr.derived()); + // Construction must pass since we will not create temprary storage in the non-const case. + eigen_assert(Base::construct(expr.derived())); } template EIGEN_DEVICE_FUNC inline Ref(const DenseBase& expr, @@ -226,7 +325,8 @@ template class Ref EIGEN_STATIC_ASSERT(bool(internal::is_lvalue::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY); EIGEN_STATIC_ASSERT(bool(Traits::template match::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); EIGEN_STATIC_ASSERT(!Derived::IsPlainObjectBase,THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY); - Base::construct(expr.const_cast_derived()); + // Construction must pass since we will not create temporary storage in the non-const case. + eigen_assert(Base::construct(expr.const_cast_derived())); } EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Ref) @@ -267,7 +367,10 @@ template class Ref< template EIGEN_DEVICE_FUNC void construct(const Expression& expr,internal::true_type) { - Base::construct(expr); + // Check if we can use the underlying expr's storage directly, otherwise call the copy version. + if (!Base::construct(expr)) { + construct(expr, internal::false_type()); + } } template diff --git a/test/ref.cpp b/test/ref.cpp index c0b6ffdcf..ebfc70d3d 100644 --- a/test/ref.cpp +++ b/test/ref.cpp @@ -141,6 +141,69 @@ template void ref_vector(const VectorType& m) VERIFY_IS_APPROX(mat1, mat2); } +template +void ref_vector_fixed_sizes() +{ + typedef Matrix RowMajorMatrixType; + typedef Matrix ColMajorMatrixType; + typedef Matrix RowVectorType; + typedef Matrix ColVectorType; + typedef Matrix RowVectorTransposeType; + typedef Matrix ColVectorTransposeType; + typedef Stride DynamicStride; + + RowMajorMatrixType mr = RowMajorMatrixType::Random(); + ColMajorMatrixType mc = ColMajorMatrixType::Random(); + + Index i = internal::random(0,Rows-1); + Index j = internal::random(0,Cols-1); + + // Reference ith row. + Ref mr_ri = mr.row(i); + VERIFY_IS_EQUAL(mr_ri, mr.row(i)); + Ref mc_ri = mc.row(i); + VERIFY_IS_EQUAL(mc_ri, mc.row(i)); + + // Reference jth col. + Ref mr_cj = mr.col(j); + VERIFY_IS_EQUAL(mr_cj, mr.col(j)); + Ref mc_cj = mc.col(j); + VERIFY_IS_EQUAL(mc_cj, mc.col(j)); + + // Reference the transpose of row i. + Ref mr_rit = mr.row(i); + VERIFY_IS_EQUAL(mr_rit, mr.row(i).transpose()); + Ref mc_rit = mc.row(i); + VERIFY_IS_EQUAL(mc_rit, mc.row(i).transpose()); + + // Reference the transpose of col j. + Ref mr_cjt = mr.col(j); + VERIFY_IS_EQUAL(mr_cjt, mr.col(j).transpose()); + Ref mc_cjt = mc.col(j); + VERIFY_IS_EQUAL(mc_cjt, mc.col(j).transpose()); + + // Const references without strides. + Ref cmr_ri = mr.row(i); + VERIFY_IS_EQUAL(cmr_ri, mr.row(i)); + Ref cmc_ri = mc.row(i); + VERIFY_IS_EQUAL(cmc_ri, mc.row(i)); + + Ref cmr_cj = mr.col(j); + VERIFY_IS_EQUAL(cmr_cj, mr.col(j)); + Ref cmc_cj = mc.col(j); + VERIFY_IS_EQUAL(cmc_cj, mc.col(j)); + + Ref cmr_rit = mr.row(i); + VERIFY_IS_EQUAL(cmr_rit, mr.row(i).transpose()); + Ref cmc_rit = mc.row(i); + VERIFY_IS_EQUAL(cmc_rit, mc.row(i).transpose()); + + Ref cmr_cjt = mr.col(j); + VERIFY_IS_EQUAL(cmr_cjt, mr.col(j).transpose()); + Ref cmc_cjt = mc.col(j); + VERIFY_IS_EQUAL(cmc_cjt, mc.col(j).transpose()); +} + template void check_const_correctness(const PlainObjectType&) { // verify that ref-to-const don't have LvalueBit @@ -287,6 +350,9 @@ EIGEN_DECLARE_TEST(ref) CALL_SUBTEST_4( ref_matrix(Matrix,10,15>()) ); CALL_SUBTEST_5( ref_matrix(MatrixXi(internal::random(1,10),internal::random(1,10))) ); CALL_SUBTEST_6( call_ref() ); + + CALL_SUBTEST_8( (ref_vector_fixed_sizes()) ); + CALL_SUBTEST_8( (ref_vector_fixed_sizes()) ); } CALL_SUBTEST_7( test_ref_overloads() );