diff --git a/Eigen/Core b/Eigen/Core index cc003b075..90bcbc391 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -319,6 +319,7 @@ using std::ptrdiff_t; #include "src/Core/Product.h" #include "src/Core/CoreEvaluators.h" #include "src/Core/AssignEvaluator.h" +#include "src/Core/RealView.h" #include "src/Core/Assign.h" #include "src/Core/ArrayBase.h" diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h index 4f6894280..0333ad167 100644 --- a/Eigen/src/Core/DenseBase.h +++ b/Eigen/src/Core/DenseBase.h @@ -367,7 +367,12 @@ class DenseBase EIGEN_DEVICE_FUNC inline bool allFinite() const; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const Scalar& other); + template ::value, typename = std::enable_if_t> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const RealScalar& other); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const Scalar& other); + template ::value, typename = std::enable_if_t> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const RealScalar& other); typedef internal::add_const_on_value_type_t::type> EvalReturnType; /** \returns the matrix or vector obtained by evaluating this expression. @@ -597,6 +602,13 @@ class DenseBase inline const_iterator end() const; inline const_iterator cend() const; + using RealViewReturnType = std::conditional_t::IsComplex, RealView, Derived&>; + using ConstRealViewReturnType = + std::conditional_t::IsComplex, RealView, const Derived&>; + + EIGEN_DEVICE_FUNC RealViewReturnType realView(); + EIGEN_DEVICE_FUNC ConstRealViewReturnType realView() const; + #define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase #define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL #define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index de599a15c..21a1bfc41 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -253,6 +253,12 @@ struct preinterpret_generic { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; } }; +template +struct preinterpret_generic::as_real, ComplexPacket, false> { + using RealPacket = typename unpacket_traits::as_real; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealPacket run(const ComplexPacket& a) { return a.v; } +}; + /** \internal \returns reinterpret_cast(a) */ template EIGEN_DEVICE_FUNC inline Target preinterpret(const Packet& a) { diff --git a/Eigen/src/Core/RealView.h b/Eigen/src/Core/RealView.h new file mode 100644 index 000000000..7ba42f9a1 --- /dev/null +++ b/Eigen/src/Core/RealView.h @@ -0,0 +1,250 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 Charlie Schlosser +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_REALVIEW_H +#define EIGEN_REALVIEW_H + +// IWYU pragma: private +#include "./InternalHeaderCheck.h" + +namespace Eigen { + +namespace internal { + +// Vectorized assignment to RealView requires array-oriented access to the real and imaginary components. +// From https://en.cppreference.com/w/cpp/numeric/complex.html: +// For any pointer to an element of an array of std::complex named p and any valid array index i, +// reinterpret_cast(p)[2 * i] is the real part of the complex number p[i], and +// reinterpret_cast(p)[2 * i + 1] is the imaginary part of the complex number p[i]. + +template +struct complex_array_access : std::false_type {}; +template <> +struct complex_array_access> : std::true_type {}; +template <> +struct complex_array_access> : std::true_type {}; +template <> +struct complex_array_access> : std::true_type {}; + +template +struct traits> : public traits { + template + static constexpr int double_size(T size, bool times_two) { + int size_as_int = int(size); + if (size_as_int == Dynamic) return Dynamic; + return times_two ? (2 * size_as_int) : size_as_int; + } + using Base = traits; + using ComplexScalar = typename Base::Scalar; + using Scalar = typename NumTraits::Real; + static constexpr int ActualDirectAccessBit = complex_array_access::value ? DirectAccessBit : 0; + static constexpr int ActualPacketAccessBit = packet_traits::Vectorizable ? PacketAccessBit : 0; + static constexpr int FlagMask = + ActualDirectAccessBit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit | LvalueBit; + static constexpr int BaseFlags = int(evaluator::Flags) | int(Base::Flags); + static constexpr int Flags = BaseFlags & FlagMask; + static constexpr bool IsRowMajor = Flags & RowMajorBit; + static constexpr int RowsAtCompileTime = double_size(Base::RowsAtCompileTime, !IsRowMajor); + static constexpr int ColsAtCompileTime = double_size(Base::ColsAtCompileTime, IsRowMajor); + static constexpr int SizeAtCompileTime = size_at_compile_time(RowsAtCompileTime, ColsAtCompileTime); + static constexpr int MaxRowsAtCompileTime = double_size(Base::MaxRowsAtCompileTime, !IsRowMajor); + static constexpr int MaxColsAtCompileTime = double_size(Base::MaxColsAtCompileTime, IsRowMajor); + static constexpr int MaxSizeAtCompileTime = size_at_compile_time(MaxRowsAtCompileTime, MaxColsAtCompileTime); + static constexpr int OuterStrideAtCompileTime = double_size(outer_stride_at_compile_time::ret, true); + static constexpr int InnerStrideAtCompileTime = inner_stride_at_compile_time::ret; +}; + +template +struct evaluator> : private evaluator { + using BaseEvaluator = evaluator; + using XprType = RealView; + using ExpressionTraits = traits; + using ComplexScalar = typename ExpressionTraits::ComplexScalar; + using ComplexCoeffReturnType = typename BaseEvaluator::CoeffReturnType; + using Scalar = typename ExpressionTraits::Scalar; + + static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor; + static constexpr int Flags = ExpressionTraits::Flags; + static constexpr int CoeffReadCost = BaseEvaluator::CoeffReadCost; + static constexpr int Alignment = BaseEvaluator::Alignment; + + EIGEN_DEVICE_FUNC explicit evaluator(XprType realView) : BaseEvaluator(realView.m_xpr) {} + + template ::value, typename = std::enable_if_t> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const { + ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); + Index p = (IsRowMajor ? col : row) & 1; + return p ? numext::real(cscalar) : numext::imag(cscalar); + } + + template ::value, typename = std::enable_if_t> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index row, Index col) const { + ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); + Index p = (IsRowMajor ? col : row) & 1; + return reinterpret_cast(cscalar)[p]; + } + + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { + ComplexScalar& cscalar = BaseEvaluator::coeffRef(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col); + Index p = (IsRowMajor ? col : row) & 1; + return reinterpret_cast(cscalar)[p]; + } + + template ::value, typename = std::enable_if_t> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const { + ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2); + Index p = index & 1; + return p ? numext::real(cscalar) : numext::imag(cscalar); + } + + template ::value, typename = std::enable_if_t> + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const { + ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2); + Index p = index & 1; + return reinterpret_cast(cscalar)[p]; + } + + constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + ComplexScalar& cscalar = BaseEvaluator::coeffRef(index / 2); + Index p = index & 1; + return reinterpret_cast(cscalar)[p]; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { + constexpr int RealPacketSize = unpacket_traits::size; + using ComplexPacket = typename find_packet_by_size::type; + EIGEN_STATIC_ASSERT((find_packet_by_size::value), + MISSING COMPATIBLE COMPLEX PACKET TYPE) + eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even"); + + Index crow = IsRowMajor ? row : row / 2; + Index ccol = IsRowMajor ? col / 2 : col; + ComplexPacket cpacket = BaseEvaluator::template packet(crow, ccol); + return preinterpret(cpacket); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const { + constexpr int RealPacketSize = unpacket_traits::size; + using ComplexPacket = typename find_packet_by_size::type; + EIGEN_STATIC_ASSERT((find_packet_by_size::value), + MISSING COMPATIBLE COMPLEX PACKET TYPE) + eigen_assert((index % 2 == 0) && "the index must be even"); + + Index cindex = index / 2; + ComplexPacket cpacket = BaseEvaluator::template packet(cindex); + return preinterpret(cpacket); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index row, Index col, Index begin, Index count) const { + constexpr int RealPacketSize = unpacket_traits::size; + using ComplexPacket = typename find_packet_by_size::type; + EIGEN_STATIC_ASSERT((find_packet_by_size::value), + MISSING COMPATIBLE COMPLEX PACKET TYPE) + eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even"); + eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even"); + + Index crow = IsRowMajor ? row : row / 2; + Index ccol = IsRowMajor ? col / 2 : col; + Index cbegin = begin / 2; + Index ccount = count / 2; + ComplexPacket cpacket = BaseEvaluator::template packetSegment(crow, ccol, cbegin, ccount); + return preinterpret(cpacket); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index index, Index begin, Index count) const { + constexpr int RealPacketSize = unpacket_traits::size; + using ComplexPacket = typename find_packet_by_size::type; + EIGEN_STATIC_ASSERT((find_packet_by_size::value), + MISSING COMPATIBLE COMPLEX PACKET TYPE) + eigen_assert((index % 2 == 0) && "the index must be even"); + eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even"); + + Index cindex = index / 2; + Index cbegin = begin / 2; + Index ccount = count / 2; + ComplexPacket cpacket = BaseEvaluator::template packetSegment(cindex, cbegin, ccount); + return preinterpret(cpacket); + } +}; + +} // namespace internal + +template +class RealView : public internal::dense_xpr_base>::type { + using ExpressionTraits = internal::traits; + EIGEN_STATIC_ASSERT(NumTraits::IsComplex, SCALAR MUST BE COMPLEX) + public: + using Scalar = typename ExpressionTraits::Scalar; + using Nested = RealView; + + EIGEN_DEVICE_FUNC explicit RealView(Xpr& xpr) : m_xpr(xpr) {} + EIGEN_DEVICE_FUNC constexpr Index rows() const noexcept { return Xpr::IsRowMajor ? m_xpr.rows() : 2 * m_xpr.rows(); } + EIGEN_DEVICE_FUNC constexpr Index cols() const noexcept { return Xpr::IsRowMajor ? 2 * m_xpr.cols() : m_xpr.cols(); } + EIGEN_DEVICE_FUNC constexpr Index size() const noexcept { return 2 * m_xpr.size(); } + EIGEN_DEVICE_FUNC constexpr Index innerStride() const noexcept { return m_xpr.innerStride(); } + EIGEN_DEVICE_FUNC constexpr Index outerStride() const noexcept { return 2 * m_xpr.outerStride(); } + EIGEN_DEVICE_FUNC void resize(Index rows, Index cols) { + m_xpr.resize(Xpr::IsRowMajor ? rows : rows / 2, Xpr::IsRowMajor ? cols / 2 : cols); + } + EIGEN_DEVICE_FUNC void resize(Index size) { m_xpr.resize(size / 2); } + EIGEN_DEVICE_FUNC Scalar* data() { return reinterpret_cast(m_xpr.data()); } + EIGEN_DEVICE_FUNC const Scalar* data() const { return reinterpret_cast(m_xpr.data()); } + + EIGEN_DEVICE_FUNC RealView(const RealView&) = default; + + EIGEN_DEVICE_FUNC RealView& operator=(const RealView& other); + + template + EIGEN_DEVICE_FUNC RealView& operator=(const RealView& other); + + template + EIGEN_DEVICE_FUNC RealView& operator=(const DenseBase& other); + + protected: + friend struct internal::evaluator>; + Xpr& m_xpr; +}; + +template +EIGEN_DEVICE_FUNC RealView& RealView::operator=(const RealView& other) { + internal::call_assignment(*this, other); + return *this; +} + +template +template +EIGEN_DEVICE_FUNC RealView& RealView::operator=(const RealView& other) { + internal::call_assignment(*this, other); + return *this; +} + +template +template +EIGEN_DEVICE_FUNC RealView& RealView::operator=(const DenseBase& other) { + internal::call_assignment(*this, other.derived()); + return *this; +} + +template +EIGEN_DEVICE_FUNC typename DenseBase::RealViewReturnType DenseBase::realView() { + return RealViewReturnType(derived()); +} + +template +EIGEN_DEVICE_FUNC typename DenseBase::ConstRealViewReturnType DenseBase::realView() const { + return ConstRealViewReturnType(derived()); +} + +} // namespace Eigen + +#endif // EIGEN_REALVIEW_H diff --git a/Eigen/src/Core/SelfCwiseBinaryOp.h b/Eigen/src/Core/SelfCwiseBinaryOp.h index f73ceb400..1bc03737e 100644 --- a/Eigen/src/Core/SelfCwiseBinaryOp.h +++ b/Eigen/src/Core/SelfCwiseBinaryOp.h @@ -15,19 +15,33 @@ namespace Eigen { -// TODO generalize the scalar type of 'other' - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::operator*=(const Scalar& other) { - internal::call_assignment(this->derived(), PlainObject::Constant(rows(), cols(), other), - internal::mul_assign_op()); + using ConstantExpr = typename internal::plain_constant_type::type; + using Op = internal::mul_assign_op; + internal::call_assignment(derived(), ConstantExpr(rows(), cols(), other), Op()); + return derived(); +} + +template +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::operator*=(const RealScalar& other) { + realView() *= other; return derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::operator/=(const Scalar& other) { - internal::call_assignment(this->derived(), PlainObject::Constant(rows(), cols(), other), - internal::div_assign_op()); + using ConstantExpr = typename internal::plain_constant_type::type; + using Op = internal::div_assign_op; + internal::call_assignment(derived(), ConstantExpr(rows(), cols(), other), Op()); + return derived(); +} + +template +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::operator/=(const RealScalar& other) { + realView() /= other; return derived(); } diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 4eb134f36..e0bc57eab 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -171,6 +171,8 @@ template class TriangularView; template class SelfAdjointView; +template +class RealView; template class SparseView; template diff --git a/test/realview.cpp b/test/realview.cpp new file mode 100644 index 000000000..8658a3f94 --- /dev/null +++ b/test/realview.cpp @@ -0,0 +1,110 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 The Eigen Authors +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +template +void test_realview(const T&) { + using Scalar = typename T::Scalar; + using RealScalar = typename NumTraits::Real; + + constexpr Index minRows = T::RowsAtCompileTime == Dynamic ? 1 : T::RowsAtCompileTime; + constexpr Index maxRows = T::MaxRowsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxRowsAtCompileTime; + constexpr Index minCols = T::ColsAtCompileTime == Dynamic ? 1 : T::ColsAtCompileTime; + constexpr Index maxCols = T::MaxColsAtCompileTime == Dynamic ? (EIGEN_TEST_MAX_SIZE / 2) : T::MaxColsAtCompileTime; + + constexpr Index rowFactor = (NumTraits::IsComplex && !T::IsRowMajor) ? 2 : 1; + constexpr Index colFactor = (NumTraits::IsComplex && T::IsRowMajor) ? 2 : 1; + constexpr Index sizeFactor = NumTraits::IsComplex ? 2 : 1; + + Index rows = internal::random(minRows, maxRows); + Index cols = internal::random(minCols, maxCols); + + T A(rows, cols), B, C; + + VERIFY(A.realView().rows() == rowFactor * A.rows()); + VERIFY(A.realView().cols() == colFactor * A.cols()); + VERIFY(A.realView().size() == sizeFactor * A.size()); + + RealScalar alpha = internal::random(RealScalar(1), RealScalar(2)); + A.setRandom(); + + VERIFY_IS_APPROX(A.matrix().squaredNorm(), A.realView().matrix().squaredNorm()); + + // test re-sizing realView during assignment + B.realView() = A.realView(); + VERIFY_IS_APPROX(A, B); + VERIFY_IS_APPROX(A.realView(), B.realView()); + + // B = A * alpha + for (Index r = 0; r < rows; r++) { + for (Index c = 0; c < cols; c++) { + B.coeffRef(r, c) = A.coeff(r, c) * Scalar(alpha); + } + } + + VERIFY_IS_APPROX(B.realView(), A.realView() * alpha); + C = A; + C.realView() *= alpha; + VERIFY_IS_APPROX(B, C); + + alpha = internal::random(RealScalar(1), RealScalar(2)); + A.setRandom(); + + // B = A / alpha + for (Index r = 0; r < rows; r++) { + for (Index c = 0; c < cols; c++) { + B.coeffRef(r, c) = A.coeff(r, c) / Scalar(alpha); + } + } + + VERIFY_IS_APPROX(B.realView(), A.realView() / alpha); + C = A; + C.realView() /= alpha; + VERIFY_IS_APPROX(B, C); +} + +template +void test_realview_driver() { + // if Rows == 1, don't test ColMajor as it is not a valid array + using ColMajorMatrixType = Matrix; + using ColMajorArrayType = Array; + // if Cols == 1, don't test RowMajor as it is not a valid array + using RowMajorMatrixType = Matrix; + using RowMajorArrayType = Array; + test_realview(ColMajorMatrixType()); + test_realview(ColMajorArrayType()); + test_realview(RowMajorMatrixType()); + test_realview(RowMajorArrayType()); +} + +template +void test_realview_driver_complex() { + test_realview_driver(); + test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); + test_realview_driver(); + test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); + test_realview_driver(); + test_realview_driver, Rows, Cols, MaxRows, MaxCols>(); +} + +EIGEN_DECLARE_TEST(realview) { + for (int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_1((test_realview_driver_complex())); + CALL_SUBTEST_2((test_realview_driver_complex())); + CALL_SUBTEST_3((test_realview_driver_complex())); + CALL_SUBTEST_4((test_realview_driver_complex())); + CALL_SUBTEST_5((test_realview_driver_complex<17, Dynamic, 17, Dynamic>())); + CALL_SUBTEST_6((test_realview_driver_complex())); + CALL_SUBTEST_7((test_realview_driver_complex<17, 19, 17, 19>())); + CALL_SUBTEST_8((test_realview_driver_complex())); + CALL_SUBTEST_9((test_realview_driver_complex<1, Dynamic>())); + CALL_SUBTEST_10((test_realview_driver_complex<1, 1>())); + } +}