├── .gitignore
├── README.md
├── arxiv_scrape.ipynb
├── beam_arxiv_scrape.ipynb
└── multi_label_trainer_tfidf.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | venv/
2 | **.csv
3 | **pycache**
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-label Text Classification
2 |
3 | Holds code for collecting data from arXiv to build a multi-label text classification dataset and a simpler classifier on
4 | top of that. Our dataset is now [available on Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts). The dataset collection process
5 | has been shown in [this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/beam_arxiv_scrape.ipynb). We leverage
6 | Apache Beam to design our data collection pipeline and our pipeline can be run on [Dataflow](https://cloud.google.com/dataflow) at scale. We hope
7 | the data will be a useful benchmark for building multi-label text classification systems.
8 |
9 | Here's an accompanying blog post on keras.io discussing the motivation behind this dataset, building a simple
10 | baseline model, etc.: [Large-scale multi-label text classification](https://keras.io/examples/nlp/multi_label_classification/).
11 |
12 | ## Acknowledgements
13 |
14 | We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us build the simple baseline classifier model. Thanks to
15 | [Lukas Schwab](https://github.com/lukasschwab) (author of [`arxiv.py`](https://github.com/lukasschwab/arxiv.py)) for helping us build
16 | our initial data collection utilities. Thanks to [Robert Bradshaw](https://www.linkedin.com/in/robert-bradshaw-1b48a07/) for his inputs
17 | on the Apache Beam pipeline. Thanks to the [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP credits
18 | that allowed us to run the Beam pipeline at scale on Dataflow.
19 |
--------------------------------------------------------------------------------
/arxiv_scrape.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "arxiv-scrape.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "code",
21 | "metadata": {
22 | "colab": {
23 | "base_uri": "https://localhost:8080/"
24 | },
25 | "id": "ndfglIjXYcBU",
26 | "outputId": "dcfc54c0-70c6-452e-ac68-2bb396744183"
27 | },
28 | "source": [
29 | "!pip install arxiv"
30 | ],
31 | "execution_count": null,
32 | "outputs": [
33 | {
34 | "output_type": "stream",
35 | "name": "stdout",
36 | "text": [
37 | "Collecting arxiv\n",
38 | " Downloading arxiv-1.4.2-py3-none-any.whl (11 kB)\n",
39 | "Collecting feedparser\n",
40 | " Downloading feedparser-6.0.8-py3-none-any.whl (81 kB)\n",
41 | "\u001b[K |████████████████████████████████| 81 kB 5.3 MB/s \n",
42 | "\u001b[?25hCollecting sgmllib3k\n",
43 | " Downloading sgmllib3k-1.0.0.tar.gz (5.8 kB)\n",
44 | "Building wheels for collected packages: sgmllib3k\n",
45 | " Building wheel for sgmllib3k (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
46 | " Created wheel for sgmllib3k: filename=sgmllib3k-1.0.0-py3-none-any.whl size=6065 sha256=1a262bf6597c0a6d7c9fea438a07ef9abce303bce87c66dc0a19ff6bfc1ebe83\n",
47 | " Stored in directory: /root/.cache/pip/wheels/73/ad/a4/0dff4a6ef231fc0dfa12ffbac2a36cebfdddfe059f50e019aa\n",
48 | "Successfully built sgmllib3k\n",
49 | "Installing collected packages: sgmllib3k, feedparser, arxiv\n",
50 | "Successfully installed arxiv-1.4.2 feedparser-6.0.8 sgmllib3k-1.0.0\n"
51 | ]
52 | }
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "metadata": {
58 | "id": "ApsL6p9CWTBX"
59 | },
60 | "source": [
61 | "import arxiv\n",
62 | "import pandas as pd\n",
63 | "from tqdm import tqdm\n",
64 | "\n",
65 | "query_keywords = [\n",
66 | " \"\\\"image segmentation\\\"\",\n",
67 | " \"\\\"self-supervised learning\\\"\",\n",
68 | " \"\\\"representation learning\\\"\",\n",
69 | " \"\\\"image generation\\\"\",\n",
70 | " \"\\\"object detection\\\"\",\n",
71 | " \"\\\"transfer learning\\\"\",\n",
72 | " \"\\\"transformers\\\"\",\n",
73 | " \"\\\"adversarial training\",\n",
74 | " \"\\\"generative adversarial networks\\\"\",\n",
75 | " \"\\\"model compressions\\\"\",\n",
76 | " \"\\\"image segmentation\\\"\",\n",
77 | " \"\\\"few-shot learning\\\"\",\n",
78 | " \"\\\"natural language\\\"\",\n",
79 | " \"\\\"graph\\\"\",\n",
80 | " \"\\\"colorization\\\"\",\n",
81 | " \"\\\"depth estimation\\\"\",\n",
82 | " \"\\\"point cloud\\\"\",\n",
83 | " \"\\\"structured data\\\"\",\n",
84 | " \"\\\"optical flow\\\"\",\n",
85 | " \"\\\"reinforcement learning\\\"\",\n",
86 | " \"\\\"super resolution\\\"\",\n",
87 | " \"\\\"attention\\\"\",\n",
88 | " \"\\\"tabular\\\"\",\n",
89 | " \"\\\"unsupervised learning\\\"\",\n",
90 | " \"\\\"semi-supervised learning\\\"\",\n",
91 | " \"\\\"explainable\\\"\",\n",
92 | " \"\\\"radiance field\\\"\",\n",
93 | " \"\\\"decision tree\\\"\",\n",
94 | " \"\\\"time series\\\"\",\n",
95 | " \"\\\"molecule\\\"\",\n",
96 | "]"
97 | ],
98 | "execution_count": null,
99 | "outputs": []
100 | },
101 | {
102 | "cell_type": "code",
103 | "metadata": {
104 | "id": "TIct2UKLbL31"
105 | },
106 | "source": [
107 | "# Reuse a client with increased number of retries (3 -> 10) and increased page\n",
108 | "# size (100->500).\n",
109 | "client = arxiv.Client(num_retries=20, page_size=500)\n",
110 | "\n",
111 | "def query_with_keywords(query):\n",
112 | " search = arxiv.Search(\n",
113 | " query=query,\n",
114 | " max_results=20000,\n",
115 | " sort_by=arxiv.SortCriterion.LastUpdatedDate\n",
116 | " )\n",
117 | " terms = []\n",
118 | " titles = []\n",
119 | " abstracts = []\n",
120 | " for res in tqdm(client.results(search), desc=query):\n",
121 | " if res.primary_category in [\"cs.CV\", \"stat.ML\", \"cs.LG\"]:\n",
122 | " terms.append(res.categories)\n",
123 | " titles.append(res.title)\n",
124 | " abstracts.append(res.summary)\n",
125 | " return terms, titles, abstracts"
126 | ],
127 | "execution_count": null,
128 | "outputs": []
129 | },
130 | {
131 | "cell_type": "code",
132 | "metadata": {
133 | "colab": {
134 | "base_uri": "https://localhost:8080/"
135 | },
136 | "id": "L73h-A7RYmqR",
137 | "outputId": "0e3bd550-1a12-4fd2-efa7-2e5e286c2c3f"
138 | },
139 | "source": [
140 | "all_titles = []\n",
141 | "all_summaries = []\n",
142 | "all_terms = []\n",
143 | "\n",
144 | "for query in query_keywords:\n",
145 | " terms, titles, abstracts = query_with_keywords(query)\n",
146 | " all_titles.extend(titles)\n",
147 | " all_summaries.extend(abstracts)\n",
148 | " all_terms.extend(terms)"
149 | ],
150 | "execution_count": null,
151 | "outputs": [
152 | {
153 | "output_type": "stream",
154 | "name": "stderr",
155 | "text": [
156 | "\"image segmentation\": 2082it [00:34, 60.41it/s]\n",
157 | "\"self-supervised learning\": 0it [00:03, ?it/s]\n",
158 | "\"representation learning\": 3690it [01:07, 54.33it/s]\n",
159 | "\"image generation\": 1241it [00:22, 54.30it/s]\n",
160 | "\"object detection\": 4439it [01:16, 57.97it/s]\n",
161 | "\"transfer learning\": 3456it [00:57, 59.61it/s]\n",
162 | "\"transformers\": 20000it [06:18, 52.89it/s]\n",
163 | "\"adversarial training: 0it [00:03, ?it/s]\n",
164 | "\"generative adversarial networks\": 4185it [01:15, 55.67it/s]\n",
165 | "\"model compressions\": 497it [00:08, 59.19it/s]\n",
166 | "\"image segmentation\": 2082it [00:37, 55.76it/s]\n",
167 | "\"few-shot learning\": 0it [00:03, ?it/s]\n",
168 | "\"natural language\": 8976it [02:37, 57.01it/s]\n",
169 | "\"graph\": 20000it [06:09, 54.17it/s]\n",
170 | "\"colorization\": 20000it [06:46, 49.24it/s]\n",
171 | "\"depth estimation\": 798it [00:14, 55.08it/s]\n",
172 | "\"point cloud\": 2699it [00:45, 59.18it/s]\n",
173 | "\"structured data\": 1458it [00:29, 50.17it/s]\n",
174 | "\"optical flow\": 1136it [00:24, 46.47it/s]\n",
175 | "\"reinforcement learning\": 10880it [03:09, 57.54it/s]\n",
176 | "\"super resolution\": 2107it [00:36, 57.58it/s]\n",
177 | "\"attention\": 20000it [05:51, 56.82it/s]\n",
178 | "\"tabular\": 824it [00:14, 56.69it/s]\n",
179 | "\"unsupervised learning\": 2139it [00:37, 57.74it/s]\n",
180 | "\"semi-supervised learning\": 0it [00:03, ?it/s]\n",
181 | "\"explainable\": 20000it [06:16, 53.13it/s]\n",
182 | "\"radiance field\": 82it [00:04, 19.99it/s]\n",
183 | "\"decision tree\": 1869it [00:31, 58.45it/s]\n",
184 | "\"time series\": 12496it [04:03, 51.28it/s]\n",
185 | "\"molecule\": 20000it [06:51, 48.56it/s]\n"
186 | ]
187 | }
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "metadata": {
193 | "colab": {
194 | "base_uri": "https://localhost:8080/",
195 | "height": 205
196 | },
197 | "id": "GX7UXcI0a1sR",
198 | "outputId": "0ed04b65-ab62-4d11-91a9-5d73d4e67ca9"
199 | },
200 | "source": [
201 | "data = pd.DataFrame({\n",
202 | " 'titles': all_titles,\n",
203 | " 'summaries': all_summaries,\n",
204 | " 'terms': all_terms\n",
205 | "})\n",
206 | "data.head()"
207 | ],
208 | "execution_count": null,
209 | "outputs": [
210 | {
211 | "output_type": "execute_result",
212 | "data": {
213 | "text/html": [
214 | "
\n",
215 | "\n",
228 | "
\n",
229 | " \n",
230 | " \n",
231 | " | \n",
232 | " titles | \n",
233 | " summaries | \n",
234 | " terms | \n",
235 | "
\n",
236 | " \n",
237 | " \n",
238 | " \n",
239 | " 0 | \n",
240 | " Survey on Semantic Stereo Matching / Semantic ... | \n",
241 | " Stereo matching is one of the widely used tech... | \n",
242 | " [cs.CV, cs.LG] | \n",
243 | "
\n",
244 | " \n",
245 | " 1 | \n",
246 | " FUTURE-AI: Guiding Principles and Consensus Re... | \n",
247 | " The recent advancements in artificial intellig... | \n",
248 | " [cs.CV, cs.AI, cs.LG] | \n",
249 | "
\n",
250 | " \n",
251 | " 2 | \n",
252 | " Enforcing Mutual Consistency of Hard Regions f... | \n",
253 | " In this paper, we proposed a novel mutual cons... | \n",
254 | " [cs.CV, cs.AI] | \n",
255 | "
\n",
256 | " \n",
257 | " 3 | \n",
258 | " Parameter Decoupling Strategy for Semi-supervi... | \n",
259 | " Consistency training has proven to be an advan... | \n",
260 | " [cs.CV] | \n",
261 | "
\n",
262 | " \n",
263 | " 4 | \n",
264 | " Background-Foreground Segmentation for Interio... | \n",
265 | " To ensure safety in automated driving, the cor... | \n",
266 | " [cs.CV, cs.LG] | \n",
267 | "
\n",
268 | " \n",
269 | "
\n",
270 | "
"
271 | ],
272 | "text/plain": [
273 | " titles ... terms\n",
274 | "0 Survey on Semantic Stereo Matching / Semantic ... ... [cs.CV, cs.LG]\n",
275 | "1 FUTURE-AI: Guiding Principles and Consensus Re... ... [cs.CV, cs.AI, cs.LG]\n",
276 | "2 Enforcing Mutual Consistency of Hard Regions f... ... [cs.CV, cs.AI]\n",
277 | "3 Parameter Decoupling Strategy for Semi-supervi... ... [cs.CV]\n",
278 | "4 Background-Foreground Segmentation for Interio... ... [cs.CV, cs.LG]\n",
279 | "\n",
280 | "[5 rows x 3 columns]"
281 | ]
282 | },
283 | "metadata": {},
284 | "execution_count": 5
285 | }
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "metadata": {
291 | "id": "LRjWIApOdTE0"
292 | },
293 | "source": [
294 | "data.to_csv('arxiv_data.csv', index=False)"
295 | ],
296 | "execution_count": null,
297 | "outputs": []
298 | },
299 | {
300 | "cell_type": "code",
301 | "metadata": {
302 | "id": "YrcBGeBldxgc",
303 | "colab": {
304 | "base_uri": "https://localhost:8080/",
305 | "height": 17
306 | },
307 | "outputId": "0a3d21a2-34e1-4dfb-fe50-79f9ef723829"
308 | },
309 | "source": [
310 | "from google.colab import files\n",
311 | "files.download('arxiv_data.csv') "
312 | ],
313 | "execution_count": null,
314 | "outputs": [
315 | {
316 | "output_type": "display_data",
317 | "data": {
318 | "application/javascript": [
319 | "\n",
320 | " async function download(id, filename, size) {\n",
321 | " if (!google.colab.kernel.accessAllowed) {\n",
322 | " return;\n",
323 | " }\n",
324 | " const div = document.createElement('div');\n",
325 | " const label = document.createElement('label');\n",
326 | " label.textContent = `Downloading \"${filename}\": `;\n",
327 | " div.appendChild(label);\n",
328 | " const progress = document.createElement('progress');\n",
329 | " progress.max = size;\n",
330 | " div.appendChild(progress);\n",
331 | " document.body.appendChild(div);\n",
332 | "\n",
333 | " const buffers = [];\n",
334 | " let downloaded = 0;\n",
335 | "\n",
336 | " const channel = await google.colab.kernel.comms.open(id);\n",
337 | " // Send a message to notify the kernel that we're ready.\n",
338 | " channel.send({})\n",
339 | "\n",
340 | " for await (const message of channel.messages) {\n",
341 | " // Send a message to notify the kernel that we're ready.\n",
342 | " channel.send({})\n",
343 | " if (message.buffers) {\n",
344 | " for (const buffer of message.buffers) {\n",
345 | " buffers.push(buffer);\n",
346 | " downloaded += buffer.byteLength;\n",
347 | " progress.value = downloaded;\n",
348 | " }\n",
349 | " }\n",
350 | " }\n",
351 | " const blob = new Blob(buffers, {type: 'application/binary'});\n",
352 | " const a = document.createElement('a');\n",
353 | " a.href = window.URL.createObjectURL(blob);\n",
354 | " a.download = filename;\n",
355 | " div.appendChild(a);\n",
356 | " a.click();\n",
357 | " div.remove();\n",
358 | " }\n",
359 | " "
360 | ],
361 | "text/plain": [
362 | ""
363 | ]
364 | },
365 | "metadata": {}
366 | },
367 | {
368 | "output_type": "display_data",
369 | "data": {
370 | "application/javascript": [
371 | "download(\"download_f004bf7a-a2b5-4190-8106-a09da6b00401\", \"arxiv_data.csv\", 67411252)"
372 | ],
373 | "text/plain": [
374 | ""
375 | ]
376 | },
377 | "metadata": {}
378 | }
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "metadata": {
384 | "id": "dqojCAC4uV8u"
385 | },
386 | "source": [
387 | ""
388 | ],
389 | "execution_count": null,
390 | "outputs": []
391 | }
392 | ]
393 | }
--------------------------------------------------------------------------------
/beam_arxiv_scrape.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "66abf160",
7 | "metadata": {
8 | "id": "66abf160"
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import apache_beam as beam\n",
13 | "import arxiv \n",
14 | "\n",
15 | "from apache_beam.dataframe.convert import to_dataframe\n",
16 | "from datetime import datetime"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "id": "5443193c",
23 | "metadata": {
24 | "id": "5443193c"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "query_keywords = [\n",
29 | " \"\\\"image segmentation\\\"\",\n",
30 | " \"\\\"self-supervised learning\\\"\",\n",
31 | " \"\\\"representation learning\\\"\",\n",
32 | " \"\\\"image generation\\\"\",\n",
33 | " \"\\\"object detection\\\"\",\n",
34 | " \"\\\"transfer learning\\\"\",\n",
35 | " \"\\\"transformers\\\"\",\n",
36 | " \"\\\"adversarial training\",\n",
37 | " \"\\\"generative adversarial networks\\\"\",\n",
38 | " \"\\\"model compressions\\\"\",\n",
39 | " \"\\\"image segmentation\\\"\",\n",
40 | " \"\\\"few-shot learning\\\"\",\n",
41 | " \"\\\"natural language\\\"\",\n",
42 | " \"\\\"graph\\\"\",\n",
43 | " \"\\\"colorization\\\"\",\n",
44 | " \"\\\"depth estimation\\\"\",\n",
45 | " \"\\\"point cloud\\\"\",\n",
46 | " \"\\\"structured data\\\"\",\n",
47 | " \"\\\"optical flow\\\"\",\n",
48 | " \"\\\"reinforcement learning\\\"\",\n",
49 | " \"\\\"super resolution\\\"\",\n",
50 | " \"\\\"attention\\\"\",\n",
51 | " \"\\\"tabular\\\"\",\n",
52 | " \"\\\"unsupervised learning\\\"\",\n",
53 | " \"\\\"semi-supervised learning\\\"\",\n",
54 | " \"\\\"explainable\\\"\",\n",
55 | " \"\\\"radiance field\\\"\",\n",
56 | " \"\\\"decision tree\\\"\",\n",
57 | " \"\\\"time series\\\"\",\n",
58 | " \"\\\"molecule\\\"\",\n",
59 | " \"\\\"physics\\\"\",\n",
60 | " \"\\\"graphics\\\"\",\n",
61 | " \"\\\"ray tracing\\\"\",\n",
62 | " \"\\\"optical flow\\\"\",\n",
63 | " \"\\\"photogrametry\\\"\",\n",
64 | "]"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "DI84CCdnY5Ek",
71 | "metadata": {
72 | "id": "DI84CCdnY5Ek"
73 | },
74 | "outputs": [],
75 | "source": [
76 | "import typing\n",
77 | "\n",
78 | "\n",
79 | "class ArxivEntries(typing.NamedTuple):\n",
80 | " terms: typing.List[str]\n",
81 | " titles: str\n",
82 | " abstracts: str"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 4,
88 | "id": "83ddf9d6",
89 | "metadata": {
90 | "id": "83ddf9d6"
91 | },
92 | "outputs": [],
93 | "source": [
94 | "client = arxiv.Client(num_retries=20, page_size=500)\n",
95 | "\n",
96 | "\n",
97 | "def query_with_keywords(query):\n",
98 | " search = arxiv.Search(\n",
99 | " query=query, max_results=20000, sort_by=arxiv.SortCriterion.LastUpdatedDate,\n",
100 | " )\n",
101 | "\n",
102 | " for res in client.results(search):\n",
103 | " if res.primary_category in [\"cs.CV\", \"stat.ML\", \"cs.LG\"]:\n",
104 | " yield beam.Row(\n",
105 | " terms=res.categories, titles=res.title, abstracts=res.summary\n",
106 | " )"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 5,
112 | "id": "SqeUCclZWwF1",
113 | "metadata": {
114 | "colab": {
115 | "base_uri": "https://localhost:8080/"
116 | },
117 | "id": "SqeUCclZWwF1",
118 | "outputId": "e086cff2-8660-4079-9128-66839909a009"
119 | },
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "Overwriting setup.py\n"
126 | ]
127 | }
128 | ],
129 | "source": [
130 | "%%writefile setup.py\n",
131 | "\n",
132 | "import setuptools\n",
133 | "\n",
134 | "\n",
135 | "NAME = \"gather_arxiv_data\"\n",
136 | "VERSION = \"0.1.0\"\n",
137 | "REQUIRED_PACKAGES = [\n",
138 | " \"apache_beam==2.32.0\",\n",
139 | " \"pandas==1.3.2\",\n",
140 | " \"arxiv==1.4.2\",\n",
141 | " \"google_cloud_storage==1.42.1\",\n",
142 | "]\n",
143 | "\n",
144 | "\n",
145 | "setuptools.setup(\n",
146 | " name=NAME,\n",
147 | " version=VERSION,\n",
148 | " install_requires=REQUIRED_PACKAGES,\n",
149 | " packages=setuptools.find_packages(),\n",
150 | " include_package_data=True,\n",
151 | ")"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 6,
157 | "id": "2196906b",
158 | "metadata": {
159 | "id": "2196906b"
160 | },
161 | "outputs": [],
162 | "source": [
163 | "gcs_bucket_name = \"arxiv-data-nlp\"\n",
164 | "gcp_project = \"####\" # Specify this.\n",
165 | "\n",
166 | "pipeline_args = {\n",
167 | " \"job_name\": f'arxiv-data-{datetime.utcnow().strftime(\"%y%m%d-%H%M%S\")}',\n",
168 | " \"num_workers\": \"4\",\n",
169 | " \"runner\": \"DataflowRunner\",\n",
170 | " \"setup_file\": \"./setup.py\",\n",
171 | " \"project\": gcp_project,\n",
172 | " \"region\": \"us-central1\",\n",
173 | " \"gcs_location\": f\"gs://{gcs_bucket_name}\",\n",
174 | " \"temp_location\": f\"gs://{gcs_bucket_name}/temp\",\n",
175 | " \"staging_location\": f\"gs://{gcs_bucket_name}/staging\",\n",
176 | " \"save_main_session\": \"True\",\n",
177 | "}\n",
178 | "\n",
179 | "# Convert the dictionary to a list of (argument, value) tuples and then flatten the list.\n",
180 | "pipeline_args = [(f\"--{k}\", v) for k, v in pipeline_args.items()]\n",
181 | "pipeline_args = [x for y in pipeline_args for x in y]"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 7,
187 | "id": "d58affea",
188 | "metadata": {
189 | "colab": {
190 | "base_uri": "https://localhost:8080/",
191 | "height": 766
192 | },
193 | "id": "d58affea",
194 | "outputId": "23852f9e-bf43-4b2a-cd3a-6180761a904f"
195 | },
196 | "outputs": [
197 | {
198 | "name": "stderr",
199 | "output_type": "stream",
200 | "text": [
201 | "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
202 | ]
203 | },
204 | {
205 | "data": {
206 | "application/javascript": [
207 | "\n",
208 | " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
209 | " var jqueryScript = document.createElement('script');\n",
210 | " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
211 | " jqueryScript.type = 'text/javascript';\n",
212 | " jqueryScript.onload = function() {\n",
213 | " var datatableScript = document.createElement('script');\n",
214 | " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
215 | " datatableScript.type = 'text/javascript';\n",
216 | " datatableScript.onload = function() {\n",
217 | " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
218 | " window.interactive_beam_jquery(document).ready(function($){\n",
219 | " \n",
220 | " });\n",
221 | " }\n",
222 | " document.head.appendChild(datatableScript);\n",
223 | " };\n",
224 | " document.head.appendChild(jqueryScript);\n",
225 | " } else {\n",
226 | " window.interactive_beam_jquery(document).ready(function($){\n",
227 | " \n",
228 | " });\n",
229 | " }"
230 | ]
231 | },
232 | "metadata": {},
233 | "output_type": "display_data"
234 | },
235 | {
236 | "name": "stderr",
237 | "output_type": "stream",
238 | "text": [
239 | "/Users/sayakpaul/.local/bin/.virtualenvs/tf/lib/python3.8/site-packages/apache_beam/dataframe/io.py:566: FutureWarning: WriteToFiles is experimental.\n",
240 | " return pcoll | fileio.WriteToFiles(\n",
241 | "/Users/sayakpaul/.local/bin/.virtualenvs/tf/lib/python3.8/site-packages/apache_beam/io/fileio.py:535: BeamDeprecationWarning: options is deprecated since First stable release. References to .options will not be supported\n",
242 | " p.options.view_as(GoogleCloudOptions).temp_location or\n",
243 | "warning: sdist: standard file not found: should have one of README, README.rst, README.txt, README.md\n",
244 | "\n",
245 | "warning: check: missing required meta-data: url\n",
246 | "\n",
247 | "warning: check: missing meta-data: either (author and author_email) or (maintainer and maintainer_email) must be supplied\n",
248 | "\n",
249 | "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.8 interpreter.\n",
250 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['--gcs_location', 'gs://arxiv-data-nlp', 'True']\n",
251 | "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['--gcs_location', 'gs://arxiv-data-nlp', 'True']\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "with beam.Pipeline(argv=pipeline_args) as pipeline:\n",
257 | " keywords = pipeline | beam.Create(query_keywords)\n",
258 | " records = keywords | beam.FlatMap(query_with_keywords).with_output_types(ArxivEntries)\n",
259 | " _ = to_dataframe(records).to_csv(\n",
260 | " f\"gs://{gcs_bucket_name}/arxiv/sample.csv\", index=False\n",
261 | " )"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 8,
267 | "id": "cb97c245",
268 | "metadata": {},
269 | "outputs": [
270 | {
271 | "name": "stdout",
272 | "output_type": "stream",
273 | "text": [
274 | "\n",
275 | "\n",
276 | "Updates are available for some Cloud SDK components. To install them,\n",
277 | "please run:\n",
278 | " $ gcloud components update\n",
279 | "\n",
280 | "gs://arxiv-data-nlp/arxiv/:\n",
281 | "gs://arxiv-data-nlp/arxiv/sample.csv-00000-of-00020\n",
282 | "gs://arxiv-data-nlp/arxiv/sample.csv-00001-of-00020\n",
283 | "gs://arxiv-data-nlp/arxiv/sample.csv-00002-of-00020\n",
284 | "gs://arxiv-data-nlp/arxiv/sample.csv-00003-of-00020\n",
285 | "gs://arxiv-data-nlp/arxiv/sample.csv-00004-of-00020\n",
286 | "gs://arxiv-data-nlp/arxiv/sample.csv-00005-of-00020\n",
287 | "gs://arxiv-data-nlp/arxiv/sample.csv-00006-of-00020\n",
288 | "gs://arxiv-data-nlp/arxiv/sample.csv-00007-of-00020\n",
289 | "gs://arxiv-data-nlp/arxiv/sample.csv-00008-of-00020\n",
290 | "gs://arxiv-data-nlp/arxiv/sample.csv-00009-of-00020\n",
291 | "gs://arxiv-data-nlp/arxiv/sample.csv-00010-of-00020\n",
292 | "gs://arxiv-data-nlp/arxiv/sample.csv-00011-of-00020\n",
293 | "gs://arxiv-data-nlp/arxiv/sample.csv-00012-of-00020\n",
294 | "gs://arxiv-data-nlp/arxiv/sample.csv-00013-of-00020\n",
295 | "gs://arxiv-data-nlp/arxiv/sample.csv-00014-of-00020\n",
296 | "gs://arxiv-data-nlp/arxiv/sample.csv-00015-of-00020\n",
297 | "gs://arxiv-data-nlp/arxiv/sample.csv-00016-of-00020\n",
298 | "gs://arxiv-data-nlp/arxiv/sample.csv-00017-of-00020\n",
299 | "gs://arxiv-data-nlp/arxiv/sample.csv-00018-of-00020\n",
300 | "gs://arxiv-data-nlp/arxiv/sample.csv-00019-of-00020\n"
301 | ]
302 | }
303 | ],
304 | "source": [
305 | "!gsutil ls -R gs://{gcs_bucket_name}/arxiv/"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 9,
311 | "id": "f98d917f",
312 | "metadata": {},
313 | "outputs": [
314 | {
315 | "name": "stdout",
316 | "output_type": "stream",
317 | "text": [
318 | "Copying gs://arxiv-data-nlp/arxiv/sample.csv-00000-of-00020...\n",
319 | "| [1 files][ 1.9 MiB/ 1.9 MiB] 347.4 KiB/s \n",
320 | "Operation completed over 1 objects/1.9 MiB. \n"
321 | ]
322 | }
323 | ],
324 | "source": [
325 | "!gsutil cp gs://arxiv-data-nlp/arxiv/sample.csv-00000-of-00020 ."
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 10,
331 | "id": "daedd023",
332 | "metadata": {
333 | "id": "daedd023",
334 | "outputId": "e59c8622-02ca-4bfe-adfc-cec0208602fc"
335 | },
336 | "outputs": [
337 | {
338 | "data": {
339 | "text/html": [
340 | "\n",
341 | "\n",
354 | "
\n",
355 | " \n",
356 | " \n",
357 | " | \n",
358 | " terms | \n",
359 | " titles | \n",
360 | " abstracts | \n",
361 | "
\n",
362 | " \n",
363 | " \n",
364 | " \n",
365 | " 0 | \n",
366 | " ['cs.LG', 'cs.AI'] | \n",
367 | " Self-supervised Learning on Graphs: Contrastiv... | \n",
368 | " Deep learning on graphs has recently achieved ... | \n",
369 | "
\n",
370 | " \n",
371 | " 1 | \n",
372 | " ['cs.CV', 'cs.CL'] | \n",
373 | " Contrastive Video-Language Segmentation | \n",
374 | " We focus on the problem of segmenting a certai... | \n",
375 | "
\n",
376 | " \n",
377 | " 2 | \n",
378 | " ['cs.LG'] | \n",
379 | " What to Prioritize? Natural Language Processin... | \n",
380 | " Managing large numbers of incoming bug reports... | \n",
381 | "
\n",
382 | " \n",
383 | " 3 | \n",
384 | " ['cs.CV', 'cs.RO'] | \n",
385 | " The VVAD-LRS3 Dataset for Visual Voice Activit... | \n",
386 | " Robots are becoming everyday devices, increasi... | \n",
387 | "
\n",
388 | " \n",
389 | " 4 | \n",
390 | " ['cs.CV'] | \n",
391 | " UTNet: A Hybrid Transformer Architecture for M... | \n",
392 | " Transformer architecture has emerged to be suc... | \n",
393 | "
\n",
394 | " \n",
395 | "
\n",
396 | "
"
397 | ],
398 | "text/plain": [
399 | " terms titles \\\n",
400 | "0 ['cs.LG', 'cs.AI'] Self-supervised Learning on Graphs: Contrastiv... \n",
401 | "1 ['cs.CV', 'cs.CL'] Contrastive Video-Language Segmentation \n",
402 | "2 ['cs.LG'] What to Prioritize? Natural Language Processin... \n",
403 | "3 ['cs.CV', 'cs.RO'] The VVAD-LRS3 Dataset for Visual Voice Activit... \n",
404 | "4 ['cs.CV'] UTNet: A Hybrid Transformer Architecture for M... \n",
405 | "\n",
406 | " abstracts \n",
407 | "0 Deep learning on graphs has recently achieved ... \n",
408 | "1 We focus on the problem of segmenting a certai... \n",
409 | "2 Managing large numbers of incoming bug reports... \n",
410 | "3 Robots are becoming everyday devices, increasi... \n",
411 | "4 Transformer architecture has emerged to be suc... "
412 | ]
413 | },
414 | "execution_count": 10,
415 | "metadata": {},
416 | "output_type": "execute_result"
417 | }
418 | ],
419 | "source": [
420 | "import pandas as pd\n",
421 | "\n",
422 | "\n",
423 | "df = pd.read_csv(\"sample.csv-00000-of-00020\")\n",
424 | "df.head()"
425 | ]
426 | },
427 | {
428 | "cell_type": "markdown",
429 | "id": "7cba00e9",
430 | "metadata": {
431 | "id": "7cba00e9"
432 | },
433 | "source": [
434 | "## Acknowledgements\n",
435 | "\n",
436 | "* [Lukas Schwab](https://github.com/lukasschwab)\n",
437 | "* [Robert Bradshaw](https://www.linkedin.com/in/robert-bradshaw-1b48a07/)"
438 | ]
439 | }
440 | ],
441 | "metadata": {
442 | "colab": {
443 | "collapsed_sections": [],
444 | "name": "beam_arxiv.ipynb",
445 | "provenance": []
446 | },
447 | "kernelspec": {
448 | "display_name": "Python 3 (ipykernel)",
449 | "language": "python",
450 | "name": "python3"
451 | },
452 | "language_info": {
453 | "codemirror_mode": {
454 | "name": "ipython",
455 | "version": 3
456 | },
457 | "file_extension": ".py",
458 | "mimetype": "text/x-python",
459 | "name": "python",
460 | "nbconvert_exporter": "python",
461 | "pygments_lexer": "ipython3",
462 | "version": "3.8.2"
463 | }
464 | },
465 | "nbformat": 4,
466 | "nbformat_minor": 5
467 | }
468 |
--------------------------------------------------------------------------------
/multi_label_trainer_tfidf.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "Q9ZG7WNCHNqw"
7 | },
8 | "source": [
9 | "## Introduction\n",
10 | "\n",
11 | "In this example, we will build a multi-label text classifier to predict the subject areas of arXiv papers from their abstract bodies. This type of classifier can be useful for conference submission portals like [OpenReview](https://openreview.net/). Given a paper abstract, the portal could provide suggestions on which areas the underlying paper would best belong to.\n",
12 | "\n",
13 | "The dataset was collected using the [`arXiv` Python library](https://github.com/lukasschwab/arxiv.py) that provides a wrapper around the [original arXiv API](http://arxiv.org/help/api/index). To know more, please refer to [this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb). "
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "5GICQpY-zws7"
20 | },
21 | "source": [
22 | "## Imports"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 28,
28 | "metadata": {
29 | "id": "ho5uPff1fLoH"
30 | },
31 | "outputs": [],
32 | "source": [
33 | "from tensorflow.keras import layers\n",
34 | "from tensorflow import keras\n",
35 | "import tensorflow as tf\n",
36 | "\n",
37 | "from sklearn.preprocessing import MultiLabelBinarizer\n",
38 | "from sklearn.model_selection import train_test_split\n",
39 | "import matplotlib.pyplot as plt\n",
40 | "from ast import literal_eval\n",
41 | "import pandas as pd\n",
42 | "import numpy as np"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {
48 | "id": "qyeQtKSezymP"
49 | },
50 | "source": [
51 | "## Read data and perform basic EDA\n",
52 | "\n",
53 | "In this section, we first load the dataset into a `pandas` dataframe and then perform some basic exploratory data analysis (EDA)."
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/",
62 | "height": 287
63 | },
64 | "id": "yFo2pNYbf2Du",
65 | "outputId": "a0478ffb-4612-4f5b-846b-8b33a502bfed"
66 | },
67 | "outputs": [
68 | {
69 | "data": {
70 | "text/html": [
71 | "\n",
72 | "\n",
85 | "
\n",
86 | " \n",
87 | " \n",
88 | " | \n",
89 | " titles | \n",
90 | " summaries | \n",
91 | " terms | \n",
92 | "
\n",
93 | " \n",
94 | " \n",
95 | " \n",
96 | " 0 | \n",
97 | " Survey on Semantic Stereo Matching / Semantic ... | \n",
98 | " Stereo matching is one of the widely used tech... | \n",
99 | " ['cs.CV', 'cs.LG'] | \n",
100 | "
\n",
101 | " \n",
102 | " 1 | \n",
103 | " FUTURE-AI: Guiding Principles and Consensus Re... | \n",
104 | " The recent advancements in artificial intellig... | \n",
105 | " ['cs.CV', 'cs.AI', 'cs.LG'] | \n",
106 | "
\n",
107 | " \n",
108 | " 2 | \n",
109 | " Enforcing Mutual Consistency of Hard Regions f... | \n",
110 | " In this paper, we proposed a novel mutual cons... | \n",
111 | " ['cs.CV', 'cs.AI'] | \n",
112 | "
\n",
113 | " \n",
114 | " 3 | \n",
115 | " Parameter Decoupling Strategy for Semi-supervi... | \n",
116 | " Consistency training has proven to be an advan... | \n",
117 | " ['cs.CV'] | \n",
118 | "
\n",
119 | " \n",
120 | " 4 | \n",
121 | " Background-Foreground Segmentation for Interio... | \n",
122 | " To ensure safety in automated driving, the cor... | \n",
123 | " ['cs.CV', 'cs.LG'] | \n",
124 | "
\n",
125 | " \n",
126 | "
\n",
127 | "
"
128 | ],
129 | "text/plain": [
130 | " titles ... terms\n",
131 | "0 Survey on Semantic Stereo Matching / Semantic ... ... ['cs.CV', 'cs.LG']\n",
132 | "1 FUTURE-AI: Guiding Principles and Consensus Re... ... ['cs.CV', 'cs.AI', 'cs.LG']\n",
133 | "2 Enforcing Mutual Consistency of Hard Regions f... ... ['cs.CV', 'cs.AI']\n",
134 | "3 Parameter Decoupling Strategy for Semi-supervi... ... ['cs.CV']\n",
135 | "4 Background-Foreground Segmentation for Interio... ... ['cs.CV', 'cs.LG']\n",
136 | "\n",
137 | "[5 rows x 3 columns]"
138 | ]
139 | },
140 | "execution_count": 2,
141 | "metadata": {},
142 | "output_type": "execute_result"
143 | }
144 | ],
145 | "source": [
146 | "arxiv_data = pd.read_csv(\n",
147 | " \"https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv\"\n",
148 | ")\n",
149 | "arxiv_data.head()"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "id": "djk-kXWvHNq3"
156 | },
157 | "source": [
158 | "Our text features are present in the `summaries` column and their corresponding labels are in `terms`. As we can notice there are multiple categories associated with a particular entry. "
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {
165 | "colab": {
166 | "base_uri": "https://localhost:8080/"
167 | },
168 | "id": "Em_8mJvUgKY-",
169 | "outputId": "8ce024c1-2745-4ab2-dd18-2b153784b4fc"
170 | },
171 | "outputs": [
172 | {
173 | "name": "stdout",
174 | "output_type": "stream",
175 | "text": [
176 | "There are 51774 rows in the dataset.\n"
177 | ]
178 | }
179 | ],
180 | "source": [
181 | "print(f\"There are {len(arxiv_data)} rows in the dataset.\")"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {
187 | "id": "P3bYjP_0HNq4"
188 | },
189 | "source": [
190 | "Real-world data is noisy. One of the most commonly observed such noise is data duplication. Here we notice that our initial dataset has got about 13k duplicate entries. "
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "metadata": {
197 | "colab": {
198 | "base_uri": "https://localhost:8080/"
199 | },
200 | "id": "k9Vb9jtK5zjg",
201 | "outputId": "92e336c0-921a-4412-87ba-e3872333e3ff"
202 | },
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "There are 12802 duplicate titles.\n"
209 | ]
210 | }
211 | ],
212 | "source": [
213 | "total_duplicate_titles = sum(arxiv_data[\"titles\"].duplicated())\n",
214 | "print(f\"There are {total_duplicate_titles} duplicate titles.\")"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {
220 | "id": "sEBbarGSHNq5"
221 | },
222 | "source": [
223 | "Before proceeding further we first drop these entries. "
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": null,
229 | "metadata": {
230 | "colab": {
231 | "base_uri": "https://localhost:8080/"
232 | },
233 | "id": "2259X-rf6OLY",
234 | "outputId": "08d73f4b-f0a0-46bd-9744-01b0e22f6a92"
235 | },
236 | "outputs": [
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "There are 38972 rows in the deduplicated dataset.\n"
242 | ]
243 | }
244 | ],
245 | "source": [
246 | "arxiv_data = arxiv_data[~arxiv_data[\"titles\"].duplicated()]\n",
247 | "print(f\"There are {len(arxiv_data)} rows in the deduplicated dataset.\")"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "colab": {
255 | "base_uri": "https://localhost:8080/"
256 | },
257 | "id": "TmgkCCr2g0w5",
258 | "outputId": "18a6bd58-e511-405f-c19e-15c40c8955c9"
259 | },
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "2321\n"
266 | ]
267 | }
268 | ],
269 | "source": [
270 | "# There are some terms with occurrence as low as 1.\n",
271 | "print(sum(arxiv_data[\"terms\"].value_counts() == 1))"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "metadata": {
278 | "colab": {
279 | "base_uri": "https://localhost:8080/"
280 | },
281 | "id": "hyBJNHMdifJ-",
282 | "outputId": "0a098fa6-2723-42b9-b70b-f18bc9c11f3b"
283 | },
284 | "outputs": [
285 | {
286 | "name": "stdout",
287 | "output_type": "stream",
288 | "text": [
289 | "3157\n"
290 | ]
291 | }
292 | ],
293 | "source": [
294 | "# How many unique terms?\n",
295 | "print(arxiv_data[\"terms\"].nunique())"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {
301 | "id": "dPvNGph0HNq7"
302 | },
303 | "source": [
304 | "As observed above, out of 3157 unique combinations of `terms`, 2321 entries have the lowest occurrence. To prepare our train, validation, and test sets with [stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop these terms. "
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "metadata": {
311 | "colab": {
312 | "base_uri": "https://localhost:8080/"
313 | },
314 | "id": "77ZoCzrMhLxc",
315 | "outputId": "ee113f1e-23e3-4fd8-8ecd-4231df22da22"
316 | },
317 | "outputs": [
318 | {
319 | "data": {
320 | "text/plain": [
321 | "(36651, 3)"
322 | ]
323 | },
324 | "execution_count": 8,
325 | "metadata": {},
326 | "output_type": "execute_result"
327 | }
328 | ],
329 | "source": [
330 | "# Filtering the rare terms.\n",
331 | "arxiv_data_filtered = arxiv_data.groupby(\"terms\").filter(lambda x: len(x) > 1)\n",
332 | "arxiv_data_filtered.shape"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {
338 | "id": "MxrG9tim0QNr"
339 | },
340 | "source": [
341 | "## Convert the string labels to list of strings\n",
342 | "\n",
343 | "The initial labels are represented as raw strings. Here we make them `List[str]` for a more compact representation. "
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": null,
349 | "metadata": {
350 | "colab": {
351 | "base_uri": "https://localhost:8080/"
352 | },
353 | "id": "LIEGLc61iwbQ",
354 | "outputId": "e5612b8b-145c-45b7-f6ed-1273ec8bd538"
355 | },
356 | "outputs": [
357 | {
358 | "data": {
359 | "text/plain": [
360 | "array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),\n",
361 | " list(['cs.CV', 'cs.AI']), list(['cs.CV']),\n",
362 | " list(['cs.CV', 'cs.LG'])], dtype=object)"
363 | ]
364 | },
365 | "execution_count": 9,
366 | "metadata": {},
367 | "output_type": "execute_result"
368 | }
369 | ],
370 | "source": [
371 | "arxiv_data_filtered[\"terms\"] = arxiv_data_filtered[\"terms\"].apply(\n",
372 | " lambda x: literal_eval(x)\n",
373 | ")\n",
374 | "arxiv_data_filtered[\"terms\"].values[:5]"
375 | ]
376 | },
377 | {
378 | "cell_type": "markdown",
379 | "metadata": {
380 | "id": "zjFB8Uoo0cXM"
381 | },
382 | "source": [
383 | "## Stratified splits because of class imbalance\n",
384 | "\n",
385 | "The dataset has a [class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset). So, to have a fair evaluation result, we need to ensure the datasets are sampled with stratification. To know more about different strategies to deal with the class imbalance problem, you can follow [this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data). For an end-to-end demonstration of classification with imbablanced data, refer to [Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/)."
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {
392 | "colab": {
393 | "base_uri": "https://localhost:8080/"
394 | },
395 | "id": "EbKDVTKPgOKe",
396 | "outputId": "f16da845-9641-42dd-c387-649facc014a9"
397 | },
398 | "outputs": [
399 | {
400 | "name": "stdout",
401 | "output_type": "stream",
402 | "text": [
403 | "Number of rows in training set: 32985\n",
404 | "Number of rows in validation set: 1833\n",
405 | "Number of rows in test set: 1833\n"
406 | ]
407 | },
408 | {
409 | "name": "stderr",
410 | "output_type": "stream",
411 | "text": [
412 | "/usr/local/lib/python3.7/dist-packages/pandas/core/frame.py:4174: SettingWithCopyWarning: \n",
413 | "A value is trying to be set on a copy of a slice from a DataFrame\n",
414 | "\n",
415 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
416 | " errors=errors,\n"
417 | ]
418 | }
419 | ],
420 | "source": [
421 | "test_split = 0.1\n",
422 | "\n",
423 | "# Initial train and test split.\n",
424 | "train_df, test_df = train_test_split(\n",
425 | " arxiv_data_filtered,\n",
426 | " test_size=test_split,\n",
427 | " stratify=arxiv_data_filtered[\"terms\"].values,\n",
428 | ")\n",
429 | "\n",
430 | "# Splitting the test set further into validation\n",
431 | "# and new test sets.\n",
432 | "val_df = test_df.sample(frac=0.5)\n",
433 | "test_df.drop(val_df.index, inplace=True)\n",
434 | "\n",
435 | "print(f\"Number of rows in training set: {len(train_df)}\")\n",
436 | "print(f\"Number of rows in validation set: {len(val_df)}\")\n",
437 | "print(f\"Number of rows in test set: {len(test_df)}\")"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {
443 | "id": "96Ew2PPI0lVc"
444 | },
445 | "source": [
446 | "## Multi-label binarization\n",
447 | "\n",
448 | "Now we preprocess our labels using [`MultiLabelBinarizer`](http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html). "
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "metadata": {
455 | "colab": {
456 | "base_uri": "https://localhost:8080/"
457 | },
458 | "id": "1vgxbdwGf07E",
459 | "outputId": "1786e58c-9f56-4c14-f9cb-1f4da7fdad85"
460 | },
461 | "outputs": [
462 | {
463 | "data": {
464 | "text/plain": [
465 | "array(['14J60 (Primary) 14F05, 14J26 (Secondary)', '62H30', '62H35',\n",
466 | " '62H99', '65D19', '68', '68Q32', '68T01', '68T05', '68T07',\n",
467 | " '68T10', '68T30', '68T45', '68T99', '68Txx', '68U01', '68U10',\n",
468 | " 'E.5; E.4; E.2; H.1.1; F.1.1; F.1.3', 'F.2.2; I.2.7', 'G.3',\n",
469 | " 'H.3.1; H.3.3; I.2.6; I.2.7', 'H.3.1; I.2.6; I.2.7', 'I.2',\n",
470 | " 'I.2.0; I.2.6', 'I.2.1', 'I.2.10', 'I.2.10; I.2.6',\n",
471 | " 'I.2.10; I.4.8', 'I.2.10; I.4.8; I.5.4', 'I.2.10; I.4; I.5',\n",
472 | " 'I.2.10; I.5.1; I.4.8', 'I.2.1; J.3', 'I.2.6', 'I.2.6, I.5.4',\n",
473 | " 'I.2.6; I.2.10', 'I.2.6; I.2.7', 'I.2.6; I.2.7; H.3.1; H.3.3',\n",
474 | " 'I.2.6; I.2.8', 'I.2.6; I.2.9', 'I.2.6; I.5.1', 'I.2.6; I.5.4',\n",
475 | " 'I.2.7', 'I.2.8', 'I.2; I.2.6; I.2.7', 'I.2; I.4; I.5', 'I.2; I.5',\n",
476 | " 'I.2; J.2', 'I.4', 'I.4.0', 'I.4.3', 'I.4.4', 'I.4.5', 'I.4.6',\n",
477 | " 'I.4.6; I.4.8', 'I.4.8', 'I.4.9', 'I.4.9; I.5.4', 'I.4; I.5',\n",
478 | " 'I.5.4', 'K.3.2', 'astro-ph.IM', 'cond-mat.dis-nn',\n",
479 | " 'cond-mat.mtrl-sci', 'cond-mat.soft', 'cond-mat.stat-mech',\n",
480 | " 'cs.AI', 'cs.AR', 'cs.CC', 'cs.CE', 'cs.CG', 'cs.CL', 'cs.CR',\n",
481 | " 'cs.CV', 'cs.CY', 'cs.DB', 'cs.DC', 'cs.DM', 'cs.DS', 'cs.ET',\n",
482 | " 'cs.FL', 'cs.GR', 'cs.GT', 'cs.HC', 'cs.IR', 'cs.IT', 'cs.LG',\n",
483 | " 'cs.LO', 'cs.MA', 'cs.MM', 'cs.MS', 'cs.NA', 'cs.NE', 'cs.NI',\n",
484 | " 'cs.PF', 'cs.PL', 'cs.RO', 'cs.SC', 'cs.SD', 'cs.SE', 'cs.SI',\n",
485 | " 'cs.SY', 'econ.EM', 'econ.GN', 'eess.AS', 'eess.IV', 'eess.SP',\n",
486 | " 'eess.SY', 'hep-ex', 'hep-ph', 'math.AP', 'math.AT', 'math.CO',\n",
487 | " 'math.DS', 'math.FA', 'math.IT', 'math.LO', 'math.NA', 'math.OC',\n",
488 | " 'math.PR', 'math.ST', 'nlin.AO', 'nlin.CD', 'physics.ao-ph',\n",
489 | " 'physics.bio-ph', 'physics.chem-ph', 'physics.comp-ph',\n",
490 | " 'physics.data-an', 'physics.flu-dyn', 'physics.geo-ph',\n",
491 | " 'physics.med-ph', 'physics.optics', 'physics.soc-ph', 'q-bio.BM',\n",
492 | " 'q-bio.GN', 'q-bio.MN', 'q-bio.NC', 'q-bio.OT', 'q-bio.QM',\n",
493 | " 'q-bio.TO', 'q-fin.CP', 'q-fin.EC', 'q-fin.GN', 'q-fin.PM',\n",
494 | " 'q-fin.RM', 'q-fin.ST', 'q-fin.TR', 'quant-ph', 'stat.AP',\n",
495 | " 'stat.CO', 'stat.ME', 'stat.ML', 'stat.TH'], dtype=object)"
496 | ]
497 | },
498 | "execution_count": 11,
499 | "metadata": {},
500 | "output_type": "execute_result"
501 | }
502 | ],
503 | "source": [
504 | "mlb = MultiLabelBinarizer()\n",
505 | "mlb.fit_transform(train_df[\"terms\"])\n",
506 | "mlb.classes_"
507 | ]
508 | },
509 | {
510 | "cell_type": "markdown",
511 | "metadata": {
512 | "id": "HTcmRnP9HNq9"
513 | },
514 | "source": [
515 | "`MultiLabelBinarizer`separates out the individual unique classes available from the label pool and then uses this information to represent a given label set with 0's and 1's. Below is an example. "
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {
522 | "colab": {
523 | "base_uri": "https://localhost:8080/"
524 | },
525 | "id": "yfMO31s8HNq9",
526 | "outputId": "7a6b1cb9-761b-409c-fbc5-9534f5168629"
527 | },
528 | "outputs": [
529 | {
530 | "name": "stdout",
531 | "output_type": "stream",
532 | "text": [
533 | "Original label: ['cs.LG', 'cs.AI', 'stat.ML']\n",
534 | "Label-binarized representation: [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
535 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0\n",
536 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
537 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
538 | " 0 0 0 0 0 0 1 0]]\n"
539 | ]
540 | }
541 | ],
542 | "source": [
543 | "sample_label = train_df[\"terms\"].iloc[0]\n",
544 | "print(f\"Original label: {sample_label}\")\n",
545 | "\n",
546 | "label_binarized = mlb.transform([sample_label])\n",
547 | "print(f\"Label-binarized representation: {label_binarized}\")"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {
553 | "id": "a2kFVBCG0oXV"
554 | },
555 | "source": [
556 | "## Data preprocessing and `tf.data.Dataset` objects\n",
557 | "\n",
558 | "We first get percentile estimates of the sequence lengths. The purpose will be clear in a moment."
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": null,
564 | "metadata": {
565 | "colab": {
566 | "base_uri": "https://localhost:8080/"
567 | },
568 | "id": "kCR-_Iw3gyT-",
569 | "outputId": "1e00b7b1-b70a-4f2f-f97a-4defa37d3646"
570 | },
571 | "outputs": [
572 | {
573 | "data": {
574 | "text/plain": [
575 | "count 32985.000000\n",
576 | "mean 156.629650\n",
577 | "std 41.521087\n",
578 | "min 5.000000\n",
579 | "25% 128.000000\n",
580 | "50% 155.000000\n",
581 | "75% 183.000000\n",
582 | "max 462.000000\n",
583 | "Name: summaries, dtype: float64"
584 | ]
585 | },
586 | "execution_count": 13,
587 | "metadata": {},
588 | "output_type": "execute_result"
589 | }
590 | ],
591 | "source": [
592 | "train_df[\"summaries\"].apply(lambda x: len(x.split(\" \"))).describe()"
593 | ]
594 | },
595 | {
596 | "cell_type": "markdown",
597 | "metadata": {
598 | "id": "2rdivJop02xG"
599 | },
600 | "source": [
601 | "Notice that 50% of the abstracts have a length of 154 (you may get a different number based on the split). So, any number near that is a good enough approximate for the maximum sequence length. \n",
602 | "\n",
603 | "Now, we write utilities to prepare our datasets that would go straight to the text classifier model. "
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": null,
609 | "metadata": {
610 | "id": "QoMgEZrtVBDS"
611 | },
612 | "outputs": [],
613 | "source": [
614 | "max_seqlen = 150\n",
615 | "batch_size = 128\n",
616 | "padding_token = \"\"\n",
617 | "\n",
618 | "\n",
619 | "def unify_text_length(text, label):\n",
620 | " # Split the given abstract and calculate its length.\n",
621 | " word_splits = tf.strings.split(text, sep=\" \")\n",
622 | " sequence_length = tf.shape(word_splits)[0]\n",
623 | " \n",
624 | " # Calculate the padding amount.\n",
625 | " padding_amount = max_seqlen - sequence_length\n",
626 | " \n",
627 | " # Check if we need to pad or truncate.\n",
628 | " if padding_amount > 0:\n",
629 | " unified_text = tf.pad([text], [[0, padding_amount]], constant_values=\"\")\n",
630 | " unified_text = tf.strings.reduce_join(unified_text, separator=\"\")\n",
631 | " else:\n",
632 | " unified_text = tf.strings.reduce_join(word_splits[:max_seqlen], separator=\" \")\n",
633 | " \n",
634 | " # The expansion is needed for subsequent vectorization.\n",
635 | " return tf.expand_dims(unified_text, -1), label\n",
636 | "\n",
637 | "\n",
638 | "def make_dataset(dataframe, is_train=True):\n",
639 | " label_binarized = mlb.transform(dataframe[\"terms\"].values)\n",
640 | " dataset = tf.data.Dataset.from_tensor_slices(\n",
641 | " (dataframe[\"summaries\"].values, label_binarized)\n",
642 | " )\n",
643 | " dataset = dataset.shuffle(batch_size * 10) if is_train else dataset\n",
644 | " dataset = dataset.map(unify_text_length).cache()\n",
645 | " return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)"
646 | ]
647 | },
648 | {
649 | "cell_type": "markdown",
650 | "metadata": {
651 | "id": "fi5g9v3HHNq-"
652 | },
653 | "source": [
654 | "Now we can prepare the `tf.data.Dataset` objects. "
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": null,
660 | "metadata": {
661 | "id": "F_vrgkCXrWOS"
662 | },
663 | "outputs": [],
664 | "source": [
665 | "train_dataset = make_dataset(train_df, is_train=True)\n",
666 | "validation_dataset = make_dataset(val_df, is_train=False)\n",
667 | "test_dataset = make_dataset(test_df, is_train=False)"
668 | ]
669 | },
670 | {
671 | "cell_type": "markdown",
672 | "metadata": {
673 | "id": "-1Bb4Xnm1EwK"
674 | },
675 | "source": [
676 | "## Dataset preview"
677 | ]
678 | },
679 | {
680 | "cell_type": "code",
681 | "execution_count": null,
682 | "metadata": {
683 | "colab": {
684 | "base_uri": "https://localhost:8080/"
685 | },
686 | "id": "w-8k7gScVoz6",
687 | "outputId": "8ea5d01c-e0ac-4e8b-b15b-b4b1975cd979"
688 | },
689 | "outputs": [
690 | {
691 | "name": "stdout",
692 | "output_type": "stream",
693 | "text": [
694 | "Abstract: b'A highly desirable property of a reinforcement learning (RL) agent -- and a\\nmajor difficulty for deep RL approaches -- is the ability to generalize\\npolicies learned on a few tasks over a high-dimensional observation space to\\nsimilar tasks not seen during training. Many promising approaches to this\\nchallenge consider RL as a process of training two functions simultaneously: a\\ncomplex nonlinear encoder that maps high-dimensional observations to a latent\\nrepresentation space, and a simple linear policy over this space. We posit that\\na superior encoder for zero-shot generalization in RL can be trained by using\\nsolely an auxiliary SSL objective if the training process encourages the\\nencoder to map behaviorally similar observations to similar representations, as\\nreward-based signal can cause overfitting in the encoder (Raileanu et al.,\\n2021). We propose Cross-Trajectory Representation Learning (CTRL), a method\\nthat runs within an RL agent and conditions its encoder to recognize behavioral\\nsimilarity in observations by applying a novel SSL objective to pairs of\\ntrajectories from'\n",
695 | "Label(s): ('cs.AI', 'cs.LG')\n",
696 | " \n",
697 | "Abstract: b'We introduce AndroidEnv, an open-source platform for Reinforcement Learning\\n(RL) research built on top of the Android ecosystem. AndroidEnv allows RL\\nagents to interact with a wide variety of apps and services commonly used by\\nhumans through a universal touchscreen interface. Since agents train on a\\nrealistic simulation of an Android device, they have the potential to be\\ndeployed on real devices. In this report, we give an overview of the\\nenvironment, highlighting the significant features it provides for research,\\nand we present an empirical evaluation of some popular reinforcement learning\\nagents on a set of tasks built on this platform.'\n",
698 | "Label(s): ('cs.AI', 'cs.LG')\n",
699 | " \n",
700 | "Abstract: b'Learning to imitate expert behavior from demonstrations can be challenging,\\nespecially in environments with high-dimensional, continuous observations and\\nunknown dynamics. Supervised learning methods based on behavioral cloning (BC)\\nsuffer from distribution shift: because the agent greedily imitates\\ndemonstrated actions, it can drift away from demonstrated states due to error\\naccumulation. Recent methods based on reinforcement learning (RL), such as\\ninverse RL and generative adversarial imitation learning (GAIL), overcome this\\nissue by training an RL agent to match the demonstrations over a long horizon.\\nSince the true reward function for the task is unknown, these methods learn a\\nreward function from the demonstrations, often using complex and brittle\\napproximation techniques that involve adversarial training. We propose a simple\\nalternative that still uses RL, but does not require learning a reward\\nfunction. The key idea is to provide the agent with an incentive to match the\\ndemonstrations over a long horizon, by encouraging it to return to demonstrated\\nstates upon encountering new, out-of-distribution states. We accomplish'\n",
701 | "Label(s): ('cs.LG', 'stat.ML')\n",
702 | " \n",
703 | "Abstract: b'Knowledge graph completion is an important task that aims to predict the\\nmissing relational link between entities. Knowledge graph embedding methods\\nperform this task by representing entities and relations as embedding vectors\\nand modeling their interactions to compute the matching score of each triple.\\nPrevious work has usually treated each embedding as a whole and has modeled the\\ninteractions between these whole embeddings, potentially making the model\\nexcessively expensive or requiring specially designed interaction mechanisms.\\nIn this work, we propose the multi-partition embedding interaction (MEI) model\\nwith block term format to systematically address this problem. MEI divides each\\nembedding into a multi-partition vector to efficiently restrict the\\ninteractions. Each local interaction is modeled with the Tucker tensor format\\nand the full interaction is modeled with the block term tensor format, enabling\\nMEI to control the trade-off between expressiveness and computational cost,\\nlearn the interaction mechanisms from data automatically, and achieve\\nstate-of-the-art performance on the link prediction task. In addition, we\\ntheoretically study the parameter efficiency'\n",
704 | "Label(s): ('cs.AI', 'cs.CL', 'cs.LG', 'stat.ML')\n",
705 | " \n",
706 | "Abstract: b'Quantization has been proven to be an effective method for reducing the\\ncomputing and/or storage cost of DNNs. However, the trade-off between the\\nquantization bitwidth and final accuracy is complex and non-convex, which makes\\nit difficult to be optimized directly. Minimizing direct quantization loss\\n(DQL) of the coefficient data is an effective local optimization method, but\\nprevious works often neglect the accurate control of the DQL, resulting in a\\nhigher loss of the final DNN model accuracy. In this paper, we propose a novel\\nmetric called Vector Loss. Based on this new metric, we develop a new\\nquantization solution called VecQ, which can guarantee minimal direct\\nquantization loss and better model accuracy. In addition, in order to speed up\\nthe proposed quantization process during model training, we accelerate the\\nquantization process with a parameterized probability estimation method and\\ntemplate-based derivation calculation. We evaluate our proposed algorithm on\\nMNIST, CIFAR, ImageNet, IMDB movie review and THUCNews text data sets with\\nnumerical DNN models. The results'\n",
707 | "Label(s): ('cs.CV',)\n",
708 | " \n"
709 | ]
710 | }
711 | ],
712 | "source": [
713 | "text_batch, label_batch = next(iter(train_dataset))\n",
714 | "\n",
715 | "for i, text in enumerate(text_batch[:5]):\n",
716 | " label = label_batch[i].numpy()[None, ...]\n",
717 | " print(f\"Abstract: {text[0]}\")\n",
718 | " print(f\"Label(s): {mlb.inverse_transform(label)[0]}\")\n",
719 | " print(\" \")"
720 | ]
721 | },
722 | {
723 | "cell_type": "markdown",
724 | "metadata": {
725 | "id": "Yfbnoi-y1IoB"
726 | },
727 | "source": [
728 | "## Vocabulary size for vectorization\n",
729 | "\n",
730 | "Before we feed the data to our model we need to represent them as numbers. For that purpose, we will use the [`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization). It can operate as a part of your main model so that the model is excluded from the core preprocessing logic. This greatly reduces the chances of training and serving skew. \n",
731 | "\n",
732 | "We first calculate the number of unique words present in the abstracts."
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": null,
738 | "metadata": {
739 | "colab": {
740 | "base_uri": "https://localhost:8080/"
741 | },
742 | "id": "IIW42KOgyfE7",
743 | "outputId": "69e504dc-a55e-41ec-e2b9-11b35750d9ca"
744 | },
745 | "outputs": [
746 | {
747 | "name": "stdout",
748 | "output_type": "stream",
749 | "text": [
750 | "Vocabulary size: 498\n"
751 | ]
752 | },
753 | {
754 | "name": "stderr",
755 | "output_type": "stream",
756 | "text": [
757 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n",
758 | "A value is trying to be set on a copy of a slice from a DataFrame.\n",
759 | "Try using .loc[row_indexer,col_indexer] = value instead\n",
760 | "\n",
761 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
762 | " \"\"\"Entry point for launching an IPython kernel.\n"
763 | ]
764 | }
765 | ],
766 | "source": [
767 | "train_df[\"total_words\"] = train_df[\"summaries\"].str.split().str.len()\n",
768 | "vocabulary_size = train_df[\"total_words\"].max()\n",
769 | "print(f\"Vocabulary size: {vocabulary_size}\")"
770 | ]
771 | },
772 | {
773 | "cell_type": "markdown",
774 | "metadata": {
775 | "id": "nis2d-xFHNq_"
776 | },
777 | "source": [
778 | "Now we can create our text classifier model with the `TextVectorization` layer present inside it. "
779 | ]
780 | },
781 | {
782 | "cell_type": "markdown",
783 | "metadata": {
784 | "id": "3TBUFHm11M3G"
785 | },
786 | "source": [
787 | "## Create model with `TextVectorization`\n",
788 | "\n",
789 | "A batch of raw text will first go through the `TextVectorization` layer and it will generate their integer representations. Internally, the `TextVectorization` layer will first create bi-grams out of the sequences and then represent them using [TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then be passed to the shallow model responsible for text classification. \n",
790 | "\n",
791 | "To know more about other possible configurations with `TextVectorizer`, please consult the [official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization). "
792 | ]
793 | },
794 | {
795 | "cell_type": "code",
796 | "execution_count": null,
797 | "metadata": {
798 | "id": "4XX7ovyPokNs"
799 | },
800 | "outputs": [],
801 | "source": [
802 | "text_vectorizer = layers.TextVectorization(\n",
803 | " max_tokens=vocabulary_size, ngrams=2, output_mode=\"tf_idf\"\n",
804 | ")\n",
805 | "\n",
806 | "# `TextVectorization` needs to be adapted as per the vocabulary from our\n",
807 | "# training set.\n",
808 | "with tf.device(\"/CPU:0\"):\n",
809 | " text_vectorizer.adapt(train_dataset.map(lambda text, label: text))\n",
810 | "\n",
811 | "\n",
812 | "def make_model():\n",
813 | " shallow_mlp_model = keras.Sequential(\n",
814 | " [\n",
815 | " text_vectorizer,\n",
816 | " layers.Dense(512, activation=\"relu\"),\n",
817 | " layers.Dense(256, activation=\"relu\"),\n",
818 | " layers.Dense(len(mlb.classes_), activation=\"sigmoid\"),\n",
819 | " ]\n",
820 | " )\n",
821 | " return shallow_mlp_model"
822 | ]
823 | },
824 | {
825 | "cell_type": "markdown",
826 | "metadata": {
827 | "id": "i7i71RLtugUs"
828 | },
829 | "source": [
830 | "Without the CPU placement, we run into: \n",
831 | "\n",
832 | "```\n",
833 | "(1) Invalid argument: During Variant Host->Device Copy: non-DMA-copy attempted of tensor type: string\n",
834 | "```"
835 | ]
836 | },
837 | {
838 | "cell_type": "code",
839 | "execution_count": null,
840 | "metadata": {
841 | "colab": {
842 | "base_uri": "https://localhost:8080/"
843 | },
844 | "id": "58tUdCsQuI1P",
845 | "outputId": "47a81fa0-d83c-4dfe-8745-8122ebbc4da9"
846 | },
847 | "outputs": [
848 | {
849 | "name": "stdout",
850 | "output_type": "stream",
851 | "text": [
852 | "Model: \"sequential\"\n",
853 | "_________________________________________________________________\n",
854 | "Layer (type) Output Shape Param # \n",
855 | "=================================================================\n",
856 | "text_vectorization (TextVect (None, 498) 1 \n",
857 | "_________________________________________________________________\n",
858 | "dense (Dense) (None, 512) 255488 \n",
859 | "_________________________________________________________________\n",
860 | "dense_1 (Dense) (None, 256) 131328 \n",
861 | "_________________________________________________________________\n",
862 | "dense_2 (Dense) (None, 152) 39064 \n",
863 | "=================================================================\n",
864 | "Total params: 425,881\n",
865 | "Trainable params: 425,880\n",
866 | "Non-trainable params: 1\n",
867 | "_________________________________________________________________\n"
868 | ]
869 | }
870 | ],
871 | "source": [
872 | "shallow_mlp_model = make_model()\n",
873 | "shallow_mlp_model.summary()"
874 | ]
875 | },
876 | {
877 | "cell_type": "markdown",
878 | "metadata": {
879 | "id": "y1Hr9D0O1Tw0"
880 | },
881 | "source": [
882 | "## Train the model\n",
883 | "\n",
884 | "We will train our model using the binary cross-entropy loss. This is because the labels are not disjoint. For a given abstract, we may have multiple categories. So, we will divide the prediction task into a series of multiple binary classification problems. This is also why we kept the activation function of the classification layer in our model to sigmoid. Researchers have used other combinations of loss function and activation function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Mahajan et al. used the softmax activation function and cross-entropy loss to train their models. "
885 | ]
886 | },
887 | {
888 | "cell_type": "code",
889 | "execution_count": null,
890 | "metadata": {
891 | "colab": {
892 | "base_uri": "https://localhost:8080/"
893 | },
894 | "id": "WCoaRkA-wsC3",
895 | "outputId": "d8f31c73-6f8b-43db-e226-c6fd55e446df"
896 | },
897 | "outputs": [
898 | {
899 | "name": "stdout",
900 | "output_type": "stream",
901 | "text": [
902 | "Epoch 1/20\n",
903 | "258/258 [==============================] - 4s 12ms/step - loss: 0.0778 - categorical_accuracy: 0.6148 - val_loss: 0.0227 - val_categorical_accuracy: 0.6519\n",
904 | "Epoch 2/20\n",
905 | "258/258 [==============================] - 3s 10ms/step - loss: 0.0229 - categorical_accuracy: 0.6810 - val_loss: 0.0215 - val_categorical_accuracy: 0.6759\n",
906 | "Epoch 3/20\n",
907 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0218 - categorical_accuracy: 0.6855 - val_loss: 0.0207 - val_categorical_accuracy: 0.6830\n",
908 | "Epoch 4/20\n",
909 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0209 - categorical_accuracy: 0.6904 - val_loss: 0.0205 - val_categorical_accuracy: 0.6994\n",
910 | "Epoch 5/20\n",
911 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0203 - categorical_accuracy: 0.6934 - val_loss: 0.0200 - val_categorical_accuracy: 0.6945\n",
912 | "Epoch 6/20\n",
913 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0197 - categorical_accuracy: 0.6949 - val_loss: 0.0198 - val_categorical_accuracy: 0.6972\n",
914 | "Epoch 7/20\n",
915 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0193 - categorical_accuracy: 0.6960 - val_loss: 0.0197 - val_categorical_accuracy: 0.6972\n",
916 | "Epoch 8/20\n",
917 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0189 - categorical_accuracy: 0.6987 - val_loss: 0.0195 - val_categorical_accuracy: 0.6929\n",
918 | "Epoch 9/20\n",
919 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0185 - categorical_accuracy: 0.6991 - val_loss: 0.0195 - val_categorical_accuracy: 0.6918\n",
920 | "Epoch 10/20\n",
921 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0182 - categorical_accuracy: 0.6992 - val_loss: 0.0196 - val_categorical_accuracy: 0.6945\n",
922 | "Epoch 11/20\n",
923 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0178 - categorical_accuracy: 0.7002 - val_loss: 0.0195 - val_categorical_accuracy: 0.6939\n",
924 | "Epoch 12/20\n",
925 | "258/258 [==============================] - 3s 10ms/step - loss: 0.0175 - categorical_accuracy: 0.7024 - val_loss: 0.0196 - val_categorical_accuracy: 0.6956\n",
926 | "Epoch 13/20\n",
927 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0172 - categorical_accuracy: 0.7048 - val_loss: 0.0197 - val_categorical_accuracy: 0.6978\n",
928 | "Epoch 14/20\n",
929 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0168 - categorical_accuracy: 0.7065 - val_loss: 0.0200 - val_categorical_accuracy: 0.6994\n",
930 | "Epoch 15/20\n",
931 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0166 - categorical_accuracy: 0.7082 - val_loss: 0.0201 - val_categorical_accuracy: 0.7005\n",
932 | "Epoch 16/20\n",
933 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0163 - categorical_accuracy: 0.7104 - val_loss: 0.0201 - val_categorical_accuracy: 0.6901\n",
934 | "Epoch 17/20\n",
935 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0160 - categorical_accuracy: 0.7111 - val_loss: 0.0202 - val_categorical_accuracy: 0.6863\n",
936 | "Epoch 18/20\n",
937 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0157 - categorical_accuracy: 0.7117 - val_loss: 0.0202 - val_categorical_accuracy: 0.6879\n",
938 | "Epoch 19/20\n",
939 | "258/258 [==============================] - 3s 11ms/step - loss: 0.0154 - categorical_accuracy: 0.7127 - val_loss: 0.0205 - val_categorical_accuracy: 0.6863\n",
940 | "Epoch 20/20\n",
941 | "258/258 [==============================] - 3s 10ms/step - loss: 0.0150 - categorical_accuracy: 0.7138 - val_loss: 0.0206 - val_categorical_accuracy: 0.6836\n"
942 | ]
943 | }
944 | ],
945 | "source": [
946 | "epochs = 20\n",
947 | "\n",
948 | "shallow_mlp_model.compile(\n",
949 | " loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"categorical_accuracy\"]\n",
950 | ")\n",
951 | "\n",
952 | "history = shallow_mlp_model.fit(train_dataset, validation_data=validation_dataset, epochs=epochs)"
953 | ]
954 | },
955 | {
956 | "cell_type": "code",
957 | "execution_count": null,
958 | "metadata": {
959 | "colab": {
960 | "base_uri": "https://localhost:8080/",
961 | "height": 575
962 | },
963 | "id": "bOG4RF0KxUpq",
964 | "outputId": "4b809603-f59c-4c4b-9f89-2d03cc42ee4a"
965 | },
966 | "outputs": [
967 | {
968 | "data": {
969 | "image/png": "\n",
970 | "text/plain": [
971 | ""
972 | ]
973 | },
974 | "metadata": {
975 | "needs_background": "light"
976 | },
977 | "output_type": "display_data"
978 | },
979 | {
980 | "data": {
981 | "image/png": "\n",
982 | "text/plain": [
983 | ""
984 | ]
985 | },
986 | "metadata": {
987 | "needs_background": "light"
988 | },
989 | "output_type": "display_data"
990 | }
991 | ],
992 | "source": [
993 | "def plot_result(item):\n",
994 | " plt.plot(history.history[item], label=item)\n",
995 | " plt.plot(history.history[\"val_\" + item], label=\"val_\" + item)\n",
996 | " plt.xlabel(\"Epochs\")\n",
997 | " plt.ylabel(item)\n",
998 | " plt.title(\"Train and Validation {} Over Epochs\".format(item), fontsize=14)\n",
999 | " plt.legend()\n",
1000 | " plt.grid()\n",
1001 | " plt.show()\n",
1002 | "\n",
1003 | "\n",
1004 | "plot_result(\"loss\")\n",
1005 | "plot_result(\"categorical_accuracy\")"
1006 | ]
1007 | },
1008 | {
1009 | "cell_type": "markdown",
1010 | "metadata": {
1011 | "id": "XkwRAIcsNeXj"
1012 | },
1013 | "source": [
1014 | "While training, we notice an initial sharp fall in the loss followed by a gradual decay."
1015 | ]
1016 | },
1017 | {
1018 | "cell_type": "markdown",
1019 | "metadata": {
1020 | "id": "wBvqeuk88G9r"
1021 | },
1022 | "source": [
1023 | "### Evaluate the model"
1024 | ]
1025 | },
1026 | {
1027 | "cell_type": "code",
1028 | "execution_count": null,
1029 | "metadata": {
1030 | "colab": {
1031 | "base_uri": "https://localhost:8080/"
1032 | },
1033 | "id": "sxz8yDaT8MdL",
1034 | "outputId": "04cbcc20-9c7d-417a-fc0f-f67236533711"
1035 | },
1036 | "outputs": [
1037 | {
1038 | "name": "stdout",
1039 | "output_type": "stream",
1040 | "text": [
1041 | "15/15 [==============================] - 0s 23ms/step - loss: 0.0200 - categorical_accuracy: 0.6978\n",
1042 | "Categorical accuracy on the test set: 69.78%.\n"
1043 | ]
1044 | }
1045 | ],
1046 | "source": [
1047 | "_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)\n",
1048 | "print(f\"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.\")"
1049 | ]
1050 | },
1051 | {
1052 | "cell_type": "markdown",
1053 | "metadata": {
1054 | "id": "of0EtUhKNjv6"
1055 | },
1056 | "source": [
1057 | "The trained model gives us a validation accuracy of ~70%."
1058 | ]
1059 | },
1060 | {
1061 | "cell_type": "markdown",
1062 | "metadata": {
1063 | "id": "SeCfK4daFGJq"
1064 | },
1065 | "source": [
1066 | "## Inference"
1067 | ]
1068 | },
1069 | {
1070 | "cell_type": "code",
1071 | "execution_count": 69,
1072 | "metadata": {
1073 | "colab": {
1074 | "base_uri": "https://localhost:8080/"
1075 | },
1076 | "id": "axBPAMpe-sp2",
1077 | "outputId": "3ba99084-0491-47ce-dce4-f74b69708e7e"
1078 | },
1079 | "outputs": [
1080 | {
1081 | "name": "stdout",
1082 | "output_type": "stream",
1083 | "text": [
1084 | "Abstract: b'Parametric causal modelling techniques rarely provide functionality for\\ncounterfactual estimation, often at the expense of modelling complexity. Since\\ncausal estimations depend on the family of functions used to model the data,\\nsimplistic models could entail imprecise characterizations of the generative\\nmechanism, and, consequently, unreliable results. This limits their\\napplicability to real-life datasets, with non-linear relationships and high\\ninteraction between variables. We propose Deep Causal Graphs, an abstract\\nspecification of the required functionality for a neural network to model\\ncausal distributions, and provide a model that satisfies this contract:\\nNormalizing Causal Flows. We demonstrate its expressive power in modelling\\ncomplex interactions and showcase applications of the method to machine\\nlearning explainability and fairness, using true causal counterfactuals.'\n",
1085 | "Label(s): ('cs.LG', 'cs.NE', 'stat.ML')\n",
1086 | "Predicted Label(s): (cs.LG, stat.ML, physics.data-an)\n",
1087 | " \n",
1088 | "Abstract: b\"Deep learning inference on embedded devices is a burgeoning field with myriad\\napplications because tiny embedded devices are omnipresent. But we must\\novercome major challenges before we can benefit from this opportunity. Embedded\\nprocessors are severely resource constrained. Their nearest mobile counterparts\\nexhibit at least a 100 -- 1,000x difference in compute capability, memory\\navailability, and power consumption. As a result, the machine-learning (ML)\\nmodels and associated ML inference framework must not only execute efficiently\\nbut also operate in a few kilobytes of memory. Also, the embedded devices'\\necosystem is heavily fragmented. To maximize efficiency, system vendors often\\nomit many features that commonly appear in mainstream systems, including\\ndynamic memory allocation and virtual memory, that allow for cross-platform\\ninteroperability. The hardware comes in many flavors (e.g., instruction-set\\narchitecture and FPU support, or lack thereof). We introduce TensorFlow Lite\\nMicro (TF Micro), an open-source ML inference framework for running\\ndeep-learning models on embedded systems. TF Micro tackles the efficiency\\nrequirements imposed by embedded-system resource constraints and\"\n",
1089 | "Label(s): ('cs.AI', 'cs.LG')\n",
1090 | "Predicted Label(s): (cs.LG, stat.ML, cs.AI)\n",
1091 | " \n",
1092 | "Abstract: b'Lidar odometry (LO) is a key technology in numerous reliable and accurate\\nlocalization and mapping systems of autonomous driving. The state-of-the-art LO\\nmethods generally leverage geometric information to perform point cloud\\nregistration. Furthermore, obtaining point cloud semantic information which can\\ndescribe the environment more abundantly will help for the registration. We\\npresent a novel semantic lidar odometry method based on self-designed\\nparameterized semantic features (PSFs) to achieve low-drift ego-motion\\nestimation for autonomous vehicle in realtime. We first use a convolutional\\nneural network-based algorithm to obtain point-wise semantics from the input\\nlaser point cloud, and then use semantic labels to separate the road, building,\\ntraffic sign and pole-like point cloud and fit them separately to obtain\\ncorresponding PSFs. A fast PSF-based matching enable us to refine geometric\\nfeatures (GeFs) registration, reducing the impact of blurred submap surface on\\nthe accuracy of GeFs matching. Besides, we design an efficient method to\\naccurately recognize and remove the dynamic objects while retaining static ones\\nin the semantic point cloud,'\n",
1093 | "Label(s): ('cs.CV',)\n",
1094 | "Predicted Label(s): (cs.CV, cs.RO, cs.GR)\n",
1095 | " \n",
1096 | "Abstract: b'Deep neural networks have been shown to suffer from poor generalization when\\nsmall perturbations are added (like Gaussian noise), yet little work has been\\ndone to evaluate their robustness to more natural image transformations like\\nphoto filters. This paper presents a study on how popular pretrained models are\\naffected by commonly used Instagram filters. To this end, we introduce\\nImageNet-Instagram, a filtered version of ImageNet, where 20 popular Instagram\\nfilters are applied to each image in ImageNet. Our analysis suggests that\\nsimple structure preserving filters which only alter the global appearance of\\nan image can lead to large differences in the convolutional feature space. To\\nimprove generalization, we introduce a lightweight de-stylization module that\\npredicts parameters used for scaling and shifting feature maps to \"undo\" the\\nchanges incurred by filters, inverting the process of style transfer tasks. We\\nfurther demonstrate the module can be readily plugged into modern CNN\\narchitectures together with skip connections. We conduct extensive studies on\\nImageNet-Instagram, and show quantitatively and'\n",
1097 | "Label(s): ('cs.CV', 'cs.LG', 'eess.IV')\n",
1098 | "Predicted Label(s): (cs.CV, eess.IV, cs.LG)\n",
1099 | " \n",
1100 | "Abstract: b'Video summarization is an effective way to facilitate video searching and\\nbrowsing. Most of existing systems employ encoder-decoder based recurrent\\nneural networks, which fail to explicitly diversify the system-generated\\nsummary frames while requiring intensive computations. In this paper, we\\npropose an efficient convolutional neural network architecture for video\\nSUMmarization via Global Diverse Attention called SUM-GDA, which adapts\\nattention mechanism in a global perspective to consider pairwise temporal\\nrelations of video frames. Particularly, the GDA module has two advantages: 1)\\nit models the relations within paired frames as well as the relations among all\\npairs, thus capturing the global attention across all frames of one video; 2)\\nit reflects the importance of each frame to the whole video, leading to diverse\\nattention on these frames. Thus, SUM-GDA is beneficial for generating diverse\\nframes to form satisfactory video summary. Extensive experiments on three data\\nsets, i.e., SumMe, TVSum, and VTW, have demonstrated that SUM-GDA and its\\nextension outperform other competing state-of-the-art methods with remarkable\\nimprovements. In addition,'\n",
1101 | "Label(s): ('cs.CV', 'cs.MM')\n",
1102 | "Predicted Label(s): (cs.CV, cs.LG, eess.IV)\n",
1103 | " \n"
1104 | ]
1105 | }
1106 | ],
1107 | "source": [
1108 | "text_batch, label_batch = next(iter(test_dataset))\n",
1109 | "predicted_probabilities = shallow_mlp_model.predict(text_batch)\n",
1110 | "\n",
1111 | "for i, text in enumerate(text_batch[:5]):\n",
1112 | " label = label_batch[i].numpy()[None, ...]\n",
1113 | " print(f\"Abstract: {text[0]}\")\n",
1114 | " print(f\"Label(s): {mlb.inverse_transform(label)[0]}\")\n",
1115 | " predicted_proba = [proba for proba in predicted_probabilities[i]]\n",
1116 | " top_3_labels = [x for _, x in sorted(zip(predicted_probabilities[i], mlb.classes_), key=lambda pair: pair[0], reverse=True)][:3]\n",
1117 | " print(f\"Predicted Label(s): ({', '.join([label for label in top_3_labels])})\")\n",
1118 | " print(\" \")"
1119 | ]
1120 | },
1121 | {
1122 | "cell_type": "markdown",
1123 | "metadata": {
1124 | "id": "5-gQfc11SwAv"
1125 | },
1126 | "source": [
1127 | "The prediction results are not that great but not below the par for a simple model like ours. We can improve this performance with models that consider word order like LSTM or even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762))."
1128 | ]
1129 | }
1130 | ],
1131 | "metadata": {
1132 | "accelerator": "GPU",
1133 | "colab": {
1134 | "collapsed_sections": [],
1135 | "name": "multi_label_trainer_tfidf.ipynb",
1136 | "provenance": []
1137 | },
1138 | "kernelspec": {
1139 | "display_name": "Python 3 (ipykernel)",
1140 | "language": "python",
1141 | "name": "python3"
1142 | },
1143 | "language_info": {
1144 | "codemirror_mode": {
1145 | "name": "ipython",
1146 | "version": 3
1147 | },
1148 | "file_extension": ".py",
1149 | "mimetype": "text/x-python",
1150 | "name": "python",
1151 | "nbconvert_exporter": "python",
1152 | "pygments_lexer": "ipython3",
1153 | "version": "3.8.2"
1154 | }
1155 | },
1156 | "nbformat": 4,
1157 | "nbformat_minor": 1
1158 | }
1159 |
--------------------------------------------------------------------------------