├── 0_baseline.ipynb ├── 0_some-concepts.ipynb ├── 1_1_nn_plus_gzip_original.ipynb ├── 1_2_caching-multiprocessing.py ├── 1_2_nn_plus_gzip_fix-tie-breaking.ipynb ├── 2_nn_countvecs.ipynb ├── 3_distilbert.ipynb ├── 4_r8-dataset.ipynb ├── LICENSE ├── README.md ├── figures ├── pseudocode.png └── r8.png └── local_dataset_utilities.py /0_baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f79f00e6-88bc-4d4d-930d-9f346eba5955", 6 | "metadata": {}, 7 | "source": [ 8 | "# Baseline accuracy" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "6e5e34ee-34b1-472d-8da5-09f50ca5a23e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import gzip\n", 19 | "import os.path as op\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "\n", 24 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "id": "aceb0005-1dcd-4735-8cff-fc3b10baae4f", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stderr", 35 | "output_type": "stream", 36 | "text": [ 37 | "100%|███████████████████████████████████| 50000/50000 [00:19<00:00, 2542.67it/s]\n" 38 | ] 39 | }, 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Class distribution:\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "if not (op.isfile(\"train.csv\") and op.isfile(\"val.csv\") and op.isfile(\"test.csv\")):\n", 50 | " download_dataset()\n", 51 | "\n", 52 | " df = load_dataset_into_to_dataframe()\n", 53 | " partition_dataset(df)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "47535727-bbc5-44ba-ae42-bcd34781adcb", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "df_train = pd.read_csv(\"train.csv\")\n", 64 | "df_val = pd.read_csv(\"val.csv\")\n", 65 | "df_test = pd.read_csv(\"test.csv\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "0bde64f1-a5d0-4269-a9dc-eaa6d2159872", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "(35000, 3)" 78 | ] 79 | }, 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "df_train.shape" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "id": "8697ead4-4986-45b8-abe0-1fab35afc4ca", 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "(10000, 3)" 99 | ] 100 | }, 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "df_test.shape" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 6, 113 | "id": "db443420-3875-4b24-adaa-9549aa98a536", 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "array([5006, 4994])" 120 | ] 121 | }, 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "bcnt = np.bincount(df_test[\"label\"].values)\n", 129 | "bcnt" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "id": "3703d434-c14d-4562-bbe4-bd3841791235", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "Baseline accuracy: 0.5006\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "print(\"Baseline accuracy:\", np.max(bcnt)/ bcnt.sum())" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python 3 (ipykernel)", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.10.6" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 5 172 | } 173 | -------------------------------------------------------------------------------- /0_some-concepts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "33075d8c-6c74-46bd-a19e-89f75262ff72", 6 | "metadata": {}, 7 | "source": [ 8 | "# Compression" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "b3cf065e-0871-46dc-917a-637ac6590b31", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import gzip\n", 19 | "\n", 20 | "txt_1 = \"hello world\"\n", 21 | "txt_2 = \"some text some text some text\"" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "0ddb828f-2480-491d-ae66-c1a2dc902523", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "31" 34 | ] 35 | }, 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "len(gzip.compress(txt_1.encode()))" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "id": "fac22ce9-a243-433a-81ed-0bad131ec7ae", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "33" 55 | ] 56 | }, 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "len(gzip.compress(txt_2.encode()))" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "id": "c99131b6-9ef1-45de-af26-c47f525956d9", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "43" 76 | ] 77 | }, 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "len(gzip.compress(\" \".join([txt_1, txt_2]).encode()))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "id": "771a37e0-67fe-4274-962f-5a99911c451a", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "34" 97 | ] 98 | }, 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "len(gzip.compress(\" \".join([txt_1, txt_1]).encode()))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "id": "36303cb5-147c-41b0-9758-fb16753774ff", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "33" 118 | ] 119 | }, 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "len(gzip.compress(\" \".join([txt_2, txt_2]).encode()))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "b9327048-17ef-4401-b3fd-f462f1e57888", 132 | "metadata": {}, 133 | "source": [ 134 | "# Tie breaking" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "71b93107-32b8-4b14-b7b1-d98d6049f154", 140 | "metadata": {}, 141 | "source": [ 142 | "Original code always selects index with lowest label in case of a tie" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "id": "97b3e05d-27da-4c77-b2d1-e97f28b11d15", 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": [ 154 | "0" 155 | ] 156 | }, 157 | "execution_count": 7, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "top_k_class = [0, 1]\n", 164 | "max(set(top_k_class), key=top_k_class.count)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 8, 170 | "id": "bbeba7c4-492e-438d-960d-f887fd4ac455", 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "0" 177 | ] 178 | }, 179 | "execution_count": 8, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "top_k_class = [1, 0]\n", 186 | "max(set(top_k_class), key=top_k_class.count)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 9, 192 | "id": "c8acd613-1f2a-4759-955b-800513024d1d", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "0" 199 | ] 200 | }, 201 | "execution_count": 9, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "top_k_class = [1, 0, 2]\n", 208 | "max(set(top_k_class), key=top_k_class.count)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "5a9f3761-e957-4f5e-acc9-aa13df736cf2", 214 | "metadata": {}, 215 | "source": [ 216 | "We can prevent this using Counter, which selects the first label in case of a tie. If labels are sorted by distance, we can ensure it's picking the closest neighbor in case of a tie, which is a more reasonable choice than always selecting the lowest-index class:" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 10, 222 | "id": "1bcc6ca8-1acb-48fa-9526-f470ea7da06f", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "from collections import Counter" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 11, 232 | "id": "b40df133-47c8-4006-80a8-217e717d95d6", 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "0" 239 | ] 240 | }, 241 | "execution_count": 11, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "top_k_class = [0, 1]\n", 248 | "\n", 249 | "Counter(top_k_class).most_common()[0][0]" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 12, 255 | "id": "43e250e8-36db-4f9b-952f-35e4faf793c2", 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "1" 262 | ] 263 | }, 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "top_k_class = [1, 0]\n", 271 | "\n", 272 | "Counter(top_k_class).most_common()[0][0]" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 16, 278 | "id": "1a2ccec3-a2db-40b6-bddc-ff2cbe721b5c", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "data": { 283 | "text/plain": [ 284 | "1" 285 | ] 286 | }, 287 | "execution_count": 16, 288 | "metadata": {}, 289 | "output_type": "execute_result" 290 | } 291 | ], 292 | "source": [ 293 | "top_k_class = [1, 2, 0]\n", 294 | "\n", 295 | "Counter(top_k_class).most_common()[0][0]" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 14, 301 | "id": "141f5817-2157-454c-82b7-b5024b2f4018", 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "2" 308 | ] 309 | }, 310 | "execution_count": 14, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "top_k_class = [1, 0, 2, 2]\n", 317 | "\n", 318 | "Counter(top_k_class).most_common()[0][0]" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "id": "e50515d5-b4cd-4c56-9a2c-9b0416367859", 324 | "metadata": {}, 325 | "source": [ 326 | "### Count vectors" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 19, 332 | "id": "efea1520-aed1-48b6-a9f5-6ac304162f52", 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "[0. 0.75 0.25]\n", 340 | "[0. 0.75 0.25]\n", 341 | "[0. 0.75 0.25]\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "import numpy as np\n", 347 | "\n", 348 | "text_1 = np.array([0., 3., 1.]) \n", 349 | "text_2 = np.array([0., 3., 1.])\n", 350 | "\n", 351 | "text_1 /= np.sum(text_1)\n", 352 | "text_2 /= np.sum(text_2)\n", 353 | "\n", 354 | "print(text_1)\n", 355 | "print(text_2)\n", 356 | "\n", 357 | "added = text_1 + text_2\n", 358 | "\n", 359 | "print(added / np.sum(added))" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "id": "89f20d55-a460-4de3-ab5d-40322be8ddd1", 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [] 369 | } 370 | ], 371 | "metadata": { 372 | "kernelspec": { 373 | "display_name": "Python 3 (ipykernel)", 374 | "language": "python", 375 | "name": "python3" 376 | }, 377 | "language_info": { 378 | "codemirror_mode": { 379 | "name": "ipython", 380 | "version": 3 381 | }, 382 | "file_extension": ".py", 383 | "mimetype": "text/x-python", 384 | "name": "python", 385 | "nbconvert_exporter": "python", 386 | "pygments_lexer": "ipython3", 387 | "version": "3.10.6" 388 | } 389 | }, 390 | "nbformat": 4, 391 | "nbformat_minor": 5 392 | } 393 | -------------------------------------------------------------------------------- /1_1_nn_plus_gzip_original.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "965aa954-7744-4ecb-8b38-a023f3c1b9af", 6 | "metadata": {}, 7 | "source": [ 8 | "# NN + Gzip on IMDB Movie Review Dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "275f3c53-b5c7-4856-9a15-656a98b33fd8", 14 | "metadata": {}, 15 | "source": [ 16 | "# NN + Gzip on IMDB Movie Review Dataset\n", 17 | "\n", 18 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n", 19 | "\n", 20 | "\n", 21 | "" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "id": "54b93603-f41c-4016-87aa-59998990075c", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import gzip\n", 32 | "import os.path as op\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "\n", 37 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "a03e71dc-e09e-4907-bc06-f8250b97005e", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n", 48 | " download_dataset()\n", 49 | "\n", 50 | " df = load_dataset_into_to_dataframe()\n", 51 | " partition_dataset(df)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "id": "bfff472d-57c1-4310-8a1b-9a3e4339646a", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "df_train = pd.read_csv(\"train.csv\")\n", 62 | "df_val = pd.read_csv(\"val.csv\")\n", 63 | "df_test = pd.read_csv(\"test.csv\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "id": "29fbe3bd-c873-4372-9c79-ea38b751608c", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stderr", 74 | "output_type": "stream", 75 | "text": [ 76 | "100%|██████████████████████████████████| 10000/10000 [21:16:23<00:00, 7.66s/it]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "from tqdm import tqdm\n", 82 | "\n", 83 | "k = 2\n", 84 | "\n", 85 | "predicted_classes = []\n", 86 | "\n", 87 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n", 88 | " test_text = row_test[1][\"text\"]\n", 89 | " test_label = row_test[1][\"label\"]\n", 90 | " c_test_text = len(gzip.compress(test_text.encode()))\n", 91 | " distance_from_test_instance = []\n", 92 | " \n", 93 | " for row_train in df_train.iterrows():\n", 94 | " train_text = row_train[1][\"text\"]\n", 95 | " train_label = row_train[1][\"label\"]\n", 96 | " c_train_text = len(gzip.compress(train_text.encode()))\n", 97 | " \n", 98 | " train_plus_test = \" \".join([test_text, train_text])\n", 99 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n", 100 | " \n", 101 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n", 102 | " / max(c_test_text, c_train_text) )\n", 103 | " distance_from_test_instance.append(ncd)\n", 104 | " \n", 105 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n", 106 | " \n", 107 | " #top_k_class = list(df_train.iloc[sorted_idx[:k]][\"label\"].values)\n", 108 | " #predicted_class = max(set(top_k_class), key=top_k_class.count)\n", 109 | " top_k_class = df_train.iloc[sorted_idx[:k]][\"label\"].values\n", 110 | " predicted_class = np.argmax(np.bincount(top_k_class))\n", 111 | " \n", 112 | " predicted_classes.append(predicted_class)\n", 113 | " " 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "id": "1f44b0d2-8303-409c-b40f-7910ab415da1", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Accuracy: 0.7005\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3 (ipykernel)", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.10.6" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 5 156 | } 157 | -------------------------------------------------------------------------------- /1_2_caching-multiprocessing.py: -------------------------------------------------------------------------------- 1 | # Parallel processing version of 1_2_nn_plus_gzip_fix-tie-breaking.ipynb 2 | # On a 2020 MacBook Air, it runs about 4 times faster ~1 iter/sec 3 | # than the non-parallel version (~4 iter/sec) 4 | 5 | # It should finish in about 2-3 h compared to ~12 h before 6 | 7 | from collections import Counter 8 | import gzip 9 | import multiprocessing as mp 10 | import os.path as op 11 | 12 | from joblib import Parallel, delayed 13 | import numpy as np 14 | import pandas as pd 15 | from tqdm import tqdm 16 | 17 | from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset 18 | 19 | 20 | def process_dataset_subset(df_train_subset, test_text, c_test_text, d): 21 | 22 | distances_to_test = [] 23 | for row_train in df_train_subset.iterrows(): 24 | index = row_train[0] 25 | train_text = row_train[1]["text"] 26 | c_train_text = d[index] 27 | 28 | train_plus_test = " ".join([test_text, train_text]) 29 | c_train_plus_test = len(gzip.compress(train_plus_test.encode())) 30 | 31 | ncd = ( (c_train_plus_test - min(c_train_text, c_test_text)) 32 | / max(c_test_text, c_train_text) ) 33 | 34 | distances_to_test.append(ncd) 35 | 36 | return distances_to_test 37 | 38 | 39 | def divide_range_into_chunks(start, end, num_chunks): 40 | chunk_size = (end - start) // num_chunks 41 | ranges = [(i, i + chunk_size) for i in range(start, end, chunk_size)] 42 | ranges[-1] = (ranges[-1][0], end) # Ensure the last chunk includes the end 43 | return ranges 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | if not op.isfile("train.csv") and not op.isfile("val.csv") and not op.isfile("test.csv"): 49 | download_dataset() 50 | 51 | df = load_dataset_into_to_dataframe() 52 | partition_dataset(df) 53 | 54 | df_train = pd.read_csv("train.csv") 55 | df_val = pd.read_csv("val.csv") 56 | df_test = pd.read_csv("test.csv") 57 | 58 | num_processes = mp.cpu_count() 59 | k = 2 60 | predicted_classes = [] 61 | 62 | start = 0 63 | end = df_train.shape[0] 64 | ranges = divide_range_into_chunks(start, end, num_chunks=num_processes) 65 | 66 | 67 | # caching compressed training examples 68 | d = {} 69 | for i, row_train in enumerate(df_train.iterrows()): 70 | train_text = row_train[1]["text"] 71 | train_label = row_train[1]["label"] 72 | c_train_text = len(gzip.compress(train_text.encode())) 73 | 74 | d[i] = c_train_text 75 | 76 | # main loop 77 | for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]): 78 | 79 | test_text = row_test[1]["text"] 80 | test_label = row_test[1]["label"] 81 | c_test_text = len(gzip.compress(test_text.encode())) 82 | all_train_distances_to_test = [] 83 | 84 | # parallelize iteration over training set into num_processes chunks 85 | with Parallel(n_jobs=num_processes, backend="loky") as parallel: 86 | 87 | results = parallel( 88 | delayed(process_dataset_subset)(df_train[range_start:range_end], test_text, c_test_text, d) 89 | for range_start, range_end in ranges 90 | ) 91 | for p in results: 92 | all_train_distances_to_test.extend(p) 93 | 94 | sorted_idx = np.argsort(np.array(all_train_distances_to_test.extend)) 95 | top_k_class = np.array(df_train["label"])[sorted_idx[:k]] 96 | predicted_class = Counter(top_k_class).most_common()[0][0] 97 | 98 | predicted_classes.append(predicted_class) 99 | 100 | print("Accuracy:", np.mean(np.array(predicted_classes) == df_test["label"].values)) -------------------------------------------------------------------------------- /1_2_nn_plus_gzip_fix-tie-breaking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "965aa954-7744-4ecb-8b38-a023f3c1b9af", 6 | "metadata": {}, 7 | "source": [ 8 | "# NN + Gzip on IMDB Movie Review Dataset\n", 9 | "\n", 10 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n", 11 | "\n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "**Modified to break ties based on choosing the closest neighbors** instead of the lowest index (see explanation in [0_some-concepts.ipynb](0_some-concepts.ipynb)).\n", 17 | "\n", 18 | "" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "54b93603-f41c-4016-87aa-59998990075c", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import gzip\n", 29 | "import os.path as op\n", 30 | "\n", 31 | "import numpy as np\n", 32 | "import pandas as pd\n", 33 | "\n", 34 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "a03e71dc-e09e-4907-bc06-f8250b97005e", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n", 45 | " download_dataset()\n", 46 | "\n", 47 | " df = load_dataset_into_to_dataframe()\n", 48 | " partition_dataset(df)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "bfff472d-57c1-4310-8a1b-9a3e4339646a", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "df_train = pd.read_csv(\"train.csv\")\n", 59 | "df_val = pd.read_csv(\"val.csv\")\n", 60 | "df_test = pd.read_csv(\"test.csv\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "id": "689b4673-db95-4dd6-ad3a-d2d31cad1a16", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | "100%|██████████████████████████████████| 10000/10000 [11:40:18<00:00, 4.20s/it]" 74 | ] 75 | }, 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "Accuracy: 0.7191\n" 81 | ] 82 | }, 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "from tqdm import tqdm\n", 93 | "from collections import Counter\n", 94 | "\n", 95 | "k = 2\n", 96 | "\n", 97 | "predicted_classes = []\n", 98 | "\n", 99 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n", 100 | " test_text = row_test[1][\"text\"]\n", 101 | " test_label = row_test[1][\"label\"]\n", 102 | " c_test_text = len(gzip.compress(test_text.encode()))\n", 103 | " distance_from_test_instance = []\n", 104 | " \n", 105 | " for row_train in df_train.iterrows():\n", 106 | " train_text = row_train[1][\"text\"]\n", 107 | " train_label = row_train[1][\"label\"]\n", 108 | " c_train_text = len(gzip.compress(train_text.encode()))\n", 109 | " \n", 110 | " train_plus_test = \" \".join([test_text, train_text])\n", 111 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n", 112 | " \n", 113 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n", 114 | " / max(c_test_text, c_train_text) )\n", 115 | " distance_from_test_instance.append(ncd)\n", 116 | " \n", 117 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n", 118 | " top_k_class = np.array(df_train[\"label\"])[sorted_idx[:k]]\n", 119 | " predicted_class = Counter(top_k_class).most_common()[0][0]\n", 120 | " \n", 121 | " predicted_classes.append(predicted_class)\n", 122 | " \n", 123 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))" 124 | ] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3 (ipykernel)", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.10.6" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 5 148 | } 149 | -------------------------------------------------------------------------------- /2_nn_countvecs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fb879955-ad5e-4c0d-a342-8772d119598e", 6 | "metadata": {}, 7 | "source": [ 8 | "# NN + Cosine Distance on IMDB Movie Review Dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "836be7ab-cdc0-4376-ab23-27d54f486f39", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import gzip\n", 19 | "import os.path as op\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "\n", 24 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "id": "e23aba95-b18a-411d-9ad2-152c06071575", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n", 35 | " download_dataset()\n", 36 | "\n", 37 | " df = load_dataset_into_to_dataframe()\n", 38 | " partition_dataset(df)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "9e01dd62-601f-4eb5-8a64-ed8fc39cd719", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "df_train = pd.read_csv(\"train.csv\")\n", 49 | "df_val = pd.read_csv(\"val.csv\")\n", 50 | "df_test = pd.read_csv(\"test.csv\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "id": "14c67c25-6275-4ec0-9596-73a014adfc8f", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "from sklearn.feature_extraction.text import CountVectorizer\n", 61 | "\n", 62 | "\n", 63 | "cv = CountVectorizer(lowercase=True, max_features=10_000, stop_words=\"english\")\n", 64 | "\n", 65 | "cv.fit(df_train[\"text\"])\n", 66 | "\n", 67 | "X_train = cv.transform(df_train[\"text\"])\n", 68 | "X_val = cv.transform(df_val[\"text\"])\n", 69 | "X_test = cv.transform(df_test[\"text\"])" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "id": "e22430d5-96e5-49ff-8aeb-1b7c5be57e26", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "100%|███████████████████████████████████| 10000/10000 [9:10:43<00:00, 3.30s/it]" 83 | ] 84 | }, 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "Accuracy: 0.6801\n" 90 | ] 91 | }, 92 | { 93 | "name": "stderr", 94 | "output_type": "stream", 95 | "text": [ 96 | "\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "from collections import Counter\n", 102 | "from tqdm import tqdm\n", 103 | "from numpy.linalg import norm\n", 104 | "\n", 105 | "\n", 106 | "k = 2\n", 107 | "\n", 108 | "predicted_classes = []\n", 109 | "\n", 110 | "for i in tqdm(range(df_test.shape[0]), total=df_test.shape[0]):\n", 111 | "\n", 112 | " test_vec = X_test[i].toarray().reshape(-1)\n", 113 | " test_label = df_test.iloc[i][\"label\"]\n", 114 | " distance_from_test_instance = []\n", 115 | " \n", 116 | " for j in range(df_train.shape[0]):\n", 117 | " train_vec = X_train[j].toarray().reshape(-1)\n", 118 | " train_label = df_train.iloc[j][\"label\"]\n", 119 | " \n", 120 | " cosine = 1 - np.dot(test_vec, train_vec)/(norm(test_vec)*norm(train_vec))\n", 121 | " distance_from_test_instance.append(cosine)\n", 122 | " \n", 123 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n", 124 | " top_k_class = np.array(df_train[\"label\"])[sorted_idx[:k]]\n", 125 | " predicted_class = Counter(top_k_class).most_common()[0][0]\n", 126 | " \n", 127 | " predicted_classes.append(predicted_class)\n", 128 | " \n", 129 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))" 130 | ] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3 (ipykernel)", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.10.6" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 5 154 | } 155 | -------------------------------------------------------------------------------- /3_distilbert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3c5d72f4", 6 | "metadata": {}, 7 | "source": [ 8 | "# DistilBERT Finetuning" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "6fd9cda8", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# pip install transformers" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "92ea5612", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# pip install datasets" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "id": "fe7191cf-62ed-4793-8358-bee70b233d05", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# pip install lightning" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "id": "033b75c5", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "torch : 2.0.0\n", 52 | "transformers: 4.27.4\n", 53 | "datasets : 2.11.0\n", 54 | "lightning : 2.0.1\n", 55 | "\n", 56 | "conda environment: finetuning-blog\n", 57 | "\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "%load_ext watermark\n", 63 | "%watermark --conda -p torch,transformers,datasets,lightning" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "09213821-b2b4-402e-adf8-7c7fe4ec57cb", 69 | "metadata": { 70 | "tags": [] 71 | }, 72 | "source": [ 73 | "# 1 Loading the dataset into DataFrames" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 5, 79 | "id": "e39e2228-5f0b-4fb9-b762-df26c2052b45", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# pip install datasets\n", 84 | "\n", 85 | "import os.path as op\n", 86 | "\n", 87 | "from datasets import load_dataset\n", 88 | "\n", 89 | "import lightning as L\n", 90 | "from lightning.pytorch.loggers import CSVLogger\n", 91 | "from lightning.pytorch.callbacks import ModelCheckpoint\n", 92 | "\n", 93 | "import numpy as np\n", 94 | "import pandas as pd\n", 95 | "import torch\n", 96 | "\n", 97 | "from sklearn.feature_extraction.text import CountVectorizer\n", 98 | "\n", 99 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset\n", 100 | "from local_dataset_utilities import IMDBDataset" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "id": "fb31ac90-9e3a-41d0-baf1-8e613043924b", 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "100%|███████████████████████████████████████████| 50000/50000 [00:24<00:00, 2023.24it/s]\n" 114 | ] 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Class distribution:\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "download_dataset()\n", 126 | "\n", 127 | "df = load_dataset_into_to_dataframe()\n", 128 | "partition_dataset(df)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 7, 134 | "id": "221f30a1-b433-4304-a18d-8d03abd42b58", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "df_train = pd.read_csv(\"train.csv\")\n", 139 | "df_val = pd.read_csv(\"val.csv\")\n", 140 | "df_test = pd.read_csv(\"test.csv\")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "876736c1-ae27-491c-850b-050507fa02b5", 146 | "metadata": {}, 147 | "source": [ 148 | "# 2 Tokenization and Numericalization" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "id": "afe0cca0-bac4-49ed-982c-14c998e578d1", 154 | "metadata": {}, 155 | "source": [ 156 | "**Load the dataset via `load_dataset`**" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 8, 162 | "id": "a1aa66c7", 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "Downloading and preparing dataset csv/default to /home/sebastian/.cache/huggingface/datasets/csv/default-3e50991f5e7f1651/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...\n" 170 | ] 171 | }, 172 | { 173 | "data": { 174 | "application/vnd.jupyter.widget-view+json": { 175 | "model_id": "9d9091423f5c4c7f8ce30a4208df97ce", 176 | "version_major": 2, 177 | "version_minor": 0 178 | }, 179 | "text/plain": [ 180 | "Downloading data files: 0%| | 0/3 [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 802 | "┃ Test metric DataLoader 0 ┃\n", 803 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 804 | "│ accuracy 0.9919999837875366 │\n", 805 | "└───────────────────────────┴───────────────────────────┘\n", 806 | "\n" 807 | ], 808 | "text/plain": [ 809 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 810 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 811 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 812 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9919999837875366 \u001b[0m\u001b[35m \u001b[0m│\n", 813 | "└───────────────────────────┴───────────────────────────┘\n" 814 | ] 815 | }, 816 | "metadata": {}, 817 | "output_type": "display_data" 818 | }, 819 | { 820 | "data": { 821 | "text/plain": [ 822 | "[{'accuracy': 0.9919999837875366}]" 823 | ] 824 | }, 825 | "execution_count": 22, 826 | "metadata": {}, 827 | "output_type": "execute_result" 828 | } 829 | ], 830 | "source": [ 831 | "trainer.test(lightning_model, dataloaders=train_loader, ckpt_path=\"best\")" 832 | ] 833 | }, 834 | { 835 | "cell_type": "code", 836 | "execution_count": 23, 837 | "id": "10ca0af1-106e-4ef7-9793-478d580af827", 838 | "metadata": {}, 839 | "outputs": [ 840 | { 841 | "name": "stderr", 842 | "output_type": "stream", 843 | "text": [ 844 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 845 | "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n", 846 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", 847 | "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n" 848 | ] 849 | }, 850 | { 851 | "data": { 852 | "application/vnd.jupyter.widget-view+json": { 853 | "model_id": "1b3d7837d2c149c0ad84347b03c0179e", 854 | "version_major": 2, 855 | "version_minor": 0 856 | }, 857 | "text/plain": [ 858 | "Testing: 0it [00:00, ?it/s]" 859 | ] 860 | }, 861 | "metadata": {}, 862 | "output_type": "display_data" 863 | }, 864 | { 865 | "data": { 866 | "text/html": [ 867 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
868 |        "┃        Test metric               DataLoader 0        ┃\n",
869 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
870 |        "│         accuracy              0.9251999855041504     │\n",
871 |        "└───────────────────────────┴───────────────────────────┘\n",
872 |        "
\n" 873 | ], 874 | "text/plain": [ 875 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 876 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 877 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 878 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9251999855041504 \u001b[0m\u001b[35m \u001b[0m│\n", 879 | "└───────────────────────────┴───────────────────────────┘\n" 880 | ] 881 | }, 882 | "metadata": {}, 883 | "output_type": "display_data" 884 | }, 885 | { 886 | "data": { 887 | "text/plain": [ 888 | "[{'accuracy': 0.9251999855041504}]" 889 | ] 890 | }, 891 | "execution_count": 23, 892 | "metadata": {}, 893 | "output_type": "execute_result" 894 | } 895 | ], 896 | "source": [ 897 | "trainer.test(lightning_model, dataloaders=val_loader, ckpt_path=\"best\")" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": 24, 903 | "id": "eeb92de4-d483-4627-b9f3-f0bba0cddd9c", 904 | "metadata": {}, 905 | "outputs": [ 906 | { 907 | "name": "stderr", 908 | "output_type": "stream", 909 | "text": [ 910 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 911 | "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n", 912 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", 913 | "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n" 914 | ] 915 | }, 916 | { 917 | "data": { 918 | "application/vnd.jupyter.widget-view+json": { 919 | "model_id": "3a05c398964e469c928cac221541e4fd", 920 | "version_major": 2, 921 | "version_minor": 0 922 | }, 923 | "text/plain": [ 924 | "Testing: 0it [00:00, ?it/s]" 925 | ] 926 | }, 927 | "metadata": {}, 928 | "output_type": "display_data" 929 | }, 930 | { 931 | "data": { 932 | "text/html": [ 933 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
934 |        "┃        Test metric               DataLoader 0        ┃\n",
935 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
936 |        "│         accuracy              0.9214000105857849     │\n",
937 |        "└───────────────────────────┴───────────────────────────┘\n",
938 |        "
\n" 939 | ], 940 | "text/plain": [ 941 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 942 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 943 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 944 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9214000105857849 \u001b[0m\u001b[35m \u001b[0m│\n", 945 | "└───────────────────────────┴───────────────────────────┘\n" 946 | ] 947 | }, 948 | "metadata": {}, 949 | "output_type": "display_data" 950 | }, 951 | { 952 | "data": { 953 | "text/plain": [ 954 | "[{'accuracy': 0.9214000105857849}]" 955 | ] 956 | }, 957 | "execution_count": 24, 958 | "metadata": {}, 959 | "output_type": "execute_result" 960 | } 961 | ], 962 | "source": [ 963 | "trainer.test(lightning_model, dataloaders=test_loader, ckpt_path=\"best\")" 964 | ] 965 | } 966 | ], 967 | "metadata": { 968 | "kernelspec": { 969 | "display_name": "Python 3 (ipykernel)", 970 | "language": "python", 971 | "name": "python3" 972 | }, 973 | "language_info": { 974 | "codemirror_mode": { 975 | "name": "ipython", 976 | "version": 3 977 | }, 978 | "file_extension": ".py", 979 | "mimetype": "text/x-python", 980 | "name": "python", 981 | "nbconvert_exporter": "python", 982 | "pygments_lexer": "ipython3", 983 | "version": "3.10.6" 984 | } 985 | }, 986 | "nbformat": 4, 987 | "nbformat_minor": 5 988 | } 989 | -------------------------------------------------------------------------------- /4_r8-dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9a8d1036-9af3-425b-af79-542cb5698183", 6 | "metadata": {}, 7 | "source": [ 8 | "## Experiments on R8 dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "81168bb4-e182-4cf0-9eff-72f0aa495401", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebooks runs the proposed method on the R8 dataset that was reported in the original paper:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "cce35398-f627-4a50-a0eb-af2ad98ff75e", 22 | "metadata": {}, 23 | "source": [ 24 | "![](figures/r8.png)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "73cc2896-1dc1-4bc2-8d06-55bb0f6813bd", 30 | "metadata": {}, 31 | "source": [ 32 | "Note that the scores in the original paper are inflated or overly optimistic because of a bug in their code repository, which was described on [https://kenschutte.com/gzip-knn-paper/](https://kenschutte.com/gzip-knn-paper/)." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "id": "6122b98c-af6f-424c-a498-cee4cd008477", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import gzip\n", 43 | "import os.path as op\n", 44 | "\n", 45 | "import numpy as np\n", 46 | "import pandas as pd" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "502672d8-2a4c-4504-ac47-684962fa7bc2", 52 | "metadata": {}, 53 | "source": [ 54 | "### Load dataset" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "018be464-1ee4-4048-a8bc-0b1ecf8c9a76", 60 | "metadata": {}, 61 | "source": [ 62 | "Before running the code below, make sure to download the dataset from here: https://www.kaggle.com/datasets/weipengfei/ohr8r52" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "id": "f11f439a-1e31-4881-a7c2-a9633693f202", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "df_train = pd.read_csv(\"r8-train-stemmed.csv\")\n", 73 | "df_test = pd.read_csv(\"r8-test-stemmed.csv\")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "id": "fe2f6e18-b3d5-4333-93aa-782a287d0350", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "{'money-fx': 0,\n", 86 | " 'crude': 1,\n", 87 | " 'interest': 2,\n", 88 | " 'trade': 3,\n", 89 | " 'earn': 4,\n", 90 | " 'grain': 5,\n", 91 | " 'ship': 6,\n", 92 | " 'acq': 7}" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "uniq = list(set(df_train[\"intent\"].values))\n", 102 | "labels = {j:i for i,j in zip(range(len(uniq)), uniq)}\n", 103 | "labels" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "5b9bf7a6-cc26-4d9f-a2c8-ea1d736b7997", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "df_train[\"label\"] = df_train[\"intent\"].apply(lambda x: labels[x])\n", 114 | "df_test[\"label\"] = df_test[\"intent\"].apply(lambda x: labels[x])" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "5978af4b-20e2-4031-8c7f-3c2bbd6eb906", 120 | "metadata": {}, 121 | "source": [ 122 | "## Original" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "2f1c8f54-82bf-45cf-a273-af4106323998", 128 | "metadata": {}, 129 | "source": [ 130 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n", 131 | "\n", 132 | "\n", 133 | "\n", 134 | "\n", 135 | "\n", 136 | "- Same code as [1_1_nn_plus_gzip_original.ipynb](1_1_nn_plus_gzip_original.ipynb)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "id": "ce436c63-2eea-46a2-95ca-cfaab0acf348", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stderr", 147 | "output_type": "stream", 148 | "text": [ 149 | "100%|███████████████████████████████████████| 2189/2189 [09:44<00:00, 3.74it/s]" 150 | ] 151 | }, 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "Accuracy: 0.8889904065783463\n" 157 | ] 158 | }, 159 | { 160 | "name": "stderr", 161 | "output_type": "stream", 162 | "text": [ 163 | "\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "k = 2\n", 169 | "\n", 170 | "predicted_classes = []\n", 171 | "\n", 172 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n", 173 | " test_text = row_test[1][\"text\"]\n", 174 | " test_label = row_test[1][\"label\"]\n", 175 | " c_test_text = len(gzip.compress(test_text.encode()))\n", 176 | " distance_from_test_instance = []\n", 177 | " \n", 178 | " for row_train in df_train.iterrows():\n", 179 | " train_text = row_train[1][\"text\"]\n", 180 | " train_label = row_train[1][\"label\"]\n", 181 | " c_train_text = len(gzip.compress(train_text.encode()))\n", 182 | " \n", 183 | " train_plus_test = \" \".join([test_text, train_text])\n", 184 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n", 185 | " \n", 186 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n", 187 | " / max(c_test_text, c_train_text) )\n", 188 | " distance_from_test_instance.append(ncd)\n", 189 | " \n", 190 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n", 191 | " \n", 192 | " #top_k_class = list(df_train.iloc[sorted_idx[:k]][\"label\"].values)\n", 193 | " #predicted_class = max(set(top_k_class), key=top_k_class.count)\n", 194 | " top_k_class = df_train.iloc[sorted_idx[:k]][\"label\"].values\n", 195 | " predicted_class = np.argmax(np.bincount(top_k_class))\n", 196 | " \n", 197 | " predicted_classes.append(predicted_class)\n", 198 | " \n", 199 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "98eecfd5-0eaf-4274-b48c-7f9c27b0d0f4", 205 | "metadata": {}, 206 | "source": [ 207 | "## With Tie-Breaking Fix" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "fc2bd458-a7f8-46d0-9c08-cecbd3daae8c", 213 | "metadata": {}, 214 | "source": [ 215 | "With improved tie breaking using `Counter` as described in [0_some-concepts.ipynb](0_some-concepts.ipynb). \n", 216 | "\n", 217 | "- Same code as [1_2_nn_plus_gzip_fix-tie-breaking.ipynb](1_2_nn_plus_gzip_fix-tie-breaking.ipynb)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 5, 223 | "id": "fa2bee8b-0ac4-42f4-95d2-5abd1ce6cebe", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stderr", 228 | "output_type": "stream", 229 | "text": [ 230 | "100%|███████████████████████████████████████| 2189/2189 [09:49<00:00, 3.71it/s]" 231 | ] 232 | }, 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Accuracy: 0.912745545911375\n" 238 | ] 239 | }, 240 | { 241 | "name": "stderr", 242 | "output_type": "stream", 243 | "text": [ 244 | "\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "from tqdm import tqdm\n", 250 | "from collections import Counter\n", 251 | "\n", 252 | "k = 2\n", 253 | "\n", 254 | "predicted_classes = []\n", 255 | "\n", 256 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n", 257 | " test_text = row_test[1][\"text\"]\n", 258 | " test_label = row_test[1][\"label\"]\n", 259 | " c_test_text = len(gzip.compress(test_text.encode()))\n", 260 | " distance_from_test_instance = []\n", 261 | " \n", 262 | " for row_train in df_train.iterrows():\n", 263 | " train_text = row_train[1][\"text\"]\n", 264 | " train_label = row_train[1][\"label\"]\n", 265 | " c_train_text = len(gzip.compress(train_text.encode()))\n", 266 | " \n", 267 | " train_plus_test = \" \".join([test_text, train_text])\n", 268 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n", 269 | " \n", 270 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n", 271 | " / max(c_test_text, c_train_text) )\n", 272 | " distance_from_test_instance.append(ncd)\n", 273 | " \n", 274 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n", 275 | " top_k_class = np.array(df_train[\"label\"])[sorted_idx[:k]]\n", 276 | " predicted_class = Counter(top_k_class).most_common()[0][0]\n", 277 | " \n", 278 | " predicted_classes.append(predicted_class)\n", 279 | " \n", 280 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3 (ipykernel)", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.10.6" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 5 305 | } 306 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nn_plus_gzip 2 | Minimalist reimplementation of the *Gzip and nearest neighbors* method for text classification based on the paper: 3 | 4 | - *“Low-Resource” Text Classification: A Parameter-Free Classification Method with Compressors* ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) by Zhiying Jiang, Matthew Yang, Mikhail Tsirlin, Raphael Tang, Yiqin Dai, and Jimmy Lin 5 | -------------------------------------------------------------------------------- /figures/pseudocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/nn_plus_gzip/2b2a456e4fb51512e4825a443c1ddd3461ee675e/figures/pseudocode.png -------------------------------------------------------------------------------- /figures/r8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/nn_plus_gzip/2b2a456e4fb51512e4825a443c1ddd3461ee675e/figures/r8.png -------------------------------------------------------------------------------- /local_dataset_utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tarfile 4 | import time 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from packaging import version 9 | from torch.utils.data import Dataset 10 | from tqdm import tqdm 11 | import urllib 12 | 13 | 14 | def reporthook(count, block_size, total_size): 15 | global start_time 16 | if count == 0: 17 | start_time = time.time() 18 | return 19 | duration = time.time() - start_time 20 | progress_size = int(count * block_size) 21 | speed = progress_size / (1024.0**2 * duration) 22 | percent = count * block_size * 100.0 / total_size 23 | 24 | sys.stdout.write( 25 | f"\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB " 26 | f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed" 27 | ) 28 | sys.stdout.flush() 29 | 30 | 31 | def download_dataset(): 32 | source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" 33 | target = "aclImdb_v1.tar.gz" 34 | 35 | if os.path.exists(target): 36 | os.remove(target) 37 | 38 | if not os.path.isdir("aclImdb") and not os.path.isfile("aclImdb_v1.tar.gz"): 39 | urllib.request.urlretrieve(source, target, reporthook) 40 | 41 | if not os.path.isdir("aclImdb"): 42 | 43 | with tarfile.open(target, "r:gz") as tar: 44 | tar.extractall() 45 | 46 | 47 | def load_dataset_into_to_dataframe(): 48 | basepath = "aclImdb" 49 | 50 | labels = {"pos": 1, "neg": 0} 51 | 52 | df = pd.DataFrame() 53 | 54 | with tqdm(total=50000) as pbar: 55 | for s in ("test", "train"): 56 | for l in ("pos", "neg"): 57 | path = os.path.join(basepath, s, l) 58 | for file in sorted(os.listdir(path)): 59 | with open(os.path.join(path, file), "r", encoding="utf-8") as infile: 60 | txt = infile.read() 61 | 62 | if version.parse(pd.__version__) >= version.parse("1.3.2"): 63 | x = pd.DataFrame( 64 | [[txt, labels[l]]], columns=["review", "sentiment"] 65 | ) 66 | df = pd.concat([df, x], ignore_index=False) 67 | 68 | else: 69 | df = df.append([[txt, labels[l]]], ignore_index=True) 70 | pbar.update() 71 | df.columns = ["text", "label"] 72 | 73 | np.random.seed(0) 74 | df = df.reindex(np.random.permutation(df.index)) 75 | 76 | print("Class distribution:") 77 | np.bincount(df["label"].values) 78 | 79 | return df 80 | 81 | 82 | def partition_dataset(df): 83 | df_shuffled = df.sample(frac=1, random_state=1).reset_index() 84 | 85 | df_train = df_shuffled.iloc[:35_000] 86 | df_val = df_shuffled.iloc[35_000:40_000] 87 | df_test = df_shuffled.iloc[40_000:] 88 | 89 | df_train.to_csv("train.csv", index=False, encoding="utf-8") 90 | df_val.to_csv("val.csv", index=False, encoding="utf-8") 91 | df_test.to_csv("test.csv", index=False, encoding="utf-8") 92 | 93 | 94 | class IMDBDataset(Dataset): 95 | def __init__(self, dataset_dict, partition_key="train"): 96 | self.partition = dataset_dict[partition_key] 97 | 98 | def __getitem__(self, index): 99 | return self.partition[index] 100 | 101 | def __len__(self): 102 | return self.partition.num_rows --------------------------------------------------------------------------------