From 22cd7307dd0320e53399c6b3e8dd15e3aeb13442 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 12 Feb 2025 11:21:44 -0800 Subject: [PATCH] Remove assumption of std::complex for complex scalar types. --- Eigen/src/Core/MathFunctions.h | 76 +++++----- Eigen/src/Core/MathFunctionsImpl.h | 37 ++--- Eigen/src/Core/MatrixBase.h | 4 +- Eigen/src/Core/util/ForwardDeclarations.h | 2 +- Eigen/src/Core/util/Meta.h | 3 + Eigen/src/Core/util/XprHelper.h | 8 +- Eigen/src/Eigenvalues/ComplexEigenSolver.h | 2 +- Eigen/src/Eigenvalues/ComplexSchur.h | 2 +- Eigen/src/Eigenvalues/EigenSolver.h | 2 +- .../src/Eigenvalues/GeneralizedEigenSolver.h | 2 +- Eigen/src/Eigenvalues/RealQZ.h | 2 +- Eigen/src/Eigenvalues/RealSchur.h | 2 +- test/CustomComplex.h | 132 ++++++++++++++++++ test/eigensolver_complex.cpp | 4 + .../Eigen/CXX11/src/Tensor/TensorFFT.h | 35 ++--- .../Eigen/src/IterativeSolvers/DGMRES.h | 11 +- .../src/MatrixFunctions/MatrixExponential.h | 46 +++--- .../src/MatrixFunctions/MatrixFunction.h | 4 +- .../src/MatrixFunctions/MatrixLogarithm.h | 2 +- .../Eigen/src/MatrixFunctions/MatrixPower.h | 8 +- .../Eigen/src/Polynomials/PolynomialSolver.h | 4 +- 21 files changed, 273 insertions(+), 115 deletions(-) create mode 100644 test/CustomComplex.h diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index c081499ee..528aed2c2 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -182,6 +182,35 @@ struct imag_ref_retval { typedef typename NumTraits::Real& type; }; +} // namespace internal + +namespace numext { + +template +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) { + return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x); +} + +template +EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t real_ref( + const Scalar& x) { + return internal::real_ref_impl::run(x); +} + +template +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) { + return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x); +} + +template +EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) { + return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x); +} + +} // namespace numext + +namespace internal { + /**************************************************************************** * Implementation of conj * ****************************************************************************/ @@ -221,7 +250,9 @@ template struct abs2_impl_default // IsComplex { typedef typename NumTraits::Real RealScalar; - EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { return x.real() * x.real() + x.imag() * x.imag(); } + EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { + return numext::real(x) * numext::real(x) + numext::imag(x) * numext::imag(x); + } }; template @@ -250,16 +281,14 @@ struct sqrt_impl { }; // Complex sqrt defined in MathFunctionsImpl.h. -template -EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& a_x); +template +EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& a_x); // Custom implementation is faster than `std::sqrt`, works on // GPU, and correctly handles special cases (unlike MSVC). template struct sqrt_impl> { - EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) { - return complex_sqrt(x); - } + EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) { return complex_sqrt(x); } }; template @@ -272,13 +301,13 @@ template struct rsqrt_impl; // Complex rsqrt defined in MathFunctionsImpl.h. -template -EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& a_x); +template +EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& a_x); template struct rsqrt_impl> { EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex run(const std::complex& x) { - return complex_rsqrt(x); + return complex_rsqrt(x); } }; @@ -299,7 +328,7 @@ struct norm1_default_impl { typedef typename NumTraits::Real RealScalar; EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { EIGEN_USING_STD(abs); - return abs(x.real()) + abs(x.imag()); + return abs(numext::real(x)) + abs(numext::imag(x)); } }; @@ -469,8 +498,8 @@ struct expm1_retval { ****************************************************************************/ // Complex log defined in MathFunctionsImpl.h. -template -EIGEN_DEVICE_FUNC std::complex complex_log(const std::complex& z); +template +EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z); template struct log_impl { @@ -846,7 +875,7 @@ struct sign_impl { real_type aa = abs(a); if (aa == real_type(0)) return Scalar(0); aa = real_type(1) / aa; - return Scalar(a.real() * aa, a.imag() * aa); + return Scalar(numext::real(a) * aa, numext::imag(a) * aa); } }; @@ -1042,27 +1071,6 @@ SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(maxi, fmax) #endif -template -EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) { - return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x); -} - -template -EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t real_ref( - const Scalar& x) { - return internal::real_ref_impl::run(x); -} - -template -EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) { - return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x); -} - -template -EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) { - return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x); -} - template EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(arg, Scalar) arg(const Scalar& x) { return EIGEN_MATHFUNC_IMPL(arg, Scalar)::run(x); diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 10ddabd79..8e2705ba9 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -171,8 +171,8 @@ struct hypot_impl { // Generic complex sqrt implementation that correctly handles corner cases // according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt -template -EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& z) { +template +EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& z) { // Computes the principal sqrt of the input. // // For a complex square root of the number x + i*y. We want to find real @@ -194,21 +194,21 @@ EIGEN_DEVICE_FUNC std::complex complex_sqrt(const std::complex& z) { // if x == 0: u = w, v = sign(y) * w // if x > 0: u = w, v = y / (2 * w) // if x < 0: u = |y| / (2 * w), v = sign(y) * w - + using T = typename NumTraits::Real; const T x = numext::real(z); const T y = numext::imag(z); const T zero = T(0); const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y))); - return (numext::isinf)(y) ? std::complex(NumTraits::infinity(), y) - : numext::is_exactly_zero(x) ? std::complex(w, y < zero ? -w : w) - : x > zero ? std::complex(w, y / (2 * w)) - : std::complex(numext::abs(y) / (2 * w), y < zero ? -w : w); + return (numext::isinf)(y) ? ComplexT(NumTraits::infinity(), y) + : numext::is_exactly_zero(x) ? ComplexT(w, y < zero ? -w : w) + : x > zero ? ComplexT(w, y / (2 * w)) + : ComplexT(numext::abs(y) / (2 * w), y < zero ? -w : w); } // Generic complex rsqrt implementation. -template -EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { +template +EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& z) { // Computes the principal reciprocal sqrt of the input. // // For a complex reciprocal square root of the number z = x + i*y. We want to @@ -230,7 +230,7 @@ EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { // if x == 0: u = w / |z|, v = -sign(y) * w / |z| // if x > 0: u = w / |z|, v = -y / (2 * w * |z|) // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z| - + using T = typename NumTraits::Real; const T x = numext::real(z); const T y = numext::imag(z); const T zero = T(0); @@ -239,20 +239,21 @@ EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z)); const T woz = w / abs_z; // Corner cases consistent with 1/sqrt(z) on gcc/clang. - return numext::is_exactly_zero(abs_z) ? std::complex(NumTraits::infinity(), NumTraits::quiet_NaN()) - : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex(zero, zero) - : numext::is_exactly_zero(x) ? std::complex(woz, y < zero ? woz : -woz) - : x > zero ? std::complex(woz, -y / (2 * w * abs_z)) - : std::complex(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz); + return numext::is_exactly_zero(abs_z) ? ComplexT(NumTraits::infinity(), NumTraits::quiet_NaN()) + : ((numext::isinf)(x) || (numext::isinf)(y)) ? ComplexT(zero, zero) + : numext::is_exactly_zero(x) ? ComplexT(woz, y < zero ? woz : -woz) + : x > zero ? ComplexT(woz, -y / (2 * w * abs_z)) + : ComplexT(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz); } -template -EIGEN_DEVICE_FUNC std::complex complex_log(const std::complex& z) { +template +EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) { // Computes complex log. + using T = typename NumTraits::Real; T a = numext::abs(z); EIGEN_USING_STD(atan2); T b = atan2(z.imag(), z.real()); - return std::complex(numext::log(a), b); + return ComplexT(numext::log(a), b); } } // end namespace internal diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index 5466a57c4..8d5c47e47 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -112,7 +112,7 @@ class MatrixBase : public DenseBase { ConstTransposeReturnType> AdjointReturnType; /** \internal Return type of eigenvalues() */ - typedef Matrix, internal::traits::ColsAtCompileTime, 1, ColMajor> + typedef Matrix, internal::traits::ColsAtCompileTime, 1, ColMajor> EigenvaluesReturnType; /** \internal the return type of identity */ typedef CwiseNullaryOp, PlainObject> IdentityReturnType; @@ -468,7 +468,7 @@ class MatrixBase : public DenseBase { EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root) EIGEN_MATRIX_FUNCTION(MatrixLogarithmReturnValue, log, logarithm) EIGEN_MATRIX_FUNCTION_1(MatrixPowerReturnValue, pow, power to \c p, const RealScalar& p) - EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const std::complex& p) + EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const internal::make_complex_t& p) protected: EIGEN_DEFAULT_COPY_CONSTRUCTOR(MatrixBase) diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index fcde64afe..2488be46f 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -497,7 +497,7 @@ class MatrixComplexPowerReturnValue; namespace internal { template struct stem_function { - typedef std::complex::Real> ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef ComplexScalar type(ComplexScalar, int); }; } // namespace internal diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h index 2e3f59e76..072247358 100644 --- a/Eigen/src/Core/util/Meta.h +++ b/Eigen/src/Core/util/Meta.h @@ -745,6 +745,9 @@ using std::is_constant_evaluated; constexpr bool is_constant_evaluated() { return false; } #endif +template +using make_complex_t = std::conditional_t::IsComplex, Scalar, std::complex>; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index cecbee865..a42bb0f73 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -885,8 +885,12 @@ struct scalar_div_cost { }; template -struct scalar_div_cost, Vectorized> { - enum { value = 2 * scalar_div_cost::value + 6 * NumTraits::MulCost + 3 * NumTraits::AddCost }; +struct scalar_div_cost::IsComplex>> { + using RealScalar = typename NumTraits::Real; + enum { + value = + 2 * scalar_div_cost::value + 6 * NumTraits::MulCost + 3 * NumTraits::AddCost + }; }; template diff --git a/Eigen/src/Eigenvalues/ComplexEigenSolver.h b/Eigen/src/Eigenvalues/ComplexEigenSolver.h index 60a24a899..50fa3b809 100644 --- a/Eigen/src/Eigenvalues/ComplexEigenSolver.h +++ b/Eigen/src/Eigenvalues/ComplexEigenSolver.h @@ -70,7 +70,7 @@ class ComplexEigenSolver { * \c float or \c double) and just \c Scalar if #Scalar is * complex. */ - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; /** \brief Type for vector of eigenvalues as returned by eigenvalues(). * diff --git a/Eigen/src/Eigenvalues/ComplexSchur.h b/Eigen/src/Eigenvalues/ComplexSchur.h index a33e46ee7..22433f2bd 100644 --- a/Eigen/src/Eigenvalues/ComplexSchur.h +++ b/Eigen/src/Eigenvalues/ComplexSchur.h @@ -75,7 +75,7 @@ class ComplexSchur { * \c float or \c double) and just \c Scalar if #Scalar is * complex. */ - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; /** \brief Type for the matrices in the Schur decomposition. * diff --git a/Eigen/src/Eigenvalues/EigenSolver.h b/Eigen/src/Eigenvalues/EigenSolver.h index f73d58f87..9dba7bd18 100644 --- a/Eigen/src/Eigenvalues/EigenSolver.h +++ b/Eigen/src/Eigenvalues/EigenSolver.h @@ -89,7 +89,7 @@ class EigenSolver { * \c float or \c double) and just \c Scalar if #Scalar is * complex. */ - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; /** \brief Type for vector of eigenvalues as returned by eigenvalues(). * diff --git a/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h b/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h index b114cfab5..c0a61dcd4 100644 --- a/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +++ b/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h @@ -83,7 +83,7 @@ class GeneralizedEigenSolver { * \c float or \c double) and just \c Scalar if #Scalar is * complex. */ - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; /** \brief Type for vector of real scalar values eigenvalues as returned by betas(). * diff --git a/Eigen/src/Eigenvalues/RealQZ.h b/Eigen/src/Eigenvalues/RealQZ.h index 3466f51c1..591538729 100644 --- a/Eigen/src/Eigenvalues/RealQZ.h +++ b/Eigen/src/Eigenvalues/RealQZ.h @@ -69,7 +69,7 @@ class RealQZ { MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; typedef typename MatrixType::Scalar Scalar; - typedef std::complex::Real> ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 typedef Matrix EigenvalueType; diff --git a/Eigen/src/Eigenvalues/RealSchur.h b/Eigen/src/Eigenvalues/RealSchur.h index 5cef6587b..54a74e2f5 100644 --- a/Eigen/src/Eigenvalues/RealSchur.h +++ b/Eigen/src/Eigenvalues/RealSchur.h @@ -66,7 +66,7 @@ class RealSchur { MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime }; typedef typename MatrixType::Scalar Scalar; - typedef std::complex::Real> ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 typedef Matrix EigenvalueType; diff --git a/test/CustomComplex.h b/test/CustomComplex.h new file mode 100644 index 000000000..048f65b48 --- /dev/null +++ b/test/CustomComplex.h @@ -0,0 +1,132 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 The Eigen Authors. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TEST_CUSTOM_COMPLEX_H +#define EIGEN_TEST_CUSTOM_COMPLEX_H + +#include +#include + +namespace custom_complex { + +template +struct CustomComplex { + CustomComplex() : re{0}, im{0} {} + CustomComplex(const CustomComplex& other) = default; + CustomComplex(CustomComplex&& other) = default; + CustomComplex& operator=(const CustomComplex& other) = default; + CustomComplex& operator=(CustomComplex&& other) = default; + CustomComplex(Real x) : re{x}, im{0} {} + CustomComplex(Real x, Real y) : re{x}, im{y} {} + + CustomComplex operator+(const CustomComplex& other) const { return CustomComplex(re + other.re, im + other.im); } + + CustomComplex operator-() const { return CustomComplex(-re, -im); } + + CustomComplex operator-(const CustomComplex& other) const { return CustomComplex(re - other.re, im - other.im); } + + CustomComplex operator*(const CustomComplex& other) const { + return CustomComplex(re * other.re - im * other.im, re * other.im + im * other.re); + } + + CustomComplex operator/(const CustomComplex& other) const { + // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf), + // guards against over/under-flow. + const bool scale_imag = numext::abs(other.im) <= numext::abs(other.re); + const Real rscale = scale_imag ? Real(1) : other.re / other.im; + const Real iscale = scale_imag ? other.im / other.re : Real(1); + const Real denominator = other.re * rscale + other.im * iscale; + return CustomComplex((re * rscale + im * iscale) / denominator, (im * rscale - re * iscale) / denominator); + } + + CustomComplex& operator+=(const CustomComplex& other) { + *this = *this + other; + return *this; + } + CustomComplex& operator-=(const CustomComplex& other) { + *this = *this - other; + return *this; + } + CustomComplex& operator*=(const CustomComplex& other) { + *this = *this * other; + return *this; + } + CustomComplex& operator/=(const CustomComplex& other) { + *this = *this / other; + return *this; + } + + bool operator==(const CustomComplex& other) const { + return numext::equal_strict(re, other.re) && numext::equal_strict(im, other.im); + } + bool operator!=(const CustomComplex& other) const { return !(*this == other); } + + friend CustomComplex operator+(const Real& a, const CustomComplex& b) { return CustomComplex(a + b.re, b.im); } + + friend CustomComplex operator-(const Real& a, const CustomComplex& b) { return CustomComplex(a - b.re, -b.im); } + + friend CustomComplex operator*(const Real& a, const CustomComplex& b) { return CustomComplex(a * b.re, a * b.im); } + + friend CustomComplex operator*(const CustomComplex& a, const Real& b) { return CustomComplex(a.re * b, a.im * b); } + + friend CustomComplex operator/(const CustomComplex& a, const Real& b) { return CustomComplex(a.re / b, a.im / b); } + + friend std::ostream& operator<<(std::ostream& stream, const CustomComplex& x) { + std::stringstream ss; + ss << "(" << x.re << ", " << x.im << ")"; + stream << ss.str(); + return stream; + } + + Real re; + Real im; +}; + +template +Real real(const CustomComplex& x) { + return x.re; +} +template +Real imag(const CustomComplex& x) { + return x.im; +} +template +CustomComplex conj(const CustomComplex& x) { + return CustomComplex(x.re, -x.im); +} +template +CustomComplex sqrt(const CustomComplex& x) { + return Eigen::internal::complex_sqrt(x); +} +template +Real abs(const CustomComplex& x) { + return Eigen::numext::sqrt(x.re * x.re + x.im * x.im); +} + +} // namespace custom_complex + +template +using CustomComplex = custom_complex::CustomComplex; + +namespace Eigen { +template +struct NumTraits> : NumTraits { + enum { IsComplex = 1 }; +}; + +namespace numext { +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const CustomComplex& x) { + return (numext::isfinite)(x.re) && (numext::isfinite)(x.im); +} + +} // namespace numext +} // namespace Eigen + +#endif // EIGEN_TEST_CUSTOM_COMPLEX_H diff --git a/test/eigensolver_complex.cpp b/test/eigensolver_complex.cpp index afb24b9af..76846a933 100644 --- a/test/eigensolver_complex.cpp +++ b/test/eigensolver_complex.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "CustomComplex.h" template bool find_pivot(typename MatrixType::Scalar tol, MatrixType& diffs, Index col = 0) { @@ -165,5 +166,8 @@ EIGEN_DECLARE_TEST(eigensolver_complex) { // Test problem size constructors CALL_SUBTEST_5(ComplexEigenSolver tmp(s)); + // Test custom complex scalar type. + CALL_SUBTEST_6(eigensolver(Matrix, 5, 5>())); + TEST_SET_BUT_UNUSED_VARIABLE(s) } diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h index 7521559aa..b9d6f376b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h @@ -15,7 +15,7 @@ namespace Eigen { -template +template struct MakeComplex { template EIGEN_DEVICE_FUNC T operator()(const T& val) const { @@ -26,16 +26,8 @@ struct MakeComplex { template <> struct MakeComplex { template - EIGEN_DEVICE_FUNC std::complex operator()(const T& val) const { - return std::complex(val, 0); - } -}; - -template <> -struct MakeComplex { - template - EIGEN_DEVICE_FUNC std::complex operator()(const std::complex& val) const { - return val; + EIGEN_DEVICE_FUNC internal::make_complex_t operator()(const T& val) const { + return internal::make_complex_t(val, T(0)); } }; @@ -49,17 +41,17 @@ struct PartOf { template <> struct PartOf { - template - T operator()(const std::complex& val) const { - return val.real(); + template ::IsComplex>> + typename NumTraits::Real operator()(const T& val) const { + return Eigen::numext::real(val); } }; template <> struct PartOf { - template - T operator()(const std::complex& val) const { - return val.imag(); + template ::IsComplex>> + typename NumTraits::Real operator()(const T& val) const { + return Eigen::numext::imag(val); } }; @@ -67,8 +59,9 @@ namespace internal { template struct traits > : public traits { typedef traits XprTraits; - typedef typename NumTraits::Real RealScalar; - typedef typename std::complex ComplexScalar; + typedef typename XprTraits::Scalar Scalar; + typedef typename NumTraits::Real RealScalar; + typedef make_complex_t ComplexScalar; typedef typename XprTraits::Scalar InputScalar; typedef std::conditional_t OutputScalar; @@ -109,7 +102,7 @@ class TensorFFTOp : public TensorBase::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; - typedef typename std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef std::conditional_t OutputScalar; typedef OutputScalar CoeffReturnType; @@ -137,7 +130,7 @@ struct TensorEvaluator, D typedef DSizes Dimensions; typedef typename XprType::Scalar Scalar; typedef typename Eigen::NumTraits::Real RealScalar; - typedef typename std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef typename TensorEvaluator::Dimensions InputDimensions; typedef internal::traits XprTraits; typedef typename XprTraits::Scalar InputScalar; diff --git a/unsupported/Eigen/src/IterativeSolvers/DGMRES.h b/unsupported/Eigen/src/IterativeSolvers/DGMRES.h index 182bd2e07..6f6df3edd 100644 --- a/unsupported/Eigen/src/IterativeSolvers/DGMRES.h +++ b/unsupported/Eigen/src/IterativeSolvers/DGMRES.h @@ -111,12 +111,13 @@ class DGMRES : public IterativeSolverBase > typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::StorageIndex StorageIndex; typedef typename MatrixType::RealScalar RealScalar; + typedef internal::make_complex_t ComplexScalar; typedef Preconditioner_ Preconditioner; typedef Matrix DenseMatrix; typedef Matrix DenseRealMatrix; typedef Matrix DenseVector; typedef Matrix DenseRealVector; - typedef Matrix, Dynamic, 1> ComplexVector; + typedef Matrix ComplexVector; /** Default constructor. */ DGMRES() @@ -389,15 +390,15 @@ inline typename DGMRES::ComplexVector DGMRES(T(j, j), RealScalar(0)); + eig(j) = ComplexScalar(T(j, j), RealScalar(0)); j++; } else { - eig(j) = std::complex(T(j, j), T(j + 1, j)); - eig(j + 1) = std::complex(T(j, j + 1), T(j + 1, j + 1)); + eig(j) = ComplexScalar(T(j, j), T(j + 1, j)); + eig(j + 1) = ComplexScalar(T(j, j + 1), T(j + 1, j + 1)); j++; } } - if (j < it - 1) eig(j) = std::complex(T(j, j), RealScalar(0)); + if (j < it - 1) eig(j) = ComplexScalar(T(j, j), RealScalar(0)); return eig; } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h index ff955e1d7..a28aa9695 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h @@ -23,8 +23,10 @@ namespace internal { * * This struct is used by CwiseUnaryOp to scale a matrix by \f$ 2^{-s} \f$. */ -template +template ::IsComplex> struct MatrixExponentialScalingOp { + using RealScalar = typename NumTraits::Real; + /** \brief Constructor. * * \param[in] squarings The integer \f$ s \f$ in this document. @@ -35,20 +37,30 @@ struct MatrixExponentialScalingOp { * * \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$. */ - inline const RealScalar operator()(const RealScalar& x) const { + inline const Scalar operator()(const Scalar& x) const { using std::ldexp; - return ldexp(x, -m_squarings); + return Scalar(ldexp(Eigen::numext::real(x), -m_squarings), ldexp(Eigen::numext::imag(x), -m_squarings)); } - typedef std::complex ComplexScalar; + private: + int m_squarings; +}; + +template +struct MatrixExponentialScalingOp { + /** \brief Constructor. + * + * \param[in] squarings The integer \f$ s \f$ in this document. + */ + MatrixExponentialScalingOp(int squarings) : m_squarings(squarings) {} /** \brief Scale a matrix coefficient. * * \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$. */ - inline const ComplexScalar operator()(const ComplexScalar& x) const { + inline const Scalar operator()(const Scalar& x) const { using std::ldexp; - return ComplexScalar(ldexp(x.real(), -m_squarings), ldexp(x.imag(), -m_squarings)); + return ldexp(x, -m_squarings); } private: @@ -220,6 +232,7 @@ struct matrix_exp_computeUV { template struct matrix_exp_computeUV { + using Scalar = typename traits::Scalar; template static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { using std::frexp; @@ -234,7 +247,7 @@ struct matrix_exp_computeUV { const float maxnorm = 3.925724783138660f; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); matrix_exp_pade7(A, U, V); } } @@ -242,12 +255,12 @@ struct matrix_exp_computeUV { template struct matrix_exp_computeUV { - typedef typename NumTraits::Scalar>::Real RealScalar; + using Scalar = typename traits::Scalar; template static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { using std::frexp; using std::pow; - const RealScalar l1norm = arg.cwiseAbs().colwise().sum().maxCoeff(); + const double l1norm = arg.cwiseAbs().colwise().sum().maxCoeff(); squarings = 0; if (l1norm < 1.495585217958292e-002) { matrix_exp_pade3(arg, U, V); @@ -258,10 +271,10 @@ struct matrix_exp_computeUV { } else if (l1norm < 2.097847961257068e+000) { matrix_exp_pade9(arg, U, V); } else { - const RealScalar maxnorm = 5.371920351148152; + const double maxnorm = 5.371920351148152; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); matrix_exp_pade13(A, U, V); } } @@ -271,6 +284,7 @@ template struct matrix_exp_computeUV { template static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { + using Scalar = typename traits::Scalar; #if LDBL_MANT_DIG == 53 // double precision matrix_exp_computeUV::run(arg, U, V, squarings); @@ -295,7 +309,7 @@ struct matrix_exp_computeUV { const long double maxnorm = 4.0246098906697353063L; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); matrix_exp_pade13(A, U, V); } @@ -315,7 +329,7 @@ struct matrix_exp_computeUV { const long double maxnorm = 3.2579440895405400856599663723517L; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); matrix_exp_pade17(A, U, V); } @@ -335,7 +349,7 @@ struct matrix_exp_computeUV { const long double maxnorm = 2.884233277829519311757165057717815L; frexp(l1norm / maxnorm, &squarings); if (squarings < 0) squarings = 0; - MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); + MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp(squarings)); matrix_exp_pade17(A, U, V); } @@ -382,9 +396,7 @@ template void matrix_exp_compute(const ArgType& arg, ResultType& result, false_type) // default { typedef typename ArgType::PlainObject MatrixType; - typedef typename traits::Scalar Scalar; - typedef typename NumTraits::Real RealScalar; - typedef typename std::complex ComplexScalar; + typedef make_complex_t::Scalar> ComplexScalar; result = arg.matrixFunction(internal::stem_function_exp); } diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h index 68336a525..0c18ad66a 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h @@ -382,7 +382,7 @@ struct matrix_function_compute { static const int Rows = Traits::RowsAtCompileTime, Cols = Traits::ColsAtCompileTime; static const int MaxRows = Traits::MaxRowsAtCompileTime, MaxCols = Traits::MaxColsAtCompileTime; - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Matrix ComplexMatrix; ComplexMatrix CA = A.template cast(); @@ -476,7 +476,7 @@ class MatrixFunctionReturnValue : public ReturnByValue::type NestedEvalType; typedef internal::remove_all_t NestedEvalTypeClean; typedef internal::traits Traits; - typedef std::complex::Real> ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Matrix DynMatrixType; diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h index 4228166d1..398971ebb 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h @@ -330,7 +330,7 @@ class MatrixLogarithmReturnValue : public ReturnByValue::type DerivedEvalType; typedef internal::remove_all_t DerivedEvalTypeClean; typedef internal::traits Traits; - typedef std::complex::Real> ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Matrix DynMatrixType; typedef internal::MatrixLogarithmAtomic AtomicType; diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h index bff619a9c..a420ee709 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixPower.h @@ -91,7 +91,7 @@ class MatrixPowerAtomic : internal::noncopyable { enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime }; typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::RealScalar RealScalar; - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Block ResultType; const MatrixType& m_A; @@ -380,7 +380,7 @@ class MatrixPower : internal::noncopyable { Index cols() const { return m_A.cols(); } private: - typedef std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; typedef Matrix ComplexMatrix; @@ -628,7 +628,7 @@ template class MatrixComplexPowerReturnValue : public ReturnByValue > { public: typedef typename Derived::PlainObject PlainObject; - typedef typename std::complex ComplexScalar; + typedef internal::make_complex_t ComplexScalar; /** * \brief Constructor. @@ -685,7 +685,7 @@ const MatrixPowerReturnValue MatrixBase::pow(const RealScalar& } template -const MatrixComplexPowerReturnValue MatrixBase::pow(const std::complex& p) const { +const MatrixComplexPowerReturnValue MatrixBase::pow(const internal::make_complex_t& p) const { return MatrixComplexPowerReturnValue(derived(), p); } diff --git a/unsupported/Eigen/src/Polynomials/PolynomialSolver.h b/unsupported/Eigen/src/Polynomials/PolynomialSolver.h index 8c0ce3b71..aa357a41a 100644 --- a/unsupported/Eigen/src/Polynomials/PolynomialSolver.h +++ b/unsupported/Eigen/src/Polynomials/PolynomialSolver.h @@ -35,7 +35,7 @@ class PolynomialSolverBase { typedef Scalar_ Scalar; typedef typename NumTraits::Real RealScalar; - typedef std::complex RootType; + typedef internal::make_complex_t RootType; typedef Matrix RootsType; typedef DenseIndex Index; @@ -308,7 +308,7 @@ class PolynomialSolver : public PolynomialSolverBase { typedef std::conditional_t::IsComplex, ComplexEigenSolver, EigenSolver > EigenSolverType; - typedef std::conditional_t::IsComplex, Scalar, std::complex > ComplexScalar; + typedef internal::make_complex_t ComplexScalar; public: /** Computes the complex roots of a new polynomial. */