Added support for tensor reductions on half floats

This commit is contained in:
Benoit Steiner 2016-02-19 10:05:59 -08:00
parent 5c4901b83a
commit 180156ba1a
3 changed files with 45 additions and 13 deletions

View File

@ -34,6 +34,26 @@ template<>
struct functor_traits<scalar_cast_op<float, half> > struct functor_traits<scalar_cast_op<float, half> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; }; { enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<>
struct scalar_cast_op<int, half> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
typedef half result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half operator() (const int& a) const {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __float2half(static_cast<float>(a));
#else
assert(false && "tbd");
return half();
#endif
}
};
template<>
struct functor_traits<scalar_cast_op<int, half> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
template<> template<>
struct scalar_cast_op<half, float> { struct scalar_cast_op<half, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)

View File

@ -72,11 +72,12 @@ template <typename T> struct SumReducer
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return static_cast<T>(0); internal::scalar_cast_op<int, T> conv;
return conv(0);
} }
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return pset1<Packet>(0); return pset1<Packet>(initialize());
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum; return accum;
@ -110,11 +111,12 @@ template <typename T> struct MeanReducer
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return static_cast<T>(0); internal::scalar_cast_op<int, T> conv;
return conv(0);
} }
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return pset1<Packet>(0); return pset1<Packet>(initialize());
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum / scalarCount_; return accum / scalarCount_;
@ -214,11 +216,12 @@ template <typename T> struct ProdReducer
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return static_cast<T>(1); internal::scalar_cast_op<int, T> conv;
return conv(1);
} }
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return pset1<Packet>(1); return pset1<Packet>(initialize());
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum; return accum;

View File

@ -93,7 +93,6 @@ void test_cuda_elementwise() {
gpu_device.deallocate(d_res_half); gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float); gpu_device.deallocate(d_res_float);
} }
/* /*
void test_cuda_contractions() { void test_cuda_contractions() {
Eigen::CudaStreamDevice stream; Eigen::CudaStreamDevice stream;
@ -139,7 +138,7 @@ void test_cuda_contractions() {
gpu_device.deallocate(d_float2); gpu_device.deallocate(d_float2);
gpu_device.deallocate(d_res_half); gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float); gpu_device.deallocate(d_res_float);
} }*/
void test_cuda_reductions() { void test_cuda_reductions() {
@ -183,7 +182,7 @@ void test_cuda_reductions() {
gpu_device.deallocate(d_res_half); gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float); gpu_device.deallocate(d_res_float);
} }
*/
#endif #endif
@ -191,9 +190,19 @@ void test_cuda_reductions() {
void test_cxx11_tensor_of_float16_cuda() void test_cxx11_tensor_of_float16_cuda()
{ {
#ifdef EIGEN_HAS_CUDA_FP16 #ifdef EIGEN_HAS_CUDA_FP16
CALL_SUBTEST_1(test_cuda_conversion()); Eigen::CudaStreamDevice stream;
CALL_SUBTEST_1(test_cuda_elementwise()); Eigen::GpuDevice device(&stream);
// CALL_SUBTEST_2(test_cuda_contractions()); if (device.majorDeviceVersion() > 5 ||
// CALL_SUBTEST_3(test_cuda_reductions()); (device.majorDeviceVersion() == 5 && device.minorDeviceVersion() >= 3)) {
CALL_SUBTEST_1(test_cuda_conversion());
CALL_SUBTEST_1(test_cuda_elementwise());
// CALL_SUBTEST_2(test_cuda_contractions());
CALL_SUBTEST_3(test_cuda_reductions());
}
else {
std::cout << "Half floats require compute capability of at least 5.3. This device only supports " << device.majorDeviceVersion() << "." << device.minorDeviceVersion() << ". Skipping the test" << std::endl;
}
#else
std::cout << "Half floats are not supported by this version of cuda: skipping the test" << std::endl;
#endif #endif
} }