diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index 3eb91a091..eb3030aa7 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -848,19 +848,23 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast(c #if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000 __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) { - return static_cast(__shfl_sync(mask, static_cast<__half>(var), srcLane, width)); + const __half h = var; + return static_cast(__shfl_sync(mask, h, srcLane, width)); } __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) { - return static_cast(__shfl_up_sync(mask, static_cast<__half>(var), delta, width)); + const __half h = var; + return static_cast(__shfl_up_sync(mask, h, delta, width)); } __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) { - return static_cast(__shfl_down_sync(mask, static_cast<__half>(var), delta, width)); + const __half h = var; + return static_cast(__shfl_down_sync(mask, h, delta, width)); } __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) { - return static_cast(__shfl_xor_sync(mask, static_cast<__half>(var), laneMask, width)); + const __half h = var; + return static_cast(__shfl_xor_sync(mask, h, laneMask, width)); } #else // HIP or CUDA SDK < 9.0