From 7244a74ab0324bc847ef1fcb8c852fc8408bf1ff Mon Sep 17 00:00:00 2001 From: Lingzhu Xiang Date: Sat, 1 Jan 2022 16:47:22 +0800 Subject: [PATCH] Add bounds checking to Eigen serializer --- Eigen/src/Core/util/Serializer.h | 59 ++++++++++++++--------- test/gpu_test_helper.h | 15 ++++-- test/serializer.cpp | 82 +++++++++++++++++++++++++++----- 3 files changed, 117 insertions(+), 39 deletions(-) diff --git a/Eigen/src/Core/util/Serializer.h b/Eigen/src/Core/util/Serializer.h index 4f7e06213..b77c5de57 100644 --- a/Eigen/src/Core/util/Serializer.h +++ b/Eigen/src/Core/util/Serializer.h @@ -45,11 +45,14 @@ class Serializer end)) return nullptr; EIGEN_USING_STD(memcpy) memcpy(dest, &value, sizeof(value)); return dest + sizeof(value); @@ -57,11 +60,14 @@ class Serializer end)) return nullptr; EIGEN_USING_STD(memcpy) memcpy(&value, src, sizeof(value)); return src + sizeof(value); @@ -84,7 +90,9 @@ class Serializer, void> { return sizeof(Header) + sizeof(Scalar) * value.size(); } - EIGEN_DEVICE_FUNC uint8_t* serialize(uint8_t* dest, const Derived& value) { + EIGEN_DEVICE_FUNC uint8_t* serialize(uint8_t* dest, uint8_t* end, const Derived& value) { + if (EIGEN_PREDICT_FALSE(dest == nullptr)) return nullptr; + if (EIGEN_PREDICT_FALSE(dest + size(value) > end)) return nullptr; const size_t header_bytes = sizeof(Header); const size_t data_bytes = sizeof(Scalar) * value.size(); Header header = {value.rows(), value.cols()}; @@ -95,14 +103,17 @@ class Serializer, void> { return dest + data_bytes; } - EIGEN_DEVICE_FUNC uint8_t* deserialize(uint8_t* src, Derived& value) const { + EIGEN_DEVICE_FUNC const uint8_t* deserialize(const uint8_t* src, const uint8_t* end, Derived& value) const { + if (EIGEN_PREDICT_FALSE(src == nullptr)) return nullptr; + if (EIGEN_PREDICT_FALSE(src + sizeof(Header) > end)) return nullptr; const size_t header_bytes = sizeof(Header); Header header; EIGEN_USING_STD(memcpy) memcpy(&header, src, header_bytes); src += header_bytes; - value.resize(header.rows, header.cols); const size_t data_bytes = sizeof(Scalar) * header.rows * header.cols; + if (EIGEN_PREDICT_FALSE(src + data_bytes > end)) return nullptr; + value.resize(header.rows, header.cols); memcpy(value.data(), src, data_bytes); return src + data_bytes; } @@ -134,17 +145,17 @@ struct serialize_impl { } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - uint8_t* serialize(uint8_t* dest, const T1& value, const Ts&... args) { + uint8_t* serialize(uint8_t* dest, uint8_t* end, const T1& value, const Ts&... args) { Serializer serializer; - dest = serializer.serialize(dest, value); - return serialize_impl::serialize(dest, args...); + dest = serializer.serialize(dest, end, value); + return serialize_impl::serialize(dest, end, args...); } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - uint8_t* deserialize(uint8_t* src, T1& value, Ts&... args) { + const uint8_t* deserialize(const uint8_t* src, const uint8_t* end, T1& value, Ts&... args) { Serializer serializer; - src = serializer.deserialize(src, value); - return serialize_impl::deserialize(src, args...); + src = serializer.deserialize(src, end, value); + return serialize_impl::deserialize(src, end, args...); } }; @@ -155,10 +166,10 @@ struct serialize_impl<0> { size_t serialize_size() { return 0; } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - uint8_t* serialize(uint8_t* dest) { return dest; } + uint8_t* serialize(uint8_t* dest, uint8_t* /*end*/) { return dest; } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - uint8_t* deserialize(uint8_t* src) { return src; } + const uint8_t* deserialize(const uint8_t* src, const uint8_t* /*end*/) { return src; } }; } // namespace internal @@ -179,27 +190,29 @@ size_t serialize_size(const Args&... args) { /** * Serialize a set of values to the byte buffer. * - * \param dest output byte buffer. + * \param dest output byte buffer; if this is nullptr, does nothing. + * \param end the end of the output byte buffer. * \param args ... arguments to serialize in sequence. * \return the next address after all serialized values. */ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -uint8_t* serialize(uint8_t* dest, const Args&... args) { - return internal::serialize_impl::serialize(dest, args...); +uint8_t* serialize(uint8_t* dest, uint8_t* end, const Args&... args) { + return internal::serialize_impl::serialize(dest, end, args...); } /** * Deserialize a set of values from the byte buffer. * - * \param src input byte buffer. + * \param src input byte buffer; if this is nullptr, does nothing. + * \param end the end of input byte buffer. * \param args ... arguments to deserialize in sequence. - * \return the next address after all parsed values. + * \return the next address after all parsed values; nullptr if parsing errors are detected. */ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE -uint8_t* deserialize(uint8_t* src, Args&... args) { - return internal::serialize_impl::deserialize(src, args...); +const uint8_t* deserialize(const uint8_t* src, const uint8_t* end, Args&... args) { + return internal::serialize_impl::deserialize(src, end, args...); } } // namespace Eigen diff --git a/test/gpu_test_helper.h b/test/gpu_test_helper.h index f796eeba5..af1fff6c9 100644 --- a/test/gpu_test_helper.h +++ b/test/gpu_test_helper.h @@ -141,12 +141,14 @@ void run_serialized(index_sequence, index_sequence using test_detail::tuple; // Deserialize input size and inputs. size_t input_size; - uint8_t* buff_ptr = Eigen::deserialize(buffer, input_size); + const uint8_t* read_ptr = buffer; + const uint8_t* read_end = buffer + capacity; + read_ptr = Eigen::deserialize(read_ptr, read_end, input_size); // Create value-type instances to populate. auto args = make_tuple(typename std::decay::type{}...); EIGEN_UNUSED_VARIABLE(args) // Avoid NVCC compile warning. // NVCC 9.1 requires us to spell out the template parameters explicitly. - buff_ptr = Eigen::deserialize(buff_ptr, get::type...>(args)...); + read_ptr = Eigen::deserialize(read_ptr, read_end, get::type...>(args)...); // Call function, with void->Void conversion so we are guaranteed a complete // output type. @@ -158,12 +160,15 @@ void run_serialized(index_sequence, index_sequence output_size += Eigen::serialize_size(result); // Always serialize required buffer size. - buff_ptr = Eigen::serialize(buffer, output_size); + uint8_t* write_ptr = buffer; + uint8_t* write_end = buffer + capacity; + write_ptr = Eigen::serialize(write_ptr, write_end, output_size); + // Null `write_ptr` can be safely passed along. // Serialize outputs if they fit in the buffer. if (output_size <= capacity) { // Collect outputs and result. - buff_ptr = Eigen::serialize(buff_ptr, get::type...>(args)...); - buff_ptr = Eigen::serialize(buff_ptr, result); + write_ptr = Eigen::serialize(write_ptr, write_end, get::type...>(args)...); + write_ptr = Eigen::serialize(write_ptr, write_end, result); } } diff --git a/test/serializer.cpp b/test/serializer.cpp index f5c0d67b3..76b9083df 100644 --- a/test/serializer.cpp +++ b/test/serializer.cpp @@ -31,15 +31,35 @@ void test_pod_type() { // Serialize. std::vector buffer(buffer_size); - uint8_t* dest = serializer.serialize(buffer.data(), initial); - VERIFY_IS_EQUAL(dest - buffer.data(), buffer_size); + uint8_t* begin = buffer.data(); + uint8_t* end = buffer.data() + buffer.size(); + uint8_t* dest = serializer.serialize(begin, end, initial); + VERIFY(dest != nullptr); + VERIFY_IS_EQUAL(dest - begin, buffer_size); // Deserialize. - uint8_t* src = serializer.deserialize(buffer.data(), clone); - VERIFY_IS_EQUAL(src - buffer.data(), buffer_size); + const uint8_t* src = serializer.deserialize(begin, end, clone); + VERIFY(src != nullptr); + VERIFY_IS_EQUAL(src - begin, buffer_size); VERIFY_IS_EQUAL(clone.x, initial.x); VERIFY_IS_EQUAL(clone.y, initial.y); VERIFY_IS_EQUAL(clone.z, initial.z); + + // Serialize with bounds checking errors. + dest = serializer.serialize(begin, end - 1, initial); + VERIFY(dest == nullptr); + dest = serializer.serialize(begin, begin, initial); + VERIFY(dest == nullptr); + dest = serializer.serialize(nullptr, nullptr, initial); + VERIFY(dest == nullptr); + + // Deserialize with bounds checking errors. + src = serializer.deserialize(begin, end - 1, clone); + VERIFY(src == nullptr); + src = serializer.deserialize(begin, begin, clone); + VERIFY(src == nullptr); + src = serializer.deserialize(nullptr, nullptr, clone); + VERIFY(src == nullptr); } // Matrix, Vector, Array @@ -54,14 +74,34 @@ void test_eigen_type(const T& type) { Eigen::Serializer serializer; size_t buffer_size = serializer.size(initial); std::vector buffer(buffer_size); - uint8_t* dest = serializer.serialize(buffer.data(), initial); - VERIFY_IS_EQUAL(dest - buffer.data(), buffer_size); + uint8_t* begin = buffer.data(); + uint8_t* end = buffer.data() + buffer.size(); + uint8_t* dest = serializer.serialize(begin, end, initial); + VERIFY(dest != nullptr); + VERIFY_IS_EQUAL(dest - begin, buffer_size); // Deserialize. T clone; - uint8_t* src = serializer.deserialize(buffer.data(), clone); - VERIFY_IS_EQUAL(src - buffer.data(), buffer_size); + const uint8_t* src = serializer.deserialize(begin, end, clone); + VERIFY(src != nullptr); + VERIFY_IS_EQUAL(src - begin, buffer_size); VERIFY_IS_CWISE_EQUAL(clone, initial); + + // Serialize with bounds checking errors. + dest = serializer.serialize(begin, end - 1, initial); + VERIFY(dest == nullptr); + dest = serializer.serialize(begin, begin, initial); + VERIFY(dest == nullptr); + dest = serializer.serialize(nullptr, nullptr, initial); + VERIFY(dest == nullptr); + + // Deserialize with bounds checking errors. + src = serializer.deserialize(begin, end - 1, clone); + VERIFY(src == nullptr); + src = serializer.deserialize(begin, begin, clone); + VERIFY(src == nullptr); + src = serializer.deserialize(nullptr, nullptr, clone); + VERIFY(src == nullptr); } // Test a collection of dense types. @@ -76,18 +116,38 @@ void test_dense_types(const T1& type1, const T2& type2, const T3& type3) { // Allocate buffer and serialize. size_t buffer_size = Eigen::serialize_size(x1, x2, x3); std::vector buffer(buffer_size); - Eigen::serialize(buffer.data(), x1, x2, x3); + uint8_t* begin = buffer.data(); + uint8_t* end = buffer.data() + buffer.size(); + uint8_t* dest = Eigen::serialize(begin, end, x1, x2, x3); + VERIFY(dest != nullptr); // Clone everything. T1 y1; T2 y2; T3 y3; - Eigen::deserialize(buffer.data(), y1, y2, y3); - + const uint8_t* src = Eigen::deserialize(begin, end, y1, y2, y3); + VERIFY(src != nullptr); + // Verify they equal. VERIFY_IS_CWISE_EQUAL(y1, x1); VERIFY_IS_CWISE_EQUAL(y2, x2); VERIFY_IS_CWISE_EQUAL(y3, x3); + + // Serialize everything with bounds checking errors. + dest = Eigen::serialize(begin, end - 1, y1, y2, y3); + VERIFY(dest == nullptr); + dest = Eigen::serialize(begin, begin, y1, y2, y3); + VERIFY(dest == nullptr); + dest = Eigen::serialize(nullptr, nullptr, y1, y2, y3); + VERIFY(dest == nullptr); + + // Deserialize everything with bounds checking errors. + src = Eigen::deserialize(begin, end - 1, y1, y2, y3); + VERIFY(src == nullptr); + src = Eigen::deserialize(begin, begin, y1, y2, y3); + VERIFY(src == nullptr); + src = Eigen::deserialize(nullptr, nullptr, y1, y2, y3); + VERIFY(src == nullptr); } EIGEN_DECLARE_TEST(serializer)