mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Enable partial support for half floats on Kepler GPUs.
This commit is contained in:
parent
1da10a7358
commit
1032441c6f
@ -17,8 +17,10 @@
|
|||||||
// we'll use on the host side (SSE, AVX, ...)
|
// we'll use on the host side (SSE, AVX, ...)
|
||||||
#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
|
#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
|
||||||
|
|
||||||
|
// The following operations require arch >= 5.3
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
__device__ half operator + (const half& a, const half& b) {
|
__device__ half operator + (const half& a, const half& b) {
|
||||||
return __hadd(a, b);
|
return __hadd(a, b);
|
||||||
}
|
}
|
||||||
@ -60,6 +62,7 @@ __device__ half abs(const half& a) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -98,8 +101,79 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1<half2>(const half&
|
|||||||
return __half2half2(from);
|
return __half2half2(from);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload<half2>(const half* from) {
|
||||||
|
return *reinterpret_cast<const half2*>(from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu<half2>(const half* from) {
|
||||||
|
return __halves2half2(from[0], from[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE half2 ploaddup<half2>(const half* from) {
|
||||||
|
return __halves2half2(from[0], from[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<half>(half* to, const half2& from) {
|
||||||
|
*reinterpret_cast<half2*>(to) = from;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<half>(half* to, const half2& from) {
|
||||||
|
to[0] = __low2half(from);
|
||||||
|
to[1] = __high2half(from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Aligned>(const half* from) {
|
||||||
|
#if __CUDA_ARCH__ >= 320
|
||||||
|
return __ldg((const half2*)from);
|
||||||
|
#else
|
||||||
|
return __halves2half2(*(from+0), *(from+1));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Unaligned>(const half* from) {
|
||||||
|
#if __CUDA_ARCH__ >= 320
|
||||||
|
return __halves2half2(__ldg(from+0), __ldg(from+1));
|
||||||
|
#else
|
||||||
|
return __halves2half2(*(from+0), *(from+1));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC inline half2 pgather<half, half2>(const half* from, Index stride) {
|
||||||
|
return __halves2half2(from[0*stride], from[1*stride]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC inline void pscatter<half, half2>(half* to, const half2& from, Index stride) {
|
||||||
|
to[stride*0] = __low2half(from);
|
||||||
|
to[stride*1] = __high2half(from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC inline half pfirst<half2>(const half2& a) {
|
||||||
|
return __low2half(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC inline half2 pabs<half2>(const half2& a) {
|
||||||
|
half2 result;
|
||||||
|
result.x = a.x & 0x7FFF7FFF;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC inline void
|
||||||
|
ptranspose(PacketBlock<half2,2>& kernel) {
|
||||||
|
half a1 = __low2half(kernel.packet[0]);
|
||||||
|
half a2 = __high2half(kernel.packet[0]);
|
||||||
|
half b1 = __low2half(kernel.packet[1]);
|
||||||
|
half b2 = __high2half(kernel.packet[1]);
|
||||||
|
kernel.packet[0] = __halves2half2(a1, b1);
|
||||||
|
kernel.packet[1] = __halves2half2(a2, b2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following operations require arch >= 5.3
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset<half2>(const half& a) {
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset<half2>(const half& a) {
|
||||||
return __halves2half2(a, __hadd(a, __float2half(1)));
|
return __halves2half2(a, __hadd(a, __float2half(1.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a, const half2& b) {
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd<half2>(const half2& a, const half2& b) {
|
||||||
@ -154,50 +228,6 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax<half2>(const half2&
|
|||||||
return __halves2half2(r1, r2);
|
return __halves2half2(r1, r2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload<half2>(const half* from) {
|
|
||||||
return *reinterpret_cast<const half2*>(from);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu<half2>(const half* from) {
|
|
||||||
return __halves2half2(from[0], from[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE half2 ploaddup<half2>(const half* from) {
|
|
||||||
return __halves2half2(from[0], from[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore<half>(half* to, const half2& from) {
|
|
||||||
*reinterpret_cast<half2*>(to) = from;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu<half>(half* to, const half2& from) {
|
|
||||||
to[0] = __low2half(from);
|
|
||||||
to[1] = __high2half(from);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Aligned>(const half* from) {
|
|
||||||
return __ldg((const half2*)from);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro<half2, Unaligned>(const half* from) {
|
|
||||||
return __halves2half2(__ldg(from+0), __ldg(from+1));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC inline half2 pgather<half, half2>(const half* from, Index stride) {
|
|
||||||
return __halves2half2(from[0*stride], from[1*stride]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC inline void pscatter<half, half2>(half* to, const half2& from, Index stride) {
|
|
||||||
to[stride*0] = __low2half(from);
|
|
||||||
to[stride*1] = __high2half(from);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC inline half pfirst<half2>(const half2& a) {
|
|
||||||
return __low2half(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC inline half predux<half2>(const half2& a) {
|
template<> EIGEN_DEVICE_FUNC inline half predux<half2>(const half2& a) {
|
||||||
return __hadd(__low2half(a), __high2half(a));
|
return __hadd(__low2half(a), __high2half(a));
|
||||||
}
|
}
|
||||||
@ -217,23 +247,7 @@ template<> EIGEN_DEVICE_FUNC inline half predux_min<half2>(const half2& a) {
|
|||||||
template<> EIGEN_DEVICE_FUNC inline half predux_mul<half2>(const half2& a) {
|
template<> EIGEN_DEVICE_FUNC inline half predux_mul<half2>(const half2& a) {
|
||||||
return __hmul(__low2half(a), __high2half(a));
|
return __hmul(__low2half(a), __high2half(a));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
template<> EIGEN_DEVICE_FUNC inline half2 pabs<half2>(const half2& a) {
|
|
||||||
half2 result;
|
|
||||||
result.x = a.x & 0x7FFF7FFF;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC inline void
|
|
||||||
ptranspose(PacketBlock<half2,2>& kernel) {
|
|
||||||
half a1 = __low2half(kernel.packet[0]);
|
|
||||||
half a2 = __high2half(kernel.packet[0]);
|
|
||||||
half b1 = __low2half(kernel.packet[1]);
|
|
||||||
half b2 = __high2half(kernel.packet[1]);
|
|
||||||
kernel.packet[0] = __halves2half2(a1, b1);
|
|
||||||
kernel.packet[1] = __halves2half2(a2, b2);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user