Made the fp16 code more portable.

This commit is contained in:
Benoit Steiner 2016-04-06 13:44:08 -07:00
parent cf7e73addd
commit 58c1dbff19

View File

@ -55,9 +55,9 @@ namespace Eigen {
namespace internal { namespace internal {
static inline EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x); static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x);
static inline EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff); static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff);
static inline EIGEN_DEVICE_FUNC float half_to_float(__half h); static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h);
} // end namespace internal } // end namespace internal
@ -192,55 +192,55 @@ __device__ bool operator >= (const half& a, const half& b) {
// Definitions for CPUs and older CUDA, mostly working through conversion // Definitions for CPUs and older CUDA, mostly working through conversion
// to/from fp32. // to/from fp32.
static inline EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
return half(float(a) + float(b)); return half(float(a) + float(b));
} }
static inline EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
return half(float(a) * float(b)); return half(float(a) * float(b));
} }
static inline EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
return half(float(a) - float(b)); return half(float(a) - float(b));
} }
static inline EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
return half(float(a) / float(b)); return half(float(a) / float(b));
} }
static inline EIGEN_DEVICE_FUNC half operator - (const half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
half result; half result;
result.x = a.x ^ 0x8000; result.x = a.x ^ 0x8000;
return result; return result;
} }
static inline EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
a = half(float(a) + float(b)); a = half(float(a) + float(b));
return a; return a;
} }
static inline EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
a = half(float(a) * float(b)); a = half(float(a) * float(b));
return a; return a;
} }
static inline EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
a = half(float(a) - float(b)); a = half(float(a) - float(b));
return a; return a;
} }
static inline EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
a = half(float(a) / float(b)); a = half(float(a) / float(b));
return a; return a;
} }
static inline EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
return float(a) == float(b); return float(a) == float(b);
} }
static inline EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
return float(a) != float(b); return float(a) != float(b);
} }
static inline EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
return float(a) < float(b); return float(a) < float(b);
} }
static inline EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
return float(a) <= float(b); return float(a) <= float(b);
} }
static inline EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
return float(a) > float(b); return float(a) > float(b);
} }
static inline EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
return float(a) >= float(b); return float(a) >= float(b);
} }
@ -248,7 +248,7 @@ static inline EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b)
// Division by an index. Do it in full float precision to avoid accuracy // Division by an index. Do it in full float precision to avoid accuracy
// issues in converting the denominator to half. // issues in converting the denominator to half.
static inline EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
return Eigen::half(static_cast<float>(a) / static_cast<float>(b)); return Eigen::half(static_cast<float>(a) / static_cast<float>(b));
} }
@ -259,7 +259,7 @@ static inline EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
namespace internal { namespace internal {
static inline EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x) {
__half h; __half h;
h.x = x; h.x = x;
return h; return h;
@ -270,7 +270,7 @@ union FP32 {
float f; float f;
}; };
static inline EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff) {
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(ff); return __float2half(ff);
#else #else
@ -318,7 +318,7 @@ static inline EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff) {
#endif #endif
} }
static inline EIGEN_DEVICE_FUNC float half_to_float(__half h) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h) {
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __half2float(h); return __half2float(h);
#else #else
@ -356,11 +356,11 @@ template<> struct is_arithmetic<half> { enum { value = true }; };
template<> struct NumTraits<Eigen::half> template<> struct NumTraits<Eigen::half>
: GenericNumTraits<Eigen::half> : GenericNumTraits<Eigen::half>
{ {
EIGEN_DEVICE_FUNC static inline float dummy_precision() { return 1e-3f; } EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float dummy_precision() { return 1e-3f; }
EIGEN_DEVICE_FUNC static inline Eigen::half highest() { EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half highest() {
return internal::raw_uint16_to_half(0x7bff); return internal::raw_uint16_to_half(0x7bff);
} }
EIGEN_DEVICE_FUNC static inline Eigen::half lowest() { EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Eigen::half lowest() {
return internal::raw_uint16_to_half(0xfbff); return internal::raw_uint16_to_half(0xfbff);
} }
}; };
@ -369,10 +369,10 @@ template<> struct NumTraits<Eigen::half>
namespace numext { namespace numext {
static inline EIGEN_DEVICE_FUNC bool (isinf)(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const Eigen::half& a) {
return (a.x & 0x7fff) == 0x7c00; return (a.x & 0x7fff) == 0x7c00;
} }
static inline EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a) {
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(a); return __hisnan(a);
#else #else
@ -385,33 +385,33 @@ static inline EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a) {
} // end namespace Eigen } // end namespace Eigen
// Standard mathematical functions and trancendentals. // Standard mathematical functions and trancendentals.
static inline EIGEN_DEVICE_FUNC Eigen::half abs(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half fabsh(const Eigen::half& a) {
Eigen::half result; Eigen::half result;
result.x = a.x & 0x7FFF; result.x = a.x & 0x7FFF;
return result; return result;
} }
static inline EIGEN_DEVICE_FUNC Eigen::half exp(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half exph(const Eigen::half& a) {
return Eigen::half(::expf(float(a))); return Eigen::half(::expf(float(a)));
} }
static inline EIGEN_DEVICE_FUNC Eigen::half log(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half logh(const Eigen::half& a) {
return Eigen::half(::logf(float(a))); return Eigen::half(::logf(float(a)));
} }
static inline EIGEN_DEVICE_FUNC Eigen::half sqrt(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sqrth(const Eigen::half& a) {
return Eigen::half(::sqrtf(float(a))); return Eigen::half(::sqrtf(float(a)));
} }
static inline EIGEN_DEVICE_FUNC Eigen::half floor(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half floorh(const Eigen::half& a) {
return Eigen::half(::floorf(float(a))); return Eigen::half(::floorf(float(a)));
} }
static inline EIGEN_DEVICE_FUNC Eigen::half ceil(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half ceilh(const Eigen::half& a) {
return Eigen::half(::ceilf(float(a))); return Eigen::half(::ceilf(float(a)));
} }
static inline EIGEN_DEVICE_FUNC bool (isnan)(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isnan)(const Eigen::half& a) {
return (Eigen::numext::isnan)(a); return (Eigen::numext::isnan)(a);
} }
static inline EIGEN_DEVICE_FUNC bool (isinf)(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isinf)(const Eigen::half& a) {
return (Eigen::numext::isinf)(a); return (Eigen::numext::isinf)(a);
} }
static inline EIGEN_DEVICE_FUNC bool (isfinite)(const Eigen::half& a) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isfinite)(const Eigen::half& a) {
return !(Eigen::numext::isinf)(a) && !(Eigen::numext::isnan)(a); return !(Eigen::numext::isinf)(a) && !(Eigen::numext::isnan)(a);
} }
@ -420,19 +420,39 @@ namespace std {
// Import the standard mathematical functions and trancendentals into the // Import the standard mathematical functions and trancendentals into the
// into the std namespace. // into the std namespace.
using ::abs; static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half abs(const Eigen::half& a) {
using ::exp; return ::fabsh(a);
using ::log; }
using ::sqrt; static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half exp(const Eigen::half& a) {
using ::floor; return ::exph(a);
using ::ceil; }
using ::isfinite; static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half log(const Eigen::half& a) {
return ::logh(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half sqrt(const Eigen::half& a) {
return ::sqrth(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half floor(const Eigen::half& a) {
return ::floorh(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half ceil(const Eigen::half& a) {
return ::ceilh(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isnan)(const Eigen::half& a) {
return (Eigen::numext::isnan)(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int (isinf)(const Eigen::half& a) {
return (Eigen::numext::isinf)(a);
}
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const Eigen::half& a) {
return !(Eigen::numext::isinf)(a) && !(Eigen::numext::isnan)(a);
}
#if __cplusplus > 199711L #if __cplusplus > 199711L
template <> template <>
struct hash<Eigen::half> { struct hash<Eigen::half> {
size_t operator()(const Eigen::half& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const {
return std::hash<unsigned short>()(a.x); return static_cast<std::size_t>(a.x);
} }
}; };
#endif #endif
@ -442,14 +462,14 @@ struct hash<Eigen::half> {
// Add the missing shfl_xor intrinsic // Add the missing shfl_xor intrinsic
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 #if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
__device__ inline Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) { __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width)); return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
} }
#endif #endif
// ldg() has an overload for __half, but we also need one for Eigen::half. // ldg() has an overload for __half, but we also need one for Eigen::half.
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 320 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 320
static inline EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) { static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
return Eigen::internal::raw_uint16_to_half( return Eigen::internal::raw_uint16_to_half(
__ldg(reinterpret_cast<const unsigned short*>(ptr))); __ldg(reinterpret_cast<const unsigned short*>(ptr)));
} }