├── .gitignore ├── .gitmodules ├── block.h ├── recode.proto ├── Makefile ├── LICENSE ├── README.md ├── framebuffer.h ├── cabac_code.h ├── test └── arithmetic_code.cpp ├── arithmetic_code.h └── recode.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | recode 3 | recode.pb.cc 4 | recode.pb.h 5 | test/arithmetic_code 6 | test/arithmetic_code.dSYM 7 | data 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "libavcodec-hooks"] 2 | path = ffmpeg 3 | url = https://github.com/dropbox/libavcodec-hooks.git 4 | update = rebase 5 | -------------------------------------------------------------------------------- /block.h: -------------------------------------------------------------------------------- 1 | #ifndef _BLOCK_H_ 2 | #define _BLOCK_H_ 3 | 4 | struct Block { 5 | uint16_t residual[(3 * (16 + 1)) * 16]; 6 | int16_t mv_x[4][4]; 7 | int16_t mv_y[4][4]; 8 | }; 9 | struct BlockMeta{ 10 | int32_t rem_pred_mode[16]; 11 | int32_t prev_pred_mode[16]; 12 | uint8_t sub_mb_type[4]; 13 | uint8_t refIdx[4]; 14 | uint8_t cbp; 15 | uint8_t mb_type; 16 | uint8_t lumai16x16mode; 17 | uint8_t chromai8x8mode; 18 | uint8_t last_mb_qp; 19 | uint8_t luma_qp; 20 | bool is_8x8; 21 | bool coded; 22 | uint8_t num_nonzeros[(3 * (16 + 1))]; 23 | }; 24 | #endif 25 | 26 | -------------------------------------------------------------------------------- /recode.proto: -------------------------------------------------------------------------------- 1 | message Recoded { 2 | message Metadata { 3 | optional bytes version = 1; 4 | optional bytes source_commit = 2; 5 | optional bytes binary_sha256 = 3; 6 | optional int64 binary_timestamp = 4; 7 | }; 8 | optional Metadata metadata = 1; 9 | 10 | message Block { 11 | optional int64 size = 1; 12 | optional bytes literal = 2; 13 | optional bool skip_coded = 3; // Always true when present. 14 | optional bytes cabac = 4; 15 | optional bool length_parity = 5; // To detect presence of x264 padding. 16 | optional bytes last_byte = 6; // Last octet (zero or x264 signature bits) 17 | }; 18 | repeated Block block = 2; 19 | }; 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | include ffmpeg/config.mak 2 | 3 | # CXXFLAGS += -Wconversion -Wno-sign-conversion 4 | #-O3 5 | CXXFLAGS += -fsanitize=address -std=c++1y -Wall -g -I. -I./ffmpeg \ 6 | $(shell pkg-config --cflags protobuf) 7 | LDLIBS = -fsanitize=address -L./ffmpeg/libavdevice -lavdevice \ 8 | -L./ffmpeg/libavformat -lavformat \ 9 | -L./ffmpeg/libavfilter -lavfilter \ 10 | -L./ffmpeg/libavcodec -lavcodec \ 11 | -L./ffmpeg/libswresample -lswresample \ 12 | -L./ffmpeg/libswscale -lswscale \ 13 | -L./ffmpeg/libavutil -lavutil \ 14 | $(EXTRALIBS) \ 15 | $(shell pkg-config --libs protobuf) \ 16 | -lstdc++ 17 | 18 | recode: recode.o recode.pb.o ffmpeg/libavcodec/libavcodec.a 19 | 20 | recode.o: recode.cpp recode.pb.h arithmetic_code.h cabac_code.h 21 | 22 | recode.pb.cc recode.pb.h: recode.proto 23 | protoc --cpp_out=. $< 24 | 25 | test/arithmetic_code: test/arithmetic_code.o 26 | 27 | test/arithmetic_code.o: test/arithmetic_code.cpp arithmetic_code.h cabac_code.h 28 | 29 | clean: 30 | rm -f recode recode.o recode.pb.{cc,h,o} 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Dropbox, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | avrecode: lossless re-compression of compressed video streams 2 | ============================================================= 3 | 4 | avrecode reads an already-compressed video file and writes a more compressed 5 | file. Unlike transcoding, which loses fidelity, the compression algorithm used 6 | by avrecode is reversible. The decompressed bytes exactly match the original 7 | input file. However, avrecode's compressed format can only be read by avrecode 8 | -- an avrecode-compressed file cannot be played directly by standard software. 9 | 10 | avrecode works by decoding the video stream into symbols using ffmpeg's 11 | libavcodec. It tries to predict each symbol as it arrives, and re-encodes the 12 | symbols to the compressed file using arithmetic coding. When avrecode's 13 | predictions are higher quality than the predictions specified by the H.264 14 | standard, it achieves a better compression ratio. 15 | 16 | 17 | Installing 18 | ---------- 19 | 20 | avrecode consists of a compression/decompression program and a fork of the 21 | libavcodec library, part of the ffmpeg project. These live in separate github 22 | repositories: 23 | 24 | - https://github.com/dropbox/avrecode 25 | - https://github.com/dropbox/libavcodec-hooks 26 | 27 | The avrecode repository imports the libavcodec-hooks repository as a submodule, 28 | so the `git submodule` command is used to keep them in sync. 29 | 30 | Download the source: 31 | 32 | ``` 33 | git clone https://github.com/dropbox/avrecode 34 | cd avrecode 35 | git submodule update --init 36 | ``` 37 | 38 | Build and test: 39 | 40 | ``` 41 | brew install protobuf 42 | cd ffmpeg 43 | ./configure 44 | make 45 | cd .. 46 | make 47 | ./recode roundtrip data/GOPR4542.MP4 48 | ``` 49 | 50 | Warning 51 | ------- 52 | This is an experimental test bed for compression research: use on trusted inputs only 53 | This tool does not validate input. 54 | 55 | License 56 | ------- 57 | 58 | avrecode is released under the BSD 3-clause license. See the LICENSE file for details. 59 | The required libavcodec-hooks patch to ffmpeg is licenced under the LGPL. 60 | 61 | 62 | Contributing 63 | ------------ 64 | 65 | avrecode was originally written by Chris Lesniewski during Dropbox Hack Week 66 | January 2016. It is a redesign of the first version written by Daniel Horn, 67 | Patrick Horn, Chris Lesniewski, and others during Dropbox Hack Week 2015. 68 | We welcome external contributions, but ask that contributors accept our 69 | Contributor License Agreement to grant us a license to distribute the code: 70 | 71 | https://opensource.dropbox.com/cla/ 72 | -------------------------------------------------------------------------------- /framebuffer.h: -------------------------------------------------------------------------------- 1 | #ifndef _FRAMEBUFFER_H_ 2 | #define _FRAMEBUFFER_H_ 3 | #include "block.h" 4 | 5 | class FrameBuffer { 6 | Block *image_; 7 | uint32_t width_; 8 | uint32_t height_; 9 | uint32_t nblocks_; 10 | BlockMeta *meta_; 11 | uint8_t *storage_; 12 | uint8_t *meta_storage_; 13 | uint8_t *mb_types_; 14 | uint16_t *cbp_; 15 | int frame_num_; 16 | FrameBuffer(const FrameBuffer &other) = delete; 17 | FrameBuffer& operator=(const FrameBuffer&other) = delete; 18 | void destroy() { 19 | if (width_ && height_) { 20 | free(storage_); 21 | free(meta_storage_); 22 | } 23 | memset(this, 0, sizeof(*this)); 24 | } 25 | public: 26 | FrameBuffer() { 27 | image_ = nullptr; 28 | storage_ = nullptr; 29 | width_ = 0; 30 | height_ = 0; 31 | nblocks_ = 0; 32 | } 33 | void bzero() { 34 | memset(meta_, 0, sizeof(BlockMeta) * nblocks_); 35 | memset(image_, 0, sizeof(Block) * nblocks_); 36 | } 37 | void set_frame_num(int frame_num) { 38 | frame_num_ = frame_num; 39 | } 40 | bool is_same_frame(int frame_num) const { 41 | return frame_num_ == frame_num && width_ != 0 && height_ != 0; 42 | } 43 | uint32_t width()const { 44 | return width_; 45 | } 46 | uint32_t height()const { 47 | return height_; 48 | } 49 | void init(uint32_t width, uint32_t height, uint32_t nblocks) { 50 | height_ = height; 51 | width_ = width; 52 | nblocks_ = width * height; 53 | storage_ = (uint8_t*)malloc(nblocks_ * sizeof(Block) + 31); 54 | meta_storage_ = (uint8_t*)malloc(nblocks_ * sizeof(BlockMeta) + 31); 55 | size_t offset = storage_ - (uint8_t *)nullptr; 56 | if (offset & 32) { 57 | image_ = (Block*)(storage_ + 32 - (offset &31)); 58 | } else { // already aligned 59 | image_ = (Block*)storage_; 60 | } 61 | offset = meta_storage_ - (uint8_t *)nullptr; 62 | if (offset & 32) { 63 | meta_ = (BlockMeta*)(meta_storage_ + 32 - (offset &31)); 64 | } else { // already aligned 65 | meta_ = (BlockMeta*)meta_storage_; 66 | } 67 | bzero(); 68 | } 69 | ~FrameBuffer() { 70 | destroy(); 71 | } 72 | size_t block_allocated() const { 73 | return nblocks_; 74 | } 75 | Block& at(uint32_t x, uint32_t y) { 76 | return image_[x + y * width_]; 77 | } 78 | const Block& at(uint32_t x, uint32_t y) const{ 79 | return image_[x + y * width_]; 80 | } 81 | BlockMeta& meta_at(uint32_t x, uint32_t y) { 82 | return meta_[x + y * width_]; 83 | } 84 | const BlockMeta& meta_at(uint32_t x, uint32_t y) const{ 85 | return meta_[x + y * width_]; 86 | } 87 | }; 88 | #endif 89 | -------------------------------------------------------------------------------- /cabac_code.h: -------------------------------------------------------------------------------- 1 | // 2 | // Arithmetic coding for H.264's CABAC encoding. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "arithmetic_code.h" 8 | 9 | extern "C" { 10 | #include "libavcodec/cabac.h" 11 | static const uint8_t * const ff_h264_lps_range = ff_h264_cabac_tables + H264_LPS_RANGE_OFFSET; 12 | static const uint8_t * const ff_h264_mlps_state = ff_h264_cabac_tables + H264_MLPS_STATE_OFFSET; 13 | } 14 | 15 | 16 | struct cabac { 17 | // Word size for encoder/decoder state. Reasonable values: uint64_t, uint32_t. 18 | typedef uint32_t fixed_point; 19 | // Word size for compressed data. Reasonable values: uint16_t, uint8_t. 20 | typedef uint16_t compressed_digit; 21 | // min_range must be at least 0x200 so that range/2 never rounds in put_bypass. 22 | static constexpr int min_range = 0x200; 23 | 24 | typedef arithmetic_code cabac_arithmetic_code; 25 | 26 | template 27 | class encoder { 28 | public: 29 | // Initial range is set so that (range >> normalize) == 0x1FE as required by CABAC spec. 30 | explicit encoder(OutputIterator out) : e(out, (cabac_arithmetic_code::fixed_one/0x200)*0x1FE) {} 31 | 32 | // Translate CABAC tables into generic arithmetic coding. 33 | size_t put(int symbol, uint8_t* state) { 34 | bool is_less_probable_symbol = (symbol != ((*state) & 1)); 35 | size_t retval = e.put(is_less_probable_symbol, [state](fixed_point range) { 36 | // Find the normalizer such that range >> normalize is between 0x100 and 0x200. 37 | int normalize = log2(range / 0x100); 38 | // Use the most significant two bits of range (other than the leading 1) as an index into the table. 39 | int range_approx = int(range >> (normalize-1)); 40 | fixed_point range_of_less_probable_symbol = ff_h264_lps_range[(range_approx & 0x180) + *state]; 41 | return range_of_less_probable_symbol << normalize; 42 | }); 43 | if (is_less_probable_symbol) { 44 | *state = ff_h264_mlps_state[127 - *state]; 45 | } else { 46 | *state = ff_h264_mlps_state[128 + *state]; 47 | } 48 | return retval; 49 | } 50 | 51 | // Simple implementation: put_bypass assumes a symbol probability of exactly 1/2. 52 | size_t put_bypass(int symbol) { 53 | return e.put(symbol, [](fixed_point range) { return range/2; }); 54 | } 55 | 56 | // The end of stream symbol is always assumed to have probability ~2/256. 57 | size_t put_terminate(int end_of_stream_symbol) { 58 | size_t retval = e.put(end_of_stream_symbol, [](fixed_point range) { 59 | int normalize = log2(range / 0x100); 60 | return fixed_point(2) << normalize; 61 | }); 62 | 63 | if (end_of_stream_symbol) { 64 | e.finish(); 65 | } 66 | return retval; 67 | } 68 | 69 | private: 70 | static int log2(uint64_t x) { 71 | int i = 0; 72 | if (x >> 32) { x >>= 32; i += 32; } 73 | if (x >> 16) { x >>= 16; i += 16; } 74 | if (x >> 8) { x >>= 8; i += 8; } 75 | if (x >> 4) { x >>= 4; i += 4; } 76 | if (x >> 2) { x >>= 2; i += 2; } 77 | if (x >> 1) { x >>= 1; i += 1; } 78 | return i; 79 | } 80 | 81 | cabac_arithmetic_code::encoder e; 82 | }; 83 | 84 | class decoder { 85 | }; 86 | }; 87 | -------------------------------------------------------------------------------- /test/arithmetic_code.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "arithmetic_code.h" 6 | #include "cabac_code.h" 7 | 8 | extern "C" { 9 | #include "libavcodec/cabac.h" 10 | } 11 | 12 | 13 | int main(int argc, char* argv[]) { 14 | #if 0 15 | // Testing a particular input that triggered a CABAC encoder bug. 16 | std::vector states = {15, 17, 106, 28, 16, 0, 10, 26, 33, 22, 35, 58, 44, 0, 0, 1, 3, 5}; 17 | std::vector bits = { 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1}; 18 | 19 | std::vector out; 20 | cabac::encoder>> encoder(std::back_inserter(out)); 21 | uint8_t state; 22 | 23 | state = states[0]; 24 | encoder.put(bits[0], &state); 25 | encoder.put_terminate(false); 26 | for (int i = 1; i < bits.size(); i++) { 27 | state = states[i]; 28 | encoder.put(bits[i], &state); 29 | } 30 | for (int i = 0; i < 16; i++) { 31 | state = 0; 32 | encoder.put(0, &state); 33 | } 34 | encoder.put_terminate(true); 35 | 36 | std::cout << out.size() << " " << (unsigned int)out[0] << " " << (unsigned int)out[1] << std::endl; 37 | CABACContext ctx; 38 | ff_init_cabac_decoder(&ctx, &out[0], out.size(), nullptr); 39 | state = states[0]; 40 | ff_get_cabac(&ctx, &state); 41 | ff_get_cabac_terminate(&ctx); 42 | for (int i = 1; i < bits.size(); i++) { 43 | state = states[i]; 44 | std::cout << (ff_get_cabac(&ctx, &state)==bits[i]) << std::endl; 45 | } 46 | 47 | return 0; 48 | #else 49 | std::srand(time(nullptr)); 50 | 51 | std::vector probabilities; 52 | for (int i = 0; i < 5; i++) { 53 | probabilities.push_back(std::rand() % 100); 54 | } 55 | 56 | std::vector bits; 57 | std::vector contexts; 58 | for (int i = 0; i < std::stoi(argv[1]); i++) { 59 | int context = std::rand() % probabilities.size(); 60 | contexts.push_back(context); 61 | bits.push_back((std::rand() % 100) > probabilities[context]); 62 | } 63 | 64 | std::vector states(0x400); 65 | std::vector out; 66 | #if 0 67 | cabac::encoder>> encoder(std::back_inserter(out)); 68 | 69 | for (int i = 0; i < bits.size(); i++) { 70 | encoder.put(bits[i], &states[contexts[i]]); 71 | } 72 | encoder.put_terminate(true); 73 | 74 | std::cout << "compressed size: " << out.size() << std::endl; 75 | 76 | states.resize(0); 77 | states.resize(0x400); 78 | 79 | CABACContext ctx; 80 | ff_init_cabac_decoder(&ctx, &out[0], out.size(), nullptr); 81 | for (int i = 0; i < bits.size(); i++) { 82 | int bit = ff_get_cabac(&ctx, &states[contexts[i]]); 83 | if (bit != bits[i]) { 84 | std::cerr << "mismatch at bit: " << i << ", " << bit << " != " << bits[i] << std::endl; 85 | return 1; 86 | } 87 | } 88 | if (!ff_get_cabac_terminate(&ctx)) { 89 | std::cerr << "mismatch at terminate" << std::endl; 90 | } 91 | return 0; 92 | #else 93 | typedef arithmetic_code code; 94 | auto encoder = make_encoder(&out); 95 | 96 | for (int i = 0; i < bits.size(); i++) { 97 | encoder.put(bits[i], [](uint64_t range){ return range/2; }); 98 | } 99 | encoder.finish(); 100 | 101 | std::cout << "compressed size: " << out.size() << std::endl; 102 | 103 | auto decoder = make_decoder(out); 104 | for (int i = 0; i < bits.size(); i++) { 105 | int bit = decoder.get([](uint64_t range){ return range/2; }); 106 | if (bit != bits[i]) { 107 | std::cerr << "mismatch at bit: " << i << ", " << bit << " != " << bits[i] << std::endl; 108 | return 1; 109 | } 110 | } 111 | return 0; 112 | #endif 113 | #endif 114 | } 115 | -------------------------------------------------------------------------------- /arithmetic_code.h: -------------------------------------------------------------------------------- 1 | // 2 | // Generic arithmetic coding. Used both for recoded encoding/decoding and for 3 | // CABAC re-encoding. 4 | // 5 | // Some notes on the data representations used by the encoder and decoder. 6 | // Uncompressed data: 7 | // Symbols: b_1 ... b_n \in {0,1} . 8 | // Probabilities: p_1 ... p_n \in [0,1], where p_i estimates the probability that b_i=1. 9 | // Compressed data: 10 | // Arithmetic coding represents a compressed stream of symbols as an 11 | // arbitrary-precision number C \in [0,1] . 12 | // If the compressed digits in base M are c_k \in {0..M-1}, then 13 | // C = \sum_{k=1}^K c_k M^{-k} . 14 | // Arithmetic coding uses the probabilities p_i to link the symbols b_i with 15 | // the compressed digits c_k: 16 | // C_i = (1-p_i) b_i + p_i C_{i+1} (1-b_i) 17 | // C_i \in [0,1] 18 | // C_1 = C = \sum_{k=1}^K c_k M^{-k} 19 | // C_n is an arbitrary value in [0,1] (normally used to encode a stop bit). 20 | // 21 | 22 | #pragma once 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | 31 | template 32 | struct arithmetic_code { 33 | private: 34 | static_assert(std::numeric_limits::is_exact, "integer types only"); 35 | static_assert(!std::numeric_limits::is_signed, "unsigned integer types only"); 36 | 37 | template 38 | static constexpr bool is_power_of_2(T x) { 39 | static_assert(std::numeric_limits::is_exact, "expected integer type"); 40 | return (x != 0) && (x & (x-1)) == 0; 41 | } 42 | template 43 | static constexpr FixedPoint digit_base_for() { 44 | static_assert(std::numeric_limits::is_exact, "integer types only"); 45 | static_assert(!std::numeric_limits::is_signed, "unsigned integer types only"); 46 | static_assert(sizeof(FixedPoint) > sizeof(Digit), "digit must be smaller than fixed point"); 47 | static_assert(sizeof(FixedPoint) % sizeof(Digit) == 0, "digit must divide fixed point evenly"); 48 | static_assert(is_power_of_2(FixedPoint(std::numeric_limits::max()) + 1), "expected power of 2"); 49 | return FixedPoint(std::numeric_limits::max()) + 1; 50 | } 51 | 52 | public: 53 | // The representation of 1.0 in fixed-point, e.g. 0x80000000 for uint32_t. 54 | static constexpr FixedPoint fixed_one = 55 | std::numeric_limits::max()/2 + 1; 56 | // The base for compressed digit outputs, e.g. 0x10000 for uint16_t. 57 | static constexpr FixedPoint digit_base = digit_base_for(); 58 | // The minimum precision for probability estimates, e.g. 0x100 for 8-bit 59 | // probabilities as in CABAC. There is a space-time tradeoff: less precision 60 | // means poorer compression, but more precision causes overflow digits more often. 61 | static constexpr FixedPoint min_range = 62 | MinRange > 0 ? MinRange : (fixed_one/digit_base) / 16; 63 | // The maximum range to reach when normalizing. 64 | static constexpr FixedPoint max_range = fixed_one; 65 | 66 | static_assert(is_power_of_2(fixed_one), "expected power of 2"); 67 | static_assert(is_power_of_2(min_range), "expected power of 2"); 68 | static_assert((fixed_one/digit_base)*digit_base == fixed_one, 69 | "expected digit_base to divide fixed_one"); 70 | static_assert(min_range > 1, "min_range too small"); 71 | static_assert(min_range < fixed_one/digit_base, "min_range too large"); 72 | 73 | // The encoder object takes an output iterator (e.g. to vector or ostream) to 74 | // emit compressed digits. 75 | // In addition to uncompressed data and compressed digits, the intermediate state is: 76 | // Maximum R (any positive number, typically 2^k) 77 | // Lower and upper bounds x,y \in [0,R) 78 | // Range r = y-x \in [0,R) 79 | // Representation invariant: 80 | // C = \sum_{k=1}^{K_i} c_k M^{-k} + (x_i + r_i C_i) M^{-K_i}/R_i 81 | // Base case: K_1 = 0, x_1 = 0, r_1 = R_1 82 | // In the base case i=1, K_1=0: C=C_1 is represented as a series of future decisions b_i. 83 | // In the final case i=n, K_n=K: C is represented as a string of compressed digits. 84 | // The various encoding methods modify K, x, r, R while keeping C fixed. 85 | template ::value_type> 87 | class encoder { 88 | static_assert(std::numeric_limits::is_exact, 89 | "integer types only"); 90 | static_assert(!std::numeric_limits::is_signed, 91 | "unsigned integer types only"); 92 | static_assert(sizeof(CompressedDigit) % sizeof(OutputDigit) == 0, 93 | "size of compressed digit must be a multiple of size of output digit"); 94 | 95 | public: 96 | explicit encoder(OutputIterator out) 97 | : encoder(out, fixed_one) {} 98 | encoder(OutputIterator out, FixedPoint initial_range) 99 | : bytes_emitted(0), out(out), low(0), range(initial_range) {} 100 | ~encoder() { finish(); } 101 | size_t get_bytes_emitted()const { 102 | return bytes_emitted; 103 | } 104 | // Symbol is int instead of bool because additional versions of `put()` could 105 | // accept more than two symbols, e.g. one could call `put(2, p1, p2, p3)`. 106 | size_t put(int symbol, std::function probability_of_1) { 107 | FixedPoint range_of_1 = probability_of_1(range); 108 | FixedPoint range_of_0 = range - range_of_1; 109 | if (symbol != 0) { 110 | low += range_of_0; 111 | range = range_of_1; 112 | } else { 113 | range = range_of_0; 114 | } 115 | if (range < min_range) { 116 | if (range == 0) { 117 | throw std::runtime_error("Encoder error: emitted a zero-probability symbol."); 118 | } 119 | size_t emitted_before = get_bytes_emitted(); 120 | while (range < max_range/digit_base) { 121 | renormalize_and_emit_digit(); 122 | } 123 | return get_bytes_emitted() - emitted_before; 124 | } 125 | return 0; 126 | } 127 | 128 | void finish() { 129 | // Find largest stop bit 2^k < range, and x such that 2^k divides x, 130 | // 2^{k+1} doesn't divide x, and x is in [low, low+range). 131 | for (FixedPoint stop_bit = (fixed_one >> 1); stop_bit > 0; stop_bit >>= 1) { 132 | FixedPoint x = (low | stop_bit) & ~(stop_bit - 1); 133 | if (stop_bit < range && low <= x && x < low + range) { 134 | low = x; 135 | break; 136 | } 137 | } 138 | 139 | while (low != 0) { 140 | range = 1; 141 | renormalize_and_emit_digit(); 142 | } 143 | range = 0; // mark complete 144 | } 145 | 146 | private: 147 | template 148 | void renormalize_and_emit_digit() { 149 | static constexpr FixedPoint base = digit_base_for(); 150 | static constexpr FixedPoint most_significant_digit = fixed_one / base; 151 | static_assert(is_power_of_2(most_significant_digit), "expected power of 2"); 152 | 153 | // Check for a carry bit, and cascade from lowest overflow digit to highest. 154 | if (low >= fixed_one) { 155 | for (int i = overflow.size()-1; i >= 0; i--) { 156 | if (++overflow[i] != 0) break; 157 | } 158 | low -= fixed_one; 159 | } 160 | assert(low < fixed_one); 161 | 162 | // Compare the minimum and maximum possible values of the top digit. 163 | // If different, defer emitting the digit until we're sure we won't have to carry. 164 | Digit digit = Digit(low / most_significant_digit); 165 | if (digit != Digit((low + range - 1) / most_significant_digit)) { 166 | assert(range < most_significant_digit); 167 | overflow.push_back(digit); 168 | } else { 169 | for (CompressedDigit overflow_digit : overflow) { 170 | emit_digit(overflow_digit); 171 | } 172 | overflow.clear(); 173 | emit_digit(digit); 174 | } 175 | 176 | // Subtract away the emitted/overflowed digit and renormalize. 177 | low -= digit * most_significant_digit; 178 | low *= base; 179 | range *= base; 180 | } 181 | 182 | // Emit a CompressedDigit as one or more OutputDigits. Loop should be 183 | // unrolled by the compiler. 184 | template 185 | void emit_digit(Digit digit) { 186 | for (int i = sizeof(Digit)-sizeof(OutputDigit); i >= 0; i -= sizeof(OutputDigit)) { 187 | *out++ = OutputDigit(digit >> (8*i)); 188 | } 189 | bytes_emitted += sizeof(digit); 190 | } 191 | size_t bytes_emitted; 192 | // Output digits are emitted to this iterator as they are produced. 193 | OutputIterator out; 194 | // The lower bound x, initialized to 0. (When overflow.size() > 0, low is 195 | // the fractional digits of x/R_0.) 196 | FixedPoint low; 197 | // The range r, which starts as fixed-point 1.0. 198 | FixedPoint range; 199 | // High digits of x. If overflow.size() = s, then R = R_0 M^s (where R_0 = fixed_one). 200 | std::vector overflow; 201 | }; 202 | 203 | // The decoder object takes an input iterator (e.g. from vector or istream) 204 | // to read compressed digits. 205 | // In addition to uncompressed data and compressed digits, the intermediate state is: 206 | // TODO(ctl) document the state, representation invariant, and decoding transitions. 207 | template ::value_type> 209 | class decoder { 210 | static_assert(std::numeric_limits::is_exact, 211 | "integer types only"); 212 | static_assert(!std::numeric_limits::is_signed, 213 | "unsigned integer types only"); 214 | static_assert(sizeof(CompressedDigit) % sizeof(InputDigit) == 0, 215 | "size of compressed digit must be a multiple of size of input digit"); 216 | 217 | public: 218 | explicit decoder(InputIterator in, InputIterator end = InputIterator()) 219 | : decoder(in, end, fixed_one) {} 220 | decoder(InputIterator in, InputIterator end, FixedPoint initial_range) 221 | : in(in), end(end) { 222 | // Initialize the decoder state by reading in bits until range ~ initial_range. 223 | next_digit = consume_digit_aligned(); 224 | low = next_digit / digit_alignment; 225 | range = digit_base / digit_alignment; 226 | while (range < initial_range) { 227 | renormalize_and_consume_digit(); 228 | } 229 | assert(range == initial_range); // Should be true if we set digit_alignment correctly. 230 | } 231 | 232 | int get(std::function probability_of_1) { 233 | FixedPoint range_of_1 = probability_of_1(range); 234 | FixedPoint range_of_0 = range - range_of_1; 235 | int symbol = (low >= range_of_0); 236 | if (symbol != 0) { 237 | low -= range_of_0; 238 | range = range_of_1; 239 | } else { 240 | range = range_of_0; 241 | } 242 | if (range < min_range) { 243 | while (range < max_range/digit_base) { 244 | renormalize_and_consume_digit(); 245 | } 246 | } 247 | return symbol; 248 | } 249 | 250 | private: 251 | static constexpr CompressedDigit digit_alignment = 252 | std::numeric_limits::max()/fixed_one + 1; 253 | static_assert(is_power_of_2(digit_alignment), ""); 254 | static_assert((fixed_one/digit_base)*digit_alignment == (std::numeric_limits::max()/digit_base) + 1, 255 | "expected fixed_one > max/digit_base"); 256 | static_assert(is_power_of_2(digit_base/digit_alignment), 257 | "expected digit_base > digit_alignment"); 258 | 259 | void renormalize_and_consume_digit() { 260 | assert(low < fixed_one/digit_base); 261 | 262 | CompressedDigit digit = consume_digit(); 263 | low = low * digit_base + digit; 264 | range *= digit_base; 265 | } 266 | 267 | // Consume a CompressedDigit. Because our initialization is not 268 | // digit-aligned, we have to bit-align the reads here. 269 | CompressedDigit consume_digit() { 270 | CompressedDigit in_digit = consume_digit_aligned(); 271 | CompressedDigit digit = ((next_digit * (digit_base/digit_alignment)) | 272 | (in_digit / digit_alignment)); 273 | next_digit = in_digit; 274 | return digit; 275 | } 276 | 277 | // Consume a CompressedDigit as one or more InputDigits. Loop should be 278 | // unrolled by the compiler. 279 | CompressedDigit consume_digit_aligned() { 280 | CompressedDigit digit = 0; 281 | for (int i = sizeof(CompressedDigit)-sizeof(InputDigit); i >= 0; i -= sizeof(InputDigit)) { 282 | digit *= digit_base_for(); 283 | if (in != end) { 284 | digit |= CompressedDigit(InputDigit(*in++)); 285 | } 286 | } 287 | return digit; 288 | } 289 | 290 | // Input digits are read from this iterator. 291 | InputIterator in, end; 292 | // The last digit read from the input - the lower bits are still to be used. 293 | CompressedDigit next_digit; 294 | // The offset z from the lower bound. 295 | FixedPoint low; 296 | // The range r, which is initialized to fixed-point 1.0. 297 | FixedPoint range; 298 | }; 299 | }; 300 | 301 | 302 | template , 303 | typename OutputContainer> 304 | typename Coder::template encoder, 305 | typename OutputContainer::value_type> 306 | make_encoder(OutputContainer* container) { 307 | auto it = std::back_inserter(*container); 308 | typedef typename OutputContainer::value_type OutputDigit; 309 | return typename Coder::template encoder(it); 310 | } 311 | 312 | template , 313 | typename InputContainer> 314 | typename Coder::template decoder 316 | make_decoder(const InputContainer& container) { 317 | auto begin = std::begin(container), end = std::end(container); 318 | typedef typename InputContainer::value_type InputDigit; 319 | return typename Coder::template decoder(begin, end); 320 | } 321 | -------------------------------------------------------------------------------- /recode.cpp: -------------------------------------------------------------------------------- 1 | /* -*-mode:c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | extern "C" { 12 | #include "libavcodec/avcodec.h" 13 | #include "libavcodec/cabac.h" 14 | #include "libavcodec/coding_hooks.h" 15 | #include "libavformat/avformat.h" 16 | #include "libavformat/avio.h" 17 | #include "libavutil/error.h" 18 | #include "libavutil/file.h" 19 | } 20 | 21 | #include "arithmetic_code.h" 22 | #include "cabac_code.h" 23 | #include "recode.pb.h" 24 | #include "framebuffer.h" 25 | 26 | // CABAC blocks smaller than this will be skipped. 27 | const int SURROGATE_MARKER_BYTES = 8; 28 | //#define DO_NEIGHBOR_LOGGING 29 | #ifdef DO_NEIGHBOR_LOGGING 30 | #define LOG_NEIGHBORS printf 31 | #else 32 | #define LOG_NEIGHBORS(...) 33 | #endif 34 | template 35 | std::unique_ptr> av_unique_ptr(T* p, const std::function& deleter) { 36 | if (p == nullptr) { 37 | throw std::bad_alloc(); 38 | } 39 | return std::unique_ptr>(p, deleter); 40 | } 41 | template 42 | std::unique_ptr> av_unique_ptr(T* p, void (*deleter)(T**)) { 43 | return av_unique_ptr(p, [deleter](T*& to_delete){ deleter(&to_delete); }); 44 | } 45 | template 46 | std::unique_ptr> av_unique_ptr(T* p, void (*deleter)(void*) = av_free) { 47 | return av_unique_ptr(p, [deleter](T*& to_delete){ deleter(to_delete); }); 48 | } 49 | 50 | template > 51 | struct defer { 52 | T to_defer; 53 | explicit defer(const T& to_defer) : to_defer(to_defer) {} 54 | defer(const defer&) = delete; 55 | ~defer() { to_defer(); } 56 | }; 57 | 58 | 59 | int av_check(int return_value, int expected_error = 0, const std::string& message = "") { 60 | if (return_value >= 0 || return_value == expected_error) { 61 | return return_value; 62 | } else { 63 | char err[AV_ERROR_MAX_STRING_SIZE]; 64 | av_make_error_string(err, AV_ERROR_MAX_STRING_SIZE, return_value); 65 | throw std::runtime_error(message + ": " + err); 66 | } 67 | } 68 | bool av_check(int return_value, const std::string& message = "") { 69 | return av_check(return_value, 0, message); 70 | } 71 | 72 | 73 | // Sets up a libavcodec decoder with I/O and decoding hooks. 74 | template 75 | class av_decoder { 76 | public: 77 | av_decoder(Driver *driver, const std::string& input_filename) : driver(driver) { 78 | const size_t avio_ctx_buffer_size = 1024*1024; 79 | uint8_t *avio_ctx_buffer = static_cast( av_malloc(avio_ctx_buffer_size) ); 80 | 81 | format_ctx = avformat_alloc_context(); 82 | if (avio_ctx_buffer == nullptr || format_ctx == nullptr) throw std::bad_alloc(); 83 | format_ctx->pb = avio_alloc_context( 84 | avio_ctx_buffer, // input buffer 85 | avio_ctx_buffer_size, // input buffer size 86 | false, // stream is not writable 87 | this, // first argument for read_packet() 88 | read_packet, // read callback 89 | nullptr, // write_packet() 90 | nullptr); // seek() 91 | 92 | if (avformat_open_input(&format_ctx, input_filename.c_str(), nullptr, nullptr) < 0) { 93 | throw std::invalid_argument("Failed to initialize decoding context: " + input_filename); 94 | } 95 | } 96 | ~av_decoder() { 97 | for (size_t i = 0; i < format_ctx->nb_streams; i++) { 98 | avcodec_close(format_ctx->streams[i]->codec); 99 | } 100 | av_freep(&format_ctx->pb->buffer); // May no longer be the same buffer we initially malloced. 101 | av_freep(&format_ctx->pb); 102 | avformat_close_input(&format_ctx); 103 | } 104 | 105 | // Read enough frames to display stream diagnostics. Only used by compressor, 106 | // because hooks are not yet set. Reads from already in-memory blocks. 107 | void dump_stream_info() { 108 | av_check( avformat_find_stream_info(format_ctx, nullptr), 109 | "Invalid input stream information" ); 110 | av_dump_format(format_ctx, 0, format_ctx->filename, 0); 111 | } 112 | 113 | // Decode all video frames in the file in single-threaded mode, calling the driver's hooks. 114 | void decode_video() { 115 | auto frame = av_unique_ptr(av_frame_alloc(), av_frame_free); 116 | AVPacket packet; 117 | // TODO(ctl) add better diagnostics to error results. 118 | while (!av_check( av_read_frame(format_ctx, &packet), AVERROR_EOF, "Failed to read frame" )) { 119 | AVCodecContext *codec = format_ctx->streams[packet.stream_index]->codec; 120 | if (codec->codec_type == AVMEDIA_TYPE_VIDEO) { 121 | if (!avcodec_is_open(codec)) { 122 | codec->thread_count = 1; 123 | codec->hooks = &hooks; 124 | av_check( avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), nullptr), 125 | "Failed to open decoder for stream " + std::to_string(packet.stream_index) ); 126 | } 127 | 128 | int got_frame = 0; 129 | av_check( avcodec_decode_video2(codec, frame.get(), &got_frame, &packet), 130 | "Failed to decode video frame" ); 131 | } 132 | av_packet_unref(&packet); 133 | } 134 | } 135 | 136 | private: 137 | // Hook stubs - wrap driver into opaque pointers. 138 | static int read_packet(void *opaque, uint8_t *buffer_out, int size) { 139 | av_decoder *self = static_cast(opaque); 140 | return self->driver->read_packet(buffer_out, size); 141 | } 142 | struct cabac { 143 | static void* init_decoder(void *opaque, CABACContext *ctx, const uint8_t *buf, int size) { 144 | av_decoder *self = static_cast(opaque); 145 | auto *cabac_decoder = new typename Driver::cabac_decoder(self->driver, ctx, buf, size); 146 | self->cabac_contexts[ctx].reset(cabac_decoder); 147 | return cabac_decoder; 148 | } 149 | static int get(void *opaque, uint8_t *state) { 150 | auto *self = static_cast(opaque); 151 | return self->get(state); 152 | } 153 | static int get_bypass(void *opaque) { 154 | auto *self = static_cast(opaque); 155 | return self->get_bypass(); 156 | } 157 | static int get_terminate(void *opaque) { 158 | auto *self = static_cast(opaque); 159 | return self->get_terminate(); 160 | } 161 | static const uint8_t* skip_bytes(void *opaque, int n) { 162 | throw std::runtime_error("Not implemented: CABAC decoder doesn't use skip_bytes."); 163 | } 164 | }; 165 | struct model_hooks { 166 | static void frame_spec(void *opaque, int frame_num, int mb_width, int mb_height) { 167 | auto *self = static_cast(opaque)->driver->get_model(); 168 | self->update_frame_spec(frame_num, mb_width, mb_height); 169 | } 170 | static void mb_xy(void *opaque, int x, int y) { 171 | auto *self = static_cast(opaque)->driver->get_model(); 172 | self->mb_coord.mb_x = x; 173 | self->mb_coord.mb_y = y; 174 | } 175 | static void begin_sub_mb(void *opaque, int cat, int scan8index, int max_coeff, int is_dc, int chroma422) { 176 | auto *self = static_cast(opaque)->driver->get_model(); 177 | self->sub_mb_cat = cat; 178 | self->mb_coord.scan8_index = scan8index; 179 | self->sub_mb_size = max_coeff; 180 | self->sub_mb_is_dc = is_dc; 181 | self->sub_mb_chroma422 = chroma422; 182 | } 183 | static void end_sub_mb(void *opaque, int cat, int scan8index, int max_coeff, int is_dc, int chroma422) { 184 | auto *self = static_cast(opaque)->driver->get_model(); 185 | assert(self->sub_mb_cat == cat); 186 | assert(self->mb_coord.scan8_index == scan8index); 187 | assert(self->sub_mb_size == max_coeff); 188 | assert(self->sub_mb_is_dc == is_dc); 189 | assert(self->sub_mb_chroma422 == chroma422); 190 | self->sub_mb_cat = -1; 191 | self->mb_coord.scan8_index = -1; 192 | self->sub_mb_size = -1; 193 | self->sub_mb_is_dc = 0; 194 | self->sub_mb_chroma422 = 0; 195 | } 196 | static void begin_coding_type(void *opaque, CodingType ct, 197 | int zigzag_index, int param0, int param1) { 198 | auto &cabac_contexts = static_cast(opaque)->cabac_contexts; 199 | assert(cabac_contexts.size() == 1); 200 | typename Driver::cabac_decoder*self = cabac_contexts.begin()->second.get(); 201 | self->begin_coding_type(ct, zigzag_index, param0, param1); 202 | } 203 | static void end_coding_type(void *opaque, CodingType ct) { 204 | auto &cabac_contexts = static_cast(opaque)->cabac_contexts; 205 | assert(cabac_contexts.size() == 1); 206 | typename Driver::cabac_decoder*self = cabac_contexts.begin()->second.get(); 207 | self->end_coding_type(ct); 208 | } 209 | }; 210 | Driver *driver; 211 | AVFormatContext *format_ctx; 212 | AVCodecHooks hooks = { this, { 213 | cabac::init_decoder, 214 | cabac::get, 215 | cabac::get_bypass, 216 | cabac::get_terminate, 217 | cabac::skip_bytes, 218 | }, 219 | { 220 | model_hooks::frame_spec, 221 | model_hooks::mb_xy, 222 | model_hooks::begin_sub_mb, 223 | model_hooks::end_sub_mb, 224 | model_hooks::begin_coding_type, 225 | model_hooks::end_coding_type, 226 | 227 | }, 228 | }; 229 | std::map> cabac_contexts; 230 | }; 231 | 232 | 233 | struct r_scan8 { 234 | uint16_t scan8_index; 235 | bool neighbor_left; 236 | bool neighbor_up; 237 | bool is_invalid() const { 238 | return scan8_index == 0 && neighbor_left && neighbor_up; 239 | } 240 | static constexpr r_scan8 inv() { 241 | return {0, true, true}; 242 | } 243 | }; 244 | /* Scan8 organization: 245 | * 0 1 2 3 4 5 6 7 246 | * 0 DY y y y y y 247 | * 1 y Y Y Y Y 248 | * 2 y Y Y Y Y 249 | * 3 y Y Y Y Y 250 | * 4 du y Y Y Y Y 251 | * 5 DU u u u u u 252 | * 6 u U U U U 253 | * 7 u U U U U 254 | * 8 u U U U U 255 | * 9 dv u U U U U 256 | * 10 DV v v v v v 257 | * 11 v V V V V 258 | * 12 v V V V V 259 | * 13 v V V V V 260 | * 14 v V V V V 261 | * DY/DU/DV are for luma/chroma DC. 262 | */ 263 | constexpr uint8_t scan_8[16 * 3 + 3] = { 264 | 4 + 1 * 8, 5 + 1 * 8, 4 + 2 * 8, 5 + 2 * 8, 265 | 6 + 1 * 8, 7 + 1 * 8, 6 + 2 * 8, 7 + 2 * 8, 266 | 4 + 3 * 8, 5 + 3 * 8, 4 + 4 * 8, 5 + 4 * 8, 267 | 6 + 3 * 8, 7 + 3 * 8, 6 + 4 * 8, 7 + 4 * 8, 268 | 4 + 6 * 8, 5 + 6 * 8, 4 + 7 * 8, 5 + 7 * 8, 269 | 6 + 6 * 8, 7 + 6 * 8, 6 + 7 * 8, 7 + 7 * 8, 270 | 4 + 8 * 8, 5 + 8 * 8, 4 + 9 * 8, 5 + 9 * 8, 271 | 6 + 8 * 8, 7 + 8 * 8, 6 + 9 * 8, 7 + 9 * 8, 272 | 4 + 11 * 8, 5 + 11 * 8, 4 + 12 * 8, 5 + 12 * 8, 273 | 6 + 11 * 8, 7 + 11 * 8, 6 + 12 * 8, 7 + 12 * 8, 274 | 4 + 13 * 8, 5 + 13 * 8, 4 + 14 * 8, 5 + 14 * 8, 275 | 6 + 13 * 8, 7 + 13 * 8, 6 + 14 * 8, 7 + 14 * 8, 276 | 0 + 0 * 8, 0 + 5 * 8, 0 + 10 * 8 277 | }; 278 | 279 | constexpr r_scan8 reverse_scan_8[15][8] = { 280 | //Y 281 | {{16 * 3, false, false}, r_scan8::inv(), r_scan8::inv(), {15, true, true}, 282 | {10, false, true}, {11, false, true}, {14, false, true}, {15, false, true}}, 283 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {5, true, false}, 284 | {0, false, false}, {1, false, false}, {4, false, false}, {5, false, false}}, 285 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {7, true, false}, 286 | {2, false, false}, {3, false, false}, {6, false, false}, {7, false, false}}, 287 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {13, true, false}, 288 | {8, false, false}, {9, false, false}, {12, false, false}, {13, false, false}}, 289 | {{16 * 3 + 1,false, true}, r_scan8::inv(), r_scan8::inv(), {15, true, false}, 290 | {10, false, false}, {11, false, false}, {14, false, false}, {15, false, false}}, 291 | // U 292 | {{16 * 3 + 1,false, false}, r_scan8::inv(), r_scan8::inv(), {16 + 15, true, true}, 293 | {16 + 10, false, true}, {16 + 11, false, true}, {16 + 14, false, true}, {16 + 15, false, true}}, 294 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {16 + 5, true, false}, 295 | {16 + 0, false, false}, {16 + 1, false, false}, {16 + 4, false, false}, {16 + 5, false, false}}, 296 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {16 + 7, true, false}, 297 | {16 + 2, false, false}, {16 + 3, false, false}, {16 + 6, false, false}, {16 + 7, false, false}}, 298 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {16 + 13, true, false}, 299 | {16 + 8, false, false}, {16 + 9, false, false}, {16 + 12, false, false}, {16 + 13, false, false}}, 300 | {{16 * 3 + 2,false, true}, r_scan8::inv(), r_scan8::inv(), {16 + 15, true, false}, 301 | {16 + 10, false, false}, {16 + 11, false, false}, {16 + 14, false, false}, {16 + 15, false, false}}, 302 | // V 303 | {{16 * 3 + 2,false, false}, r_scan8::inv(), r_scan8::inv(), {32 + 15, true, true}, 304 | {32 + 10, false, true}, {32 + 11, false, true}, {32 + 14, false, true}, {32 + 15, false, true}}, 305 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {32 + 5, true, false}, 306 | {32 + 0, false, false}, {32 + 1, false, false}, {32 + 4, false, false}, {32 + 5, false, false}}, 307 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {32 + 7, true, false}, 308 | {32 + 2, false, false}, {32 + 3, false, false}, {32 + 6, false, false}, {32 + 7, false, false}}, 309 | {r_scan8::inv(), r_scan8::inv(), r_scan8::inv(), {32 + 13, true, false}, 310 | {32 + 8, false, false}, {32 + 9, false, false}, {32 + 12, false, false}, {32 + 13, false, false}}, 311 | {{32 + 16 * 3 + 1,false, true}, r_scan8::inv(), r_scan8::inv(), {32 + 15, true, false}, 312 | {32 + 10, false, false}, {32 + 11, false, false}, {32 + 14, false, false}, {32 + 15, false, false}}}; 313 | 314 | // Encoder / decoder for recoded CABAC blocks. 315 | typedef uint64_t range_t; 316 | typedef arithmetic_code recoded_code; 317 | 318 | typedef std::tuple model_key; 319 | /* 320 | not sure these tables are the ones we want to use 321 | constexpr uint8_t unzigzag16[16] = { 322 | 0 + 0 * 4, 0 + 1 * 4, 1 + 0 * 4, 0 + 2 * 4, 323 | 0 + 3 * 4, 1 + 1 * 4, 1 + 2 * 4, 1 + 3 * 4, 324 | 2 + 0 * 4, 2 + 1 * 4, 2 + 2 * 4, 2 + 3 * 4, 325 | 3 + 0 * 4, 3 + 1 * 4, 3 + 2 * 4, 3 + 3 * 4, 326 | }; 327 | constexpr uint8_t zigzag16[16] = { 328 | 0, 2, 8, 12, 329 | 1, 5, 9, 13, 330 | 3, 6, 10, 14, 331 | 4, 7, 11, 15 332 | }; 333 | 334 | constexpr uint8_t zigzag_field64[64] = { 335 | 0 + 0 * 8, 0 + 1 * 8, 0 + 2 * 8, 1 + 0 * 8, 336 | 1 + 1 * 8, 0 + 3 * 8, 0 + 4 * 8, 1 + 2 * 8, 337 | 2 + 0 * 8, 1 + 3 * 8, 0 + 5 * 8, 0 + 6 * 8, 338 | 0 + 7 * 8, 1 + 4 * 8, 2 + 1 * 8, 3 + 0 * 8, 339 | 2 + 2 * 8, 1 + 5 * 8, 1 + 6 * 8, 1 + 7 * 8, 340 | 2 + 3 * 8, 3 + 1 * 8, 4 + 0 * 8, 3 + 2 * 8, 341 | 2 + 4 * 8, 2 + 5 * 8, 2 + 6 * 8, 2 + 7 * 8, 342 | 3 + 3 * 8, 4 + 1 * 8, 5 + 0 * 8, 4 + 2 * 8, 343 | 3 + 4 * 8, 3 + 5 * 8, 3 + 6 * 8, 3 + 7 * 8, 344 | 4 + 3 * 8, 5 + 1 * 8, 6 + 0 * 8, 5 + 2 * 8, 345 | 4 + 4 * 8, 4 + 5 * 8, 4 + 6 * 8, 4 + 7 * 8, 346 | 5 + 3 * 8, 6 + 1 * 8, 6 + 2 * 8, 5 + 4 * 8, 347 | 5 + 5 * 8, 5 + 6 * 8, 5 + 7 * 8, 6 + 3 * 8, 348 | 7 + 0 * 8, 7 + 1 * 8, 6 + 4 * 8, 6 + 5 * 8, 349 | 6 + 6 * 8, 6 + 7 * 8, 7 + 2 * 8, 7 + 3 * 8, 350 | 7 + 4 * 8, 7 + 5 * 8, 7 + 6 * 8, 7 + 7 * 8, 351 | }; 352 | 353 | */ 354 | constexpr uint8_t zigzag4[4] = { 355 | 0, 1, 2, 3 356 | }; 357 | constexpr uint8_t unzigzag4[4] = { 358 | 0, 1, 2, 3 359 | }; 360 | 361 | constexpr uint8_t unzigzag16[16] = { 362 | 0, 1, 4, 8, 363 | 5, 2, 3, 6, 364 | 9, 12, 13, 10, 365 | 7, 11, 14, 15 366 | }; 367 | constexpr uint8_t zigzag16[16] = { 368 | 0, 1, 5, 6, 369 | 2, 4, 7, 12, 370 | 3, 8, 11, 13, 371 | 9, 10, 14, 15 372 | }; 373 | constexpr uint8_t unzigzag64[64] = { 374 | 0, 1, 8, 16, 9, 2, 3, 10, 375 | 17, 24, 32, 25, 18, 11, 4, 5, 376 | 12, 19, 26, 33, 40, 48, 41, 34, 377 | 27, 20, 13, 6, 7, 14, 21, 28, 378 | 35, 42, 49, 56, 57, 50, 43, 36, 379 | 29, 22, 15, 23, 30, 37, 44, 51, 380 | 58, 59, 52, 45, 38, 31, 39, 46, 381 | 53, 60, 61, 54, 47, 55, 62, 63 382 | }; 383 | 384 | constexpr uint8_t zigzag64[64] = { 385 | 0, 1, 5, 6, 14, 15, 27, 28, 386 | 2, 4, 7, 13, 16, 26, 29, 42, 387 | 3, 8, 12, 17, 25, 30, 41, 43, 388 | 9, 11, 18, 24, 31, 40, 44, 53, 389 | 10, 19, 23, 32, 39, 45, 52, 54, 390 | 20, 22, 33, 38, 46, 51, 55, 60, 391 | 21, 34, 37, 47, 50, 56, 59, 61, 392 | 35, 36, 48, 49, 57, 58, 62, 63 393 | }; 394 | 395 | 396 | int test_reverse_scan8() { 397 | for (size_t i = 0; i < sizeof(scan_8)/ sizeof(scan_8[0]); ++i) { 398 | auto a = reverse_scan_8[scan_8[i] >> 3][scan_8[i] & 7]; 399 | assert(a.neighbor_left == false && a.neighbor_up == false); 400 | assert(a.scan8_index == i); 401 | if (a.scan8_index != i) { 402 | return 1; 403 | } 404 | } 405 | for (int i = 0;i < 16; ++i) { 406 | assert(zigzag16[unzigzag16[i]] == i); 407 | assert(unzigzag16[zigzag16[i]] == i); 408 | } 409 | return 0; 410 | } 411 | int make_sure_reverse_scan8 = test_reverse_scan8(); 412 | struct CoefficientCoord { 413 | int mb_x; 414 | int mb_y; 415 | int scan8_index; 416 | int zigzag_index; 417 | }; 418 | 419 | bool get_neighbor_sub_mb(bool above, int sub_mb_size, 420 | CoefficientCoord input, 421 | CoefficientCoord *output) { 422 | int mb_x = input.mb_x; 423 | int mb_y = input.mb_y; 424 | int scan8_index = input.scan8_index; 425 | output->scan8_index = scan8_index; 426 | output->mb_x = mb_x; 427 | output->mb_y = mb_y; 428 | output->zigzag_index = input.zigzag_index; 429 | if (scan8_index >= 16 * 3) { 430 | if (above) { 431 | if (mb_y > 0) { 432 | output->mb_y -= 1; 433 | return true; 434 | } 435 | return false; 436 | } else { 437 | if (mb_x > 0) { 438 | output->mb_x -= 1; 439 | return true; 440 | } 441 | return false; 442 | } 443 | } 444 | int scan8 = scan_8[scan8_index]; 445 | int left_shift = (above ? 0 : -1); 446 | int above_shift = (above ? -1 : 0); 447 | auto neighbor = reverse_scan_8[(scan8 >> 3) + above_shift][(scan8 & 7) + left_shift]; 448 | if (neighbor.neighbor_left) { 449 | if (mb_x == 0){ 450 | return false; 451 | } else { 452 | --mb_x; 453 | } 454 | } 455 | if (neighbor.neighbor_up) { 456 | if (mb_y == 0) { 457 | return false; 458 | } else { 459 | --mb_y; 460 | } 461 | } 462 | output->scan8_index = neighbor.scan8_index; 463 | if (sub_mb_size >= 32) { 464 | output->scan8_index /= 4; 465 | output->scan8_index *= 4; // round down to the nearest multiple of 4 466 | } 467 | output->zigzag_index = input.zigzag_index; 468 | output->mb_x = mb_x; 469 | output->mb_y = mb_y; 470 | return true; 471 | } 472 | int log2(int y) { 473 | int x = -1; 474 | while (y) { 475 | y/=2; 476 | x++; 477 | } 478 | return x; 479 | } 480 | bool get_neighbor(bool above, int sub_mb_size, 481 | CoefficientCoord input, 482 | CoefficientCoord *output) { 483 | int mb_x = input.mb_x; 484 | int mb_y = input.mb_y; 485 | int scan8_index = input.scan8_index; 486 | unsigned int zigzag_index = input.zigzag_index; 487 | int dimension = 2; 488 | if (sub_mb_size > 15) { 489 | dimension = 4; 490 | } 491 | if (sub_mb_size > 32) { 492 | dimension = 8; 493 | } 494 | if (scan8_index >= 16 * 3) { 495 | // we are DC... 496 | int linear_index = unzigzag4[zigzag_index & 0x3]; 497 | if (sub_mb_size == 16) { 498 | linear_index = unzigzag16[zigzag_index & 0xf]; 499 | } else { 500 | assert(sub_mb_size <= 4); 501 | } 502 | if ((above && linear_index >= dimension) // if is inner 503 | || ((linear_index & (dimension - 1)) && !above)) { 504 | if (above) { 505 | linear_index -= dimension; 506 | } else { 507 | -- linear_index; 508 | } 509 | if (sub_mb_size == 16) { 510 | output->zigzag_index = zigzag16[linear_index]; 511 | } else { 512 | output->zigzag_index = zigzag4[linear_index]; 513 | } 514 | output->mb_x = mb_x; 515 | output->mb_y = mb_y; 516 | output->scan8_index = scan8_index; 517 | return true; 518 | } 519 | if (above) { 520 | if (mb_y == 0) { 521 | return false; 522 | } 523 | linear_index += dimension * (dimension - 1);//go to bottom 524 | --mb_y; 525 | } else { 526 | if (mb_x == 0) { 527 | return false; 528 | } 529 | linear_index += dimension - 1;//go to end of row 530 | --mb_x; 531 | } 532 | if (sub_mb_size == 16) { 533 | output->zigzag_index = zigzag16[linear_index]; 534 | } else { 535 | output->zigzag_index = linear_index; 536 | } 537 | output->mb_x = mb_x; 538 | output->mb_y = mb_y; 539 | output->scan8_index = scan8_index; 540 | return true; 541 | } 542 | int scan8 = scan_8[scan8_index]; 543 | int left_shift = (above ? 0 : -1); 544 | int above_shift = (above ? -1 : 0); 545 | auto neighbor = reverse_scan_8[(scan8 >> 3) + above_shift][(scan8 & 7) + left_shift]; 546 | if (neighbor.neighbor_left) { 547 | if (mb_x == 0){ 548 | return false; 549 | } else { 550 | --mb_x; 551 | } 552 | } 553 | if (neighbor.neighbor_up) { 554 | if (mb_y == 0) { 555 | return false; 556 | } else { 557 | --mb_y; 558 | } 559 | } 560 | output->scan8_index = neighbor.scan8_index; 561 | if (sub_mb_size >= 32) { 562 | output->scan8_index /= 4; 563 | output->scan8_index *= 4; // round down to the nearest multiple of 4 564 | } 565 | output->zigzag_index = zigzag_index; 566 | output->mb_x = mb_x; 567 | output->mb_y = mb_y; 568 | return true; 569 | } 570 | 571 | bool get_neighbor_coefficient(bool above, 572 | int sub_mb_size, 573 | CoefficientCoord input, 574 | CoefficientCoord *output) { 575 | if (input.scan8_index >= 16 * 3) { 576 | return get_neighbor(above, sub_mb_size, input, output); 577 | } 578 | int zigzag_addition = 0; 579 | 580 | if ((sub_mb_size & (sub_mb_size - 1)) != 0) { 581 | zigzag_addition = 1;// the DC is not included 582 | } 583 | const uint8_t *zigzag_to_raster = unzigzag16; 584 | const uint8_t *raster_to_zigzag = zigzag16; 585 | int dim = 4; 586 | if (sub_mb_size <= 4) { 587 | dim = 2; 588 | zigzag_to_raster = zigzag4; 589 | raster_to_zigzag = unzigzag4; 590 | } 591 | if (sub_mb_size > 16) { 592 | dim = 16; 593 | zigzag_to_raster = zigzag64; 594 | raster_to_zigzag = unzigzag64; 595 | } 596 | int raster_coord = zigzag_to_raster[input.zigzag_index + zigzag_addition]; 597 | //fprintf(stderr, "%d %d %d -> %d\n", sub_mb_size, zigzag_addition, input.zigzag_index, raster_coord); 598 | if (above) { 599 | if (raster_coord >= dim) { 600 | raster_coord -= dim; 601 | } else { 602 | return false; 603 | } 604 | } else { 605 | if (raster_coord & (dim - 1)) { 606 | raster_coord -= 1; 607 | } else { 608 | return false; 609 | } 610 | } 611 | *output = input; 612 | output->zigzag_index = raster_to_zigzag[raster_coord] - zigzag_addition; 613 | return true; 614 | } 615 | #define STRINGIFY_COMMA(s) #s , 616 | const char * billing_names [] = {EACH_PIP_CODING_TYPE(STRINGIFY_COMMA)}; 617 | #undef STRINGIFY_COMMA 618 | class h264_model { 619 | public: 620 | CodingType coding_type = PIP_UNKNOWN; 621 | size_t bill[sizeof(billing_names)/sizeof(billing_names[0])]; 622 | size_t cabac_bill[sizeof(billing_names)/sizeof(billing_names[0])]; 623 | FrameBuffer frames[2]; 624 | int cur_frame = 0; 625 | uint8_t STATE_FOR_NUM_NONZERO_BIT[6]; 626 | bool do_print; 627 | public: 628 | h264_model() { reset(); do_print = false; memset(bill, 0, sizeof(bill)); memset(cabac_bill, 0, sizeof(cabac_bill));} 629 | void enable_debug() { 630 | do_print = true; 631 | } 632 | void disable_debug() { 633 | do_print = false; 634 | } 635 | ~h264_model() { 636 | bool first = true; 637 | for (size_t i = 0; i < sizeof(billing_names)/sizeof(billing_names[i]); ++i) { 638 | if (bill[i]) { 639 | if (first) { 640 | fprintf(stderr, "Avrecode Bill\n=============\n"); 641 | } 642 | first = false; 643 | fprintf(stderr, "%s : %ld\n", billing_names[i], bill[i]); 644 | } 645 | } 646 | for (size_t i = 0; i < sizeof(billing_names)/sizeof(billing_names[i]); ++i) { 647 | if (cabac_bill[i]) { 648 | if (first) { 649 | fprintf(stderr, "CABAC Bill\n=============\n"); 650 | } 651 | first = false; 652 | fprintf(stderr, "%s : %ld\n", billing_names[i], cabac_bill[i]); 653 | } 654 | } 655 | } 656 | void billable_bytes(size_t num_bytes_emitted) { 657 | bill[coding_type] += num_bytes_emitted; 658 | } 659 | void billable_cabac_bytes(size_t num_bytes_emitted) { 660 | cabac_bill[coding_type] += num_bytes_emitted; 661 | } 662 | void reset() { 663 | // reset should do nothing as we wish to remember what we've learned 664 | memset(STATE_FOR_NUM_NONZERO_BIT, 0, sizeof(STATE_FOR_NUM_NONZERO_BIT)); 665 | } 666 | bool fetch(bool previous, bool match_type, CoefficientCoord coord, int16_t*output) const{ 667 | if (match_type && (previous || coord.mb_x != mb_coord.mb_x || coord.mb_y != mb_coord.mb_y)) { 668 | BlockMeta meta = frames[previous ? !cur_frame : cur_frame].meta_at(coord.mb_x, coord.mb_y); 669 | if (!meta.coded) { // when we populate mb_type in the metadata, then we can use it here 670 | return false; 671 | } 672 | } 673 | *output = frames[previous ? !cur_frame : cur_frame].at(coord.mb_x, coord.mb_y).residual[coord.scan8_index * 16 + coord.zigzag_index]; 674 | return true; 675 | } 676 | model_key get_model_key(const void *context)const { 677 | switch(coding_type) { 678 | case PIP_SIGNIFICANCE_NZ: 679 | return model_key(context, 0, 0); 680 | case PIP_UNKNOWN: 681 | case PIP_UNREACHABLE: 682 | case PIP_RESIDUALS: 683 | return model_key(context, 0, 0); 684 | case PIP_SIGNIFICANCE_MAP: 685 | { 686 | static const uint8_t sig_coeff_flag_offset_8x8[2][63] = { 687 | { 0, 1, 2, 3, 4, 5, 5, 4, 4, 3, 3, 4, 4, 4, 5, 5, 688 | 4, 4, 4, 4, 3, 3, 6, 7, 7, 7, 8, 9,10, 9, 8, 7, 689 | 7, 6,11,12,13,11, 6, 7, 8, 9,14,10, 9, 8, 6,11, 690 | 12,13,11, 6, 9,14,10, 9,11,12,13,11,14,10,12 }, 691 | { 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 7, 7, 8, 4, 5, 692 | 6, 9,10,10, 8,11,12,11, 9, 9,10,10, 8,11,12,11, 693 | 9, 9,10,10, 8,11,12,11, 9, 9,10,10, 8,13,13, 9, 694 | 9,10,10, 8,13,13, 9, 9,10,10,14,14,14,14,14 } 695 | }; 696 | int cat_lookup[14] = { 105+0, 105+15, 105+29, 105+44, 105+47, 402, 484+0, 484+15, 484+29, 660, 528+0, 528+15, 528+29, 718 }; 697 | static const uint8_t sig_coeff_offset_dc[7] = { 0, 0, 1, 1, 2, 2, 2 }; 698 | int zigzag_offset = mb_coord.zigzag_index; 699 | if (sub_mb_is_dc && sub_mb_chroma422) { 700 | assert(mb_coord.zigzag_index < 7); 701 | zigzag_offset = sig_coeff_offset_dc[mb_coord.zigzag_index]; 702 | } else { 703 | if (sub_mb_size > 32) { assert(mb_coord.zigzag_index < 63); 704 | zigzag_offset = sig_coeff_flag_offset_8x8[0][mb_coord.zigzag_index]; 705 | } 706 | } 707 | assert(sub_mb_cat < (int)(sizeof(cat_lookup)/sizeof(cat_lookup[0]))); 708 | int neighbor_above = 2; 709 | int neighbor_left = 2; 710 | int coeff_neighbor_above = 2; 711 | int coeff_neighbor_left = 2; 712 | if (do_print) { 713 | LOG_NEIGHBORS("["); 714 | } 715 | { 716 | CoefficientCoord neighbor_left_coord = {0, 0, 0, 0}; 717 | if (get_neighbor(false, sub_mb_size, mb_coord, &neighbor_left_coord)) { 718 | int16_t tmp = 0; 719 | if (fetch(false, true, neighbor_left_coord, &tmp)){ 720 | neighbor_left = !!tmp; 721 | if (do_print) { 722 | LOG_NEIGHBORS("%d,", tmp); 723 | } 724 | } else { 725 | neighbor_left = 3; 726 | if (do_print) { 727 | LOG_NEIGHBORS("_,"); 728 | } 729 | } 730 | } else { 731 | if (do_print) { 732 | LOG_NEIGHBORS("x,"); 733 | } 734 | } 735 | } 736 | { 737 | CoefficientCoord neighbor_above_coord = {0, 0, 0, 0}; 738 | if (get_neighbor(true, sub_mb_size, mb_coord, &neighbor_above_coord)) { 739 | int16_t tmp = 0; 740 | if (fetch(false, true, neighbor_above_coord, &tmp)){ 741 | neighbor_above = !!tmp; 742 | if (do_print) { 743 | LOG_NEIGHBORS("%d,", tmp); 744 | } 745 | } else { 746 | neighbor_above = 3; 747 | if (do_print) { 748 | LOG_NEIGHBORS("_,"); 749 | } 750 | } 751 | } else { 752 | if (do_print) { 753 | LOG_NEIGHBORS("x,"); 754 | } 755 | } 756 | 757 | } 758 | { 759 | CoefficientCoord neighbor_left_coord = {0, 0, 0, 0}; 760 | if (get_neighbor_coefficient(false, sub_mb_size, mb_coord, &neighbor_left_coord)) { 761 | int16_t tmp = 0; 762 | if (fetch(false, true, neighbor_left_coord, &tmp)){ 763 | coeff_neighbor_left = !!tmp; 764 | } else { 765 | coeff_neighbor_left = 3; 766 | } 767 | } else { 768 | } 769 | } 770 | { 771 | CoefficientCoord neighbor_above_coord = {0, 0, 0, 0}; 772 | if (get_neighbor_coefficient(true, sub_mb_size, mb_coord, &neighbor_above_coord)) { 773 | int16_t tmp = 0; 774 | if (fetch(false, true, neighbor_above_coord, &tmp)){ 775 | coeff_neighbor_above = !!tmp; 776 | } else { 777 | coeff_neighbor_above = 3; 778 | } 779 | } else { 780 | } 781 | } 782 | 783 | // FIXM: why doesn't this prior help at all 784 | { 785 | int16_t output = 0; 786 | if (fetch(true, true, mb_coord, &output)) { 787 | if (do_print) LOG_NEIGHBORS("%d] ", output); 788 | } else { 789 | if (do_print) LOG_NEIGHBORS("x] "); 790 | } 791 | } 792 | //const BlockMeta &meta = frames[!cur_frame].meta_at(mb_x, mb_y); 793 | int num_nonzeros = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).num_nonzeros[mb_coord.scan8_index]; 794 | (void)neighbor_above; 795 | (void)neighbor_left; 796 | (void)coeff_neighbor_above; 797 | (void)coeff_neighbor_left;//haven't found a good way to utilize these priors to make the results better 798 | return model_key(&significance_context, 799 | 64 * num_nonzeros + nonzeros_observed, 800 | sub_mb_is_dc + zigzag_offset * 2 + 16 * 2 * cat_lookup[sub_mb_cat]); 801 | } 802 | case PIP_SIGNIFICANCE_EOB: 803 | { 804 | // FIXME: why doesn't this prior help at all 805 | static int fake_context = 0; 806 | int num_nonzeros = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).num_nonzeros[mb_coord.scan8_index]; 807 | 808 | return model_key(&fake_context, num_nonzeros == nonzeros_observed, 0); 809 | } 810 | default: 811 | break; 812 | } 813 | assert(false && "Unreachable"); 814 | abort(); 815 | } 816 | range_t probability_for_model_key(range_t range, model_key key) { 817 | auto* e = &estimators[key]; 818 | int total = e->pos + e->neg; 819 | return (range/total) * e->pos; 820 | } 821 | range_t probability_for_state(range_t range, const void *context) { 822 | return probability_for_model_key(range, get_model_key(context)); 823 | } 824 | void update_frame_spec(int frame_num, int mb_width, int mb_height) { 825 | if (frames[cur_frame].width() != (uint32_t)mb_width 826 | || frames[cur_frame].height() != (uint32_t)mb_height 827 | || !frames[cur_frame].is_same_frame(frame_num)) { 828 | cur_frame = !cur_frame; 829 | if (frames[cur_frame].width() != (uint32_t)mb_width 830 | || frames[cur_frame].height() != (uint32_t)mb_height) { 831 | frames[cur_frame].init(mb_width, mb_height, mb_width * mb_height); 832 | if (frames[!cur_frame].width() != (uint32_t)mb_width 833 | || frames[!cur_frame].height() != (uint32_t)mb_height) { 834 | frames[!cur_frame].init(mb_width, mb_height, mb_width * mb_height); 835 | } 836 | //fprintf(stderr, "Init(%d=%d) %d x %d\n", frame_num, cur_frame, mb_width, mb_height); 837 | } else { 838 | frames[cur_frame].bzero(); 839 | //fprintf(stderr, "Clear (%d=%d)\n", frame_num, cur_frame); 840 | } 841 | frames[cur_frame].set_frame_num(frame_num); 842 | } 843 | } 844 | template 845 | void finished_queueing(CodingType ct, const Functor &put_or_get) { 846 | 847 | if (ct == PIP_SIGNIFICANCE_MAP) { 848 | bool block_of_interest = (sub_mb_cat == 1 || sub_mb_cat == 2); 849 | CodingType last = coding_type; 850 | coding_type = PIP_SIGNIFICANCE_NZ; 851 | BlockMeta &meta = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y); 852 | int nonzero_bits[6] = {}; 853 | for (int i= 0; i < 6; ++i) { 854 | nonzero_bits[i] = (meta.num_nonzeros[mb_coord.scan8_index] & (1 << i)) >> i; 855 | } 856 | #define QUEUE_MODE 857 | #ifdef QUEUE_MODE 858 | const uint32_t serialized_bits = sub_mb_size > 16 ? 6 : sub_mb_size > 4 ? 4 : 2; 859 | { 860 | uint32_t i = 0; 861 | uint32_t serialized_so_far = 0; 862 | CoefficientCoord neighbor; 863 | uint32_t left_nonzero = 0; 864 | uint32_t above_nonzero = 0; 865 | bool has_left = get_neighbor_sub_mb(false, sub_mb_size, mb_coord, &neighbor); 866 | if (has_left) { 867 | left_nonzero = frames[cur_frame].meta_at(neighbor.mb_x, neighbor.mb_y).num_nonzeros[neighbor.scan8_index]; 868 | } 869 | bool has_above = get_neighbor_sub_mb(true, sub_mb_size, mb_coord, &neighbor); 870 | if (has_above) { 871 | above_nonzero = frames[cur_frame].meta_at(neighbor.mb_x, neighbor.mb_y).num_nonzeros[neighbor.scan8_index]; 872 | } 873 | 874 | do { 875 | uint32_t cur_bit = (1<= cur_bit); 879 | } 880 | int above_nonzero_bit = 2; 881 | if (above_nonzero) { 882 | above_nonzero_bit = (above_nonzero >= cur_bit); 883 | } 884 | put_or_get(model_key(&(STATE_FOR_NUM_NONZERO_BIT[i]), serialized_so_far + 64 * (frames[!cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).num_nonzeros[mb_coord.scan8_index] >= cur_bit) + 128 * left_nonzero_bit + 384 * above_nonzero_bit, meta.is_8x8 + sub_mb_is_dc * 2 + sub_mb_chroma422 + sub_mb_cat * 4), &nonzero_bits[i]); 885 | if (nonzero_bits[i]) { 886 | serialized_so_far |= cur_bit; 887 | } 888 | } while (++i < serialized_bits); 889 | if (block_of_interest) { 890 | LOG_NEIGHBORS("<{"); 891 | } 892 | if (has_left) { 893 | if (block_of_interest) { 894 | LOG_NEIGHBORS("%d,", left_nonzero); 895 | } 896 | } else { 897 | if (block_of_interest) { 898 | LOG_NEIGHBORS("X,"); 899 | } 900 | } 901 | if (has_above) { 902 | if (block_of_interest) { 903 | LOG_NEIGHBORS("%d,", above_nonzero); 904 | } 905 | } else { 906 | if (block_of_interest) { 907 | LOG_NEIGHBORS("X,"); 908 | } 909 | } 910 | if (frames[!cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).coded) { 911 | if (block_of_interest) { 912 | LOG_NEIGHBORS("%d",frames[!cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).num_nonzeros[mb_coord.scan8_index]); 913 | } 914 | } else { 915 | if (block_of_interest) { 916 | LOG_NEIGHBORS("X"); 917 | } 918 | } 919 | } 920 | #endif 921 | meta.num_nonzeros[mb_coord.scan8_index] = 0; 922 | for (int i= 0; i < 6; ++i) { 923 | meta.num_nonzeros[mb_coord.scan8_index] |= nonzero_bits[i] << i; 924 | } 925 | if (block_of_interest) { 926 | LOG_NEIGHBORS("} %d> ",meta.num_nonzeros[mb_coord.scan8_index]); 927 | } 928 | coding_type = last; 929 | } 930 | } 931 | void end_coding_type(CodingType ct) { 932 | if (ct == PIP_SIGNIFICANCE_MAP) { 933 | assert(coding_type == PIP_UNREACHABLE 934 | || (coding_type == PIP_SIGNIFICANCE_MAP && mb_coord.zigzag_index == 0)); 935 | uint8_t num_nonzeros = 0; 936 | for (int i = 0; i < sub_mb_size; ++i) { 937 | int16_t res = frames[cur_frame].at(mb_coord.mb_x, mb_coord.mb_y).residual[mb_coord.scan8_index * 16 + i]; 938 | assert(res == 1 || res == 0); 939 | if (res != 0) { 940 | num_nonzeros += 1; 941 | } 942 | } 943 | BlockMeta &meta = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y); 944 | meta.is_8x8 = meta.is_8x8 || (sub_mb_size > 32); // 8x8 will have DC be 2x2 945 | meta.coded = true; 946 | assert(meta.num_nonzeros[mb_coord.scan8_index] == 0 || meta.num_nonzeros[mb_coord.scan8_index] == num_nonzeros); 947 | meta.num_nonzeros[mb_coord.scan8_index] = num_nonzeros; 948 | } 949 | coding_type = PIP_UNKNOWN; 950 | } 951 | bool begin_coding_type(CodingType ct, int zz_index, int param0, int param1) { 952 | 953 | bool begin_queueing = false; 954 | coding_type = ct; 955 | switch (ct) { 956 | case PIP_SIGNIFICANCE_MAP: 957 | { 958 | BlockMeta &meta = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y); 959 | meta.num_nonzeros[mb_coord.scan8_index] = 0; 960 | } 961 | assert(!zz_index); 962 | nonzeros_observed = 0; 963 | if (sub_mb_is_dc) { 964 | mb_coord.zigzag_index = 0; 965 | } else { 966 | mb_coord.zigzag_index = 0; 967 | } 968 | begin_queueing = true; 969 | break; 970 | default: 971 | break; 972 | } 973 | return begin_queueing; 974 | } 975 | void reset_mb_significance_state_tracking() { 976 | mb_coord.zigzag_index = 0; 977 | nonzeros_observed = 0; 978 | coding_type = PIP_SIGNIFICANCE_MAP; 979 | } 980 | void update_state_tracking(int symbol) { 981 | switch (coding_type) { 982 | case PIP_SIGNIFICANCE_NZ: 983 | break; 984 | case PIP_SIGNIFICANCE_MAP: 985 | frames[cur_frame].at(mb_coord.mb_x, mb_coord.mb_y).residual[mb_coord.scan8_index * 16 + mb_coord.zigzag_index] = symbol; 986 | nonzeros_observed += symbol; 987 | if (mb_coord.zigzag_index + 1 == sub_mb_size) { 988 | coding_type = PIP_UNREACHABLE; 989 | mb_coord.zigzag_index = 0; 990 | } else { 991 | if (symbol) { 992 | coding_type = PIP_SIGNIFICANCE_EOB; 993 | } else { 994 | ++mb_coord.zigzag_index; 995 | if (mb_coord.zigzag_index + 1 == sub_mb_size) { 996 | // if we were a zero and we haven't eob'd then the 997 | // next and last must be a one 998 | frames[cur_frame].at(mb_coord.mb_x, mb_coord.mb_y).residual[mb_coord.scan8_index * 16 + mb_coord.zigzag_index] = 1; 999 | ++nonzeros_observed; 1000 | coding_type = PIP_UNREACHABLE; 1001 | mb_coord.zigzag_index = 0; 1002 | } 1003 | } 1004 | } 1005 | break; 1006 | case PIP_SIGNIFICANCE_EOB: 1007 | if (symbol) { 1008 | mb_coord.zigzag_index = 0; 1009 | coding_type = PIP_UNREACHABLE; 1010 | } else if (mb_coord.zigzag_index + 2 == sub_mb_size) { 1011 | frames[cur_frame].at(mb_coord.mb_x, mb_coord.mb_y).residual[mb_coord.scan8_index * 16 + mb_coord.zigzag_index + 1] = 1; 1012 | coding_type = PIP_UNREACHABLE; 1013 | } else { 1014 | coding_type = PIP_SIGNIFICANCE_MAP; 1015 | ++mb_coord.zigzag_index; 1016 | } 1017 | break; 1018 | case PIP_RESIDUALS: 1019 | case PIP_UNKNOWN: 1020 | break; 1021 | case PIP_UNREACHABLE: 1022 | assert(false); 1023 | default: 1024 | assert(false); 1025 | } 1026 | } 1027 | void update_state(int symbol, const void *context) { 1028 | update_state_for_model_key(symbol, get_model_key(context)); 1029 | } 1030 | void update_state_for_model_key(int symbol, model_key key) { 1031 | if (coding_type == PIP_SIGNIFICANCE_EOB) { 1032 | int num_nonzeros = frames[cur_frame].meta_at(mb_coord.mb_x, mb_coord.mb_y).num_nonzeros[mb_coord.scan8_index]; 1033 | assert(symbol == (num_nonzeros == nonzeros_observed)); 1034 | } 1035 | auto* e = &estimators[key]; 1036 | if (symbol) { 1037 | e->pos++; 1038 | } else { 1039 | e->neg++; 1040 | } 1041 | if ((coding_type != PIP_SIGNIFICANCE_MAP && e->pos + e->neg > 0x60) 1042 | || (coding_type == PIP_SIGNIFICANCE_MAP && e->pos + e->neg > 0x50)) { 1043 | e->pos = (e->pos + 1) / 2; 1044 | e->neg = (e->neg + 1) / 2; 1045 | } 1046 | update_state_tracking(symbol); 1047 | } 1048 | 1049 | const uint8_t bypass_context = 0, terminate_context = 0, significance_context = 0; 1050 | CoefficientCoord mb_coord; 1051 | int nonzeros_observed = 0; 1052 | int sub_mb_cat = -1; 1053 | int sub_mb_size = -1; 1054 | int sub_mb_is_dc = 0; 1055 | int sub_mb_chroma422 = 0; 1056 | private: 1057 | struct estimator { int pos = 1, neg = 1; }; 1058 | std::map estimators; 1059 | }; 1060 | 1061 | class h264_symbol { 1062 | public: 1063 | h264_symbol(int symbol, const void*state) 1064 | : symbol(symbol), state(state) { 1065 | } 1066 | 1067 | template 1068 | void execute(T &encoder, h264_model *model, 1069 | Recoded::Block *out, std::vector &encoder_out) { 1070 | bool in_significance_map = (model->coding_type == PIP_SIGNIFICANCE_MAP); 1071 | bool block_of_interest = (model->sub_mb_cat == 1 || model->sub_mb_cat == 2); 1072 | bool print_priors = in_significance_map && block_of_interest; 1073 | if (model->coding_type != PIP_SIGNIFICANCE_EOB) { 1074 | size_t billable_bytes = encoder.put(symbol, [&](range_t range){ 1075 | return model->probability_for_state(range, state); }); 1076 | if (billable_bytes) { 1077 | model->billable_bytes(billable_bytes); 1078 | } 1079 | }else if (block_of_interest) { 1080 | if (symbol) { 1081 | LOG_NEIGHBORS("\n"); 1082 | } 1083 | } 1084 | if (print_priors) { 1085 | model->enable_debug(); 1086 | } 1087 | model->update_state(symbol, state); 1088 | if (print_priors) { 1089 | LOG_NEIGHBORS("%d ", symbol); 1090 | model->disable_debug(); 1091 | } 1092 | if (state == &model->terminate_context && symbol) { 1093 | encoder.finish(); 1094 | out->set_cabac(&encoder_out[0], encoder_out.size()); 1095 | } 1096 | } 1097 | private: 1098 | int symbol; 1099 | const void* state; 1100 | }; 1101 | 1102 | class compressor { 1103 | public: 1104 | compressor(const std::string& input_filename, std::ostream& out_stream) 1105 | : input_filename(input_filename), out_stream(out_stream) { 1106 | if (av_file_map(input_filename.c_str(), &original_bytes, &original_size, 0, NULL) < 0) { 1107 | throw std::invalid_argument("Failed to open file: " + input_filename); 1108 | } 1109 | } 1110 | 1111 | ~compressor() { 1112 | av_file_unmap(original_bytes, original_size); 1113 | } 1114 | 1115 | void run() { 1116 | // Run through all the frames in the file, building the output using our hooks. 1117 | av_decoder d(this, input_filename); 1118 | d.dump_stream_info(); 1119 | d.decode_video(); 1120 | 1121 | // Flush the final block to the output and write to stdout. 1122 | out.add_block()->set_literal( 1123 | &original_bytes[prev_coded_block_end], original_size - prev_coded_block_end); 1124 | out_stream << out.SerializeAsString(); 1125 | } 1126 | 1127 | int read_packet(uint8_t *buffer_out, int size) { 1128 | size = std::min(size, int(original_size - read_offset)); 1129 | memcpy(buffer_out, &original_bytes[read_offset], size); 1130 | read_offset += size; 1131 | return size; 1132 | } 1133 | 1134 | class cabac_decoder { 1135 | public: 1136 | cabac_decoder(compressor *c, CABACContext *ctx_in, const uint8_t *buf, int size) { 1137 | out = c->find_next_coded_block_and_emit_literal(buf, size); 1138 | model = nullptr; 1139 | if (out == nullptr) { 1140 | // We're skipping this block, so disable calls to our hooks. 1141 | ctx_in->coding_hooks = nullptr; 1142 | ctx_in->coding_hooks_opaque = nullptr; 1143 | ::ff_reset_cabac_decoder(ctx_in, buf, size); 1144 | return; 1145 | } 1146 | 1147 | out->set_size(size); 1148 | 1149 | ctx = *ctx_in; 1150 | ctx.coding_hooks = nullptr; 1151 | ctx.coding_hooks_opaque = nullptr; 1152 | ::ff_reset_cabac_decoder(&ctx, buf, size); 1153 | 1154 | this->c = c; 1155 | model = &c->model; 1156 | model->reset(); 1157 | } 1158 | ~cabac_decoder() { assert(out == nullptr || out->has_cabac()); } 1159 | 1160 | void execute_symbol(int symbol, const void* state) { 1161 | h264_symbol sym(symbol, state); 1162 | #define QUEUE_MODE 1163 | #ifdef QUEUE_MODE 1164 | if (queueing_symbols == PIP_SIGNIFICANCE_MAP || queueing_symbols == PIP_SIGNIFICANCE_EOB || !symbol_buffer.empty()) { 1165 | symbol_buffer.push_back(sym); 1166 | model->update_state_tracking(symbol); 1167 | } else { 1168 | #endif 1169 | sym.execute(encoder, model, out, encoder_out); 1170 | #ifdef QUEUE_MODE 1171 | } 1172 | #endif 1173 | } 1174 | 1175 | int get(uint8_t *state) { 1176 | int symbol = ::ff_get_cabac(&ctx, state); 1177 | execute_symbol(symbol, state); 1178 | return symbol; 1179 | } 1180 | 1181 | int get_bypass() { 1182 | int symbol = ::ff_get_cabac_bypass(&ctx); 1183 | execute_symbol(symbol, &model->bypass_context); 1184 | return symbol; 1185 | } 1186 | 1187 | int get_terminate() { 1188 | int n = ::ff_get_cabac_terminate(&ctx); 1189 | int symbol = (n != 0); 1190 | execute_symbol(symbol, &model->terminate_context); 1191 | return symbol; 1192 | } 1193 | 1194 | void begin_coding_type( 1195 | CodingType ct, int zigzag_index, int param0, int param1) { 1196 | if (!model) { 1197 | return; 1198 | } 1199 | bool begin_queue = model->begin_coding_type(ct, zigzag_index, param0, param1); 1200 | if (begin_queue && (ct == PIP_SIGNIFICANCE_MAP || ct == PIP_SIGNIFICANCE_EOB)) { 1201 | push_queueing_symbols(ct); 1202 | } 1203 | } 1204 | void end_coding_type(CodingType ct) { 1205 | if (!model) { 1206 | return; 1207 | } 1208 | model->end_coding_type(ct); 1209 | 1210 | if ((ct == PIP_SIGNIFICANCE_MAP || ct == PIP_SIGNIFICANCE_EOB)) { 1211 | stop_queueing_symbols(); 1212 | model->finished_queueing(ct, 1213 | [&](model_key key, int*symbol) { 1214 | size_t billable_bytes = encoder.put(*symbol, [&](range_t range){ 1215 | return model->probability_for_model_key(range, key); 1216 | }); 1217 | model->update_state_for_model_key(*symbol, key); 1218 | if (billable_bytes) { 1219 | model->billable_bytes(billable_bytes); 1220 | } 1221 | }); 1222 | static int i = 0; 1223 | if (i++ < 10) { 1224 | std::cerr << "FINISHED QUEUING DECODE: " << (int)(model->frames[model->cur_frame].meta_at(model->mb_coord.mb_x, model->mb_coord.mb_y).num_nonzeros[model->mb_coord.scan8_index]) << std::endl; 1225 | } 1226 | pop_queueing_symbols(ct); 1227 | model->coding_type = PIP_UNKNOWN; 1228 | } 1229 | } 1230 | 1231 | private: 1232 | void push_queueing_symbols(CodingType ct) { 1233 | // Does not currently support nested queues. 1234 | assert (queueing_symbols == PIP_UNKNOWN); 1235 | assert (symbol_buffer.empty()); 1236 | queueing_symbols = ct; 1237 | } 1238 | 1239 | void stop_queueing_symbols() { 1240 | assert (queueing_symbols != PIP_UNKNOWN); 1241 | queueing_symbols = PIP_UNKNOWN; 1242 | } 1243 | 1244 | void pop_queueing_symbols(CodingType ct) { 1245 | //std::cerr<< "FINISHED QUEUEING "<< symbol_buffer.size()<reset_mb_significance_state_tracking(); 1249 | } 1250 | } 1251 | for (auto &sym : symbol_buffer) { 1252 | sym.execute(encoder, model, out, encoder_out); 1253 | } 1254 | symbol_buffer.clear(); 1255 | } 1256 | 1257 | Recoded::Block *out; 1258 | CABACContext ctx; 1259 | 1260 | compressor *c; 1261 | h264_model *model; 1262 | std::vector encoder_out; 1263 | recoded_code::encoder>, uint8_t> encoder{ 1264 | std::back_inserter(encoder_out)}; 1265 | 1266 | CodingType queueing_symbols = PIP_UNKNOWN; 1267 | std::vector symbol_buffer; 1268 | }; 1269 | h264_model *get_model() { 1270 | return &model; 1271 | } 1272 | 1273 | private: 1274 | 1275 | Recoded::Block* find_next_coded_block_and_emit_literal(const uint8_t *buf, int size) { 1276 | uint8_t *found = static_cast( memmem( 1277 | &original_bytes[prev_coded_block_end], read_offset - prev_coded_block_end, 1278 | buf, size) ); 1279 | if (found && size >= SURROGATE_MARKER_BYTES) { 1280 | size_t gap = found - &original_bytes[prev_coded_block_end]; 1281 | out.add_block()->set_literal(&original_bytes[prev_coded_block_end], gap); 1282 | prev_coded_block_end += gap + size; 1283 | Recoded::Block *newBlock = out.add_block(); 1284 | newBlock->set_length_parity(size & 1); 1285 | if (size > 1) { 1286 | newBlock->set_last_byte(&(buf[size - 1]), 1); 1287 | } 1288 | return newBlock; // Return a block for the recoder to fill. 1289 | } else { 1290 | // Can't recode this block, probably because it was NAL-escaped. Place 1291 | // a skip marker in the block list. 1292 | Recoded::Block* block = out.add_block(); 1293 | block->set_skip_coded(true); 1294 | block->set_size(size); 1295 | return nullptr; // Tell the recoder to ignore this block. 1296 | } 1297 | } 1298 | 1299 | std::string input_filename; 1300 | std::ostream& out_stream; 1301 | 1302 | uint8_t *original_bytes = nullptr; 1303 | size_t original_size = 0; 1304 | int read_offset = 0; 1305 | int prev_coded_block_end = 0; 1306 | 1307 | h264_model model; 1308 | Recoded out; 1309 | }; 1310 | 1311 | 1312 | class decompressor { 1313 | // Used to track the decoding state of each block. 1314 | struct block_state { 1315 | bool coded = false; 1316 | std::string surrogate_marker; 1317 | std::string out_bytes; 1318 | bool done = false; 1319 | int8_t length_parity = -1; 1320 | uint8_t last_byte; 1321 | }; 1322 | 1323 | public: 1324 | decompressor(const std::string& input_filename, std::ostream& out_stream) 1325 | : input_filename(input_filename), out_stream(out_stream) { 1326 | uint8_t *bytes; 1327 | size_t size; 1328 | if (av_file_map(input_filename.c_str(), &bytes, &size, 0, NULL) < 0) { 1329 | throw std::invalid_argument("Failed to open file: " + input_filename); 1330 | } 1331 | in.ParseFromArray(bytes, size); 1332 | } 1333 | decompressor(const std::string& input_filename, const std::string& in_bytes, std::ostream& out_stream) 1334 | : input_filename(input_filename), out_stream(out_stream) { 1335 | in.ParseFromString(in_bytes); 1336 | } 1337 | 1338 | void run() { 1339 | blocks.clear(); 1340 | blocks.resize(in.block_size()); 1341 | 1342 | av_decoder d(this, input_filename); 1343 | d.decode_video(); 1344 | 1345 | for (auto& block : blocks) { 1346 | if (!block.done) throw std::runtime_error("Not all blocks were decoded."); 1347 | if (block.length_parity != -1) { 1348 | // Correct for x264 padding: replace last byte or add an extra byte. 1349 | if (block.length_parity != (int)(block.out_bytes.size() & 1)) { 1350 | block.out_bytes.insert(block.out_bytes.end(), block.last_byte); 1351 | } else { 1352 | block.out_bytes[block.out_bytes.size() - 1] = block.last_byte; 1353 | } 1354 | } 1355 | out_stream << block.out_bytes; 1356 | } 1357 | } 1358 | 1359 | int read_packet(uint8_t *buffer_out, int size) { 1360 | uint8_t *p = buffer_out; 1361 | while (size > 0 && read_index < in.block_size()) { 1362 | if (read_block.empty()) { 1363 | const Recoded::Block& block = in.block(read_index); 1364 | if (int(block.has_literal()) + int(block.has_cabac()) + int(block.has_skip_coded()) != 1) { 1365 | throw std::runtime_error("Invalid input block: must have exactly one type"); 1366 | } 1367 | if (block.has_literal()) { 1368 | // This block is passed through without any re-coding. 1369 | blocks[read_index].out_bytes = block.literal(); 1370 | blocks[read_index].done = true; 1371 | read_block = block.literal(); 1372 | } else if (block.has_cabac()) { 1373 | // Re-coded CABAC coded block. out_bytes will be filled by cabac_decoder. 1374 | blocks[read_index].coded = true; 1375 | blocks[read_index].surrogate_marker = next_surrogate_marker(); 1376 | blocks[read_index].done = false; 1377 | if (!block.has_size()) { 1378 | throw std::runtime_error("CABAC block requires size field."); 1379 | } 1380 | if (block.has_length_parity() && block.has_last_byte() && 1381 | !block.last_byte().empty()) { 1382 | blocks[read_index].length_parity = block.length_parity(); 1383 | blocks[read_index].last_byte = block.last_byte()[0]; 1384 | } 1385 | read_block = make_surrogate_block(blocks[read_index].surrogate_marker, block.size()); 1386 | } else if (block.has_skip_coded() && block.skip_coded()) { 1387 | // Non-re-coded CABAC coded block. The bytes of this block are 1388 | // emitted in a literal block following this one. This block is 1389 | // a flag to expect a cabac_decoder without a surrogate marker. 1390 | blocks[read_index].coded = true; 1391 | blocks[read_index].done = true; 1392 | } else { 1393 | throw std::runtime_error("Unknown input block type"); 1394 | } 1395 | } 1396 | if ((size_t)read_offset < read_block.size()) { 1397 | int n = read_block.copy(reinterpret_cast(p), size, read_offset); 1398 | read_offset += n; 1399 | p += n; 1400 | size -= n; 1401 | } 1402 | if ((size_t)read_offset >= read_block.size()) { 1403 | read_block.clear(); 1404 | read_offset = 0; 1405 | read_index++; 1406 | } 1407 | } 1408 | return p - buffer_out; 1409 | } 1410 | 1411 | class cabac_decoder { 1412 | public: 1413 | cabac_decoder(decompressor *d, CABACContext *ctx_in, const uint8_t *buf, int size) { 1414 | index = d->recognize_coded_block(buf, size); 1415 | block = &d->in.block(index); 1416 | out = &d->blocks[index]; 1417 | model = nullptr; 1418 | 1419 | if (block->has_cabac()) { 1420 | model = &d->model; 1421 | model->reset(); 1422 | decoder.reset(new recoded_code::decoder( 1423 | block->cabac().data(), block->cabac().data() + block->cabac().size())); 1424 | } else if (block->has_skip_coded() && block->skip_coded()) { 1425 | // We're skipping this block, so disable calls to our hooks. 1426 | ctx_in->coding_hooks = nullptr; 1427 | ctx_in->coding_hooks_opaque = nullptr; 1428 | ::ff_reset_cabac_decoder(ctx_in, buf, size); 1429 | } else { 1430 | throw std::runtime_error("Expected CABAC block."); 1431 | } 1432 | } 1433 | ~cabac_decoder() { assert(out->done); } 1434 | 1435 | int get(uint8_t *state) { 1436 | int symbol; 1437 | if (model->coding_type == PIP_SIGNIFICANCE_EOB) { 1438 | symbol = std::get<1>(model->get_model_key(state)); 1439 | } else { 1440 | symbol = decoder->get([&](range_t range){ 1441 | return model->probability_for_state(range, state); }); 1442 | } 1443 | size_t billable_bytes = cabac_encoder.put(symbol, state); 1444 | if (billable_bytes) { 1445 | model->billable_cabac_bytes(billable_bytes); 1446 | } 1447 | model->update_state(symbol, state); 1448 | return symbol; 1449 | } 1450 | 1451 | int get_bypass() { 1452 | int symbol = decoder->get([&](range_t range){ 1453 | return model->probability_for_state(range, &model->bypass_context); }); 1454 | model->update_state(symbol, &model->bypass_context); 1455 | size_t billable_bytes = cabac_encoder.put_bypass(symbol); 1456 | if (billable_bytes) { 1457 | model->billable_cabac_bytes(billable_bytes); 1458 | } 1459 | return symbol; 1460 | } 1461 | 1462 | int get_terminate() { 1463 | int symbol = decoder->get([&](range_t range){ 1464 | return model->probability_for_state(range, &model->terminate_context); }); 1465 | model->update_state(symbol, &model->terminate_context); 1466 | size_t billable_bytes = cabac_encoder.put_terminate(symbol); 1467 | if (billable_bytes) { 1468 | model->billable_cabac_bytes(billable_bytes); 1469 | } 1470 | if (symbol) { 1471 | finish(); 1472 | } 1473 | return symbol; 1474 | } 1475 | 1476 | void begin_coding_type( 1477 | CodingType ct, int zigzag_index, int param0, int param1) { 1478 | bool begin_queue = model && model->begin_coding_type(ct, zigzag_index, param0, param1); 1479 | if (begin_queue && ct) { 1480 | model->finished_queueing(ct, 1481 | [&](model_key key, int * symbol) { 1482 | *symbol = decoder->get([&](range_t range){ 1483 | return model->probability_for_model_key(range, key); 1484 | }); 1485 | model->update_state_for_model_key(*symbol, key); 1486 | }); 1487 | static int i = 0; 1488 | if (i++ < 10) { 1489 | std::cerr << "FINISHED QUEUING RECODE: " << (int)model->frames[model->cur_frame].meta_at(model->mb_coord.mb_x, model->mb_coord.mb_y).num_nonzeros[model->mb_coord.scan8_index] << std::endl; 1490 | } 1491 | } 1492 | } 1493 | void end_coding_type(CodingType ct) { 1494 | if (!model) { 1495 | return; 1496 | } 1497 | model->end_coding_type(ct); 1498 | } 1499 | 1500 | private: 1501 | void finish() { 1502 | // Omit trailing byte if it's only a stop bit. 1503 | if (cabac_out.back() == 0x80) { 1504 | cabac_out.pop_back(); 1505 | } 1506 | out->out_bytes.assign(reinterpret_cast(cabac_out.data()), cabac_out.size()); 1507 | out->done = true; 1508 | } 1509 | 1510 | int index; 1511 | const Recoded::Block *block; 1512 | block_state *out = nullptr; 1513 | 1514 | h264_model *model; 1515 | std::unique_ptr> decoder; 1516 | 1517 | std::vector cabac_out; 1518 | cabac::encoder>> cabac_encoder{ 1519 | std::back_inserter(cabac_out)}; 1520 | }; 1521 | h264_model *get_model() { 1522 | return &model; 1523 | } 1524 | 1525 | private: 1526 | // Return a unique 8-byte string containing no zero bytes (NAL-encoding-safe). 1527 | std::string next_surrogate_marker() { 1528 | uint64_t n = surrogate_marker_sequence_number++; 1529 | std::string surrogate_marker(SURROGATE_MARKER_BYTES, '\x01'); 1530 | for (int i = 0; i < (int)surrogate_marker.size(); i++) { 1531 | surrogate_marker[i] = (n % 255) + 1; 1532 | n /= 255; 1533 | } 1534 | return surrogate_marker; 1535 | } 1536 | 1537 | std::string make_surrogate_block(const std::string& surrogate_marker, size_t size) { 1538 | if (size < surrogate_marker.size()) { 1539 | throw std::runtime_error("Invalid coded block size for surrogate: " + std::to_string(size)); 1540 | } 1541 | std::string surrogate_block = surrogate_marker; 1542 | surrogate_block.resize(size, 'X'); // NAL-encoding-safe padding. 1543 | return surrogate_block; 1544 | } 1545 | 1546 | int recognize_coded_block(const uint8_t* buf, int size) { 1547 | while (!blocks[next_coded_block].coded) { 1548 | if (next_coded_block >= read_index) { 1549 | throw std::runtime_error("Coded block expected, but not recorded in the compressed data."); 1550 | } 1551 | next_coded_block++; 1552 | } 1553 | int index = next_coded_block++; 1554 | // Validate the decoder init call against the coded block's size and surrogate marker. 1555 | const Recoded::Block& block = in.block(index); 1556 | if (block.has_cabac()) { 1557 | if (block.size() != size) { 1558 | throw std::runtime_error("Invalid surrogate block size."); 1559 | } 1560 | std::string buf_header(reinterpret_cast(buf), 1561 | blocks[index].surrogate_marker.size()); 1562 | if (blocks[index].surrogate_marker != buf_header) { 1563 | throw std::runtime_error("Invalid surrogate marker in coded block."); 1564 | } 1565 | } else if (block.has_skip_coded()) { 1566 | if (block.size() != size) { 1567 | throw std::runtime_error("Invalid skip_coded block size."); 1568 | } 1569 | } else { 1570 | throw std::runtime_error("Internal error: expected coded block."); 1571 | } 1572 | return index; 1573 | } 1574 | 1575 | std::string input_filename; 1576 | std::ostream& out_stream; 1577 | 1578 | Recoded in; 1579 | int read_index = 0, read_offset = 0; 1580 | std::string read_block; 1581 | 1582 | std::vector blocks; 1583 | 1584 | // Counter used to generate surrogate markers for coded blocks. 1585 | uint64_t surrogate_marker_sequence_number = 1; 1586 | // Head of the coded block queue - blocks that have been produced by 1587 | // read_packet but not yet decoded. Tail of the queue is read_index. 1588 | int next_coded_block = 0; 1589 | 1590 | h264_model model; 1591 | }; 1592 | 1593 | 1594 | int roundtrip(const std::string& input_filename, std::ostream* out) { 1595 | std::stringstream original, compressed, decompressed; 1596 | original << std::ifstream(input_filename).rdbuf(); 1597 | compressor c(input_filename, compressed); 1598 | c.run(); 1599 | decompressor d(input_filename, compressed.str(), decompressed); 1600 | d.run(); 1601 | 1602 | if (original.str() == decompressed.str()) { 1603 | if (out) { 1604 | (*out) << compressed.str(); 1605 | } 1606 | double ratio = compressed.str().size() * 1.0 / original.str().size(); 1607 | 1608 | Recoded compressed_proto; 1609 | compressed_proto.ParseFromString(compressed.str()); 1610 | int proto_block_bytes = 0; 1611 | for (const auto& block : compressed_proto.block()) { 1612 | proto_block_bytes += block.literal().size() + block.cabac().size(); 1613 | } 1614 | double proto_overhead = (compressed.str().size() - proto_block_bytes) * 1.0 / compressed.str().size(); 1615 | 1616 | std::cout << "Compress-decompress roundtrip succeeded:" << std::endl; 1617 | std::cout << " compression ratio: " << ratio*100. << "%" << std::endl; 1618 | std::cout << " protobuf overhead: " << proto_overhead*100. << "%" << std::endl; 1619 | return 0; 1620 | } else { 1621 | std::cerr << "Compress-decompress roundtrip failed." << std::endl; 1622 | return 1; 1623 | } 1624 | } 1625 | 1626 | 1627 | int 1628 | main(int argc, char **argv) { 1629 | av_register_all(); 1630 | 1631 | if (argc < 3 || argc > 4) { 1632 | std::cerr << "Usage: " << argv[0] << " [compress|decompress|roundtrip] [output]" << std::endl; 1633 | return 1; 1634 | } 1635 | std::string command = argv[1]; 1636 | std::string input_filename = argv[2]; 1637 | std::ofstream out_file; 1638 | if (argc > 3) { 1639 | out_file.open(argv[3]); 1640 | } 1641 | 1642 | try { 1643 | if (command == "compress") { 1644 | compressor c(input_filename, out_file.is_open() ? out_file : std::cout); 1645 | c.run(); 1646 | } else if (command == "decompress") { 1647 | decompressor d(input_filename, out_file.is_open() ? out_file : std::cout); 1648 | d.run(); 1649 | } else if (command == "roundtrip") { 1650 | return roundtrip(input_filename, out_file.is_open() ? &out_file : nullptr); 1651 | } else { 1652 | throw std::invalid_argument("Unknown command: " + command); 1653 | } 1654 | } catch (const std::exception& e) { 1655 | std::cerr << "Exception (" << typeid(e).name() << "): " << e.what() << std::endl; 1656 | return 1; 1657 | } 1658 | return 0; 1659 | } 1660 | --------------------------------------------------------------------------------