Better CUDA complex division.

The original produced NaNs when dividing 0/b for subnormal b.
The `complex_divide_stable` was changed to use the more common
Smith's algorithm.
This commit is contained in:
Antonio Sanchez 2021-04-23 16:04:01 -07:00 committed by Rasmus Munk Larsen
parent 172db7bfc3
commit 1c013be2cc

View File

@ -67,27 +67,26 @@ std::complex<T> complex_divide_fast(const std::complex<T>& a, const std::complex
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
(a_imag * b_real - a_real * b_imag) * norm);
const T norm = (b_real * b_real + b_imag * b_imag);
return std::complex<T>((a_real * b_real + a_imag * b_imag) / norm,
(a_imag * b_real - a_real * b_imag) / norm);
}
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
std::complex<T> complex_divide_stable(const std::complex<T>& a, const std::complex<T>& b) {
const T a_real = numext::real(a);
const T a_imag = numext::imag(a);
const T b_real = numext::real(b);
const T b_imag = numext::imag(b);
// Guard against over/under-flow.
const T scale = T(1) / (numext::abs(b_real) + numext::abs(b_imag));
const T a_real_scaled = numext::real(a) * scale;
const T a_imag_scaled = numext::imag(a) * scale;
const T b_real_scaled = b_real * scale;
const T b_imag_scaled = b_imag * scale;
const T b_norm2_scaled = b_real_scaled * b_real_scaled + b_imag_scaled * b_imag_scaled;
return std::complex<T>(
(a_real_scaled * b_real_scaled + a_imag_scaled * b_imag_scaled) / b_norm2_scaled,
(a_imag_scaled * b_real_scaled - a_real_scaled * b_imag_scaled) / b_norm2_scaled);
// Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf),
// guards against over/under-flow.
const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real);
const T rscale = scale_imag ? T(1) : b_real / b_imag;
const T iscale = scale_imag ? b_imag / b_real : T(1);
const T denominator = b_real * rscale + b_imag * iscale;
return std::complex<T>((a_real * rscale + a_imag * iscale) / denominator,
(a_imag * rscale - a_real * iscale) / denominator);
}
template<typename T>