Fix CUDA compilation error for pselect<half>.

This commit is contained in:
Rasmus Munk Larsen 2019-06-28 12:07:29 -07:00
parent 74a9dd1102
commit 8053eeb51e

View File

@ -180,8 +180,10 @@ template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect<half2>(const half2& mask,
const half2& a,
const half2& b) {
half result_low = __low2half(mask) == __half(0) ? __low2half(b) : __low2half(a);
half result_high = __high2half(mask) == __half(0) ? __high2half(b) : __high2half(a);
half mask_low = __low2half(mask);
half mask_high = __high2half(mask);
half result_low = mask_low == half(0) ? __low2half(b) : __low2half(a);
half result_high = mask_high == half(0) ? __high2half(b) : __high2half(a);
return __halves2half2(result_low, result_high);
}