' in line:
13 | skip = False
14 | if skip:
15 | continue
16 | long_description_lines.append(line)
17 | return ''.join(long_description_lines)
18 |
19 |
20 | def _get_version():
21 | with open('bin/version.json', 'r') as f:
22 | version_info = json.load(f)
23 | if version_info['used']:
24 | raise ValueError('Version already used!')
25 | return version_info['version']
26 |
27 |
28 | setuptools.setup(
29 | name='torchac',
30 | packages=['torchac'],
31 | version=_get_version(),
32 | author='fab-jul',
33 | author_email='fabianjul@gmail.com',
34 | description='Fast Arithmetic Coding for PyTorch',
35 | long_description=_get_long_description(),
36 | long_description_content_type='text/markdown',
37 | python_requires='>=3.6',
38 | license='GNU General Public License',
39 | url='https://github.com/fab-jul/torchac')
40 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torchac import torchac
4 |
5 |
6 | def test_out_of_range_symbol():
7 | cdf_float = torch.tensor([0., 1/3, 2/3, 1.], dtype=torch.float32).reshape(1, -1)
8 | assert list(_encode_decode(cdf_float, [10],
9 | needs_normalization=False,
10 | check_input_bounds=False)) == [False]
11 |
12 |
13 | def test_uniform_float():
14 | cdf_float = torch.tensor([0., 1/3, 2/3, 1.], dtype=torch.float32).reshape(1, -1)
15 |
16 | # Check if integer conversion works as expected.
17 | cdf_int = torchac._convert_to_int_and_normalize(cdf_float,
18 | needs_normalization=False)
19 | assert cdf_int[0, 1] == 2**16//3
20 | assert cdf_int[0, -1] == 0
21 |
22 | # Check if we can uniquely encode without normalization.
23 | assert all(_encode_decode(cdf_float,
24 | symbols_to_check=(0, 1, 2),
25 | needs_normalization=False))
26 |
27 |
28 | def test_uniform_float_multipdim():
29 | cdf_float = torch.tensor([0., 1/3, 2/3, 1.], dtype=torch.float32).reshape(1, -1)
30 |
31 | L = 3
32 | C, H, W = 5, 8, 9
33 |
34 | Lp = L + 1
35 | cdf_float = torch.cat([cdf_float for _ in range(C*H*W)], dim=0)
36 | cdf_float = cdf_float.reshape(C, H, W, -1)
37 | assert cdf_float.shape[-1] == Lp
38 |
39 | sym = torch.arange(C * H * W, dtype=torch.int16) % L
40 | sym = sym.reshape(C, H, W)
41 |
42 | byte_stream = torchac.encode_float_cdf(
43 | cdf_float,
44 | sym,
45 | needs_normalization=False,
46 | check_input_bounds=True)
47 | sym_out = torchac.decode_float_cdf(
48 | cdf_float,
49 | byte_stream,
50 | needs_normalization=False)
51 | assert sym_out.equal(sym)
52 |
53 |
54 | def test_normalize_float():
55 | # Two times the same value -> needs to be normalized!
56 | cdf_float = torch.tensor([0., 1/3, 1/3, 1.], dtype=torch.float32).reshape(1, -1)
57 | # Check if we can uniquely encode
58 | assert all(_encode_decode(cdf_float,
59 | symbols_to_check=(0, 1, 2),
60 | needs_normalization=True))
61 |
62 | # Should raise because symbol is out of bounds.
63 | with pytest.raises(ValueError):
64 | sym = torch.tensor([3], dtype=torch.int16)
65 | torchac.encode_float_cdf(cdf_float, sym,
66 | needs_normalization=True,
67 | check_input_bounds=True)
68 |
69 |
70 | def test_normalization_sigmoid():
71 | mu = 0
72 | L = 256
73 | Lp = L + 1
74 | x_for_cdf = torch.linspace(-1, 1, Lp)
75 | # Logistic distribution.
76 | for sigma in [0.001, 0.01, 0.1, 1., 10.]:
77 | cdf_float = torch.sigmoid((x_for_cdf-mu)/sigma)
78 |
79 | # Put it into the expected shape.
80 | cdf_float = cdf_float.reshape(1, -1)
81 |
82 | # Check if we can uniquely decode all valid symbols.
83 | assert all(_encode_decode(
84 | cdf_float, symbols_to_check=range(L), needs_normalization=True))
85 |
86 |
87 | def _encode_decode(cdf_float, symbols_to_check,
88 | needs_normalization, check_input_bounds=True):
89 | # Check if we can uniquely encode
90 | for symbol in symbols_to_check:
91 | sym = torch.tensor([symbol], dtype=torch.int16)
92 | byte_stream = torchac.encode_float_cdf(
93 | cdf_float,
94 | sym,
95 | needs_normalization=needs_normalization,
96 | check_input_bounds=check_input_bounds)
97 | sym_out = torchac.decode_float_cdf(
98 | cdf_float,
99 | byte_stream,
100 | needs_normalization=needs_normalization)
101 | yield sym_out == sym
102 |
--------------------------------------------------------------------------------
/torchac/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torchac.torchac import encode_float_cdf
3 | from torchac.torchac import decode_float_cdf
4 | from torchac.torchac import encode_int16_normalized_cdf
5 | from torchac.torchac import decode_int16_normalized_cdf
6 |
--------------------------------------------------------------------------------
/torchac/backend/torchac_backend.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * COPYRIGHT 2020 ETH Zurich
3 | * BASED on
4 | *
5 | * https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
6 | */
7 |
8 | #include
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 |
20 | #include
21 |
22 |
23 | using cdf_t = uint16_t;
24 |
25 | /** Encapsulates a pointer to a CDF tensor */
26 | struct cdf_ptr {
27 | const cdf_t* data; // expected to be a N_sym x Lp matrix, stored in row major.
28 | const int N_sym; // Number of symbols stored by `data`.
29 | const int Lp; // == L+1, where L is the number of possible values a symbol can take.
30 | cdf_ptr(const cdf_t* data,
31 | const int N_sym,
32 | const int Lp) : data(data), N_sym(N_sym), Lp(Lp) {};
33 | };
34 |
35 | /** Class to save output bit by bit to a byte string */
36 | class OutCacheString {
37 | private:
38 | public:
39 | std::string out="";
40 | uint8_t cache=0;
41 | uint8_t count=0;
42 | void append(const int bit) {
43 | cache <<= 1;
44 | cache |= bit;
45 | count += 1;
46 | if (count == 8) {
47 | out.append(reinterpret_cast(&cache), 1);
48 | count = 0;
49 | }
50 | }
51 | void flush() {
52 | if (count > 0) {
53 | for (int i = count; i < 8; ++i) {
54 | append(0);
55 | }
56 | assert(count==0);
57 | }
58 | }
59 | void append_bit_and_pending(const int bit, uint64_t &pending_bits) {
60 | append(bit);
61 | while (pending_bits > 0) {
62 | append(!bit);
63 | pending_bits -= 1;
64 | }
65 | }
66 | };
67 |
68 | /** Class to read byte string bit by bit */
69 | class InCacheString {
70 | private:
71 | const std::string& in_;
72 |
73 | public:
74 | explicit InCacheString(const std::string& in) : in_(in) {};
75 |
76 | uint8_t cache=0;
77 | uint8_t cached_bits=0;
78 | size_t in_ptr=0;
79 |
80 | void get(uint32_t& value) {
81 | if (cached_bits == 0) {
82 | if (in_ptr == in_.size()){
83 | value <<= 1;
84 | return;
85 | }
86 | /// Read 1 byte
87 | cache = (uint8_t) in_[in_ptr];
88 | in_ptr++;
89 | cached_bits = 8;
90 | }
91 | value <<= 1;
92 | value |= (cache >> (cached_bits - 1)) & 1;
93 | cached_bits--;
94 | }
95 |
96 | void initialize(uint32_t& value) {
97 | for (int i = 0; i < 32; ++i) {
98 | get(value);
99 | }
100 | }
101 | };
102 |
103 | const void check_sym(const torch::Tensor& sym) {
104 | TORCH_CHECK(sym.sizes().size() == 1,
105 | "Invalid size for sym. Expected just 1 dim.")
106 | }
107 |
108 | /** Get an instance of the `cdf_ptr` struct. */
109 | const struct cdf_ptr get_cdf_ptr(const torch::Tensor& cdf)
110 | {
111 | TORCH_CHECK(!cdf.is_cuda(), "cdf must be on CPU!")
112 | const auto s = cdf.sizes();
113 | TORCH_CHECK(s.size() == 2, "Invalid size for cdf! Expected (N, Lp)")
114 |
115 | const int N_sym = s[0];
116 | const int Lp = s[1];
117 | const auto cdf_acc = cdf.accessor();
118 | const cdf_t* cdf_ptr = (uint16_t*)cdf_acc.data();
119 |
120 | const struct cdf_ptr res(cdf_ptr, N_sym, Lp);
121 | return res;
122 | }
123 |
124 |
125 | // -----------------------------------------------------------------------------
126 |
127 |
128 | /** Encode symbols `sym` with CDF represented by `cdf_ptr`. NOTE: this is not exposted to python. */
129 | py::bytes encode(
130 | const cdf_ptr& cdf_ptr,
131 | const torch::Tensor& sym){
132 |
133 | OutCacheString out_cache;
134 |
135 | uint32_t low = 0;
136 | uint32_t high = 0xFFFFFFFFU;
137 | uint64_t pending_bits = 0;
138 |
139 | const int precision = 16;
140 |
141 | const cdf_t* cdf = cdf_ptr.data;
142 | const int N_sym = cdf_ptr.N_sym;
143 | const int Lp = cdf_ptr.Lp;
144 | const int max_symbol = Lp - 2;
145 |
146 | auto sym_ = sym.accessor();
147 |
148 | for (int i = 0; i < N_sym; ++i) {
149 | const int16_t sym_i = sym_[i];
150 |
151 | const uint64_t span = static_cast(high) - static_cast(low) + 1;
152 |
153 | const int offset = i * Lp;
154 | // Left boundary is at offset + sym_i
155 | const uint32_t c_low = cdf[offset + sym_i];
156 | // Right boundary is at offset + sym_i + 1, except for the `max_symbol`
157 | // For which we hardcode the maxvalue. So if e.g.
158 | // L == 4, it means that Lp == 5, and the allowed symbols are
159 | // {0, 1, 2, 3}. The max symbol is thus Lp - 2 == 3. It's probability
160 | // is then given by c_max - cdf[-2].
161 | const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1];
162 |
163 | high = (low - 1) + ((span * static_cast(c_high)) >> precision);
164 | low = (low) + ((span * static_cast(c_low)) >> precision);
165 |
166 | while (true) {
167 | if (high < 0x80000000U) {
168 | out_cache.append_bit_and_pending(0, pending_bits);
169 | low <<= 1;
170 | high <<= 1;
171 | high |= 1;
172 | } else if (low >= 0x80000000U) {
173 | out_cache.append_bit_and_pending(1, pending_bits);
174 | low <<= 1;
175 | high <<= 1;
176 | high |= 1;
177 | } else if (low >= 0x40000000U && high < 0xC0000000U) {
178 | pending_bits++;
179 | low <<= 1;
180 | low &= 0x7FFFFFFF;
181 | high <<= 1;
182 | high |= 0x80000001;
183 | } else {
184 | break;
185 | }
186 | }
187 | }
188 |
189 | pending_bits += 1;
190 |
191 | if (pending_bits) {
192 | if (low < 0x40000000U) {
193 | out_cache.append_bit_and_pending(0, pending_bits);
194 | } else {
195 | out_cache.append_bit_and_pending(1, pending_bits);
196 | }
197 | }
198 |
199 | out_cache.flush();
200 |
201 | #ifdef VERBOSE
202 | std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now();
203 | std::cout << "Time difference (sec) = " << (std::chrono::duration_cast(end - begin).count()) /1000000.0 <((left + right) / 2);
234 | const auto v = cdf[offset + m];
235 | if (v < target) {
236 | left = m;
237 | } else if (v > target) {
238 | right = m;
239 | } else {
240 | return m;
241 | }
242 | }
243 |
244 | return left;
245 | }
246 |
247 |
248 | torch::Tensor decode(
249 | const cdf_ptr& cdf_ptr,
250 | const std::string& in) {
251 |
252 | #ifdef VERBOSE
253 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
254 | #endif
255 |
256 | const cdf_t* cdf = cdf_ptr.data;
257 | const int N_sym = cdf_ptr.N_sym; // To know the # of syms to decode. Is encoded in the stream!
258 | const int Lp = cdf_ptr.Lp; // To calculate offset
259 | const int max_symbol = Lp - 2;
260 |
261 | // 16 bit!
262 | auto out = torch::empty({N_sym}, torch::kShort);
263 | auto out_ = out.accessor();
264 |
265 | uint32_t low = 0;
266 | uint32_t high = 0xFFFFFFFFU;
267 | uint32_t value = 0;
268 | const uint32_t c_count = 0x10000U;
269 | const int precision = 16;
270 |
271 | InCacheString in_cache(in);
272 | in_cache.initialize(value);
273 |
274 | for (int i = 0; i < N_sym; ++i) {
275 | const uint64_t span = static_cast(high) - static_cast(low) + 1;
276 | // always < 0x10000 ???
277 | const uint16_t count = ((static_cast(value) - static_cast(low) + 1) * c_count - 1) / span;
278 |
279 | const int offset = i * Lp;
280 | auto sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
281 |
282 | out_[i] = (int16_t)sym_i;
283 |
284 | if (i == N_sym-1) {
285 | break;
286 | }
287 |
288 | const uint32_t c_low = cdf[offset + sym_i];
289 | const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1];
290 |
291 | high = (low - 1) + ((span * static_cast(c_high)) >> precision);
292 | low = (low) + ((span * static_cast(c_low)) >> precision);
293 |
294 | while (true) {
295 | if (low >= 0x80000000U || high < 0x80000000U) {
296 | low <<= 1;
297 | high <<= 1;
298 | high |= 1;
299 | in_cache.get(value);
300 | } else if (low >= 0x40000000U && high < 0xC0000000U) {
301 | /**
302 | * 0100 0000 ... <= value < 1100 0000 ...
303 | * <=>
304 | * 0100 0000 ... <= value <= 1011 1111 ...
305 | * <=>
306 | * value starts with 01 or 10.
307 | * 01 - 01 == 00 | 10 - 01 == 01
308 | * i.e., with shifts
309 | * 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
310 | * near convergence
311 | */
312 | low <<= 1;
313 | low &= 0x7FFFFFFFU; // make MSB 0
314 | high <<= 1;
315 | high |= 0x80000001U; // add 1 at the end, retain MSB = 1
316 | value -= 0x40000000U;
317 | in_cache.get(value);
318 | } else {
319 | break;
320 | }
321 | }
322 | }
323 |
324 | #ifdef VERBOSE
325 | std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now();
326 | std::cout << "Time difference (sec) = " << (std::chrono::duration_cast(end - begin).count()) /1000000.0 <=0.!')
39 | if cdf_float.max() > 1:
40 | raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!')
41 | Lp = cdf_float.shape[-1]
42 | if sym.max() >= Lp - 1:
43 | raise ValueError
44 | cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
45 | return encode_int16_normalized_cdf(cdf_int, sym)
46 |
47 |
48 | def decode_float_cdf(cdf_float, byte_stream, needs_normalization=True):
49 | """Encode symbols in `byte_stream` with potentially unnormalized float CDF.
50 |
51 | Check the README for more details.
52 |
53 | :param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp).
54 | :param byte_stream: byte-stream, encoding some symbols `sym`.
55 | :param needs_normalization: if True, assume `cdf_float` is un-normalized and
56 | needs normalization. Otherwise only convert it, without normalizing.
57 |
58 | :return: decoded `sym` of shape (N1, ..., Nm).
59 | """
60 | cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
61 | return decode_int16_normalized_cdf(cdf_int, byte_stream)
62 |
63 |
64 | def encode_int16_normalized_cdf(cdf_int, sym):
65 | """Encode symbols `sym` with a normalized integer cdf `cdf_int`.
66 |
67 | Check the README for more details.
68 |
69 | :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
70 | :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
71 |
72 | :return: byte-string, encoding `sym`
73 | """
74 | cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym)
75 | return torchac_backend.encode_cdf(cdf_int, sym)
76 |
77 |
78 | def decode_int16_normalized_cdf(cdf_int, byte_stream):
79 | """Decode symbols in `byte_stream` with a normalized integer cdf `cdf_int`.
80 |
81 | Check the README for more details.
82 |
83 | :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
84 | :param byte_stream: byte-stream, encoding some symbols `sym`.
85 |
86 | :return: decoded `sym` of shape (N1, ..., Nm).
87 | """
88 | cdf_reshaped = _check_and_reshape_inputs(cdf_int)
89 | # Merge the m dimensions into one.
90 | sym = torchac_backend.decode_cdf(cdf_reshaped, byte_stream)
91 | return _reshape_output(cdf_int.shape, sym)
92 |
93 |
94 | def _check_and_reshape_inputs(cdf, sym=None):
95 | """Check device, dtype, and shapes."""
96 | if cdf.is_cuda:
97 | raise ValueError('CDF must be on CPU')
98 | if sym is not None and sym.is_cuda:
99 | raise ValueError('Symbols must be on CPU')
100 | if sym is not None and sym.dtype != torch.int16:
101 | raise ValueError('Symbols must be int16!')
102 | if sym is not None:
103 | if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape:
104 | raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! '
105 | 'The first m elements of cdf.shape must be equal to '
106 | 'sym.shape, and cdf should only have one more dimension.')
107 | Lp = cdf.shape[-1]
108 | cdf = cdf.reshape(-1, Lp)
109 | if sym is None:
110 | return cdf
111 | sym = sym.reshape(-1)
112 | return cdf, sym
113 |
114 |
115 | def _reshape_output(cdf_shape, sym):
116 | """Reshape single dimension `sym` back to the correct spatial dimensions."""
117 | spatial_dimensions = cdf_shape[:-1]
118 | if len(sym) != np.prod(spatial_dimensions):
119 | raise ValueError()
120 | return sym.reshape(*spatial_dimensions)
121 |
122 |
123 | def _convert_to_int_and_normalize(cdf_float, needs_normalization):
124 | """Convert floatingpoint CDF to integers. See README for more info.
125 |
126 | The idea is the following:
127 | When we get the cdf here, it is (assumed to be) between 0 and 1, i.e,
128 | cdf \in [0, 1)
129 | (note that 1 should not be included.)
130 | We now want to convert this to int16 but make sure we do not get
131 | the same value twice, as this would break the arithmetic coder
132 | (you need a strictly monotonically increasing function).
133 | So, if needs_normalization==True, we multiply the input CDF
134 | with 2**16 - (Lp - 1). This means that now,
135 | cdf \in [0, 2**16 - (Lp - 1)].
136 | Then, in a final step, we add an arange(Lp), which is just a line with
137 | slope one. This ensure that for sure, we will get unique, strictly
138 | monotonically increasing CDFs, which are \in [0, 2**16)
139 | """
140 | Lp = cdf_float.shape[-1]
141 | factor = torch.tensor(
142 | 2, dtype=torch.float32, device=cdf_float.device).pow_(PRECISION)
143 | new_max_value = factor
144 | if needs_normalization:
145 | new_max_value = new_max_value - (Lp - 1)
146 | cdf_float = cdf_float.mul(new_max_value)
147 | cdf_float = cdf_float.round()
148 | cdf = cdf_float.to(dtype=torch.int16, non_blocking=True)
149 | if needs_normalization:
150 | r = torch.arange(Lp, dtype=torch.int16, device=cdf.device)
151 | cdf.add_(r)
152 | return cdf
153 |
--------------------------------------------------------------------------------