diff --git a/Eigen/src/Core/CwiseUnaryView.h b/Eigen/src/Core/CwiseUnaryView.h index 725b33710..fdf68f0fc 100644 --- a/Eigen/src/Core/CwiseUnaryView.h +++ b/Eigen/src/Core/CwiseUnaryView.h @@ -18,7 +18,9 @@ namespace Eigen { namespace internal { template struct traits > : traits { - typedef typename result_of::Scalar&)>::type Scalar; + typedef typename std::result_of::Scalar&)>::type ScalarRef; + static_assert(std::is_reference::value, "Views must return a reference type."); + typedef remove_all_t Scalar; typedef typename MatrixType::Nested MatrixTypeNested; typedef remove_all_t MatrixTypeNested_; enum { @@ -112,7 +114,7 @@ class CwiseUnaryViewImpl EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); } - EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeff(0)); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeffRef(0)); } EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const { return StrideType::InnerStrideAtCompileTime != 0 @@ -128,8 +130,25 @@ class CwiseUnaryViewImpl sizeof(Scalar); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { + return internal::evaluator(derived()).coeffRef(row, col); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return internal::evaluator(derived()).coeffRef(index); + } + protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) + + // Allow const access to coeffRef for the case of direct access being enabled. + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { + return const_cast(this)->coeffRef(index); + } + + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const { + return const_cast(this)->coeffRef(row, col); + } }; } // end namespace Eigen diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index a3fc44c1f..4447b82f7 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -286,9 +286,10 @@ struct functor_traits> { template struct scalar_real_ref_op { typedef typename NumTraits::Real result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type& operator()(const Scalar& a) const { - return numext::real_ref(*const_cast(&a)); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type& operator()(const Scalar& a) const { + return numext::real_ref(a); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type& operator()(Scalar& a) const { return numext::real_ref(a); } }; template struct functor_traits> { @@ -303,8 +304,9 @@ struct functor_traits> { template struct scalar_imag_ref_op { typedef typename NumTraits::Real result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type& operator()(const Scalar& a) const { - return numext::imag_ref(*const_cast(&a)); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type& operator()(Scalar& a) const { return numext::imag_ref(a); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type& operator()(const Scalar& a) const { + return numext::imag_ref(a); } }; template diff --git a/Eigen/src/plugins/CommonCwiseUnaryOps.inc b/Eigen/src/plugins/CommonCwiseUnaryOps.inc index f20f2f817..64f364884 100644 --- a/Eigen/src/plugins/CommonCwiseUnaryOps.inc +++ b/Eigen/src/plugins/CommonCwiseUnaryOps.inc @@ -118,7 +118,7 @@ EIGEN_DEVICE_FUNC inline const CwiseUnaryOp unaryE return CwiseUnaryOp(derived(), func); } -/// \returns an expression of a custom coefficient-wise unary operator \a func of *this +/// \returns a const expression of a custom coefficient-wise unary operator \a func of *this /// /// The template parameter \a CustomUnaryOp is the type of the functor /// of the custom unary operator. @@ -137,6 +137,21 @@ EIGEN_DEVICE_FUNC inline const CwiseUnaryView unary return CwiseUnaryView(derived(), func); } +/// \returns a non-const expression of a custom coefficient-wise unary view \a func of *this +/// +/// The template parameter \a CustomUnaryOp is the type of the functor +/// of the custom unary operator. +/// +EIGEN_DOC_UNARY_ADDONS(unaryViewExpr, unary function) +/// +/// \sa unaryExpr, binaryExpr class CwiseUnaryOp +/// +template +EIGEN_DEVICE_FUNC inline CwiseUnaryView unaryViewExpr( + const CustomViewOp& func = CustomViewOp()) { + return CwiseUnaryView(derived(), func); +} + /// \returns a non const expression of the real part of \c *this. /// EIGEN_DOC_UNARY_ADDONS(real, real part function) diff --git a/failtest/cwiseunaryview_on_const_type_actually_const.cpp b/failtest/cwiseunaryview_on_const_type_actually_const.cpp index 7ecf5425f..fd3c1d64a 100644 --- a/failtest/cwiseunaryview_on_const_type_actually_const.cpp +++ b/failtest/cwiseunaryview_on_const_type_actually_const.cpp @@ -10,7 +10,7 @@ using namespace Eigen; void foo() { MatrixXf m; - CwiseUnaryView, CV_QUALIFIER MatrixXf>(m).coeffRef(0, 0) = 1.0f; + CwiseUnaryView, CV_QUALIFIER MatrixXf>(m).coeffRef(0, 0) = 1.0f; } int main() {} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4c7c3a468..a778d1e75 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -218,7 +218,7 @@ ei_add_test(commainitializer) ei_add_test(smallvectors) ei_add_test(mapped_matrix) ei_add_test(mapstride) -ei_add_test(unaryviewstride) +ei_add_test(unaryview) ei_add_test(mapstaticmethods) ei_add_test(array_cwise) ei_add_test(array_for_matrix) diff --git a/test/unaryview.cpp b/test/unaryview.cpp new file mode 100644 index 000000000..58e95d69d --- /dev/null +++ b/test/unaryview.cpp @@ -0,0 +1,109 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2021 Andrew Johnson +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +template +void unaryview_stride(const VectorType& m) { + typedef typename VectorType::Scalar Scalar; + Index rows = m.rows(); + Index cols = m.cols(); + VectorType vec = VectorType::Random(rows, cols); + + struct view_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const Scalar& v) const { return v; } + }; + + CwiseUnaryView> vec_view(vec); + VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride)); + VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride)); +} + +void test_mutable_unaryview() { + struct Vec3 { + double x; + double y; + double z; + }; + + Eigen::Vector m; + auto x_view = m.unaryViewExpr([](Vec3& v) -> double& { return v.x; }); + auto y_view = m.unaryViewExpr([](Vec3& v) -> double& { return v.y; }); + auto z_view = m.unaryViewExpr([](Vec3& v) -> double& { return v.z; }); + + x_view.setConstant(1); + y_view.setConstant(2); + z_view.setConstant(3); + + for (int i = 0; i < m.size(); ++i) { + VERIFY_IS_EQUAL(m(i).x, 1); + VERIFY_IS_EQUAL(m(i).y, 2); + VERIFY_IS_EQUAL(m(i).z, 3); + } +} + +void test_unaryview_solve() { + // Random upper-triangular system. + Eigen::MatrixXd A = Eigen::MatrixXd::Random(5, 5); + A.triangularView().setZero(); + A.diagonal().setRandom(); + Eigen::VectorXd b = Eigen::VectorXd::Random(5); + + struct trivial_view_op { + double& operator()(double& x) const { return x; } + const double& operator()(const double& x) const { return x; } + }; + + // Non-const view: + { + auto b_view = b.unaryViewExpr(trivial_view_op()); + b_view(0) = 1; // Allows modification. + Eigen::VectorXd x = A.triangularView().solve(b_view); + VERIFY_IS_APPROX(A * x, b); + } + + // Const view: + { + const auto b_view = b.unaryViewExpr(trivial_view_op()); + Eigen::VectorXd x = A.triangularView().solve(b_view); + VERIFY_IS_APPROX(A * x, b); + } + + // Non-const view of const matrix: + { + const Eigen::VectorXd const_b = b; + auto b_view = const_b.unaryViewExpr(trivial_view_op()); + Eigen::VectorXd x = A.triangularView().solve(b_view); + VERIFY_IS_APPROX(A * x, b); + } + + // Const view of const matrix: + { + const Eigen::VectorXd const_b = b; + const auto b_view = const_b.unaryViewExpr(trivial_view_op()); + Eigen::VectorXd x = A.triangularView().solve(b_view); + VERIFY_IS_APPROX(A * x, b); + } + + // Eigen::MatrixXd out = + // mat_in.real() + // .triangularView() + // .solve(mat_in.unaryViewExpr([&](const auto& x){ return std::real(x); })); +} + +EIGEN_DECLARE_TEST(unaryviewstride) { + CALL_SUBTEST_1((unaryview_stride<1, 2>(MatrixXf()))); + CALL_SUBTEST_1((unaryview_stride<0, 0>(MatrixXf()))); + CALL_SUBTEST_2((unaryview_stride<1, 2>(VectorXf()))); + CALL_SUBTEST_2((unaryview_stride<0, 0>(VectorXf()))); + CALL_SUBTEST_3((unaryview_stride<1, 2>(RowVectorXf()))); + CALL_SUBTEST_3((unaryview_stride<0, 0>(RowVectorXf()))); + CALL_SUBTEST_4(test_mutable_unaryview()); + CALL_SUBTEST_4(test_unaryview_solve()); +} diff --git a/test/unaryviewstride.cpp b/test/unaryviewstride.cpp deleted file mode 100644 index 490a5b7d6..000000000 --- a/test/unaryviewstride.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// This file is part of Eigen, a lightweight C++ template library -// for linear algebra. -// -// Copyright (C) 2021 Andrew Johnson -// -// This Source Code Form is subject to the terms of the Mozilla -// Public License v. 2.0. If a copy of the MPL was not distributed -// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#include "main.h" - -template -void unaryview_stride(const VectorType& m) { - typedef typename VectorType::Scalar Scalar; - Index rows = m.rows(); - Index cols = m.cols(); - VectorType vec = VectorType::Random(rows, cols); - - struct view_op { - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const Scalar& v) const { return v; } - }; - - CwiseUnaryView> vec_view(vec); - VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride)); - VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride)); -} - -EIGEN_DECLARE_TEST(unaryviewstride) { - CALL_SUBTEST_1((unaryview_stride<1, 2>(MatrixXf()))); - CALL_SUBTEST_1((unaryview_stride<0, 0>(MatrixXf()))); - CALL_SUBTEST_2((unaryview_stride<1, 2>(VectorXf()))); - CALL_SUBTEST_2((unaryview_stride<0, 0>(VectorXf()))); - CALL_SUBTEST_3((unaryview_stride<1, 2>(RowVectorXf()))); - CALL_SUBTEST_3((unaryview_stride<0, 0>(RowVectorXf()))); -}