Completed the implementation of vectorized type casting of half floats.

This commit is contained in:
Benoit Steiner 2016-03-18 13:36:28 -07:00
parent 7bd551b3a9
commit 134d750eab

View File

@ -87,8 +87,16 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast<half2, float4>(con
float2 r2 = __half22float2(b); float2 r2 = __half22float2(b);
return make_float4(r1.x, r1.y, r2.x, r2.y); return make_float4(r1.x, r1.y, r2.x, r2.y);
#else #else
assert(false && "tbd"); half r1;
return float4(); r1.x = a.x & 0xFFFF;
half r2;
r2.x = (a.x & 0xFFFF0000) >> 16;
half r3;
r3.x = b.x & 0xFFFF;
half r4;
r4.x = (b.x & 0xFFFF0000) >> 16;
return make_float4(static_cast<float>(r1), static_cast<float>(r2),
static_cast<float>(r3), static_cast<float>(r4));
#endif #endif
} }
@ -106,8 +114,13 @@ template<> EIGEN_STRONG_INLINE half2 pcast<float4, half2>(const float4& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float22half2_rn(make_float2(a.x, a.y)); return __float22half2_rn(make_float2(a.x, a.y));
#else #else
assert(false && "tbd"); half r1 = a.x;
return half2(); half r2 = a.y;
half2 r;
r.x = 0;
r.x |= r1.x;
r.x |= (static_cast<unsigned int>(r2.x) << 16);
return r;
#endif #endif
} }