diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h index b1618e567..deb4c8694 100644 --- a/Eigen/src/Core/arch/CUDA/Complex.h +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -67,27 +67,26 @@ std::complex complex_divide_fast(const std::complex& a, const std::complex const T a_imag = numext::imag(a); const T b_real = numext::real(b); const T b_imag = numext::imag(b); - const T norm = T(1) / (b_real * b_real + b_imag * b_imag); - return std::complex((a_real * b_real + a_imag * b_imag) * norm, - (a_imag * b_real - a_real * b_imag) * norm); + const T norm = (b_real * b_real + b_imag * b_imag); + return std::complex((a_real * b_real + a_imag * b_imag) / norm, + (a_imag * b_real - a_real * b_imag) / norm); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex complex_divide_stable(const std::complex& a, const std::complex& b) { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); const T b_real = numext::real(b); const T b_imag = numext::imag(b); - // Guard against over/under-flow. - const T scale = T(1) / (numext::abs(b_real) + numext::abs(b_imag)); - const T a_real_scaled = numext::real(a) * scale; - const T a_imag_scaled = numext::imag(a) * scale; - const T b_real_scaled = b_real * scale; - const T b_imag_scaled = b_imag * scale; - - const T b_norm2_scaled = b_real_scaled * b_real_scaled + b_imag_scaled * b_imag_scaled; - return std::complex( - (a_real_scaled * b_real_scaled + a_imag_scaled * b_imag_scaled) / b_norm2_scaled, - (a_imag_scaled * b_real_scaled - a_real_scaled * b_imag_scaled) / b_norm2_scaled); + // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf), + // guards against over/under-flow. + const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real); + const T rscale = scale_imag ? T(1) : b_real / b_imag; + const T iscale = scale_imag ? b_imag / b_real : T(1); + const T denominator = b_real * rscale + b_imag * iscale; + return std::complex((a_real * rscale + a_imag * iscale) / denominator, + (a_imag * rscale - a_real * iscale) / denominator); } template