├── 0_baseline.ipynb
├── 0_some-concepts.ipynb
├── 1_1_nn_plus_gzip_original.ipynb
├── 1_2_caching-multiprocessing.py
├── 1_2_nn_plus_gzip_fix-tie-breaking.ipynb
├── 2_nn_countvecs.ipynb
├── 3_distilbert.ipynb
├── 4_r8-dataset.ipynb
├── LICENSE
├── README.md
├── figures
├── pseudocode.png
└── r8.png
└── local_dataset_utilities.py
/0_baseline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "f79f00e6-88bc-4d4d-930d-9f346eba5955",
6 | "metadata": {},
7 | "source": [
8 | "# Baseline accuracy"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "6e5e34ee-34b1-472d-8da5-09f50ca5a23e",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import gzip\n",
19 | "import os.path as op\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "\n",
24 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "id": "aceb0005-1dcd-4735-8cff-fc3b10baae4f",
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "name": "stderr",
35 | "output_type": "stream",
36 | "text": [
37 | "100%|███████████████████████████████████| 50000/50000 [00:19<00:00, 2542.67it/s]\n"
38 | ]
39 | },
40 | {
41 | "name": "stdout",
42 | "output_type": "stream",
43 | "text": [
44 | "Class distribution:\n"
45 | ]
46 | }
47 | ],
48 | "source": [
49 | "if not (op.isfile(\"train.csv\") and op.isfile(\"val.csv\") and op.isfile(\"test.csv\")):\n",
50 | " download_dataset()\n",
51 | "\n",
52 | " df = load_dataset_into_to_dataframe()\n",
53 | " partition_dataset(df)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 3,
59 | "id": "47535727-bbc5-44ba-ae42-bcd34781adcb",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "df_train = pd.read_csv(\"train.csv\")\n",
64 | "df_val = pd.read_csv(\"val.csv\")\n",
65 | "df_test = pd.read_csv(\"test.csv\")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 4,
71 | "id": "0bde64f1-a5d0-4269-a9dc-eaa6d2159872",
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/plain": [
77 | "(35000, 3)"
78 | ]
79 | },
80 | "execution_count": 4,
81 | "metadata": {},
82 | "output_type": "execute_result"
83 | }
84 | ],
85 | "source": [
86 | "df_train.shape"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 5,
92 | "id": "8697ead4-4986-45b8-abe0-1fab35afc4ca",
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "data": {
97 | "text/plain": [
98 | "(10000, 3)"
99 | ]
100 | },
101 | "execution_count": 5,
102 | "metadata": {},
103 | "output_type": "execute_result"
104 | }
105 | ],
106 | "source": [
107 | "df_test.shape"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 6,
113 | "id": "db443420-3875-4b24-adaa-9549aa98a536",
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "data": {
118 | "text/plain": [
119 | "array([5006, 4994])"
120 | ]
121 | },
122 | "execution_count": 6,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "bcnt = np.bincount(df_test[\"label\"].values)\n",
129 | "bcnt"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 7,
135 | "id": "3703d434-c14d-4562-bbe4-bd3841791235",
136 | "metadata": {},
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "Baseline accuracy: 0.5006\n"
143 | ]
144 | }
145 | ],
146 | "source": [
147 | "print(\"Baseline accuracy:\", np.max(bcnt)/ bcnt.sum())"
148 | ]
149 | }
150 | ],
151 | "metadata": {
152 | "kernelspec": {
153 | "display_name": "Python 3 (ipykernel)",
154 | "language": "python",
155 | "name": "python3"
156 | },
157 | "language_info": {
158 | "codemirror_mode": {
159 | "name": "ipython",
160 | "version": 3
161 | },
162 | "file_extension": ".py",
163 | "mimetype": "text/x-python",
164 | "name": "python",
165 | "nbconvert_exporter": "python",
166 | "pygments_lexer": "ipython3",
167 | "version": "3.10.6"
168 | }
169 | },
170 | "nbformat": 4,
171 | "nbformat_minor": 5
172 | }
173 |
--------------------------------------------------------------------------------
/0_some-concepts.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "33075d8c-6c74-46bd-a19e-89f75262ff72",
6 | "metadata": {},
7 | "source": [
8 | "# Compression"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "b3cf065e-0871-46dc-917a-637ac6590b31",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import gzip\n",
19 | "\n",
20 | "txt_1 = \"hello world\"\n",
21 | "txt_2 = \"some text some text some text\""
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "id": "0ddb828f-2480-491d-ae66-c1a2dc902523",
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "data": {
32 | "text/plain": [
33 | "31"
34 | ]
35 | },
36 | "execution_count": 2,
37 | "metadata": {},
38 | "output_type": "execute_result"
39 | }
40 | ],
41 | "source": [
42 | "len(gzip.compress(txt_1.encode()))"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "id": "fac22ce9-a243-433a-81ed-0bad131ec7ae",
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "data": {
53 | "text/plain": [
54 | "33"
55 | ]
56 | },
57 | "execution_count": 3,
58 | "metadata": {},
59 | "output_type": "execute_result"
60 | }
61 | ],
62 | "source": [
63 | "len(gzip.compress(txt_2.encode()))"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "id": "c99131b6-9ef1-45de-af26-c47f525956d9",
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "data": {
74 | "text/plain": [
75 | "43"
76 | ]
77 | },
78 | "execution_count": 4,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "len(gzip.compress(\" \".join([txt_1, txt_2]).encode()))"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 5,
90 | "id": "771a37e0-67fe-4274-962f-5a99911c451a",
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "data": {
95 | "text/plain": [
96 | "34"
97 | ]
98 | },
99 | "execution_count": 5,
100 | "metadata": {},
101 | "output_type": "execute_result"
102 | }
103 | ],
104 | "source": [
105 | "len(gzip.compress(\" \".join([txt_1, txt_1]).encode()))"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "id": "36303cb5-147c-41b0-9758-fb16753774ff",
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "data": {
116 | "text/plain": [
117 | "33"
118 | ]
119 | },
120 | "execution_count": 6,
121 | "metadata": {},
122 | "output_type": "execute_result"
123 | }
124 | ],
125 | "source": [
126 | "len(gzip.compress(\" \".join([txt_2, txt_2]).encode()))"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "id": "b9327048-17ef-4401-b3fd-f462f1e57888",
132 | "metadata": {},
133 | "source": [
134 | "# Tie breaking"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "id": "71b93107-32b8-4b14-b7b1-d98d6049f154",
140 | "metadata": {},
141 | "source": [
142 | "Original code always selects index with lowest label in case of a tie"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 7,
148 | "id": "97b3e05d-27da-4c77-b2d1-e97f28b11d15",
149 | "metadata": {},
150 | "outputs": [
151 | {
152 | "data": {
153 | "text/plain": [
154 | "0"
155 | ]
156 | },
157 | "execution_count": 7,
158 | "metadata": {},
159 | "output_type": "execute_result"
160 | }
161 | ],
162 | "source": [
163 | "top_k_class = [0, 1]\n",
164 | "max(set(top_k_class), key=top_k_class.count)"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 8,
170 | "id": "bbeba7c4-492e-438d-960d-f887fd4ac455",
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": [
176 | "0"
177 | ]
178 | },
179 | "execution_count": 8,
180 | "metadata": {},
181 | "output_type": "execute_result"
182 | }
183 | ],
184 | "source": [
185 | "top_k_class = [1, 0]\n",
186 | "max(set(top_k_class), key=top_k_class.count)"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 9,
192 | "id": "c8acd613-1f2a-4759-955b-800513024d1d",
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "data": {
197 | "text/plain": [
198 | "0"
199 | ]
200 | },
201 | "execution_count": 9,
202 | "metadata": {},
203 | "output_type": "execute_result"
204 | }
205 | ],
206 | "source": [
207 | "top_k_class = [1, 0, 2]\n",
208 | "max(set(top_k_class), key=top_k_class.count)"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "id": "5a9f3761-e957-4f5e-acc9-aa13df736cf2",
214 | "metadata": {},
215 | "source": [
216 | "We can prevent this using Counter, which selects the first label in case of a tie. If labels are sorted by distance, we can ensure it's picking the closest neighbor in case of a tie, which is a more reasonable choice than always selecting the lowest-index class:"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 10,
222 | "id": "1bcc6ca8-1acb-48fa-9526-f470ea7da06f",
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "from collections import Counter"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 11,
232 | "id": "b40df133-47c8-4006-80a8-217e717d95d6",
233 | "metadata": {},
234 | "outputs": [
235 | {
236 | "data": {
237 | "text/plain": [
238 | "0"
239 | ]
240 | },
241 | "execution_count": 11,
242 | "metadata": {},
243 | "output_type": "execute_result"
244 | }
245 | ],
246 | "source": [
247 | "top_k_class = [0, 1]\n",
248 | "\n",
249 | "Counter(top_k_class).most_common()[0][0]"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 12,
255 | "id": "43e250e8-36db-4f9b-952f-35e4faf793c2",
256 | "metadata": {},
257 | "outputs": [
258 | {
259 | "data": {
260 | "text/plain": [
261 | "1"
262 | ]
263 | },
264 | "execution_count": 12,
265 | "metadata": {},
266 | "output_type": "execute_result"
267 | }
268 | ],
269 | "source": [
270 | "top_k_class = [1, 0]\n",
271 | "\n",
272 | "Counter(top_k_class).most_common()[0][0]"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 16,
278 | "id": "1a2ccec3-a2db-40b6-bddc-ff2cbe721b5c",
279 | "metadata": {},
280 | "outputs": [
281 | {
282 | "data": {
283 | "text/plain": [
284 | "1"
285 | ]
286 | },
287 | "execution_count": 16,
288 | "metadata": {},
289 | "output_type": "execute_result"
290 | }
291 | ],
292 | "source": [
293 | "top_k_class = [1, 2, 0]\n",
294 | "\n",
295 | "Counter(top_k_class).most_common()[0][0]"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 14,
301 | "id": "141f5817-2157-454c-82b7-b5024b2f4018",
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "data": {
306 | "text/plain": [
307 | "2"
308 | ]
309 | },
310 | "execution_count": 14,
311 | "metadata": {},
312 | "output_type": "execute_result"
313 | }
314 | ],
315 | "source": [
316 | "top_k_class = [1, 0, 2, 2]\n",
317 | "\n",
318 | "Counter(top_k_class).most_common()[0][0]"
319 | ]
320 | },
321 | {
322 | "cell_type": "markdown",
323 | "id": "e50515d5-b4cd-4c56-9a2c-9b0416367859",
324 | "metadata": {},
325 | "source": [
326 | "### Count vectors"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 19,
332 | "id": "efea1520-aed1-48b6-a9f5-6ac304162f52",
333 | "metadata": {},
334 | "outputs": [
335 | {
336 | "name": "stdout",
337 | "output_type": "stream",
338 | "text": [
339 | "[0. 0.75 0.25]\n",
340 | "[0. 0.75 0.25]\n",
341 | "[0. 0.75 0.25]\n"
342 | ]
343 | }
344 | ],
345 | "source": [
346 | "import numpy as np\n",
347 | "\n",
348 | "text_1 = np.array([0., 3., 1.]) \n",
349 | "text_2 = np.array([0., 3., 1.])\n",
350 | "\n",
351 | "text_1 /= np.sum(text_1)\n",
352 | "text_2 /= np.sum(text_2)\n",
353 | "\n",
354 | "print(text_1)\n",
355 | "print(text_2)\n",
356 | "\n",
357 | "added = text_1 + text_2\n",
358 | "\n",
359 | "print(added / np.sum(added))"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "id": "89f20d55-a460-4de3-ab5d-40322be8ddd1",
366 | "metadata": {},
367 | "outputs": [],
368 | "source": []
369 | }
370 | ],
371 | "metadata": {
372 | "kernelspec": {
373 | "display_name": "Python 3 (ipykernel)",
374 | "language": "python",
375 | "name": "python3"
376 | },
377 | "language_info": {
378 | "codemirror_mode": {
379 | "name": "ipython",
380 | "version": 3
381 | },
382 | "file_extension": ".py",
383 | "mimetype": "text/x-python",
384 | "name": "python",
385 | "nbconvert_exporter": "python",
386 | "pygments_lexer": "ipython3",
387 | "version": "3.10.6"
388 | }
389 | },
390 | "nbformat": 4,
391 | "nbformat_minor": 5
392 | }
393 |
--------------------------------------------------------------------------------
/1_1_nn_plus_gzip_original.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "965aa954-7744-4ecb-8b38-a023f3c1b9af",
6 | "metadata": {},
7 | "source": [
8 | "# NN + Gzip on IMDB Movie Review Dataset"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "275f3c53-b5c7-4856-9a15-656a98b33fd8",
14 | "metadata": {},
15 | "source": [
16 | "# NN + Gzip on IMDB Movie Review Dataset\n",
17 | "\n",
18 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n",
19 | "\n",
20 | "\n",
21 | ""
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 1,
27 | "id": "54b93603-f41c-4016-87aa-59998990075c",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import gzip\n",
32 | "import os.path as op\n",
33 | "\n",
34 | "import numpy as np\n",
35 | "import pandas as pd\n",
36 | "\n",
37 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "id": "a03e71dc-e09e-4907-bc06-f8250b97005e",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n",
48 | " download_dataset()\n",
49 | "\n",
50 | " df = load_dataset_into_to_dataframe()\n",
51 | " partition_dataset(df)"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "id": "bfff472d-57c1-4310-8a1b-9a3e4339646a",
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "df_train = pd.read_csv(\"train.csv\")\n",
62 | "df_val = pd.read_csv(\"val.csv\")\n",
63 | "df_test = pd.read_csv(\"test.csv\")"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "id": "29fbe3bd-c873-4372-9c79-ea38b751608c",
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "name": "stderr",
74 | "output_type": "stream",
75 | "text": [
76 | "100%|██████████████████████████████████| 10000/10000 [21:16:23<00:00, 7.66s/it]\n"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "from tqdm import tqdm\n",
82 | "\n",
83 | "k = 2\n",
84 | "\n",
85 | "predicted_classes = []\n",
86 | "\n",
87 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n",
88 | " test_text = row_test[1][\"text\"]\n",
89 | " test_label = row_test[1][\"label\"]\n",
90 | " c_test_text = len(gzip.compress(test_text.encode()))\n",
91 | " distance_from_test_instance = []\n",
92 | " \n",
93 | " for row_train in df_train.iterrows():\n",
94 | " train_text = row_train[1][\"text\"]\n",
95 | " train_label = row_train[1][\"label\"]\n",
96 | " c_train_text = len(gzip.compress(train_text.encode()))\n",
97 | " \n",
98 | " train_plus_test = \" \".join([test_text, train_text])\n",
99 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n",
100 | " \n",
101 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n",
102 | " / max(c_test_text, c_train_text) )\n",
103 | " distance_from_test_instance.append(ncd)\n",
104 | " \n",
105 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n",
106 | " \n",
107 | " #top_k_class = list(df_train.iloc[sorted_idx[:k]][\"label\"].values)\n",
108 | " #predicted_class = max(set(top_k_class), key=top_k_class.count)\n",
109 | " top_k_class = df_train.iloc[sorted_idx[:k]][\"label\"].values\n",
110 | " predicted_class = np.argmax(np.bincount(top_k_class))\n",
111 | " \n",
112 | " predicted_classes.append(predicted_class)\n",
113 | " "
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "id": "1f44b0d2-8303-409c-b40f-7910ab415da1",
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "Accuracy: 0.7005\n"
127 | ]
128 | }
129 | ],
130 | "source": [
131 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))"
132 | ]
133 | }
134 | ],
135 | "metadata": {
136 | "kernelspec": {
137 | "display_name": "Python 3 (ipykernel)",
138 | "language": "python",
139 | "name": "python3"
140 | },
141 | "language_info": {
142 | "codemirror_mode": {
143 | "name": "ipython",
144 | "version": 3
145 | },
146 | "file_extension": ".py",
147 | "mimetype": "text/x-python",
148 | "name": "python",
149 | "nbconvert_exporter": "python",
150 | "pygments_lexer": "ipython3",
151 | "version": "3.10.6"
152 | }
153 | },
154 | "nbformat": 4,
155 | "nbformat_minor": 5
156 | }
157 |
--------------------------------------------------------------------------------
/1_2_caching-multiprocessing.py:
--------------------------------------------------------------------------------
1 | # Parallel processing version of 1_2_nn_plus_gzip_fix-tie-breaking.ipynb
2 | # On a 2020 MacBook Air, it runs about 4 times faster ~1 iter/sec
3 | # than the non-parallel version (~4 iter/sec)
4 |
5 | # It should finish in about 2-3 h compared to ~12 h before
6 |
7 | from collections import Counter
8 | import gzip
9 | import multiprocessing as mp
10 | import os.path as op
11 |
12 | from joblib import Parallel, delayed
13 | import numpy as np
14 | import pandas as pd
15 | from tqdm import tqdm
16 |
17 | from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
18 |
19 |
20 | def process_dataset_subset(df_train_subset, test_text, c_test_text, d):
21 |
22 | distances_to_test = []
23 | for row_train in df_train_subset.iterrows():
24 | index = row_train[0]
25 | train_text = row_train[1]["text"]
26 | c_train_text = d[index]
27 |
28 | train_plus_test = " ".join([test_text, train_text])
29 | c_train_plus_test = len(gzip.compress(train_plus_test.encode()))
30 |
31 | ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))
32 | / max(c_test_text, c_train_text) )
33 |
34 | distances_to_test.append(ncd)
35 |
36 | return distances_to_test
37 |
38 |
39 | def divide_range_into_chunks(start, end, num_chunks):
40 | chunk_size = (end - start) // num_chunks
41 | ranges = [(i, i + chunk_size) for i in range(start, end, chunk_size)]
42 | ranges[-1] = (ranges[-1][0], end) # Ensure the last chunk includes the end
43 | return ranges
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | if not op.isfile("train.csv") and not op.isfile("val.csv") and not op.isfile("test.csv"):
49 | download_dataset()
50 |
51 | df = load_dataset_into_to_dataframe()
52 | partition_dataset(df)
53 |
54 | df_train = pd.read_csv("train.csv")
55 | df_val = pd.read_csv("val.csv")
56 | df_test = pd.read_csv("test.csv")
57 |
58 | num_processes = mp.cpu_count()
59 | k = 2
60 | predicted_classes = []
61 |
62 | start = 0
63 | end = df_train.shape[0]
64 | ranges = divide_range_into_chunks(start, end, num_chunks=num_processes)
65 |
66 |
67 | # caching compressed training examples
68 | d = {}
69 | for i, row_train in enumerate(df_train.iterrows()):
70 | train_text = row_train[1]["text"]
71 | train_label = row_train[1]["label"]
72 | c_train_text = len(gzip.compress(train_text.encode()))
73 |
74 | d[i] = c_train_text
75 |
76 | # main loop
77 | for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):
78 |
79 | test_text = row_test[1]["text"]
80 | test_label = row_test[1]["label"]
81 | c_test_text = len(gzip.compress(test_text.encode()))
82 | all_train_distances_to_test = []
83 |
84 | # parallelize iteration over training set into num_processes chunks
85 | with Parallel(n_jobs=num_processes, backend="loky") as parallel:
86 |
87 | results = parallel(
88 | delayed(process_dataset_subset)(df_train[range_start:range_end], test_text, c_test_text, d)
89 | for range_start, range_end in ranges
90 | )
91 | for p in results:
92 | all_train_distances_to_test.extend(p)
93 |
94 | sorted_idx = np.argsort(np.array(all_train_distances_to_test.extend))
95 | top_k_class = np.array(df_train["label"])[sorted_idx[:k]]
96 | predicted_class = Counter(top_k_class).most_common()[0][0]
97 |
98 | predicted_classes.append(predicted_class)
99 |
100 | print("Accuracy:", np.mean(np.array(predicted_classes) == df_test["label"].values))
--------------------------------------------------------------------------------
/1_2_nn_plus_gzip_fix-tie-breaking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "965aa954-7744-4ecb-8b38-a023f3c1b9af",
6 | "metadata": {},
7 | "source": [
8 | "# NN + Gzip on IMDB Movie Review Dataset\n",
9 | "\n",
10 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n",
11 | "\n",
12 | "\n",
13 | "\n",
14 | "\n",
15 | "\n",
16 | "**Modified to break ties based on choosing the closest neighbors** instead of the lowest index (see explanation in [0_some-concepts.ipynb](0_some-concepts.ipynb)).\n",
17 | "\n",
18 | "
"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "id": "54b93603-f41c-4016-87aa-59998990075c",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import gzip\n",
29 | "import os.path as op\n",
30 | "\n",
31 | "import numpy as np\n",
32 | "import pandas as pd\n",
33 | "\n",
34 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "id": "a03e71dc-e09e-4907-bc06-f8250b97005e",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n",
45 | " download_dataset()\n",
46 | "\n",
47 | " df = load_dataset_into_to_dataframe()\n",
48 | " partition_dataset(df)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 3,
54 | "id": "bfff472d-57c1-4310-8a1b-9a3e4339646a",
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "df_train = pd.read_csv(\"train.csv\")\n",
59 | "df_val = pd.read_csv(\"val.csv\")\n",
60 | "df_test = pd.read_csv(\"test.csv\")"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "id": "689b4673-db95-4dd6-ad3a-d2d31cad1a16",
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stderr",
71 | "output_type": "stream",
72 | "text": [
73 | "100%|██████████████████████████████████| 10000/10000 [11:40:18<00:00, 4.20s/it]"
74 | ]
75 | },
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "Accuracy: 0.7191\n"
81 | ]
82 | },
83 | {
84 | "name": "stderr",
85 | "output_type": "stream",
86 | "text": [
87 | "\n"
88 | ]
89 | }
90 | ],
91 | "source": [
92 | "from tqdm import tqdm\n",
93 | "from collections import Counter\n",
94 | "\n",
95 | "k = 2\n",
96 | "\n",
97 | "predicted_classes = []\n",
98 | "\n",
99 | "for row_test in tqdm(df_test.iterrows(), total=df_test.shape[0]):\n",
100 | " test_text = row_test[1][\"text\"]\n",
101 | " test_label = row_test[1][\"label\"]\n",
102 | " c_test_text = len(gzip.compress(test_text.encode()))\n",
103 | " distance_from_test_instance = []\n",
104 | " \n",
105 | " for row_train in df_train.iterrows():\n",
106 | " train_text = row_train[1][\"text\"]\n",
107 | " train_label = row_train[1][\"label\"]\n",
108 | " c_train_text = len(gzip.compress(train_text.encode()))\n",
109 | " \n",
110 | " train_plus_test = \" \".join([test_text, train_text])\n",
111 | " c_train_plus_test = len(gzip.compress(train_plus_test.encode()))\n",
112 | " \n",
113 | " ncd = ( (c_train_plus_test - min(c_train_text, c_test_text))\n",
114 | " / max(c_test_text, c_train_text) )\n",
115 | " distance_from_test_instance.append(ncd)\n",
116 | " \n",
117 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n",
118 | " top_k_class = np.array(df_train[\"label\"])[sorted_idx[:k]]\n",
119 | " predicted_class = Counter(top_k_class).most_common()[0][0]\n",
120 | " \n",
121 | " predicted_classes.append(predicted_class)\n",
122 | " \n",
123 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))"
124 | ]
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "Python 3 (ipykernel)",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.10.6"
144 | }
145 | },
146 | "nbformat": 4,
147 | "nbformat_minor": 5
148 | }
149 |
--------------------------------------------------------------------------------
/2_nn_countvecs.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "fb879955-ad5e-4c0d-a342-8772d119598e",
6 | "metadata": {},
7 | "source": [
8 | "# NN + Cosine Distance on IMDB Movie Review Dataset"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "836be7ab-cdc0-4376-ab23-27d54f486f39",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import gzip\n",
19 | "import os.path as op\n",
20 | "\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "\n",
24 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "id": "e23aba95-b18a-411d-9ad2-152c06071575",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "if not op.isfile(\"train.csv\") and not op.isfile(\"val.csv\") and not op.isfile(\"test.csv\"):\n",
35 | " download_dataset()\n",
36 | "\n",
37 | " df = load_dataset_into_to_dataframe()\n",
38 | " partition_dataset(df)"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 3,
44 | "id": "9e01dd62-601f-4eb5-8a64-ed8fc39cd719",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "df_train = pd.read_csv(\"train.csv\")\n",
49 | "df_val = pd.read_csv(\"val.csv\")\n",
50 | "df_test = pd.read_csv(\"test.csv\")"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 4,
56 | "id": "14c67c25-6275-4ec0-9596-73a014adfc8f",
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "from sklearn.feature_extraction.text import CountVectorizer\n",
61 | "\n",
62 | "\n",
63 | "cv = CountVectorizer(lowercase=True, max_features=10_000, stop_words=\"english\")\n",
64 | "\n",
65 | "cv.fit(df_train[\"text\"])\n",
66 | "\n",
67 | "X_train = cv.transform(df_train[\"text\"])\n",
68 | "X_val = cv.transform(df_val[\"text\"])\n",
69 | "X_test = cv.transform(df_test[\"text\"])"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 5,
75 | "id": "e22430d5-96e5-49ff-8aeb-1b7c5be57e26",
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "name": "stderr",
80 | "output_type": "stream",
81 | "text": [
82 | "100%|███████████████████████████████████| 10000/10000 [9:10:43<00:00, 3.30s/it]"
83 | ]
84 | },
85 | {
86 | "name": "stdout",
87 | "output_type": "stream",
88 | "text": [
89 | "Accuracy: 0.6801\n"
90 | ]
91 | },
92 | {
93 | "name": "stderr",
94 | "output_type": "stream",
95 | "text": [
96 | "\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "from collections import Counter\n",
102 | "from tqdm import tqdm\n",
103 | "from numpy.linalg import norm\n",
104 | "\n",
105 | "\n",
106 | "k = 2\n",
107 | "\n",
108 | "predicted_classes = []\n",
109 | "\n",
110 | "for i in tqdm(range(df_test.shape[0]), total=df_test.shape[0]):\n",
111 | "\n",
112 | " test_vec = X_test[i].toarray().reshape(-1)\n",
113 | " test_label = df_test.iloc[i][\"label\"]\n",
114 | " distance_from_test_instance = []\n",
115 | " \n",
116 | " for j in range(df_train.shape[0]):\n",
117 | " train_vec = X_train[j].toarray().reshape(-1)\n",
118 | " train_label = df_train.iloc[j][\"label\"]\n",
119 | " \n",
120 | " cosine = 1 - np.dot(test_vec, train_vec)/(norm(test_vec)*norm(train_vec))\n",
121 | " distance_from_test_instance.append(cosine)\n",
122 | " \n",
123 | " sorted_idx = np.argsort(np.array(distance_from_test_instance))\n",
124 | " top_k_class = np.array(df_train[\"label\"])[sorted_idx[:k]]\n",
125 | " predicted_class = Counter(top_k_class).most_common()[0][0]\n",
126 | " \n",
127 | " predicted_classes.append(predicted_class)\n",
128 | " \n",
129 | "print(\"Accuracy:\", np.mean(np.array(predicted_classes) == df_test[\"label\"].values))"
130 | ]
131 | }
132 | ],
133 | "metadata": {
134 | "kernelspec": {
135 | "display_name": "Python 3 (ipykernel)",
136 | "language": "python",
137 | "name": "python3"
138 | },
139 | "language_info": {
140 | "codemirror_mode": {
141 | "name": "ipython",
142 | "version": 3
143 | },
144 | "file_extension": ".py",
145 | "mimetype": "text/x-python",
146 | "name": "python",
147 | "nbconvert_exporter": "python",
148 | "pygments_lexer": "ipython3",
149 | "version": "3.10.6"
150 | }
151 | },
152 | "nbformat": 4,
153 | "nbformat_minor": 5
154 | }
155 |
--------------------------------------------------------------------------------
/3_distilbert.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "3c5d72f4",
6 | "metadata": {},
7 | "source": [
8 | "# DistilBERT Finetuning"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "6fd9cda8",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# pip install transformers"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "id": "92ea5612",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "# pip install datasets"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 3,
34 | "id": "fe7191cf-62ed-4793-8358-bee70b233d05",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# pip install lightning"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 4,
44 | "id": "033b75c5",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "torch : 2.0.0\n",
52 | "transformers: 4.27.4\n",
53 | "datasets : 2.11.0\n",
54 | "lightning : 2.0.1\n",
55 | "\n",
56 | "conda environment: finetuning-blog\n",
57 | "\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "%load_ext watermark\n",
63 | "%watermark --conda -p torch,transformers,datasets,lightning"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "id": "09213821-b2b4-402e-adf8-7c7fe4ec57cb",
69 | "metadata": {
70 | "tags": []
71 | },
72 | "source": [
73 | "# 1 Loading the dataset into DataFrames"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 5,
79 | "id": "e39e2228-5f0b-4fb9-b762-df26c2052b45",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "# pip install datasets\n",
84 | "\n",
85 | "import os.path as op\n",
86 | "\n",
87 | "from datasets import load_dataset\n",
88 | "\n",
89 | "import lightning as L\n",
90 | "from lightning.pytorch.loggers import CSVLogger\n",
91 | "from lightning.pytorch.callbacks import ModelCheckpoint\n",
92 | "\n",
93 | "import numpy as np\n",
94 | "import pandas as pd\n",
95 | "import torch\n",
96 | "\n",
97 | "from sklearn.feature_extraction.text import CountVectorizer\n",
98 | "\n",
99 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset\n",
100 | "from local_dataset_utilities import IMDBDataset"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 6,
106 | "id": "fb31ac90-9e3a-41d0-baf1-8e613043924b",
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "name": "stderr",
111 | "output_type": "stream",
112 | "text": [
113 | "100%|███████████████████████████████████████████| 50000/50000 [00:24<00:00, 2023.24it/s]\n"
114 | ]
115 | },
116 | {
117 | "name": "stdout",
118 | "output_type": "stream",
119 | "text": [
120 | "Class distribution:\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "download_dataset()\n",
126 | "\n",
127 | "df = load_dataset_into_to_dataframe()\n",
128 | "partition_dataset(df)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 7,
134 | "id": "221f30a1-b433-4304-a18d-8d03abd42b58",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "df_train = pd.read_csv(\"train.csv\")\n",
139 | "df_val = pd.read_csv(\"val.csv\")\n",
140 | "df_test = pd.read_csv(\"test.csv\")"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "id": "876736c1-ae27-491c-850b-050507fa02b5",
146 | "metadata": {},
147 | "source": [
148 | "# 2 Tokenization and Numericalization"
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "id": "afe0cca0-bac4-49ed-982c-14c998e578d1",
154 | "metadata": {},
155 | "source": [
156 | "**Load the dataset via `load_dataset`**"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 8,
162 | "id": "a1aa66c7",
163 | "metadata": {},
164 | "outputs": [
165 | {
166 | "name": "stdout",
167 | "output_type": "stream",
168 | "text": [
169 | "Downloading and preparing dataset csv/default to /home/sebastian/.cache/huggingface/datasets/csv/default-3e50991f5e7f1651/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...\n"
170 | ]
171 | },
172 | {
173 | "data": {
174 | "application/vnd.jupyter.widget-view+json": {
175 | "model_id": "9d9091423f5c4c7f8ce30a4208df97ce",
176 | "version_major": 2,
177 | "version_minor": 0
178 | },
179 | "text/plain": [
180 | "Downloading data files: 0%| | 0/3 [00:00, ?it/s]"
181 | ]
182 | },
183 | "metadata": {},
184 | "output_type": "display_data"
185 | },
186 | {
187 | "data": {
188 | "application/vnd.jupyter.widget-view+json": {
189 | "model_id": "c620f73d0ffe4e94a4336899859b5003",
190 | "version_major": 2,
191 | "version_minor": 0
192 | },
193 | "text/plain": [
194 | "Extracting data files: 0%| | 0/3 [00:00, ?it/s]"
195 | ]
196 | },
197 | "metadata": {},
198 | "output_type": "display_data"
199 | },
200 | {
201 | "data": {
202 | "application/vnd.jupyter.widget-view+json": {
203 | "model_id": "",
204 | "version_major": 2,
205 | "version_minor": 0
206 | },
207 | "text/plain": [
208 | "Generating train split: 0 examples [00:00, ? examples/s]"
209 | ]
210 | },
211 | "metadata": {},
212 | "output_type": "display_data"
213 | },
214 | {
215 | "data": {
216 | "application/vnd.jupyter.widget-view+json": {
217 | "model_id": "",
218 | "version_major": 2,
219 | "version_minor": 0
220 | },
221 | "text/plain": [
222 | "Generating validation split: 0 examples [00:00, ? examples/s]"
223 | ]
224 | },
225 | "metadata": {},
226 | "output_type": "display_data"
227 | },
228 | {
229 | "data": {
230 | "application/vnd.jupyter.widget-view+json": {
231 | "model_id": "",
232 | "version_major": 2,
233 | "version_minor": 0
234 | },
235 | "text/plain": [
236 | "Generating test split: 0 examples [00:00, ? examples/s]"
237 | ]
238 | },
239 | "metadata": {},
240 | "output_type": "display_data"
241 | },
242 | {
243 | "name": "stdout",
244 | "output_type": "stream",
245 | "text": [
246 | "Dataset csv downloaded and prepared to /home/sebastian/.cache/huggingface/datasets/csv/default-3e50991f5e7f1651/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.\n"
247 | ]
248 | },
249 | {
250 | "data": {
251 | "application/vnd.jupyter.widget-view+json": {
252 | "model_id": "dff20829bf3c4a51ac3e59007045d98c",
253 | "version_major": 2,
254 | "version_minor": 0
255 | },
256 | "text/plain": [
257 | " 0%| | 0/3 [00:00, ?it/s]"
258 | ]
259 | },
260 | "metadata": {},
261 | "output_type": "display_data"
262 | },
263 | {
264 | "name": "stdout",
265 | "output_type": "stream",
266 | "text": [
267 | "DatasetDict({\n",
268 | " train: Dataset({\n",
269 | " features: ['index', 'text', 'label'],\n",
270 | " num_rows: 35000\n",
271 | " })\n",
272 | " validation: Dataset({\n",
273 | " features: ['index', 'text', 'label'],\n",
274 | " num_rows: 5000\n",
275 | " })\n",
276 | " test: Dataset({\n",
277 | " features: ['index', 'text', 'label'],\n",
278 | " num_rows: 10000\n",
279 | " })\n",
280 | "})\n"
281 | ]
282 | }
283 | ],
284 | "source": [
285 | "imdb_dataset = load_dataset(\n",
286 | " \"csv\",\n",
287 | " data_files={\n",
288 | " \"train\": \"train.csv\",\n",
289 | " \"validation\": \"val.csv\",\n",
290 | " \"test\": \"test.csv\",\n",
291 | " },\n",
292 | ")\n",
293 | "\n",
294 | "print(imdb_dataset)"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "id": "8b201159-f3fa-4649-8076-eff8bc5535d3",
300 | "metadata": {},
301 | "source": [
302 | "**Tokenize the dataset**"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 9,
308 | "id": "5ea762ba",
309 | "metadata": {},
310 | "outputs": [
311 | {
312 | "name": "stdout",
313 | "output_type": "stream",
314 | "text": [
315 | "Tokenizer input max length: 512\n",
316 | "Tokenizer vocabulary size: 30522\n"
317 | ]
318 | }
319 | ],
320 | "source": [
321 | "from transformers import AutoTokenizer\n",
322 | "\n",
323 | "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
324 | "print(\"Tokenizer input max length:\", tokenizer.model_max_length)\n",
325 | "print(\"Tokenizer vocabulary size:\", tokenizer.vocab_size)"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 10,
331 | "id": "8432c15c",
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "def tokenize_text(batch):\n",
336 | " return tokenizer(batch[\"text\"], truncation=True, padding=True)"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": 11,
342 | "id": "0bb392cf",
343 | "metadata": {},
344 | "outputs": [
345 | {
346 | "data": {
347 | "application/vnd.jupyter.widget-view+json": {
348 | "model_id": "",
349 | "version_major": 2,
350 | "version_minor": 0
351 | },
352 | "text/plain": [
353 | "Map: 0%| | 0/35000 [00:00, ? examples/s]"
354 | ]
355 | },
356 | "metadata": {},
357 | "output_type": "display_data"
358 | },
359 | {
360 | "data": {
361 | "application/vnd.jupyter.widget-view+json": {
362 | "model_id": "",
363 | "version_major": 2,
364 | "version_minor": 0
365 | },
366 | "text/plain": [
367 | "Map: 0%| | 0/5000 [00:00, ? examples/s]"
368 | ]
369 | },
370 | "metadata": {},
371 | "output_type": "display_data"
372 | },
373 | {
374 | "data": {
375 | "application/vnd.jupyter.widget-view+json": {
376 | "model_id": "",
377 | "version_major": 2,
378 | "version_minor": 0
379 | },
380 | "text/plain": [
381 | "Map: 0%| | 0/10000 [00:00, ? examples/s]"
382 | ]
383 | },
384 | "metadata": {},
385 | "output_type": "display_data"
386 | }
387 | ],
388 | "source": [
389 | "imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 12,
395 | "id": "6d4103c3",
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "del imdb_dataset"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": 13,
405 | "id": "89ef894c-978f-47f2-9d61-cb6a9f38e745",
406 | "metadata": {},
407 | "outputs": [],
408 | "source": [
409 | "imdb_tokenized.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 14,
415 | "id": "0ea67091-aeb7-46c1-871f-638ce58d8a0e",
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "import os\n",
420 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
421 | ]
422 | },
423 | {
424 | "cell_type": "markdown",
425 | "id": "7ff16488-abe6-48af-9b03-868b457b0ea3",
426 | "metadata": {},
427 | "source": [
428 | "# 3 Set Up DataLoaders"
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": 15,
434 | "id": "0807b068-7d8f-4055-a26a-177e07dea4c7",
435 | "metadata": {},
436 | "outputs": [],
437 | "source": [
438 | "from torch.utils.data import DataLoader, Dataset\n",
439 | "\n",
440 | "\n",
441 | "class IMDBDataset(Dataset):\n",
442 | " def __init__(self, dataset_dict, partition_key=\"train\"):\n",
443 | " self.partition = dataset_dict[partition_key]\n",
444 | "\n",
445 | " def __getitem__(self, index):\n",
446 | " return self.partition[index]\n",
447 | "\n",
448 | " def __len__(self):\n",
449 | " return self.partition.num_rows"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 16,
455 | "id": "90cb08f3-ef77-4351-8b19-42d99dd24f98",
456 | "metadata": {},
457 | "outputs": [],
458 | "source": [
459 | "train_dataset = IMDBDataset(imdb_tokenized, partition_key=\"train\")\n",
460 | "val_dataset = IMDBDataset(imdb_tokenized, partition_key=\"validation\")\n",
461 | "test_dataset = IMDBDataset(imdb_tokenized, partition_key=\"test\")\n",
462 | "\n",
463 | "train_loader = DataLoader(\n",
464 | " dataset=train_dataset,\n",
465 | " batch_size=12,\n",
466 | " shuffle=True, \n",
467 | " num_workers=4\n",
468 | ")\n",
469 | "\n",
470 | "val_loader = DataLoader(\n",
471 | " dataset=val_dataset,\n",
472 | " batch_size=12,\n",
473 | " num_workers=4\n",
474 | ")\n",
475 | "\n",
476 | "test_loader = DataLoader(\n",
477 | " dataset=test_dataset,\n",
478 | " batch_size=12,\n",
479 | " num_workers=4\n",
480 | ")"
481 | ]
482 | },
483 | {
484 | "cell_type": "markdown",
485 | "id": "78e774ab-45a0-4c48-ad61-a3d0e1927ef4",
486 | "metadata": {},
487 | "source": [
488 | "# 4 Initializing DistilBERT"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": 17,
494 | "id": "dc28ddbe-1a96-4c24-9f5c-40ffdca4a572",
495 | "metadata": {},
496 | "outputs": [
497 | {
498 | "name": "stderr",
499 | "output_type": "stream",
500 | "text": [
501 | "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias']\n",
502 | "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
503 | "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
504 | "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n",
505 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
506 | ]
507 | }
508 | ],
509 | "source": [
510 | "from transformers import AutoModelForSequenceClassification\n",
511 | "\n",
512 | "model = AutoModelForSequenceClassification.from_pretrained(\n",
513 | " \"distilbert-base-uncased\", num_labels=2)"
514 | ]
515 | },
516 | {
517 | "cell_type": "markdown",
518 | "id": "def1cf25-0a7d-4bb2-9419-b7a8fe1c1eab",
519 | "metadata": {},
520 | "source": [
521 | "## 5 Finetuning"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "id": "534f7a59-2c86-4895-ad7c-2cdd675b003a",
527 | "metadata": {},
528 | "source": [
529 | "**Wrap in LightningModule for Training**"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 18,
535 | "id": "9f2c474d",
536 | "metadata": {},
537 | "outputs": [],
538 | "source": [
539 | "import lightning as L\n",
540 | "import torch\n",
541 | "import torchmetrics\n",
542 | "\n",
543 | "\n",
544 | "class CustomLightningModule(L.LightningModule):\n",
545 | " def __init__(self, model, learning_rate=5e-5):\n",
546 | " super().__init__()\n",
547 | "\n",
548 | " self.learning_rate = learning_rate\n",
549 | " self.model = model\n",
550 | "\n",
551 | " self.val_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=2)\n",
552 | " self.test_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=2)\n",
553 | "\n",
554 | " def forward(self, input_ids, attention_mask, labels):\n",
555 | " return self.model(input_ids, attention_mask=attention_mask, labels=labels)\n",
556 | " \n",
557 | " def training_step(self, batch, batch_idx):\n",
558 | " outputs = self(batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"],\n",
559 | " labels=batch[\"label\"]) \n",
560 | " self.log(\"train_loss\", outputs[\"loss\"])\n",
561 | " return outputs[\"loss\"] # this is passed to the optimizer for training\n",
562 | "\n",
563 | " def validation_step(self, batch, batch_idx):\n",
564 | " outputs = self(batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"],\n",
565 | " labels=batch[\"label\"]) \n",
566 | " self.log(\"val_loss\", outputs[\"loss\"], prog_bar=True)\n",
567 | " \n",
568 | " logits = outputs[\"logits\"]\n",
569 | " predicted_labels = torch.argmax(logits, 1)\n",
570 | " self.val_acc(predicted_labels, batch[\"label\"])\n",
571 | " self.log(\"val_acc\", self.val_acc, prog_bar=True)\n",
572 | " \n",
573 | " def test_step(self, batch, batch_idx):\n",
574 | " outputs = self(batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"],\n",
575 | " labels=batch[\"label\"]) \n",
576 | " \n",
577 | " logits = outputs[\"logits\"]\n",
578 | " predicted_labels = torch.argmax(logits, 1)\n",
579 | " self.test_acc(predicted_labels, batch[\"label\"])\n",
580 | " self.log(\"accuracy\", self.test_acc, prog_bar=True)\n",
581 | "\n",
582 | " def configure_optimizers(self):\n",
583 | " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
584 | " return optimizer\n",
585 | " \n",
586 | "\n",
587 | "lightning_model = CustomLightningModule(model)"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": 19,
593 | "id": "e6dab813-e1fc-47cd-87a1-5eb8070699c6",
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "from lightning.pytorch.callbacks import ModelCheckpoint\n",
598 | "from lightning.pytorch.loggers import CSVLogger\n",
599 | "\n",
600 | "\n",
601 | "callbacks = [\n",
602 | " ModelCheckpoint(\n",
603 | " save_top_k=1, mode=\"max\", monitor=\"val_acc\"\n",
604 | " ) # save top 1 model\n",
605 | "]\n",
606 | "logger = CSVLogger(save_dir=\"logs/\", name=\"my-model\")"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": 20,
612 | "id": "492aa043-02da-459e-a266-091b34254ac6",
613 | "metadata": {},
614 | "outputs": [
615 | {
616 | "name": "stderr",
617 | "output_type": "stream",
618 | "text": [
619 | "Using 16bit Automatic Mixed Precision (AMP)\n",
620 | "GPU available: True (cuda), used: True\n",
621 | "TPU available: False, using: 0 TPU cores\n",
622 | "IPU available: False, using: 0 IPUs\n",
623 | "HPU available: False, using: 0 HPUs\n"
624 | ]
625 | }
626 | ],
627 | "source": [
628 | "trainer = L.Trainer(\n",
629 | " max_epochs=3,\n",
630 | " callbacks=callbacks,\n",
631 | " accelerator=\"gpu\",\n",
632 | " precision=\"16-mixed\",\n",
633 | " devices=[2],\n",
634 | " logger=logger,\n",
635 | " log_every_n_steps=10,\n",
636 | ")"
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "execution_count": 21,
642 | "id": "f18bf9b5-d247-405f-86c4-513a52238a14",
643 | "metadata": {},
644 | "outputs": [
645 | {
646 | "name": "stderr",
647 | "output_type": "stream",
648 | "text": [
649 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
650 | "/home/sebastian/miniforge3/envs/finetuning-blog/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory logs/my-model/version_0/checkpoints exists and is not empty.\n",
651 | " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n",
652 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
653 | "\n",
654 | " | Name | Type | Params\n",
655 | "-----------------------------------------------------------------\n",
656 | "0 | model | DistilBertForSequenceClassification | 67.0 M\n",
657 | "1 | val_acc | MulticlassAccuracy | 0 \n",
658 | "2 | test_acc | MulticlassAccuracy | 0 \n",
659 | "-----------------------------------------------------------------\n",
660 | "67.0 M Trainable params\n",
661 | "0 Non-trainable params\n",
662 | "67.0 M Total params\n",
663 | "267.820 Total estimated model params size (MB)\n",
664 | "/home/sebastian/miniforge3/envs/finetuning-blog/lib/python3.9/site-packages/lightning/fabric/loggers/csv_logs.py:188: UserWarning: Experiment logs directory logs/my-model/version_0 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!\n",
665 | " rank_zero_warn(\n"
666 | ]
667 | },
668 | {
669 | "data": {
670 | "application/vnd.jupyter.widget-view+json": {
671 | "model_id": "",
672 | "version_major": 2,
673 | "version_minor": 0
674 | },
675 | "text/plain": [
676 | "Sanity Checking: 0it [00:00, ?it/s]"
677 | ]
678 | },
679 | "metadata": {},
680 | "output_type": "display_data"
681 | },
682 | {
683 | "data": {
684 | "application/vnd.jupyter.widget-view+json": {
685 | "model_id": "2964f4eb454849f0ab94c10e80c6e947",
686 | "version_major": 2,
687 | "version_minor": 0
688 | },
689 | "text/plain": [
690 | "Training: 0it [00:00, ?it/s]"
691 | ]
692 | },
693 | "metadata": {},
694 | "output_type": "display_data"
695 | },
696 | {
697 | "data": {
698 | "application/vnd.jupyter.widget-view+json": {
699 | "model_id": "",
700 | "version_major": 2,
701 | "version_minor": 0
702 | },
703 | "text/plain": [
704 | "Validation: 0it [00:00, ?it/s]"
705 | ]
706 | },
707 | "metadata": {},
708 | "output_type": "display_data"
709 | },
710 | {
711 | "data": {
712 | "application/vnd.jupyter.widget-view+json": {
713 | "model_id": "",
714 | "version_major": 2,
715 | "version_minor": 0
716 | },
717 | "text/plain": [
718 | "Validation: 0it [00:00, ?it/s]"
719 | ]
720 | },
721 | "metadata": {},
722 | "output_type": "display_data"
723 | },
724 | {
725 | "data": {
726 | "application/vnd.jupyter.widget-view+json": {
727 | "model_id": "",
728 | "version_major": 2,
729 | "version_minor": 0
730 | },
731 | "text/plain": [
732 | "Validation: 0it [00:00, ?it/s]"
733 | ]
734 | },
735 | "metadata": {},
736 | "output_type": "display_data"
737 | },
738 | {
739 | "name": "stderr",
740 | "output_type": "stream",
741 | "text": [
742 | "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
743 | ]
744 | },
745 | {
746 | "name": "stdout",
747 | "output_type": "stream",
748 | "text": [
749 | "Time elapsed 7.21 min\n"
750 | ]
751 | }
752 | ],
753 | "source": [
754 | "import time\n",
755 | "start = time.time()\n",
756 | "\n",
757 | "trainer.fit(model=lightning_model,\n",
758 | " train_dataloaders=train_loader,\n",
759 | " val_dataloaders=val_loader)\n",
760 | "\n",
761 | "end = time.time()\n",
762 | "elapsed = end - start\n",
763 | "print(f\"Time elapsed {elapsed/60:.2f} min\")"
764 | ]
765 | },
766 | {
767 | "cell_type": "code",
768 | "execution_count": 22,
769 | "id": "d795778a-70d2-4b04-96fb-598eccbcd1be",
770 | "metadata": {},
771 | "outputs": [
772 | {
773 | "name": "stderr",
774 | "output_type": "stream",
775 | "text": [
776 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
777 | "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n",
778 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
779 | "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n",
780 | "/home/sebastian/miniforge3/envs/finetuning-blog/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:478: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n",
781 | " rank_zero_warn(\n"
782 | ]
783 | },
784 | {
785 | "data": {
786 | "application/vnd.jupyter.widget-view+json": {
787 | "model_id": "8b4da131e12047d1a5405c80a54bd209",
788 | "version_major": 2,
789 | "version_minor": 0
790 | },
791 | "text/plain": [
792 | "Testing: 0it [00:00, ?it/s]"
793 | ]
794 | },
795 | "metadata": {},
796 | "output_type": "display_data"
797 | },
798 | {
799 | "data": {
800 | "text/html": [
801 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 802 | "┃ Test metric ┃ DataLoader 0 ┃\n", 803 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 804 | "│ accuracy │ 0.9919999837875366 │\n", 805 | "└───────────────────────────┴───────────────────────────┘\n", 806 | "\n" 807 | ], 808 | "text/plain": [ 809 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 810 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 811 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 812 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9919999837875366 \u001b[0m\u001b[35m \u001b[0m│\n", 813 | "└───────────────────────────┴───────────────────────────┘\n" 814 | ] 815 | }, 816 | "metadata": {}, 817 | "output_type": "display_data" 818 | }, 819 | { 820 | "data": { 821 | "text/plain": [ 822 | "[{'accuracy': 0.9919999837875366}]" 823 | ] 824 | }, 825 | "execution_count": 22, 826 | "metadata": {}, 827 | "output_type": "execute_result" 828 | } 829 | ], 830 | "source": [ 831 | "trainer.test(lightning_model, dataloaders=train_loader, ckpt_path=\"best\")" 832 | ] 833 | }, 834 | { 835 | "cell_type": "code", 836 | "execution_count": 23, 837 | "id": "10ca0af1-106e-4ef7-9793-478d580af827", 838 | "metadata": {}, 839 | "outputs": [ 840 | { 841 | "name": "stderr", 842 | "output_type": "stream", 843 | "text": [ 844 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 845 | "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n", 846 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", 847 | "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n" 848 | ] 849 | }, 850 | { 851 | "data": { 852 | "application/vnd.jupyter.widget-view+json": { 853 | "model_id": "1b3d7837d2c149c0ad84347b03c0179e", 854 | "version_major": 2, 855 | "version_minor": 0 856 | }, 857 | "text/plain": [ 858 | "Testing: 0it [00:00, ?it/s]" 859 | ] 860 | }, 861 | "metadata": {}, 862 | "output_type": "display_data" 863 | }, 864 | { 865 | "data": { 866 | "text/html": [ 867 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 868 | "┃ Test metric ┃ DataLoader 0 ┃\n", 869 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 870 | "│ accuracy │ 0.9251999855041504 │\n", 871 | "└───────────────────────────┴───────────────────────────┘\n", 872 | "\n" 873 | ], 874 | "text/plain": [ 875 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 876 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 877 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 878 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9251999855041504 \u001b[0m\u001b[35m \u001b[0m│\n", 879 | "└───────────────────────────┴───────────────────────────┘\n" 880 | ] 881 | }, 882 | "metadata": {}, 883 | "output_type": "display_data" 884 | }, 885 | { 886 | "data": { 887 | "text/plain": [ 888 | "[{'accuracy': 0.9251999855041504}]" 889 | ] 890 | }, 891 | "execution_count": 23, 892 | "metadata": {}, 893 | "output_type": "execute_result" 894 | } 895 | ], 896 | "source": [ 897 | "trainer.test(lightning_model, dataloaders=val_loader, ckpt_path=\"best\")" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": 24, 903 | "id": "eeb92de4-d483-4627-b9f3-f0bba0cddd9c", 904 | "metadata": {}, 905 | "outputs": [ 906 | { 907 | "name": "stderr", 908 | "output_type": "stream", 909 | "text": [ 910 | "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 911 | "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n", 912 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", 913 | "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=2-step=8751-v1.ckpt\n" 914 | ] 915 | }, 916 | { 917 | "data": { 918 | "application/vnd.jupyter.widget-view+json": { 919 | "model_id": "3a05c398964e469c928cac221541e4fd", 920 | "version_major": 2, 921 | "version_minor": 0 922 | }, 923 | "text/plain": [ 924 | "Testing: 0it [00:00, ?it/s]" 925 | ] 926 | }, 927 | "metadata": {}, 928 | "output_type": "display_data" 929 | }, 930 | { 931 | "data": { 932 | "text/html": [ 933 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 934 | "┃ Test metric ┃ DataLoader 0 ┃\n", 935 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 936 | "│ accuracy │ 0.9214000105857849 │\n", 937 | "└───────────────────────────┴───────────────────────────┘\n", 938 | "\n" 939 | ], 940 | "text/plain": [ 941 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 942 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 943 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 944 | "│\u001b[36m \u001b[0m\u001b[36m accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9214000105857849 \u001b[0m\u001b[35m \u001b[0m│\n", 945 | "└───────────────────────────┴───────────────────────────┘\n" 946 | ] 947 | }, 948 | "metadata": {}, 949 | "output_type": "display_data" 950 | }, 951 | { 952 | "data": { 953 | "text/plain": [ 954 | "[{'accuracy': 0.9214000105857849}]" 955 | ] 956 | }, 957 | "execution_count": 24, 958 | "metadata": {}, 959 | "output_type": "execute_result" 960 | } 961 | ], 962 | "source": [ 963 | "trainer.test(lightning_model, dataloaders=test_loader, ckpt_path=\"best\")" 964 | ] 965 | } 966 | ], 967 | "metadata": { 968 | "kernelspec": { 969 | "display_name": "Python 3 (ipykernel)", 970 | "language": "python", 971 | "name": "python3" 972 | }, 973 | "language_info": { 974 | "codemirror_mode": { 975 | "name": "ipython", 976 | "version": 3 977 | }, 978 | "file_extension": ".py", 979 | "mimetype": "text/x-python", 980 | "name": "python", 981 | "nbconvert_exporter": "python", 982 | "pygments_lexer": "ipython3", 983 | "version": "3.10.6" 984 | } 985 | }, 986 | "nbformat": 4, 987 | "nbformat_minor": 5 988 | } 989 | -------------------------------------------------------------------------------- /4_r8-dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9a8d1036-9af3-425b-af79-542cb5698183", 6 | "metadata": {}, 7 | "source": [ 8 | "## Experiments on R8 dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "81168bb4-e182-4cf0-9eff-72f0aa495401", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebooks runs the proposed method on the R8 dataset that was reported in the original paper:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "cce35398-f627-4a50-a0eb-af2ad98ff75e", 22 | "metadata": {}, 23 | "source": [ 24 | "" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "73cc2896-1dc1-4bc2-8d06-55bb0f6813bd", 30 | "metadata": {}, 31 | "source": [ 32 | "Note that the scores in the original paper are inflated or overly optimistic because of a bug in their code repository, which was described on [https://kenschutte.com/gzip-knn-paper/](https://kenschutte.com/gzip-knn-paper/)." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "id": "6122b98c-af6f-424c-a498-cee4cd008477", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import gzip\n", 43 | "import os.path as op\n", 44 | "\n", 45 | "import numpy as np\n", 46 | "import pandas as pd" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "502672d8-2a4c-4504-ac47-684962fa7bc2", 52 | "metadata": {}, 53 | "source": [ 54 | "### Load dataset" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "018be464-1ee4-4048-a8bc-0b1ecf8c9a76", 60 | "metadata": {}, 61 | "source": [ 62 | "Before running the code below, make sure to download the dataset from here: https://www.kaggle.com/datasets/weipengfei/ohr8r52" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "id": "f11f439a-1e31-4881-a7c2-a9633693f202", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "df_train = pd.read_csv(\"r8-train-stemmed.csv\")\n", 73 | "df_test = pd.read_csv(\"r8-test-stemmed.csv\")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "id": "fe2f6e18-b3d5-4333-93aa-782a287d0350", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "{'money-fx': 0,\n", 86 | " 'crude': 1,\n", 87 | " 'interest': 2,\n", 88 | " 'trade': 3,\n", 89 | " 'earn': 4,\n", 90 | " 'grain': 5,\n", 91 | " 'ship': 6,\n", 92 | " 'acq': 7}" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "uniq = list(set(df_train[\"intent\"].values))\n", 102 | "labels = {j:i for i,j in zip(range(len(uniq)), uniq)}\n", 103 | "labels" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "5b9bf7a6-cc26-4d9f-a2c8-ea1d736b7997", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "df_train[\"label\"] = df_train[\"intent\"].apply(lambda x: labels[x])\n", 114 | "df_test[\"label\"] = df_test[\"intent\"].apply(lambda x: labels[x])" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "5978af4b-20e2-4031-8c7f-3c2bbd6eb906", 120 | "metadata": {}, 121 | "source": [ 122 | "## Original" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "2f1c8f54-82bf-45cf-a273-af4106323998", 128 | "metadata": {}, 129 | "source": [ 130 | "Reimplementation of the pseudocode in the *\"Low-Resource\" Text Classification: A Parameter-Free Classification Method with Compressors* paper ([https://aclanthology.org/2023.findings-acl.426/](https://aclanthology.org/2023.findings-acl.426/)) \n", 131 | "\n", 132 | "\n", 133 | "