diff --git a/Eigen/src/Core/Block.h b/Eigen/src/Core/Block.h index 6ef26ca7d..31cd094db 100644 --- a/Eigen/src/Core/Block.h +++ b/Eigen/src/Core/Block.h @@ -17,9 +17,10 @@ namespace Eigen { namespace internal { -template -struct traits > : traits +template +struct traits > : traits { + typedef XprType_ XprType; typedef typename traits::Scalar Scalar; typedef typename traits::StorageKind StorageKind; typedef typename traits::XprKind XprKind; @@ -53,12 +54,13 @@ struct traits > : traits::value ? LvalueBit : 0, FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, - Flags = (traits::Flags & (DirectAccessBit | (InnerPanel?CompressedAccessBit:0))) | FlagsLvalueBit | FlagsRowMajorBit, + Flags = (traits::Flags & (DirectAccessBit | (InnerPanel_?CompressedAccessBit:0))) | FlagsLvalueBit | FlagsRowMajorBit, // FIXME DirectAccessBit should not be handled by expressions // // Alignment is needed by MapBase's assertions // We can sefely set it to false here. Internal alignment errors will be detected by an eigen_internal_assert in the respective evaluator - Alignment = 0 + Alignment = 0, + InnerPanel = InnerPanel_ ? 1 : 0 }; }; @@ -107,6 +109,7 @@ template class : public BlockImpl::StorageKind> { typedef BlockImpl::StorageKind> Impl; + using BlockHelper = internal::block_xpr_helper; public: //typedef typename Impl::Base Base; typedef Impl Base; @@ -149,9 +152,25 @@ template class eigen_assert(startRow >= 0 && blockRows >= 0 && startRow <= xpr.rows() - blockRows && startCol >= 0 && blockCols >= 0 && startCol <= xpr.cols() - blockCols); } + + // convert nested blocks (e.g. Block>) to a simple block expression (Block) + + using ConstUnwindReturnType = Block; + using UnwindReturnType = Block; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ConstUnwindReturnType unwind() const { + return ConstUnwindReturnType(BlockHelper::base(*this), BlockHelper::row(*this, 0), BlockHelper::col(*this, 0), + this->rows(), this->cols()); + } + + template ::value>> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE UnwindReturnType unwind() { + return UnwindReturnType(BlockHelper::base(*this), BlockHelper::row(*this, 0), BlockHelper::col(*this, 0), + this->rows(), this->cols()); + } }; -// The generic default implementation for dense block simplu forward to the internal::BlockImpl_dense +// The generic default implementation for dense block simply forward to the internal::BlockImpl_dense // that must be specialized for direct and non-direct access... template class BlockImpl diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index dc2b194c7..bb382a7b5 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -809,6 +809,54 @@ std::string demangle_flags(int f) } #endif +template +struct is_block_xpr : std::false_type {}; + +template +struct is_block_xpr> : std::true_type {}; + +template +struct is_block_xpr> : std::true_type {}; + +// Helper utility for constructing non-recursive block expressions. +template +struct block_xpr_helper { + using BaseType = XprType; + + // For regular block expressions, simply forward along the InnerPanel argument, + // which is set when calling row/column expressions. + static constexpr bool is_inner_panel(bool inner_panel) { return inner_panel; }; + + // Only enable non-const base function if XprType is not const (otherwise we get a duplicate definition). + template::value>> + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BaseType& base(XprType& xpr) { return xpr; } + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const BaseType& base(const XprType& xpr) { return xpr; } + static constexpr EIGEN_ALWAYS_INLINE Index row(const XprType& /*xpr*/, Index r) { return r; } + static constexpr EIGEN_ALWAYS_INLINE Index col(const XprType& /*xpr*/, Index c) { return c; } +}; + +template +struct block_xpr_helper> { + using BlockXprType = Block; + // Recursive helper in case of explicit block-of-block expression. + using NestedXprHelper = block_xpr_helper; + using BaseType = typename NestedXprHelper::BaseType; + + // For block-of-block expressions, we need to combine the InnerPannel trait + // with that of the block subexpression. + static constexpr bool is_inner_panel(bool inner_panel) { return InnerPanel && inner_panel; } + + // Only enable non-const base function if XprType is not const (otherwise we get a duplicates definition). + template::value>> + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BaseType& base(BlockXprType& xpr) { return NestedXprHelper::base(xpr.nestedExpression()); } + static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const BaseType& base(const BlockXprType& xpr) { return NestedXprHelper::base(xpr.nestedExpression()); } + static constexpr EIGEN_ALWAYS_INLINE Index row(const BlockXprType& xpr, Index r) { return xpr.startRow() + NestedXprHelper::row(xpr.nestedExpression(), r); } + static constexpr EIGEN_ALWAYS_INLINE Index col(const BlockXprType& xpr, Index c) { return xpr.startCol() + NestedXprHelper::col(xpr.nestedExpression(), c); } +}; + +template +struct block_xpr_helper> : block_xpr_helper> {}; + } // end namespace internal diff --git a/test/block.cpp b/test/block.cpp index aba089629..867b76952 100644 --- a/test/block.cpp +++ b/test/block.cpp @@ -306,6 +306,43 @@ void data_and_stride(const MatrixType& m) compare_using_data_and_stride(m1.col(c1).transpose()); } + +template +struct unwind_test_impl { + static void run(Xpr& xpr) { + Index startRow = internal::random(0, xpr.rows() / 5); + Index startCol = internal::random(0, xpr.cols() / 6); + Index rows = xpr.rows() / 3; + Index cols = xpr.cols() / 2; + // test equivalence of const expressions + const Block constNestedBlock(xpr, startRow, startCol, rows, cols); + const Block constUnwoundBlock = constNestedBlock.unwind(); + VERIFY_IS_CWISE_EQUAL(constNestedBlock, constUnwoundBlock); + // modify a random element in each representation and test equivalence of non-const expressions + Block nestedBlock(xpr, startRow, startCol, rows, cols); + Block unwoundBlock = nestedBlock.unwind(); + Index r1 = internal::random(0, rows - 1); + Index c1 = internal::random(0, cols - 1); + Index r2 = internal::random(0, rows - 1); + Index c2 = internal::random(0, cols - 1); + nestedBlock.coeffRef(r1, c1) = internal::random::Scalar>(); + unwoundBlock.coeffRef(r2, c2) = internal::random::Scalar>(); + VERIFY_IS_CWISE_EQUAL(nestedBlock, unwoundBlock); + unwind_test_impl, Depth + 1>::run(nestedBlock); + } +}; + +template +struct unwind_test_impl { + static void run(const Xpr&) {} +}; + +template +void unwind_test(const BaseXpr&) { + BaseXpr xpr = BaseXpr::Random(100, 100); + unwind_test_impl::run(xpr); +} + EIGEN_DECLARE_TEST(block) { for(int i = 0; i < g_repeat; i++) { @@ -320,6 +357,7 @@ EIGEN_DECLARE_TEST(block) CALL_SUBTEST_7( block(Matrix(internal::random(2,50), internal::random(2,50))) ); CALL_SUBTEST_8( block(Matrix(3, 4)) ); + CALL_SUBTEST_9( unwind_test(MatrixXf())); #ifndef EIGEN_DEFAULT_TO_ROW_MAJOR CALL_SUBTEST_6( data_and_stride(MatrixXf(internal::random(5,50), internal::random(5,50))) );