// Copyright 2016 The Draco Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #ifndef DRACO_CORE_RANS_SYMBOL_ENCODER_H_ #define DRACO_CORE_RANS_SYMBOL_ENCODER_H_ #include #include #include #include "core/ans.h" #include "core/encoder_buffer.h" #include "core/rans_symbol_coding.h" namespace draco { // A helper class for encoding symbols using the rANS algorithm (see ans.h). // The class can be used to initialize and encode probability table needed by // rANS, and to perform encoding of symbols into the provided EncoderBuffer. template class RAnsSymbolEncoder { public: RAnsSymbolEncoder() : num_symbols_(0), num_expected_bits_(0), buffer_offset_(0) {} // Creates a probability table needed by the rANS library and encode it into // the provided buffer. bool Create(const uint64_t *frequencies, int num_symbols, EncoderBuffer *buffer); void StartEncoding(EncoderBuffer *buffer); void EncodeSymbol(uint32_t symbol) { ans_.rans_write(&probability_table_[symbol]); } void EndEncoding(EncoderBuffer *buffer); // rANS requires to encode the input symbols in the reverse order. static constexpr bool needs_reverse_encoding() { return true; } private: // Functor used for sorting symbol ids according to their probabilities. // The functor sorts symbol indices that index an underlying map between // symbol ids and their probabilities. We don't sort the probability table // directly, because that would require an additional indirection during the // EncodeSymbol() function. struct ProbabilityLess { explicit ProbabilityLess(const std::vector *probs) : probabilities(probs) {} bool operator()(int i, int j) const { return probabilities->at(i).prob < probabilities->at(j).prob; } const std::vector *probabilities; }; // Encodes the probability table into the output buffer. void EncodeTable(EncoderBuffer *buffer); static constexpr int max_symbols_ = 1 << max_symbol_bit_length_t; static constexpr int rans_precision_bits_ = ComputeRAnsPrecisionFromMaxSymbolBitLength(max_symbol_bit_length_t); static constexpr int rans_precision_ = 1 << rans_precision_bits_; std::vector probability_table_; // The number of symbols in the input alphabet. uint32_t num_symbols_; // Expected number of bits that is needed to encode the input. uint64_t num_expected_bits_; RAnsEncoder ans_; // Initial offset of the encoder buffer before any ans data was encoded. uint64_t buffer_offset_; }; template bool RAnsSymbolEncoder::Create( const uint64_t *frequencies, int num_symbols, EncoderBuffer *buffer) { if (num_symbols > max_symbols_) return false; // Compute the total of the input frequencies. uint64_t total_freq = 0; int max_valid_symbol = 0; for (int i = 0; i < num_symbols; ++i) { total_freq += frequencies[i]; if (frequencies[i] > 0) max_valid_symbol = i; } num_symbols = max_valid_symbol + 1; num_symbols_ = num_symbols; probability_table_.resize(num_symbols); const double total_freq_d = static_cast(total_freq); const double rans_precision_d = static_cast(rans_precision_); // Compute probabilities by rescaling the normalized frequencies into interval // [1, rans_precision - 1]. The total probability needs to be equal to // rans_precision. int total_rans_prob = 0; for (int i = 0; i < num_symbols; ++i) { const uint64_t freq = frequencies[i]; // Normalized probability. const double prob = static_cast(freq) / total_freq_d; // RAns probability in range of [1, rans_precision - 1]. uint32_t rans_prob = static_cast(prob * rans_precision_d + 0.5f); if (rans_prob == 0 && freq > 0) rans_prob = 1; probability_table_[i].prob = rans_prob; total_rans_prob += rans_prob; } // Because of rounding errors, the total precision may not be exactly accurate // and we may need to adjust the entries a little bit. if (total_rans_prob != rans_precision_) { std::vector sorted_probabilities(num_symbols); for (int i = 0; i < num_symbols; ++i) { sorted_probabilities[i] = i; } std::sort(sorted_probabilities.begin(), sorted_probabilities.end(), ProbabilityLess(&probability_table_)); if (total_rans_prob < rans_precision_) { // This happens rather infrequently, just add the extra needed precision // to the most frequent symbol. probability_table_[sorted_probabilities.back()].prob += rans_precision_ - total_rans_prob; } else { // We have over-allocated the precision, which is quite common. // Rescale the probabilities of all symbols. int32_t error = total_rans_prob - rans_precision_; while (error > 0) { const double act_total_prob_d = static_cast(total_rans_prob); const double act_rel_error_d = rans_precision_d / act_total_prob_d; for (int j = num_symbols - 1; j > 0; --j) { int symbol_id = sorted_probabilities[j]; if (probability_table_[symbol_id].prob <= 1) { if (j == num_symbols - 1) return false; // Most frequent symbol would be empty. break; } const int32_t new_prob = floor(act_rel_error_d * static_cast(probability_table_[symbol_id].prob)); int32_t fix = probability_table_[symbol_id].prob - new_prob; if (fix == 0) fix = 1; if (fix >= probability_table_[symbol_id].prob) fix = probability_table_[symbol_id].prob - 1; if (fix > error) fix = error; probability_table_[symbol_id].prob -= fix; total_rans_prob -= fix; error -= fix; if (total_rans_prob == rans_precision_) break; } } } } // Compute the cumulative probability (cdf). uint32_t total_prob = 0; for (int i = 0; i < num_symbols; ++i) { probability_table_[i].cum_prob = total_prob; total_prob += probability_table_[i].prob; } if (total_prob != rans_precision_) return false; // Estimate the number of bits needed to encode the input. // From Shannon entropy the total number of bits N is: // N = -sum{i : all_symbols}(F(i) * log2(P(i))) // where P(i) is the normalized probability of symbol i and F(i) is the // symbol's frequency in the input data. double num_bits = 0; for (int i = 0; i < num_symbols; ++i) { if (probability_table_[i].prob == 0) continue; const double norm_prob = static_cast(probability_table_[i].prob) / rans_precision_d; num_bits += static_cast(frequencies[i]) * log2(norm_prob); } num_expected_bits_ = static_cast(ceil(-num_bits)); EncodeTable(buffer); return true; } template void RAnsSymbolEncoder::EncodeTable( EncoderBuffer *buffer) { buffer->Encode(num_symbols_); // Use varint encoding for the probabilities (first two bits represent the // number of bytes used - 1). for (int i = 0; i < num_symbols_; ++i) { const uint32_t prob = probability_table_[i].prob; int num_extra_bytes = 0; if (prob >= (1 << 6)) { num_extra_bytes++; if (prob >= (1 << 14)) { num_extra_bytes++; if (prob >= (1 << 22)) { num_extra_bytes++; } } } // Encode the first byte (including the number of extra bytes). buffer->Encode(static_cast((prob << 2) | (num_extra_bytes & 3))); // Encode the extra bytes. for (int b = 0; b < num_extra_bytes; ++b) { buffer->Encode(static_cast(prob >> (8 * (b + 1) - 2))); } } } template void RAnsSymbolEncoder::StartEncoding( EncoderBuffer *buffer) { // Allocate extra storage just in case. const uint64_t required_bits = 2 * num_expected_bits_ + 32; buffer_offset_ = buffer->size(); const int64_t required_bytes = (required_bits + 7) / 8; buffer->Resize(buffer_offset_ + required_bytes + sizeof(buffer_offset_)); uint8_t *const data = reinterpret_cast(const_cast(buffer->data())); // Offset the encoding by sizeof(buffer_offset_). We will use this memory to // store the number of encoded bytes. ans_.write_init(data + buffer_offset_ + sizeof(buffer_offset_)); } template void RAnsSymbolEncoder::EndEncoding( EncoderBuffer *buffer) { const int64_t bytes_written = ans_.write_end(); // Store the size of the encoded data. memcpy(const_cast(buffer->data()) + buffer_offset_, &bytes_written, sizeof(bytes_written)); // Resize the buffer to match the number of encoded bytes. buffer->Resize(buffer_offset_ + bytes_written + sizeof(buffer_offset_)); } } // namespace draco #endif // DRACO_CORE_RANS_SYMBOL_ENCODER_H_