Making Draco decoder more robust when handling tampered input data.

Adding support for unit tests to Windows builds.
This commit is contained in:
Ondrej Stava 2017-01-19 15:23:10 -08:00
parent 27cf67cd0f
commit b0215d525b
16 changed files with 104 additions and 25 deletions

View File

@ -107,6 +107,11 @@ else ()
endif () endif ()
if (ENABLE_TESTS) if (ENABLE_TESTS)
if (MSVC)
# Default runtime selected by cmake collides with Googletest settings,
# just force all Draco builds to be static (instead of dll/msvcrt).
include("${draco_root}/cmake/msvc_runtime.cmake")
endif ()
# Googletest defaults. # Googletest defaults.
set(GTEST_SOURCE_DIR set(GTEST_SOURCE_DIR
"${draco_root}/../googletest" CACHE STRING "${draco_root}/../googletest" CACHE STRING

View File

@ -47,7 +47,8 @@ class MeshTraversalSequencer : public PointsSequencer {
bool UpdatePointToAttributeIndexMapping(PointAttribute *attribute) override { bool UpdatePointToAttributeIndexMapping(PointAttribute *attribute) override {
const auto *corner_table = traverser_.corner_table(); const auto *corner_table = traverser_.corner_table();
attribute->SetExplicitMapping(mesh_->num_points()); attribute->SetExplicitMapping(mesh_->num_points());
const int32_t num_faces = mesh_->num_faces(); const size_t num_faces = mesh_->num_faces();
const size_t num_points = mesh_->num_points();
for (FaceIndex f(0); f < num_faces; ++f) { for (FaceIndex f(0); f < num_faces; ++f) {
const auto &face = mesh_->face(f); const auto &face = mesh_->face(f);
for (int p = 0; p < 3; ++p) { for (int p = 0; p < 3; ++p) {
@ -57,6 +58,10 @@ class MeshTraversalSequencer : public PointsSequencer {
const AttributeValueIndex att_entry_id( const AttributeValueIndex att_entry_id(
encoding_data_ encoding_data_
->vertex_to_encoded_attribute_value_index_map[vert_id.value()]); ->vertex_to_encoded_attribute_value_index_map[vert_id.value()]);
if (att_entry_id.value() >= num_points) {
// There cannot be more attribute values than the number of points.
return false;
}
attribute->SetPointMapEntry(point_id, att_entry_id); attribute->SetPointMapEntry(point_id, att_entry_id);
} }
} }

View File

@ -126,7 +126,7 @@ class PredictionSchemeNormalOctahedronTransform
orig = orig - t; orig = orig - t;
pred = pred - t; pred = pred - t;
if (!IsInDiamond( max_value_, pred[0], pred[1])) { if (!IsInDiamond(max_value_, pred[0], pred[1])) {
InvertRepresentation(max_value_, &orig[0], &orig[1]); InvertRepresentation(max_value_, &orig[0], &orig[1]);
InvertRepresentation(max_value_, &pred[0], &pred[1]); InvertRepresentation(max_value_, &pred[0], &pred[1]);
} }
@ -141,7 +141,7 @@ class PredictionSchemeNormalOctahedronTransform
const Point2 t(max_value_, max_value_); const Point2 t(max_value_, max_value_);
pred = pred - t; pred = pred - t;
const bool pred_is_in_diamond = IsInDiamond( max_value_, pred[0], pred[1]); const bool pred_is_in_diamond = IsInDiamond(max_value_, pred[0], pred[1]);
if (!pred_is_in_diamond) { if (!pred_is_in_diamond) {
InvertRepresentation(max_value_, &pred[0], &pred[1]); InvertRepresentation(max_value_, &pred[0], &pred[1]);
} }

View File

@ -27,6 +27,9 @@ bool SequentialNormalAttributeDecoder::Initialize(PointCloudDecoder *decoder,
// Currently, this encoder works only for 3-component normal vectors. // Currently, this encoder works only for 3-component normal vectors.
if (attribute()->components_count() != 3) if (attribute()->components_count() != 3)
return false; return false;
// Also the data type must be DT_FLOAT32.
if (attribute()->data_type() != DT_FLOAT32)
return false;
return true; return true;
} }

View File

@ -100,6 +100,9 @@ bool MeshEdgeBreakerDecoderImpl<TraversalDecoder>::CreateAttributesDecoder(
return false; return false;
if (att_data_id >= 0) { if (att_data_id >= 0) {
if (att_data_id >= attribute_data_.size()) {
return false; // Unexpected attribute data.
}
attribute_data_[att_data_id].decoder_id = att_decoder_id; attribute_data_[att_data_id].decoder_id = att_decoder_id;
} }
@ -138,6 +141,9 @@ bool MeshEdgeBreakerDecoderImpl<TraversalDecoder>::CreateAttributesDecoder(
sequencer = std::move(traversal_sequencer); sequencer = std::move(traversal_sequencer);
} else { } else {
if (att_data_id < 0)
return false; // Attribute data must be specified.
// Per-corner attribute decoder. // Per-corner attribute decoder.
typedef CornerTableTraversalProcessor<MeshAttributeCornerTable> typedef CornerTableTraversalProcessor<MeshAttributeCornerTable>
AttProcessor; AttProcessor;
@ -225,6 +231,13 @@ bool MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity() {
if (!decoder_->buffer()->Decode(&num_encoded_symbols)) if (!decoder_->buffer()->Decode(&num_encoded_symbols))
return false; return false;
if (num_faces < num_encoded_symbols) {
// Number of faces needs to be the same or greater than the number of
// symbols (it can be greater because the initial face may not be encoded as
// a symbol).
return false;
}
uint32_t num_encoded_split_symbols; uint32_t num_encoded_split_symbols;
if (!decoder_->buffer()->Decode(&num_encoded_split_symbols)) if (!decoder_->buffer()->Decode(&num_encoded_split_symbols))
return false; return false;
@ -252,7 +265,9 @@ bool MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity() {
traversal_decoder_.SetNumEncodedVertices(num_encoded_vertices_); traversal_decoder_.SetNumEncodedVertices(num_encoded_vertices_);
traversal_decoder_.SetNumAttributeData(num_attribute_data); traversal_decoder_.SetNumAttributeData(num_attribute_data);
const DecoderBuffer traversal_end_buffer = traversal_decoder_.Start(); DecoderBuffer traversal_end_buffer;
if (!traversal_decoder_.Start(&traversal_end_buffer))
return false;
const int num_connectivity_verts = DecodeConnectivity(num_encoded_symbols); const int num_connectivity_verts = DecodeConnectivity(num_encoded_symbols);
if (num_connectivity_verts == -1) if (num_connectivity_verts == -1)
@ -340,6 +355,7 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
std::unordered_map<int, CornerIndex> topology_split_active_corners; std::unordered_map<int, CornerIndex> topology_split_active_corners;
int num_vertices = 0; int num_vertices = 0;
int max_num_vertices = is_vert_hole_.size();
int num_faces = 0; int num_faces = 0;
for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) { for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) {
const FaceIndex face(num_faces++); const FaceIndex face(num_faces++);
@ -384,6 +400,8 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
corner + 1, corner_table_->Vertex(corner_table_->Next(corner_b))); corner + 1, corner_table_->Vertex(corner_table_->Next(corner_b)));
corner_table_->MapCornerToVertex( corner_table_->MapCornerToVertex(
corner + 2, corner_table_->Vertex(corner_table_->Previous(corner_a))); corner + 2, corner_table_->Vertex(corner_table_->Previous(corner_a)));
if (num_vertices > max_num_vertices)
return -1; // Unexpected number of decoded vertices.
// Mark the vertex |x| as interior. // Mark the vertex |x| as interior.
is_vert_hole_[vertex_x.value()] = false; is_vert_hole_[vertex_x.value()] = false;
// Update the corner on the active stack. // Update the corner on the active stack.
@ -484,6 +502,9 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
// Add the tip corner to the active stack. // Add the tip corner to the active stack.
active_corner_stack.push_back(corner); active_corner_stack.push_back(corner);
check_topology_split = true; check_topology_split = true;
} else {
// Error. Unknown symbol decoded.
return -1;
} }
// Inform the traversal decoder that a new corner has been reached. // Inform the traversal decoder that a new corner has been reached.
traversal_decoder_.NewActiveCornerReached(active_corner_stack.back()); traversal_decoder_.NewActiveCornerReached(active_corner_stack.back());
@ -503,6 +524,8 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
int encoder_split_symbol_id; int encoder_split_symbol_id;
while (IsTopologySplit(encoder_symbol_id, &split_edge, while (IsTopologySplit(encoder_symbol_id, &split_edge,
&encoder_split_symbol_id)) { &encoder_split_symbol_id)) {
if (encoder_split_symbol_id < 0)
return -1; // Wrong split symbol id.
// Symbol was part of a topology split. Now we need to determine which // Symbol was part of a topology split. Now we need to determine which
// edge should be added to the active edges stack. // edge should be added to the active edges stack.
const CornerIndex act_top_corner = active_corner_stack.back(); const CornerIndex act_top_corner = active_corner_stack.back();
@ -530,7 +553,8 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
} }
} }
} }
if (num_vertices > max_num_vertices)
return -1; // Unexpected number of decoded vertices.
// Decode start faces and connect them to the faces from the active stack. // Decode start faces and connect them to the faces from the active stack.
while (active_corner_stack.size() > 0) { while (active_corner_stack.size() > 0) {
const CornerIndex corner = active_corner_stack.back(); const CornerIndex corner = active_corner_stack.back();
@ -598,6 +622,8 @@ int MeshEdgeBreakerDecoderImpl<TraversalDecoder>::DecodeConnectivity(
init_corners_.push_back(corner); init_corners_.push_back(corner);
} }
} }
if (num_faces != corner_table_->num_faces())
return -1; // Unexcpected number of decoded faces.
vertex_id_map_.resize(num_vertices); vertex_id_map_.resize(num_vertices);
return num_vertices; return num_vertices;
} }

View File

@ -79,7 +79,15 @@ class MeshEdgeBreakerDecoderImpl : public MeshEdgeBreakerDecoderImplInterface {
int *out_encoder_split_symbol_id) { int *out_encoder_split_symbol_id) {
if (topology_split_data_.size() == 0) if (topology_split_data_.size() == 0)
return false; return false;
DCHECK_LE(topology_split_data_.back().source_symbol_id, encoder_symbol_id); if (topology_split_data_.back().source_symbol_id > encoder_symbol_id) {
// Something is wrong; if the desired source symbol is greater than the
// current encoder_symbol_id, we missed it, or the input was tampered
// (|encoder_symbol_id| keeps decreasing).
// Return invalid symbol id to notify the decoder that there was an
// error.
*out_encoder_split_symbol_id = -1;
return true;
}
if (topology_split_data_.back().source_symbol_id != encoder_symbol_id) if (topology_split_data_.back().source_symbol_id != encoder_symbol_id)
return false; return false;
*out_face_edge = *out_face_edge =

View File

@ -28,9 +28,7 @@ namespace draco {
class MeshEdgebreakerEncodingTest : public ::testing::Test { class MeshEdgebreakerEncodingTest : public ::testing::Test {
protected: protected:
void TestFile(const std::string &file_name) { void TestFile(const std::string &file_name) { TestFile(file_name, -1); }
TestFile(file_name, -1);
}
void TestFile(const std::string &file_name, int compression_level) { void TestFile(const std::string &file_name, int compression_level) {
const std::string path = GetTestFileFullPath(file_name); const std::string path = GetTestFileFullPath(file_name);

View File

@ -49,28 +49,36 @@ class MeshEdgeBreakerTraversalDecoder {
// Called before the traversal decoding is started. // Called before the traversal decoding is started.
// Returns a buffer decoder that points to data that was encoded after the // Returns a buffer decoder that points to data that was encoded after the
// traversal. // traversal.
DecoderBuffer Start() { bool Start(DecoderBuffer *out_buffer) {
// Decode symbols from the main buffer decoder and face configurations from // Decode symbols from the main buffer decoder and face configurations from
// the start_face_buffer decoder. // the start_face_buffer decoder.
uint64_t traversal_size; uint64_t traversal_size;
buffer_.StartBitDecoding(true, &traversal_size); if (!buffer_.StartBitDecoding(true, &traversal_size))
return false;
start_face_buffer_.Init(buffer_.data_head(), buffer_.remaining_size()); start_face_buffer_.Init(buffer_.data_head(), buffer_.remaining_size());
if (traversal_size > start_face_buffer_.remaining_size())
return false;
start_face_buffer_.Advance(traversal_size); start_face_buffer_.Advance(traversal_size);
start_face_buffer_.StartBitDecoding(true, &traversal_size); if (!start_face_buffer_.StartBitDecoding(true, &traversal_size))
return false;
// Create a decoder that is set to the end of the encoded traversal data. // Create a decoder that is set to the end of the encoded traversal data.
DecoderBuffer ret; DecoderBuffer ret;
ret.Init(start_face_buffer_.data_head(), ret.Init(start_face_buffer_.data_head(),
start_face_buffer_.remaining_size()); start_face_buffer_.remaining_size());
if (traversal_size > ret.remaining_size())
return false;
ret.Advance(traversal_size); ret.Advance(traversal_size);
// Prepare attribute decoding. // Prepare attribute decoding.
if (num_attribute_data_ > 0) { if (num_attribute_data_ > 0) {
attribute_connectivity_decoders_ = std::unique_ptr<BinaryDecoder[]>( attribute_connectivity_decoders_ = std::unique_ptr<BinaryDecoder[]>(
new BinaryDecoder[num_attribute_data_]); new BinaryDecoder[num_attribute_data_]);
for (int i = 0; i < num_attribute_data_; ++i) { for (int i = 0; i < num_attribute_data_; ++i) {
attribute_connectivity_decoders_[i].StartDecoding(&ret); if (!attribute_connectivity_decoders_[i].StartDecoding(&ret))
return false;
} }
} }
return ret; *out_buffer = ret;
return true;
} }
// Returns the configuration of a new initial face. // Returns the configuration of a new initial face.

View File

@ -37,16 +37,19 @@ class MeshEdgeBreakerTraversalPredictiveDecoder
} }
void SetNumEncodedVertices(int num_vertices) { num_vertices_ = num_vertices; } void SetNumEncodedVertices(int num_vertices) { num_vertices_ = num_vertices; }
DecoderBuffer Start() { bool Start(DecoderBuffer *out_buffer) {
DecoderBuffer buffer = MeshEdgeBreakerTraversalDecoder::Start(); if (!MeshEdgeBreakerTraversalDecoder::Start(out_buffer))
return false;
int32_t num_split_symbols; int32_t num_split_symbols;
buffer.Decode(&num_split_symbols); if (!out_buffer->Decode(&num_split_symbols))
return false;
// Add one vertex for each split symbol. // Add one vertex for each split symbol.
num_vertices_ += num_split_symbols; num_vertices_ += num_split_symbols;
// Set the valences of all initial vertices to 0. // Set the valences of all initial vertices to 0.
vertex_valences_.resize(num_vertices_, 0); vertex_valences_.resize(num_vertices_, 0);
prediction_decoder_.StartDecoding(&buffer); if (!prediction_decoder_.StartDecoding(out_buffer))
return buffer; return false;
return true;
} }
inline uint32_t DecodeSymbol() { inline uint32_t DecodeSymbol() {

View File

@ -466,7 +466,8 @@ class RAnsDecoder {
} }
// Construct a look up table with |rans_precision| number of entries. // Construct a look up table with |rans_precision| number of entries.
inline void rans_build_look_up_table(const uint32_t token_probs[], // Returns false if the table couldn't be built (because of wrong input data).
inline bool rans_build_look_up_table(const uint32_t token_probs[],
uint32_t num_symbols) { uint32_t num_symbols) {
lut_table_.resize(rans_precision); lut_table_.resize(rans_precision);
probability_table_.resize(num_symbols); probability_table_.resize(num_symbols);
@ -476,12 +477,16 @@ class RAnsDecoder {
probability_table_[i].prob = token_probs[i]; probability_table_[i].prob = token_probs[i];
probability_table_[i].cum_prob = cum_prob; probability_table_[i].cum_prob = cum_prob;
cum_prob += token_probs[i]; cum_prob += token_probs[i];
if (cum_prob > rans_precision) {
return false;
}
for (uint32_t j = act_prob; j < cum_prob; ++j) { for (uint32_t j = act_prob; j < cum_prob; ++j) {
lut_table_[j] = i; lut_table_[j] = i;
} }
act_prob = cum_prob; act_prob = cum_prob;
} }
assert(cum_prob == rans_precision); assert(cum_prob == rans_precision);
return true;
} }
private: private:

View File

@ -27,7 +27,8 @@ void DecoderBuffer::Init(const char *data, size_t data_size) {
bool DecoderBuffer::StartBitDecoding(bool decode_size, uint64_t *out_size) { bool DecoderBuffer::StartBitDecoding(bool decode_size, uint64_t *out_size) {
if (decode_size) { if (decode_size) {
Decode(out_size); if (!Decode(out_size))
return false;
} }
bit_mode_ = true; bit_mode_ = true;
bit_decoder_.reset(data_head(), remaining_size()); bit_decoder_.reset(data_head(), remaining_size());

View File

@ -1,6 +1,6 @@
#include "core/draco_test_base.h" #include "core/draco_test_base.h"
int main(int argc, char* argv[]) { int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv); ::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }

View File

@ -122,7 +122,7 @@ RAnsBitDecoder::RAnsBitDecoder() : prob_zero_(0) {}
RAnsBitDecoder::~RAnsBitDecoder() { Clear(); } RAnsBitDecoder::~RAnsBitDecoder() { Clear(); }
void RAnsBitDecoder::StartDecoding(DecoderBuffer *source_buffer) { bool RAnsBitDecoder::StartDecoding(DecoderBuffer *source_buffer) {
Clear(); Clear();
source_buffer->Decode(&prob_zero_); source_buffer->Decode(&prob_zero_);
@ -130,10 +130,14 @@ void RAnsBitDecoder::StartDecoding(DecoderBuffer *source_buffer) {
uint32_t size_in_bytes; uint32_t size_in_bytes;
source_buffer->Decode(&size_in_bytes); source_buffer->Decode(&size_in_bytes);
if (size_in_bytes > source_buffer->remaining_size())
return false;
ans_read_init(&ans_decoder_, reinterpret_cast<uint8_t *>(const_cast<char *>( ans_read_init(&ans_decoder_, reinterpret_cast<uint8_t *>(const_cast<char *>(
source_buffer->data_head())), source_buffer->data_head())),
size_in_bytes); size_in_bytes);
source_buffer->Advance(size_in_bytes); source_buffer->Advance(size_in_bytes);
return true;
} }
bool RAnsBitDecoder::DecodeNextBit() { bool RAnsBitDecoder::DecodeNextBit() {

View File

@ -61,7 +61,8 @@ class RAnsBitDecoder {
~RAnsBitDecoder(); ~RAnsBitDecoder();
// Sets |source_buffer| as the buffer to decode bits from. // Sets |source_buffer| as the buffer to decode bits from.
void StartDecoding(DecoderBuffer *source_buffer); // Returns false when the data is invalid.
bool StartDecoding(DecoderBuffer *source_buffer);
// Decode one bit. Returns true if the bit is a 1, otherwsie false. // Decode one bit. Returns true if the bit is a 1, otherwsie false.
bool DecodeNextBit(); bool DecodeNextBit();

View File

@ -32,6 +32,8 @@ class RAnsSymbolDecoder {
// Initialize the decoder and decode the probability table. // Initialize the decoder and decode the probability table.
bool Create(DecoderBuffer *buffer); bool Create(DecoderBuffer *buffer);
uint32_t num_symbols() const { return num_symbols_; }
// Starts decoding from the buffer. The buffer will be advanced past the // Starts decoding from the buffer. The buffer will be advanced past the
// encoded data after this call. // encoded data after this call.
void StartDecoding(DecoderBuffer *buffer); void StartDecoding(DecoderBuffer *buffer);
@ -55,6 +57,8 @@ bool RAnsSymbolDecoder<max_symbol_bit_length_t>::Create(DecoderBuffer *buffer) {
if (!buffer->Decode(&num_symbols_)) if (!buffer->Decode(&num_symbols_))
return false; return false;
probability_table_.resize(num_symbols_); probability_table_.resize(num_symbols_);
if (num_symbols_ == 0)
return true;
// Decode the table. // Decode the table.
for (uint32_t i = 0; i < num_symbols_; ++i) { for (uint32_t i = 0; i < num_symbols_; ++i) {
uint32_t prob = 0; uint32_t prob = 0;
@ -74,7 +78,8 @@ bool RAnsSymbolDecoder<max_symbol_bit_length_t>::Create(DecoderBuffer *buffer) {
} }
probability_table_[i] = prob; probability_table_[i] = prob;
} }
ans_.rans_build_look_up_table(&probability_table_[0], num_symbols_); if (!ans_.rans_build_look_up_table(&probability_table_[0], num_symbols_))
return false;
return true; return true;
} }

View File

@ -73,6 +73,9 @@ bool DecodeTaggedSymbols(int num_values, int num_components,
tag_decoder.StartDecoding(src_buffer); tag_decoder.StartDecoding(src_buffer);
if (num_values > 0 && tag_decoder.num_symbols() == 0)
return false; // Wrong number of symbols.
// src_buffer now points behind the encoded tag data (to the place where the // src_buffer now points behind the encoded tag data (to the place where the
// values are encoded). // values are encoded).
src_buffer->StartBitDecoding(false, nullptr); src_buffer->StartBitDecoding(false, nullptr);
@ -99,6 +102,10 @@ bool DecodeRawSymbolsInternal(int num_values, DecoderBuffer *src_buffer,
SymbolDecoderT decoder; SymbolDecoderT decoder;
if (!decoder.Create(src_buffer)) if (!decoder.Create(src_buffer))
return false; return false;
if (num_values > 0 && decoder.num_symbols() == 0)
return false; // Wrong number of symbols.
decoder.StartDecoding(src_buffer); decoder.StartDecoding(src_buffer);
for (int i = 0; i < num_values; ++i) { for (int i = 0; i < num_values; ++i) {
// Decode a symbol into the value. // Decode a symbol into the value.