Device implementation of log for std::complex types.

(cherry picked from commit 7e6a1c129c201db4eff46f4dd68acdc7e935eaf2)
This commit is contained in:
Nathan Luehr 2021-04-19 18:05:27 -05:00 committed by Rasmus Munk Larsen
parent d9288f078d
commit d1825cbb68
2 changed files with 36 additions and 3 deletions

View File

@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2006-2010 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -687,6 +688,30 @@ struct expm1_retval
typedef Scalar type;
};
/****************************************************************************
* Implementation of log *
****************************************************************************/
// Complex log defined in MathFunctionsImpl.h.
template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z);
template<typename Scalar>
struct log_impl {
EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x)
{
EIGEN_USING_STD(log);
return static_cast<Scalar>(log(x));
}
};
template<typename Scalar>
struct log_impl<std::complex<Scalar> > {
EIGEN_DEVICE_FUNC static inline std::complex<Scalar> run(const std::complex<Scalar>& z)
{
return complex_log(z);
}
};
/****************************************************************************
* Implementation of log1p *
****************************************************************************/
@ -700,7 +725,7 @@ namespace std_fallback {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_USING_STD(log);
Scalar x1p = RealScalar(1) + x;
Scalar log_1p = log(x1p);
Scalar log_1p = log_impl<Scalar>::run(x1p);
const bool is_small = numext::equal_strict(x1p, Scalar(1));
const bool is_inf = numext::equal_strict(x1p, log_1p);
return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1)));
@ -1460,8 +1485,7 @@ T rsqrt(const T& x)
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
EIGEN_USING_STD(log);
return static_cast<T>(log(x));
return internal::log_impl<T>::run(x);
}
#if defined(SYCL_DEVICE_ONLY)

View File

@ -184,6 +184,15 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz );
}
template<typename T>
EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
// Computes complex log.
T a = numext::abs(z);
EIGEN_USING_STD(atan2);
T b = atan2(z.imag(), z.real());
return std::complex<T>(numext::log(a), b);
}
} // end namespace internal
} // end namespace Eigen