├── A B tests with Machine Learning.ipynb
├── README.md
├── data
├── control_data.csv
└── experiment_data.csv
└── multi_label_trainer.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | Recently, I was reading through [A/B Testing with Machine Learning - A Step-by-Step Tutorial](https://www.business-science.io/business/2019/03/11/ab-testing-machine-learning.html) written by [Matt Dancho](https://www.linkedin.com/in/mattdancho/) of [Business Science](https://www.business-science.io). I have been always fascinated by the idea of **A/B Testing** and the amount of impact it can bring in businesses. The tutorial is very definitive and Matt has explained each and every step in the tutorial. He has detailed about each and every decision taken while developing the solution.
2 |
3 | Even though the tutorial is written in `R`, I was able to scram through his code and my knowledge of Data Science helped me to understand the concepts very quickly. I will have to thank Matt for putting together all the key ingredients of the Data Science world and or using them to solve a real problem.
4 |
5 | The notebook in this repository contains my implementation of the solution (presented by Matt) in Python.
6 |
--------------------------------------------------------------------------------
/data/control_data.csv:
--------------------------------------------------------------------------------
1 | Date,Pageviews,Clicks,Enrollments,Payments
2 | "Sat, Oct 11",7723,687,134,70
3 | "Sun, Oct 12",9102,779,147,70
4 | "Mon, Oct 13",10511,909,167,95
5 | "Tue, Oct 14",9871,836,156,105
6 | "Wed, Oct 15",10014,837,163,64
7 | "Thu, Oct 16",9670,823,138,82
8 | "Fri, Oct 17",9008,748,146,76
9 | "Sat, Oct 18",7434,632,110,70
10 | "Sun, Oct 19",8459,691,131,60
11 | "Mon, Oct 20",10667,861,165,97
12 | "Tue, Oct 21",10660,867,196,105
13 | "Wed, Oct 22",9947,838,162,92
14 | "Thu, Oct 23",8324,665,127,56
15 | "Fri, Oct 24",9434,673,220,122
16 | "Sat, Oct 25",8687,691,176,128
17 | "Sun, Oct 26",8896,708,161,104
18 | "Mon, Oct 27",9535,759,233,124
19 | "Tue, Oct 28",9363,736,154,91
20 | "Wed, Oct 29",9327,739,196,86
21 | "Thu, Oct 30",9345,734,167,75
22 | "Fri, Oct 31",8890,706,174,101
23 | "Sat, Nov 1",8460,681,156,93
24 | "Sun, Nov 2",8836,693,206,67
25 | "Mon, Nov 3",9437,788,,
26 | "Tue, Nov 4",9420,781,,
27 | "Wed, Nov 5",9570,805,,
28 | "Thu, Nov 6",9921,830,,
29 | "Fri, Nov 7",9424,781,,
30 | "Sat, Nov 8",9010,756,,
31 | "Sun, Nov 9",9656,825,,
32 | "Mon, Nov 10",10419,874,,
33 | "Tue, Nov 11",9880,830,,
34 | "Wed, Nov 12",10134,801,,
35 | "Thu, Nov 13",9717,814,,
36 | "Fri, Nov 14",9192,735,,
37 | "Sat, Nov 15",8630,743,,
38 | "Sun, Nov 16",8970,722,,
39 |
--------------------------------------------------------------------------------
/data/experiment_data.csv:
--------------------------------------------------------------------------------
1 | Date,Pageviews,Clicks,Enrollments,Payments
2 | "Sat, Oct 11",7716,686,105,34
3 | "Sun, Oct 12",9288,785,116,91
4 | "Mon, Oct 13",10480,884,145,79
5 | "Tue, Oct 14",9867,827,138,92
6 | "Wed, Oct 15",9793,832,140,94
7 | "Thu, Oct 16",9500,788,129,61
8 | "Fri, Oct 17",9088,780,127,44
9 | "Sat, Oct 18",7664,652,94,62
10 | "Sun, Oct 19",8434,697,120,77
11 | "Mon, Oct 20",10496,860,153,98
12 | "Tue, Oct 21",10551,864,143,71
13 | "Wed, Oct 22",9737,801,128,70
14 | "Thu, Oct 23",8176,642,122,68
15 | "Fri, Oct 24",9402,697,194,94
16 | "Sat, Oct 25",8669,669,127,81
17 | "Sun, Oct 26",8881,693,153,101
18 | "Mon, Oct 27",9655,771,213,119
19 | "Tue, Oct 28",9396,736,162,120
20 | "Wed, Oct 29",9262,727,201,96
21 | "Thu, Oct 30",9308,728,207,67
22 | "Fri, Oct 31",8715,722,182,123
23 | "Sat, Nov 1",8448,695,142,100
24 | "Sun, Nov 2",8836,724,182,103
25 | "Mon, Nov 3",9359,789,,
26 | "Tue, Nov 4",9427,743,,
27 | "Wed, Nov 5",9633,808,,
28 | "Thu, Nov 6",9842,831,,
29 | "Fri, Nov 7",9272,767,,
30 | "Sat, Nov 8",8969,760,,
31 | "Sun, Nov 9",9697,850,,
32 | "Mon, Nov 10",10445,851,,
33 | "Tue, Nov 11",9931,831,,
34 | "Wed, Nov 12",10042,802,,
35 | "Thu, Nov 13",9721,829,,
36 | "Fri, Nov 14",9304,770,,
37 | "Sat, Nov 15",8668,724,,
38 | "Sun, Nov 16",8988,710,,
39 |
--------------------------------------------------------------------------------
/multi_label_trainer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "multi-label-trainer.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "display_name": "Python 3 (ipykernel)",
14 | "language": "python",
15 | "name": "python3"
16 | },
17 | "language_info": {
18 | "codemirror_mode": {
19 | "name": "ipython",
20 | "version": 3
21 | },
22 | "file_extension": ".py",
23 | "mimetype": "text/x-python",
24 | "name": "python",
25 | "nbconvert_exporter": "python",
26 | "pygments_lexer": "ipython3",
27 | "version": "3.8.2"
28 | }
29 | },
30 | "cells": [
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "view-in-github",
35 | "colab_type": "text"
36 | },
37 | "source": [
38 | "
"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "5GICQpY-zws7"
45 | },
46 | "source": [
47 | "## Imports"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "metadata": {
53 | "id": "ho5uPff1fLoH"
54 | },
55 | "source": [
56 | "from tensorflow.keras import layers\n",
57 | "from tensorflow import keras\n",
58 | "import tensorflow as tf\n",
59 | "\n",
60 | "from sklearn.preprocessing import MultiLabelBinarizer\n",
61 | "from sklearn.model_selection import train_test_split\n",
62 | "from ast import literal_eval\n",
63 | "import pandas as pd"
64 | ],
65 | "execution_count": null,
66 | "outputs": []
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {
71 | "id": "qyeQtKSezymP"
72 | },
73 | "source": [
74 | "## Read data and perform basic EDA"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "colab": {
81 | "base_uri": "https://localhost:8080/",
82 | "height": 206
83 | },
84 | "id": "yFo2pNYbf2Du",
85 | "outputId": "d487ddac-8552-4543-f7f7-ecc3a1c2f9a8"
86 | },
87 | "source": [
88 | "arxiv_data = pd.read_csv(\n",
89 | " \"https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv\"\n",
90 | ")\n",
91 | "arxiv_data.head()"
92 | ],
93 | "execution_count": null,
94 | "outputs": [
95 | {
96 | "output_type": "execute_result",
97 | "data": {
98 | "text/html": [
99 | "
\n",
100 | "\n",
113 | "
\n",
114 | " \n",
115 | " \n",
116 | " | \n",
117 | " titles | \n",
118 | " summaries | \n",
119 | " terms | \n",
120 | "
\n",
121 | " \n",
122 | " \n",
123 | " \n",
124 | " 0 | \n",
125 | " Survey on Semantic Stereo Matching / Semantic ... | \n",
126 | " Stereo matching is one of the widely used tech... | \n",
127 | " ['cs.CV', 'cs.LG'] | \n",
128 | "
\n",
129 | " \n",
130 | " 1 | \n",
131 | " FUTURE-AI: Guiding Principles and Consensus Re... | \n",
132 | " The recent advancements in artificial intellig... | \n",
133 | " ['cs.CV', 'cs.AI', 'cs.LG'] | \n",
134 | "
\n",
135 | " \n",
136 | " 2 | \n",
137 | " Enforcing Mutual Consistency of Hard Regions f... | \n",
138 | " In this paper, we proposed a novel mutual cons... | \n",
139 | " ['cs.CV', 'cs.AI'] | \n",
140 | "
\n",
141 | " \n",
142 | " 3 | \n",
143 | " Parameter Decoupling Strategy for Semi-supervi... | \n",
144 | " Consistency training has proven to be an advan... | \n",
145 | " ['cs.CV'] | \n",
146 | "
\n",
147 | " \n",
148 | " 4 | \n",
149 | " Background-Foreground Segmentation for Interio... | \n",
150 | " To ensure safety in automated driving, the cor... | \n",
151 | " ['cs.CV', 'cs.LG'] | \n",
152 | "
\n",
153 | " \n",
154 | "
\n",
155 | "
"
156 | ],
157 | "text/plain": [
158 | " titles ... terms\n",
159 | "0 Survey on Semantic Stereo Matching / Semantic ... ... ['cs.CV', 'cs.LG']\n",
160 | "1 FUTURE-AI: Guiding Principles and Consensus Re... ... ['cs.CV', 'cs.AI', 'cs.LG']\n",
161 | "2 Enforcing Mutual Consistency of Hard Regions f... ... ['cs.CV', 'cs.AI']\n",
162 | "3 Parameter Decoupling Strategy for Semi-supervi... ... ['cs.CV']\n",
163 | "4 Background-Foreground Segmentation for Interio... ... ['cs.CV', 'cs.LG']\n",
164 | "\n",
165 | "[5 rows x 3 columns]"
166 | ]
167 | },
168 | "metadata": {},
169 | "execution_count": 2
170 | }
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "metadata": {
176 | "colab": {
177 | "base_uri": "https://localhost:8080/"
178 | },
179 | "id": "Em_8mJvUgKY-",
180 | "outputId": "96ecef3c-6ff5-4e15-9cb3-be8835bee5e8"
181 | },
182 | "source": [
183 | "print(f\"There are {len(arxiv_data)} rows in the dataset.\")"
184 | ],
185 | "execution_count": null,
186 | "outputs": [
187 | {
188 | "output_type": "stream",
189 | "name": "stdout",
190 | "text": [
191 | "There are 51774 rows in the dataset.\n"
192 | ]
193 | }
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "metadata": {
199 | "id": "k9Vb9jtK5zjg",
200 | "colab": {
201 | "base_uri": "https://localhost:8080/"
202 | },
203 | "outputId": "38aa5188-62a0-480b-d829-0cab6cd141d2"
204 | },
205 | "source": [
206 | "total_duplicate_titles = sum(arxiv_data[\"titles\"].duplicated())\n",
207 | "print(f\"There are {total_duplicate_titles} duplicate titles.\")"
208 | ],
209 | "execution_count": null,
210 | "outputs": [
211 | {
212 | "output_type": "stream",
213 | "name": "stdout",
214 | "text": [
215 | "There are 12802 duplicate titles.\n"
216 | ]
217 | }
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "metadata": {
223 | "id": "2259X-rf6OLY",
224 | "colab": {
225 | "base_uri": "https://localhost:8080/"
226 | },
227 | "outputId": "309cfa18-56ff-4c4d-b7b8-1a66bec152d5"
228 | },
229 | "source": [
230 | "arxiv_data = arxiv_data[~arxiv_data[\"titles\"].duplicated()]\n",
231 | "print(f\"There are {len(arxiv_data)} rows in the deduplicated dataset.\")"
232 | ],
233 | "execution_count": null,
234 | "outputs": [
235 | {
236 | "output_type": "stream",
237 | "name": "stdout",
238 | "text": [
239 | "There are 38972 rows in the deduplicated dataset.\n"
240 | ]
241 | }
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "metadata": {
247 | "colab": {
248 | "base_uri": "https://localhost:8080/"
249 | },
250 | "id": "TmgkCCr2g0w5",
251 | "outputId": "3b7268e4-86eb-43e9-852f-8d4da544cdc3"
252 | },
253 | "source": [
254 | "# There are some terms with occurence as low as 1.\n",
255 | "sum(arxiv_data[\"terms\"].value_counts() == 1)"
256 | ],
257 | "execution_count": null,
258 | "outputs": [
259 | {
260 | "output_type": "execute_result",
261 | "data": {
262 | "text/plain": [
263 | "2321"
264 | ]
265 | },
266 | "metadata": {},
267 | "execution_count": 10
268 | }
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "colab": {
275 | "base_uri": "https://localhost:8080/"
276 | },
277 | "id": "hyBJNHMdifJ-",
278 | "outputId": "b220cf2b-341f-46da-ad09-0c36ca02f4f8"
279 | },
280 | "source": [
281 | "# How many unique terms?\n",
282 | "arxiv_data[\"terms\"].nunique()"
283 | ],
284 | "execution_count": null,
285 | "outputs": [
286 | {
287 | "output_type": "execute_result",
288 | "data": {
289 | "text/plain": [
290 | "3157"
291 | ]
292 | },
293 | "metadata": {},
294 | "execution_count": 11
295 | }
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "metadata": {
301 | "colab": {
302 | "base_uri": "https://localhost:8080/"
303 | },
304 | "id": "77ZoCzrMhLxc",
305 | "outputId": "733bee4c-11f8-44ce-9b0e-53cd1bafa637"
306 | },
307 | "source": [
308 | "# Filtering the rare terms.\n",
309 | "arxiv_data_filtered = arxiv_data.groupby(\"terms\").filter(lambda x: len(x) > 1)\n",
310 | "arxiv_data_filtered.shape"
311 | ],
312 | "execution_count": null,
313 | "outputs": [
314 | {
315 | "output_type": "execute_result",
316 | "data": {
317 | "text/plain": [
318 | "(36651, 3)"
319 | ]
320 | },
321 | "metadata": {},
322 | "execution_count": 12
323 | }
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {
329 | "id": "MxrG9tim0QNr"
330 | },
331 | "source": [
332 | "## Convert the string labels to list of strings. \n",
333 | "\n",
334 | "The initial labels are represented as raw strings. Here we make them `List[str]` for a more compact representation. "
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "metadata": {
340 | "colab": {
341 | "base_uri": "https://localhost:8080/"
342 | },
343 | "id": "LIEGLc61iwbQ",
344 | "outputId": "ebddd3c1-8f4b-42cf-e236-356a00c88f51"
345 | },
346 | "source": [
347 | "arxiv_data_filtered[\"terms\"] = arxiv_data_filtered[\"terms\"].apply(\n",
348 | " lambda x: literal_eval(x)\n",
349 | ")\n",
350 | "arxiv_data_filtered[\"terms\"].values[:5]"
351 | ],
352 | "execution_count": null,
353 | "outputs": [
354 | {
355 | "output_type": "execute_result",
356 | "data": {
357 | "text/plain": [
358 | "array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),\n",
359 | " list(['cs.CV', 'cs.AI']), list(['cs.CV']),\n",
360 | " list(['cs.CV', 'cs.LG'])], dtype=object)"
361 | ]
362 | },
363 | "metadata": {},
364 | "execution_count": 13
365 | }
366 | ]
367 | },
368 | {
369 | "cell_type": "markdown",
370 | "metadata": {
371 | "id": "zjFB8Uoo0cXM"
372 | },
373 | "source": [
374 | "## Stratified splits because of class imbalance"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "metadata": {
380 | "colab": {
381 | "base_uri": "https://localhost:8080/"
382 | },
383 | "id": "EbKDVTKPgOKe",
384 | "outputId": "38a3b8b4-0dd1-4181-ee36-e7e3dd8ed953"
385 | },
386 | "source": [
387 | "test_split = 0.1\n",
388 | "\n",
389 | "# Initial train and test split.\n",
390 | "train_df, test_df = train_test_split(\n",
391 | " arxiv_data_filtered,\n",
392 | " test_size=test_split,\n",
393 | " stratify=arxiv_data_filtered[\"terms\"].values,\n",
394 | ")\n",
395 | "\n",
396 | "# Splitting the test set further into validation\n",
397 | "# and new test sets.\n",
398 | "val_df = test_df.sample(frac=0.5)\n",
399 | "test_df.drop(val_df.index, inplace=True)\n",
400 | "\n",
401 | "print(f\"Number of rows in training set: {len(train_df)}\")\n",
402 | "print(f\"Number of rows in validation set: {len(val_df)}\")\n",
403 | "print(f\"Number of rows in test set: {len(test_df)}\")"
404 | ],
405 | "execution_count": null,
406 | "outputs": [
407 | {
408 | "output_type": "stream",
409 | "name": "stdout",
410 | "text": [
411 | "Number of rows in training set: 32985\n",
412 | "Number of rows in validation set: 1833\n",
413 | "Number of rows in test set: 1833\n"
414 | ]
415 | },
416 | {
417 | "output_type": "stream",
418 | "name": "stderr",
419 | "text": [
420 | "/usr/local/lib/python3.7/dist-packages/pandas/core/frame.py:4174: SettingWithCopyWarning: \n",
421 | "A value is trying to be set on a copy of a slice from a DataFrame\n",
422 | "\n",
423 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
424 | " errors=errors,\n"
425 | ]
426 | }
427 | ]
428 | },
429 | {
430 | "cell_type": "markdown",
431 | "metadata": {
432 | "id": "96Ew2PPI0lVc"
433 | },
434 | "source": [
435 | "## Multi-label binarization"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "metadata": {
441 | "colab": {
442 | "base_uri": "https://localhost:8080/"
443 | },
444 | "id": "1vgxbdwGf07E",
445 | "outputId": "6f2d9d79-cf1d-433d-d1f8-aabaffb103ea"
446 | },
447 | "source": [
448 | "mlb = MultiLabelBinarizer()\n",
449 | "mlb.fit_transform(train_df[\"terms\"])\n",
450 | "mlb.classes_"
451 | ],
452 | "execution_count": null,
453 | "outputs": [
454 | {
455 | "output_type": "execute_result",
456 | "data": {
457 | "text/plain": [
458 | "array(['14J60 (Primary) 14F05, 14J26 (Secondary)', '62H30', '62H35',\n",
459 | " '62H99', '65D19', '68', '68Q32', '68T01', '68T05', '68T07',\n",
460 | " '68T10', '68T30', '68T45', '68T99', '68Txx', '68U01', '68U10',\n",
461 | " 'E.5; E.4; E.2; H.1.1; F.1.1; F.1.3', 'F.2.2; I.2.7', 'G.3',\n",
462 | " 'H.3.1; H.3.3; I.2.6; I.2.7', 'H.3.1; I.2.6; I.2.7', 'I.2',\n",
463 | " 'I.2.0; I.2.6', 'I.2.1', 'I.2.10', 'I.2.10; I.2.6',\n",
464 | " 'I.2.10; I.4.8', 'I.2.10; I.4.8; I.5.4', 'I.2.10; I.4; I.5',\n",
465 | " 'I.2.10; I.5.1; I.4.8', 'I.2.1; J.3', 'I.2.6', 'I.2.6, I.5.4',\n",
466 | " 'I.2.6; I.2.10', 'I.2.6; I.2.7', 'I.2.6; I.2.7; H.3.1; H.3.3',\n",
467 | " 'I.2.6; I.2.8', 'I.2.6; I.2.9', 'I.2.6; I.5.1', 'I.2.6; I.5.4',\n",
468 | " 'I.2.7', 'I.2.8', 'I.2; I.2.6; I.2.7', 'I.2; I.4; I.5', 'I.2; I.5',\n",
469 | " 'I.2; J.2', 'I.4', 'I.4.0', 'I.4.3', 'I.4.4', 'I.4.5', 'I.4.6',\n",
470 | " 'I.4.6; I.4.8', 'I.4.8', 'I.4.9', 'I.4.9; I.5.4', 'I.4; I.5',\n",
471 | " 'I.5.4', 'K.3.2', 'astro-ph.IM', 'cond-mat.dis-nn',\n",
472 | " 'cond-mat.mtrl-sci', 'cond-mat.soft', 'cond-mat.stat-mech',\n",
473 | " 'cs.AI', 'cs.AR', 'cs.CC', 'cs.CE', 'cs.CG', 'cs.CL', 'cs.CR',\n",
474 | " 'cs.CV', 'cs.CY', 'cs.DB', 'cs.DC', 'cs.DM', 'cs.DS', 'cs.ET',\n",
475 | " 'cs.FL', 'cs.GR', 'cs.GT', 'cs.HC', 'cs.IR', 'cs.IT', 'cs.LG',\n",
476 | " 'cs.LO', 'cs.MA', 'cs.MM', 'cs.MS', 'cs.NA', 'cs.NE', 'cs.NI',\n",
477 | " 'cs.PF', 'cs.PL', 'cs.RO', 'cs.SC', 'cs.SD', 'cs.SE', 'cs.SI',\n",
478 | " 'cs.SY', 'econ.EM', 'econ.GN', 'eess.AS', 'eess.IV', 'eess.SP',\n",
479 | " 'eess.SY', 'hep-ex', 'hep-ph', 'math.AP', 'math.AT', 'math.CO',\n",
480 | " 'math.DS', 'math.FA', 'math.IT', 'math.LO', 'math.NA', 'math.OC',\n",
481 | " 'math.PR', 'math.ST', 'nlin.AO', 'nlin.CD', 'physics.ao-ph',\n",
482 | " 'physics.bio-ph', 'physics.chem-ph', 'physics.comp-ph',\n",
483 | " 'physics.data-an', 'physics.flu-dyn', 'physics.geo-ph',\n",
484 | " 'physics.med-ph', 'physics.optics', 'physics.soc-ph', 'q-bio.BM',\n",
485 | " 'q-bio.GN', 'q-bio.MN', 'q-bio.NC', 'q-bio.OT', 'q-bio.QM',\n",
486 | " 'q-bio.TO', 'q-fin.CP', 'q-fin.EC', 'q-fin.GN', 'q-fin.PM',\n",
487 | " 'q-fin.RM', 'q-fin.ST', 'q-fin.TR', 'quant-ph', 'stat.AP',\n",
488 | " 'stat.CO', 'stat.ME', 'stat.ML', 'stat.TH'], dtype=object)"
489 | ]
490 | },
491 | "metadata": {},
492 | "execution_count": 16
493 | }
494 | ]
495 | },
496 | {
497 | "cell_type": "markdown",
498 | "metadata": {
499 | "id": "a2kFVBCG0oXV"
500 | },
501 | "source": [
502 | "## Data preprocessing and `tf.data.Dataset` objects\n",
503 | "\n",
504 | "Get percentile estimates of the sequence lengths. "
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "metadata": {
510 | "colab": {
511 | "base_uri": "https://localhost:8080/"
512 | },
513 | "id": "kCR-_Iw3gyT-",
514 | "outputId": "f2939ed0-79dc-44c5-931d-af117ce0fc8d"
515 | },
516 | "source": [
517 | "train_df[\"summaries\"].apply(lambda x: len(x.split(\" \"))).describe()"
518 | ],
519 | "execution_count": null,
520 | "outputs": [
521 | {
522 | "output_type": "execute_result",
523 | "data": {
524 | "text/plain": [
525 | "count 32985.000000\n",
526 | "mean 156.502471\n",
527 | "std 41.538054\n",
528 | "min 5.000000\n",
529 | "25% 128.000000\n",
530 | "50% 154.000000\n",
531 | "75% 183.000000\n",
532 | "max 462.000000\n",
533 | "Name: summaries, dtype: float64"
534 | ]
535 | },
536 | "metadata": {},
537 | "execution_count": 17
538 | }
539 | ]
540 | },
541 | {
542 | "cell_type": "markdown",
543 | "metadata": {
544 | "id": "2rdivJop02xG"
545 | },
546 | "source": [
547 | "Notice that 50% of the abstracts have a length of 158. So, any number near that is a good enough approximate for the maximum sequence length. "
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "metadata": {
553 | "id": "QoMgEZrtVBDS"
554 | },
555 | "source": [
556 | "max_seqlen = 150\n",
557 | "batch_size = 128\n",
558 | "\n",
559 | "\n",
560 | "def unify_text_length(text, label):\n",
561 | " unified_text = tf.strings.substr(text, 0, max_seqlen)\n",
562 | " return tf.expand_dims(unified_text, -1), label\n",
563 | "\n",
564 | "\n",
565 | "def make_dataset(dataframe, train=True):\n",
566 | " label_binarized = mlb.transform(dataframe[\"terms\"].values)\n",
567 | " dataset = tf.data.Dataset.from_tensor_slices(\n",
568 | " (dataframe[\"summaries\"].values, label_binarized)\n",
569 | " )\n",
570 | " if train:\n",
571 | " dataset = dataset.shuffle(batch_size * 10)\n",
572 | " dataset = dataset.map(unify_text_length).cache()\n",
573 | " return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)"
574 | ],
575 | "execution_count": null,
576 | "outputs": []
577 | },
578 | {
579 | "cell_type": "code",
580 | "metadata": {
581 | "id": "F_vrgkCXrWOS"
582 | },
583 | "source": [
584 | "train_dataset = make_dataset(train_df)\n",
585 | "validation_dataset = make_dataset(val_df, False)\n",
586 | "test_dataset = make_dataset(test_df, False)"
587 | ],
588 | "execution_count": null,
589 | "outputs": []
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "metadata": {
594 | "id": "-1Bb4Xnm1EwK"
595 | },
596 | "source": [
597 | "## Dataset preview"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "metadata": {
603 | "colab": {
604 | "base_uri": "https://localhost:8080/"
605 | },
606 | "id": "w-8k7gScVoz6",
607 | "outputId": "78ffcba1-795e-4a7e-c945-4b54523d87e8"
608 | },
609 | "source": [
610 | "text_batch, label_batch = next(iter(train_dataset))\n",
611 | "\n",
612 | "for i, text in enumerate(text_batch[:5]):\n",
613 | " label = label_batch[i].numpy()[None, ...]\n",
614 | " print(f\"Abstract: {text[0]}\")\n",
615 | " print(f\"Label(s): {mlb.inverse_transform(label)[0]}\")\n",
616 | " print(\" \")"
617 | ],
618 | "execution_count": null,
619 | "outputs": [
620 | {
621 | "name": "stdout",
622 | "output_type": "stream",
623 | "text": [
624 | "Abstract: b'We study the effect of the stochastic gradient noise on the training of\\ngenerative adversarial networks (GANs) and show that it can prevent the\\nconver'\n",
625 | "Label(s): ('cs.LG', 'math.OC', 'stat.ML')\n",
626 | " \n",
627 | "Abstract: b'Sensitive medical data is often subject to strict usage constraints. In this\\npaper, we trained a generative adversarial network (GAN) on real-world\\nel'\n",
628 | "Label(s): ('cs.LG',)\n",
629 | " \n",
630 | "Abstract: b'Popular rotated detection methods usually use five parameters (coordinates of\\nthe central point, width, height, and rotation angle) to describe the ro'\n",
631 | "Label(s): ('cs.CV',)\n",
632 | " \n",
633 | "Abstract: b'FPN is a common component used in object detectors, it supplements\\nmulti-scale information by adjacent level features interpolation and summation.\\nHow'\n",
634 | "Label(s): ('cs.CV',)\n",
635 | " \n",
636 | "Abstract: b'Data for Image segmentation models can be costly to obtain due to the\\nprecision required by human annotators. We run a series of experiments showing\\nt'\n",
637 | "Label(s): ('cs.CV',)\n",
638 | " \n"
639 | ]
640 | }
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {
646 | "id": "Yfbnoi-y1IoB"
647 | },
648 | "source": [
649 | "## Vocabulary size for vectorization"
650 | ]
651 | },
652 | {
653 | "cell_type": "code",
654 | "metadata": {
655 | "colab": {
656 | "base_uri": "https://localhost:8080/"
657 | },
658 | "id": "IIW42KOgyfE7",
659 | "outputId": "b6508f2b-623a-45ee-bfb7-479416ec5c52"
660 | },
661 | "source": [
662 | "train_df[\"total_words\"] = train_df[\"summaries\"].str.split().str.len()\n",
663 | "vocabulary_size = train_df[\"total_words\"].max()\n",
664 | "print(f\"Vocabulary size: {vocabulary_size}\")"
665 | ],
666 | "execution_count": null,
667 | "outputs": [
668 | {
669 | "output_type": "stream",
670 | "name": "stdout",
671 | "text": [
672 | "Vocabulary size: 498\n"
673 | ]
674 | },
675 | {
676 | "output_type": "stream",
677 | "name": "stderr",
678 | "text": [
679 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n",
680 | "A value is trying to be set on a copy of a slice from a DataFrame.\n",
681 | "Try using .loc[row_indexer,col_indexer] = value instead\n",
682 | "\n",
683 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
684 | " \"\"\"Entry point for launching an IPython kernel.\n"
685 | ]
686 | }
687 | ]
688 | },
689 | {
690 | "cell_type": "markdown",
691 | "metadata": {
692 | "id": "3TBUFHm11M3G"
693 | },
694 | "source": [
695 | "## Create model with `TextVectorization`"
696 | ]
697 | },
698 | {
699 | "cell_type": "code",
700 | "metadata": {
701 | "id": "4XX7ovyPokNs"
702 | },
703 | "source": [
704 | "text_vectorizer = layers.TextVectorization(\n",
705 | " max_tokens=vocabulary_size, ngrams=2, output_mode=\"tf_idf\"\n",
706 | ")\n",
707 | "\n",
708 | "with tf.device(\"/CPU:0\"):\n",
709 | " text_vectorizer.adapt(train_dataset.map(lambda text, label: text))\n",
710 | "\n",
711 | "\n",
712 | "def make_model():\n",
713 | " shallow_mlp_model = keras.Sequential(\n",
714 | " [\n",
715 | " keras.Input(shape=(), dtype=tf.string),\n",
716 | " text_vectorizer,\n",
717 | " layers.Dense(512, activation=\"relu\"),\n",
718 | " layers.Dense(256, activation=\"relu\"),\n",
719 | " layers.Dense(len(mlb.classes_), activation=\"softmax\"),\n",
720 | " ]\n",
721 | " )\n",
722 | " return shallow_mlp_model"
723 | ],
724 | "execution_count": null,
725 | "outputs": []
726 | },
727 | {
728 | "cell_type": "markdown",
729 | "metadata": {
730 | "id": "i7i71RLtugUs"
731 | },
732 | "source": [
733 | "With the CPU placement, we run into: \n",
734 | "\n",
735 | "```\n",
736 | "(1) Invalid argument: During Variant Host->Device Copy: non-DMA-copy attempted of tensor type: string\n",
737 | "```"
738 | ]
739 | },
740 | {
741 | "cell_type": "code",
742 | "metadata": {
743 | "colab": {
744 | "base_uri": "https://localhost:8080/"
745 | },
746 | "id": "58tUdCsQuI1P",
747 | "outputId": "2c58e032-0da5-434a-b735-44b2d35fa920"
748 | },
749 | "source": [
750 | "shallow_mlp_model = make_model()\n",
751 | "shallow_mlp_model.summary()"
752 | ],
753 | "execution_count": null,
754 | "outputs": [
755 | {
756 | "output_type": "stream",
757 | "name": "stdout",
758 | "text": [
759 | "Model: \"sequential\"\n",
760 | "_________________________________________________________________\n",
761 | "Layer (type) Output Shape Param # \n",
762 | "=================================================================\n",
763 | "text_vectorization (TextVect (None, 498) 1 \n",
764 | "_________________________________________________________________\n",
765 | "dense (Dense) (None, 512) 255488 \n",
766 | "_________________________________________________________________\n",
767 | "dense_1 (Dense) (None, 256) 131328 \n",
768 | "_________________________________________________________________\n",
769 | "dense_2 (Dense) (None, 152) 39064 \n",
770 | "=================================================================\n",
771 | "Total params: 425,881\n",
772 | "Trainable params: 425,880\n",
773 | "Non-trainable params: 1\n",
774 | "_________________________________________________________________\n"
775 | ]
776 | }
777 | ]
778 | },
779 | {
780 | "cell_type": "markdown",
781 | "metadata": {
782 | "id": "y1Hr9D0O1Tw0"
783 | },
784 | "source": [
785 | "## Train the model"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "metadata": {
791 | "colab": {
792 | "base_uri": "https://localhost:8080/"
793 | },
794 | "id": "WCoaRkA-wsC3",
795 | "outputId": "198a94a5-04f0-4584-8b3e-27328f894e5e"
796 | },
797 | "source": [
798 | "epochs = 20\n",
799 | "\n",
800 | "shallow_mlp_model.compile(\n",
801 | " loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"]\n",
802 | ")\n",
803 | "\n",
804 | "shallow_mlp_model.fit(train_dataset, validation_data=validation_dataset, epochs=epochs)"
805 | ],
806 | "execution_count": null,
807 | "outputs": [
808 | {
809 | "output_type": "stream",
810 | "name": "stdout",
811 | "text": [
812 | "Epoch 1/20\n",
813 | "258/258 [==============================] - 4s 8ms/step - loss: 0.0382 - categorical_accuracy: 0.5880 - val_loss: 0.0218 - val_categorical_accuracy: 0.6416\n",
814 | "Epoch 2/20\n",
815 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0224 - categorical_accuracy: 0.6366 - val_loss: 0.0214 - val_categorical_accuracy: 0.6454\n",
816 | "Epoch 3/20\n",
817 | "258/258 [==============================] - 2s 8ms/step - loss: 0.0216 - categorical_accuracy: 0.6410 - val_loss: 0.0213 - val_categorical_accuracy: 0.6465\n",
818 | "Epoch 4/20\n",
819 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0211 - categorical_accuracy: 0.6441 - val_loss: 0.0213 - val_categorical_accuracy: 0.6476\n",
820 | "Epoch 5/20\n",
821 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0205 - categorical_accuracy: 0.6484 - val_loss: 0.0213 - val_categorical_accuracy: 0.6487\n",
822 | "Epoch 6/20\n",
823 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0200 - categorical_accuracy: 0.6535 - val_loss: 0.0214 - val_categorical_accuracy: 0.6503\n",
824 | "Epoch 7/20\n",
825 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0194 - categorical_accuracy: 0.6582 - val_loss: 0.0217 - val_categorical_accuracy: 0.6508\n",
826 | "Epoch 8/20\n",
827 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0187 - categorical_accuracy: 0.6662 - val_loss: 0.0222 - val_categorical_accuracy: 0.6519\n",
828 | "Epoch 9/20\n",
829 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0181 - categorical_accuracy: 0.6742 - val_loss: 0.0229 - val_categorical_accuracy: 0.6481\n",
830 | "Epoch 10/20\n",
831 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0174 - categorical_accuracy: 0.6813 - val_loss: 0.0237 - val_categorical_accuracy: 0.6394\n",
832 | "Epoch 11/20\n",
833 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0168 - categorical_accuracy: 0.6866 - val_loss: 0.0244 - val_categorical_accuracy: 0.6307\n",
834 | "Epoch 12/20\n",
835 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0165 - categorical_accuracy: 0.6861 - val_loss: 0.0244 - val_categorical_accuracy: 0.6290\n",
836 | "Epoch 13/20\n",
837 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0159 - categorical_accuracy: 0.6899 - val_loss: 0.0253 - val_categorical_accuracy: 0.6432\n",
838 | "Epoch 14/20\n",
839 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0150 - categorical_accuracy: 0.6983 - val_loss: 0.0261 - val_categorical_accuracy: 0.6416\n",
840 | "Epoch 15/20\n",
841 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0142 - categorical_accuracy: 0.7049 - val_loss: 0.0271 - val_categorical_accuracy: 0.6350\n",
842 | "Epoch 16/20\n",
843 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0135 - categorical_accuracy: 0.7095 - val_loss: 0.0285 - val_categorical_accuracy: 0.6323\n",
844 | "Epoch 17/20\n",
845 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0128 - categorical_accuracy: 0.7126 - val_loss: 0.0297 - val_categorical_accuracy: 0.6296\n",
846 | "Epoch 18/20\n",
847 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0121 - categorical_accuracy: 0.7162 - val_loss: 0.0312 - val_categorical_accuracy: 0.6247\n",
848 | "Epoch 19/20\n",
849 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0115 - categorical_accuracy: 0.7166 - val_loss: 0.0324 - val_categorical_accuracy: 0.6263\n",
850 | "Epoch 20/20\n",
851 | "258/258 [==============================] - 2s 7ms/step - loss: 0.0109 - categorical_accuracy: 0.7184 - val_loss: 0.0339 - val_categorical_accuracy: 0.6230\n"
852 | ]
853 | },
854 | {
855 | "output_type": "execute_result",
856 | "data": {
857 | "text/plain": [
858 | ""
859 | ]
860 | },
861 | "metadata": {},
862 | "execution_count": 23
863 | }
864 | ]
865 | },
866 | {
867 | "cell_type": "markdown",
868 | "metadata": {
869 | "id": "wBvqeuk88G9r"
870 | },
871 | "source": [
872 | "## Evaluate the model"
873 | ]
874 | },
875 | {
876 | "cell_type": "code",
877 | "metadata": {
878 | "id": "sxz8yDaT8MdL",
879 | "colab": {
880 | "base_uri": "https://localhost:8080/"
881 | },
882 | "outputId": "d83176bd-34ea-4bac-ff75-7af429e0d475"
883 | },
884 | "source": [
885 | "_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)\n",
886 | "print(f\"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.\")"
887 | ],
888 | "execution_count": null,
889 | "outputs": [
890 | {
891 | "output_type": "stream",
892 | "name": "stdout",
893 | "text": [
894 | "15/15 [==============================] - 0s 11ms/step - loss: 0.0339 - categorical_accuracy: 0.6230\n",
895 | "Categorical accuracy on the test set: 62.3%.\n"
896 | ]
897 | }
898 | ]
899 | }
900 | ]
901 | }
--------------------------------------------------------------------------------