Remove assumption of std::complex for complex scalar types.

This commit is contained in:
Antonio Sanchez 2025-02-12 11:21:44 -08:00
parent 6b4881ad48
commit 22cd7307dd
21 changed files with 273 additions and 115 deletions

View File

@ -182,6 +182,35 @@ struct imag_ref_retval {
typedef typename NumTraits<Scalar>::Real& type;
};
} // namespace internal
namespace numext {
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
const Scalar& x) {
return internal::real_ref_impl<Scalar>::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
} // namespace numext
namespace internal {
/****************************************************************************
* Implementation of conj *
****************************************************************************/
@ -221,7 +250,9 @@ template <typename Scalar>
struct abs2_impl_default<Scalar, true> // IsComplex
{
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { return x.real() * x.real() + x.imag() * x.imag(); }
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
return numext::real(x) * numext::real(x) + numext::imag(x) * numext::imag(x);
}
};
template <typename Scalar>
@ -250,16 +281,14 @@ struct sqrt_impl {
};
// Complex sqrt defined in MathFunctionsImpl.h.
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x);
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& a_x);
// Custom implementation is faster than `std::sqrt`, works on
// GPU, and correctly handles special cases (unlike MSVC).
template <typename T>
struct sqrt_impl<std::complex<T>> {
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) {
return complex_sqrt<T>(x);
}
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) { return complex_sqrt(x); }
};
template <typename Scalar>
@ -272,13 +301,13 @@ template <typename T>
struct rsqrt_impl;
// Complex rsqrt defined in MathFunctionsImpl.h.
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x);
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& a_x);
template <typename T>
struct rsqrt_impl<std::complex<T>> {
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) {
return complex_rsqrt<T>(x);
return complex_rsqrt(x);
}
};
@ -299,7 +328,7 @@ struct norm1_default_impl<Scalar, true> {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
EIGEN_USING_STD(abs);
return abs(x.real()) + abs(x.imag());
return abs(numext::real(x)) + abs(numext::imag(x));
}
};
@ -469,8 +498,8 @@ struct expm1_retval {
****************************************************************************/
// Complex log defined in MathFunctionsImpl.h.
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z);
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z);
template <typename Scalar>
struct log_impl {
@ -846,7 +875,7 @@ struct sign_impl<Scalar, true, IsInteger> {
real_type aa = abs(a);
if (aa == real_type(0)) return Scalar(0);
aa = real_type(1) / aa;
return Scalar(a.real() * aa, a.imag() * aa);
return Scalar(numext::real(a) * aa, numext::imag(a) * aa);
}
};
@ -1042,27 +1071,6 @@ SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(maxi, fmax)
#endif
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
const Scalar& x) {
return internal::real_ref_impl<Scalar>::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(arg, Scalar) arg(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(arg, Scalar)::run(x);

View File

@ -171,8 +171,8 @@ struct hypot_impl {
// Generic complex sqrt implementation that correctly handles corner cases
// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& z) {
// Computes the principal sqrt of the input.
//
// For a complex square root of the number x + i*y. We want to find real
@ -194,21 +194,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
// if x == 0: u = w, v = sign(y) * w
// if x > 0: u = w, v = y / (2 * w)
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
using T = typename NumTraits<ComplexT>::Real;
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
return (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
: numext::is_exactly_zero(x) ? std::complex<T>(w, y < zero ? -w : w)
: x > zero ? std::complex<T>(w, y / (2 * w))
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w);
return (numext::isinf)(y) ? ComplexT(NumTraits<T>::infinity(), y)
: numext::is_exactly_zero(x) ? ComplexT(w, y < zero ? -w : w)
: x > zero ? ComplexT(w, y / (2 * w))
: ComplexT(numext::abs(y) / (2 * w), y < zero ? -w : w);
}
// Generic complex rsqrt implementation.
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& z) {
// Computes the principal reciprocal sqrt of the input.
//
// For a complex reciprocal square root of the number z = x + i*y. We want to
@ -230,7 +230,7 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
// if x == 0: u = w / |z|, v = -sign(y) * w / |z|
// if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
// if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
using T = typename NumTraits<ComplexT>::Real;
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
@ -239,20 +239,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
const T woz = w / abs_z;
// Corner cases consistent with 1/sqrt(z) on gcc/clang.
return numext::is_exactly_zero(abs_z) ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
: ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
: numext::is_exactly_zero(x) ? std::complex<T>(woz, y < zero ? woz : -woz)
: x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
return numext::is_exactly_zero(abs_z) ? ComplexT(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
: ((numext::isinf)(x) || (numext::isinf)(y)) ? ComplexT(zero, zero)
: numext::is_exactly_zero(x) ? ComplexT(woz, y < zero ? woz : -woz)
: x > zero ? ComplexT(woz, -y / (2 * w * abs_z))
: ComplexT(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
}
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
template <typename ComplexT>
EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) {
// Computes complex log.
using T = typename NumTraits<ComplexT>::Real;
T a = numext::abs(z);
EIGEN_USING_STD(atan2);
T b = atan2(z.imag(), z.real());
return std::complex<T>(numext::log(a), b);
return ComplexT(numext::log(a), b);
}
} // end namespace internal

View File

@ -112,7 +112,7 @@ class MatrixBase : public DenseBase<Derived> {
ConstTransposeReturnType>
AdjointReturnType;
/** \internal Return type of eigenvalues() */
typedef Matrix<std::complex<RealScalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor>
typedef Matrix<internal::make_complex_t<Scalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor>
EigenvaluesReturnType;
/** \internal the return type of identity */
typedef CwiseNullaryOp<internal::scalar_identity_op<Scalar>, PlainObject> IdentityReturnType;
@ -468,7 +468,7 @@ class MatrixBase : public DenseBase<Derived> {
EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root)
EIGEN_MATRIX_FUNCTION(MatrixLogarithmReturnValue, log, logarithm)
EIGEN_MATRIX_FUNCTION_1(MatrixPowerReturnValue, pow, power to \c p, const RealScalar& p)
EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const std::complex<RealScalar>& p)
EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const internal::make_complex_t<Scalar>& p)
protected:
EIGEN_DEFAULT_COPY_CONSTRUCTOR(MatrixBase)

View File

@ -497,7 +497,7 @@ class MatrixComplexPowerReturnValue;
namespace internal {
template <typename Scalar>
struct stem_function {
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef ComplexScalar type(ComplexScalar, int);
};
} // namespace internal

View File

@ -745,6 +745,9 @@ using std::is_constant_evaluated;
constexpr bool is_constant_evaluated() { return false; }
#endif
template <typename Scalar>
using make_complex_t = std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar>>;
} // end namespace internal
} // end namespace Eigen

View File

@ -885,8 +885,12 @@ struct scalar_div_cost {
};
template <typename T, bool Vectorized>
struct scalar_div_cost<std::complex<T>, Vectorized> {
enum { value = 2 * scalar_div_cost<T>::value + 6 * NumTraits<T>::MulCost + 3 * NumTraits<T>::AddCost };
struct scalar_div_cost<T, Vectorized, std::enable_if_t<NumTraits<T>::IsComplex>> {
using RealScalar = typename NumTraits<T>::Real;
enum {
value =
2 * scalar_div_cost<RealScalar>::value + 6 * NumTraits<RealScalar>::MulCost + 3 * NumTraits<RealScalar>::AddCost
};
};
template <bool Vectorized>

View File

@ -70,7 +70,7 @@ class ComplexEigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is
* complex.
*/
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of eigenvalues as returned by eigenvalues().
*

View File

@ -75,7 +75,7 @@ class ComplexSchur {
* \c float or \c double) and just \c Scalar if #Scalar is
* complex.
*/
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for the matrices in the Schur decomposition.
*

View File

@ -89,7 +89,7 @@ class EigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is
* complex.
*/
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of eigenvalues as returned by eigenvalues().
*

View File

@ -83,7 +83,7 @@ class GeneralizedEigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is
* complex.
*/
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of real scalar values eigenvalues as returned by betas().
*

View File

@ -69,7 +69,7 @@ class RealQZ {
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
typedef typename MatrixType::Scalar Scalar;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;

View File

@ -66,7 +66,7 @@ class RealSchur {
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
};
typedef typename MatrixType::Scalar Scalar;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;

132
test/CustomComplex.h Normal file
View File

@ -0,0 +1,132 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TEST_CUSTOM_COMPLEX_H
#define EIGEN_TEST_CUSTOM_COMPLEX_H
#include <ostream>
#include <sstream>
namespace custom_complex {
template <typename Real>
struct CustomComplex {
CustomComplex() : re{0}, im{0} {}
CustomComplex(const CustomComplex& other) = default;
CustomComplex(CustomComplex&& other) = default;
CustomComplex& operator=(const CustomComplex& other) = default;
CustomComplex& operator=(CustomComplex&& other) = default;
CustomComplex(Real x) : re{x}, im{0} {}
CustomComplex(Real x, Real y) : re{x}, im{y} {}
CustomComplex operator+(const CustomComplex& other) const { return CustomComplex(re + other.re, im + other.im); }
CustomComplex operator-() const { return CustomComplex(-re, -im); }
CustomComplex operator-(const CustomComplex& other) const { return CustomComplex(re - other.re, im - other.im); }
CustomComplex operator*(const CustomComplex& other) const {
return CustomComplex(re * other.re - im * other.im, re * other.im + im * other.re);
}
CustomComplex operator/(const CustomComplex& other) const {
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
// guards against over/under-flow.
const bool scale_imag = numext::abs(other.im) <= numext::abs(other.re);
const Real rscale = scale_imag ? Real(1) : other.re / other.im;
const Real iscale = scale_imag ? other.im / other.re : Real(1);
const Real denominator = other.re * rscale + other.im * iscale;
return CustomComplex((re * rscale + im * iscale) / denominator, (im * rscale - re * iscale) / denominator);
}
CustomComplex& operator+=(const CustomComplex& other) {
*this = *this + other;
return *this;
}
CustomComplex& operator-=(const CustomComplex& other) {
*this = *this - other;
return *this;
}
CustomComplex& operator*=(const CustomComplex& other) {
*this = *this * other;
return *this;
}
CustomComplex& operator/=(const CustomComplex& other) {
*this = *this / other;
return *this;
}
bool operator==(const CustomComplex& other) const {
return numext::equal_strict(re, other.re) && numext::equal_strict(im, other.im);
}
bool operator!=(const CustomComplex& other) const { return !(*this == other); }
friend CustomComplex operator+(const Real& a, const CustomComplex& b) { return CustomComplex(a + b.re, b.im); }
friend CustomComplex operator-(const Real& a, const CustomComplex& b) { return CustomComplex(a - b.re, -b.im); }
friend CustomComplex operator*(const Real& a, const CustomComplex& b) { return CustomComplex(a * b.re, a * b.im); }
friend CustomComplex operator*(const CustomComplex& a, const Real& b) { return CustomComplex(a.re * b, a.im * b); }
friend CustomComplex operator/(const CustomComplex& a, const Real& b) { return CustomComplex(a.re / b, a.im / b); }
friend std::ostream& operator<<(std::ostream& stream, const CustomComplex& x) {
std::stringstream ss;
ss << "(" << x.re << ", " << x.im << ")";
stream << ss.str();
return stream;
}
Real re;
Real im;
};
template <typename Real>
Real real(const CustomComplex<Real>& x) {
return x.re;
}
template <typename Real>
Real imag(const CustomComplex<Real>& x) {
return x.im;
}
template <typename Real>
CustomComplex<Real> conj(const CustomComplex<Real>& x) {
return CustomComplex<Real>(x.re, -x.im);
}
template <typename Real>
CustomComplex<Real> sqrt(const CustomComplex<Real>& x) {
return Eigen::internal::complex_sqrt(x);
}
template <typename Real>
Real abs(const CustomComplex<Real>& x) {
return Eigen::numext::sqrt(x.re * x.re + x.im * x.im);
}
} // namespace custom_complex
template <typename Real>
using CustomComplex = custom_complex::CustomComplex<Real>;
namespace Eigen {
template <typename Real>
struct NumTraits<CustomComplex<Real>> : NumTraits<Real> {
enum { IsComplex = 1 };
};
namespace numext {
template <typename Real>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const CustomComplex<Real>& x) {
return (numext::isfinite)(x.re) && (numext::isfinite)(x.im);
}
} // namespace numext
} // namespace Eigen
#endif // EIGEN_TEST_CUSTOM_COMPLEX_H

View File

@ -12,6 +12,7 @@
#include <limits>
#include <Eigen/Eigenvalues>
#include <Eigen/LU>
#include "CustomComplex.h"
template <typename MatrixType>
bool find_pivot(typename MatrixType::Scalar tol, MatrixType& diffs, Index col = 0) {
@ -165,5 +166,8 @@ EIGEN_DECLARE_TEST(eigensolver_complex) {
// Test problem size constructors
CALL_SUBTEST_5(ComplexEigenSolver<MatrixXf> tmp(s));
// Test custom complex scalar type.
CALL_SUBTEST_6(eigensolver(Matrix<CustomComplex<double>, 5, 5>()));
TEST_SET_BUT_UNUSED_VARIABLE(s)
}

View File

@ -15,7 +15,7 @@
namespace Eigen {
template <bool NeedUprade>
template <bool IsReal>
struct MakeComplex {
template <typename T>
EIGEN_DEVICE_FUNC T operator()(const T& val) const {
@ -26,16 +26,8 @@ struct MakeComplex {
template <>
struct MakeComplex<true> {
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> operator()(const T& val) const {
return std::complex<T>(val, 0);
}
};
template <>
struct MakeComplex<false> {
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> operator()(const std::complex<T>& val) const {
return val;
EIGEN_DEVICE_FUNC internal::make_complex_t<T> operator()(const T& val) const {
return internal::make_complex_t<T>(val, T(0));
}
};
@ -49,17 +41,17 @@ struct PartOf {
template <>
struct PartOf<RealPart> {
template <typename T>
T operator()(const std::complex<T>& val) const {
return val.real();
template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
typename NumTraits<T>::Real operator()(const T& val) const {
return Eigen::numext::real(val);
}
};
template <>
struct PartOf<ImagPart> {
template <typename T>
T operator()(const std::complex<T>& val) const {
return val.imag();
template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
typename NumTraits<T>::Real operator()(const T& val) const {
return Eigen::numext::imag(val);
}
};
@ -67,8 +59,9 @@ namespace internal {
template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> {
typedef traits<XprType> XprTraits;
typedef typename NumTraits<typename XprTraits::Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar;
typedef typename XprTraits::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef make_complex_t<Scalar> ComplexScalar;
typedef typename XprTraits::Scalar InputScalar;
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar;
@ -109,7 +102,7 @@ class TensorFFTOp : public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, F
public:
typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar;
typedef OutputScalar CoeffReturnType;
@ -137,7 +130,7 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
typedef DSizes<Index, NumDims> Dimensions;
typedef typename XprType::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
typedef internal::traits<XprType> XprTraits;
typedef typename XprTraits::Scalar InputScalar;

View File

@ -111,12 +111,13 @@ class DGMRES : public IterativeSolverBase<DGMRES<MatrixType_, Preconditioner_> >
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex;
typedef typename MatrixType::RealScalar RealScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Preconditioner_ Preconditioner;
typedef Matrix<Scalar, Dynamic, Dynamic> DenseMatrix;
typedef Matrix<RealScalar, Dynamic, Dynamic> DenseRealMatrix;
typedef Matrix<Scalar, Dynamic, 1> DenseVector;
typedef Matrix<RealScalar, Dynamic, 1> DenseRealVector;
typedef Matrix<std::complex<RealScalar>, Dynamic, 1> ComplexVector;
typedef Matrix<ComplexScalar, Dynamic, 1> ComplexVector;
/** Default constructor. */
DGMRES()
@ -389,15 +390,15 @@ inline typename DGMRES<MatrixType_, Preconditioner_>::ComplexVector DGMRES<Matri
Index j = 0;
while (j < it - 1) {
if (T(j + 1, j) == Scalar(0)) {
eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0));
eig(j) = ComplexScalar(T(j, j), RealScalar(0));
j++;
} else {
eig(j) = std::complex<RealScalar>(T(j, j), T(j + 1, j));
eig(j + 1) = std::complex<RealScalar>(T(j, j + 1), T(j + 1, j + 1));
eig(j) = ComplexScalar(T(j, j), T(j + 1, j));
eig(j + 1) = ComplexScalar(T(j, j + 1), T(j + 1, j + 1));
j++;
}
}
if (j < it - 1) eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0));
if (j < it - 1) eig(j) = ComplexScalar(T(j, j), RealScalar(0));
return eig;
}

View File

@ -23,8 +23,10 @@ namespace internal {
*
* This struct is used by CwiseUnaryOp to scale a matrix by \f$ 2^{-s} \f$.
*/
template <typename RealScalar>
template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
struct MatrixExponentialScalingOp {
using RealScalar = typename NumTraits<Scalar>::Real;
/** \brief Constructor.
*
* \param[in] squarings The integer \f$ s \f$ in this document.
@ -35,20 +37,30 @@ struct MatrixExponentialScalingOp {
*
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
*/
inline const RealScalar operator()(const RealScalar& x) const {
inline const Scalar operator()(const Scalar& x) const {
using std::ldexp;
return ldexp(x, -m_squarings);
return Scalar(ldexp(Eigen::numext::real(x), -m_squarings), ldexp(Eigen::numext::imag(x), -m_squarings));
}
typedef std::complex<RealScalar> ComplexScalar;
private:
int m_squarings;
};
template <typename Scalar>
struct MatrixExponentialScalingOp<Scalar, /*IsComplex=*/false> {
/** \brief Constructor.
*
* \param[in] squarings The integer \f$ s \f$ in this document.
*/
MatrixExponentialScalingOp(int squarings) : m_squarings(squarings) {}
/** \brief Scale a matrix coefficient.
*
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
*/
inline const ComplexScalar operator()(const ComplexScalar& x) const {
inline const Scalar operator()(const Scalar& x) const {
using std::ldexp;
return ComplexScalar(ldexp(x.real(), -m_squarings), ldexp(x.imag(), -m_squarings));
return ldexp(x, -m_squarings);
}
private:
@ -220,6 +232,7 @@ struct matrix_exp_computeUV {
template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, float> {
using Scalar = typename traits<MatrixType>::Scalar;
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using std::frexp;
@ -234,7 +247,7 @@ struct matrix_exp_computeUV<MatrixType, float> {
const float maxnorm = 3.925724783138660f;
frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<float>(squarings));
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade7(A, U, V);
}
}
@ -242,12 +255,12 @@ struct matrix_exp_computeUV<MatrixType, float> {
template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, double> {
typedef typename NumTraits<typename traits<MatrixType>::Scalar>::Real RealScalar;
using Scalar = typename traits<MatrixType>::Scalar;
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using std::frexp;
using std::pow;
const RealScalar l1norm = arg.cwiseAbs().colwise().sum().maxCoeff();
const double l1norm = arg.cwiseAbs().colwise().sum().maxCoeff();
squarings = 0;
if (l1norm < 1.495585217958292e-002) {
matrix_exp_pade3(arg, U, V);
@ -258,10 +271,10 @@ struct matrix_exp_computeUV<MatrixType, double> {
} else if (l1norm < 2.097847961257068e+000) {
matrix_exp_pade9(arg, U, V);
} else {
const RealScalar maxnorm = 5.371920351148152;
const double maxnorm = 5.371920351148152;
frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<RealScalar>(squarings));
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade13(A, U, V);
}
}
@ -271,6 +284,7 @@ template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, long double> {
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using Scalar = typename traits<MatrixType>::Scalar;
#if LDBL_MANT_DIG == 53 // double precision
matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings);
@ -295,7 +309,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 4.0246098906697353063L;
frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade13(A, U, V);
}
@ -315,7 +329,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 3.2579440895405400856599663723517L;
frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade17(A, U, V);
}
@ -335,7 +349,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 2.884233277829519311757165057717815L;
frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade17(A, U, V);
}
@ -382,9 +396,7 @@ template <typename ArgType, typename ResultType>
void matrix_exp_compute(const ArgType& arg, ResultType& result, false_type) // default
{
typedef typename ArgType::PlainObject MatrixType;
typedef typename traits<MatrixType>::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar;
typedef make_complex_t<typename traits<MatrixType>::Scalar> ComplexScalar;
result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>);
}

View File

@ -382,7 +382,7 @@ struct matrix_function_compute<MatrixType, 0> {
static const int Rows = Traits::RowsAtCompileTime, Cols = Traits::ColsAtCompileTime;
static const int MaxRows = Traits::MaxRowsAtCompileTime, MaxCols = Traits::MaxColsAtCompileTime;
typedef std::complex<Scalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Rows, Cols, 0, MaxRows, MaxCols> ComplexMatrix;
ComplexMatrix CA = A.template cast<ComplexScalar>();
@ -476,7 +476,7 @@ class MatrixFunctionReturnValue : public ReturnByValue<MatrixFunctionReturnValue
typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType;
typedef internal::remove_all_t<NestedEvalType> NestedEvalTypeClean;
typedef internal::traits<NestedEvalTypeClean> Traits;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
DynMatrixType;

View File

@ -330,7 +330,7 @@ class MatrixLogarithmReturnValue : public ReturnByValue<MatrixLogarithmReturnVal
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
typedef internal::remove_all_t<DerivedEvalType> DerivedEvalTypeClean;
typedef internal::traits<DerivedEvalTypeClean> Traits;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
DynMatrixType;
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;

View File

@ -91,7 +91,7 @@ class MatrixPowerAtomic : internal::noncopyable {
enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime };
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Block<MatrixType, Dynamic, Dynamic> ResultType;
const MatrixType& m_A;
@ -380,7 +380,7 @@ class MatrixPower : internal::noncopyable {
Index cols() const { return m_A.cols(); }
private:
typedef std::complex<RealScalar> ComplexScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime>
ComplexMatrix;
@ -628,7 +628,7 @@ template <typename Derived>
class MatrixComplexPowerReturnValue : public ReturnByValue<MatrixComplexPowerReturnValue<Derived> > {
public:
typedef typename Derived::PlainObject PlainObject;
typedef typename std::complex<typename Derived::RealScalar> ComplexScalar;
typedef internal::make_complex_t<typename Derived::Scalar> ComplexScalar;
/**
* \brief Constructor.
@ -685,7 +685,7 @@ const MatrixPowerReturnValue<Derived> MatrixBase<Derived>::pow(const RealScalar&
}
template <typename Derived>
const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const std::complex<RealScalar>& p) const {
const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const internal::make_complex_t<Scalar>& p) const {
return MatrixComplexPowerReturnValue<Derived>(derived(), p);
}

View File

@ -35,7 +35,7 @@ class PolynomialSolverBase {
typedef Scalar_ Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef std::complex<RealScalar> RootType;
typedef internal::make_complex_t<Scalar> RootType;
typedef Matrix<RootType, Deg_, 1> RootsType;
typedef DenseIndex Index;
@ -308,7 +308,7 @@ class PolynomialSolver : public PolynomialSolverBase<Scalar_, Deg_> {
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, ComplexEigenSolver<CompanionMatrixType>,
EigenSolver<CompanionMatrixType> >
EigenSolverType;
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar> > ComplexScalar;
typedef internal::make_complex_t<Scalar_> ComplexScalar;
public:
/** Computes the complex roots of a new polynomial. */