mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Added support for tensor reductions on half floats
This commit is contained in:
parent
5c4901b83a
commit
180156ba1a
@ -34,6 +34,26 @@ template<>
|
|||||||
struct functor_traits<scalar_cast_op<float, half> >
|
struct functor_traits<scalar_cast_op<float, half> >
|
||||||
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct scalar_cast_op<int, half> {
|
||||||
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||||
|
typedef half result_type;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half operator() (const int& a) const {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
|
||||||
|
return __float2half(static_cast<float>(a));
|
||||||
|
#else
|
||||||
|
assert(false && "tbd");
|
||||||
|
return half();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct functor_traits<scalar_cast_op<int, half> >
|
||||||
|
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
|
||||||
|
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
struct scalar_cast_op<half, float> {
|
struct scalar_cast_op<half, float> {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
|
||||||
|
@ -72,11 +72,12 @@ template <typename T> struct SumReducer
|
|||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return static_cast<T>(0);
|
internal::scalar_cast_op<int, T> conv;
|
||||||
|
return conv(0);
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
||||||
return pset1<Packet>(0);
|
return pset1<Packet>(initialize());
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||||
return accum;
|
return accum;
|
||||||
@ -110,11 +111,12 @@ template <typename T> struct MeanReducer
|
|||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return static_cast<T>(0);
|
internal::scalar_cast_op<int, T> conv;
|
||||||
|
return conv(0);
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
||||||
return pset1<Packet>(0);
|
return pset1<Packet>(initialize());
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||||
return accum / scalarCount_;
|
return accum / scalarCount_;
|
||||||
@ -214,11 +216,12 @@ template <typename T> struct ProdReducer
|
|||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||||
return static_cast<T>(1);
|
internal::scalar_cast_op<int, T> conv;
|
||||||
|
return conv(1);
|
||||||
}
|
}
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
|
||||||
return pset1<Packet>(1);
|
return pset1<Packet>(initialize());
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||||
return accum;
|
return accum;
|
||||||
|
@ -93,7 +93,6 @@ void test_cuda_elementwise() {
|
|||||||
gpu_device.deallocate(d_res_half);
|
gpu_device.deallocate(d_res_half);
|
||||||
gpu_device.deallocate(d_res_float);
|
gpu_device.deallocate(d_res_float);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
void test_cuda_contractions() {
|
void test_cuda_contractions() {
|
||||||
Eigen::CudaStreamDevice stream;
|
Eigen::CudaStreamDevice stream;
|
||||||
@ -139,7 +138,7 @@ void test_cuda_contractions() {
|
|||||||
gpu_device.deallocate(d_float2);
|
gpu_device.deallocate(d_float2);
|
||||||
gpu_device.deallocate(d_res_half);
|
gpu_device.deallocate(d_res_half);
|
||||||
gpu_device.deallocate(d_res_float);
|
gpu_device.deallocate(d_res_float);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
|
|
||||||
void test_cuda_reductions() {
|
void test_cuda_reductions() {
|
||||||
@ -183,7 +182,7 @@ void test_cuda_reductions() {
|
|||||||
gpu_device.deallocate(d_res_half);
|
gpu_device.deallocate(d_res_half);
|
||||||
gpu_device.deallocate(d_res_float);
|
gpu_device.deallocate(d_res_float);
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -191,9 +190,19 @@ void test_cuda_reductions() {
|
|||||||
void test_cxx11_tensor_of_float16_cuda()
|
void test_cxx11_tensor_of_float16_cuda()
|
||||||
{
|
{
|
||||||
#ifdef EIGEN_HAS_CUDA_FP16
|
#ifdef EIGEN_HAS_CUDA_FP16
|
||||||
CALL_SUBTEST_1(test_cuda_conversion());
|
Eigen::CudaStreamDevice stream;
|
||||||
CALL_SUBTEST_1(test_cuda_elementwise());
|
Eigen::GpuDevice device(&stream);
|
||||||
// CALL_SUBTEST_2(test_cuda_contractions());
|
if (device.majorDeviceVersion() > 5 ||
|
||||||
// CALL_SUBTEST_3(test_cuda_reductions());
|
(device.majorDeviceVersion() == 5 && device.minorDeviceVersion() >= 3)) {
|
||||||
|
CALL_SUBTEST_1(test_cuda_conversion());
|
||||||
|
CALL_SUBTEST_1(test_cuda_elementwise());
|
||||||
|
// CALL_SUBTEST_2(test_cuda_contractions());
|
||||||
|
CALL_SUBTEST_3(test_cuda_reductions());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::cout << "Half floats require compute capability of at least 5.3. This device only supports " << device.majorDeviceVersion() << "." << device.minorDeviceVersion() << ". Skipping the test" << std::endl;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
std::cout << "Half floats are not supported by this version of cuda: skipping the test" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user