Added support for mean reductions on fp16

This commit is contained in:
Benoit Steiner 2016-06-01 11:12:07 -07:00
parent cd221a62ee
commit d27b0ad4c8

View File

@ -230,13 +230,15 @@ __global__ void ReductionCleanupKernelHalfFloat(Op& reducer, half* output, half2
#endif
template <typename Self, typename Op>
template <typename Self, typename Op, typename OutputType, bool PacketAccess>
struct FullReductionLauncher {
template <typename OutputType>
static void run(const Self&, Op&, const GpuDevice&, OutputType*, typename Self::Index) {
assert(false && "Should only be called on floats and half floats");
}
};
template <typename Self, typename Op, bool PacketAccess>
struct FullReductionLauncher<Self, Op, float, PacketAccess> {
static void run(const Self& self, Op& reducer, const GpuDevice& device, float* output, typename Self::Index num_coeffs) {
typedef typename Self::Index Index;
typedef typename Self::CoeffReturnType Scalar;
@ -254,8 +256,18 @@ struct FullReductionLauncher {
LAUNCH_CUDA_KERNEL((FullReductionKernel<block_size, num_per_thread, Self, Op, Index>),
num_blocks, block_size, 0, device, reducer, self, num_coeffs, output);
}
};
#ifdef EIGEN_HAS_CUDA_FP16
template <typename Self, typename Op>
struct FullReductionLauncher<Self, Op, Eigen::half, false> {
static void run(const Self&, Op&, const GpuDevice&, half*, typename Self::Index) {
assert(false && "Should not be called since there is no packet accessor");
}
};
template <typename Self, typename Op>
struct FullReductionLauncher<Self, Op, Eigen::half, true> {
static void run(const Self& self, Op& reducer, const GpuDevice& device, half* output, typename Self::Index num_coeffs) {
typedef typename Self::Index Index;
@ -279,8 +291,8 @@ struct FullReductionLauncher {
1, 1, 0, device, reducer, output, scratch);
}
}
#endif
};
#endif
template <typename Self, typename Op, bool Vectorizable>
@ -306,7 +318,7 @@ struct FullReducer<Self, Op, GpuDevice, Vectorizable> {
return;
}
FullReductionLauncher<Self, Op>::run(self, reducer, device, output, num_coeffs);
FullReductionLauncher<Self, Op, OutputType, Op::PacketAccess>::run(self, reducer, device, output, num_coeffs);
}
};
@ -473,14 +485,16 @@ __global__ void InnerReductionKernelHalfFloat(Reducer reducer, const Self input,
#endif
template <typename Self, typename Op>
template <typename Self, typename Op, typename OutputType, bool PacketAccess>
struct InnerReductionLauncher {
template <typename OutputType>
static EIGEN_DEVICE_FUNC bool run(const Self&, Op&, const GpuDevice&, OutputType*, typename Self::Index, typename Self::Index) {
assert(false && "Should only be called to reduce floats and half floats on a gpu device");
return true;
}
};
template <typename Self, typename Op, bool PacketAccess>
struct InnerReductionLauncher<Self, Op, float, PacketAccess> {
static bool run(const Self& self, Op& reducer, const GpuDevice& device, float* output, typename Self::Index num_coeffs_to_reduce, typename Self::Index num_preserved_vals) {
typedef typename Self::Index Index;
@ -509,8 +523,18 @@ struct InnerReductionLauncher {
return false;
}
};
#ifdef EIGEN_HAS_CUDA_FP16
template <typename Self, typename Op>
struct InnerReductionLauncher<Self, Op, Eigen::half, false> {
static bool run(const Self&, Op&, const GpuDevice&, half*, typename Self::Index, typename Self::Index) {
assert(false && "Should not be called since there is no packet accessor");
}
};
template <typename Self, typename Op>
struct InnerReductionLauncher<Self, Op, Eigen::half, true> {
static bool run(const Self& self, Op& reducer, const GpuDevice& device, half* output, typename Self::Index num_coeffs_to_reduce, typename Self::Index num_preserved_vals) {
typedef typename Self::Index Index;
@ -543,8 +567,8 @@ struct InnerReductionLauncher {
return false;
}
#endif
};
#endif
template <typename Self, typename Op>
@ -574,7 +598,7 @@ struct InnerReducer<Self, Op, GpuDevice> {
return true;
}
return InnerReductionLauncher<Self, Op>::run(self, reducer, device, output, num_coeffs_to_reduce, num_preserved_vals);
return InnerReductionLauncher<Self, Op, OutputType, Op::PacketAccess>::run(self, reducer, device, output, num_coeffs_to_reduce, num_preserved_vals);
}
};