mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Update the padding computation for PADDING_SAME to be consistent with TensorFlow.
This commit is contained in:
parent
393b7c4959
commit
3122477c86
@ -265,6 +265,10 @@ struct TensorEvaluator<const TensorImagePatchOp<Rows, Cols, ArgType>, Device>
|
|||||||
// Calculate the padding
|
// Calculate the padding
|
||||||
m_rowPaddingTop = ((m_outputRows - 1) * m_row_strides + m_patch_rows_eff - m_input_rows_eff) / 2;
|
m_rowPaddingTop = ((m_outputRows - 1) * m_row_strides + m_patch_rows_eff - m_input_rows_eff) / 2;
|
||||||
m_colPaddingLeft = ((m_outputCols - 1) * m_col_strides + m_patch_cols_eff - m_input_cols_eff) / 2;
|
m_colPaddingLeft = ((m_outputCols - 1) * m_col_strides + m_patch_cols_eff - m_input_cols_eff) / 2;
|
||||||
|
// The padding size calculation for PADDING_SAME has been updated to
|
||||||
|
// be consistent with how TensorFlow extracts its paddings.
|
||||||
|
m_rowPaddingTop = numext::maxi<Index>(0, m_rowPaddingTop);
|
||||||
|
m_colPaddingLeft = numext::maxi<Index>(0, m_colPaddingLeft);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
eigen_assert(false && "unexpected padding");
|
eigen_assert(false && "unexpected padding");
|
||||||
|
@ -405,6 +405,57 @@ void test_patch_padding_same()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verifies that SAME padding, when computed as negative values, will be clipped
|
||||||
|
// to zero.
|
||||||
|
void test_patch_padding_same_negative_padding_clip_to_zero() {
|
||||||
|
int input_depth = 1;
|
||||||
|
int input_rows = 15;
|
||||||
|
int input_cols = 1;
|
||||||
|
int input_batches = 1;
|
||||||
|
int ksize = 1; // Corresponds to the Rows and Cols for
|
||||||
|
// tensor.extract_image_patches<>.
|
||||||
|
int row_stride = 5;
|
||||||
|
int col_stride = 1;
|
||||||
|
// ColMajor
|
||||||
|
Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
|
||||||
|
// Initializes tensor with incrementing numbers.
|
||||||
|
for (int i = 0; i < tensor.size(); ++i) {
|
||||||
|
tensor.data()[i] = i + 1;
|
||||||
|
}
|
||||||
|
Tensor<float, 5> result = tensor.extract_image_patches(
|
||||||
|
ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME);
|
||||||
|
// row padding will be computed as -2 originally and then be clipped to 0.
|
||||||
|
VERIFY_IS_EQUAL(result.coeff(0), 1.0f);
|
||||||
|
VERIFY_IS_EQUAL(result.coeff(1), 6.0f);
|
||||||
|
VERIFY_IS_EQUAL(result.coeff(2), 11.0f);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(0), input_depth); // depth
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(1), ksize); // kernel rows
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(2), ksize); // kernel cols
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(3), 3); // number of patches
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(4), input_batches); // number of batches
|
||||||
|
|
||||||
|
// RowMajor
|
||||||
|
Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
|
||||||
|
VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
|
||||||
|
VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
|
||||||
|
VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
|
||||||
|
VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
|
||||||
|
|
||||||
|
Tensor<float, 5, RowMajor> result_row_major =
|
||||||
|
tensor_row_major.extract_image_patches(ksize, ksize, row_stride,
|
||||||
|
col_stride, 1, 1, PADDING_SAME);
|
||||||
|
VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f);
|
||||||
|
VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f);
|
||||||
|
VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
|
||||||
|
}
|
||||||
|
|
||||||
void test_patch_no_extra_dim()
|
void test_patch_no_extra_dim()
|
||||||
{
|
{
|
||||||
Tensor<float, 3> tensor(2,3,5);
|
Tensor<float, 3> tensor(2,3,5);
|
||||||
@ -754,4 +805,5 @@ void test_cxx11_tensor_image_patch()
|
|||||||
CALL_SUBTEST_4(test_patch_padding_valid_same_value());
|
CALL_SUBTEST_4(test_patch_padding_valid_same_value());
|
||||||
CALL_SUBTEST_5(test_patch_padding_same());
|
CALL_SUBTEST_5(test_patch_padding_same());
|
||||||
CALL_SUBTEST_6(test_imagenet_patches());
|
CALL_SUBTEST_6(test_imagenet_patches());
|
||||||
|
CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user