Remove assumption of std::complex for complex scalar types.

This commit is contained in:
Antonio Sanchez 2025-02-12 11:21:44 -08:00
parent 6b4881ad48
commit 22cd7307dd
21 changed files with 273 additions and 115 deletions

View File

@ -182,6 +182,35 @@ struct imag_ref_retval {
typedef typename NumTraits<Scalar>::Real& type; typedef typename NumTraits<Scalar>::Real& type;
}; };
} // namespace internal
namespace numext {
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
const Scalar& x) {
return internal::real_ref_impl<Scalar>::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
} // namespace numext
namespace internal {
/**************************************************************************** /****************************************************************************
* Implementation of conj * * Implementation of conj *
****************************************************************************/ ****************************************************************************/
@ -221,7 +250,9 @@ template <typename Scalar>
struct abs2_impl_default<Scalar, true> // IsComplex struct abs2_impl_default<Scalar, true> // IsComplex
{ {
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { return x.real() * x.real() + x.imag() * x.imag(); } EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
return numext::real(x) * numext::real(x) + numext::imag(x) * numext::imag(x);
}
}; };
template <typename Scalar> template <typename Scalar>
@ -250,16 +281,14 @@ struct sqrt_impl {
}; };
// Complex sqrt defined in MathFunctionsImpl.h. // Complex sqrt defined in MathFunctionsImpl.h.
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x); EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& a_x);
// Custom implementation is faster than `std::sqrt`, works on // Custom implementation is faster than `std::sqrt`, works on
// GPU, and correctly handles special cases (unlike MSVC). // GPU, and correctly handles special cases (unlike MSVC).
template <typename T> template <typename T>
struct sqrt_impl<std::complex<T>> { struct sqrt_impl<std::complex<T>> {
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) { EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) { return complex_sqrt(x); }
return complex_sqrt<T>(x);
}
}; };
template <typename Scalar> template <typename Scalar>
@ -272,13 +301,13 @@ template <typename T>
struct rsqrt_impl; struct rsqrt_impl;
// Complex rsqrt defined in MathFunctionsImpl.h. // Complex rsqrt defined in MathFunctionsImpl.h.
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x); EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& a_x);
template <typename T> template <typename T>
struct rsqrt_impl<std::complex<T>> { struct rsqrt_impl<std::complex<T>> {
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) { EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x) {
return complex_rsqrt<T>(x); return complex_rsqrt(x);
} }
}; };
@ -299,7 +328,7 @@ struct norm1_default_impl<Scalar, true> {
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) { EIGEN_DEVICE_FUNC static inline RealScalar run(const Scalar& x) {
EIGEN_USING_STD(abs); EIGEN_USING_STD(abs);
return abs(x.real()) + abs(x.imag()); return abs(numext::real(x)) + abs(numext::imag(x));
} }
}; };
@ -469,8 +498,8 @@ struct expm1_retval {
****************************************************************************/ ****************************************************************************/
// Complex log defined in MathFunctionsImpl.h. // Complex log defined in MathFunctionsImpl.h.
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z); EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z);
template <typename Scalar> template <typename Scalar>
struct log_impl { struct log_impl {
@ -846,7 +875,7 @@ struct sign_impl<Scalar, true, IsInteger> {
real_type aa = abs(a); real_type aa = abs(a);
if (aa == real_type(0)) return Scalar(0); if (aa == real_type(0)) return Scalar(0);
aa = real_type(1) / aa; aa = real_type(1) / aa;
return Scalar(a.real() * aa, a.imag() * aa); return Scalar(numext::real(a) * aa, numext::imag(a) * aa);
} }
}; };
@ -1042,27 +1071,6 @@ SYCL_SPECIALIZE_FLOATING_TYPES_BINARY(maxi, fmax)
#endif #endif
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real, Scalar) real(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline internal::add_const_on_value_type_t<EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(
const Scalar& x) {
return internal::real_ref_impl<Scalar>::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(real_ref, Scalar) real_ref(Scalar& x) {
return EIGEN_MATHFUNC_IMPL(real_ref, Scalar)::run(x);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
template <typename Scalar> template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(arg, Scalar) arg(const Scalar& x) { EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(arg, Scalar) arg(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(arg, Scalar)::run(x); return EIGEN_MATHFUNC_IMPL(arg, Scalar)::run(x);

View File

@ -171,8 +171,8 @@ struct hypot_impl {
// Generic complex sqrt implementation that correctly handles corner cases // Generic complex sqrt implementation that correctly handles corner cases
// according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt // according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) { EIGEN_DEVICE_FUNC ComplexT complex_sqrt(const ComplexT& z) {
// Computes the principal sqrt of the input. // Computes the principal sqrt of the input.
// //
// For a complex square root of the number x + i*y. We want to find real // For a complex square root of the number x + i*y. We want to find real
@ -194,21 +194,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
// if x == 0: u = w, v = sign(y) * w // if x == 0: u = w, v = sign(y) * w
// if x > 0: u = w, v = y / (2 * w) // if x > 0: u = w, v = y / (2 * w)
// if x < 0: u = |y| / (2 * w), v = sign(y) * w // if x < 0: u = |y| / (2 * w), v = sign(y) * w
using T = typename NumTraits<ComplexT>::Real;
const T x = numext::real(z); const T x = numext::real(z);
const T y = numext::imag(z); const T y = numext::imag(z);
const T zero = T(0); const T zero = T(0);
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y))); const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
return (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y) return (numext::isinf)(y) ? ComplexT(NumTraits<T>::infinity(), y)
: numext::is_exactly_zero(x) ? std::complex<T>(w, y < zero ? -w : w) : numext::is_exactly_zero(x) ? ComplexT(w, y < zero ? -w : w)
: x > zero ? std::complex<T>(w, y / (2 * w)) : x > zero ? ComplexT(w, y / (2 * w))
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w); : ComplexT(numext::abs(y) / (2 * w), y < zero ? -w : w);
} }
// Generic complex rsqrt implementation. // Generic complex rsqrt implementation.
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) { EIGEN_DEVICE_FUNC ComplexT complex_rsqrt(const ComplexT& z) {
// Computes the principal reciprocal sqrt of the input. // Computes the principal reciprocal sqrt of the input.
// //
// For a complex reciprocal square root of the number z = x + i*y. We want to // For a complex reciprocal square root of the number z = x + i*y. We want to
@ -230,7 +230,7 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
// if x == 0: u = w / |z|, v = -sign(y) * w / |z| // if x == 0: u = w / |z|, v = -sign(y) * w / |z|
// if x > 0: u = w / |z|, v = -y / (2 * w * |z|) // if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
// if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z| // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
using T = typename NumTraits<ComplexT>::Real;
const T x = numext::real(z); const T x = numext::real(z);
const T y = numext::imag(z); const T y = numext::imag(z);
const T zero = T(0); const T zero = T(0);
@ -239,20 +239,21 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z)); const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
const T woz = w / abs_z; const T woz = w / abs_z;
// Corner cases consistent with 1/sqrt(z) on gcc/clang. // Corner cases consistent with 1/sqrt(z) on gcc/clang.
return numext::is_exactly_zero(abs_z) ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN()) return numext::is_exactly_zero(abs_z) ? ComplexT(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
: ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero) : ((numext::isinf)(x) || (numext::isinf)(y)) ? ComplexT(zero, zero)
: numext::is_exactly_zero(x) ? std::complex<T>(woz, y < zero ? woz : -woz) : numext::is_exactly_zero(x) ? ComplexT(woz, y < zero ? woz : -woz)
: x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z)) : x > zero ? ComplexT(woz, -y / (2 * w * abs_z))
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz); : ComplexT(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
} }
template <typename T> template <typename ComplexT>
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) { EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) {
// Computes complex log. // Computes complex log.
using T = typename NumTraits<ComplexT>::Real;
T a = numext::abs(z); T a = numext::abs(z);
EIGEN_USING_STD(atan2); EIGEN_USING_STD(atan2);
T b = atan2(z.imag(), z.real()); T b = atan2(z.imag(), z.real());
return std::complex<T>(numext::log(a), b); return ComplexT(numext::log(a), b);
} }
} // end namespace internal } // end namespace internal

View File

@ -112,7 +112,7 @@ class MatrixBase : public DenseBase<Derived> {
ConstTransposeReturnType> ConstTransposeReturnType>
AdjointReturnType; AdjointReturnType;
/** \internal Return type of eigenvalues() */ /** \internal Return type of eigenvalues() */
typedef Matrix<std::complex<RealScalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor> typedef Matrix<internal::make_complex_t<Scalar>, internal::traits<Derived>::ColsAtCompileTime, 1, ColMajor>
EigenvaluesReturnType; EigenvaluesReturnType;
/** \internal the return type of identity */ /** \internal the return type of identity */
typedef CwiseNullaryOp<internal::scalar_identity_op<Scalar>, PlainObject> IdentityReturnType; typedef CwiseNullaryOp<internal::scalar_identity_op<Scalar>, PlainObject> IdentityReturnType;
@ -468,7 +468,7 @@ class MatrixBase : public DenseBase<Derived> {
EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root) EIGEN_MATRIX_FUNCTION(MatrixSquareRootReturnValue, sqrt, square root)
EIGEN_MATRIX_FUNCTION(MatrixLogarithmReturnValue, log, logarithm) EIGEN_MATRIX_FUNCTION(MatrixLogarithmReturnValue, log, logarithm)
EIGEN_MATRIX_FUNCTION_1(MatrixPowerReturnValue, pow, power to \c p, const RealScalar& p) EIGEN_MATRIX_FUNCTION_1(MatrixPowerReturnValue, pow, power to \c p, const RealScalar& p)
EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const std::complex<RealScalar>& p) EIGEN_MATRIX_FUNCTION_1(MatrixComplexPowerReturnValue, pow, power to \c p, const internal::make_complex_t<Scalar>& p)
protected: protected:
EIGEN_DEFAULT_COPY_CONSTRUCTOR(MatrixBase) EIGEN_DEFAULT_COPY_CONSTRUCTOR(MatrixBase)

View File

@ -497,7 +497,7 @@ class MatrixComplexPowerReturnValue;
namespace internal { namespace internal {
template <typename Scalar> template <typename Scalar>
struct stem_function { struct stem_function {
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef ComplexScalar type(ComplexScalar, int); typedef ComplexScalar type(ComplexScalar, int);
}; };
} // namespace internal } // namespace internal

View File

@ -745,6 +745,9 @@ using std::is_constant_evaluated;
constexpr bool is_constant_evaluated() { return false; } constexpr bool is_constant_evaluated() { return false; }
#endif #endif
template <typename Scalar>
using make_complex_t = std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar>>;
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen

View File

@ -885,8 +885,12 @@ struct scalar_div_cost {
}; };
template <typename T, bool Vectorized> template <typename T, bool Vectorized>
struct scalar_div_cost<std::complex<T>, Vectorized> { struct scalar_div_cost<T, Vectorized, std::enable_if_t<NumTraits<T>::IsComplex>> {
enum { value = 2 * scalar_div_cost<T>::value + 6 * NumTraits<T>::MulCost + 3 * NumTraits<T>::AddCost }; using RealScalar = typename NumTraits<T>::Real;
enum {
value =
2 * scalar_div_cost<RealScalar>::value + 6 * NumTraits<RealScalar>::MulCost + 3 * NumTraits<RealScalar>::AddCost
};
}; };
template <bool Vectorized> template <bool Vectorized>

View File

@ -70,7 +70,7 @@ class ComplexEigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is * \c float or \c double) and just \c Scalar if #Scalar is
* complex. * complex.
*/ */
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of eigenvalues as returned by eigenvalues(). /** \brief Type for vector of eigenvalues as returned by eigenvalues().
* *

View File

@ -75,7 +75,7 @@ class ComplexSchur {
* \c float or \c double) and just \c Scalar if #Scalar is * \c float or \c double) and just \c Scalar if #Scalar is
* complex. * complex.
*/ */
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for the matrices in the Schur decomposition. /** \brief Type for the matrices in the Schur decomposition.
* *

View File

@ -89,7 +89,7 @@ class EigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is * \c float or \c double) and just \c Scalar if #Scalar is
* complex. * complex.
*/ */
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of eigenvalues as returned by eigenvalues(). /** \brief Type for vector of eigenvalues as returned by eigenvalues().
* *

View File

@ -83,7 +83,7 @@ class GeneralizedEigenSolver {
* \c float or \c double) and just \c Scalar if #Scalar is * \c float or \c double) and just \c Scalar if #Scalar is
* complex. * complex.
*/ */
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
/** \brief Type for vector of real scalar values eigenvalues as returned by betas(). /** \brief Type for vector of real scalar values eigenvalues as returned by betas().
* *

View File

@ -69,7 +69,7 @@ class RealQZ {
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
}; };
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType; typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;

View File

@ -66,7 +66,7 @@ class RealSchur {
MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
}; };
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3 typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType; typedef Matrix<ComplexScalar, ColsAtCompileTime, 1, Options & ~RowMajor, MaxColsAtCompileTime, 1> EigenvalueType;

132
test/CustomComplex.h Normal file
View File

@ -0,0 +1,132 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TEST_CUSTOM_COMPLEX_H
#define EIGEN_TEST_CUSTOM_COMPLEX_H
#include <ostream>
#include <sstream>
namespace custom_complex {
template <typename Real>
struct CustomComplex {
CustomComplex() : re{0}, im{0} {}
CustomComplex(const CustomComplex& other) = default;
CustomComplex(CustomComplex&& other) = default;
CustomComplex& operator=(const CustomComplex& other) = default;
CustomComplex& operator=(CustomComplex&& other) = default;
CustomComplex(Real x) : re{x}, im{0} {}
CustomComplex(Real x, Real y) : re{x}, im{y} {}
CustomComplex operator+(const CustomComplex& other) const { return CustomComplex(re + other.re, im + other.im); }
CustomComplex operator-() const { return CustomComplex(-re, -im); }
CustomComplex operator-(const CustomComplex& other) const { return CustomComplex(re - other.re, im - other.im); }
CustomComplex operator*(const CustomComplex& other) const {
return CustomComplex(re * other.re - im * other.im, re * other.im + im * other.re);
}
CustomComplex operator/(const CustomComplex& other) const {
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
// guards against over/under-flow.
const bool scale_imag = numext::abs(other.im) <= numext::abs(other.re);
const Real rscale = scale_imag ? Real(1) : other.re / other.im;
const Real iscale = scale_imag ? other.im / other.re : Real(1);
const Real denominator = other.re * rscale + other.im * iscale;
return CustomComplex((re * rscale + im * iscale) / denominator, (im * rscale - re * iscale) / denominator);
}
CustomComplex& operator+=(const CustomComplex& other) {
*this = *this + other;
return *this;
}
CustomComplex& operator-=(const CustomComplex& other) {
*this = *this - other;
return *this;
}
CustomComplex& operator*=(const CustomComplex& other) {
*this = *this * other;
return *this;
}
CustomComplex& operator/=(const CustomComplex& other) {
*this = *this / other;
return *this;
}
bool operator==(const CustomComplex& other) const {
return numext::equal_strict(re, other.re) && numext::equal_strict(im, other.im);
}
bool operator!=(const CustomComplex& other) const { return !(*this == other); }
friend CustomComplex operator+(const Real& a, const CustomComplex& b) { return CustomComplex(a + b.re, b.im); }
friend CustomComplex operator-(const Real& a, const CustomComplex& b) { return CustomComplex(a - b.re, -b.im); }
friend CustomComplex operator*(const Real& a, const CustomComplex& b) { return CustomComplex(a * b.re, a * b.im); }
friend CustomComplex operator*(const CustomComplex& a, const Real& b) { return CustomComplex(a.re * b, a.im * b); }
friend CustomComplex operator/(const CustomComplex& a, const Real& b) { return CustomComplex(a.re / b, a.im / b); }
friend std::ostream& operator<<(std::ostream& stream, const CustomComplex& x) {
std::stringstream ss;
ss << "(" << x.re << ", " << x.im << ")";
stream << ss.str();
return stream;
}
Real re;
Real im;
};
template <typename Real>
Real real(const CustomComplex<Real>& x) {
return x.re;
}
template <typename Real>
Real imag(const CustomComplex<Real>& x) {
return x.im;
}
template <typename Real>
CustomComplex<Real> conj(const CustomComplex<Real>& x) {
return CustomComplex<Real>(x.re, -x.im);
}
template <typename Real>
CustomComplex<Real> sqrt(const CustomComplex<Real>& x) {
return Eigen::internal::complex_sqrt(x);
}
template <typename Real>
Real abs(const CustomComplex<Real>& x) {
return Eigen::numext::sqrt(x.re * x.re + x.im * x.im);
}
} // namespace custom_complex
template <typename Real>
using CustomComplex = custom_complex::CustomComplex<Real>;
namespace Eigen {
template <typename Real>
struct NumTraits<CustomComplex<Real>> : NumTraits<Real> {
enum { IsComplex = 1 };
};
namespace numext {
template <typename Real>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const CustomComplex<Real>& x) {
return (numext::isfinite)(x.re) && (numext::isfinite)(x.im);
}
} // namespace numext
} // namespace Eigen
#endif // EIGEN_TEST_CUSTOM_COMPLEX_H

View File

@ -12,6 +12,7 @@
#include <limits> #include <limits>
#include <Eigen/Eigenvalues> #include <Eigen/Eigenvalues>
#include <Eigen/LU> #include <Eigen/LU>
#include "CustomComplex.h"
template <typename MatrixType> template <typename MatrixType>
bool find_pivot(typename MatrixType::Scalar tol, MatrixType& diffs, Index col = 0) { bool find_pivot(typename MatrixType::Scalar tol, MatrixType& diffs, Index col = 0) {
@ -165,5 +166,8 @@ EIGEN_DECLARE_TEST(eigensolver_complex) {
// Test problem size constructors // Test problem size constructors
CALL_SUBTEST_5(ComplexEigenSolver<MatrixXf> tmp(s)); CALL_SUBTEST_5(ComplexEigenSolver<MatrixXf> tmp(s));
// Test custom complex scalar type.
CALL_SUBTEST_6(eigensolver(Matrix<CustomComplex<double>, 5, 5>()));
TEST_SET_BUT_UNUSED_VARIABLE(s) TEST_SET_BUT_UNUSED_VARIABLE(s)
} }

View File

@ -15,7 +15,7 @@
namespace Eigen { namespace Eigen {
template <bool NeedUprade> template <bool IsReal>
struct MakeComplex { struct MakeComplex {
template <typename T> template <typename T>
EIGEN_DEVICE_FUNC T operator()(const T& val) const { EIGEN_DEVICE_FUNC T operator()(const T& val) const {
@ -26,16 +26,8 @@ struct MakeComplex {
template <> template <>
struct MakeComplex<true> { struct MakeComplex<true> {
template <typename T> template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> operator()(const T& val) const { EIGEN_DEVICE_FUNC internal::make_complex_t<T> operator()(const T& val) const {
return std::complex<T>(val, 0); return internal::make_complex_t<T>(val, T(0));
}
};
template <>
struct MakeComplex<false> {
template <typename T>
EIGEN_DEVICE_FUNC std::complex<T> operator()(const std::complex<T>& val) const {
return val;
} }
}; };
@ -49,17 +41,17 @@ struct PartOf {
template <> template <>
struct PartOf<RealPart> { struct PartOf<RealPart> {
template <typename T> template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
T operator()(const std::complex<T>& val) const { typename NumTraits<T>::Real operator()(const T& val) const {
return val.real(); return Eigen::numext::real(val);
} }
}; };
template <> template <>
struct PartOf<ImagPart> { struct PartOf<ImagPart> {
template <typename T> template <typename T, typename EnableIf = std::enable_if_t<NumTraits<T>::IsComplex>>
T operator()(const std::complex<T>& val) const { typename NumTraits<T>::Real operator()(const T& val) const {
return val.imag(); return Eigen::numext::imag(val);
} }
}; };
@ -67,8 +59,9 @@ namespace internal {
template <typename FFT, typename XprType, int FFTResultType, int FFTDir> template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> { struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> {
typedef traits<XprType> XprTraits; typedef traits<XprType> XprTraits;
typedef typename NumTraits<typename XprTraits::Scalar>::Real RealScalar; typedef typename XprTraits::Scalar Scalar;
typedef typename std::complex<RealScalar> ComplexScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
typedef make_complex_t<Scalar> ComplexScalar;
typedef typename XprTraits::Scalar InputScalar; typedef typename XprTraits::Scalar InputScalar;
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar> typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar; OutputScalar;
@ -109,7 +102,7 @@ class TensorFFTOp : public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, F
public: public:
typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar> typedef std::conditional_t<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>
OutputScalar; OutputScalar;
typedef OutputScalar CoeffReturnType; typedef OutputScalar CoeffReturnType;
@ -137,7 +130,7 @@ struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, D
typedef DSizes<Index, NumDims> Dimensions; typedef DSizes<Index, NumDims> Dimensions;
typedef typename XprType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions; typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
typedef internal::traits<XprType> XprTraits; typedef internal::traits<XprType> XprTraits;
typedef typename XprTraits::Scalar InputScalar; typedef typename XprTraits::Scalar InputScalar;

View File

@ -111,12 +111,13 @@ class DGMRES : public IterativeSolverBase<DGMRES<MatrixType_, Preconditioner_> >
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex; typedef typename MatrixType::StorageIndex StorageIndex;
typedef typename MatrixType::RealScalar RealScalar; typedef typename MatrixType::RealScalar RealScalar;
typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Preconditioner_ Preconditioner; typedef Preconditioner_ Preconditioner;
typedef Matrix<Scalar, Dynamic, Dynamic> DenseMatrix; typedef Matrix<Scalar, Dynamic, Dynamic> DenseMatrix;
typedef Matrix<RealScalar, Dynamic, Dynamic> DenseRealMatrix; typedef Matrix<RealScalar, Dynamic, Dynamic> DenseRealMatrix;
typedef Matrix<Scalar, Dynamic, 1> DenseVector; typedef Matrix<Scalar, Dynamic, 1> DenseVector;
typedef Matrix<RealScalar, Dynamic, 1> DenseRealVector; typedef Matrix<RealScalar, Dynamic, 1> DenseRealVector;
typedef Matrix<std::complex<RealScalar>, Dynamic, 1> ComplexVector; typedef Matrix<ComplexScalar, Dynamic, 1> ComplexVector;
/** Default constructor. */ /** Default constructor. */
DGMRES() DGMRES()
@ -389,15 +390,15 @@ inline typename DGMRES<MatrixType_, Preconditioner_>::ComplexVector DGMRES<Matri
Index j = 0; Index j = 0;
while (j < it - 1) { while (j < it - 1) {
if (T(j + 1, j) == Scalar(0)) { if (T(j + 1, j) == Scalar(0)) {
eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0)); eig(j) = ComplexScalar(T(j, j), RealScalar(0));
j++; j++;
} else { } else {
eig(j) = std::complex<RealScalar>(T(j, j), T(j + 1, j)); eig(j) = ComplexScalar(T(j, j), T(j + 1, j));
eig(j + 1) = std::complex<RealScalar>(T(j, j + 1), T(j + 1, j + 1)); eig(j + 1) = ComplexScalar(T(j, j + 1), T(j + 1, j + 1));
j++; j++;
} }
} }
if (j < it - 1) eig(j) = std::complex<RealScalar>(T(j, j), RealScalar(0)); if (j < it - 1) eig(j) = ComplexScalar(T(j, j), RealScalar(0));
return eig; return eig;
} }

View File

@ -23,8 +23,10 @@ namespace internal {
* *
* This struct is used by CwiseUnaryOp to scale a matrix by \f$ 2^{-s} \f$. * This struct is used by CwiseUnaryOp to scale a matrix by \f$ 2^{-s} \f$.
*/ */
template <typename RealScalar> template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
struct MatrixExponentialScalingOp { struct MatrixExponentialScalingOp {
using RealScalar = typename NumTraits<Scalar>::Real;
/** \brief Constructor. /** \brief Constructor.
* *
* \param[in] squarings The integer \f$ s \f$ in this document. * \param[in] squarings The integer \f$ s \f$ in this document.
@ -35,20 +37,30 @@ struct MatrixExponentialScalingOp {
* *
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$. * \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
*/ */
inline const RealScalar operator()(const RealScalar& x) const { inline const Scalar operator()(const Scalar& x) const {
using std::ldexp; using std::ldexp;
return ldexp(x, -m_squarings); return Scalar(ldexp(Eigen::numext::real(x), -m_squarings), ldexp(Eigen::numext::imag(x), -m_squarings));
} }
typedef std::complex<RealScalar> ComplexScalar; private:
int m_squarings;
};
template <typename Scalar>
struct MatrixExponentialScalingOp<Scalar, /*IsComplex=*/false> {
/** \brief Constructor.
*
* \param[in] squarings The integer \f$ s \f$ in this document.
*/
MatrixExponentialScalingOp(int squarings) : m_squarings(squarings) {}
/** \brief Scale a matrix coefficient. /** \brief Scale a matrix coefficient.
* *
* \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$. * \param[in,out] x The scalar to be scaled, becoming \f$ 2^{-s} x \f$.
*/ */
inline const ComplexScalar operator()(const ComplexScalar& x) const { inline const Scalar operator()(const Scalar& x) const {
using std::ldexp; using std::ldexp;
return ComplexScalar(ldexp(x.real(), -m_squarings), ldexp(x.imag(), -m_squarings)); return ldexp(x, -m_squarings);
} }
private: private:
@ -220,6 +232,7 @@ struct matrix_exp_computeUV {
template <typename MatrixType> template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, float> { struct matrix_exp_computeUV<MatrixType, float> {
using Scalar = typename traits<MatrixType>::Scalar;
template <typename ArgType> template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using std::frexp; using std::frexp;
@ -234,7 +247,7 @@ struct matrix_exp_computeUV<MatrixType, float> {
const float maxnorm = 3.925724783138660f; const float maxnorm = 3.925724783138660f;
frexp(l1norm / maxnorm, &squarings); frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0; if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<float>(squarings)); MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade7(A, U, V); matrix_exp_pade7(A, U, V);
} }
} }
@ -242,12 +255,12 @@ struct matrix_exp_computeUV<MatrixType, float> {
template <typename MatrixType> template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, double> { struct matrix_exp_computeUV<MatrixType, double> {
typedef typename NumTraits<typename traits<MatrixType>::Scalar>::Real RealScalar; using Scalar = typename traits<MatrixType>::Scalar;
template <typename ArgType> template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using std::frexp; using std::frexp;
using std::pow; using std::pow;
const RealScalar l1norm = arg.cwiseAbs().colwise().sum().maxCoeff(); const double l1norm = arg.cwiseAbs().colwise().sum().maxCoeff();
squarings = 0; squarings = 0;
if (l1norm < 1.495585217958292e-002) { if (l1norm < 1.495585217958292e-002) {
matrix_exp_pade3(arg, U, V); matrix_exp_pade3(arg, U, V);
@ -258,10 +271,10 @@ struct matrix_exp_computeUV<MatrixType, double> {
} else if (l1norm < 2.097847961257068e+000) { } else if (l1norm < 2.097847961257068e+000) {
matrix_exp_pade9(arg, U, V); matrix_exp_pade9(arg, U, V);
} else { } else {
const RealScalar maxnorm = 5.371920351148152; const double maxnorm = 5.371920351148152;
frexp(l1norm / maxnorm, &squarings); frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0; if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<RealScalar>(squarings)); MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade13(A, U, V); matrix_exp_pade13(A, U, V);
} }
} }
@ -271,6 +284,7 @@ template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, long double> { struct matrix_exp_computeUV<MatrixType, long double> {
template <typename ArgType> template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) { static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings) {
using Scalar = typename traits<MatrixType>::Scalar;
#if LDBL_MANT_DIG == 53 // double precision #if LDBL_MANT_DIG == 53 // double precision
matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings); matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings);
@ -295,7 +309,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 4.0246098906697353063L; const long double maxnorm = 4.0246098906697353063L;
frexp(l1norm / maxnorm, &squarings); frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0; if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings)); MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade13(A, U, V); matrix_exp_pade13(A, U, V);
} }
@ -315,7 +329,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 3.2579440895405400856599663723517L; const long double maxnorm = 3.2579440895405400856599663723517L;
frexp(l1norm / maxnorm, &squarings); frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0; if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings)); MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade17(A, U, V); matrix_exp_pade17(A, U, V);
} }
@ -335,7 +349,7 @@ struct matrix_exp_computeUV<MatrixType, long double> {
const long double maxnorm = 2.884233277829519311757165057717815L; const long double maxnorm = 2.884233277829519311757165057717815L;
frexp(l1norm / maxnorm, &squarings); frexp(l1norm / maxnorm, &squarings);
if (squarings < 0) squarings = 0; if (squarings < 0) squarings = 0;
MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<long double>(squarings)); MatrixType A = arg.unaryExpr(MatrixExponentialScalingOp<Scalar>(squarings));
matrix_exp_pade17(A, U, V); matrix_exp_pade17(A, U, V);
} }
@ -382,9 +396,7 @@ template <typename ArgType, typename ResultType>
void matrix_exp_compute(const ArgType& arg, ResultType& result, false_type) // default void matrix_exp_compute(const ArgType& arg, ResultType& result, false_type) // default
{ {
typedef typename ArgType::PlainObject MatrixType; typedef typename ArgType::PlainObject MatrixType;
typedef typename traits<MatrixType>::Scalar Scalar; typedef make_complex_t<typename traits<MatrixType>::Scalar> ComplexScalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename std::complex<RealScalar> ComplexScalar;
result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>); result = arg.matrixFunction(internal::stem_function_exp<ComplexScalar>);
} }

View File

@ -382,7 +382,7 @@ struct matrix_function_compute<MatrixType, 0> {
static const int Rows = Traits::RowsAtCompileTime, Cols = Traits::ColsAtCompileTime; static const int Rows = Traits::RowsAtCompileTime, Cols = Traits::ColsAtCompileTime;
static const int MaxRows = Traits::MaxRowsAtCompileTime, MaxCols = Traits::MaxColsAtCompileTime; static const int MaxRows = Traits::MaxRowsAtCompileTime, MaxCols = Traits::MaxColsAtCompileTime;
typedef std::complex<Scalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Rows, Cols, 0, MaxRows, MaxCols> ComplexMatrix; typedef Matrix<ComplexScalar, Rows, Cols, 0, MaxRows, MaxCols> ComplexMatrix;
ComplexMatrix CA = A.template cast<ComplexScalar>(); ComplexMatrix CA = A.template cast<ComplexScalar>();
@ -476,7 +476,7 @@ class MatrixFunctionReturnValue : public ReturnByValue<MatrixFunctionReturnValue
typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType; typedef typename internal::nested_eval<Derived, 10>::type NestedEvalType;
typedef internal::remove_all_t<NestedEvalType> NestedEvalTypeClean; typedef internal::remove_all_t<NestedEvalType> NestedEvalTypeClean;
typedef internal::traits<NestedEvalTypeClean> Traits; typedef internal::traits<NestedEvalTypeClean> Traits;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime> typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
DynMatrixType; DynMatrixType;

View File

@ -330,7 +330,7 @@ class MatrixLogarithmReturnValue : public ReturnByValue<MatrixLogarithmReturnVal
typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType; typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType;
typedef internal::remove_all_t<DerivedEvalType> DerivedEvalTypeClean; typedef internal::remove_all_t<DerivedEvalType> DerivedEvalTypeClean;
typedef internal::traits<DerivedEvalTypeClean> Traits; typedef internal::traits<DerivedEvalTypeClean> Traits;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime> typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, Traits::RowsAtCompileTime, Traits::ColsAtCompileTime>
DynMatrixType; DynMatrixType;
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType; typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;

View File

@ -91,7 +91,7 @@ class MatrixPowerAtomic : internal::noncopyable {
enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime }; enum { RowsAtCompileTime = MatrixType::RowsAtCompileTime, MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime };
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar; typedef typename MatrixType::RealScalar RealScalar;
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Block<MatrixType, Dynamic, Dynamic> ResultType; typedef Block<MatrixType, Dynamic, Dynamic> ResultType;
const MatrixType& m_A; const MatrixType& m_A;
@ -380,7 +380,7 @@ class MatrixPower : internal::noncopyable {
Index cols() const { return m_A.cols(); } Index cols() const { return m_A.cols(); }
private: private:
typedef std::complex<RealScalar> ComplexScalar; typedef internal::make_complex_t<Scalar> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime> typedef Matrix<ComplexScalar, Dynamic, Dynamic, 0, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime>
ComplexMatrix; ComplexMatrix;
@ -628,7 +628,7 @@ template <typename Derived>
class MatrixComplexPowerReturnValue : public ReturnByValue<MatrixComplexPowerReturnValue<Derived> > { class MatrixComplexPowerReturnValue : public ReturnByValue<MatrixComplexPowerReturnValue<Derived> > {
public: public:
typedef typename Derived::PlainObject PlainObject; typedef typename Derived::PlainObject PlainObject;
typedef typename std::complex<typename Derived::RealScalar> ComplexScalar; typedef internal::make_complex_t<typename Derived::Scalar> ComplexScalar;
/** /**
* \brief Constructor. * \brief Constructor.
@ -685,7 +685,7 @@ const MatrixPowerReturnValue<Derived> MatrixBase<Derived>::pow(const RealScalar&
} }
template <typename Derived> template <typename Derived>
const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const std::complex<RealScalar>& p) const { const MatrixComplexPowerReturnValue<Derived> MatrixBase<Derived>::pow(const internal::make_complex_t<Scalar>& p) const {
return MatrixComplexPowerReturnValue<Derived>(derived(), p); return MatrixComplexPowerReturnValue<Derived>(derived(), p);
} }

View File

@ -35,7 +35,7 @@ class PolynomialSolverBase {
typedef Scalar_ Scalar; typedef Scalar_ Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
typedef std::complex<RealScalar> RootType; typedef internal::make_complex_t<Scalar> RootType;
typedef Matrix<RootType, Deg_, 1> RootsType; typedef Matrix<RootType, Deg_, 1> RootsType;
typedef DenseIndex Index; typedef DenseIndex Index;
@ -308,7 +308,7 @@ class PolynomialSolver : public PolynomialSolverBase<Scalar_, Deg_> {
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, ComplexEigenSolver<CompanionMatrixType>, typedef std::conditional_t<NumTraits<Scalar>::IsComplex, ComplexEigenSolver<CompanionMatrixType>,
EigenSolver<CompanionMatrixType> > EigenSolver<CompanionMatrixType> >
EigenSolverType; EigenSolverType;
typedef std::conditional_t<NumTraits<Scalar>::IsComplex, Scalar, std::complex<Scalar> > ComplexScalar; typedef internal::make_complex_t<Scalar_> ComplexScalar;
public: public:
/** Computes the complex roots of a new polynomial. */ /** Computes the complex roots of a new polynomial. */