Implemented some of the missing type casting for half floats

This commit is contained in:
Benoit Steiner 2016-03-17 21:45:45 -07:00
parent 95b8961a9b
commit 7b98de1f15

View File

@ -24,8 +24,7 @@ struct scalar_cast_op<float, half> {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(a);
#else
assert(false && "tbd");
return half();
return half(a);
#endif
}
};
@ -43,8 +42,7 @@ struct scalar_cast_op<int, half> {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(static_cast<float>(a));
#else
assert(false && "tbd");
return half();
return half(static_cast<float>(a));
#endif
}
};
@ -62,8 +60,7 @@ struct scalar_cast_op<half, float> {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __half2float(a);
#else
assert(false && "tbd");
return 0.0f;
return static_cast<float>(a);
#endif
}
};