├── Readme.markdown ├── cpp ├── AdaptiveArithmeticCompress.cpp ├── AdaptiveArithmeticDecompress.cpp ├── ArithmeticCoder.cpp ├── ArithmeticCoder.hpp ├── ArithmeticCompress.cpp ├── ArithmeticDecompress.cpp ├── BitIoStream.cpp ├── BitIoStream.hpp ├── FrequencyTable.cpp ├── FrequencyTable.hpp ├── Makefile ├── PpmCompress.cpp ├── PpmDecompress.cpp ├── PpmModel.cpp └── PpmModel.hpp ├── java ├── src │ ├── AdaptiveArithmeticCompress.java │ ├── AdaptiveArithmeticDecompress.java │ ├── ArithmeticCoderBase.java │ ├── ArithmeticCompress.java │ ├── ArithmeticDecoder.java │ ├── ArithmeticDecompress.java │ ├── ArithmeticEncoder.java │ ├── BitInputStream.java │ ├── BitOutputStream.java │ ├── CheckedFrequencyTable.java │ ├── FlatFrequencyTable.java │ ├── FrequencyTable.java │ ├── PpmCompress.java │ ├── PpmDecompress.java │ ├── PpmModel.java │ └── SimpleFrequencyTable.java └── test │ ├── AdaptiveArithmeticCompressTest.java │ ├── ArithmeticCodingTest.java │ ├── ArithmeticCompressTest.java │ └── PpmCompressTest.java └── python ├── adaptive-arithmetic-compress.py ├── adaptive-arithmetic-decompress.py ├── arithmetic-compress.py ├── arithmetic-decompress.py ├── arithmeticcoding.py ├── ppm-compress.py ├── ppm-decompress.py └── ppmmodel.py /Readme.markdown: -------------------------------------------------------------------------------- 1 | Reference arithmetic coding 2 | =========================== 3 | 4 | This project is a clear implementation of arithmetic coding, suitable as a reference for 5 | educational purposes. It is provided separately in Java, Python, C++, and is open source. 6 | 7 | The code can be used for study, and as a solid basis for modification and extension. 8 | Consequently, the codebase optimizes for readability and avoids fancy logic, 9 | and does not target the best speed/memory/performance. 10 | 11 | Home page with detailed description: [https://www.nayuki.io/page/reference-arithmetic-coding](https://www.nayuki.io/page/reference-arithmetic-coding) 12 | 13 | 14 | License 15 | ------- 16 | 17 | Copyright © 2023 Project Nayuki. (MIT License) 18 | 19 | Permission is hereby granted, free of charge, to any person obtaining a copy of 20 | this software and associated documentation files (the "Software"), to deal in 21 | the Software without restriction, including without limitation the rights to 22 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 23 | the Software, and to permit persons to whom the Software is furnished to do so, 24 | subject to the following conditions: 25 | 26 | * The above copyright notice and this permission notice shall be included in 27 | all copies or substantial portions of the Software. 28 | 29 | * The Software is provided "as is", without warranty of any kind, express or 30 | implied, including but not limited to the warranties of merchantability, 31 | fitness for a particular purpose and noninfringement. In no event shall the 32 | authors or copyright holders be liable for any claim, damages or other 33 | liability, whether in an action of contract, tort or otherwise, arising from, 34 | out of or in connection with the Software or the use or other dealings in the 35 | Software. 36 | -------------------------------------------------------------------------------- /cpp/AdaptiveArithmeticCompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Compression application using adaptive arithmetic coding 3 | * 4 | * Usage: AdaptiveArithmeticCompress InputFile OutputFile 5 | * Then use the corresponding "AdaptiveArithmeticDecompress" application to recreate the original input file. 6 | * Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1), 7 | * and updates it after each byte encoded. The corresponding decompressor program also starts with a flat 8 | * frequency table and updates it after each byte decoded. It is by design that the compressor and 9 | * decompressor have synchronized states, so that the data can be decompressed properly. 10 | * 11 | * Copyright (c) Project Nayuki 12 | * MIT License. See readme file. 13 | * https://www.nayuki.io/page/reference-arithmetic-coding 14 | */ 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include "ArithmeticCoder.hpp" 23 | #include "BitIoStream.hpp" 24 | #include "FrequencyTable.hpp" 25 | 26 | using std::uint32_t; 27 | 28 | 29 | int main(int argc, char *argv[]) { 30 | // Handle command line arguments 31 | if (argc != 3) { 32 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 33 | return EXIT_FAILURE; 34 | } 35 | const char *inputFile = argv[1]; 36 | const char *outputFile = argv[2]; 37 | 38 | // Perform file compression 39 | std::ifstream in(inputFile, std::ios::binary); 40 | std::ofstream out(outputFile, std::ios::binary); 41 | BitOutputStream bout(out); 42 | try { 43 | 44 | SimpleFrequencyTable freqs(FlatFrequencyTable(257)); 45 | ArithmeticEncoder enc(32, bout); 46 | while (true) { 47 | // Read and encode one byte 48 | int symbol = in.get(); 49 | if (symbol == std::char_traits::eof()) 50 | break; 51 | if (!(0 <= symbol && symbol <= 255)) 52 | throw std::logic_error("Assertion error"); 53 | enc.write(freqs, static_cast(symbol)); 54 | freqs.increment(static_cast(symbol)); 55 | } 56 | 57 | enc.write(freqs, 256); // EOF 58 | enc.finish(); // Flush remaining code bits 59 | bout.finish(); 60 | return EXIT_SUCCESS; 61 | 62 | } catch (const char *msg) { 63 | std::cerr << msg << std::endl; 64 | return EXIT_FAILURE; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /cpp/AdaptiveArithmeticDecompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Decompression application using adaptive arithmetic coding 3 | * 4 | * Usage: AdaptiveArithmeticDecompress InputFile OutputFile 5 | * This decompresses files generated by the "AdaptiveArithmeticCompress" application. 6 | * 7 | * Copyright (c) Project Nayuki 8 | * MIT License. See readme file. 9 | * https://www.nayuki.io/page/reference-arithmetic-coding 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include "ArithmeticCoder.hpp" 18 | #include "BitIoStream.hpp" 19 | #include "FrequencyTable.hpp" 20 | 21 | using std::uint32_t; 22 | 23 | 24 | int main(int argc, char *argv[]) { 25 | // Handle command line arguments 26 | if (argc != 3) { 27 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 28 | return EXIT_FAILURE; 29 | } 30 | const char *inputFile = argv[1]; 31 | const char *outputFile = argv[2]; 32 | 33 | // Perform file decompression 34 | std::ifstream in(inputFile, std::ios::binary); 35 | std::ofstream out(outputFile, std::ios::binary); 36 | BitInputStream bin(in); 37 | try { 38 | 39 | SimpleFrequencyTable freqs(FlatFrequencyTable(257)); 40 | ArithmeticDecoder dec(32, bin); 41 | while (true) { 42 | // Decode and write one byte 43 | uint32_t symbol = dec.read(freqs); 44 | if (symbol == 256) // EOF symbol 45 | break; 46 | int b = static_cast(symbol); 47 | if (std::numeric_limits::is_signed) 48 | b -= (b >> 7) << 8; 49 | out.put(static_cast(b)); 50 | freqs.increment(symbol); 51 | } 52 | return EXIT_SUCCESS; 53 | 54 | } catch (const char *msg) { 55 | std::cerr << msg << std::endl; 56 | return EXIT_FAILURE; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /cpp/ArithmeticCoder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #include 10 | #include 11 | #include "ArithmeticCoder.hpp" 12 | 13 | using std::uint32_t; 14 | using std::uint64_t; 15 | 16 | 17 | ArithmeticCoderBase::ArithmeticCoderBase(int numBits) { 18 | if (!(1 <= numBits && numBits <= 63)) 19 | throw std::domain_error("State size out of range"); 20 | numStateBits = numBits; 21 | fullRange = static_cast(1) << numStateBits; 22 | halfRange = fullRange >> 1; // Non-zero 23 | quarterRange = halfRange >> 1; // Can be zero 24 | minimumRange = quarterRange + 2; // At least 2 25 | maximumTotal = std::min(std::numeric_limits::max() / fullRange, minimumRange); 26 | stateMask = fullRange - 1; 27 | low = 0; 28 | high = stateMask; 29 | } 30 | 31 | 32 | ArithmeticCoderBase::~ArithmeticCoderBase() {} 33 | 34 | 35 | void ArithmeticCoderBase::update(const FrequencyTable &freqs, uint32_t symbol) { 36 | // State check 37 | if (low >= high || (low & stateMask) != low || (high & stateMask) != high) 38 | throw std::logic_error("Assertion error: Low or high out of range"); 39 | uint64_t range = high - low + 1; 40 | if (!(minimumRange <= range && range <= fullRange)) 41 | throw std::logic_error("Assertion error: Range out of range"); 42 | 43 | // Frequency table values check 44 | uint32_t total = freqs.getTotal(); 45 | uint32_t symLow = freqs.getLow(symbol); 46 | uint32_t symHigh = freqs.getHigh(symbol); 47 | if (symLow == symHigh) 48 | throw std::invalid_argument("Symbol has zero frequency"); 49 | if (total > maximumTotal) 50 | throw std::invalid_argument("Cannot code symbol because total is too large"); 51 | 52 | // Update range 53 | uint64_t newLow = low + symLow * range / total; 54 | uint64_t newHigh = low + symHigh * range / total - 1; 55 | low = newLow; 56 | high = newHigh; 57 | 58 | // While low and high have the same top bit value, shift them out 59 | while (((low ^ high) & halfRange) == 0) { 60 | shift(); 61 | low = ((low << 1) & stateMask); 62 | high = ((high << 1) & stateMask) | 1; 63 | } 64 | // Now low's top bit must be 0 and high's top bit must be 1 65 | 66 | // While low's top two bits are 01 and high's are 10, delete the second highest bit of both 67 | while ((low & ~high & quarterRange) != 0) { 68 | underflow(); 69 | low = (low << 1) ^ halfRange; 70 | high = ((high ^ halfRange) << 1) | halfRange | 1; 71 | } 72 | } 73 | 74 | 75 | ArithmeticDecoder::ArithmeticDecoder(int numBits, BitInputStream &in) : 76 | ArithmeticCoderBase(numBits), 77 | input(in), 78 | code(0) { 79 | for (int i = 0; i < numStateBits; i++) 80 | code = code << 1 | readCodeBit(); 81 | } 82 | 83 | 84 | uint32_t ArithmeticDecoder::read(const FrequencyTable &freqs) { 85 | // Translate from coding range scale to frequency table scale 86 | uint32_t total = freqs.getTotal(); 87 | if (total > maximumTotal) 88 | throw std::invalid_argument("Cannot decode symbol because total is too large"); 89 | uint64_t range = high - low + 1; 90 | uint64_t offset = code - low; 91 | uint64_t value = ((offset + 1) * total - 1) / range; 92 | if (value * range / total > offset) 93 | throw std::logic_error("Assertion error"); 94 | if (value >= total) 95 | throw std::logic_error("Assertion error"); 96 | 97 | // A kind of binary search. Find highest symbol such that freqs.getLow(symbol) <= value. 98 | uint32_t start = 0; 99 | uint32_t end = freqs.getSymbolLimit(); 100 | while (end - start > 1) { 101 | uint32_t middle = (start + end) >> 1; 102 | if (freqs.getLow(middle) > value) 103 | end = middle; 104 | else 105 | start = middle; 106 | } 107 | if (start + 1 != end) 108 | throw std::logic_error("Assertion error"); 109 | 110 | uint32_t symbol = start; 111 | if (!(freqs.getLow(symbol) * range / total <= offset && offset < freqs.getHigh(symbol) * range / total)) 112 | throw std::logic_error("Assertion error"); 113 | update(freqs, symbol); 114 | if (!(low <= code && code <= high)) 115 | throw std::logic_error("Assertion error: Code out of range"); 116 | return symbol; 117 | } 118 | 119 | 120 | void ArithmeticDecoder::shift() { 121 | code = ((code << 1) & stateMask) | readCodeBit(); 122 | } 123 | 124 | 125 | void ArithmeticDecoder::underflow() { 126 | code = (code & halfRange) | ((code << 1) & (stateMask >> 1)) | readCodeBit(); 127 | } 128 | 129 | 130 | int ArithmeticDecoder::readCodeBit() { 131 | int temp = input.read(); 132 | if (temp == -1) 133 | temp = 0; 134 | return temp; 135 | } 136 | 137 | 138 | ArithmeticEncoder::ArithmeticEncoder(int numBits, BitOutputStream &out) : 139 | ArithmeticCoderBase(numBits), 140 | output(out), 141 | numUnderflow(0) {} 142 | 143 | 144 | void ArithmeticEncoder::write(const FrequencyTable &freqs, uint32_t symbol) { 145 | update(freqs, symbol); 146 | } 147 | 148 | 149 | void ArithmeticEncoder::finish() { 150 | output.write(1); 151 | } 152 | 153 | 154 | void ArithmeticEncoder::shift() { 155 | int bit = static_cast(low >> (numStateBits - 1)); 156 | output.write(bit); 157 | 158 | // Write out the saved underflow bits 159 | for (; numUnderflow > 0; numUnderflow--) 160 | output.write(bit ^ 1); 161 | } 162 | 163 | 164 | void ArithmeticEncoder::underflow() { 165 | if (numUnderflow == std::numeric_limits::max()) 166 | throw std::overflow_error("Maximum underflow reached"); 167 | numUnderflow++; 168 | } 169 | -------------------------------------------------------------------------------- /cpp/ArithmeticCoder.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include "BitIoStream.hpp" 14 | #include "FrequencyTable.hpp" 15 | 16 | 17 | /* 18 | * Provides the state and behaviors that arithmetic coding encoders and decoders share. 19 | */ 20 | class ArithmeticCoderBase { 21 | 22 | /*---- Configuration fields ----*/ 23 | 24 | // Number of bits for the 'low' and 'high' state variables. Must be in the range [1, 63]. 25 | // - For state sizes less than the midpoint of around 32, larger values are generally better - 26 | // they allow a larger maximum frequency total (maximumTotal), and they reduce the approximation 27 | // error inherent in adapting fractions to integers; both effects reduce the data encoding loss 28 | // and asymptotically approach the efficiency of arithmetic coding using exact fractions. 29 | // - But for state sizes greater than the midpoint, because intermediate computations are limited 30 | // to the long integer type's 63-bit unsigned precision, larger state sizes will decrease the 31 | // maximum frequency total, which might constrain the user-supplied probability model. 32 | // - Therefore numStateBits=32 is recommended as the most versatile setting 33 | // because it maximizes maximumTotal (which ends up being slightly over 2^30). 34 | // - Note that numStateBits=63 is legal but useless because it implies maximumTotal=1, 35 | // which means a frequency table can only support one symbol with non-zero frequency. 36 | protected: int numStateBits; 37 | 38 | // Maximum range (high+1-low) during coding (trivial), which is 2^numStateBits = 1000...000. 39 | protected: std::uint64_t fullRange; 40 | 41 | // The top bit at width numStateBits, which is 0100...000. 42 | protected: std::uint64_t halfRange; 43 | 44 | // The second highest bit at width numStateBits, which is 0010...000. This is zero when numStateBits=1. 45 | protected: std::uint64_t quarterRange; 46 | 47 | // Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. 48 | protected: std::uint64_t minimumRange; 49 | 50 | // Maximum allowed total from a frequency table at all times during coding. 51 | protected: std::uint64_t maximumTotal; 52 | 53 | // Bit mask of numStateBits ones, which is 0111...111. 54 | protected: std::uint64_t stateMask; 55 | 56 | 57 | /*---- State fields ----*/ 58 | 59 | // Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. 60 | protected: std::uint64_t low; 61 | 62 | // High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. 63 | protected: std::uint64_t high; 64 | 65 | 66 | /*---- Constructor ----*/ 67 | 68 | // Constructs an arithmetic coder, which initializes the code range. 69 | public: explicit ArithmeticCoderBase(int numBits); 70 | 71 | 72 | public: virtual ~ArithmeticCoderBase() = 0; 73 | 74 | 75 | /*---- Methods ----*/ 76 | 77 | // Updates the code range (low and high) of this arithmetic coder as a result 78 | // of processing the given symbol with the given frequency table. 79 | // Invariants that are true before and after encoding/decoding each symbol 80 | // (letting fullRange = 2^numStateBits): 81 | // * 0 <= low <= code <= high < fullRange. ('code' exists only in the decoder.) 82 | // Therefore these variables are unsigned integers of numStateBits bits. 83 | // * low < 1/2 * fullRange <= high. 84 | // In other words, they are in different halves of the full range. 85 | // * (low < 1/4 * fullRange) || (high >= 3/4 * fullRange). 86 | // In other words, they are not both in the middle two quarters. 87 | // * Let range = high - low + 1, then fullRange/4 < minimumRange <= range <= fullRange. 88 | // These invariants for 'range' essentially dictate the maximum total that the incoming 89 | // frequency table can have, such that intermediate calculations don't overflow. 90 | protected: virtual void update(const FrequencyTable &freqs, std::uint32_t symbol); 91 | 92 | 93 | // Called to handle the situation when the top bit of 'low' and 'high' are equal. 94 | protected: virtual void shift() = 0; 95 | 96 | 97 | // Called to handle the situation when low=01(...) and high=10(...). 98 | protected: virtual void underflow() = 0; 99 | 100 | }; 101 | 102 | 103 | 104 | /* 105 | * Reads from an arithmetic-coded bit stream and decodes symbols. 106 | */ 107 | class ArithmeticDecoder final : private ArithmeticCoderBase { 108 | 109 | /*---- Fields ----*/ 110 | 111 | // The underlying bit input stream. 112 | private: BitInputStream &input; 113 | 114 | // The current raw code bits being buffered, which is always in the range [low, high]. 115 | private: std::uint64_t code; 116 | 117 | 118 | /*---- Constructor ----*/ 119 | 120 | // Constructs an arithmetic coding decoder based on the 121 | // given bit input stream, and fills the code bits. 122 | public: explicit ArithmeticDecoder(int numBits, BitInputStream &in); 123 | 124 | 125 | /*---- Methods ----*/ 126 | 127 | // Decodes the next symbol based on the given frequency table and returns it. 128 | // Also updates this arithmetic coder's state and may read in some bits. 129 | public: std::uint32_t read(const FrequencyTable &freqs); 130 | 131 | 132 | protected: void shift() override; 133 | 134 | 135 | protected: void underflow() override; 136 | 137 | 138 | // Returns the next bit (0 or 1) from the input stream. The end 139 | // of stream is treated as an infinite number of trailing zeros. 140 | private: int readCodeBit(); 141 | 142 | }; 143 | 144 | 145 | 146 | /* 147 | * Encodes symbols and writes to an arithmetic-coded bit stream. 148 | */ 149 | class ArithmeticEncoder final : private ArithmeticCoderBase { 150 | 151 | /*---- Fields ----*/ 152 | 153 | // The underlying bit output stream. 154 | private: BitOutputStream &output; 155 | 156 | // Number of saved underflow bits. This value can grow without bound, 157 | // so a truly correct implementation would use a bigint. 158 | private: unsigned long numUnderflow; 159 | 160 | 161 | /*---- Constructor ----*/ 162 | 163 | // Constructs an arithmetic coding encoder based on the given bit output stream. 164 | public: explicit ArithmeticEncoder(int numBits, BitOutputStream &out); 165 | 166 | 167 | /*---- Methods ----*/ 168 | 169 | // Encodes the given symbol based on the given frequency table. 170 | // Also updates this arithmetic coder's state and may write out some bits. 171 | public: void write(const FrequencyTable &freqs, std::uint32_t symbol); 172 | 173 | 174 | // Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. 175 | // It is important that this method must be called at the end of the each encoding process. 176 | // Note that this method merely writes data to the underlying output stream but does not close it. 177 | public: void finish(); 178 | 179 | 180 | protected: void shift() override; 181 | 182 | 183 | protected: void underflow() override; 184 | 185 | }; 186 | -------------------------------------------------------------------------------- /cpp/ArithmeticCompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Compression application using static arithmetic coding 3 | * 4 | * Usage: ArithmeticCompress InputFile OutputFile 5 | * Then use the corresponding "ArithmeticDecompress" application to recreate the original input file. 6 | * Note that the application uses an alphabet of 257 symbols - 256 symbols for the byte 7 | * values and 1 symbol for the EOF marker. The compressed file format starts with a list 8 | * of 256 symbol frequencies, and then followed by the arithmetic-coded data. 9 | * 10 | * Copyright (c) Project Nayuki 11 | * MIT License. See readme file. 12 | * https://www.nayuki.io/page/reference-arithmetic-coding 13 | */ 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include "ArithmeticCoder.hpp" 23 | #include "BitIoStream.hpp" 24 | #include "FrequencyTable.hpp" 25 | 26 | using std::uint32_t; 27 | 28 | 29 | int main(int argc, char *argv[]) { 30 | // Handle command line arguments 31 | if (argc != 3) { 32 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 33 | return EXIT_FAILURE; 34 | } 35 | const char *inputFile = argv[1]; 36 | const char *outputFile = argv[2]; 37 | 38 | // Read input file once to compute symbol frequencies 39 | std::ifstream in(inputFile, std::ios::binary); 40 | SimpleFrequencyTable freqs(std::vector(257, 0)); 41 | freqs.increment(256); // EOF symbol gets a frequency of 1 42 | while (true) { 43 | int b = in.get(); 44 | if (b == std::char_traits::eof()) 45 | break; 46 | if (b < 0 || b > 255) 47 | throw std::logic_error("Assertion error"); 48 | freqs.increment(static_cast(b)); 49 | } 50 | 51 | // Read input file again, compress with arithmetic coding, and write output file 52 | in.clear(); 53 | in.seekg(0); 54 | std::ofstream out(outputFile, std::ios::binary); 55 | BitOutputStream bout(out); 56 | try { 57 | 58 | // Write frequency table 59 | for (uint32_t i = 0; i < 256; i++) { 60 | uint32_t freq = freqs.get(i); 61 | for (int j = 31; j >= 0; j--) 62 | bout.write(static_cast((freq >> j) & 1)); // Big endian 63 | } 64 | 65 | ArithmeticEncoder enc(32, bout); 66 | while (true) { 67 | // Read and encode one byte 68 | int symbol = in.get(); 69 | if (symbol == std::char_traits::eof()) 70 | break; 71 | if (!(0 <= symbol && symbol <= 255)) 72 | throw std::logic_error("Assertion error"); 73 | enc.write(freqs, static_cast(symbol)); 74 | } 75 | 76 | enc.write(freqs, 256); // EOF 77 | enc.finish(); // Flush remaining code bits 78 | bout.finish(); 79 | return EXIT_SUCCESS; 80 | 81 | } catch (const char *msg) { 82 | std::cerr << msg << std::endl; 83 | return EXIT_FAILURE; 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /cpp/ArithmeticDecompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Decompression application using static arithmetic coding 3 | * 4 | * Usage: ArithmeticDecompress InputFile OutputFile 5 | * This decompresses files generated by the "ArithmeticCompress" application. 6 | * 7 | * Copyright (c) Project Nayuki 8 | * MIT License. See readme file. 9 | * https://www.nayuki.io/page/reference-arithmetic-coding 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include "ArithmeticCoder.hpp" 19 | #include "BitIoStream.hpp" 20 | #include "FrequencyTable.hpp" 21 | 22 | using std::uint32_t; 23 | 24 | 25 | int main(int argc, char *argv[]) { 26 | // Handle command line arguments 27 | if (argc != 3) { 28 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 29 | return EXIT_FAILURE; 30 | } 31 | const char *inputFile = argv[1]; 32 | const char *outputFile = argv[2]; 33 | 34 | // Perform file decompression 35 | std::ifstream in(inputFile, std::ios::binary); 36 | std::ofstream out(outputFile, std::ios::binary); 37 | BitInputStream bin(in); 38 | try { 39 | 40 | // Read frequency table 41 | SimpleFrequencyTable freqs(std::vector(257, 0)); 42 | for (uint32_t i = 0; i < 256; i++) { 43 | uint32_t freq = 0; 44 | for (int j = 0; j < 32; j++) 45 | freq = (freq << 1) | bin.readNoEof(); // Big endian 46 | freqs.set(i, freq); 47 | } 48 | freqs.increment(256); // EOF symbol 49 | 50 | ArithmeticDecoder dec(32, bin); 51 | while (true) { 52 | uint32_t symbol = dec.read(freqs); 53 | if (symbol == 256) // EOF symbol 54 | break; 55 | int b = static_cast(symbol); 56 | if (std::numeric_limits::is_signed) 57 | b -= (b >> 7) << 8; 58 | out.put(static_cast(b)); 59 | } 60 | return EXIT_SUCCESS; 61 | 62 | } catch (const char *msg) { 63 | std::cerr << msg << std::endl; 64 | return EXIT_FAILURE; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /cpp/BitIoStream.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include "BitIoStream.hpp" 13 | 14 | 15 | BitInputStream::BitInputStream(std::istream &in) : 16 | input(in), 17 | currentByte(0), 18 | numBitsRemaining(0) {} 19 | 20 | 21 | int BitInputStream::read() { 22 | if (currentByte == std::char_traits::eof()) 23 | return -1; 24 | if (numBitsRemaining == 0) { 25 | currentByte = input.get(); // Note: istream.get() returns int, not char 26 | if (currentByte == std::char_traits::eof()) 27 | return -1; 28 | if (!(0 <= currentByte && currentByte <= 255)) 29 | throw std::logic_error("Assertion error"); 30 | numBitsRemaining = 8; 31 | } 32 | if (numBitsRemaining <= 0) 33 | throw std::logic_error("Assertion error"); 34 | numBitsRemaining--; 35 | return (currentByte >> numBitsRemaining) & 1; 36 | } 37 | 38 | 39 | int BitInputStream::readNoEof() { 40 | int result = read(); 41 | if (result != -1) 42 | return result; 43 | else 44 | throw std::runtime_error("End of stream"); 45 | } 46 | 47 | 48 | BitOutputStream::BitOutputStream(std::ostream &out) : 49 | output(out), 50 | currentByte(0), 51 | numBitsFilled(0) {} 52 | 53 | 54 | void BitOutputStream::write(int b) { 55 | if (b != 0 && b != 1) 56 | throw std::domain_error("Argument must be 0 or 1"); 57 | currentByte = (currentByte << 1) | b; 58 | numBitsFilled++; 59 | if (numBitsFilled == 8) { 60 | // Note: ostream.put() takes char, which may be signed/unsigned 61 | if (std::numeric_limits::is_signed) 62 | currentByte -= (currentByte >> 7) << 8; 63 | output.put(static_cast(currentByte)); 64 | currentByte = 0; 65 | numBitsFilled = 0; 66 | } 67 | } 68 | 69 | 70 | void BitOutputStream::finish() { 71 | while (numBitsFilled != 0) 72 | write(0); 73 | } 74 | -------------------------------------------------------------------------------- /cpp/BitIoStream.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | 15 | /* 16 | * A stream of bits that can be read. Because they come from an underlying byte stream, 17 | * the total number of bits is always a multiple of 8. The bits are read in big endian. 18 | */ 19 | class BitInputStream final { 20 | 21 | /*---- Fields ----*/ 22 | 23 | // The underlying byte stream to read from. 24 | private: std::istream &input; 25 | 26 | // Either in the range [0x00, 0xFF] if bits are available, or EOF if end of stream is reached. 27 | private: int currentByte; 28 | 29 | // Number of remaining bits in the current byte, always between 0 and 7 (inclusive). 30 | private: int numBitsRemaining; 31 | 32 | 33 | /*---- Constructor ----*/ 34 | 35 | // Constructs a bit input stream based on the given byte input stream. 36 | public: explicit BitInputStream(std::istream &in); 37 | 38 | 39 | /*---- Methods ----*/ 40 | 41 | // Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if 42 | // the end of stream is reached. The end of stream always occurs on a byte boundary. 43 | public: int read(); 44 | 45 | 46 | // Reads a bit from this stream. Returns 0 or 1 if a bit is available, or throws an exception 47 | // if the end of stream is reached. The end of stream always occurs on a byte boundary. 48 | public: int readNoEof(); 49 | 50 | }; 51 | 52 | 53 | 54 | /* 55 | * A stream where bits can be written to. Because they are written to an underlying 56 | * byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. 57 | * The bits are written in big endian. 58 | */ 59 | class BitOutputStream final { 60 | 61 | /*---- Fields ----*/ 62 | 63 | // The underlying byte stream to write to. 64 | private: std::ostream &output; 65 | 66 | // The accumulated bits for the current byte, always in the range [0x00, 0xFF]. 67 | private: int currentByte; 68 | 69 | // Number of accumulated bits in the current byte, always between 0 and 7 (inclusive). 70 | private: int numBitsFilled; 71 | 72 | 73 | /*---- Constructor ----*/ 74 | 75 | // Constructs a bit output stream based on the given byte output stream. 76 | public: explicit BitOutputStream(std::ostream &out); 77 | 78 | 79 | /*---- Methods ----*/ 80 | 81 | // Writes a bit to the stream. The given bit must be 0 or 1. 82 | public: void write(int b); 83 | 84 | 85 | // Writes the minimum number of "0" bits (between 0 and 7 of them) as padding to 86 | // reach the next byte boundary. Most applications will require the bits in the last 87 | // partial byte to be written before the underlying stream is closed. Note that this 88 | // method merely writes data to the underlying output stream but does not close it. 89 | public: void finish(); 90 | 91 | }; 92 | -------------------------------------------------------------------------------- /cpp/FrequencyTable.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #include 10 | #include "FrequencyTable.hpp" 11 | 12 | using std::uint32_t; 13 | 14 | 15 | FrequencyTable::~FrequencyTable() {} 16 | 17 | 18 | FlatFrequencyTable::FlatFrequencyTable(uint32_t numSyms) : 19 | numSymbols(numSyms) { 20 | if (numSyms < 1) 21 | throw std::domain_error("Number of symbols must be positive"); 22 | } 23 | 24 | 25 | uint32_t FlatFrequencyTable::getSymbolLimit() const { 26 | return numSymbols; 27 | } 28 | 29 | 30 | uint32_t FlatFrequencyTable::get(uint32_t symbol) const { 31 | checkSymbol(symbol); 32 | return 1; 33 | } 34 | 35 | 36 | uint32_t FlatFrequencyTable::getTotal() const { 37 | return numSymbols; 38 | } 39 | 40 | 41 | uint32_t FlatFrequencyTable::getLow(uint32_t symbol) const { 42 | checkSymbol(symbol); 43 | return symbol; 44 | } 45 | 46 | 47 | uint32_t FlatFrequencyTable::getHigh(uint32_t symbol) const { 48 | checkSymbol(symbol); 49 | return symbol + 1; 50 | } 51 | 52 | 53 | void FlatFrequencyTable::set(uint32_t, uint32_t) { 54 | throw std::logic_error("Unsupported operation"); 55 | } 56 | 57 | 58 | void FlatFrequencyTable::increment(uint32_t) { 59 | throw std::logic_error("Unsupported operation"); 60 | } 61 | 62 | 63 | void FlatFrequencyTable::checkSymbol(uint32_t symbol) const { 64 | if (symbol >= numSymbols) 65 | throw std::domain_error("Symbol out of range"); 66 | } 67 | 68 | 69 | SimpleFrequencyTable::SimpleFrequencyTable(const std::vector &freqs) { 70 | if (freqs.size() > UINT32_MAX - 1) 71 | throw std::length_error("Too many symbols"); 72 | uint32_t size = static_cast(freqs.size()); 73 | if (size < 1) 74 | throw std::invalid_argument("At least 1 symbol needed"); 75 | 76 | frequencies = freqs; 77 | cumulative.reserve(size + 1); 78 | initCumulative(false); 79 | total = getHigh(size - 1); 80 | } 81 | 82 | 83 | SimpleFrequencyTable::SimpleFrequencyTable(const FrequencyTable &freqs) { 84 | uint32_t size = freqs.getSymbolLimit(); 85 | if (size < 1) 86 | throw std::invalid_argument("At least 1 symbol needed"); 87 | if (size > UINT32_MAX - 1) 88 | throw std::length_error("Too many symbols"); 89 | 90 | frequencies.reserve(size + 1); 91 | for (uint32_t i = 0; i < size; i++) 92 | frequencies.push_back(freqs.get(i)); 93 | 94 | cumulative.reserve(size + 1); 95 | initCumulative(false); 96 | total = getHigh(size - 1); 97 | } 98 | 99 | 100 | uint32_t SimpleFrequencyTable::getSymbolLimit() const { 101 | return static_cast(frequencies.size()); 102 | } 103 | 104 | 105 | uint32_t SimpleFrequencyTable::get(uint32_t symbol) const { 106 | return frequencies.at(symbol); 107 | } 108 | 109 | 110 | void SimpleFrequencyTable::set(uint32_t symbol, uint32_t freq) { 111 | if (total < frequencies.at(symbol)) 112 | throw std::logic_error("Assertion error"); 113 | uint32_t temp = total - frequencies.at(symbol); 114 | total = checkedAdd(temp, freq); 115 | frequencies.at(symbol) = freq; 116 | cumulative.clear(); 117 | } 118 | 119 | 120 | void SimpleFrequencyTable::increment(uint32_t symbol) { 121 | if (frequencies.at(symbol) == UINT32_MAX) 122 | throw std::overflow_error("Arithmetic overflow"); 123 | total = checkedAdd(total, 1); 124 | frequencies.at(symbol)++; 125 | cumulative.clear(); 126 | } 127 | 128 | 129 | uint32_t SimpleFrequencyTable::getTotal() const { 130 | return total; 131 | } 132 | 133 | 134 | uint32_t SimpleFrequencyTable::getLow(uint32_t symbol) const { 135 | initCumulative(); 136 | return cumulative.at(symbol); 137 | } 138 | 139 | 140 | uint32_t SimpleFrequencyTable::getHigh(uint32_t symbol) const { 141 | initCumulative(); 142 | return cumulative.at(symbol + 1); 143 | } 144 | 145 | 146 | void SimpleFrequencyTable::initCumulative(bool checkTotal) const { 147 | if (!cumulative.empty()) 148 | return; 149 | uint32_t sum = 0; 150 | cumulative.push_back(sum); 151 | for (uint32_t freq : frequencies) { 152 | // This arithmetic should not throw an exception, because invariants are being maintained 153 | // elsewhere in the data structure. This implementation is just a defensive measure. 154 | sum = checkedAdd(freq, sum); 155 | cumulative.push_back(sum); 156 | } 157 | if (checkTotal && sum != total) 158 | throw std::logic_error("Assertion error"); 159 | } 160 | 161 | 162 | uint32_t SimpleFrequencyTable::checkedAdd(uint32_t x, uint32_t y) { 163 | if (x > UINT32_MAX - y) 164 | throw std::overflow_error("Arithmetic overflow"); 165 | return x + y; 166 | } 167 | -------------------------------------------------------------------------------- /cpp/FrequencyTable.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | 15 | /* 16 | * A table of symbol frequencies. The table holds data for symbols numbered from 0 17 | * to getSymbolLimit()-1. Each symbol has a frequency, which is a non-negative integer. 18 | * Frequency table objects are primarily used for getting cumulative symbol 19 | * frequencies. These objects can be mutable depending on the implementation. 20 | * The total of all symbol frequencies must not exceed UINT32_MAX. 21 | */ 22 | class FrequencyTable { 23 | 24 | public: virtual ~FrequencyTable() = 0; 25 | 26 | 27 | // Returns the number of symbols in this frequency table, which is a positive number. 28 | public: virtual std::uint32_t getSymbolLimit() const = 0; 29 | 30 | 31 | // Returns the frequency of the given symbol. 32 | public: virtual std::uint32_t get(std::uint32_t symbol) const = 0; 33 | 34 | 35 | // Sets the frequency of the given symbol to the given value. 36 | public: virtual void set(std::uint32_t symbol, std::uint32_t freq) = 0; 37 | 38 | 39 | // Increments the frequency of the given symbol. 40 | public: virtual void increment(std::uint32_t symbol) = 0; 41 | 42 | 43 | // Returns the total of all symbol frequencies. The returned 44 | // value is always equal to getHigh(getSymbolLimit() - 1). 45 | public: virtual std::uint32_t getTotal() const = 0; 46 | 47 | 48 | // Returns the sum of the frequencies of all the symbols strictly below the given symbol value. 49 | public: virtual std::uint32_t getLow(std::uint32_t symbol) const = 0; 50 | 51 | 52 | // Returns the sum of the frequencies of the given symbol and all the symbols below. 53 | public: virtual std::uint32_t getHigh(std::uint32_t symbol) const = 0; 54 | 55 | }; 56 | 57 | 58 | 59 | class FlatFrequencyTable final : public FrequencyTable { 60 | 61 | /*---- Fields ----*/ 62 | 63 | // Total number of symbols, which is at least 1. 64 | private: std::uint32_t numSymbols; 65 | 66 | 67 | /*---- Constructor ----*/ 68 | 69 | // Constructs a flat frequency table with the given number of symbols. 70 | public: explicit FlatFrequencyTable(std::uint32_t numSyms); 71 | 72 | 73 | /*---- Methods ----*/ 74 | 75 | public: std::uint32_t getSymbolLimit() const override; 76 | 77 | 78 | public: std::uint32_t get(std::uint32_t symbol) const override; 79 | 80 | 81 | public: std::uint32_t getTotal() const override; 82 | 83 | 84 | public: std::uint32_t getLow(std::uint32_t symbol) const override; 85 | 86 | 87 | public: std::uint32_t getHigh(std::uint32_t symbol) const override; 88 | 89 | 90 | public: void set(std::uint32_t symbol, std::uint32_t freq) override; 91 | 92 | 93 | public: void increment(std::uint32_t symbol) override; 94 | 95 | 96 | private: void checkSymbol(std::uint32_t symbol) const; 97 | 98 | }; 99 | 100 | 101 | 102 | /* 103 | * A mutable table of symbol frequencies. The number of symbols cannot be changed 104 | * after construction. The current algorithm for calculating cumulative frequencies 105 | * takes linear time, but there exist faster algorithms such as Fenwick trees. 106 | */ 107 | class SimpleFrequencyTable final : public FrequencyTable { 108 | 109 | /*---- Fields ----*/ 110 | 111 | // The frequency for each symbol. Its length is at least 1. 112 | private: std::vector frequencies; 113 | 114 | // cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 115 | // Initialized lazily. When its length is not zero, the data is valid. 116 | private: mutable std::vector cumulative; 117 | 118 | // Always equal to the sum of 'frequencies'. 119 | private: std::uint32_t total; 120 | 121 | 122 | /*---- Constructors ----*/ 123 | 124 | // Constructs a frequency table from the given array of symbol frequencies. 125 | // There must be at least 1 symbol, and the total must not exceed UINT32_MAX. 126 | public: explicit SimpleFrequencyTable(const std::vector &freqs); 127 | 128 | 129 | // Constructs a frequency table by copying the given frequency table. 130 | public: explicit SimpleFrequencyTable(const FrequencyTable &freqs); 131 | 132 | 133 | /*---- Methods ----*/ 134 | 135 | public: std::uint32_t getSymbolLimit() const override; 136 | 137 | 138 | public: std::uint32_t get(std::uint32_t symbol) const override; 139 | 140 | 141 | public: void set(std::uint32_t symbol, std::uint32_t freq) override; 142 | 143 | 144 | public: void increment(std::uint32_t symbol) override; 145 | 146 | 147 | public: std::uint32_t getTotal() const override; 148 | 149 | 150 | public: std::uint32_t getLow(std::uint32_t symbol) const override; 151 | 152 | 153 | public: std::uint32_t getHigh(std::uint32_t symbol) const override; 154 | 155 | 156 | // Recomputes the array of cumulative symbol frequencies. 157 | private: void initCumulative(bool checkTotal=true) const; 158 | 159 | 160 | // Adds the given integers, or throws an exception if the result cannot be represented as a uint32_t (i.e. overflow). 161 | private: static std::uint32_t checkedAdd(std::uint32_t x, std::uint32_t y); 162 | 163 | }; 164 | -------------------------------------------------------------------------------- /cpp/Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # Reference arithmetic coding 3 | # 4 | # Copyright (c) Project Nayuki 5 | # MIT License. See readme file. 6 | # https://www.nayuki.io/page/reference-arithmetic-coding 7 | # 8 | 9 | 10 | CXXFLAGS += -std=c++11 -O1 -Wall -Wextra -fsanitize=undefined 11 | 12 | 13 | .SUFFIXES: 14 | 15 | .SECONDARY: 16 | 17 | .DEFAULT_GOAL = all 18 | .PHONY: all clean 19 | 20 | 21 | OBJ = ArithmeticCoder.o BitIoStream.o FrequencyTable.o PpmModel.o 22 | MAINS = AdaptiveArithmeticCompress AdaptiveArithmeticDecompress ArithmeticCompress ArithmeticDecompress PpmCompress PpmDecompress 23 | 24 | all: $(MAINS) 25 | 26 | clean: 27 | rm -f -- $(OBJ) $(MAINS:=.o) $(MAINS) 28 | rm -rf .deps 29 | 30 | %: %.o $(OBJ) 31 | $(CXX) $(CXXFLAGS) -o $@ $^ 32 | 33 | %.o: %.cpp .deps/timestamp 34 | $(CXX) $(CXXFLAGS) -c -o $@ -MMD -MF .deps/$*.d $< 35 | 36 | .deps/timestamp: 37 | mkdir -p .deps 38 | touch .deps/timestamp 39 | 40 | -include .deps/*.d 41 | -------------------------------------------------------------------------------- /cpp/PpmCompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Compression application using prediction by partial matching (PPM) with arithmetic coding 3 | * 4 | * Usage: PpmCompress InputFile OutputFile 5 | * Then use the corresponding "PpmDecompress" application to recreate the original input file. 6 | * Note that both the compressor and decompressor need to use the same PPM context modeling logic. 7 | * The PPM algorithm can be thought of as a powerful generalization of adaptive arithmetic coding. 8 | * 9 | * Copyright (c) Project Nayuki 10 | * MIT License. See readme file. 11 | * https://www.nayuki.io/page/reference-arithmetic-coding 12 | */ 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include "ArithmeticCoder.hpp" 23 | #include "BitIoStream.hpp" 24 | #include "PpmModel.hpp" 25 | 26 | using std::uint32_t; 27 | using std::vector; 28 | 29 | 30 | // Must be at least -1 and match PpmDecompress. Warning: Exponential memory usage at O(257^n). 31 | static constexpr int MODEL_ORDER = 3; 32 | 33 | 34 | static void compress(std::ifstream &in, BitOutputStream &out); 35 | static void encodeSymbol(PpmModel &model, const vector &history, uint32_t symbol, ArithmeticEncoder &enc); 36 | 37 | 38 | int main(int argc, char *argv[]) { 39 | // Handle command line arguments 40 | if (argc != 3) { 41 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 42 | return EXIT_FAILURE; 43 | } 44 | const char *inputFile = argv[1]; 45 | const char *outputFile = argv[2]; 46 | 47 | // Perform file compression 48 | std::ifstream in(inputFile, std::ios::binary); 49 | std::ofstream out(outputFile, std::ios::binary); 50 | BitOutputStream bout(out); 51 | try { 52 | compress(in, bout); 53 | bout.finish(); 54 | return EXIT_SUCCESS; 55 | } catch (const char *msg) { 56 | std::cerr << msg << std::endl; 57 | return EXIT_FAILURE; 58 | } 59 | } 60 | 61 | 62 | static void compress(std::ifstream &in, BitOutputStream &out) { 63 | // Set up encoder and model. In this PPM model, symbol 256 represents EOF; 64 | // its frequency is 1 in the order -1 context but its frequency 65 | // is 0 in all other contexts (which have non-negative order). 66 | ArithmeticEncoder enc(32, out); 67 | PpmModel model(MODEL_ORDER, 257, 256); 68 | vector history; 69 | 70 | while (true) { 71 | // Read and encode one byte 72 | int symbol = in.get(); 73 | if (symbol == std::char_traits::eof()) 74 | break; 75 | if (!(0 <= symbol && symbol <= 255)) 76 | throw std::logic_error("Assertion error"); 77 | uint32_t sym = static_cast(symbol); 78 | encodeSymbol(model, history, sym, enc); 79 | model.incrementContexts(history, sym); 80 | 81 | if (model.modelOrder >= 1) { 82 | // Prepend current symbol, dropping oldest symbol if necessary 83 | if (history.size() >= static_cast(model.modelOrder)) 84 | history.erase(history.end() - 1); 85 | history.insert(history.begin(), sym); 86 | } 87 | } 88 | 89 | encodeSymbol(model, history, 256, enc); // EOF 90 | enc.finish(); // Flush remaining code bits 91 | } 92 | 93 | 94 | static void encodeSymbol(PpmModel &model, const vector &history, uint32_t symbol, ArithmeticEncoder &enc) { 95 | // Try to use highest order context that exists based on the history suffix, such 96 | // that the next symbol has non-zero frequency. When symbol 256 is produced at a context 97 | // at any non-negative order, it means "escape to the next lower order with non-empty 98 | // context". When symbol 256 is produced at the order -1 context, it means "EOF". 99 | for (int order = static_cast(history.size()); order >= 0; order--) { 100 | PpmModel::Context *ctx = model.rootContext.get(); 101 | for (int i = 0; i < order; i++) { 102 | if (ctx->subcontexts.empty()) 103 | throw std::logic_error("Assertion error"); 104 | ctx = ctx->subcontexts.at(history.at(i)).get(); 105 | if (ctx == nullptr) 106 | goto outerEnd; 107 | } 108 | if (symbol != 256 && ctx->frequencies.get(symbol) > 0) { 109 | enc.write(ctx->frequencies, symbol); 110 | return; 111 | } 112 | // Else write context escape symbol and continue decrementing the order 113 | enc.write(ctx->frequencies, 256); 114 | outerEnd:; 115 | } 116 | // Logic for order = -1 117 | enc.write(model.orderMinus1Freqs, symbol); 118 | } 119 | -------------------------------------------------------------------------------- /cpp/PpmDecompress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Decompression application using prediction by partial matching (PPM) with arithmetic coding 3 | * 4 | * Usage: PpmDecompress InputFile OutputFile 5 | * This decompresses files generated by the "PpmCompress" application. 6 | * 7 | * Copyright (c) Project Nayuki 8 | * MIT License. See readme file. 9 | * https://www.nayuki.io/page/reference-arithmetic-coding 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include "ArithmeticCoder.hpp" 21 | #include "BitIoStream.hpp" 22 | #include "PpmModel.hpp" 23 | 24 | using std::uint32_t; 25 | using std::vector; 26 | 27 | 28 | // Must be at least -1 and match PpmDecompress. Warning: Exponential memory usage at O(257^n). 29 | static constexpr int MODEL_ORDER = 3; 30 | 31 | 32 | static void decompress(BitInputStream &in, std::ostream &out); 33 | static uint32_t decodeSymbol(ArithmeticDecoder &dec, PpmModel &model, const vector &history); 34 | 35 | 36 | int main(int argc, char *argv[]) { 37 | // Handle command line arguments 38 | if (argc != 3) { 39 | std::cerr << "Usage: " << argv[0] << " InputFile OutputFile" << std::endl; 40 | return EXIT_FAILURE; 41 | } 42 | const char *inputFile = argv[1]; 43 | const char *outputFile = argv[2]; 44 | 45 | // Perform file decompression 46 | std::ifstream in(inputFile, std::ios::binary); 47 | std::ofstream out(outputFile, std::ios::binary); 48 | BitInputStream bin(in); 49 | try { 50 | decompress(bin, out); 51 | return EXIT_SUCCESS; 52 | } catch (const char *msg) { 53 | std::cerr << msg << std::endl; 54 | return EXIT_FAILURE; 55 | } 56 | } 57 | 58 | 59 | static void decompress(BitInputStream &in, std::ostream &out) { 60 | // Set up decoder and model. In this PPM model, symbol 256 represents EOF; 61 | // its frequency is 1 in the order -1 context but its frequency 62 | // is 0 in all other contexts (which have non-negative order). 63 | ArithmeticDecoder dec(32, in); 64 | PpmModel model(MODEL_ORDER, 257, 256); 65 | vector history; 66 | 67 | while (true) { 68 | // Decode and write one byte 69 | uint32_t symbol = decodeSymbol(dec, model, history); 70 | if (symbol == 256) // EOF symbol 71 | break; 72 | int b = static_cast(symbol); 73 | if (std::numeric_limits::is_signed) 74 | b -= (b >> 7) << 8; 75 | out.put(static_cast(b)); 76 | model.incrementContexts(history, symbol); 77 | 78 | if (model.modelOrder >= 1) { 79 | // Prepend current symbol, dropping oldest symbol if necessary 80 | if (history.size() >= static_cast(model.modelOrder)) 81 | history.erase(history.end() - 1); 82 | history.insert(history.begin(), symbol); 83 | } 84 | } 85 | } 86 | 87 | 88 | static uint32_t decodeSymbol(ArithmeticDecoder &dec, PpmModel &model, const vector &history) { 89 | // Try to use highest order context that exists based on the history suffix. When symbol 256 90 | // is consumed at a context at any non-negative order, it means "escape to the next lower order 91 | // with non-empty context". When symbol 256 is consumed at the order -1 context, it means "EOF". 92 | for (int order = static_cast(history.size()); order >= 0; order--) { 93 | PpmModel::Context *ctx = model.rootContext.get(); 94 | for (int i = 0; i < order; i++) { 95 | if (ctx->subcontexts.empty()) 96 | throw std::logic_error("Assertion error"); 97 | ctx = ctx->subcontexts.at(history.at(i)).get(); 98 | if (ctx == nullptr) 99 | goto outerEnd; 100 | } 101 | { 102 | uint32_t symbol = dec.read(ctx->frequencies); 103 | if (symbol < 256) 104 | return symbol; 105 | } 106 | // Else we read the context escape symbol, so continue decrementing the order 107 | outerEnd:; 108 | } 109 | // Logic for order = -1 110 | return dec.read(model.orderMinus1Freqs); 111 | } 112 | -------------------------------------------------------------------------------- /cpp/PpmModel.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include "PpmModel.hpp" 13 | 14 | using std::uint32_t; 15 | using std::vector; 16 | 17 | 18 | PpmModel::Context::Context(uint32_t symbols, bool hasSubctx) : 19 | frequencies(vector(symbols, 0)) { 20 | if (hasSubctx) { 21 | for (uint32_t i = 0; i < symbols; i++) 22 | subcontexts.push_back(std::unique_ptr(nullptr)); 23 | } 24 | } 25 | 26 | 27 | PpmModel::PpmModel(int order, uint32_t symLimit, uint32_t escapeSym) : 28 | modelOrder(order), 29 | symbolLimit(symLimit), 30 | escapeSymbol(escapeSym), 31 | rootContext(std::unique_ptr(nullptr)), 32 | orderMinus1Freqs(FlatFrequencyTable(symbolLimit)) { 33 | if (!(order >= -1 && escapeSym < symLimit)) 34 | throw std::domain_error("Illegal argument"); 35 | if (order >= 0) { 36 | rootContext.reset(new Context(symbolLimit, order >= 1)); 37 | rootContext->frequencies.increment(escapeSymbol); 38 | } 39 | } 40 | 41 | 42 | void PpmModel::incrementContexts(const vector &history, uint32_t symbol) { 43 | if (modelOrder == -1) 44 | return; 45 | if (!(history.size() <= static_cast(modelOrder) && symbol < symbolLimit)) 46 | throw std::invalid_argument("Illegal argument"); 47 | 48 | Context *ctx = rootContext.get(); 49 | ctx->frequencies.increment(symbol); 50 | std::size_t i = 0; 51 | for (uint32_t sym : history) { 52 | vector > &subctxs = ctx->subcontexts; 53 | if (subctxs.empty()) 54 | throw std::logic_error("Assertion error"); 55 | 56 | std::unique_ptr &subctx = subctxs.at(sym); 57 | if (subctx.get() == nullptr) { 58 | subctx.reset(new Context(symbolLimit, i + 1 < static_cast(modelOrder))); 59 | subctx->frequencies.increment(escapeSymbol); 60 | } 61 | ctx = subctx.get(); 62 | ctx->frequencies.increment(symbol); 63 | i++; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /cpp/PpmModel.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include "FrequencyTable.hpp" 15 | 16 | 17 | class PpmModel final { 18 | 19 | /*---- Helper structure ----*/ 20 | 21 | public: class Context final { 22 | 23 | public: SimpleFrequencyTable frequencies; 24 | 25 | public: std::vector > subcontexts; 26 | 27 | 28 | public: explicit Context(std::uint32_t symbols, bool hasSubctx); 29 | 30 | }; 31 | 32 | 33 | 34 | /*---- Fields ----*/ 35 | 36 | public: int modelOrder; 37 | 38 | private: std::uint32_t symbolLimit; 39 | private: std::uint32_t escapeSymbol; 40 | 41 | public: std::unique_ptr rootContext; 42 | public: SimpleFrequencyTable orderMinus1Freqs; 43 | 44 | 45 | /*---- Constructor ----*/ 46 | 47 | public: explicit PpmModel(int order, std::uint32_t symLimit, std::uint32_t escapeSym); 48 | 49 | 50 | /*---- Methods ----*/ 51 | 52 | public: void incrementContexts(const std::vector &history, std::uint32_t symbol); 53 | 54 | 55 | private: static std::vector makeEmpty(std::uint32_t len); 56 | 57 | }; 58 | -------------------------------------------------------------------------------- /java/src/AdaptiveArithmeticCompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.InputStream; 16 | 17 | 18 | /** 19 | * Compression application using adaptive arithmetic coding. 20 | *

Usage: java AdaptiveArithmeticCompress InputFile OutputFile

21 | *

Then use the corresponding "AdaptiveArithmeticDecompress" application to recreate the original input file.

22 | *

Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1), 23 | * and updates it after each byte encoded. The corresponding decompressor program also starts with a flat 24 | * frequency table and updates it after each byte decoded. It is by design that the compressor and 25 | * decompressor have synchronized states, so that the data can be decompressed properly.

26 | */ 27 | public class AdaptiveArithmeticCompress { 28 | 29 | public static void main(String[] args) throws IOException { 30 | // Handle command line arguments 31 | if (args.length != 2) { 32 | System.err.println("Usage: java AdaptiveArithmeticCompress InputFile OutputFile"); 33 | System.exit(1); 34 | return; 35 | } 36 | File inputFile = new File(args[0]); 37 | File outputFile = new File(args[1]); 38 | 39 | // Perform file compression 40 | try (InputStream in = new BufferedInputStream(new FileInputStream(inputFile)); 41 | BitOutputStream out = new BitOutputStream(new BufferedOutputStream(new FileOutputStream(outputFile)))) { 42 | compress(in, out); 43 | } 44 | } 45 | 46 | 47 | // To allow unit testing, this method is package-private instead of private. 48 | static void compress(InputStream in, BitOutputStream out) throws IOException { 49 | FlatFrequencyTable initFreqs = new FlatFrequencyTable(257); 50 | FrequencyTable freqs = new SimpleFrequencyTable(initFreqs); 51 | ArithmeticEncoder enc = new ArithmeticEncoder(32, out); 52 | while (true) { 53 | // Read and encode one byte 54 | int symbol = in.read(); 55 | if (symbol == -1) 56 | break; 57 | enc.write(freqs, symbol); 58 | freqs.increment(symbol); 59 | } 60 | enc.write(freqs, 256); // EOF 61 | enc.finish(); // Flush remaining code bits 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /java/src/AdaptiveArithmeticDecompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.OutputStream; 16 | 17 | 18 | /** 19 | * Decompression application using adaptive arithmetic coding. 20 | *

Usage: java AdaptiveArithmeticDecompress InputFile OutputFile

21 | *

This decompresses files generated by the "AdaptiveArithmeticCompress" application.

22 | */ 23 | public class AdaptiveArithmeticDecompress { 24 | 25 | public static void main(String[] args) throws IOException { 26 | // Handle command line arguments 27 | if (args.length != 2) { 28 | System.err.println("Usage: java AdaptiveArithmeticDecompress InputFile OutputFile"); 29 | System.exit(1); 30 | return; 31 | } 32 | File inputFile = new File(args[0]); 33 | File outputFile = new File(args[1]); 34 | 35 | // Perform file decompression 36 | try (BitInputStream in = new BitInputStream(new BufferedInputStream(new FileInputStream(inputFile))); 37 | OutputStream out = new BufferedOutputStream(new FileOutputStream(outputFile))) { 38 | decompress(in, out); 39 | } 40 | } 41 | 42 | 43 | // To allow unit testing, this method is package-private instead of private. 44 | static void decompress(BitInputStream in, OutputStream out) throws IOException { 45 | FlatFrequencyTable initFreqs = new FlatFrequencyTable(257); 46 | FrequencyTable freqs = new SimpleFrequencyTable(initFreqs); 47 | ArithmeticDecoder dec = new ArithmeticDecoder(32, in); 48 | while (true) { 49 | // Decode and write one byte 50 | int symbol = dec.read(freqs); 51 | if (symbol == 256) // EOF symbol 52 | break; 53 | out.write(symbol); 54 | freqs.increment(symbol); 55 | } 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /java/src/ArithmeticCoderBase.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.IOException; 10 | 11 | 12 | /** 13 | * Provides the state and behaviors that arithmetic coding encoders and decoders share. 14 | * @see ArithmeticEncoder 15 | * @see ArithmeticDecoder 16 | */ 17 | public abstract class ArithmeticCoderBase { 18 | 19 | /*---- Configuration fields ----*/ 20 | 21 | /** 22 | * Number of bits for the 'low' and 'high' state variables. Must be in the range [1, 62]. 23 | *
    24 | *
  • For state sizes less than the midpoint of around 32, larger values are generally better - 25 | * they allow a larger maximum frequency total (maximumTotal), and they reduce the approximation 26 | * error inherent in adapting fractions to integers; both effects reduce the data encoding loss 27 | * and asymptotically approach the efficiency of arithmetic coding using exact fractions.
  • 28 | *
  • But for state sizes greater than the midpoint, because intermediate computations are limited 29 | * to the long integer type's 63-bit unsigned precision, larger state sizes will decrease the 30 | * maximum frequency total, which might constrain the user-supplied probability model.
  • 31 | *
  • Therefore numStateBits=32 is recommended as the most versatile setting 32 | * because it maximizes maximumTotal (which ends up being slightly over 2^30).
  • 33 | *
  • Note that numStateBits=62 is legal but useless because it implies maximumTotal=1, 34 | * which means a frequency table can only support one symbol with non-zero frequency.
  • 35 | *
36 | */ 37 | protected final int numStateBits; 38 | 39 | /** Maximum range (high+1-low) during coding (trivial), which is 2^numStateBits = 1000...000. */ 40 | protected final long fullRange; 41 | 42 | /** The top bit at width numStateBits, which is 0100...000. */ 43 | protected final long halfRange; 44 | 45 | /** The second highest bit at width numStateBits, which is 0010...000. This is zero when numStateBits=1. */ 46 | protected final long quarterRange; 47 | 48 | /** Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. */ 49 | protected final long minimumRange; 50 | 51 | /** Maximum allowed total from a frequency table at all times during coding. */ 52 | protected final long maximumTotal; 53 | 54 | /** Bit mask of numStateBits ones, which is 0111...111. */ 55 | protected final long stateMask; 56 | 57 | 58 | 59 | /*---- State fields ----*/ 60 | 61 | /** 62 | * Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. 63 | */ 64 | protected long low; 65 | 66 | /** 67 | * High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. 68 | */ 69 | protected long high; 70 | 71 | 72 | 73 | /*---- Constructor ----*/ 74 | 75 | /** 76 | * Constructs an arithmetic coder, which initializes the code range. 77 | * @param numBits the number of bits for the arithmetic coding range 78 | * @throws IllegalArgumentException if stateSize is outside the range [1, 62] 79 | */ 80 | public ArithmeticCoderBase(int numBits) { 81 | if (!(1 <= numBits && numBits <= 62)) 82 | throw new IllegalArgumentException("State size out of range"); 83 | numStateBits = numBits; 84 | fullRange = 1L << numStateBits; 85 | halfRange = fullRange >>> 1; // Non-zero 86 | quarterRange = halfRange >>> 1; // Can be zero 87 | minimumRange = quarterRange + 2; // At least 2 88 | maximumTotal = Math.min(Long.MAX_VALUE / fullRange, minimumRange); 89 | stateMask = fullRange - 1; 90 | 91 | low = 0; 92 | high = stateMask; 93 | } 94 | 95 | 96 | 97 | /*---- Methods ----*/ 98 | 99 | /** 100 | * Updates the code range (low and high) of this arithmetic coder as a result 101 | * of processing the specified symbol with the specified frequency table. 102 | *

Invariants that are true before and after encoding/decoding each symbol 103 | * (letting fullRange = 2numStateBits):

104 | *
    105 | *
  • 0 ≤ low ≤ code ≤ high < fullRange. ('code' exists only in the decoder.) 106 | * Therefore these variables are unsigned integers of numStateBits bits.
  • 107 | *
  • low < 1/2 × fullRange ≤ high. 108 | * In other words, they are in different halves of the full range.
  • 109 | *
  • (low < 1/4 × fullRange) || (high ≥ 3/4 × fullRange). 110 | * In other words, they are not both in the middle two quarters.
  • 111 | *
  • Let range = high − low + 1, then fullRange/4 < minimumRange ≤ range ≤ 112 | * fullRange. These invariants for 'range' essentially dictate the maximum total that the 113 | * incoming frequency table can have, such that intermediate calculations don't overflow.
  • 114 | *
115 | * @param freqs the frequency table to use 116 | * @param symbol the symbol that was processed 117 | * @throws IllegalArgumentException if the symbol has zero frequency or the frequency table's total is too large 118 | */ 119 | protected void update(CheckedFrequencyTable freqs, int symbol) throws IOException { 120 | // State check 121 | if (low >= high || (low & stateMask) != low || (high & stateMask) != high) 122 | throw new AssertionError("Low or high out of range"); 123 | long range = high - low + 1; 124 | if (!(minimumRange <= range && range <= fullRange)) 125 | throw new AssertionError("Range out of range"); 126 | 127 | // Frequency table values check 128 | long total = freqs.getTotal(); 129 | long symLow = freqs.getLow(symbol); 130 | long symHigh = freqs.getHigh(symbol); 131 | if (symLow == symHigh) 132 | throw new IllegalArgumentException("Symbol has zero frequency"); 133 | if (total > maximumTotal) 134 | throw new IllegalArgumentException("Cannot code symbol because total is too large"); 135 | 136 | // Update range 137 | long newLow = low + symLow * range / total; 138 | long newHigh = low + symHigh * range / total - 1; 139 | low = newLow; 140 | high = newHigh; 141 | 142 | // While low and high have the same top bit value, shift them out 143 | while (((low ^ high) & halfRange) == 0) { 144 | shift(); 145 | low = ((low << 1) & stateMask); 146 | high = ((high << 1) & stateMask) | 1; 147 | } 148 | // Now low's top bit must be 0 and high's top bit must be 1 149 | 150 | // While low's top two bits are 01 and high's are 10, delete the second highest bit of both 151 | while ((low & ~high & quarterRange) != 0) { 152 | underflow(); 153 | low = (low << 1) ^ halfRange; 154 | high = ((high ^ halfRange) << 1) | halfRange | 1; 155 | } 156 | } 157 | 158 | 159 | /** 160 | * Called to handle the situation when the top bit of {@code low} and {@code high} are equal. 161 | * @throws IOException if an I/O exception occurred 162 | */ 163 | protected abstract void shift() throws IOException; 164 | 165 | 166 | /** 167 | * Called to handle the situation when low=01(...) and high=10(...). 168 | * @throws IOException if an I/O exception occurred 169 | */ 170 | protected abstract void underflow() throws IOException; 171 | 172 | } 173 | -------------------------------------------------------------------------------- /java/src/ArithmeticCompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.InputStream; 16 | 17 | 18 | /** 19 | * Compression application using static arithmetic coding. 20 | *

Usage: java ArithmeticCompress InputFile OutputFile

21 | *

Then use the corresponding "ArithmeticDecompress" application to recreate the original input file.

22 | *

Note that the application uses an alphabet of 257 symbols - 256 symbols for the byte 23 | * values and 1 symbol for the EOF marker. The compressed file format starts with a list 24 | * of 256 symbol frequencies, and then followed by the arithmetic-coded data.

25 | */ 26 | public class ArithmeticCompress { 27 | 28 | public static void main(String[] args) throws IOException { 29 | // Handle command line arguments 30 | if (args.length != 2) { 31 | System.err.println("Usage: java ArithmeticCompress InputFile OutputFile"); 32 | System.exit(1); 33 | return; 34 | } 35 | File inputFile = new File(args[0]); 36 | File outputFile = new File(args[1]); 37 | 38 | // Read input file once to compute symbol frequencies 39 | FrequencyTable freqs = getFrequencies(inputFile); 40 | freqs.increment(256); // EOF symbol gets a frequency of 1 41 | 42 | // Read input file again, compress with arithmetic coding, and write output file 43 | try (InputStream in = new BufferedInputStream(new FileInputStream(inputFile)); 44 | BitOutputStream out = new BitOutputStream(new BufferedOutputStream(new FileOutputStream(outputFile)))) { 45 | writeFrequencies(out, freqs); 46 | compress(freqs, in, out); 47 | } 48 | } 49 | 50 | 51 | // Returns a frequency table based on the bytes in the given file. 52 | // Also contains an extra entry for symbol 256, whose frequency is set to 0. 53 | private static FrequencyTable getFrequencies(File file) throws IOException { 54 | FrequencyTable freqs = new SimpleFrequencyTable(new int[257]); 55 | try (InputStream input = new BufferedInputStream(new FileInputStream(file))) { 56 | while (true) { 57 | int b = input.read(); 58 | if (b == -1) 59 | break; 60 | freqs.increment(b); 61 | } 62 | } 63 | return freqs; 64 | } 65 | 66 | 67 | // To allow unit testing, this method is package-private instead of private. 68 | static void writeFrequencies(BitOutputStream out, FrequencyTable freqs) throws IOException { 69 | for (int i = 0; i < 256; i++) 70 | writeInt(out, 32, freqs.get(i)); 71 | } 72 | 73 | 74 | // To allow unit testing, this method is package-private instead of private. 75 | static void compress(FrequencyTable freqs, InputStream in, BitOutputStream out) throws IOException { 76 | ArithmeticEncoder enc = new ArithmeticEncoder(32, out); 77 | while (true) { 78 | int symbol = in.read(); 79 | if (symbol == -1) 80 | break; 81 | enc.write(freqs, symbol); 82 | } 83 | enc.write(freqs, 256); // EOF 84 | enc.finish(); // Flush remaining code bits 85 | } 86 | 87 | 88 | // Writes an unsigned integer of the given bit width to the given stream. 89 | private static void writeInt(BitOutputStream out, int numBits, int value) throws IOException { 90 | if (numBits < 0 || numBits > 32) 91 | throw new IllegalArgumentException(); 92 | 93 | for (int i = numBits - 1; i >= 0; i--) 94 | out.write((value >>> i) & 1); // Big endian 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /java/src/ArithmeticDecoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.IOException; 10 | import java.util.Objects; 11 | 12 | 13 | /** 14 | * Reads from an arithmetic-coded bit stream and decodes symbols. Not thread-safe. 15 | * @see ArithmeticEncoder 16 | */ 17 | public final class ArithmeticDecoder extends ArithmeticCoderBase { 18 | 19 | /*---- Fields ----*/ 20 | 21 | // The underlying bit input stream (not null). 22 | private BitInputStream input; 23 | 24 | // The current raw code bits being buffered, which is always in the range [low, high]. 25 | private long code; 26 | 27 | 28 | 29 | /*---- Constructor ----*/ 30 | 31 | /** 32 | * Constructs an arithmetic coding decoder based on the 33 | * specified bit input stream, and fills the code bits. 34 | * @param numBits the number of bits for the arithmetic coding range 35 | * @param in the bit input stream to read from 36 | * @throws NullPointerException if the input steam is {@code null} 37 | * @throws IllegalArgumentException if stateSize is outside the range [1, 62] 38 | * @throws IOException if an I/O exception occurred 39 | */ 40 | public ArithmeticDecoder(int numBits, BitInputStream in) throws IOException { 41 | super(numBits); 42 | input = Objects.requireNonNull(in); 43 | code = 0; 44 | for (int i = 0; i < numStateBits; i++) 45 | code = code << 1 | readCodeBit(); 46 | } 47 | 48 | 49 | 50 | /*---- Methods ----*/ 51 | 52 | /** 53 | * Decodes the next symbol based on the specified frequency table and returns it. 54 | * Also updates this arithmetic coder's state and may read in some bits. 55 | * @param freqs the frequency table to use 56 | * @return the next symbol 57 | * @throws NullPointerException if the frequency table is {@code null} 58 | * @throws IOException if an I/O exception occurred 59 | */ 60 | public int read(FrequencyTable freqs) throws IOException { 61 | return read(new CheckedFrequencyTable(freqs)); 62 | } 63 | 64 | 65 | /** 66 | * Decodes the next symbol based on the specified frequency table and returns it. 67 | * Also updates this arithmetic coder's state and may read in some bits. 68 | * @param freqs the frequency table to use 69 | * @return the next symbol 70 | * @throws NullPointerException if the frequency table is {@code null} 71 | * @throws IllegalArgumentException if the frequency table's total is too large 72 | * @throws IOException if an I/O exception occurred 73 | */ 74 | public int read(CheckedFrequencyTable freqs) throws IOException { 75 | // Translate from coding range scale to frequency table scale 76 | long total = freqs.getTotal(); 77 | if (total > maximumTotal) 78 | throw new IllegalArgumentException("Cannot decode symbol because total is too large"); 79 | long range = high - low + 1; 80 | long offset = code - low; 81 | long value = ((offset + 1) * total - 1) / range; 82 | if (value * range / total > offset) 83 | throw new AssertionError(); 84 | if (!(0 <= value && value < total)) 85 | throw new AssertionError(); 86 | 87 | // A kind of binary search. Find highest symbol such that freqs.getLow(symbol) <= value. 88 | int start = 0; 89 | int end = freqs.getSymbolLimit(); 90 | while (end - start > 1) { 91 | int middle = (start + end) >>> 1; 92 | if (freqs.getLow(middle) > value) 93 | end = middle; 94 | else 95 | start = middle; 96 | } 97 | if (start + 1 != end) 98 | throw new AssertionError(); 99 | 100 | int symbol = start; 101 | if (!(freqs.getLow(symbol) * range / total <= offset && offset < freqs.getHigh(symbol) * range / total)) 102 | throw new AssertionError(); 103 | update(freqs, symbol); 104 | if (!(low <= code && code <= high)) 105 | throw new AssertionError("Code out of range"); 106 | return symbol; 107 | } 108 | 109 | 110 | protected void shift() throws IOException { 111 | code = ((code << 1) & stateMask) | readCodeBit(); 112 | } 113 | 114 | 115 | protected void underflow() throws IOException { 116 | code = (code & halfRange) | ((code << 1) & (stateMask >>> 1)) | readCodeBit(); 117 | } 118 | 119 | 120 | // Returns the next bit (0 or 1) from the input stream. The end 121 | // of stream is treated as an infinite number of trailing zeros. 122 | private int readCodeBit() throws IOException { 123 | int temp = input.read(); 124 | if (temp == -1) 125 | temp = 0; 126 | return temp; 127 | } 128 | 129 | } 130 | -------------------------------------------------------------------------------- /java/src/ArithmeticDecompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.OutputStream; 16 | 17 | 18 | /** 19 | * Decompression application using static arithmetic coding. 20 | *

Usage: java ArithmeticDecompress InputFile OutputFile

21 | *

This decompresses files generated by the "ArithmeticCompress" application.

22 | */ 23 | public class ArithmeticDecompress { 24 | 25 | public static void main(String[] args) throws IOException { 26 | // Handle command line arguments 27 | if (args.length != 2) { 28 | System.err.println("Usage: java ArithmeticDecompress InputFile OutputFile"); 29 | System.exit(1); 30 | return; 31 | } 32 | File inputFile = new File(args[0]); 33 | File outputFile = new File(args[1]); 34 | 35 | // Perform file decompression 36 | try (BitInputStream in = new BitInputStream(new BufferedInputStream(new FileInputStream(inputFile))); 37 | OutputStream out = new BufferedOutputStream(new FileOutputStream(outputFile))) { 38 | FrequencyTable freqs = readFrequencies(in); 39 | decompress(freqs, in, out); 40 | } 41 | } 42 | 43 | 44 | // To allow unit testing, this method is package-private instead of private. 45 | static FrequencyTable readFrequencies(BitInputStream in) throws IOException { 46 | int[] freqs = new int[257]; 47 | for (int i = 0; i < 256; i++) 48 | freqs[i] = readInt(in, 32); 49 | freqs[256] = 1; // EOF symbol 50 | return new SimpleFrequencyTable(freqs); 51 | } 52 | 53 | 54 | // To allow unit testing, this method is package-private instead of private. 55 | static void decompress(FrequencyTable freqs, BitInputStream in, OutputStream out) throws IOException { 56 | ArithmeticDecoder dec = new ArithmeticDecoder(32, in); 57 | while (true) { 58 | int symbol = dec.read(freqs); 59 | if (symbol == 256) // EOF symbol 60 | break; 61 | out.write(symbol); 62 | } 63 | } 64 | 65 | 66 | // Reads an unsigned integer of the given bit width from the given stream. 67 | private static int readInt(BitInputStream in, int numBits) throws IOException { 68 | if (!(0 <= numBits && numBits <= 32)) 69 | throw new IllegalArgumentException(); 70 | 71 | int result = 0; 72 | for (int i = 0; i < numBits; i++) 73 | result = (result << 1) | in.readNoEof(); // Big endian 74 | return result; 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /java/src/ArithmeticEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.IOException; 10 | import java.util.Objects; 11 | 12 | 13 | /** 14 | * Encodes symbols and writes to an arithmetic-coded bit stream. Not thread-safe. 15 | * @see ArithmeticDecoder 16 | */ 17 | public final class ArithmeticEncoder extends ArithmeticCoderBase { 18 | 19 | /*---- Fields ----*/ 20 | 21 | // The underlying bit output stream (not null). 22 | private BitOutputStream output; 23 | 24 | // Number of saved underflow bits. This value can grow without bound, 25 | // so a truly correct implementation would use a BigInteger. 26 | private int numUnderflow; 27 | 28 | 29 | 30 | /*---- Constructor ----*/ 31 | 32 | /** 33 | * Constructs an arithmetic coding encoder based on the specified bit output stream. 34 | * @param numBits the number of bits for the arithmetic coding range 35 | * @param out the bit output stream to write to 36 | * @throws NullPointerException if the output stream is {@code null} 37 | * @throws IllegalArgumentException if stateSize is outside the range [1, 62] 38 | */ 39 | public ArithmeticEncoder(int numBits, BitOutputStream out) { 40 | super(numBits); 41 | output = Objects.requireNonNull(out); 42 | numUnderflow = 0; 43 | } 44 | 45 | 46 | 47 | /*---- Methods ----*/ 48 | 49 | /** 50 | * Encodes the specified symbol based on the specified frequency table. 51 | * This updates this arithmetic coder's state and may write out some bits. 52 | * @param freqs the frequency table to use 53 | * @param symbol the symbol to encode 54 | * @throws NullPointerException if the frequency table is {@code null} 55 | * @throws IllegalArgumentException if the symbol has zero frequency 56 | * or the frequency table's total is too large 57 | * @throws IOException if an I/O exception occurred 58 | */ 59 | public void write(FrequencyTable freqs, int symbol) throws IOException { 60 | write(new CheckedFrequencyTable(freqs), symbol); 61 | } 62 | 63 | 64 | /** 65 | * Encodes the specified symbol based on the specified frequency table. 66 | * Also updates this arithmetic coder's state and may write out some bits. 67 | * @param freqs the frequency table to use 68 | * @param symbol the symbol to encode 69 | * @throws NullPointerException if the frequency table is {@code null} 70 | * @throws IllegalArgumentException if the symbol has zero frequency 71 | * or the frequency table's total is too large 72 | * @throws IOException if an I/O exception occurred 73 | */ 74 | public void write(CheckedFrequencyTable freqs, int symbol) throws IOException { 75 | update(freqs, symbol); 76 | } 77 | 78 | 79 | /** 80 | * Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. 81 | * It is important that this method must be called at the end of the each encoding process. 82 | *

Note that this method merely writes data to the underlying output stream but does not close it.

83 | * @throws IOException if an I/O exception occurred 84 | */ 85 | public void finish() throws IOException { 86 | output.write(1); 87 | } 88 | 89 | 90 | protected void shift() throws IOException { 91 | int bit = (int)(low >>> (numStateBits - 1)); 92 | output.write(bit); 93 | 94 | // Write out the saved underflow bits 95 | for (; numUnderflow > 0; numUnderflow--) 96 | output.write(bit ^ 1); 97 | } 98 | 99 | 100 | protected void underflow() { 101 | if (numUnderflow == Integer.MAX_VALUE) 102 | throw new ArithmeticException("Maximum underflow reached"); 103 | numUnderflow++; 104 | } 105 | 106 | } 107 | -------------------------------------------------------------------------------- /java/src/BitInputStream.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.EOFException; 10 | import java.io.IOException; 11 | import java.io.InputStream; 12 | import java.util.Objects; 13 | 14 | 15 | /** 16 | * A stream of bits that can be read. Because they come from an underlying byte stream, 17 | * the total number of bits is always a multiple of 8. The bits are read in big endian. 18 | * Mutable and not thread-safe. 19 | * @see BitOutputStream 20 | */ 21 | public final class BitInputStream implements AutoCloseable { 22 | 23 | /*---- Fields ----*/ 24 | 25 | // The underlying byte stream to read from (not null). 26 | private InputStream input; 27 | 28 | // Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached. 29 | private int currentByte; 30 | 31 | // Number of remaining bits in the current byte, always between 0 and 7 (inclusive). 32 | private int numBitsRemaining; 33 | 34 | 35 | 36 | /*---- Constructor ----*/ 37 | 38 | /** 39 | * Constructs a bit input stream based on the specified byte input stream. 40 | * @param in the byte input stream 41 | * @throws NullPointerException if the input stream is {@code null} 42 | */ 43 | public BitInputStream(InputStream in) { 44 | input = Objects.requireNonNull(in); 45 | currentByte = 0; 46 | numBitsRemaining = 0; 47 | } 48 | 49 | 50 | 51 | /*---- Methods ----*/ 52 | 53 | /** 54 | * Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if 55 | * the end of stream is reached. The end of stream always occurs on a byte boundary. 56 | * @return the next bit of 0 or 1, or -1 for the end of stream 57 | * @throws IOException if an I/O exception occurred 58 | */ 59 | public int read() throws IOException { 60 | if (currentByte == -1) 61 | return -1; 62 | if (numBitsRemaining == 0) { 63 | currentByte = input.read(); 64 | if (currentByte == -1) 65 | return -1; 66 | numBitsRemaining = 8; 67 | } 68 | if (numBitsRemaining <= 0) 69 | throw new AssertionError(); 70 | numBitsRemaining--; 71 | return (currentByte >>> numBitsRemaining) & 1; 72 | } 73 | 74 | 75 | /** 76 | * Reads a bit from this stream. Returns 0 or 1 if a bit is available, or throws an {@code EOFException} 77 | * if the end of stream is reached. The end of stream always occurs on a byte boundary. 78 | * @return the next bit of 0 or 1 79 | * @throws IOException if an I/O exception occurred 80 | * @throws EOFException if the end of stream is reached 81 | */ 82 | public int readNoEof() throws IOException { 83 | int result = read(); 84 | if (result != -1) 85 | return result; 86 | else 87 | throw new EOFException(); 88 | } 89 | 90 | 91 | /** 92 | * Closes this stream and the underlying input stream. 93 | * @throws IOException if an I/O exception occurred 94 | */ 95 | public void close() throws IOException { 96 | input.close(); 97 | currentByte = -1; 98 | numBitsRemaining = 0; 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /java/src/BitOutputStream.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.IOException; 10 | import java.io.OutputStream; 11 | import java.util.Objects; 12 | 13 | 14 | /** 15 | * A stream where bits can be written to. Because they are written to an underlying 16 | * byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. 17 | * The bits are written in big endian. Mutable and not thread-safe. 18 | * @see BitInputStream 19 | */ 20 | public final class BitOutputStream implements AutoCloseable { 21 | 22 | /*---- Fields ----*/ 23 | 24 | // The underlying byte stream to write to (not null). 25 | private OutputStream output; 26 | 27 | // The accumulated bits for the current byte, always in the range [0x00, 0xFF]. 28 | private int currentByte; 29 | 30 | // Number of accumulated bits in the current byte, always between 0 and 7 (inclusive). 31 | private int numBitsFilled; 32 | 33 | 34 | 35 | /*---- Constructor ----*/ 36 | 37 | /** 38 | * Constructs a bit output stream based on the specified byte output stream. 39 | * @param out the byte output stream 40 | * @throws NullPointerException if the output stream is {@code null} 41 | */ 42 | public BitOutputStream(OutputStream out) { 43 | output = Objects.requireNonNull(out); 44 | currentByte = 0; 45 | numBitsFilled = 0; 46 | } 47 | 48 | 49 | 50 | /*---- Methods ----*/ 51 | 52 | /** 53 | * Writes a bit to the stream. The specified bit must be 0 or 1. 54 | * @param b the bit to write, which must be 0 or 1 55 | * @throws IOException if an I/O exception occurred 56 | */ 57 | public void write(int b) throws IOException { 58 | if (b != 0 && b != 1) 59 | throw new IllegalArgumentException("Argument must be 0 or 1"); 60 | currentByte = (currentByte << 1) | b; 61 | numBitsFilled++; 62 | if (numBitsFilled == 8) { 63 | output.write(currentByte); 64 | currentByte = 0; 65 | numBitsFilled = 0; 66 | } 67 | } 68 | 69 | 70 | /** 71 | * Closes this stream and the underlying output stream. If called when this 72 | * bit stream is not at a byte boundary, then the minimum number of "0" bits 73 | * (between 0 and 7 of them) are written as padding to reach the next byte boundary. 74 | * @throws IOException if an I/O exception occurred 75 | */ 76 | public void close() throws IOException { 77 | while (numBitsFilled != 0) 78 | write(0); 79 | output.close(); 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /java/src/CheckedFrequencyTable.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.util.Objects; 10 | 11 | 12 | /** 13 | * A wrapper that checks the preconditions (arguments) and postconditions (return value) 14 | * of all the frequency table methods. Useful for finding faults in a frequency table 15 | * implementation. However, arithmetic overflow conditions are not checked. 16 | */ 17 | public final class CheckedFrequencyTable implements FrequencyTable { 18 | 19 | /*---- Fields ----*/ 20 | 21 | // The underlying frequency table that holds the data (not null). 22 | private FrequencyTable freqTable; 23 | 24 | 25 | 26 | /*---- Constructor ----*/ 27 | 28 | public CheckedFrequencyTable(FrequencyTable freq) { 29 | freqTable = Objects.requireNonNull(freq); 30 | } 31 | 32 | 33 | 34 | /*---- Methods ----*/ 35 | 36 | public int getSymbolLimit() { 37 | int result = freqTable.getSymbolLimit(); 38 | if (result <= 0) 39 | throw new AssertionError("Non-positive symbol limit"); 40 | return result; 41 | } 42 | 43 | 44 | public int get(int symbol) { 45 | int result = freqTable.get(symbol); 46 | if (!isSymbolInRange(symbol)) 47 | throw new AssertionError("IllegalArgumentException expected"); 48 | if (result < 0) 49 | throw new AssertionError("Negative symbol frequency"); 50 | return result; 51 | } 52 | 53 | 54 | public int getTotal() { 55 | int result = freqTable.getTotal(); 56 | if (result < 0) 57 | throw new AssertionError("Negative total frequency"); 58 | return result; 59 | } 60 | 61 | 62 | public int getLow(int symbol) { 63 | if (isSymbolInRange(symbol)) { 64 | int low = freqTable.getLow (symbol); 65 | int high = freqTable.getHigh(symbol); 66 | if (!(0 <= low && low <= high && high <= freqTable.getTotal())) 67 | throw new AssertionError("Symbol low cumulative frequency out of range"); 68 | return low; 69 | } else { 70 | freqTable.getLow(symbol); 71 | throw new AssertionError("IllegalArgumentException expected"); 72 | } 73 | } 74 | 75 | 76 | public int getHigh(int symbol) { 77 | if (isSymbolInRange(symbol)) { 78 | int low = freqTable.getLow (symbol); 79 | int high = freqTable.getHigh(symbol); 80 | if (!(0 <= low && low <= high && high <= freqTable.getTotal())) 81 | throw new AssertionError("Symbol high cumulative frequency out of range"); 82 | return high; 83 | } else { 84 | freqTable.getHigh(symbol); 85 | throw new AssertionError("IllegalArgumentException expected"); 86 | } 87 | } 88 | 89 | 90 | public String toString() { 91 | return "CheckedFrequencyTable (" + freqTable.toString() + ")"; 92 | } 93 | 94 | 95 | public void set(int symbol, int freq) { 96 | freqTable.set(symbol, freq); 97 | if (!isSymbolInRange(symbol) || freq < 0) 98 | throw new AssertionError("IllegalArgumentException expected"); 99 | } 100 | 101 | 102 | public void increment(int symbol) { 103 | freqTable.increment(symbol); 104 | if (!isSymbolInRange(symbol)) 105 | throw new AssertionError("IllegalArgumentException expected"); 106 | } 107 | 108 | 109 | private boolean isSymbolInRange(int symbol) { 110 | return 0 <= symbol && symbol < getSymbolLimit(); 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /java/src/FlatFrequencyTable.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | 10 | /** 11 | * An immutable frequency table where every symbol has the same frequency of 1. 12 | * Useful as a fallback model when no statistics are available. 13 | */ 14 | public final class FlatFrequencyTable implements FrequencyTable { 15 | 16 | /*---- Fields ----*/ 17 | 18 | // Total number of symbols, which is at least 1. 19 | private final int numSymbols; 20 | 21 | 22 | 23 | /*---- Constructor ----*/ 24 | 25 | /** 26 | * Constructs a flat frequency table with the specified number of symbols. 27 | * @param numSyms the number of symbols, which must be at least 1 28 | * @throws IllegalArgumentException if the number of symbols is less than 1 29 | */ 30 | public FlatFrequencyTable(int numSyms) { 31 | if (numSyms < 1) 32 | throw new IllegalArgumentException("Number of symbols must be positive"); 33 | numSymbols = numSyms; 34 | } 35 | 36 | 37 | 38 | /*---- Methods ----*/ 39 | 40 | /** 41 | * Returns the number of symbols in this table, which is at least 1. 42 | * @return the number of symbols in this table 43 | */ 44 | public int getSymbolLimit() { 45 | return numSymbols; 46 | } 47 | 48 | 49 | /** 50 | * Returns the frequency of the specified symbol, which is always 1. 51 | * @param symbol the symbol to query 52 | * @return the frequency of the symbol, which is 1 53 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 54 | */ 55 | public int get(int symbol) { 56 | checkSymbol(symbol); 57 | return 1; 58 | } 59 | 60 | 61 | /** 62 | * Returns the total of all symbol frequencies, which is 63 | * always equal to the number of symbols in this table. 64 | * @return the total of all symbol frequencies, which is {@code getSymbolLimit()} 65 | */ 66 | public int getTotal() { 67 | return numSymbols; 68 | } 69 | 70 | 71 | /** 72 | * Returns the sum of the frequencies of all the symbols strictly below 73 | * the specified symbol value. The returned value is equal to {@code symbol}. 74 | * @param symbol the symbol to query 75 | * @return the sum of the frequencies of all the symbols below {@code symbol}, which is {@code symbol} 76 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 77 | */ 78 | public int getLow(int symbol) { 79 | checkSymbol(symbol); 80 | return symbol; 81 | } 82 | 83 | 84 | /** 85 | * Returns the sum of the frequencies of the specified symbol and all 86 | * the symbols below. The returned value is equal to {@code symbol + 1}. 87 | * @param symbol the symbol to query 88 | * @return the sum of the frequencies of {@code symbol} and all symbols below, which is {@code symbol + 1} 89 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 90 | */ 91 | public int getHigh(int symbol) { 92 | checkSymbol(symbol); 93 | return symbol + 1; 94 | } 95 | 96 | 97 | // Returns silently if 0 <= symbol < numSymbols, otherwise throws an exception. 98 | private void checkSymbol(int symbol) { 99 | if (!(0 <= symbol && symbol < numSymbols)) 100 | throw new IllegalArgumentException("Symbol out of range"); 101 | } 102 | 103 | 104 | /** 105 | * Returns a string representation of this frequency table. The format is subject to change. 106 | * @return a string representation of this frequency table 107 | */ 108 | public String toString() { 109 | return "FlatFrequencyTable=" + numSymbols; 110 | } 111 | 112 | 113 | /** 114 | * Unsupported operation, because this frequency table is immutable. 115 | * @param symbol ignored 116 | * @param freq ignored 117 | * @throws UnsupportedOperationException because this frequency table is immutable 118 | */ 119 | public void set(int symbol, int freq) { 120 | throw new UnsupportedOperationException(); 121 | } 122 | 123 | 124 | /** 125 | * Unsupported operation, because this frequency table is immutable. 126 | * @param symbol ignored 127 | * @throws UnsupportedOperationException because this frequency table is immutable 128 | */ 129 | public void increment(int symbol) { 130 | throw new UnsupportedOperationException(); 131 | } 132 | 133 | } 134 | -------------------------------------------------------------------------------- /java/src/FrequencyTable.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | 10 | /** 11 | * A table of symbol frequencies. The table holds data for symbols numbered from 0 12 | * to getSymbolLimit()−1. Each symbol has a frequency, which is a non-negative integer. 13 | *

Frequency table objects are primarily used for getting cumulative symbol 14 | * frequencies. These objects can be mutable depending on the implementation. 15 | * The total of all symbol frequencies must not exceed Integer.MAX_VALUE.

16 | */ 17 | public interface FrequencyTable { 18 | 19 | /** 20 | * Returns the number of symbols in this frequency table, which is a positive number. 21 | * @return the number of symbols in this frequency table 22 | */ 23 | public int getSymbolLimit(); 24 | 25 | 26 | /** 27 | * Returns the frequency of the specified symbol. The returned value is at least 0. 28 | * @param symbol the symbol to query 29 | * @return the frequency of the symbol 30 | * @throws IllegalArgumentException if the symbol is out of range 31 | */ 32 | public int get(int symbol); 33 | 34 | 35 | /** 36 | * Sets the frequency of the specified symbol to the specified value. 37 | * The frequency value must be at least 0. 38 | * @param symbol the symbol to set 39 | * @param freq the frequency value to set 40 | * @throws IllegalArgumentException if the frequency is negative or the symbol is out of range 41 | * @throws ArithmeticException if an arithmetic overflow occurs 42 | */ 43 | public void set(int symbol, int freq); 44 | 45 | 46 | /** 47 | * Increments the frequency of the specified symbol. 48 | * @param symbol the symbol whose frequency to increment 49 | * @throws IllegalArgumentException if the symbol is out of range 50 | * @throws ArithmeticException if an arithmetic overflow occurs 51 | */ 52 | public void increment(int symbol); 53 | 54 | 55 | /** 56 | * Returns the total of all symbol frequencies. The returned value is at 57 | * least 0 and is always equal to {@code getHigh(getSymbolLimit() - 1)}. 58 | * @return the total of all symbol frequencies 59 | */ 60 | public int getTotal(); 61 | 62 | 63 | /** 64 | * Returns the sum of the frequencies of all the symbols strictly 65 | * below the specified symbol value. The returned value is at least 0. 66 | * @param symbol the symbol to query 67 | * @return the sum of the frequencies of all the symbols below {@code symbol} 68 | * @throws IllegalArgumentException if the symbol is out of range 69 | */ 70 | public int getLow(int symbol); 71 | 72 | 73 | /** 74 | * Returns the sum of the frequencies of the specified symbol 75 | * and all the symbols below. The returned value is at least 0. 76 | * @param symbol the symbol to query 77 | * @return the sum of the frequencies of {@code symbol} and all symbols below 78 | * @throws IllegalArgumentException if the symbol is out of range 79 | */ 80 | public int getHigh(int symbol); 81 | 82 | } 83 | -------------------------------------------------------------------------------- /java/src/PpmCompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.InputStream; 16 | import java.util.Arrays; 17 | 18 | 19 | /** 20 | * Compression application using prediction by partial matching (PPM) with arithmetic coding. 21 | *

Usage: java PpmCompress InputFile OutputFile

22 | *

Then use the corresponding "PpmDecompress" application to recreate the original input file.

23 | *

Note that both the compressor and decompressor need to use the same PPM context modeling logic. 24 | * The PPM algorithm can be thought of as a powerful generalization of adaptive arithmetic coding.

25 | */ 26 | public final class PpmCompress { 27 | 28 | // Must be at least -1 and match PpmDecompress. Warning: Exponential memory usage at O(257^n). 29 | private static final int MODEL_ORDER = 3; 30 | 31 | 32 | public static void main(String[] args) throws IOException { 33 | // Handle command line arguments 34 | if (args.length != 2) { 35 | System.err.println("Usage: java PpmCompress InputFile OutputFile"); 36 | System.exit(1); 37 | return; 38 | } 39 | File inputFile = new File(args[0]); 40 | File outputFile = new File(args[1]); 41 | 42 | // Perform file compression 43 | try (InputStream in = new BufferedInputStream(new FileInputStream(inputFile)); 44 | BitOutputStream out = new BitOutputStream(new BufferedOutputStream(new FileOutputStream(outputFile)))) { 45 | compress(in, out); 46 | } 47 | } 48 | 49 | 50 | // To allow unit testing, this method is package-private instead of private. 51 | static void compress(InputStream in, BitOutputStream out) throws IOException { 52 | // Set up encoder and model. In this PPM model, symbol 256 represents EOF; 53 | // its frequency is 1 in the order -1 context but its frequency 54 | // is 0 in all other contexts (which have non-negative order). 55 | ArithmeticEncoder enc = new ArithmeticEncoder(32, out); 56 | PpmModel model = new PpmModel(MODEL_ORDER, 257, 256); 57 | int[] history = new int[0]; 58 | 59 | while (true) { 60 | // Read and encode one byte 61 | int symbol = in.read(); 62 | if (symbol == -1) 63 | break; 64 | encodeSymbol(model, history, symbol, enc); 65 | model.incrementContexts(history, symbol); 66 | 67 | if (model.modelOrder >= 1) { 68 | // Prepend current symbol, dropping oldest symbol if necessary 69 | if (history.length < model.modelOrder) 70 | history = Arrays.copyOf(history, history.length + 1); 71 | System.arraycopy(history, 0, history, 1, history.length - 1); 72 | history[0] = symbol; 73 | } 74 | } 75 | 76 | encodeSymbol(model, history, 256, enc); // EOF 77 | enc.finish(); // Flush remaining code bits 78 | } 79 | 80 | 81 | private static void encodeSymbol(PpmModel model, int[] history, int symbol, ArithmeticEncoder enc) throws IOException { 82 | // Try to use highest order context that exists based on the history suffix, such 83 | // that the next symbol has non-zero frequency. When symbol 256 is produced at a context 84 | // at any non-negative order, it means "escape to the next lower order with non-empty 85 | // context". When symbol 256 is produced at the order -1 context, it means "EOF". 86 | outer: 87 | for (int order = history.length; order >= 0; order--) { 88 | PpmModel.Context ctx = model.rootContext; 89 | for (int i = 0; i < order; i++) { 90 | if (ctx.subcontexts == null) 91 | throw new AssertionError(); 92 | ctx = ctx.subcontexts[history[i]]; 93 | if (ctx == null) 94 | continue outer; 95 | } 96 | if (symbol != 256 && ctx.frequencies.get(symbol) > 0) { 97 | enc.write(ctx.frequencies, symbol); 98 | return; 99 | } 100 | // Else write context escape symbol and continue decrementing the order 101 | enc.write(ctx.frequencies, 256); 102 | } 103 | // Logic for order = -1 104 | enc.write(model.orderMinus1Freqs, symbol); 105 | } 106 | 107 | } 108 | -------------------------------------------------------------------------------- /java/src/PpmDecompress.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.BufferedInputStream; 10 | import java.io.BufferedOutputStream; 11 | import java.io.File; 12 | import java.io.FileInputStream; 13 | import java.io.FileOutputStream; 14 | import java.io.IOException; 15 | import java.io.OutputStream; 16 | import java.util.Arrays; 17 | 18 | 19 | /** 20 | * Decompression application using prediction by partial matching (PPM) with arithmetic coding. 21 | *

Usage: java PpmDecompress InputFile OutputFile

22 | *

This decompresses files generated by the "PpmCompress" application.

23 | */ 24 | public final class PpmDecompress { 25 | 26 | // Must be at least -1 and match PpmCompress. Warning: Exponential memory usage at O(257^n). 27 | private static final int MODEL_ORDER = 3; 28 | 29 | 30 | public static void main(String[] args) throws IOException { 31 | // Handle command line arguments 32 | if (args.length != 2) { 33 | System.err.println("Usage: java PpmDecompress InputFile OutputFile"); 34 | System.exit(1); 35 | return; 36 | } 37 | File inputFile = new File(args[0]); 38 | File outputFile = new File(args[1]); 39 | 40 | // Perform file decompression 41 | try (BitInputStream in = new BitInputStream(new BufferedInputStream(new FileInputStream(inputFile))); 42 | OutputStream out = new BufferedOutputStream(new FileOutputStream(outputFile))) { 43 | decompress(in, out); 44 | } 45 | } 46 | 47 | 48 | // To allow unit testing, this method is package-private instead of private. 49 | static void decompress(BitInputStream in, OutputStream out) throws IOException { 50 | // Set up decoder and model. In this PPM model, symbol 256 represents EOF; 51 | // its frequency is 1 in the order -1 context but its frequency 52 | // is 0 in all other contexts (which have non-negative order). 53 | ArithmeticDecoder dec = new ArithmeticDecoder(32, in); 54 | PpmModel model = new PpmModel(MODEL_ORDER, 257, 256); 55 | int[] history = new int[0]; 56 | 57 | while (true) { 58 | // Decode and write one byte 59 | int symbol = decodeSymbol(dec, model, history); 60 | if (symbol == 256) // EOF symbol 61 | break; 62 | out.write(symbol); 63 | model.incrementContexts(history, symbol); 64 | 65 | if (model.modelOrder >= 1) { 66 | // Prepend current symbol, dropping oldest symbol if necessary 67 | if (history.length < model.modelOrder) 68 | history = Arrays.copyOf(history, history.length + 1); 69 | System.arraycopy(history, 0, history, 1, history.length - 1); 70 | history[0] = symbol; 71 | } 72 | } 73 | } 74 | 75 | 76 | private static int decodeSymbol(ArithmeticDecoder dec, PpmModel model, int[] history) throws IOException { 77 | // Try to use highest order context that exists based on the history suffix. When symbol 256 78 | // is consumed at a context at any non-negative order, it means "escape to the next lower order 79 | // with non-empty context". When symbol 256 is consumed at the order -1 context, it means "EOF". 80 | outer: 81 | for (int order = history.length; order >= 0; order--) { 82 | PpmModel.Context ctx = model.rootContext; 83 | for (int i = 0; i < order; i++) { 84 | if (ctx.subcontexts == null) 85 | throw new AssertionError(); 86 | ctx = ctx.subcontexts[history[i]]; 87 | if (ctx == null) 88 | continue outer; 89 | } 90 | int symbol = dec.read(ctx.frequencies); 91 | if (symbol < 256) 92 | return symbol; 93 | // Else we read the context escape symbol, so continue decrementing the order 94 | } 95 | // Logic for order = -1 96 | return dec.read(model.orderMinus1Freqs); 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /java/src/PpmModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | 10 | final class PpmModel { 11 | 12 | /*---- Fields ----*/ 13 | 14 | public final int modelOrder; 15 | 16 | private final int symbolLimit; 17 | private final int escapeSymbol; 18 | 19 | public final Context rootContext; 20 | public final FrequencyTable orderMinus1Freqs; 21 | 22 | 23 | 24 | /*---- Constructors ----*/ 25 | 26 | public PpmModel(int order, int symbolLimit, int escapeSymbol) { 27 | if (!(order >= -1 && 0 <= escapeSymbol && escapeSymbol < symbolLimit)) 28 | throw new IllegalArgumentException(); 29 | this.modelOrder = order; 30 | this.symbolLimit = symbolLimit; 31 | this.escapeSymbol = escapeSymbol; 32 | 33 | if (order >= 0) { 34 | rootContext = new Context(symbolLimit, order >= 1); 35 | rootContext.frequencies.increment(escapeSymbol); 36 | } else 37 | rootContext = null; 38 | orderMinus1Freqs = new FlatFrequencyTable(symbolLimit); 39 | } 40 | 41 | 42 | 43 | /*---- Methods ----*/ 44 | 45 | public void incrementContexts(int[] history, int symbol) { 46 | if (modelOrder == -1) 47 | return; 48 | if (!(history.length <= modelOrder && 0 <= symbol && symbol < symbolLimit)) 49 | throw new IllegalArgumentException(); 50 | 51 | Context ctx = rootContext; 52 | ctx.frequencies.increment(symbol); 53 | int i = 0; 54 | for (int sym : history) { 55 | Context[] subctxs = ctx.subcontexts; 56 | if (subctxs == null) 57 | throw new AssertionError(); 58 | 59 | if (subctxs[sym] == null) { 60 | subctxs[sym] = new Context(symbolLimit, i + 1 < modelOrder); 61 | subctxs[sym].frequencies.increment(escapeSymbol); 62 | } 63 | ctx = subctxs[sym]; 64 | ctx.frequencies.increment(symbol); 65 | i++; 66 | } 67 | } 68 | 69 | 70 | 71 | /*---- Helper structure ----*/ 72 | 73 | public static final class Context { 74 | 75 | public final FrequencyTable frequencies; 76 | 77 | public final Context[] subcontexts; 78 | 79 | 80 | public Context(int symbols, boolean hasSubctx) { 81 | frequencies = new SimpleFrequencyTable(new int[symbols]); 82 | if (hasSubctx) 83 | subcontexts = new Context[symbols]; 84 | else 85 | subcontexts = null; 86 | } 87 | 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /java/src/SimpleFrequencyTable.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.util.Objects; 10 | 11 | 12 | /** 13 | * A mutable table of symbol frequencies. The number of symbols cannot be changed 14 | * after construction. The current algorithm for calculating cumulative frequencies 15 | * takes linear time, but there exist faster algorithms such as Fenwick trees. 16 | */ 17 | public final class SimpleFrequencyTable implements FrequencyTable { 18 | 19 | /*---- Fields ----*/ 20 | 21 | // The frequency for each symbol. Its length is at least 1, and each element is non-negative. 22 | private int[] frequencies; 23 | 24 | // cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 25 | // Initialized lazily. When this is not null, the data is valid. 26 | private int[] cumulative; 27 | 28 | // Always equal to the sum of 'frequencies'. 29 | private int total; 30 | 31 | 32 | 33 | /*---- Constructors ----*/ 34 | 35 | /** 36 | * Constructs a frequency table from the specified array of symbol frequencies. There must be at least 37 | * 1 symbol, no symbol has a negative frequency, and the total must not exceed {@code Integer.MAX_VALUE}. 38 | * @param freqs the array of symbol frequencies 39 | * @throws NullPointerException if the array is {@code null} 40 | * @throws IllegalArgumentException if {@code freqs.length} < 1, 41 | * {@code freqs.length} = {@code Integer.MAX_VALUE}, or any element {@code freqs[i]} < 0 42 | * @throws ArithmeticException if the total of {@code freqs} exceeds {@code Integer.MAX_VALUE} 43 | */ 44 | public SimpleFrequencyTable(int[] freqs) { 45 | Objects.requireNonNull(freqs); 46 | if (freqs.length < 1) 47 | throw new IllegalArgumentException("At least 1 symbol needed"); 48 | if (freqs.length > Integer.MAX_VALUE - 1) 49 | throw new IllegalArgumentException("Too many symbols"); 50 | 51 | frequencies = freqs.clone(); // Make copy 52 | total = 0; 53 | for (int x : frequencies) { 54 | if (x < 0) 55 | throw new IllegalArgumentException("Negative frequency"); 56 | total = Math.addExact(x, total); 57 | } 58 | cumulative = null; 59 | } 60 | 61 | 62 | /** 63 | * Constructs a frequency table by copying the specified frequency table. 64 | * @param freqs the frequency table to copy 65 | * @throws NullPointerException if {@code freqs} is {@code null} 66 | * @throws IllegalArgumentException if {@code freqs.getSymbolLimit()} < 1 67 | * or any element {@code freqs.get(i)} < 0 68 | * @throws ArithmeticException if the total of all {@code freqs} elements exceeds {@code Integer.MAX_VALUE} 69 | */ 70 | public SimpleFrequencyTable(FrequencyTable freqs) { 71 | Objects.requireNonNull(freqs); 72 | int numSym = freqs.getSymbolLimit(); 73 | if (numSym < 1) 74 | throw new IllegalArgumentException("At least 1 symbol needed"); 75 | 76 | frequencies = new int[numSym]; 77 | total = 0; 78 | for (int i = 0; i < frequencies.length; i++) { 79 | int x = freqs.get(i); 80 | if (x < 0) 81 | throw new IllegalArgumentException("Negative frequency"); 82 | frequencies[i] = x; 83 | total = Math.addExact(x, total); 84 | } 85 | cumulative = null; 86 | } 87 | 88 | 89 | 90 | /*---- Methods ----*/ 91 | 92 | /** 93 | * Returns the number of symbols in this frequency table, which is at least 1. 94 | * @return the number of symbols in this frequency table 95 | */ 96 | public int getSymbolLimit() { 97 | return frequencies.length; 98 | } 99 | 100 | 101 | /** 102 | * Returns the frequency of the specified symbol. The returned value is at least 0. 103 | * @param symbol the symbol to query 104 | * @return the frequency of the specified symbol 105 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 106 | */ 107 | public int get(int symbol) { 108 | checkSymbol(symbol); 109 | return frequencies[symbol]; 110 | } 111 | 112 | 113 | /** 114 | * Sets the frequency of the specified symbol to the specified value. The frequency value 115 | * must be at least 0. If an exception is thrown, then the state is left unchanged. 116 | * @param symbol the symbol to set 117 | * @param freq the frequency value to set 118 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 119 | * @throws ArithmeticException if this set request would cause the total to exceed {@code Integer.MAX_VALUE} 120 | */ 121 | public void set(int symbol, int freq) { 122 | checkSymbol(symbol); 123 | if (freq < 0) 124 | throw new IllegalArgumentException("Negative frequency"); 125 | 126 | int temp = total - frequencies[symbol]; 127 | if (temp < 0) 128 | throw new AssertionError(); 129 | total = Math.addExact(temp, freq); 130 | frequencies[symbol] = freq; 131 | cumulative = null; 132 | } 133 | 134 | 135 | /** 136 | * Increments the frequency of the specified symbol. 137 | * @param symbol the symbol whose frequency to increment 138 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 139 | */ 140 | public void increment(int symbol) { 141 | checkSymbol(symbol); 142 | if (frequencies[symbol] == Integer.MAX_VALUE) 143 | throw new ArithmeticException("Arithmetic overflow"); 144 | total = Math.addExact(total, 1); 145 | frequencies[symbol]++; 146 | cumulative = null; 147 | } 148 | 149 | 150 | /** 151 | * Returns the total of all symbol frequencies. The returned value is at 152 | * least 0 and is always equal to {@code getHigh(getSymbolLimit() - 1)}. 153 | * @return the total of all symbol frequencies 154 | */ 155 | public int getTotal() { 156 | return total; 157 | } 158 | 159 | 160 | /** 161 | * Returns the sum of the frequencies of all the symbols strictly 162 | * below the specified symbol value. The returned value is at least 0. 163 | * @param symbol the symbol to query 164 | * @return the sum of the frequencies of all the symbols below {@code symbol} 165 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 166 | */ 167 | public int getLow(int symbol) { 168 | checkSymbol(symbol); 169 | if (cumulative == null) 170 | initCumulative(); 171 | return cumulative[symbol]; 172 | } 173 | 174 | 175 | /** 176 | * Returns the sum of the frequencies of the specified symbol 177 | * and all the symbols below. The returned value is at least 0. 178 | * @param symbol the symbol to query 179 | * @return the sum of the frequencies of {@code symbol} and all symbols below 180 | * @throws IllegalArgumentException if {@code symbol} < 0 or {@code symbol} ≥ {@code getSymbolLimit()} 181 | */ 182 | public int getHigh(int symbol) { 183 | checkSymbol(symbol); 184 | if (cumulative == null) 185 | initCumulative(); 186 | return cumulative[symbol + 1]; 187 | } 188 | 189 | 190 | // Recomputes the array of cumulative symbol frequencies. 191 | private void initCumulative() { 192 | cumulative = new int[frequencies.length + 1]; 193 | int sum = 0; 194 | for (int i = 0; i < frequencies.length; i++) { 195 | // This arithmetic should not throw an exception, because invariants are being maintained 196 | // elsewhere in the data structure. This implementation is just a defensive measure. 197 | sum = Math.addExact(frequencies[i], sum); 198 | cumulative[i + 1] = sum; 199 | } 200 | if (sum != total) 201 | throw new AssertionError(); 202 | } 203 | 204 | 205 | // Returns silently if 0 <= symbol < frequencies.length, otherwise throws an exception. 206 | private void checkSymbol(int symbol) { 207 | if (!(0 <= symbol && symbol < frequencies.length)) 208 | throw new IllegalArgumentException("Symbol out of range"); 209 | } 210 | 211 | 212 | /** 213 | * Returns a string representation of this frequency table, 214 | * useful for debugging only, and the format is subject to change. 215 | * @return a string representation of this frequency table 216 | */ 217 | public String toString() { 218 | StringBuilder sb = new StringBuilder(); 219 | for (int i = 0; i < frequencies.length; i++) 220 | sb.append(String.format("%d\t%d%n", i, frequencies[i])); 221 | return sb.toString(); 222 | } 223 | 224 | } 225 | -------------------------------------------------------------------------------- /java/test/AdaptiveArithmeticCompressTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.ByteArrayInputStream; 10 | import java.io.ByteArrayOutputStream; 11 | import java.io.IOException; 12 | import java.io.InputStream; 13 | 14 | 15 | /** 16 | * Tests {@link AdaptiveArithmeticCompress} coupled with {@link AdaptiveArithmeticDecompress}. 17 | */ 18 | public class AdaptiveArithmeticCompressTest extends ArithmeticCodingTest { 19 | 20 | protected byte[] compress(byte[] b) throws IOException { 21 | InputStream in = new ByteArrayInputStream(b); 22 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 23 | try (BitOutputStream bitOut = new BitOutputStream(out)) { 24 | AdaptiveArithmeticCompress.compress(in, bitOut); 25 | } 26 | return out.toByteArray(); 27 | } 28 | 29 | 30 | protected byte[] decompress(byte[] b) throws IOException { 31 | InputStream in = new ByteArrayInputStream(b); 32 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 33 | AdaptiveArithmeticDecompress.decompress(new BitInputStream(in), out); 34 | return out.toByteArray(); 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /java/test/ArithmeticCodingTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import static org.junit.Assert.assertArrayEquals; 10 | import static org.junit.Assert.assertEquals; 11 | import static org.junit.Assert.fail; 12 | import java.io.EOFException; 13 | import java.io.IOException; 14 | import java.util.Random; 15 | import org.junit.Test; 16 | 17 | 18 | /** 19 | * Tests the compression and decompression of a complete arithmetic coding application, using the JUnit test framework. 20 | */ 21 | public abstract class ArithmeticCodingTest { 22 | 23 | /*---- Test cases ----*/ 24 | 25 | @Test public void testEmpty() { 26 | test(new byte[0]); 27 | } 28 | 29 | 30 | @Test public void testOneSymbol() { 31 | test(new byte[10]); 32 | } 33 | 34 | 35 | @Test public void testSimple() { 36 | test(new byte[]{0, 3, 1, 2}); 37 | } 38 | 39 | 40 | @Test public void testEveryByteValue() { 41 | byte[] b = new byte[256]; 42 | for (int i = 0; i < b.length; i++) 43 | b[i] = (byte)i; 44 | test(b); 45 | } 46 | 47 | 48 | @Test public void testUnderflow() { 49 | test(new byte[]{0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2}); 50 | } 51 | 52 | 53 | @Test public void testUniformRandom() { 54 | for (int i = 0; i < 100; i++) { 55 | byte[] b = new byte[random.nextInt(1000)]; 56 | random.nextBytes(b); 57 | test(b); 58 | } 59 | } 60 | 61 | 62 | @Test public void testRandomDistribution() { 63 | for (int i = 0; i < 1000; i++) { 64 | int m = random.nextInt(255) + 1; // Number of different symbols present 65 | int n = Math.max(random.nextInt(1000), m); // Length of message 66 | 67 | // Create distribution 68 | int[] freqs = new int[m]; 69 | int sum = 0; 70 | for (int j = 0; j < freqs.length; j++) { 71 | freqs[j] = random.nextInt(10000) + 1; 72 | sum += freqs[j]; 73 | } 74 | int total = sum; 75 | 76 | // Rescale frequencies 77 | sum = 0; 78 | int index = 0; 79 | for (int j = 0; j < freqs.length; j++) { 80 | int newsum = sum + freqs[j]; 81 | int newindex = (n - m) * newsum / total + j + 1; 82 | freqs[j] = newindex - index; 83 | sum = newsum; 84 | index = newindex; 85 | } 86 | assertEquals(n, index); 87 | 88 | // Create symbols 89 | byte[] message = new byte[n]; 90 | for (int k = 0, j = 0; k < freqs.length; k++) { 91 | for (int l = 0; l < freqs[k]; l++, j++) 92 | message[j] = (byte)k; 93 | } 94 | 95 | // Shuffle message (Durstenfeld algorithm) 96 | for (int j = 0; j < message.length; j++) { 97 | int k = random.nextInt(message.length - j) + j; 98 | byte temp = message[j]; 99 | message[j] = message[k]; 100 | message[k] = temp; 101 | } 102 | 103 | test(message); 104 | } 105 | } 106 | 107 | 108 | 109 | /*---- Utilities ----*/ 110 | 111 | // Tests that the given byte array can be compressed and decompressed to the same data, and not throw any exceptions. 112 | private void test(byte[] b) { 113 | try { 114 | byte[] compressed = compress(b); 115 | byte[] decompressed = decompress(compressed); 116 | assertArrayEquals(b, decompressed); 117 | } catch (EOFException e) { 118 | fail("Unexpected EOF"); 119 | } catch (IOException e) { 120 | throw new AssertionError(e); 121 | } 122 | } 123 | 124 | 125 | private static Random random = new Random(); 126 | 127 | 128 | 129 | /*---- Abstract methods ----*/ 130 | 131 | // Compression method that needs to be supplied by a subclass. 132 | protected abstract byte[] compress(byte[] b) throws IOException; 133 | 134 | // Decompression method that needs to be supplied by a subclass. 135 | protected abstract byte[] decompress(byte[] b) throws IOException; 136 | 137 | } 138 | -------------------------------------------------------------------------------- /java/test/ArithmeticCompressTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.ByteArrayInputStream; 10 | import java.io.ByteArrayOutputStream; 11 | import java.io.IOException; 12 | import java.io.InputStream; 13 | 14 | 15 | /** 16 | * Tests {@link ArithmeticCompress} coupled with {@link ArithmeticDecompress}. 17 | */ 18 | public class ArithmeticCompressTest extends ArithmeticCodingTest { 19 | 20 | protected byte[] compress(byte[] b) throws IOException { 21 | FrequencyTable freqs = new SimpleFrequencyTable(new int[257]); 22 | for (byte x : b) 23 | freqs.increment(x & 0xFF); 24 | freqs.increment(256); // EOF symbol gets a frequency of 1 25 | 26 | InputStream in = new ByteArrayInputStream(b); 27 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 28 | try (BitOutputStream bitOut = new BitOutputStream(out)) { 29 | ArithmeticCompress.writeFrequencies(bitOut, freqs); 30 | ArithmeticCompress.compress(freqs, in, bitOut); 31 | } 32 | return out.toByteArray(); 33 | } 34 | 35 | 36 | protected byte[] decompress(byte[] b) throws IOException { 37 | InputStream in = new ByteArrayInputStream(b); 38 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 39 | BitInputStream bitIn = new BitInputStream(in); 40 | 41 | FrequencyTable freqs = ArithmeticDecompress.readFrequencies(bitIn); 42 | ArithmeticDecompress.decompress(freqs, bitIn, out); 43 | return out.toByteArray(); 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /java/test/PpmCompressTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Reference arithmetic coding 3 | * 4 | * Copyright (c) Project Nayuki 5 | * MIT License. See readme file. 6 | * https://www.nayuki.io/page/reference-arithmetic-coding 7 | */ 8 | 9 | import java.io.ByteArrayInputStream; 10 | import java.io.ByteArrayOutputStream; 11 | import java.io.IOException; 12 | import java.io.InputStream; 13 | 14 | 15 | /** 16 | * Tests {@link PpmCompress} coupled with {@link PpmDecompress}. 17 | */ 18 | public class PpmCompressTest extends ArithmeticCodingTest { 19 | 20 | protected byte[] compress(byte[] b) throws IOException { 21 | InputStream in = new ByteArrayInputStream(b); 22 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 23 | try (BitOutputStream bitOut = new BitOutputStream(out)) { 24 | PpmCompress.compress(in, bitOut); 25 | } 26 | return out.toByteArray(); 27 | } 28 | 29 | 30 | protected byte[] decompress(byte[] b) throws IOException { 31 | InputStream in = new ByteArrayInputStream(b); 32 | ByteArrayOutputStream out = new ByteArrayOutputStream(); 33 | PpmDecompress.decompress(new BitInputStream(in), out); 34 | return out.toByteArray(); 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /python/adaptive-arithmetic-compress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Compression application using adaptive arithmetic coding 3 | # 4 | # Usage: python adaptive-arithmetic-compress.py InputFile OutputFile 5 | # Then use the corresponding adaptive-arithmetic-decompress.py application to recreate the original input file. 6 | # Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1), 7 | # and updates it after each byte encoded. The corresponding decompressor program also starts with a flat 8 | # frequency table and updates it after each byte decoded. It is by design that the compressor and 9 | # decompressor have synchronized states, so that the data can be decompressed properly. 10 | # 11 | # Copyright (c) Project Nayuki 12 | # MIT License. See readme file. 13 | # https://www.nayuki.io/page/reference-arithmetic-coding 14 | # 15 | 16 | import contextlib, sys 17 | import arithmeticcoding 18 | 19 | 20 | # Command line main application function. 21 | def main(args): 22 | # Handle command line arguments 23 | if len(args) != 2: 24 | sys.exit("Usage: python adaptive-arithmetic-compress.py InputFile OutputFile") 25 | inputfile, outputfile = args 26 | 27 | # Perform file compression 28 | with open(inputfile, "rb") as inp, \ 29 | contextlib.closing(arithmeticcoding.BitOutputStream(open(outputfile, "wb"))) as bitout: 30 | compress(inp, bitout) 31 | 32 | 33 | def compress(inp, bitout): 34 | initfreqs = arithmeticcoding.FlatFrequencyTable(257) 35 | freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs) 36 | enc = arithmeticcoding.ArithmeticEncoder(32, bitout) 37 | while True: 38 | # Read and encode one byte 39 | symbol = inp.read(1) 40 | if len(symbol) == 0: 41 | break 42 | enc.write(freqs, symbol[0]) 43 | freqs.increment(symbol[0]) 44 | enc.write(freqs, 256) # EOF 45 | enc.finish() # Flush remaining code bits 46 | 47 | 48 | # Main launcher 49 | if __name__ == "__main__": 50 | main(sys.argv[1 : ]) 51 | -------------------------------------------------------------------------------- /python/adaptive-arithmetic-decompress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Decompression application using adaptive arithmetic coding 3 | # 4 | # Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile 5 | # This decompresses files generated by the adaptive-arithmetic-compress.py application. 6 | # 7 | # Copyright (c) Project Nayuki 8 | # MIT License. See readme file. 9 | # https://www.nayuki.io/page/reference-arithmetic-coding 10 | # 11 | 12 | import sys 13 | import arithmeticcoding 14 | 15 | 16 | # Command line main application function. 17 | def main(args): 18 | # Handle command line arguments 19 | if len(args) != 2: 20 | sys.exit("Usage: python adaptive-arithmetic-decompress.py InputFile OutputFile") 21 | inputfile, outputfile = args 22 | 23 | # Perform file decompression 24 | with open(inputfile, "rb") as inp, open(outputfile, "wb") as out: 25 | bitin = arithmeticcoding.BitInputStream(inp) 26 | decompress(bitin, out) 27 | 28 | 29 | def decompress(bitin, out): 30 | initfreqs = arithmeticcoding.FlatFrequencyTable(257) 31 | freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs) 32 | dec = arithmeticcoding.ArithmeticDecoder(32, bitin) 33 | while True: 34 | # Decode and write one byte 35 | symbol = dec.read(freqs) 36 | if symbol == 256: # EOF symbol 37 | break 38 | out.write(bytes((symbol,))) 39 | freqs.increment(symbol) 40 | 41 | 42 | # Main launcher 43 | if __name__ == "__main__": 44 | main(sys.argv[1 : ]) 45 | -------------------------------------------------------------------------------- /python/arithmetic-compress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Compression application using static arithmetic coding 3 | # 4 | # Usage: python arithmetic-compress.py InputFile OutputFile 5 | # Then use the corresponding arithmetic-decompress.py application to recreate the original input file. 6 | # Note that the application uses an alphabet of 257 symbols - 256 symbols for the byte 7 | # values and 1 symbol for the EOF marker. The compressed file format starts with a list 8 | # of 256 symbol frequencies, and then followed by the arithmetic-coded data. 9 | # 10 | # Copyright (c) Project Nayuki 11 | # MIT License. See readme file. 12 | # https://www.nayuki.io/page/reference-arithmetic-coding 13 | # 14 | 15 | import contextlib, sys 16 | import arithmeticcoding 17 | 18 | 19 | # Command line main application function. 20 | def main(args): 21 | # Handle command line arguments 22 | if len(args) != 2: 23 | sys.exit("Usage: python arithmetic-compress.py InputFile OutputFile") 24 | inputfile, outputfile = args 25 | 26 | # Read input file once to compute symbol frequencies 27 | freqs = get_frequencies(inputfile) 28 | freqs.increment(256) # EOF symbol gets a frequency of 1 29 | 30 | # Read input file again, compress with arithmetic coding, and write output file 31 | with open(inputfile, "rb") as inp, \ 32 | contextlib.closing(arithmeticcoding.BitOutputStream(open(outputfile, "wb"))) as bitout: 33 | write_frequencies(bitout, freqs) 34 | compress(freqs, inp, bitout) 35 | 36 | 37 | # Returns a frequency table based on the bytes in the given file. 38 | # Also contains an extra entry for symbol 256, whose frequency is set to 0. 39 | def get_frequencies(filepath): 40 | freqs = arithmeticcoding.SimpleFrequencyTable([0] * 257) 41 | with open(filepath, "rb") as input: 42 | while True: 43 | b = input.read(1) 44 | if len(b) == 0: 45 | break 46 | freqs.increment(b[0]) 47 | return freqs 48 | 49 | 50 | def write_frequencies(bitout, freqs): 51 | for i in range(256): 52 | write_int(bitout, 32, freqs.get(i)) 53 | 54 | 55 | def compress(freqs, inp, bitout): 56 | enc = arithmeticcoding.ArithmeticEncoder(32, bitout) 57 | while True: 58 | symbol = inp.read(1) 59 | if len(symbol) == 0: 60 | break 61 | enc.write(freqs, symbol[0]) 62 | enc.write(freqs, 256) # EOF 63 | enc.finish() # Flush remaining code bits 64 | 65 | 66 | # Writes an unsigned integer of the given bit width to the given stream. 67 | def write_int(bitout, numbits, value): 68 | for i in reversed(range(numbits)): 69 | bitout.write((value >> i) & 1) # Big endian 70 | 71 | 72 | # Main launcher 73 | if __name__ == "__main__": 74 | main(sys.argv[1 : ]) 75 | -------------------------------------------------------------------------------- /python/arithmetic-decompress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Decompression application using static arithmetic coding 3 | # 4 | # Usage: python arithmetic-decompress.py InputFile OutputFile 5 | # This decompresses files generated by the arithmetic-compress.py application. 6 | # 7 | # Copyright (c) Project Nayuki 8 | # MIT License. See readme file. 9 | # https://www.nayuki.io/page/reference-arithmetic-coding 10 | # 11 | 12 | import sys 13 | import arithmeticcoding 14 | 15 | 16 | # Command line main application function. 17 | def main(args): 18 | # Handle command line arguments 19 | if len(args) != 2: 20 | sys.exit("Usage: python arithmetic-decompress.py InputFile OutputFile") 21 | inputfile, outputfile = args 22 | 23 | # Perform file decompression 24 | with open(outputfile, "wb") as out, open(inputfile, "rb") as inp: 25 | bitin = arithmeticcoding.BitInputStream(inp) 26 | freqs = read_frequencies(bitin) 27 | decompress(freqs, bitin, out) 28 | 29 | 30 | def read_frequencies(bitin): 31 | def read_int(n): 32 | result = 0 33 | for _ in range(n): 34 | result = (result << 1) | bitin.read_no_eof() # Big endian 35 | return result 36 | 37 | freqs = [read_int(32) for _ in range(256)] 38 | freqs.append(1) # EOF symbol 39 | return arithmeticcoding.SimpleFrequencyTable(freqs) 40 | 41 | 42 | def decompress(freqs, bitin, out): 43 | dec = arithmeticcoding.ArithmeticDecoder(32, bitin) 44 | while True: 45 | symbol = dec.read(freqs) 46 | if symbol == 256: # EOF symbol 47 | break 48 | out.write(bytes((symbol,))) 49 | 50 | 51 | # Main launcher 52 | if __name__ == "__main__": 53 | main(sys.argv[1 : ]) 54 | -------------------------------------------------------------------------------- /python/arithmeticcoding.py: -------------------------------------------------------------------------------- 1 | # 2 | # Reference arithmetic coding 3 | # 4 | # Copyright (c) Project Nayuki 5 | # MIT License. See readme file. 6 | # https://www.nayuki.io/page/reference-arithmetic-coding 7 | # 8 | 9 | 10 | # ---- Arithmetic coding core classes ---- 11 | 12 | # Provides the state and behaviors that arithmetic coding encoders and decoders share. 13 | class ArithmeticCoderBase: 14 | 15 | # Constructs an arithmetic coder, which initializes the code range. 16 | def __init__(self, numbits): 17 | if numbits < 1: 18 | raise ValueError("State size out of range") 19 | 20 | # -- Configuration fields -- 21 | # Number of bits for the 'low' and 'high' state variables. Must be at least 1. 22 | # - Larger values are generally better - they allow a larger maximum frequency total (maximum_total), 23 | # and they reduce the approximation error inherent in adapting fractions to integers; 24 | # both effects reduce the data encoding loss and asymptotically approach the efficiency 25 | # of arithmetic coding using exact fractions. 26 | # - But larger state sizes increase the computation time for integer arithmetic, 27 | # and compression gains beyond ~30 bits essentially zero in real-world applications. 28 | # - Python has native bigint arithmetic, so there is no upper limit to the state size. 29 | # For Java and C++ where using native machine-sized integers makes the most sense, 30 | # they have a recommended value of num_state_bits=32 as the most versatile setting. 31 | self.num_state_bits = numbits 32 | # Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000. 33 | self.full_range = 1 << self.num_state_bits 34 | # The top bit at width num_state_bits, which is 0100...000. 35 | self.half_range = self.full_range >> 1 # Non-zero 36 | # The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1. 37 | self.quarter_range = self.half_range >> 1 # Can be zero 38 | # Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. 39 | self.minimum_range = self.quarter_range + 2 # At least 2 40 | # Maximum allowed total from a frequency table at all times during coding. This differs from Java 41 | # and C++ because Python's native bigint avoids constraining the size of intermediate computations. 42 | self.maximum_total = self.minimum_range 43 | # Bit mask of num_state_bits ones, which is 0111...111. 44 | self.state_mask = self.full_range - 1 45 | 46 | # -- State fields -- 47 | # Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. 48 | self.low = 0 49 | # High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. 50 | self.high = self.state_mask 51 | 52 | 53 | # Updates the code range (low and high) of this arithmetic coder as a result 54 | # of processing the given symbol with the given frequency table. 55 | # Invariants that are true before and after encoding/decoding each symbol 56 | # (letting full_range = 2^num_state_bits): 57 | # - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.) 58 | # Therefore these variables are unsigned integers of num_state_bits bits. 59 | # - low < 1/2 * full_range <= high. 60 | # In other words, they are in different halves of the full range. 61 | # - (low < 1/4 * full_range) || (high >= 3/4 * full_range). 62 | # In other words, they are not both in the middle two quarters. 63 | # - Let range = high - low + 1, then full_range/4 < minimum_range 64 | # <= range <= full_range. These invariants for 'range' essentially 65 | # dictate the maximum total that the incoming frequency table can have. 66 | def update(self, freqs, symbol): 67 | # State check 68 | low = self.low 69 | high = self.high 70 | if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high: 71 | raise AssertionError("Low or high out of range") 72 | range = high - low + 1 73 | if not (self.minimum_range <= range <= self.full_range): 74 | raise AssertionError("Range out of range") 75 | 76 | # Frequency table values check 77 | total = freqs.get_total() 78 | symlow = freqs.get_low(symbol) 79 | symhigh = freqs.get_high(symbol) 80 | if symlow == symhigh: 81 | raise ValueError("Symbol has zero frequency") 82 | if total > self.maximum_total: 83 | raise ValueError("Cannot code symbol because total is too large") 84 | 85 | # Update range 86 | newlow = low + symlow * range // total 87 | newhigh = low + symhigh * range // total - 1 88 | self.low = newlow 89 | self.high = newhigh 90 | 91 | # While low and high have the same top bit value, shift them out 92 | while ((self.low ^ self.high) & self.half_range) == 0: 93 | self.shift() 94 | self.low = ((self.low << 1) & self.state_mask) 95 | self.high = ((self.high << 1) & self.state_mask) | 1 96 | # Now low's top bit must be 0 and high's top bit must be 1 97 | 98 | # While low's top two bits are 01 and high's are 10, delete the second highest bit of both 99 | while (self.low & ~self.high & self.quarter_range) != 0: 100 | self.underflow() 101 | self.low = (self.low << 1) ^ self.half_range 102 | self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1 103 | 104 | 105 | # Called to handle the situation when the top bit of 'low' and 'high' are equal. 106 | def shift(self): 107 | raise NotImplementedError() 108 | 109 | 110 | # Called to handle the situation when low=01(...) and high=10(...). 111 | def underflow(self): 112 | raise NotImplementedError() 113 | 114 | 115 | 116 | # Encodes symbols and writes to an arithmetic-coded bit stream. 117 | class ArithmeticEncoder(ArithmeticCoderBase): 118 | 119 | # Constructs an arithmetic coding encoder based on the given bit output stream. 120 | def __init__(self, numbits, bitout): 121 | super(ArithmeticEncoder, self).__init__(numbits) 122 | # The underlying bit output stream. 123 | self.output = bitout 124 | # Number of saved underflow bits. This value can grow without bound. 125 | self.num_underflow = 0 126 | 127 | 128 | # Encodes the given symbol based on the given frequency table. 129 | # This updates this arithmetic coder's state and may write out some bits. 130 | def write(self, freqs, symbol): 131 | if not isinstance(freqs, CheckedFrequencyTable): 132 | freqs = CheckedFrequencyTable(freqs) 133 | self.update(freqs, symbol) 134 | 135 | 136 | # Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. 137 | # It is important that this method must be called at the end of the each encoding process. 138 | # Note that this method merely writes data to the underlying output stream but does not close it. 139 | def finish(self): 140 | self.output.write(1) 141 | 142 | 143 | def shift(self): 144 | bit = self.low >> (self.num_state_bits - 1) 145 | self.output.write(bit) 146 | 147 | # Write out the saved underflow bits 148 | for _ in range(self.num_underflow): 149 | self.output.write(bit ^ 1) 150 | self.num_underflow = 0 151 | 152 | 153 | def underflow(self): 154 | self.num_underflow += 1 155 | 156 | 157 | 158 | # Reads from an arithmetic-coded bit stream and decodes symbols. 159 | class ArithmeticDecoder(ArithmeticCoderBase): 160 | 161 | # Constructs an arithmetic coding decoder based on the 162 | # given bit input stream, and fills the code bits. 163 | def __init__(self, numbits, bitin): 164 | super(ArithmeticDecoder, self).__init__(numbits) 165 | # The underlying bit input stream. 166 | self.input = bitin 167 | # The current raw code bits being buffered, which is always in the range [low, high]. 168 | self.code = 0 169 | for _ in range(self.num_state_bits): 170 | self.code = self.code << 1 | self.read_code_bit() 171 | 172 | 173 | # Decodes the next symbol based on the given frequency table and returns it. 174 | # Also updates this arithmetic coder's state and may read in some bits. 175 | def read(self, freqs): 176 | if not isinstance(freqs, CheckedFrequencyTable): 177 | freqs = CheckedFrequencyTable(freqs) 178 | 179 | # Translate from coding range scale to frequency table scale 180 | total = freqs.get_total() 181 | if total > self.maximum_total: 182 | raise ValueError("Cannot decode symbol because total is too large") 183 | range = self.high - self.low + 1 184 | offset = self.code - self.low 185 | value = ((offset + 1) * total - 1) // range 186 | assert value * range // total <= offset 187 | assert 0 <= value < total 188 | 189 | # A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value. 190 | start = 0 191 | end = freqs.get_symbol_limit() 192 | while end - start > 1: 193 | middle = (start + end) >> 1 194 | if freqs.get_low(middle) > value: 195 | end = middle 196 | else: 197 | start = middle 198 | assert start + 1 == end 199 | 200 | symbol = start 201 | assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total 202 | self.update(freqs, symbol) 203 | if not (self.low <= self.code <= self.high): 204 | raise AssertionError("Code out of range") 205 | return symbol 206 | 207 | 208 | def shift(self): 209 | self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit() 210 | 211 | 212 | def underflow(self): 213 | self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit() 214 | 215 | 216 | # Returns the next bit (0 or 1) from the input stream. The end 217 | # of stream is treated as an infinite number of trailing zeros. 218 | def read_code_bit(self): 219 | temp = self.input.read() 220 | if temp == -1: 221 | temp = 0 222 | return temp 223 | 224 | 225 | 226 | # ---- Frequency table classes ---- 227 | 228 | # A table of symbol frequencies. The table holds data for symbols numbered from 0 229 | # to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer. 230 | # Frequency table objects are primarily used for getting cumulative symbol 231 | # frequencies. These objects can be mutable depending on the implementation. 232 | class FrequencyTable: 233 | 234 | # Returns the number of symbols in this frequency table, which is a positive number. 235 | def get_symbol_limit(self): 236 | raise NotImplementedError() 237 | 238 | # Returns the frequency of the given symbol. The returned value is at least 0. 239 | def get(self, symbol): 240 | raise NotImplementedError() 241 | 242 | # Sets the frequency of the given symbol to the given value. 243 | # The frequency value must be at least 0. 244 | def set(self, symbol, freq): 245 | raise NotImplementedError() 246 | 247 | # Increments the frequency of the given symbol. 248 | def increment(self, symbol): 249 | raise NotImplementedError() 250 | 251 | # Returns the total of all symbol frequencies. The returned value is at 252 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 253 | def get_total(self): 254 | raise NotImplementedError() 255 | 256 | # Returns the sum of the frequencies of all the symbols strictly 257 | # below the given symbol value. The returned value is at least 0. 258 | def get_low(self, symbol): 259 | raise NotImplementedError() 260 | 261 | # Returns the sum of the frequencies of the given symbol 262 | # and all the symbols below. The returned value is at least 0. 263 | def get_high(self, symbol): 264 | raise NotImplementedError() 265 | 266 | 267 | 268 | # An immutable frequency table where every symbol has the same frequency of 1. 269 | # Useful as a fallback model when no statistics are available. 270 | class FlatFrequencyTable(FrequencyTable): 271 | 272 | # Constructs a flat frequency table with the given number of symbols. 273 | def __init__(self, numsyms): 274 | if numsyms < 1: 275 | raise ValueError("Number of symbols must be positive") 276 | self.numsymbols = numsyms # Total number of symbols, which is at least 1 277 | 278 | # Returns the number of symbols in this table, which is at least 1. 279 | def get_symbol_limit(self): 280 | return self.numsymbols 281 | 282 | # Returns the frequency of the given symbol, which is always 1. 283 | def get(self, symbol): 284 | self._check_symbol(symbol) 285 | return 1 286 | 287 | # Returns the total of all symbol frequencies, which is 288 | # always equal to the number of symbols in this table. 289 | def get_total(self): 290 | return self.numsymbols 291 | 292 | # Returns the sum of the frequencies of all the symbols strictly below 293 | # the given symbol value. The returned value is equal to 'symbol'. 294 | def get_low(self, symbol): 295 | self._check_symbol(symbol) 296 | return symbol 297 | 298 | 299 | # Returns the sum of the frequencies of the given symbol and all 300 | # the symbols below. The returned value is equal to 'symbol' + 1. 301 | def get_high(self, symbol): 302 | self._check_symbol(symbol) 303 | return symbol + 1 304 | 305 | 306 | # Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception. 307 | def _check_symbol(self, symbol): 308 | if not (0 <= symbol < self.numsymbols): 309 | raise ValueError("Symbol out of range") 310 | 311 | # Returns a string representation of this frequency table. The format is subject to change. 312 | def __str__(self): 313 | return "FlatFrequencyTable={}".format(self.numsymbols) 314 | 315 | # Unsupported operation, because this frequency table is immutable. 316 | def set(self, symbol, freq): 317 | raise NotImplementedError() 318 | 319 | # Unsupported operation, because this frequency table is immutable. 320 | def increment(self, symbol): 321 | raise NotImplementedError() 322 | 323 | 324 | 325 | # A mutable table of symbol frequencies. The number of symbols cannot be changed 326 | # after construction. The current algorithm for calculating cumulative frequencies 327 | # takes linear time, but there exist faster algorithms such as Fenwick trees. 328 | class SimpleFrequencyTable(FrequencyTable): 329 | 330 | # Constructs a simple frequency table in one of two ways: 331 | # - SimpleFrequencyTable(sequence): 332 | # Builds a frequency table from the given sequence of symbol frequencies. 333 | # There must be at least 1 symbol, and no symbol has a negative frequency. 334 | # - SimpleFrequencyTable(freqtable): 335 | # Builds a frequency table by copying the given frequency table. 336 | def __init__(self, freqs): 337 | if isinstance(freqs, FrequencyTable): 338 | numsym = freqs.get_symbol_limit() 339 | self.frequencies = [freqs.get(i) for i in range(numsym)] 340 | else: # Assume it is a sequence type 341 | self.frequencies = list(freqs) # Make copy 342 | 343 | # 'frequencies' is a list of the frequency for each symbol. 344 | # Its length is at least 1, and each element is non-negative. 345 | if len(self.frequencies) < 1: 346 | raise ValueError("At least 1 symbol needed") 347 | for freq in self.frequencies: 348 | if freq < 0: 349 | raise ValueError("Negative frequency") 350 | 351 | # Always equal to the sum of 'frequencies' 352 | self.total = sum(self.frequencies) 353 | 354 | # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 355 | # Initialized lazily. When it is not None, the data is valid. 356 | self.cumulative = None 357 | 358 | 359 | # Returns the number of symbols in this frequency table, which is at least 1. 360 | def get_symbol_limit(self): 361 | return len(self.frequencies) 362 | 363 | 364 | # Returns the frequency of the given symbol. The returned value is at least 0. 365 | def get(self, symbol): 366 | self._check_symbol(symbol) 367 | return self.frequencies[symbol] 368 | 369 | 370 | # Sets the frequency of the given symbol to the given value. The frequency value 371 | # must be at least 0. If an exception is raised, then the state is left unchanged. 372 | def set(self, symbol, freq): 373 | self._check_symbol(symbol) 374 | if freq < 0: 375 | raise ValueError("Negative frequency") 376 | temp = self.total - self.frequencies[symbol] 377 | assert temp >= 0 378 | self.total = temp + freq 379 | self.frequencies[symbol] = freq 380 | self.cumulative = None 381 | 382 | 383 | # Increments the frequency of the given symbol. 384 | def increment(self, symbol): 385 | self._check_symbol(symbol) 386 | self.total += 1 387 | self.frequencies[symbol] += 1 388 | self.cumulative = None 389 | 390 | 391 | # Returns the total of all symbol frequencies. The returned value is at 392 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 393 | def get_total(self): 394 | return self.total 395 | 396 | 397 | # Returns the sum of the frequencies of all the symbols strictly 398 | # below the given symbol value. The returned value is at least 0. 399 | def get_low(self, symbol): 400 | self._check_symbol(symbol) 401 | if self.cumulative is None: 402 | self._init_cumulative() 403 | return self.cumulative[symbol] 404 | 405 | 406 | # Returns the sum of the frequencies of the given symbol 407 | # and all the symbols below. The returned value is at least 0. 408 | def get_high(self, symbol): 409 | self._check_symbol(symbol) 410 | if self.cumulative is None: 411 | self._init_cumulative() 412 | return self.cumulative[symbol + 1] 413 | 414 | 415 | # Recomputes the array of cumulative symbol frequencies. 416 | def _init_cumulative(self): 417 | cumul = [0] 418 | sum = 0 419 | for freq in self.frequencies: 420 | sum += freq 421 | cumul.append(sum) 422 | assert sum == self.total 423 | self.cumulative = cumul 424 | 425 | 426 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 427 | def _check_symbol(self, symbol): 428 | if not (0 <= symbol < len(self.frequencies)): 429 | raise ValueError("Symbol out of range") 430 | 431 | 432 | # Returns a string representation of this frequency table, 433 | # useful for debugging only, and the format is subject to change. 434 | def __str__(self): 435 | result = "" 436 | for (i, freq) in enumerate(self.frequencies): 437 | result += "{}\t{}\n".format(i, freq) 438 | return result 439 | 440 | 441 | 442 | # A wrapper that checks the preconditions (arguments) and postconditions (return value) of all 443 | # the frequency table methods. Useful for finding faults in a frequency table implementation. 444 | class CheckedFrequencyTable(FrequencyTable): 445 | 446 | def __init__(self, freqtab): 447 | # The underlying frequency table that holds the data 448 | self.freqtable = freqtab 449 | 450 | 451 | def get_symbol_limit(self): 452 | result = self.freqtable.get_symbol_limit() 453 | if result <= 0: 454 | raise AssertionError("Non-positive symbol limit") 455 | return result 456 | 457 | 458 | def get(self, symbol): 459 | result = self.freqtable.get(symbol) 460 | if not self._is_symbol_in_range(symbol): 461 | raise AssertionError("ValueError expected") 462 | if result < 0: 463 | raise AssertionError("Negative symbol frequency") 464 | return result 465 | 466 | 467 | def get_total(self): 468 | result = self.freqtable.get_total() 469 | if result < 0: 470 | raise AssertionError("Negative total frequency") 471 | return result 472 | 473 | 474 | def get_low(self, symbol): 475 | if self._is_symbol_in_range(symbol): 476 | low = self.freqtable.get_low (symbol) 477 | high = self.freqtable.get_high(symbol) 478 | if not (0 <= low <= high <= self.freqtable.get_total()): 479 | raise AssertionError("Symbol low cumulative frequency out of range") 480 | return low 481 | else: 482 | self.freqtable.get_low(symbol) 483 | raise AssertionError("ValueError expected") 484 | 485 | 486 | def get_high(self, symbol): 487 | if self._is_symbol_in_range(symbol): 488 | low = self.freqtable.get_low (symbol) 489 | high = self.freqtable.get_high(symbol) 490 | if not (0 <= low <= high <= self.freqtable.get_total()): 491 | raise AssertionError("Symbol high cumulative frequency out of range") 492 | return high 493 | else: 494 | self.freqtable.get_high(symbol) 495 | raise AssertionError("ValueError expected") 496 | 497 | 498 | def __str__(self): 499 | return "CheckedFrequencyTable (" + str(self.freqtable) + ")" 500 | 501 | 502 | def set(self, symbol, freq): 503 | self.freqtable.set(symbol, freq) 504 | if not self._is_symbol_in_range(symbol) or freq < 0: 505 | raise AssertionError("ValueError expected") 506 | 507 | 508 | def increment(self, symbol): 509 | self.freqtable.increment(symbol) 510 | if not self._is_symbol_in_range(symbol): 511 | raise AssertionError("ValueError expected") 512 | 513 | 514 | def _is_symbol_in_range(self, symbol): 515 | return 0 <= symbol < self.get_symbol_limit() 516 | 517 | 518 | 519 | # ---- Bit-oriented I/O streams ---- 520 | 521 | # A stream of bits that can be read. Because they come from an underlying byte stream, 522 | # the total number of bits is always a multiple of 8. The bits are read in big endian. 523 | class BitInputStream: 524 | 525 | # Constructs a bit input stream based on the given byte input stream. 526 | def __init__(self, inp): 527 | # The underlying byte stream to read from 528 | self.input = inp 529 | # Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached 530 | self.currentbyte = 0 531 | # Number of remaining bits in the current byte, always between 0 and 7 (inclusive) 532 | self.numbitsremaining = 0 533 | 534 | 535 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if 536 | # the end of stream is reached. The end of stream always occurs on a byte boundary. 537 | def read(self): 538 | if self.currentbyte == -1: 539 | return -1 540 | if self.numbitsremaining == 0: 541 | temp = self.input.read(1) 542 | if len(temp) == 0: 543 | self.currentbyte = -1 544 | return -1 545 | self.currentbyte = temp[0] 546 | self.numbitsremaining = 8 547 | assert self.numbitsremaining > 0 548 | self.numbitsremaining -= 1 549 | return (self.currentbyte >> self.numbitsremaining) & 1 550 | 551 | 552 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError 553 | # if the end of stream is reached. The end of stream always occurs on a byte boundary. 554 | def read_no_eof(self): 555 | result = self.read() 556 | if result != -1: 557 | return result 558 | else: 559 | raise EOFError() 560 | 561 | 562 | # Closes this stream and the underlying input stream. 563 | def close(self): 564 | self.input.close() 565 | self.currentbyte = -1 566 | self.numbitsremaining = 0 567 | 568 | 569 | 570 | # A stream where bits can be written to. Because they are written to an underlying 571 | # byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. 572 | # The bits are written in big endian. 573 | class BitOutputStream: 574 | 575 | # Constructs a bit output stream based on the given byte output stream. 576 | def __init__(self, out): 577 | self.output = out # The underlying byte stream to write to 578 | self.currentbyte = 0 # The accumulated bits for the current byte, always in the range [0x00, 0xFF] 579 | self.numbitsfilled = 0 # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive) 580 | 581 | 582 | # Writes a bit to the stream. The given bit must be 0 or 1. 583 | def write(self, b): 584 | if b not in (0, 1): 585 | raise ValueError("Argument must be 0 or 1") 586 | self.currentbyte = (self.currentbyte << 1) | b 587 | self.numbitsfilled += 1 588 | if self.numbitsfilled == 8: 589 | towrite = bytes((self.currentbyte,)) 590 | self.output.write(towrite) 591 | self.currentbyte = 0 592 | self.numbitsfilled = 0 593 | 594 | 595 | # Closes this stream and the underlying output stream. If called when this 596 | # bit stream is not at a byte boundary, then the minimum number of "0" bits 597 | # (between 0 and 7 of them) are written as padding to reach the next byte boundary. 598 | def close(self): 599 | while self.numbitsfilled != 0: 600 | self.write(0) 601 | self.output.close() 602 | -------------------------------------------------------------------------------- /python/ppm-compress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Compression application using prediction by partial matching (PPM) with arithmetic coding 3 | # 4 | # Usage: python ppm-compress.py InputFile OutputFile 5 | # Then use the corresponding ppm-decompress.py application to recreate the original input file. 6 | # Note that both the compressor and decompressor need to use the same PPM context modeling logic. 7 | # The PPM algorithm can be thought of as a powerful generalization of adaptive arithmetic coding. 8 | # 9 | # Copyright (c) Project Nayuki 10 | # MIT License. See readme file. 11 | # https://www.nayuki.io/page/reference-arithmetic-coding 12 | # 13 | 14 | import contextlib, sys 15 | import arithmeticcoding, ppmmodel 16 | 17 | 18 | # Must be at least -1 and match ppm-decompress.py. Warning: Exponential memory usage at O(257^n). 19 | MODEL_ORDER = 3 20 | 21 | 22 | # Command line main application function. 23 | def main(args): 24 | # Handle command line arguments 25 | if len(args) != 2: 26 | sys.exit("Usage: python ppm-compress.py InputFile OutputFile") 27 | inputfile = args[0] 28 | outputfile = args[1] 29 | 30 | # Perform file compression 31 | with open(inputfile, "rb") as inp, \ 32 | contextlib.closing(arithmeticcoding.BitOutputStream(open(outputfile, "wb"))) as bitout: 33 | compress(inp, bitout) 34 | 35 | 36 | def compress(inp, bitout): 37 | # Set up encoder and model. In this PPM model, symbol 256 represents EOF; 38 | # its frequency is 1 in the order -1 context but its frequency 39 | # is 0 in all other contexts (which have non-negative order). 40 | enc = arithmeticcoding.ArithmeticEncoder(32, bitout) 41 | model = ppmmodel.PpmModel(MODEL_ORDER, 257, 256) 42 | history = [] 43 | 44 | while True: 45 | # Read and encode one byte 46 | symbol = inp.read(1) 47 | if len(symbol) == 0: 48 | break 49 | symbol = symbol[0] 50 | encode_symbol(model, history, symbol, enc) 51 | model.increment_contexts(history, symbol) 52 | 53 | if model.model_order >= 1: 54 | # Prepend current symbol, dropping oldest symbol if necessary 55 | if len(history) == model.model_order: 56 | history.pop() 57 | history.insert(0, symbol) 58 | 59 | encode_symbol(model, history, 256, enc) # EOF 60 | enc.finish() # Flush remaining code bits 61 | 62 | 63 | def encode_symbol(model, history, symbol, enc): 64 | # Try to use highest order context that exists based on the history suffix, such 65 | # that the next symbol has non-zero frequency. When symbol 256 is produced at a context 66 | # at any non-negative order, it means "escape to the next lower order with non-empty 67 | # context". When symbol 256 is produced at the order -1 context, it means "EOF". 68 | for order in reversed(range(len(history) + 1)): 69 | ctx = model.root_context 70 | for sym in history[ : order]: 71 | assert ctx.subcontexts is not None 72 | ctx = ctx.subcontexts[sym] 73 | if ctx is None: 74 | break 75 | else: # ctx is not None 76 | if symbol != 256 and ctx.frequencies.get(symbol) > 0: 77 | enc.write(ctx.frequencies, symbol) 78 | return 79 | # Else write context escape symbol and continue decrementing the order 80 | enc.write(ctx.frequencies, 256) 81 | # Logic for order = -1 82 | enc.write(model.order_minus1_freqs, symbol) 83 | 84 | 85 | # Main launcher 86 | if __name__ == "__main__": 87 | main(sys.argv[1 : ]) 88 | -------------------------------------------------------------------------------- /python/ppm-decompress.py: -------------------------------------------------------------------------------- 1 | # 2 | # Decompression application using prediction by partial matching (PPM) with arithmetic coding 3 | # 4 | # Usage: python ppm-decompress.py InputFile OutputFile 5 | # This decompresses files generated by the ppm-compress.py application. 6 | # 7 | # Copyright (c) Project Nayuki 8 | # MIT License. See readme file. 9 | # https://www.nayuki.io/page/reference-arithmetic-coding 10 | # 11 | 12 | import sys 13 | import arithmeticcoding, ppmmodel 14 | 15 | 16 | # Must be at least -1 and match ppm-compress.py. Warning: Exponential memory usage at O(257^n). 17 | MODEL_ORDER = 3 18 | 19 | 20 | # Command line main application function. 21 | def main(args): 22 | # Handle command line arguments 23 | if len(args) != 2: 24 | sys.exit("Usage: python ppm-decompress.py InputFile OutputFile") 25 | inputfile = args[0] 26 | outputfile = args[1] 27 | 28 | # Perform file decompression 29 | with open(inputfile, "rb") as inp, open(outputfile, "wb") as out: 30 | bitin = arithmeticcoding.BitInputStream(inp) 31 | decompress(bitin, out) 32 | 33 | 34 | def decompress(bitin, out): 35 | # Set up decoder and model. In this PPM model, symbol 256 represents EOF; 36 | # its frequency is 1 in the order -1 context but its frequency 37 | # is 0 in all other contexts (which have non-negative order). 38 | dec = arithmeticcoding.ArithmeticDecoder(32, bitin) 39 | model = ppmmodel.PpmModel(MODEL_ORDER, 257, 256) 40 | history = [] 41 | 42 | while True: 43 | # Decode and write one byte 44 | symbol = decode_symbol(dec, model, history) 45 | if symbol == 256: # EOF symbol 46 | break 47 | out.write(bytes((symbol,))) 48 | model.increment_contexts(history, symbol) 49 | 50 | if model.model_order >= 1: 51 | # Prepend current symbol, dropping oldest symbol if necessary 52 | if len(history) == model.model_order: 53 | history.pop() 54 | history.insert(0, symbol) 55 | 56 | 57 | def decode_symbol(dec, model, history): 58 | # Try to use highest order context that exists based on the history suffix. When symbol 256 59 | # is consumed at a context at any non-negative order, it means "escape to the next lower order 60 | # with non-empty context". When symbol 256 is consumed at the order -1 context, it means "EOF". 61 | for order in reversed(range(len(history) + 1)): 62 | ctx = model.root_context 63 | for sym in history[ : order]: 64 | assert ctx.subcontexts is not None 65 | ctx = ctx.subcontexts[sym] 66 | if ctx is None: 67 | break 68 | else: # ctx is not None 69 | symbol = dec.read(ctx.frequencies) 70 | if symbol < 256: 71 | return symbol 72 | # Else we read the context escape symbol, so continue decrementing the order 73 | # Logic for order = -1 74 | return dec.read(model.order_minus1_freqs) 75 | 76 | 77 | # Main launcher 78 | if __name__ == "__main__": 79 | main(sys.argv[1 : ]) 80 | -------------------------------------------------------------------------------- /python/ppmmodel.py: -------------------------------------------------------------------------------- 1 | # 2 | # Reference arithmetic coding 3 | # 4 | # Copyright (c) Project Nayuki 5 | # MIT License. See readme file. 6 | # https://www.nayuki.io/page/reference-arithmetic-coding 7 | # 8 | 9 | import arithmeticcoding 10 | 11 | 12 | class PpmModel: 13 | 14 | def __init__(self, order, symbollimit, escapesymbol): 15 | if not ((order >= -1) and (0 <= escapesymbol < symbollimit)): 16 | raise ValueError() 17 | self.model_order = order 18 | self.symbol_limit = symbollimit 19 | self.escape_symbol = escapesymbol 20 | 21 | if order >= 0: 22 | self.root_context = PpmModel.Context(symbollimit, order >= 1) 23 | self.root_context.frequencies.increment(escapesymbol) 24 | else: 25 | self.root_context = None 26 | self.order_minus1_freqs = arithmeticcoding.FlatFrequencyTable(symbollimit) 27 | 28 | 29 | def increment_contexts(self, history, symbol): 30 | if self.model_order == -1: 31 | return 32 | if not ((len(history) <= self.model_order) and (0 <= symbol < self.symbol_limit)): 33 | raise ValueError() 34 | 35 | ctx = self.root_context 36 | ctx.frequencies.increment(symbol) 37 | for (i, sym) in enumerate(history): 38 | subctxs = ctx.subcontexts 39 | assert subctxs is not None 40 | 41 | if subctxs[sym] is None: 42 | subctxs[sym] = PpmModel.Context(self.symbol_limit, i + 1 < self.model_order) 43 | subctxs[sym].frequencies.increment(self.escape_symbol) 44 | ctx = subctxs[sym] 45 | ctx.frequencies.increment(symbol) 46 | 47 | 48 | 49 | # Helper structure 50 | class Context: 51 | 52 | def __init__(self, symbols, hassubctx): 53 | self.frequencies = arithmeticcoding.SimpleFrequencyTable([0] * symbols) 54 | self.subcontexts = ([None] * symbols) if hassubctx else None 55 | --------------------------------------------------------------------------------