mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 17:33:15 +08:00
Cleaned up a regression test
This commit is contained in:
parent
4860727ac2
commit
5266ff8966
@ -116,10 +116,10 @@ void test_cuda_argmax_dim()
|
|||||||
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL(tensor_arg.dimensions().TotalSize(),
|
VERIFY_IS_EQUAL(tensor_arg.size(),
|
||||||
size_t(2*3*5*7 / tensor.dimension(dim)));
|
size_t(2*3*5*7 / tensor.dimension(dim)));
|
||||||
|
|
||||||
for (size_t n = 0; n < tensor_arg.dimensions().TotalSize(); ++n) {
|
for (DenseIndex n = 0; n < tensor_arg.size(); ++n) {
|
||||||
// Expect max to be in the first index of the reduced dimension
|
// Expect max to be in the first index of the reduced dimension
|
||||||
VERIFY_IS_EQUAL(tensor_arg.data()[n], 0);
|
VERIFY_IS_EQUAL(tensor_arg.data()[n], 0);
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ void test_cuda_argmax_dim()
|
|||||||
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||||
|
|
||||||
for (size_t n = 0; n < tensor_arg.dimensions().TotalSize(); ++n) {
|
for (DenseIndex n = 0; n < tensor_arg.size(); ++n) {
|
||||||
// Expect max to be in the last index of the reduced dimension
|
// Expect max to be in the last index of the reduced dimension
|
||||||
VERIFY_IS_EQUAL(tensor_arg.data()[n], tensor.dimension(dim) - 1);
|
VERIFY_IS_EQUAL(tensor_arg.data()[n], tensor.dimension(dim) - 1);
|
||||||
}
|
}
|
||||||
@ -205,10 +205,10 @@ void test_cuda_argmin_dim()
|
|||||||
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL(tensor_arg.dimensions().TotalSize(),
|
VERIFY_IS_EQUAL(tensor_arg.size(),
|
||||||
size_t(2*3*5*7 / tensor.dimension(dim)));
|
2*3*5*7 / tensor.dimension(dim));
|
||||||
|
|
||||||
for (size_t n = 0; n < tensor_arg.dimensions().TotalSize(); ++n) {
|
for (DenseIndex n = 0; n < tensor_arg.size(); ++n) {
|
||||||
// Expect min to be in the first index of the reduced dimension
|
// Expect min to be in the first index of the reduced dimension
|
||||||
VERIFY_IS_EQUAL(tensor_arg.data()[n], 0);
|
VERIFY_IS_EQUAL(tensor_arg.data()[n], 0);
|
||||||
}
|
}
|
||||||
@ -233,7 +233,7 @@ void test_cuda_argmin_dim()
|
|||||||
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
assert(cudaMemcpyAsync(tensor_arg.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||||
|
|
||||||
for (size_t n = 0; n < tensor_arg.dimensions().TotalSize(); ++n) {
|
for (DenseIndex n = 0; n < tensor_arg.size(); ++n) {
|
||||||
// Expect max to be in the last index of the reduced dimension
|
// Expect max to be in the last index of the reduced dimension
|
||||||
VERIFY_IS_EQUAL(tensor_arg.data()[n], tensor.dimension(dim) - 1);
|
VERIFY_IS_EQUAL(tensor_arg.data()[n], tensor.dimension(dim) - 1);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user