mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-06 19:29:08 +08:00
Fix potential race condition in the CUDA reduction code.
This commit is contained in:
parent
cbb14ed47e
commit
08348b4e48
@ -121,6 +121,7 @@ __global__ void FullReductionKernel(Reducer reducer, const Self input, Index num
|
|||||||
// Initialize the output value if it wasn't initialized by the ReductionInitKernel
|
// Initialize the output value if it wasn't initialized by the ReductionInitKernel
|
||||||
if (gridDim.x == 1 && first_index == 0) {
|
if (gridDim.x == 1 && first_index == 0) {
|
||||||
*output = reducer.initialize();
|
*output = reducer.initialize();
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
typename Self::CoeffReturnType accum = reducer.initialize();
|
typename Self::CoeffReturnType accum = reducer.initialize();
|
||||||
@ -172,6 +173,7 @@ static __global__ void FullReductionKernelHalfFloat(Reducer reducer, const Self
|
|||||||
} else {
|
} else {
|
||||||
*scratch = reducer.template initializePacket<half2>();
|
*scratch = reducer.template initializePacket<half2>();
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 accum = reducer.template initializePacket<half2>();
|
half2 accum = reducer.template initializePacket<half2>();
|
||||||
@ -316,6 +318,7 @@ __global__ void InnerReductionKernel(Reducer reducer, const Self input, Index nu
|
|||||||
for (Index i = thread_id; i < num_preserved_coeffs; i += num_threads) {
|
for (Index i = thread_id; i < num_preserved_coeffs; i += num_threads) {
|
||||||
output[i] = reducer.initialize();
|
output[i] = reducer.initialize();
|
||||||
}
|
}
|
||||||
|
_syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Index i = blockIdx.x; i < num_input_blocks; i += gridDim.x) {
|
for (Index i = blockIdx.x; i < num_input_blocks; i += gridDim.x) {
|
||||||
@ -420,6 +423,7 @@ __global__ void OuterReductionKernel(Reducer reducer, const Self input, Index nu
|
|||||||
for (Index i = thread_id; i < num_preserved_coeffs; i += num_threads) {
|
for (Index i = thread_id; i < num_preserved_coeffs; i += num_threads) {
|
||||||
output[i] = reducer.initialize();
|
output[i] = reducer.initialize();
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the reduction.
|
// Do the reduction.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user