├── README.md ├── compression.ipynb └── output.png /README.md: -------------------------------------------------------------------------------- 1 | # Compression using Arithmetic Encoding and Binary Search 2 | 3 | This repo showcases how to compress data given a known probability distribution over the data. 4 | 5 | ![Compression Ratio vs Advantage](./output.png) -------------------------------------------------------------------------------- /compression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "compress uniform\n", 13 | "compress with prior\n", 14 | "====================\n", 15 | "Advantage: 1\n", 16 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 17 | "Accuracy@uniform: 0.0\n", 18 | "Accuracy@prior: 1.0\n", 19 | "#bits_uncompressed: 400\n", 20 | "#bits_compressed@uniform: 471\n", 21 | "#bits_compressed@prior: 422\n", 22 | "Compression ratio@uniform: -0.1775\n", 23 | "Compression ratio@prior: -0.05499999999999994\n", 24 | "====================\n", 25 | "Advantage: 10\n", 26 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 27 | "Accuracy@uniform: 0.0\n", 28 | "Accuracy@prior: 1.0\n", 29 | "#bits_uncompressed: 400\n", 30 | "#bits_compressed@uniform: 471\n", 31 | "#bits_compressed@prior: 372\n", 32 | "Compression ratio@uniform: -0.1775\n", 33 | "Compression ratio@prior: 0.06999999999999995\n", 34 | "====================\n", 35 | "Advantage: 100\n", 36 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 37 | "Accuracy@uniform: 0.0\n", 38 | "Accuracy@prior: 1.0\n", 39 | "#bits_uncompressed: 400\n", 40 | "#bits_compressed@uniform: 471\n", 41 | "#bits_compressed@prior: 304\n", 42 | "Compression ratio@uniform: -0.1775\n", 43 | "Compression ratio@prior: 0.24\n", 44 | "====================\n", 45 | "Advantage: 1000\n", 46 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 47 | "Accuracy@uniform: 0.0\n", 48 | "Accuracy@prior: 1.0\n", 49 | "#bits_uncompressed: 400\n", 50 | "#bits_compressed@uniform: 471\n", 51 | "#bits_compressed@prior: 244\n", 52 | "Compression ratio@uniform: -0.1775\n", 53 | "Compression ratio@prior: 0.39\n", 54 | "====================\n", 55 | "Advantage: 10000\n", 56 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 57 | "Accuracy@uniform: 0.0\n", 58 | "Accuracy@prior: 1.0\n", 59 | "#bits_uncompressed: 400\n", 60 | "#bits_compressed@uniform: 471\n", 61 | "#bits_compressed@prior: 173\n", 62 | "Compression ratio@uniform: -0.1775\n", 63 | "Compression ratio@prior: 0.5675\n", 64 | "====================\n", 65 | "Advantage: 100000\n", 66 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 67 | "Accuracy@uniform: 0.0\n", 68 | "Accuracy@prior: 1.0\n", 69 | "#bits_uncompressed: 400\n", 70 | "#bits_compressed@uniform: 471\n", 71 | "#bits_compressed@prior: 129\n", 72 | "Compression ratio@uniform: -0.1775\n", 73 | "Compression ratio@prior: 0.6775\n", 74 | "====================\n", 75 | "Advantage: 1000000\n", 76 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 77 | "Accuracy@uniform: 0.0\n", 78 | "Accuracy@prior: 1.0\n", 79 | "#bits_uncompressed: 400\n", 80 | "#bits_compressed@uniform: 471\n", 81 | "#bits_compressed@prior: 100\n", 82 | "Compression ratio@uniform: -0.1775\n", 83 | "Compression ratio@prior: 0.75\n", 84 | "====================\n", 85 | "Advantage: 10000000\n", 86 | "Theoretical accuracy@uniform: 9.5367431640625e-07\n", 87 | "Accuracy@uniform: 0.0\n", 88 | "Accuracy@prior: 1.0\n", 89 | "#bits_uncompressed: 400\n", 90 | "#bits_compressed@uniform: 471\n", 91 | "#bits_compressed@prior: 100\n", 92 | "Compression ratio@uniform: -0.1775\n", 93 | "Compression ratio@prior: 0.75\n" 94 | ] 95 | }, 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "" 100 | ] 101 | }, 102 | "execution_count": 7, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | }, 106 | { 107 | "data": { 108 | "image/png": "", 109 | "text/plain": [ 110 | "
" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "output_type": "display_data" 115 | } 116 | ], 117 | "source": [ 118 | "import numpy as np\n", 119 | "import matplotlib.pyplot as plt\n", 120 | "from typing import List\n", 121 | "import seaborn as sns\n", 122 | "\n", 123 | "\n", 124 | "def gen_proba_uniform(rng: np.random.RandomState, n: int, m: int):\n", 125 | " \"\"\"\n", 126 | " :param n: number of tokens\n", 127 | " :param m: codebook size\n", 128 | " \"\"\"\n", 129 | " proba_uniform = rng.uniform(size=(n, m))\n", 130 | " proba_uniform /= proba_uniform.sum(axis=1, keepdims=True)\n", 131 | " return proba_uniform\n", 132 | "\n", 133 | "\n", 134 | "def gen_proba_with_prior(\n", 135 | " rng: np.random.RandomState, tokens: List[int], m: int, advantage: float = 0.5\n", 136 | "):\n", 137 | " \"\"\"\n", 138 | " :param tokens: list of tokens\n", 139 | " :param m: codebook size\n", 140 | " :param advantage: advantage of the prior\n", 141 | " \"\"\"\n", 142 | " n = len(tokens)\n", 143 | " proba_with_prior = rng.uniform(size=(n, m))\n", 144 | " proba_with_prior[np.arange(n), tokens] += advantage\n", 145 | " proba_with_prior /= proba_with_prior.sum(axis=1, keepdims=True)\n", 146 | " return proba_with_prior\n", 147 | "\n", 148 | "\n", 149 | "# Compresssion & Decompression Algorithm\n", 150 | "def encode(proba: List[int], k: int):\n", 151 | " \"\"\"\n", 152 | " Encode data k using binary search\n", 153 | " \"\"\"\n", 154 | "\n", 155 | " pcsum = proba.cumsum()\n", 156 | " assert np.allclose(pcsum[-1], 1)\n", 157 | "\n", 158 | " left, right = 0.0, 1.0\n", 159 | " bisection_sequence = []\n", 160 | " # print(\"target: \", pcsum[k])\n", 161 | " while True:\n", 162 | " mid = (left + right) * 0.5\n", 163 | " # l = np.where(pcsum > mid)\n", 164 | " # a faster way to find the index of the first element that is greater than mid\n", 165 | " idx = np.searchsorted(pcsum, mid, side=\"right\")\n", 166 | "\n", 167 | " # print(\"=\" * 20)\n", 168 | " # print(\"pcsum:\", pcsum)\n", 169 | " # print(\"left, right, mid:\", left, right, mid)\n", 170 | " # print(\"bisect idx:\", idx)\n", 171 | " # print(\"k:\", k)\n", 172 | " # print(\"pcsum[idx]:\", pcsum[idx])\n", 173 | "\n", 174 | " if idx < k:\n", 175 | " left = mid\n", 176 | " bisection_sequence.append(1)\n", 177 | " elif idx > k:\n", 178 | " right = mid\n", 179 | " bisection_sequence.append(0)\n", 180 | " else:\n", 181 | " break\n", 182 | " # print(\"bisection_sequence:\", bisection_sequence)\n", 183 | "\n", 184 | " return bisection_sequence\n", 185 | "\n", 186 | "\n", 187 | "def decode(proba: List[int], bisection_sequence: List[int]):\n", 188 | " \"\"\"\n", 189 | " Decode data using binary search\n", 190 | " \"\"\"\n", 191 | "\n", 192 | " pcsum = proba.cumsum()\n", 193 | " assert np.allclose(pcsum[-1], 1)\n", 194 | "\n", 195 | " left, right = 0.0, 1.0\n", 196 | " for bit in bisection_sequence:\n", 197 | " mid = (left + right) * 0.5\n", 198 | " # print(\"=\" * 20)\n", 199 | " # print(\"left, right, mid:\", left, right, mid)\n", 200 | " if bit == 0:\n", 201 | " right = mid\n", 202 | " else:\n", 203 | " left = mid\n", 204 | "\n", 205 | " mid = (left + right) * 0.5\n", 206 | " idx = np.searchsorted(pcsum, mid, side=\"right\")\n", 207 | " return idx\n", 208 | "\n", 209 | "\n", 210 | "def do_compress_and_decompress(tokens: List[int], proba: List[List[float]]):\n", 211 | " \"\"\"\n", 212 | " Compress and decompress a sequence of tokens\n", 213 | " :param tokens: a sequence of tokens\n", 214 | " :param proba: a list of probability distributions of shape (n, m),\n", 215 | " where n is the number of tokens and m is the codebook size\n", 216 | " \"\"\"\n", 217 | "\n", 218 | " # to store \"the number of compressed bits\" of a token, we need at most log2(log2(m)) bits\n", 219 | " # e.g., if codebook size m = 256, the length of the compressed sequence is at most 8 bits,\n", 220 | " # so we need at most 3 bits to represent the length of the compressed sequence\n", 221 | " num_bits_to_represent_num_compressed_bits = int(\n", 222 | " np.ceil(np.log2(np.log2(len(proba[0]))))\n", 223 | " )\n", 224 | "\n", 225 | " num_bits_compressed = 0\n", 226 | "\n", 227 | " for i, k0 in enumerate(tokens):\n", 228 | " bisection_sequence = encode(proba[i], k0)\n", 229 | " k1 = decode(proba[i], bisection_sequence)\n", 230 | "\n", 231 | " assert k0 == k1\n", 232 | "\n", 233 | " num_bits_compressed += num_bits_to_represent_num_compressed_bits\n", 234 | "\n", 235 | " num_bits_compressed += len(bisection_sequence)\n", 236 | "\n", 237 | " return num_bits_compressed\n", 238 | "\n", 239 | "\n", 240 | "def compress_algorithm_test(\n", 241 | " rng: np.random.RandomState,\n", 242 | " seqlen: int,\n", 243 | " bits_per_token: int,\n", 244 | " advantages: List[float],\n", 245 | "):\n", 246 | " \"\"\"\n", 247 | " :param rng: random number generator\n", 248 | " :param seqlen: sequence length\n", 249 | " :param bits_per_token: bits per token\n", 250 | " :param advantages: advantages of the prior\n", 251 | " \"\"\"\n", 252 | "\n", 253 | " # hyperparameters\n", 254 | " n = seqlen\n", 255 | " m = 2**bits_per_token\n", 256 | "\n", 257 | " # generate tokens\n", 258 | " tokens = rng.randint(m, size=n)\n", 259 | "\n", 260 | " # generate uniform probability\n", 261 | " proba_uniform = gen_proba_uniform(rng, n, m)\n", 262 | "\n", 263 | " print(\"compress uniform\")\n", 264 | " num_bits_uncompressed = n * bits_per_token\n", 265 | " num_bits_compressed_uniform = do_compress_and_decompress(tokens, proba_uniform)\n", 266 | "\n", 267 | " print(\"compress with prior\")\n", 268 | " compression_ratios_with_prior = []\n", 269 | " for advantage in advantages:\n", 270 | " # generate probability with prior\n", 271 | " proba_with_prior = gen_proba_with_prior(rng, tokens, m, advantage=advantage)\n", 272 | " print(\"=\" * 20)\n", 273 | " print(\"Advantage:\", advantage)\n", 274 | " print(\"Theoretical accuracy@uniform:\", 1 / m)\n", 275 | " print(\"Accuracy@uniform:\", np.mean(proba_uniform.argmax(axis=1) == tokens))\n", 276 | " print(\"Accuracy@prior:\", np.mean(proba_with_prior.argmax(axis=1) == tokens))\n", 277 | "\n", 278 | " num_bits_compressed_with_prior = do_compress_and_decompress(\n", 279 | " tokens, proba_with_prior\n", 280 | " )\n", 281 | "\n", 282 | " print(\"#bits_uncompressed:\", num_bits_uncompressed)\n", 283 | " print(\"#bits_compressed@uniform:\", num_bits_compressed_uniform)\n", 284 | " print(\"#bits_compressed@prior:\", num_bits_compressed_with_prior)\n", 285 | " compression_ratio_uniform = (\n", 286 | " 1 - num_bits_compressed_uniform / num_bits_uncompressed\n", 287 | " )\n", 288 | "\n", 289 | " print(\"Compression ratio@uniform:\", compression_ratio_uniform)\n", 290 | "\n", 291 | " cr_prior = 1 - num_bits_compressed_with_prior / num_bits_uncompressed\n", 292 | " compression_ratios_with_prior.append(cr_prior)\n", 293 | " print(\"Compression ratio@prior:\", cr_prior)\n", 294 | "\n", 295 | " return {\n", 296 | " 'cr_uniform': compression_ratio_uniform,\n", 297 | " 'cr_prior': compression_ratios_with_prior,\n", 298 | " }\n", 299 | "\n", 300 | "\n", 301 | "\n", 302 | "rng = np.random.RandomState(42)\n", 303 | "seqlen = 20\n", 304 | "bits_per_token = 20\n", 305 | "advantages = [10**i for i in range(8)]\n", 306 | "rst = compress_algorithm_test(rng, seqlen, bits_per_token, advantages)\n", 307 | "\n", 308 | "# plot the results\n", 309 | "sns.set()\n", 310 | "plt.figure(figsize=(8, 6))\n", 311 | "\n", 312 | "plt.plot(advantages, np.zeros_like(advantages), label='no compression')\n", 313 | "plt.plot(advantages, rst['cr_uniform'] * np.ones_like(advantages), label='uniform')\n", 314 | "plt.plot(advantages, rst['cr_prior'], label='prior')\n", 315 | "\n", 316 | "plt.xscale('log')\n", 317 | "plt.xlabel('advantage')\n", 318 | "plt.ylabel('compression ratio')\n", 319 | "\n", 320 | "plt.title('Compression Ratio vs. Advantage')\n", 321 | "\n", 322 | "plt.legend()\n" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.10.10" 350 | }, 351 | "orig_nbformat": 4, 352 | "vscode": { 353 | "interpreter": { 354 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 355 | } 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxytim/arithmetic-encoding-compression/592633ad2baf59b56117186062caf750a98d2358/output.png --------------------------------------------------------------------------------