mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 03:09:01 +08:00
Remove assumption of std::complex for complex scalar types.
This commit is contained in:
parent
6b4881ad48
commit
22cd7307dd
@ -182,6 +182,35 @@ struct imag_ref_retval {
|
||||
typedef typename NumTraits<Scalar>::Real& type;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
namespace numext {
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
|
||||
const Scalar& x) {
|
||||
return internal::real_ref_impl<Scalar>::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
|
||||
}
|
||||
|
||||
} // namespace numext
|
||||
|
||||
namespace internal {
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of conj *
|
||||
****************************************************************************/
|
||||
@ -221,7 +250,9 @@ template <typename Scalar>
|
||||
struct abs2_impl_default<Scalar, true> // IsComplex
|
||||
{
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { return x.real() * x.real() + x.imag() * x.imag(); }
|
||||
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
|
||||
return numext::real(x) * numext::real(x) + numext::imag(x) * numext::imag(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
@ -250,16 +281,14 @@ struct sqrt_impl {
|
||||
};
|
||||
|
||||
// Complex sqrt defined in MathFunctionsImpl.h.
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x);
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& a_x);
|
||||
|
||||
// Custom implementation is faster than `std::sqrt`, works on
|
||||
// GPU, and correctly handles special cases (unlike MSVC).
|
||||
template <typename T>
|
||||
struct sqrt_impl<std::complex<T>> {
|
||||
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) {
|
||||
return complex_sqrt<T>(x);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) { return complex_sqrt(x); }
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
@ -272,13 +301,13 @@ template <typename T>
|
||||
struct rsqrt_impl;
|
||||
|
||||
// Complex rsqrt defined in MathFunctionsImpl.h.
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x);
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& a_x);
|
||||
|
||||
template <typename T>
|
||||
struct rsqrt_impl<std::complex<T>> {
|
||||
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) {
|
||||
return complex_rsqrt<T>(x);
|
||||
return complex_rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
@ -299,7 +328,7 @@ struct norm1_default_impl<Scalar, true> {
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
|
||||
EIGEN_USING_STD(abs);
|
||||
return abs(x.real()) + abs(x.imag());
|
||||
return abs(numext::real(x)) + abs(numext::imag(x));
|
||||
}
|
||||
};
|
||||
|
||||
@ -469,8 +498,8 @@ struct expm1_retval {
|
||||
****************************************************************************/
|
||||
|
||||
// Complex log defined in MathFunctionsImpl.h.
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z);
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z);
|
||||
|
||||
template <typename Scalar>
|
||||
struct log_impl {
|
||||
@ -846,7 +875,7 @@ struct sign_impl<Scalar, true, IsInteger> {
|
||||
real_type aa = abs(a);
|
||||
if (aa == real_type(0)) return Scalar(0);
|
||||
aa = real_type(1) / aa;
|
||||
return Scalar(a.real() * aa, a.imag() * aa);
|
||||
return Scalar(numext::real(a) * aa, numext::imag(a) * aa);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1042,27 +1071,6 @@ SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(maxi, fmax)
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
|
||||
const Scalar& x) {
|
||||
return internal::real_ref_impl<Scalar>::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(arg, Scalar) arg(const Scalar& x) {
|
||||
return EIGEN_MATHFUNC_IMPL(arg, Scalar)::run(x);
|
||||
|
@ -171,8 +171,8 @@ struct hypot_impl {
|
||||
|
||||
// Generic complex sqrt implementation that correctly handles corner cases
|
||||
// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& z) {
|
||||
// Computes the principal sqrt of the input.
|
||||
//
|
||||
// For a complex square root of the number x + i*y. We want to find real
|
||||
@ -194,21 +194,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
|
||||
// if x == 0: u = w, v = sign(y) * w
|
||||
// if x > 0: u = w, v = y / (2 * w)
|
||||
// if x < 0: u = |y| / (2 * w), v = sign(y) * w
|
||||
|
||||
using T = typename NumTraits<ComplexT>::Real;
|
||||
const T x = numext::real(z);
|
||||
const T y = numext::imag(z);
|
||||
const T zero = T(0);
|
||||
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
|
||||
|
||||
return (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
|
||||
: numext::is_exactly_zero(x) ? std::complex<T>(w, y < zero ? -w : w)
|
||||
: x > zero ? std::complex<T>(w, y / (2 * w))
|
||||
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w);
|
||||
return (numext::isinf)(y) ? ComplexT(NumTraits<T>::infinity(), y)
|
||||
: numext::is_exactly_zero(x) ? ComplexT(w, y < zero ? -w : w)
|
||||
: x > zero ? ComplexT(w, y / (2 * w))
|
||||
: ComplexT(numext::abs(y) / (2 * w), y < zero ? -w : w);
|
||||
}
|
||||
|
||||
// Generic complex rsqrt implementation.
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& z) {
|
||||
// Computes the principal reciprocal sqrt of the input.
|
||||
//
|
||||
// For a complex reciprocal square root of the number z = x + i*y. We want to
|
||||
@ -230,7 +230,7 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
|
||||
// if x == 0: u = w / |z|, v = -sign(y) * w / |z|
|
||||
// if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
|
||||
// if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
|
||||
|
||||
using T = typename NumTraits<ComplexT>::Real;
|
||||
const T x = numext::real(z);
|
||||
const T y = numext::imag(z);
|
||||
const T zero = T(0);
|
||||
@ -239,20 +239,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
|
||||
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
|
||||
const T woz = w / abs_z;
|
||||
// Corner cases consistent with 1/sqrt(z) on gcc/clang.
|
||||
return numext::is_exactly_zero(abs_z) ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
|
||||
: ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
|
||||
: numext::is_exactly_zero(x) ? std::complex<T>(woz, y < zero ? woz : -woz)
|
||||
: x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
|
||||
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
|
||||
return numext::is_exactly_zero(abs_z) ? ComplexT(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
|
||||
: ((numext::isinf)(x) || (numext::isinf)(y)) ? ComplexT(zero, zero)
|
||||
: numext::is_exactly_zero(x) ? ComplexT(woz, y < zero ? woz : -woz)
|
||||
: x > zero ? ComplexT(woz, -y / (2 * w * abs_z))
|
||||
: ComplexT(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
|
||||
template <typename ComplexT>
|
||||
EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) {
|
||||
// Computes complex log.
|
||||
using T = typename NumTraits<ComplexT>::Real;
|
||||
T a = numext::abs(z);
|
||||
EIGEN_USING_STD(atan2);
|
||||
T b = atan2(z.imag(), z.real());
|
||||
return std::complex<T>(numext::log(a), b);
|
||||
return ComplexT(numext::log(a), b);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
@ -112,7 +112,7 @@ class MatrixBase : public DenseBase<Derived> {
|
||||
ConstTransposeReturnType>
|
||||
AdjointReturnType;
|
||||
/** \internal Return type of eigenvalues() */
|
||||
typedef Matrix<std::complex<RealScalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor>
|
||||
typedef Matrix<internal::make_complex_t<Scalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor>
|
||||
EigenvaluesReturnType;
|
||||
/** \internal the return type of identity */
|
||||
typedef CwiseNullaryOp<internal::scalar_identity_op<Scalar>, PlainObject> IdentityReturnType;
|
||||
@ -468,7 +468,7 @@ class MatrixBase : public DenseBase<Derived> {
|
||||
EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root)
|
||||
EIGEN_MATRIX_FUNCTION(MatrixLogarithmReturnValue, log, logarithm)
|
||||
EIGEN_MATRIX_FUNCTION_1(MatrixPowerReturnValue, pow, power to \c p, const RealScalar& p)
|
||||
EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const std::complex<RealScalar>& p)
|
||||
EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const internal::make_complex_t<Scalar>& p)
|
||||
|
||||
protected:
|
||||
EIGEN_DEFAULT_COPY_CONSTRUCTOR(MatrixBase)
|
||||
|
@ -497,7 +497,7 @@ class MatrixComplexPowerReturnValue;
|
||||
namespace internal {
|
||||
template <typename Scalar>
|
||||
struct stem_function {
|
||||
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef ComplexScalar type(ComplexScalar, int);
|
||||
};
|
||||
} // namespace internal
|
||||
|
@ -745,6 +745,9 @@ using std::is_constant_evaluated;
|
||||
constexpr bool is_constant_evaluated() { return false; }
|
||||
#endif
|
||||
|
||||
template <typename Scalar>
|
||||
using make_complex_t = std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar>>;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -885,8 +885,12 @@ struct scalar_div_cost {
|
||||
};
|
||||
|
||||
template <typename T, bool Vectorized>
|
||||
struct scalar_div_cost<std::complex<T>, Vectorized> {
|
||||
enum { value = 2 * scalar_div_cost<T>::value + 6 * NumTraits<T>::MulCost + 3 * NumTraits<T>::AddCost };
|
||||
struct scalar_div_cost<T, Vectorized, std::enable_if_t<NumTraits<T>::IsComplex>> {
|
||||
using RealScalar = typename NumTraits<T>::Real;
|
||||
enum {
|
||||
value =
|
||||
2 * scalar_div_cost<RealScalar>::value + 6 * NumTraits<RealScalar>::MulCost + 3 * NumTraits<RealScalar>::AddCost
|
||||
};
|
||||
};
|
||||
|
||||
template <bool Vectorized>
|
||||
|
@ -70,7 +70,7 @@ class ComplexEigenSolver {
|
||||
* \c float or \c double) and just \c Scalar if #Scalar is
|
||||
* complex.
|
||||
*/
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
|
||||
/** \brief Type for vector of eigenvalues as returned by eigenvalues().
|
||||
*
|
||||
|
@ -75,7 +75,7 @@ class ComplexSchur {
|
||||
* \c float or \c double) and just \c Scalar if #Scalar is
|
||||
* complex.
|
||||
*/
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
|
||||
/** \brief Type for the matrices in the Schur decomposition.
|
||||
*
|
||||
|
@ -89,7 +89,7 @@ class EigenSolver {
|
||||
* \c float or \c double) and just \c Scalar if #Scalar is
|
||||
* complex.
|
||||
*/
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
|
||||
/** \brief Type for vector of eigenvalues as returned by eigenvalues().
|
||||
*
|
||||
|
@ -83,7 +83,7 @@ class GeneralizedEigenSolver {
|
||||
* \c float or \c double) and just \c Scalar if #Scalar is
|
||||
* complex.
|
||||
*/
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
|
||||
/** \brief Type for vector of real scalar values eigenvalues as returned by betas().
|
||||
*
|
||||
|
@ -69,7 +69,7 @@ class RealQZ {
|
||||
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
|
||||
};
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
|
||||
|
||||
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;
|
||||
|
@ -66,7 +66,7 @@ class RealSchur {
|
||||
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
|
||||
};
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
|
||||
|
||||
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;
|
||||
|
132
test/CustomComplex.h
Normal file
132
test/CustomComplex.h
Normal file
@ -0,0 +1,132 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2025 The Eigen Authors.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_TEST_CUSTOM_COMPLEX_H
|
||||
#define EIGEN_TEST_CUSTOM_COMPLEX_H
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace custom_complex {
|
||||
|
||||
template <typename Real>
|
||||
struct CustomComplex {
|
||||
CustomComplex() : re{0}, im{0} {}
|
||||
CustomComplex(const CustomComplex& other) = default;
|
||||
CustomComplex(CustomComplex&& other) = default;
|
||||
CustomComplex& operator=(const CustomComplex& other) = default;
|
||||
CustomComplex& operator=(CustomComplex&& other) = default;
|
||||
CustomComplex(Real x) : re{x}, im{0} {}
|
||||
CustomComplex(Real x, Real y) : re{x}, im{y} {}
|
||||
|
||||
CustomComplex operator+(const CustomComplex& other) const { return CustomComplex(re + other.re, im + other.im); }
|
||||
|
||||
CustomComplex operator-() const { return CustomComplex(-re, -im); }
|
||||
|
||||
CustomComplex operator-(const CustomComplex& other) const { return CustomComplex(re - other.re, im - other.im); }
|
||||
|
||||
CustomComplex operator*(const CustomComplex& other) const {
|
||||
return CustomComplex(re * other.re - im * other.im, re * other.im + im * other.re);
|
||||
}
|
||||
|
||||
CustomComplex operator/(const CustomComplex& other) const {
|
||||
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
|
||||
// guards against over/under-flow.
|
||||
const bool scale_imag = numext::abs(other.im) <= numext::abs(other.re);
|
||||
const Real rscale = scale_imag ? Real(1) : other.re / other.im;
|
||||
const Real iscale = scale_imag ? other.im / other.re : Real(1);
|
||||
const Real denominator = other.re * rscale + other.im * iscale;
|
||||
return CustomComplex((re * rscale + im * iscale) / denominator, (im * rscale - re * iscale) / denominator);
|
||||
}
|
||||
|
||||
CustomComplex& operator+=(const CustomComplex& other) {
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
}
|
||||
CustomComplex& operator-=(const CustomComplex& other) {
|
||||
*this = *this - other;
|
||||
return *this;
|
||||
}
|
||||
CustomComplex& operator*=(const CustomComplex& other) {
|
||||
*this = *this * other;
|
||||
return *this;
|
||||
}
|
||||
CustomComplex& operator/=(const CustomComplex& other) {
|
||||
*this = *this / other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const CustomComplex& other) const {
|
||||
return numext::equal_strict(re, other.re) && numext::equal_strict(im, other.im);
|
||||
}
|
||||
bool operator!=(const CustomComplex& other) const { return !(*this == other); }
|
||||
|
||||
friend CustomComplex operator+(const Real& a, const CustomComplex& b) { return CustomComplex(a + b.re, b.im); }
|
||||
|
||||
friend CustomComplex operator-(const Real& a, const CustomComplex& b) { return CustomComplex(a - b.re, -b.im); }
|
||||
|
||||
friend CustomComplex operator*(const Real& a, const CustomComplex& b) { return CustomComplex(a * b.re, a * b.im); }
|
||||
|
||||
friend CustomComplex operator*(const CustomComplex& a, const Real& b) { return CustomComplex(a.re * b, a.im * b); }
|
||||
|
||||
friend CustomComplex operator/(const CustomComplex& a, const Real& b) { return CustomComplex(a.re / b, a.im / b); }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& stream, const CustomComplex& x) {
|
||||
std::stringstream ss;
|
||||
ss << "(" << x.re << ", " << x.im << ")";
|
||||
stream << ss.str();
|
||||
return stream;
|
||||
}
|
||||
|
||||
Real re;
|
||||
Real im;
|
||||
};
|
||||
|
||||
template <typename Real>
|
||||
Real real(const CustomComplex<Real>& x) {
|
||||
return x.re;
|
||||
}
|
||||
template <typename Real>
|
||||
Real imag(const CustomComplex<Real>& x) {
|
||||
return x.im;
|
||||
}
|
||||
template <typename Real>
|
||||
CustomComplex<Real> conj(const CustomComplex<Real>& x) {
|
||||
return CustomComplex<Real>(x.re, -x.im);
|
||||
}
|
||||
template <typename Real>
|
||||
CustomComplex<Real> sqrt(const CustomComplex<Real>& x) {
|
||||
return Eigen::internal::complex_sqrt(x);
|
||||
}
|
||||
template <typename Real>
|
||||
Real abs(const CustomComplex<Real>& x) {
|
||||
return Eigen::numext::sqrt(x.re * x.re + x.im * x.im);
|
||||
}
|
||||
|
||||
} // namespace custom_complex
|
||||
|
||||
template <typename Real>
|
||||
using CustomComplex = custom_complex::CustomComplex<Real>;
|
||||
|
||||
namespace Eigen {
|
||||
template <typename Real>
|
||||
struct NumTraits<CustomComplex<Real>> : NumTraits<Real> {
|
||||
enum { IsComplex = 1 };
|
||||
};
|
||||
|
||||
namespace numext {
|
||||
template <typename Real>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const CustomComplex<Real>& x) {
|
||||
return (numext::isfinite)(x.re) && (numext::isfinite)(x.im);
|
||||
}
|
||||
|
||||
} // namespace numext
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_TEST_CUSTOM_COMPLEX_H
|
@ -12,6 +12,7 @@
|
||||
#include <limits>
|
||||
#include <Eigen/Eigenvalues>
|
||||
#include <Eigen/LU>
|
||||
#include "CustomComplex.h"
|
||||
|
||||
template <typename MatrixType>
|
||||
bool find_pivot(typename MatrixType::Scalar tol, MatrixType& diffs, Index col = 0) {
|
||||
@ -165,5 +166,8 @@ EIGEN_DECLARE_TEST(eigensolver_complex) {
|
||||
// Test problem size constructors
|
||||
CALL_SUBTEST_5(ComplexEigenSolver<MatrixXf> tmp(s));
|
||||
|
||||
// Test custom complex scalar type.
|
||||
CALL_SUBTEST_6(eigensolver(Matrix<CustomComplex<double>, 5, 5>()));
|
||||
|
||||
TEST_SET_BUT_UNUSED_VARIABLE(s)
|
||||
}
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template <bool NeedUprade>
|
||||
template <bool IsReal>
|
||||
struct MakeComplex {
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC T operator()(const T& val) const {
|
||||
@ -26,16 +26,8 @@ struct MakeComplex {
|
||||
template <>
|
||||
struct MakeComplex<true> {
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> operator()(const T& val) const {
|
||||
return std::complex<T>(val, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MakeComplex<false> {
|
||||
template <typename T>
|
||||
EIGEN_DEVICE_FUNC std::complex<T> operator()(const std::complex<T>& val) const {
|
||||
return val;
|
||||
EIGEN_DEVICE_FUNC internal::make_complex_t<T> operator()(const T& val) const {
|
||||
return internal::make_complex_t<T>(val, T(0));
|
||||
}
|
||||
};
|
||||
|
||||
@ -49,17 +41,17 @@ struct PartOf {
|
||||
|
||||
template <>
|
||||
struct PartOf<RealPart> {
|
||||
template <typename T>
|
||||
T operator()(const std::complex<T>& val) const {
|
||||
return val.real();
|
||||
template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
|
||||
typename NumTraits<T>::Real operator()(const T& val) const {
|
||||
return Eigen::numext::real(val);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PartOf<ImagPart> {
|
||||
template <typename T>
|
||||
T operator()(const std::complex<T>& val) const {
|
||||
return val.imag();
|
||||
template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
|
||||
typename NumTraits<T>::Real operator()(const T& val) const {
|
||||
return Eigen::numext::imag(val);
|
||||
}
|
||||
};
|
||||
|
||||
@ -67,8 +59,9 @@ namespace internal {
|
||||
template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
|
||||
struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> {
|
||||
typedef traits<XprType> XprTraits;
|
||||
typedef typename NumTraits<typename XprTraits::Scalar>::Real RealScalar;
|
||||
typedef typename std::complex<RealScalar> ComplexScalar;
|
||||
typedef typename XprTraits::Scalar Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef make_complex_t<Scalar> ComplexScalar;
|
||||
typedef typename XprTraits::Scalar InputScalar;
|
||||
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
|
||||
OutputScalar;
|
||||
@ -109,7 +102,7 @@ class TensorFFTOp : public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, F
|
||||
public:
|
||||
typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
|
||||
OutputScalar;
|
||||
typedef OutputScalar CoeffReturnType;
|
||||
@ -137,7 +130,7 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
|
||||
typedef DSizes<Index, NumDims> Dimensions;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
|
||||
typedef internal::traits<XprType> XprTraits;
|
||||
typedef typename XprTraits::Scalar InputScalar;
|
||||
|
@ -111,12 +111,13 @@ class DGMRES : public IterativeSolverBase<DGMRES<MatrixType_, Preconditioner_> >
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::StorageIndex StorageIndex;
|
||||
typedef typename MatrixType::RealScalar RealScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Preconditioner_ Preconditioner;
|
||||
typedef Matrix<Scalar, Dynamic, Dynamic> DenseMatrix;
|
||||
typedef Matrix<RealScalar, Dynamic, Dynamic> DenseRealMatrix;
|
||||
typedef Matrix<Scalar, Dynamic, 1> DenseVector;
|
||||
typedef Matrix<RealScalar, Dynamic, 1> DenseRealVector;
|
||||
typedef Matrix<std::complex<RealScalar>, Dynamic, 1> ComplexVector;
|
||||
typedef Matrix<ComplexScalar, Dynamic, 1> ComplexVector;
|
||||
|
||||
/** Default constructor. */
|
||||
DGMRES()
|
||||
@ -389,15 +390,15 @@ inline typename DGMRES<MatrixType_, Preconditioner_>::ComplexVector DGMRES<Matri
|
||||
Index j = 0;
|
||||
while (j < it - 1) {
|
||||
if (T(j + 1, j) == Scalar(0)) {
|
||||
eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0));
|
||||
eig(j) = ComplexScalar(T(j, j), RealScalar(0));
|
||||
j++;
|
||||
} else {
|
||||
eig(j) = std::complex<RealScalar>(T(j, j), T(j + 1, j));
|
||||
eig(j + 1) = std::complex<RealScalar>(T(j, j + 1), T(j + 1, j + 1));
|
||||
eig(j) = ComplexScalar(T(j, j), T(j + 1, j));
|
||||
eig(j + 1) = ComplexScalar(T(j, j + 1), T(j + 1, j + 1));
|
||||
j++;
|
||||
}
|
||||
}
|
||||
if (j < it - 1) eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0));
|
||||
if (j < it - 1) eig(j) = ComplexScalar(T(j, j), RealScalar(0));
|
||||
return eig;
|
||||
}
|
||||
|
||||
|
@ -23,8 +23,10 @@ namespace internal {
|
||||
*
|
||||
* This struct is used by CwiseUnaryOp to scale a matrix by \f$ 2^{-s} \f$.
|
||||
*/
|
||||
template <typename RealScalar>
|
||||
template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
|
||||
struct MatrixExponentialScalingOp {
|
||||
using RealScalar = typename NumTraits<Scalar>::Real;
|
||||
|
||||
/** \brief Constructor.
|
||||
*
|
||||
* \param[in] squarings The integer \f$ s \f$ in this document.
|
||||
@ -35,20 +37,30 @@ struct MatrixExponentialScalingOp {
|
||||
*
|
||||
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
|
||||
*/
|
||||
inline const RealScalar operator()(const RealScalar& x) const {
|
||||
inline const Scalar operator()(const Scalar& x) const {
|
||||
using std::ldexp;
|
||||
return ldexp(x, -m_squarings);
|
||||
return Scalar(ldexp(Eigen::numext::real(x), -m_squarings), ldexp(Eigen::numext::imag(x), -m_squarings));
|
||||
}
|
||||
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
private:
|
||||
int m_squarings;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct MatrixExponentialScalingOp<Scalar, /*IsComplex=*/false> {
|
||||
/** \brief Constructor.
|
||||
*
|
||||
* \param[in] squarings The integer \f$ s \f$ in this document.
|
||||
*/
|
||||
MatrixExponentialScalingOp(int squarings) : m_squarings(squarings) {}
|
||||
|
||||
/** \brief Scale a matrix coefficient.
|
||||
*
|
||||
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
|
||||
*/
|
||||
inline const ComplexScalar operator()(const ComplexScalar& x) const {
|
||||
inline const Scalar operator()(const Scalar& x) const {
|
||||
using std::ldexp;
|
||||
return ComplexScalar(ldexp(x.real(), -m_squarings), ldexp(x.imag(), -m_squarings));
|
||||
return ldexp(x, -m_squarings);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -220,6 +232,7 @@ struct matrix_exp_computeUV {
|
||||
|
||||
template <typename MatrixType>
|
||||
struct matrix_exp_computeUV<MatrixType, float> {
|
||||
using Scalar = typename traits<MatrixType>::Scalar;
|
||||
template <typename ArgType>
|
||||
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
|
||||
using std::frexp;
|
||||
@ -234,7 +247,7 @@ struct matrix_exp_computeUV<MatrixType, float> {
|
||||
const float maxnorm = 3.925724783138660f;
|
||||
frexp(l1norm / maxnorm, &squarings);
|
||||
if (squarings < 0) squarings = 0;
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<float>(squarings));
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
|
||||
matrix_exp_pade7(A, U, V);
|
||||
}
|
||||
}
|
||||
@ -242,12 +255,12 @@ struct matrix_exp_computeUV<MatrixType, float> {
|
||||
|
||||
template <typename MatrixType>
|
||||
struct matrix_exp_computeUV<MatrixType, double> {
|
||||
typedef typename NumTraits<typename traits<MatrixType>::Scalar>::Real RealScalar;
|
||||
using Scalar = typename traits<MatrixType>::Scalar;
|
||||
template <typename ArgType>
|
||||
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
|
||||
using std::frexp;
|
||||
using std::pow;
|
||||
const RealScalar l1norm = arg.cwiseAbs().colwise().sum().maxCoeff();
|
||||
const double l1norm = arg.cwiseAbs().colwise().sum().maxCoeff();
|
||||
squarings = 0;
|
||||
if (l1norm < 1.495585217958292e-002) {
|
||||
matrix_exp_pade3(arg, U, V);
|
||||
@ -258,10 +271,10 @@ struct matrix_exp_computeUV<MatrixType, double> {
|
||||
} else if (l1norm < 2.097847961257068e+000) {
|
||||
matrix_exp_pade9(arg, U, V);
|
||||
} else {
|
||||
const RealScalar maxnorm = 5.371920351148152;
|
||||
const double maxnorm = 5.371920351148152;
|
||||
frexp(l1norm / maxnorm, &squarings);
|
||||
if (squarings < 0) squarings = 0;
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<RealScalar>(squarings));
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
|
||||
matrix_exp_pade13(A, U, V);
|
||||
}
|
||||
}
|
||||
@ -271,6 +284,7 @@ template <typename MatrixType>
|
||||
struct matrix_exp_computeUV<MatrixType, long double> {
|
||||
template <typename ArgType>
|
||||
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
|
||||
using Scalar = typename traits<MatrixType>::Scalar;
|
||||
#if LDBL_MANT_DIG == 53 // double precision
|
||||
matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings);
|
||||
|
||||
@ -295,7 +309,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
|
||||
const long double maxnorm = 4.0246098906697353063L;
|
||||
frexp(l1norm / maxnorm, &squarings);
|
||||
if (squarings < 0) squarings = 0;
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
|
||||
matrix_exp_pade13(A, U, V);
|
||||
}
|
||||
|
||||
@ -315,7 +329,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
|
||||
const long double maxnorm = 3.2579440895405400856599663723517L;
|
||||
frexp(l1norm / maxnorm, &squarings);
|
||||
if (squarings < 0) squarings = 0;
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
|
||||
matrix_exp_pade17(A, U, V);
|
||||
}
|
||||
|
||||
@ -335,7 +349,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
|
||||
const long double maxnorm = 2.884233277829519311757165057717815L;
|
||||
frexp(l1norm / maxnorm, &squarings);
|
||||
if (squarings < 0) squarings = 0;
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings));
|
||||
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
|
||||
matrix_exp_pade17(A, U, V);
|
||||
}
|
||||
|
||||
@ -382,9 +396,7 @@ template <typename ArgType, typename ResultType>
|
||||
void matrix_exp_compute(const ArgType& arg, ResultType& result, false_type) // default
|
||||
{
|
||||
typedef typename ArgType::PlainObject MatrixType;
|
||||
typedef typename traits<MatrixType>::Scalar Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef typename std::complex<RealScalar> ComplexScalar;
|
||||
typedef make_complex_t<typename traits<MatrixType>::Scalar> ComplexScalar;
|
||||
result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>);
|
||||
}
|
||||
|
||||
|
@ -382,7 +382,7 @@ struct matrix_function_compute<MatrixType, 0> {
|
||||
static const int Rows = Traits::RowsAtCompileTime, Cols = Traits::ColsAtCompileTime;
|
||||
static const int MaxRows = Traits::MaxRowsAtCompileTime, MaxCols = Traits::MaxColsAtCompileTime;
|
||||
|
||||
typedef std::complex<Scalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Matrix<ComplexScalar, Rows, Cols, 0, MaxRows, MaxCols> ComplexMatrix;
|
||||
|
||||
ComplexMatrix CA = A.template cast<ComplexScalar>();
|
||||
@ -476,7 +476,7 @@ class MatrixFunctionReturnValue : public ReturnByValue<MatrixFunctionReturnValue
|
||||
typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType;
|
||||
typedef internal::remove_all_t<NestedEvalType> NestedEvalTypeClean;
|
||||
typedef internal::traits<NestedEvalTypeClean> Traits;
|
||||
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
|
||||
DynMatrixType;
|
||||
|
||||
|
@ -330,7 +330,7 @@ class MatrixLogarithmReturnValue : public ReturnByValue<MatrixLogarithmReturnVal
|
||||
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
|
||||
typedef internal::remove_all_t<DerivedEvalType> DerivedEvalTypeClean;
|
||||
typedef internal::traits<DerivedEvalTypeClean> Traits;
|
||||
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
|
||||
DynMatrixType;
|
||||
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;
|
||||
|
@ -91,7 +91,7 @@ class MatrixPowerAtomic : internal::noncopyable {
|
||||
enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime };
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::RealScalar RealScalar;
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Block<MatrixType, Dynamic, Dynamic> ResultType;
|
||||
|
||||
const MatrixType& m_A;
|
||||
@ -380,7 +380,7 @@ class MatrixPower : internal::noncopyable {
|
||||
Index cols() const { return m_A.cols(); }
|
||||
|
||||
private:
|
||||
typedef std::complex<RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar> ComplexScalar;
|
||||
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime>
|
||||
ComplexMatrix;
|
||||
|
||||
@ -628,7 +628,7 @@ template <typename Derived>
|
||||
class MatrixComplexPowerReturnValue : public ReturnByValue<MatrixComplexPowerReturnValue<Derived> > {
|
||||
public:
|
||||
typedef typename Derived::PlainObject PlainObject;
|
||||
typedef typename std::complex<typename Derived::RealScalar> ComplexScalar;
|
||||
typedef internal::make_complex_t<typename Derived::Scalar> ComplexScalar;
|
||||
|
||||
/**
|
||||
* \brief Constructor.
|
||||
@ -685,7 +685,7 @@ const MatrixPowerReturnValue<Derived> MatrixBase<Derived>::pow(const RealScalar&
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const std::complex<RealScalar>& p) const {
|
||||
const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const internal::make_complex_t<Scalar>& p) const {
|
||||
return MatrixComplexPowerReturnValue<Derived>(derived(), p);
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,7 @@ class PolynomialSolverBase {
|
||||
|
||||
typedef Scalar_ Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef std::complex<RealScalar> RootType;
|
||||
typedef internal::make_complex_t<Scalar> RootType;
|
||||
typedef Matrix<RootType, Deg_, 1> RootsType;
|
||||
|
||||
typedef DenseIndex Index;
|
||||
@ -308,7 +308,7 @@ class PolynomialSolver : public PolynomialSolverBase<Scalar_, Deg_> {
|
||||
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, ComplexEigenSolver<CompanionMatrixType>,
|
||||
EigenSolver<CompanionMatrixType> >
|
||||
EigenSolverType;
|
||||
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar> > ComplexScalar;
|
||||
typedef internal::make_complex_t<Scalar_> ComplexScalar;
|
||||
|
||||
public:
|
||||
/** Computes the complex roots of a new polynomial. */
|
||||
|
Loading…
x
Reference in New Issue
Block a user