├── .DS_Store
├── BERT_Captum
├── .DS_Store
└── Bert_captum.ipynb
├── Bar Chart Race
└── bar chart race.ipynb
├── Bottleneck_Adapters
└── Bottleneck_Adapters_Medium.ipynb
├── GPT2_TextGeneration
└── GPT_2_Medium.ipynb
├── Layout_Parser
├── img
│ ├── doc_1.pdf
│ └── doc_2.pdf
└── layout_parser_ex.ipynb
├── Lime
├── LIME_image_class.ipynb
└── panda_00024.jpg
├── NER_BERT
├── .ipynb_checkpoints
│ └── NER_with_BERT-checkpoint.ipynb
└── NER_with_BERT.ipynb
├── Optuna
└── Optuna.ipynb
├── README.md
├── STS_BERT
└── STS_BERT.ipynb
├── Spaces_Translation_App
└── app.py
├── Text_Classification_BERT
└── bert_medium.ipynb
├── Text_Classification_Transformer_Encoders
└── Transformer_Encoder.ipynb
└── ViT
└── Vision_Transformer.ipynb
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/.DS_Store
--------------------------------------------------------------------------------
/BERT_Captum/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/BERT_Captum/.DS_Store
--------------------------------------------------------------------------------
/BERT_Captum/Bert_captum.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "f0f52175-a7e8-45ab-a965-0d4a1e8e760e",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%%capture\n",
11 | "!pip install transformers\n",
12 | "!pip install captum"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "90853bde-95df-4827-98af-70410794c502",
18 | "metadata": {},
19 | "source": [
20 | "# Tokenization Example"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "3d970c20-8f80-4421-bcc3-85a581cf9738",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "from transformers import BertTokenizer\n",
31 | "\n",
32 | "# Instantiate tokenizer\n",
33 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
34 | "\n",
35 | "text = 'The movie is superb'\n",
36 | "\n",
37 | "# Tokenize input text\n",
38 | "text_ids = tokenizer.encode(text, add_special_tokens=True)\n",
39 | "\n",
40 | "# Print the tokens\n",
41 | "print(tokenizer.convert_ids_to_tokens(text_ids))\n",
42 | "# Output: ['[CLS]', 'The', 'movie', 'is', 'superb', '[SEP]']\n",
43 | "\n",
44 | "# Print the ids of the tokens\n",
45 | "print(text_ids)\n",
46 | "# Output: [101, 1109, 2523, 1110, 25876, 102]"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "62e541e1-da3d-47dd-a278-7950b4fb0e54",
52 | "metadata": {},
53 | "source": [
54 | "# Minimal Example to Fetch the Embeddings of Tokens"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "2fb12068-b429-4411-bda5-bedff656d387",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "from transformers import BertModel\n",
65 | "import torch\n",
66 | "# Instantiate BERT model\n",
67 | "model = BertModel.from_pretrained('bert-base-cased')\n",
68 | "\n",
69 | "embeddings = model.embeddings(torch.tensor([text_ids]))\n",
70 | "print(embeddings.size())\n",
71 | "# Output: torch.Size([1, 6, 768]), since there are 6 tokens in text_ids"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "id": "cc1d6810-edf6-4b1d-a48e-44e034a32a90",
77 | "metadata": {},
78 | "source": [
79 | "# Specify Model Architecture"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "2ea24f09-9fef-469c-8884-794fda046a5a",
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "from torch import nn\n",
90 | "\n",
91 | "class BertClassifier(nn.Module):\n",
92 | "\n",
93 | " def __init__(self, dropout=0.5):\n",
94 | "\n",
95 | " super(BertClassifier, self).__init__()\n",
96 | "\n",
97 | " self.bert = BertModel.from_pretrained('bert-base-cased')\n",
98 | " self.dropout = nn.Dropout(dropout)\n",
99 | " self.linear = nn.Linear(768, 2)\n",
100 | " self.relu = nn.ReLU()\n",
101 | "\n",
102 | " def forward(self, input_id, mask = None):\n",
103 | "\n",
104 | " _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)\n",
105 | " dropout_output = self.dropout(pooled_output)\n",
106 | " linear_output = self.linear(dropout_output)\n",
107 | " final_layer = self.relu(linear_output)\n",
108 | "\n",
109 | " return final_layer"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "id": "ee6f1929-401c-40c2-a7a2-4a061380f633",
115 | "metadata": {},
116 | "source": [
117 | "# Load Model's Parameters "
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "1da8d59e-4ed4-4ca2-bcdb-93857a9295e7",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "model = BertClassifier()\n",
128 | "model.load_state_dict(torch.load('path/to/bert_model.pt', map_location=torch.device('cpu')))\n",
129 | "model.eval()"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "8d29dd60-b4cc-43d6-b5a7-44f56adca3be",
135 | "metadata": {},
136 | "source": [
137 | "# Define Model Input and Output"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "1c8bc95b-99f1-4087-a012-8306fe328daa",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "# Define model output\n",
148 | "def model_output(inputs):\n",
149 | " return model(inputs)[0]\n",
150 | "\n",
151 | "# Define model input\n",
152 | "model_input = model.bert.embeddings"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "id": "d04dae59-850a-4b11-9684-cdcba55c6e2c",
158 | "metadata": {},
159 | "source": [
160 | "# Instantiate Integrated Gradients Method"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "e8e805ad-9c5d-40b6-bb9d-890eecde0945",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "from captum.attr import LayerIntegratedGradients\n",
171 | "lig = LayerIntegratedGradients(model_output, model_input)"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "id": "86492c38-8201-4d54-aba0-de21e4966a42",
177 | "metadata": {},
178 | "source": [
179 | "# Construct Original and Baseline Input"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "563b09b7-a6c2-4e67-bc11-2b25aa3a4a4c",
186 | "metadata": {
187 | "tags": []
188 | },
189 | "outputs": [],
190 | "source": [
191 | "def construct_input_and_baseline(text):\n",
192 | "\n",
193 | " max_length = 510\n",
194 | " baseline_token_id = tokenizer.pad_token_id \n",
195 | " sep_token_id = tokenizer.sep_token_id \n",
196 | " cls_token_id = tokenizer.cls_token_id \n",
197 | "\n",
198 | " text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)\n",
199 | " \n",
200 | " input_ids = [cls_token_id] + text_ids + [sep_token_id]\n",
201 | " token_list = tokenizer.convert_ids_to_tokens(input_ids)\n",
202 | " \n",
203 | "\n",
204 | " baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]\n",
205 | " return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list\n",
206 | "\n",
207 | "text = 'This movie is superb'\n",
208 | "input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)\n",
209 | "\n",
210 | "print(f'original text: {input_ids}')\n",
211 | "print(f'baseline text: {baseline_input_ids}')\n",
212 | "\n",
213 | "# Output: original text: tensor([[ 101, 1109, 2523, 1110, 25876, 102]])\n",
214 | "# Output: baseline text: tensor([[101, 0, 0, 0, 0, 102]])"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "id": "1aa548b0-e738-45fe-bc0f-0d490427b79b",
220 | "metadata": {},
221 | "source": [
222 | "# Compute Attributions"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "3db56682-41f3-46ec-b331-c81617391fec",
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "attributions, delta = lig.attribute(inputs= input_ids,\n",
233 | " baselines= baseline_input_ids,\n",
234 | " return_convergence_delta=True\n",
235 | " )\n",
236 | "print(attributions.size())\n",
237 | "# Output: torch.Size([1, 6, 768])"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "id": "864769c7-2378-42a2-8f35-76bd44af63ac",
243 | "metadata": {},
244 | "source": [
245 | "# Compute Attribution for Each Token"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "id": "826a93e2-4b4b-430a-9c20-10cbb5ac08e1",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "def summarize_attributions(attributions):\n",
256 | "\n",
257 | " attributions = attributions.sum(dim=-1).squeeze(0)\n",
258 | " attributions = attributions / torch.norm(attributions)\n",
259 | " \n",
260 | " return attributions\n",
261 | "\n",
262 | "attributions_sum = summarize_attributions(attributions)\n",
263 | "print(attributions_sum.size())\n",
264 | "# Output: torch.Size([6])"
265 | ]
266 | },
267 | {
268 | "cell_type": "markdown",
269 | "id": "0123bbd3-536f-48cb-a0fa-c9a6fd9557df",
270 | "metadata": {},
271 | "source": [
272 | "# Visualize the Interpretation"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "id": "af2eaaa4-0fa4-4928-a7e2-a727f8e56fdd",
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "from captum.attr import visualization as viz\n",
283 | "\n",
284 | "score_vis = viz.VisualizationDataRecord(\n",
285 | " word_attributions = attributions_sum,\n",
286 | " pred_prob = torch.max(model(input_ids)[0]),\n",
287 | " pred_class = torch.argmax(model(input_ids)[0]).numpy(),\n",
288 | " true_class = 1,\n",
289 | " attr_class = text,\n",
290 | " attr_score = attributions_sum.sum(), \n",
291 | " raw_input_ids = all_tokens,\n",
292 | " convergence_score = delta)\n",
293 | "\n",
294 | "viz.visualize_text([score_vis])"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "id": "a7c16834-2b58-4d96-b56b-f5908468f64a",
300 | "metadata": {},
301 | "source": [
302 | "# Encapsulate All the Steps Above"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "id": "b6c9ec71-84ca-4734-afd0-b006c3a711df",
309 | "metadata": {},
310 | "outputs": [],
311 | "source": [
312 | "def interpret_text(text, true_class):\n",
313 | "\n",
314 | " input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)\n",
315 | " attributions, delta = lig.attribute(inputs= input_ids,\n",
316 | " baselines= baseline_input_ids,\n",
317 | " return_convergence_delta=True\n",
318 | " )\n",
319 | " attributions_sum = summarize_attributions(attributions)\n",
320 | "\n",
321 | " score_vis = viz.VisualizationDataRecord(\n",
322 | " word_attributions = attributions_sum,\n",
323 | " pred_prob = torch.max(model(input_ids)[0]),\n",
324 | " pred_class = torch.argmax(model(input_ids)[0]).numpy(),\n",
325 | " true_class = true_class,\n",
326 | " attr_class = text,\n",
327 | " attr_score = attributions_sum.sum(), \n",
328 | " raw_input_ids = all_tokens,\n",
329 | " convergence_score = delta)\n",
330 | "\n",
331 | " viz.visualize_text([score_vis])"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "id": "57984aee-32e9-418a-a286-f35e0bc67b70",
337 | "metadata": {},
338 | "source": [
339 | "# Interpret Texts"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "id": "d530db94-52fa-4c9a-a673-27247eac3b99",
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "text = \"It's a heartfelt film about love, loss, and legacy\"\n",
350 | "true_class = 1\n",
351 | "interpret_text(text, true_class)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "id": "8544cb5c-3afd-45a7-ae82-7f9e692ff223",
358 | "metadata": {},
359 | "outputs": [],
360 | "source": [
361 | "text = \"A noisy, hideous, and viciously cumbersome movie\"\n",
362 | "true_class = 0\n",
363 | "interpret_text(text, true_class)"
364 | ]
365 | }
366 | ],
367 | "metadata": {
368 | "kernelspec": {
369 | "display_name": "Python 3 (ipykernel)",
370 | "language": "python",
371 | "name": "python3"
372 | },
373 | "language_info": {
374 | "codemirror_mode": {
375 | "name": "ipython",
376 | "version": 3
377 | },
378 | "file_extension": ".py",
379 | "mimetype": "text/x-python",
380 | "name": "python",
381 | "nbconvert_exporter": "python",
382 | "pygments_lexer": "ipython3",
383 | "version": "3.9.7"
384 | }
385 | },
386 | "nbformat": 4,
387 | "nbformat_minor": 5
388 | }
389 |
--------------------------------------------------------------------------------
/Bar Chart Race/bar chart race.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "prem_league = pd.read_csv('D:/PL Dataset/premierLeague_tables_1992-2017.csv')"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "data": {
29 | "text/html": [
30 | "
\n",
31 | "\n",
44 | "
\n",
45 | " \n",
46 | " \n",
47 | " | \n",
48 | " season | \n",
49 | " team | \n",
50 | " points | \n",
51 | " w | \n",
52 | " d | \n",
53 | " l | \n",
54 | " gf | \n",
55 | " ga | \n",
56 | " gd | \n",
57 | " pld | \n",
58 | " ... | \n",
59 | " d_h | \n",
60 | " d_a | \n",
61 | " l_h | \n",
62 | " l_a | \n",
63 | " gf_h | \n",
64 | " gf_a | \n",
65 | " ga_h | \n",
66 | " ga_a | \n",
67 | " gd_h | \n",
68 | " gd_a | \n",
69 | "
\n",
70 | " \n",
71 | " \n",
72 | " \n",
73 | " | 0 | \n",
74 | " 2017-18 | \n",
75 | " Manchester City | \n",
76 | " 100 | \n",
77 | " 32 | \n",
78 | " 4 | \n",
79 | " 2 | \n",
80 | " 106 | \n",
81 | " 27 | \n",
82 | " 79 | \n",
83 | " 38 | \n",
84 | " ... | \n",
85 | " 2 | \n",
86 | " 2 | \n",
87 | " 1 | \n",
88 | " 1 | \n",
89 | " 61 | \n",
90 | " 45 | \n",
91 | " 14 | \n",
92 | " 13 | \n",
93 | " 47 | \n",
94 | " 32 | \n",
95 | "
\n",
96 | " \n",
97 | " | 1 | \n",
98 | " 2017-18 | \n",
99 | " Manchester United | \n",
100 | " 81 | \n",
101 | " 25 | \n",
102 | " 6 | \n",
103 | " 7 | \n",
104 | " 68 | \n",
105 | " 28 | \n",
106 | " 40 | \n",
107 | " 38 | \n",
108 | " ... | \n",
109 | " 2 | \n",
110 | " 4 | \n",
111 | " 2 | \n",
112 | " 5 | \n",
113 | " 38 | \n",
114 | " 30 | \n",
115 | " 9 | \n",
116 | " 19 | \n",
117 | " 29 | \n",
118 | " 11 | \n",
119 | "
\n",
120 | " \n",
121 | " | 2 | \n",
122 | " 2017-18 | \n",
123 | " Tottenham Hotspur | \n",
124 | " 77 | \n",
125 | " 23 | \n",
126 | " 8 | \n",
127 | " 7 | \n",
128 | " 74 | \n",
129 | " 36 | \n",
130 | " 38 | \n",
131 | " 38 | \n",
132 | " ... | \n",
133 | " 4 | \n",
134 | " 4 | \n",
135 | " 2 | \n",
136 | " 5 | \n",
137 | " 40 | \n",
138 | " 34 | \n",
139 | " 16 | \n",
140 | " 20 | \n",
141 | " 24 | \n",
142 | " 14 | \n",
143 | "
\n",
144 | " \n",
145 | " | 3 | \n",
146 | " 2017-18 | \n",
147 | " Liverpool | \n",
148 | " 75 | \n",
149 | " 21 | \n",
150 | " 12 | \n",
151 | " 5 | \n",
152 | " 84 | \n",
153 | " 38 | \n",
154 | " 46 | \n",
155 | " 38 | \n",
156 | " ... | \n",
157 | " 7 | \n",
158 | " 5 | \n",
159 | " 0 | \n",
160 | " 5 | \n",
161 | " 45 | \n",
162 | " 39 | \n",
163 | " 10 | \n",
164 | " 28 | \n",
165 | " 35 | \n",
166 | " 11 | \n",
167 | "
\n",
168 | " \n",
169 | " | 4 | \n",
170 | " 2017-18 | \n",
171 | " Chelsea | \n",
172 | " 70 | \n",
173 | " 21 | \n",
174 | " 7 | \n",
175 | " 10 | \n",
176 | " 62 | \n",
177 | " 38 | \n",
178 | " 24 | \n",
179 | " 38 | \n",
180 | " ... | \n",
181 | " 4 | \n",
182 | " 3 | \n",
183 | " 4 | \n",
184 | " 6 | \n",
185 | " 30 | \n",
186 | " 32 | \n",
187 | " 16 | \n",
188 | " 22 | \n",
189 | " 14 | \n",
190 | " 10 | \n",
191 | "
\n",
192 | " \n",
193 | "
\n",
194 | "
5 rows × 26 columns
\n",
195 | "
"
196 | ],
197 | "text/plain": [
198 | " season team points w d l gf ga gd pld ... d_h \\\n",
199 | "0 2017-18 Manchester City 100 32 4 2 106 27 79 38 ... 2 \n",
200 | "1 2017-18 Manchester United 81 25 6 7 68 28 40 38 ... 2 \n",
201 | "2 2017-18 Tottenham Hotspur 77 23 8 7 74 36 38 38 ... 4 \n",
202 | "3 2017-18 Liverpool 75 21 12 5 84 38 46 38 ... 7 \n",
203 | "4 2017-18 Chelsea 70 21 7 10 62 38 24 38 ... 4 \n",
204 | "\n",
205 | " d_a l_h l_a gf_h gf_a ga_h ga_a gd_h gd_a \n",
206 | "0 2 1 1 61 45 14 13 47 32 \n",
207 | "1 4 2 5 38 30 9 19 29 11 \n",
208 | "2 4 2 5 40 34 16 20 24 14 \n",
209 | "3 5 0 5 45 39 10 28 35 11 \n",
210 | "4 3 4 6 30 32 16 22 14 10 \n",
211 | "\n",
212 | "[5 rows x 26 columns]"
213 | ]
214 | },
215 | "execution_count": 3,
216 | "metadata": {},
217 | "output_type": "execute_result"
218 | }
219 | ],
220 | "source": [
221 | "prem_league.head()"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 4,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "prem_league = prem_league[['season', 'team', 'points']]"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 5,
236 | "metadata": {},
237 | "outputs": [
238 | {
239 | "data": {
240 | "text/html": [
241 | "\n",
242 | "\n",
255 | "
\n",
256 | " \n",
257 | " \n",
258 | " | \n",
259 | " season | \n",
260 | " team | \n",
261 | " points | \n",
262 | "
\n",
263 | " \n",
264 | " \n",
265 | " \n",
266 | " | 0 | \n",
267 | " 2017-18 | \n",
268 | " Manchester City | \n",
269 | " 100 | \n",
270 | "
\n",
271 | " \n",
272 | " | 1 | \n",
273 | " 2017-18 | \n",
274 | " Manchester United | \n",
275 | " 81 | \n",
276 | "
\n",
277 | " \n",
278 | " | 2 | \n",
279 | " 2017-18 | \n",
280 | " Tottenham Hotspur | \n",
281 | " 77 | \n",
282 | "
\n",
283 | " \n",
284 | " | 3 | \n",
285 | " 2017-18 | \n",
286 | " Liverpool | \n",
287 | " 75 | \n",
288 | "
\n",
289 | " \n",
290 | " | 4 | \n",
291 | " 2017-18 | \n",
292 | " Chelsea | \n",
293 | " 70 | \n",
294 | "
\n",
295 | " \n",
296 | "
\n",
297 | "
"
298 | ],
299 | "text/plain": [
300 | " season team points\n",
301 | "0 2017-18 Manchester City 100\n",
302 | "1 2017-18 Manchester United 81\n",
303 | "2 2017-18 Tottenham Hotspur 77\n",
304 | "3 2017-18 Liverpool 75\n",
305 | "4 2017-18 Chelsea 70"
306 | ]
307 | },
308 | "execution_count": 5,
309 | "metadata": {},
310 | "output_type": "execute_result"
311 | }
312 | ],
313 | "source": [
314 | "prem_league.head()"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 6,
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/plain": [
325 | "season object\n",
326 | "team object\n",
327 | "points int64\n",
328 | "dtype: object"
329 | ]
330 | },
331 | "execution_count": 6,
332 | "metadata": {},
333 | "output_type": "execute_result"
334 | }
335 | ],
336 | "source": [
337 | "prem_league.dtypes"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 7,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "df = prem_league.pivot_table(values = 'points',index = ['season'], columns = 'team')"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": 8,
352 | "metadata": {},
353 | "outputs": [
354 | {
355 | "data": {
356 | "text/html": [
357 | "\n",
358 | "\n",
371 | "
\n",
372 | " \n",
373 | " \n",
374 | " | team | \n",
375 | " Arsenal | \n",
376 | " Aston Villa | \n",
377 | " Barnsley | \n",
378 | " Birmingham City | \n",
379 | " Blackburn Rovers | \n",
380 | " Blackpool | \n",
381 | " Bolton Wanderers | \n",
382 | " Bournemouth | \n",
383 | " Bradford City | \n",
384 | " Brighton and Hove Albion | \n",
385 | " ... | \n",
386 | " Sunderland | \n",
387 | " Swansea City | \n",
388 | " Swindon Town | \n",
389 | " Tottenham Hotspur | \n",
390 | " Watford | \n",
391 | " West Bromwich Albion | \n",
392 | " West Ham United | \n",
393 | " Wigan Athletic | \n",
394 | " Wimbledon FC | \n",
395 | " Wolverhampton Wanderers | \n",
396 | "
\n",
397 | " \n",
398 | " | season | \n",
399 | " | \n",
400 | " | \n",
401 | " | \n",
402 | " | \n",
403 | " | \n",
404 | " | \n",
405 | " | \n",
406 | " | \n",
407 | " | \n",
408 | " | \n",
409 | " | \n",
410 | " | \n",
411 | " | \n",
412 | " | \n",
413 | " | \n",
414 | " | \n",
415 | " | \n",
416 | " | \n",
417 | " | \n",
418 | " | \n",
419 | " | \n",
420 | "
\n",
421 | " \n",
422 | " \n",
423 | " \n",
424 | " | 1992-93 | \n",
425 | " 56.0 | \n",
426 | " 74.0 | \n",
427 | " NaN | \n",
428 | " NaN | \n",
429 | " 71.0 | \n",
430 | " NaN | \n",
431 | " NaN | \n",
432 | " NaN | \n",
433 | " NaN | \n",
434 | " NaN | \n",
435 | " ... | \n",
436 | " NaN | \n",
437 | " NaN | \n",
438 | " NaN | \n",
439 | " 59.0 | \n",
440 | " NaN | \n",
441 | " NaN | \n",
442 | " NaN | \n",
443 | " NaN | \n",
444 | " 54.0 | \n",
445 | " NaN | \n",
446 | "
\n",
447 | " \n",
448 | " | 1993-94 | \n",
449 | " 71.0 | \n",
450 | " 57.0 | \n",
451 | " NaN | \n",
452 | " NaN | \n",
453 | " 84.0 | \n",
454 | " NaN | \n",
455 | " NaN | \n",
456 | " NaN | \n",
457 | " NaN | \n",
458 | " NaN | \n",
459 | " ... | \n",
460 | " NaN | \n",
461 | " NaN | \n",
462 | " 30.0 | \n",
463 | " 45.0 | \n",
464 | " NaN | \n",
465 | " NaN | \n",
466 | " 52.0 | \n",
467 | " NaN | \n",
468 | " 65.0 | \n",
469 | " NaN | \n",
470 | "
\n",
471 | " \n",
472 | " | 1994-95 | \n",
473 | " 51.0 | \n",
474 | " 48.0 | \n",
475 | " NaN | \n",
476 | " NaN | \n",
477 | " 89.0 | \n",
478 | " NaN | \n",
479 | " NaN | \n",
480 | " NaN | \n",
481 | " NaN | \n",
482 | " NaN | \n",
483 | " ... | \n",
484 | " NaN | \n",
485 | " NaN | \n",
486 | " NaN | \n",
487 | " 62.0 | \n",
488 | " NaN | \n",
489 | " NaN | \n",
490 | " 50.0 | \n",
491 | " NaN | \n",
492 | " 56.0 | \n",
493 | " NaN | \n",
494 | "
\n",
495 | " \n",
496 | " | 1995-96 | \n",
497 | " 63.0 | \n",
498 | " 63.0 | \n",
499 | " NaN | \n",
500 | " NaN | \n",
501 | " 61.0 | \n",
502 | " NaN | \n",
503 | " 29.0 | \n",
504 | " NaN | \n",
505 | " NaN | \n",
506 | " NaN | \n",
507 | " ... | \n",
508 | " NaN | \n",
509 | " NaN | \n",
510 | " NaN | \n",
511 | " 61.0 | \n",
512 | " NaN | \n",
513 | " NaN | \n",
514 | " 51.0 | \n",
515 | " NaN | \n",
516 | " 41.0 | \n",
517 | " NaN | \n",
518 | "
\n",
519 | " \n",
520 | " | 1996-97 | \n",
521 | " 68.0 | \n",
522 | " 61.0 | \n",
523 | " NaN | \n",
524 | " NaN | \n",
525 | " 42.0 | \n",
526 | " NaN | \n",
527 | " NaN | \n",
528 | " NaN | \n",
529 | " NaN | \n",
530 | " NaN | \n",
531 | " ... | \n",
532 | " 40.0 | \n",
533 | " NaN | \n",
534 | " NaN | \n",
535 | " 46.0 | \n",
536 | " NaN | \n",
537 | " NaN | \n",
538 | " 42.0 | \n",
539 | " NaN | \n",
540 | " 56.0 | \n",
541 | " NaN | \n",
542 | "
\n",
543 | " \n",
544 | "
\n",
545 | "
5 rows × 49 columns
\n",
546 | "
"
547 | ],
548 | "text/plain": [
549 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n",
550 | "season \n",
551 | "1992-93 56.0 74.0 NaN NaN 71.0 \n",
552 | "1993-94 71.0 57.0 NaN NaN 84.0 \n",
553 | "1994-95 51.0 48.0 NaN NaN 89.0 \n",
554 | "1995-96 63.0 63.0 NaN NaN 61.0 \n",
555 | "1996-97 68.0 61.0 NaN NaN 42.0 \n",
556 | "\n",
557 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n",
558 | "season \n",
559 | "1992-93 NaN NaN NaN NaN \n",
560 | "1993-94 NaN NaN NaN NaN \n",
561 | "1994-95 NaN NaN NaN NaN \n",
562 | "1995-96 NaN 29.0 NaN NaN \n",
563 | "1996-97 NaN NaN NaN NaN \n",
564 | "\n",
565 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n",
566 | "season ... \n",
567 | "1992-93 NaN ... NaN NaN \n",
568 | "1993-94 NaN ... NaN NaN \n",
569 | "1994-95 NaN ... NaN NaN \n",
570 | "1995-96 NaN ... NaN NaN \n",
571 | "1996-97 NaN ... 40.0 NaN \n",
572 | "\n",
573 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n",
574 | "season \n",
575 | "1992-93 NaN 59.0 NaN NaN \n",
576 | "1993-94 30.0 45.0 NaN NaN \n",
577 | "1994-95 NaN 62.0 NaN NaN \n",
578 | "1995-96 NaN 61.0 NaN NaN \n",
579 | "1996-97 NaN 46.0 NaN NaN \n",
580 | "\n",
581 | "team West Ham United Wigan Athletic Wimbledon FC \\\n",
582 | "season \n",
583 | "1992-93 NaN NaN 54.0 \n",
584 | "1993-94 52.0 NaN 65.0 \n",
585 | "1994-95 50.0 NaN 56.0 \n",
586 | "1995-96 51.0 NaN 41.0 \n",
587 | "1996-97 42.0 NaN 56.0 \n",
588 | "\n",
589 | "team Wolverhampton Wanderers \n",
590 | "season \n",
591 | "1992-93 NaN \n",
592 | "1993-94 NaN \n",
593 | "1994-95 NaN \n",
594 | "1995-96 NaN \n",
595 | "1996-97 NaN \n",
596 | "\n",
597 | "[5 rows x 49 columns]"
598 | ]
599 | },
600 | "execution_count": 8,
601 | "metadata": {},
602 | "output_type": "execute_result"
603 | }
604 | ],
605 | "source": [
606 | "df.head()"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": 9,
612 | "metadata": {},
613 | "outputs": [],
614 | "source": [
615 | "df.fillna(0, inplace=True)\n",
616 | "df.sort_values(list(df.columns),inplace=True)\n",
617 | "df = df.sort_index()"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "execution_count": 10,
623 | "metadata": {},
624 | "outputs": [
625 | {
626 | "data": {
627 | "text/html": [
628 | "\n",
629 | "\n",
642 | "
\n",
643 | " \n",
644 | " \n",
645 | " | team | \n",
646 | " Arsenal | \n",
647 | " Aston Villa | \n",
648 | " Barnsley | \n",
649 | " Birmingham City | \n",
650 | " Blackburn Rovers | \n",
651 | " Blackpool | \n",
652 | " Bolton Wanderers | \n",
653 | " Bournemouth | \n",
654 | " Bradford City | \n",
655 | " Brighton and Hove Albion | \n",
656 | " ... | \n",
657 | " Sunderland | \n",
658 | " Swansea City | \n",
659 | " Swindon Town | \n",
660 | " Tottenham Hotspur | \n",
661 | " Watford | \n",
662 | " West Bromwich Albion | \n",
663 | " West Ham United | \n",
664 | " Wigan Athletic | \n",
665 | " Wimbledon FC | \n",
666 | " Wolverhampton Wanderers | \n",
667 | "
\n",
668 | " \n",
669 | " | season | \n",
670 | " | \n",
671 | " | \n",
672 | " | \n",
673 | " | \n",
674 | " | \n",
675 | " | \n",
676 | " | \n",
677 | " | \n",
678 | " | \n",
679 | " | \n",
680 | " | \n",
681 | " | \n",
682 | " | \n",
683 | " | \n",
684 | " | \n",
685 | " | \n",
686 | " | \n",
687 | " | \n",
688 | " | \n",
689 | " | \n",
690 | " | \n",
691 | "
\n",
692 | " \n",
693 | " \n",
694 | " \n",
695 | " | 1992-93 | \n",
696 | " 56.0 | \n",
697 | " 74.0 | \n",
698 | " 0.0 | \n",
699 | " 0.0 | \n",
700 | " 71.0 | \n",
701 | " 0.0 | \n",
702 | " 0.0 | \n",
703 | " 0.0 | \n",
704 | " 0.0 | \n",
705 | " 0.0 | \n",
706 | " ... | \n",
707 | " 0.0 | \n",
708 | " 0.0 | \n",
709 | " 0.0 | \n",
710 | " 59.0 | \n",
711 | " 0.0 | \n",
712 | " 0.0 | \n",
713 | " 0.0 | \n",
714 | " 0.0 | \n",
715 | " 54.0 | \n",
716 | " 0.0 | \n",
717 | "
\n",
718 | " \n",
719 | " | 1993-94 | \n",
720 | " 71.0 | \n",
721 | " 57.0 | \n",
722 | " 0.0 | \n",
723 | " 0.0 | \n",
724 | " 84.0 | \n",
725 | " 0.0 | \n",
726 | " 0.0 | \n",
727 | " 0.0 | \n",
728 | " 0.0 | \n",
729 | " 0.0 | \n",
730 | " ... | \n",
731 | " 0.0 | \n",
732 | " 0.0 | \n",
733 | " 30.0 | \n",
734 | " 45.0 | \n",
735 | " 0.0 | \n",
736 | " 0.0 | \n",
737 | " 52.0 | \n",
738 | " 0.0 | \n",
739 | " 65.0 | \n",
740 | " 0.0 | \n",
741 | "
\n",
742 | " \n",
743 | " | 1994-95 | \n",
744 | " 51.0 | \n",
745 | " 48.0 | \n",
746 | " 0.0 | \n",
747 | " 0.0 | \n",
748 | " 89.0 | \n",
749 | " 0.0 | \n",
750 | " 0.0 | \n",
751 | " 0.0 | \n",
752 | " 0.0 | \n",
753 | " 0.0 | \n",
754 | " ... | \n",
755 | " 0.0 | \n",
756 | " 0.0 | \n",
757 | " 0.0 | \n",
758 | " 62.0 | \n",
759 | " 0.0 | \n",
760 | " 0.0 | \n",
761 | " 50.0 | \n",
762 | " 0.0 | \n",
763 | " 56.0 | \n",
764 | " 0.0 | \n",
765 | "
\n",
766 | " \n",
767 | " | 1995-96 | \n",
768 | " 63.0 | \n",
769 | " 63.0 | \n",
770 | " 0.0 | \n",
771 | " 0.0 | \n",
772 | " 61.0 | \n",
773 | " 0.0 | \n",
774 | " 29.0 | \n",
775 | " 0.0 | \n",
776 | " 0.0 | \n",
777 | " 0.0 | \n",
778 | " ... | \n",
779 | " 0.0 | \n",
780 | " 0.0 | \n",
781 | " 0.0 | \n",
782 | " 61.0 | \n",
783 | " 0.0 | \n",
784 | " 0.0 | \n",
785 | " 51.0 | \n",
786 | " 0.0 | \n",
787 | " 41.0 | \n",
788 | " 0.0 | \n",
789 | "
\n",
790 | " \n",
791 | " | 1996-97 | \n",
792 | " 68.0 | \n",
793 | " 61.0 | \n",
794 | " 0.0 | \n",
795 | " 0.0 | \n",
796 | " 42.0 | \n",
797 | " 0.0 | \n",
798 | " 0.0 | \n",
799 | " 0.0 | \n",
800 | " 0.0 | \n",
801 | " 0.0 | \n",
802 | " ... | \n",
803 | " 40.0 | \n",
804 | " 0.0 | \n",
805 | " 0.0 | \n",
806 | " 46.0 | \n",
807 | " 0.0 | \n",
808 | " 0.0 | \n",
809 | " 42.0 | \n",
810 | " 0.0 | \n",
811 | " 56.0 | \n",
812 | " 0.0 | \n",
813 | "
\n",
814 | " \n",
815 | "
\n",
816 | "
5 rows × 49 columns
\n",
817 | "
"
818 | ],
819 | "text/plain": [
820 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n",
821 | "season \n",
822 | "1992-93 56.0 74.0 0.0 0.0 71.0 \n",
823 | "1993-94 71.0 57.0 0.0 0.0 84.0 \n",
824 | "1994-95 51.0 48.0 0.0 0.0 89.0 \n",
825 | "1995-96 63.0 63.0 0.0 0.0 61.0 \n",
826 | "1996-97 68.0 61.0 0.0 0.0 42.0 \n",
827 | "\n",
828 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n",
829 | "season \n",
830 | "1992-93 0.0 0.0 0.0 0.0 \n",
831 | "1993-94 0.0 0.0 0.0 0.0 \n",
832 | "1994-95 0.0 0.0 0.0 0.0 \n",
833 | "1995-96 0.0 29.0 0.0 0.0 \n",
834 | "1996-97 0.0 0.0 0.0 0.0 \n",
835 | "\n",
836 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n",
837 | "season ... \n",
838 | "1992-93 0.0 ... 0.0 0.0 \n",
839 | "1993-94 0.0 ... 0.0 0.0 \n",
840 | "1994-95 0.0 ... 0.0 0.0 \n",
841 | "1995-96 0.0 ... 0.0 0.0 \n",
842 | "1996-97 0.0 ... 40.0 0.0 \n",
843 | "\n",
844 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n",
845 | "season \n",
846 | "1992-93 0.0 59.0 0.0 0.0 \n",
847 | "1993-94 30.0 45.0 0.0 0.0 \n",
848 | "1994-95 0.0 62.0 0.0 0.0 \n",
849 | "1995-96 0.0 61.0 0.0 0.0 \n",
850 | "1996-97 0.0 46.0 0.0 0.0 \n",
851 | "\n",
852 | "team West Ham United Wigan Athletic Wimbledon FC \\\n",
853 | "season \n",
854 | "1992-93 0.0 0.0 54.0 \n",
855 | "1993-94 52.0 0.0 65.0 \n",
856 | "1994-95 50.0 0.0 56.0 \n",
857 | "1995-96 51.0 0.0 41.0 \n",
858 | "1996-97 42.0 0.0 56.0 \n",
859 | "\n",
860 | "team Wolverhampton Wanderers \n",
861 | "season \n",
862 | "1992-93 0.0 \n",
863 | "1993-94 0.0 \n",
864 | "1994-95 0.0 \n",
865 | "1995-96 0.0 \n",
866 | "1996-97 0.0 \n",
867 | "\n",
868 | "[5 rows x 49 columns]"
869 | ]
870 | },
871 | "execution_count": 10,
872 | "metadata": {},
873 | "output_type": "execute_result"
874 | }
875 | ],
876 | "source": [
877 | "df.head()"
878 | ]
879 | },
880 | {
881 | "cell_type": "code",
882 | "execution_count": 11,
883 | "metadata": {},
884 | "outputs": [],
885 | "source": [
886 | "df.iloc[:, 0:-1] = df.iloc[:, 0:-1].cumsum()"
887 | ]
888 | },
889 | {
890 | "cell_type": "code",
891 | "execution_count": 12,
892 | "metadata": {
893 | "scrolled": true
894 | },
895 | "outputs": [
896 | {
897 | "data": {
898 | "text/html": [
899 | "\n",
900 | "\n",
913 | "
\n",
914 | " \n",
915 | " \n",
916 | " | team | \n",
917 | " Arsenal | \n",
918 | " Aston Villa | \n",
919 | " Barnsley | \n",
920 | " Birmingham City | \n",
921 | " Blackburn Rovers | \n",
922 | " Blackpool | \n",
923 | " Bolton Wanderers | \n",
924 | " Bournemouth | \n",
925 | " Bradford City | \n",
926 | " Brighton and Hove Albion | \n",
927 | " ... | \n",
928 | " Sunderland | \n",
929 | " Swansea City | \n",
930 | " Swindon Town | \n",
931 | " Tottenham Hotspur | \n",
932 | " Watford | \n",
933 | " West Bromwich Albion | \n",
934 | " West Ham United | \n",
935 | " Wigan Athletic | \n",
936 | " Wimbledon FC | \n",
937 | " Wolverhampton Wanderers | \n",
938 | "
\n",
939 | " \n",
940 | " | season | \n",
941 | " | \n",
942 | " | \n",
943 | " | \n",
944 | " | \n",
945 | " | \n",
946 | " | \n",
947 | " | \n",
948 | " | \n",
949 | " | \n",
950 | " | \n",
951 | " | \n",
952 | " | \n",
953 | " | \n",
954 | " | \n",
955 | " | \n",
956 | " | \n",
957 | " | \n",
958 | " | \n",
959 | " | \n",
960 | " | \n",
961 | " | \n",
962 | "
\n",
963 | " \n",
964 | " \n",
965 | " \n",
966 | " | 1992-93 | \n",
967 | " 56.0 | \n",
968 | " 74.0 | \n",
969 | " 0.0 | \n",
970 | " 0.0 | \n",
971 | " 71.0 | \n",
972 | " 0.0 | \n",
973 | " 0.0 | \n",
974 | " 0.0 | \n",
975 | " 0.0 | \n",
976 | " 0.0 | \n",
977 | " ... | \n",
978 | " 0.0 | \n",
979 | " 0.0 | \n",
980 | " 0.0 | \n",
981 | " 59.0 | \n",
982 | " 0.0 | \n",
983 | " 0.0 | \n",
984 | " 0.0 | \n",
985 | " 0.0 | \n",
986 | " 54.0 | \n",
987 | " 0.0 | \n",
988 | "
\n",
989 | " \n",
990 | " | 1993-94 | \n",
991 | " 127.0 | \n",
992 | " 131.0 | \n",
993 | " 0.0 | \n",
994 | " 0.0 | \n",
995 | " 155.0 | \n",
996 | " 0.0 | \n",
997 | " 0.0 | \n",
998 | " 0.0 | \n",
999 | " 0.0 | \n",
1000 | " 0.0 | \n",
1001 | " ... | \n",
1002 | " 0.0 | \n",
1003 | " 0.0 | \n",
1004 | " 30.0 | \n",
1005 | " 104.0 | \n",
1006 | " 0.0 | \n",
1007 | " 0.0 | \n",
1008 | " 52.0 | \n",
1009 | " 0.0 | \n",
1010 | " 119.0 | \n",
1011 | " 0.0 | \n",
1012 | "
\n",
1013 | " \n",
1014 | " | 1994-95 | \n",
1015 | " 178.0 | \n",
1016 | " 179.0 | \n",
1017 | " 0.0 | \n",
1018 | " 0.0 | \n",
1019 | " 244.0 | \n",
1020 | " 0.0 | \n",
1021 | " 0.0 | \n",
1022 | " 0.0 | \n",
1023 | " 0.0 | \n",
1024 | " 0.0 | \n",
1025 | " ... | \n",
1026 | " 0.0 | \n",
1027 | " 0.0 | \n",
1028 | " 30.0 | \n",
1029 | " 166.0 | \n",
1030 | " 0.0 | \n",
1031 | " 0.0 | \n",
1032 | " 102.0 | \n",
1033 | " 0.0 | \n",
1034 | " 175.0 | \n",
1035 | " 0.0 | \n",
1036 | "
\n",
1037 | " \n",
1038 | " | 1995-96 | \n",
1039 | " 241.0 | \n",
1040 | " 242.0 | \n",
1041 | " 0.0 | \n",
1042 | " 0.0 | \n",
1043 | " 305.0 | \n",
1044 | " 0.0 | \n",
1045 | " 29.0 | \n",
1046 | " 0.0 | \n",
1047 | " 0.0 | \n",
1048 | " 0.0 | \n",
1049 | " ... | \n",
1050 | " 0.0 | \n",
1051 | " 0.0 | \n",
1052 | " 30.0 | \n",
1053 | " 227.0 | \n",
1054 | " 0.0 | \n",
1055 | " 0.0 | \n",
1056 | " 153.0 | \n",
1057 | " 0.0 | \n",
1058 | " 216.0 | \n",
1059 | " 0.0 | \n",
1060 | "
\n",
1061 | " \n",
1062 | " | 1996-97 | \n",
1063 | " 309.0 | \n",
1064 | " 303.0 | \n",
1065 | " 0.0 | \n",
1066 | " 0.0 | \n",
1067 | " 347.0 | \n",
1068 | " 0.0 | \n",
1069 | " 29.0 | \n",
1070 | " 0.0 | \n",
1071 | " 0.0 | \n",
1072 | " 0.0 | \n",
1073 | " ... | \n",
1074 | " 40.0 | \n",
1075 | " 0.0 | \n",
1076 | " 30.0 | \n",
1077 | " 273.0 | \n",
1078 | " 0.0 | \n",
1079 | " 0.0 | \n",
1080 | " 195.0 | \n",
1081 | " 0.0 | \n",
1082 | " 272.0 | \n",
1083 | " 0.0 | \n",
1084 | "
\n",
1085 | " \n",
1086 | " | 1997-98 | \n",
1087 | " 387.0 | \n",
1088 | " 360.0 | \n",
1089 | " 35.0 | \n",
1090 | " 0.0 | \n",
1091 | " 405.0 | \n",
1092 | " 0.0 | \n",
1093 | " 69.0 | \n",
1094 | " 0.0 | \n",
1095 | " 0.0 | \n",
1096 | " 0.0 | \n",
1097 | " ... | \n",
1098 | " 40.0 | \n",
1099 | " 0.0 | \n",
1100 | " 30.0 | \n",
1101 | " 317.0 | \n",
1102 | " 0.0 | \n",
1103 | " 0.0 | \n",
1104 | " 251.0 | \n",
1105 | " 0.0 | \n",
1106 | " 316.0 | \n",
1107 | " 0.0 | \n",
1108 | "
\n",
1109 | " \n",
1110 | " | 1998-99 | \n",
1111 | " 465.0 | \n",
1112 | " 415.0 | \n",
1113 | " 35.0 | \n",
1114 | " 0.0 | \n",
1115 | " 440.0 | \n",
1116 | " 0.0 | \n",
1117 | " 69.0 | \n",
1118 | " 0.0 | \n",
1119 | " 0.0 | \n",
1120 | " 0.0 | \n",
1121 | " ... | \n",
1122 | " 40.0 | \n",
1123 | " 0.0 | \n",
1124 | " 30.0 | \n",
1125 | " 364.0 | \n",
1126 | " 0.0 | \n",
1127 | " 0.0 | \n",
1128 | " 308.0 | \n",
1129 | " 0.0 | \n",
1130 | " 358.0 | \n",
1131 | " 0.0 | \n",
1132 | "
\n",
1133 | " \n",
1134 | " | 1999-00 | \n",
1135 | " 538.0 | \n",
1136 | " 473.0 | \n",
1137 | " 35.0 | \n",
1138 | " 0.0 | \n",
1139 | " 440.0 | \n",
1140 | " 0.0 | \n",
1141 | " 69.0 | \n",
1142 | " 0.0 | \n",
1143 | " 36.0 | \n",
1144 | " 0.0 | \n",
1145 | " ... | \n",
1146 | " 98.0 | \n",
1147 | " 0.0 | \n",
1148 | " 30.0 | \n",
1149 | " 417.0 | \n",
1150 | " 24.0 | \n",
1151 | " 0.0 | \n",
1152 | " 363.0 | \n",
1153 | " 0.0 | \n",
1154 | " 391.0 | \n",
1155 | " 0.0 | \n",
1156 | "
\n",
1157 | " \n",
1158 | " | 2000-01 | \n",
1159 | " 608.0 | \n",
1160 | " 527.0 | \n",
1161 | " 35.0 | \n",
1162 | " 0.0 | \n",
1163 | " 440.0 | \n",
1164 | " 0.0 | \n",
1165 | " 69.0 | \n",
1166 | " 0.0 | \n",
1167 | " 62.0 | \n",
1168 | " 0.0 | \n",
1169 | " ... | \n",
1170 | " 155.0 | \n",
1171 | " 0.0 | \n",
1172 | " 30.0 | \n",
1173 | " 466.0 | \n",
1174 | " 24.0 | \n",
1175 | " 0.0 | \n",
1176 | " 405.0 | \n",
1177 | " 0.0 | \n",
1178 | " 391.0 | \n",
1179 | " 0.0 | \n",
1180 | "
\n",
1181 | " \n",
1182 | " | 2001-02 | \n",
1183 | " 695.0 | \n",
1184 | " 577.0 | \n",
1185 | " 35.0 | \n",
1186 | " 0.0 | \n",
1187 | " 486.0 | \n",
1188 | " 0.0 | \n",
1189 | " 109.0 | \n",
1190 | " 0.0 | \n",
1191 | " 62.0 | \n",
1192 | " 0.0 | \n",
1193 | " ... | \n",
1194 | " 195.0 | \n",
1195 | " 0.0 | \n",
1196 | " 30.0 | \n",
1197 | " 516.0 | \n",
1198 | " 24.0 | \n",
1199 | " 0.0 | \n",
1200 | " 458.0 | \n",
1201 | " 0.0 | \n",
1202 | " 391.0 | \n",
1203 | " 0.0 | \n",
1204 | "
\n",
1205 | " \n",
1206 | "
\n",
1207 | "
10 rows × 49 columns
\n",
1208 | "
"
1209 | ],
1210 | "text/plain": [
1211 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n",
1212 | "season \n",
1213 | "1992-93 56.0 74.0 0.0 0.0 71.0 \n",
1214 | "1993-94 127.0 131.0 0.0 0.0 155.0 \n",
1215 | "1994-95 178.0 179.0 0.0 0.0 244.0 \n",
1216 | "1995-96 241.0 242.0 0.0 0.0 305.0 \n",
1217 | "1996-97 309.0 303.0 0.0 0.0 347.0 \n",
1218 | "1997-98 387.0 360.0 35.0 0.0 405.0 \n",
1219 | "1998-99 465.0 415.0 35.0 0.0 440.0 \n",
1220 | "1999-00 538.0 473.0 35.0 0.0 440.0 \n",
1221 | "2000-01 608.0 527.0 35.0 0.0 440.0 \n",
1222 | "2001-02 695.0 577.0 35.0 0.0 486.0 \n",
1223 | "\n",
1224 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n",
1225 | "season \n",
1226 | "1992-93 0.0 0.0 0.0 0.0 \n",
1227 | "1993-94 0.0 0.0 0.0 0.0 \n",
1228 | "1994-95 0.0 0.0 0.0 0.0 \n",
1229 | "1995-96 0.0 29.0 0.0 0.0 \n",
1230 | "1996-97 0.0 29.0 0.0 0.0 \n",
1231 | "1997-98 0.0 69.0 0.0 0.0 \n",
1232 | "1998-99 0.0 69.0 0.0 0.0 \n",
1233 | "1999-00 0.0 69.0 0.0 36.0 \n",
1234 | "2000-01 0.0 69.0 0.0 62.0 \n",
1235 | "2001-02 0.0 109.0 0.0 62.0 \n",
1236 | "\n",
1237 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n",
1238 | "season ... \n",
1239 | "1992-93 0.0 ... 0.0 0.0 \n",
1240 | "1993-94 0.0 ... 0.0 0.0 \n",
1241 | "1994-95 0.0 ... 0.0 0.0 \n",
1242 | "1995-96 0.0 ... 0.0 0.0 \n",
1243 | "1996-97 0.0 ... 40.0 0.0 \n",
1244 | "1997-98 0.0 ... 40.0 0.0 \n",
1245 | "1998-99 0.0 ... 40.0 0.0 \n",
1246 | "1999-00 0.0 ... 98.0 0.0 \n",
1247 | "2000-01 0.0 ... 155.0 0.0 \n",
1248 | "2001-02 0.0 ... 195.0 0.0 \n",
1249 | "\n",
1250 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n",
1251 | "season \n",
1252 | "1992-93 0.0 59.0 0.0 0.0 \n",
1253 | "1993-94 30.0 104.0 0.0 0.0 \n",
1254 | "1994-95 30.0 166.0 0.0 0.0 \n",
1255 | "1995-96 30.0 227.0 0.0 0.0 \n",
1256 | "1996-97 30.0 273.0 0.0 0.0 \n",
1257 | "1997-98 30.0 317.0 0.0 0.0 \n",
1258 | "1998-99 30.0 364.0 0.0 0.0 \n",
1259 | "1999-00 30.0 417.0 24.0 0.0 \n",
1260 | "2000-01 30.0 466.0 24.0 0.0 \n",
1261 | "2001-02 30.0 516.0 24.0 0.0 \n",
1262 | "\n",
1263 | "team West Ham United Wigan Athletic Wimbledon FC \\\n",
1264 | "season \n",
1265 | "1992-93 0.0 0.0 54.0 \n",
1266 | "1993-94 52.0 0.0 119.0 \n",
1267 | "1994-95 102.0 0.0 175.0 \n",
1268 | "1995-96 153.0 0.0 216.0 \n",
1269 | "1996-97 195.0 0.0 272.0 \n",
1270 | "1997-98 251.0 0.0 316.0 \n",
1271 | "1998-99 308.0 0.0 358.0 \n",
1272 | "1999-00 363.0 0.0 391.0 \n",
1273 | "2000-01 405.0 0.0 391.0 \n",
1274 | "2001-02 458.0 0.0 391.0 \n",
1275 | "\n",
1276 | "team Wolverhampton Wanderers \n",
1277 | "season \n",
1278 | "1992-93 0.0 \n",
1279 | "1993-94 0.0 \n",
1280 | "1994-95 0.0 \n",
1281 | "1995-96 0.0 \n",
1282 | "1996-97 0.0 \n",
1283 | "1997-98 0.0 \n",
1284 | "1998-99 0.0 \n",
1285 | "1999-00 0.0 \n",
1286 | "2000-01 0.0 \n",
1287 | "2001-02 0.0 \n",
1288 | "\n",
1289 | "[10 rows x 49 columns]"
1290 | ]
1291 | },
1292 | "execution_count": 12,
1293 | "metadata": {},
1294 | "output_type": "execute_result"
1295 | }
1296 | ],
1297 | "source": [
1298 | "df[0:10]"
1299 | ]
1300 | },
1301 | {
1302 | "cell_type": "code",
1303 | "execution_count": 13,
1304 | "metadata": {},
1305 | "outputs": [],
1306 | "source": [
1307 | "top_prem_clubs = set()\n",
1308 | "\n",
1309 | "for index, row in df.iterrows():\n",
1310 | " top_prem_clubs |= set(row[row > 0].sort_values(ascending=False).head(6).index)\n",
1311 | "\n",
1312 | "df = df[top_prem_clubs]"
1313 | ]
1314 | },
1315 | {
1316 | "cell_type": "code",
1317 | "execution_count": 14,
1318 | "metadata": {},
1319 | "outputs": [
1320 | {
1321 | "data": {
1322 | "text/html": [
1323 | "\n",
1324 | "\n",
1337 | "
\n",
1338 | " \n",
1339 | " \n",
1340 | " | team | \n",
1341 | " Everton | \n",
1342 | " Queens Park Rangers | \n",
1343 | " Newcastle United | \n",
1344 | " Sheffield Wednesday | \n",
1345 | " Blackburn Rovers | \n",
1346 | " Chelsea | \n",
1347 | " Aston Villa | \n",
1348 | " Norwich City | \n",
1349 | " Manchester United | \n",
1350 | " Arsenal | \n",
1351 | " Leeds United | \n",
1352 | " Tottenham Hotspur | \n",
1353 | " Liverpool | \n",
1354 | "
\n",
1355 | " \n",
1356 | " | season | \n",
1357 | " | \n",
1358 | " | \n",
1359 | " | \n",
1360 | " | \n",
1361 | " | \n",
1362 | " | \n",
1363 | " | \n",
1364 | " | \n",
1365 | " | \n",
1366 | " | \n",
1367 | " | \n",
1368 | " | \n",
1369 | " | \n",
1370 | "
\n",
1371 | " \n",
1372 | " \n",
1373 | " \n",
1374 | " | 1992-93 | \n",
1375 | " 53.0 | \n",
1376 | " 63.0 | \n",
1377 | " 0.0 | \n",
1378 | " 59.0 | \n",
1379 | " 71.0 | \n",
1380 | " 56.0 | \n",
1381 | " 74.0 | \n",
1382 | " 72.0 | \n",
1383 | " 84.0 | \n",
1384 | " 56.0 | \n",
1385 | " 51.0 | \n",
1386 | " 59.0 | \n",
1387 | " 59.0 | \n",
1388 | "
\n",
1389 | " \n",
1390 | " | 1993-94 | \n",
1391 | " 97.0 | \n",
1392 | " 123.0 | \n",
1393 | " 77.0 | \n",
1394 | " 123.0 | \n",
1395 | " 155.0 | \n",
1396 | " 107.0 | \n",
1397 | " 131.0 | \n",
1398 | " 125.0 | \n",
1399 | " 176.0 | \n",
1400 | " 127.0 | \n",
1401 | " 121.0 | \n",
1402 | " 104.0 | \n",
1403 | " 119.0 | \n",
1404 | "
\n",
1405 | " \n",
1406 | " | 1994-95 | \n",
1407 | " 147.0 | \n",
1408 | " 183.0 | \n",
1409 | " 149.0 | \n",
1410 | " 174.0 | \n",
1411 | " 244.0 | \n",
1412 | " 161.0 | \n",
1413 | " 179.0 | \n",
1414 | " 168.0 | \n",
1415 | " 264.0 | \n",
1416 | " 178.0 | \n",
1417 | " 194.0 | \n",
1418 | " 166.0 | \n",
1419 | " 193.0 | \n",
1420 | "
\n",
1421 | " \n",
1422 | " | 1995-96 | \n",
1423 | " 208.0 | \n",
1424 | " 216.0 | \n",
1425 | " 227.0 | \n",
1426 | " 214.0 | \n",
1427 | " 305.0 | \n",
1428 | " 211.0 | \n",
1429 | " 242.0 | \n",
1430 | " 168.0 | \n",
1431 | " 346.0 | \n",
1432 | " 241.0 | \n",
1433 | " 237.0 | \n",
1434 | " 227.0 | \n",
1435 | " 264.0 | \n",
1436 | "
\n",
1437 | " \n",
1438 | " | 1996-97 | \n",
1439 | " 250.0 | \n",
1440 | " 216.0 | \n",
1441 | " 295.0 | \n",
1442 | " 271.0 | \n",
1443 | " 347.0 | \n",
1444 | " 270.0 | \n",
1445 | " 303.0 | \n",
1446 | " 168.0 | \n",
1447 | " 421.0 | \n",
1448 | " 309.0 | \n",
1449 | " 283.0 | \n",
1450 | " 273.0 | \n",
1451 | " 332.0 | \n",
1452 | "
\n",
1453 | " \n",
1454 | "
\n",
1455 | "
"
1456 | ],
1457 | "text/plain": [
1458 | "team Everton Queens Park Rangers Newcastle United Sheffield Wednesday \\\n",
1459 | "season \n",
1460 | "1992-93 53.0 63.0 0.0 59.0 \n",
1461 | "1993-94 97.0 123.0 77.0 123.0 \n",
1462 | "1994-95 147.0 183.0 149.0 174.0 \n",
1463 | "1995-96 208.0 216.0 227.0 214.0 \n",
1464 | "1996-97 250.0 216.0 295.0 271.0 \n",
1465 | "\n",
1466 | "team Blackburn Rovers Chelsea Aston Villa Norwich City \\\n",
1467 | "season \n",
1468 | "1992-93 71.0 56.0 74.0 72.0 \n",
1469 | "1993-94 155.0 107.0 131.0 125.0 \n",
1470 | "1994-95 244.0 161.0 179.0 168.0 \n",
1471 | "1995-96 305.0 211.0 242.0 168.0 \n",
1472 | "1996-97 347.0 270.0 303.0 168.0 \n",
1473 | "\n",
1474 | "team Manchester United Arsenal Leeds United Tottenham Hotspur \\\n",
1475 | "season \n",
1476 | "1992-93 84.0 56.0 51.0 59.0 \n",
1477 | "1993-94 176.0 127.0 121.0 104.0 \n",
1478 | "1994-95 264.0 178.0 194.0 166.0 \n",
1479 | "1995-96 346.0 241.0 237.0 227.0 \n",
1480 | "1996-97 421.0 309.0 283.0 273.0 \n",
1481 | "\n",
1482 | "team Liverpool \n",
1483 | "season \n",
1484 | "1992-93 59.0 \n",
1485 | "1993-94 119.0 \n",
1486 | "1994-95 193.0 \n",
1487 | "1995-96 264.0 \n",
1488 | "1996-97 332.0 "
1489 | ]
1490 | },
1491 | "execution_count": 14,
1492 | "metadata": {},
1493 | "output_type": "execute_result"
1494 | }
1495 | ],
1496 | "source": [
1497 | "df.head()"
1498 | ]
1499 | },
1500 | {
1501 | "cell_type": "code",
1502 | "execution_count": 15,
1503 | "metadata": {},
1504 | "outputs": [],
1505 | "source": [
1506 | "import bar_chart_race as bcr"
1507 | ]
1508 | },
1509 | {
1510 | "cell_type": "code",
1511 | "execution_count": 22,
1512 | "metadata": {},
1513 | "outputs": [],
1514 | "source": [
1515 | "bcr.bar_chart_race(df = df, \n",
1516 | " n_bars = 6, \n",
1517 | " sort='desc',\n",
1518 | " title='Premier League Clubs Points Since 1992',\n",
1519 | " period_length = 750,\n",
1520 | " filename = 'pl_clubs.mp4')"
1521 | ]
1522 | },
1523 | {
1524 | "cell_type": "code",
1525 | "execution_count": 29,
1526 | "metadata": {},
1527 | "outputs": [
1528 | {
1529 | "ename": "TypeError",
1530 | "evalue": "bar_chart_race() got an unexpected keyword argument 'img_label_folder'",
1531 | "output_type": "error",
1532 | "traceback": [
1533 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
1534 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
1535 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mbcr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbar_chart_race\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg_label_folder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'PL clubs'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_bars\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m6\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mperiod_length\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m750\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
1536 | "\u001b[1;31mTypeError\u001b[0m: bar_chart_race() got an unexpected keyword argument 'img_label_folder'"
1537 | ]
1538 | }
1539 | ],
1540 | "source": [
1541 | "bcr.bar_chart_race(df, img_label_folder = 'PL clubs', n_bars=6, period_length = 750)"
1542 | ]
1543 | },
1544 | {
1545 | "cell_type": "code",
1546 | "execution_count": null,
1547 | "metadata": {},
1548 | "outputs": [],
1549 | "source": []
1550 | }
1551 | ],
1552 | "metadata": {
1553 | "kernelspec": {
1554 | "display_name": "Python 3",
1555 | "language": "python",
1556 | "name": "python3"
1557 | },
1558 | "language_info": {
1559 | "codemirror_mode": {
1560 | "name": "ipython",
1561 | "version": 3
1562 | },
1563 | "file_extension": ".py",
1564 | "mimetype": "text/x-python",
1565 | "name": "python",
1566 | "nbconvert_exporter": "python",
1567 | "pygments_lexer": "ipython3",
1568 | "version": "3.7.6"
1569 | }
1570 | },
1571 | "nbformat": 4,
1572 | "nbformat_minor": 4
1573 | }
1574 |
--------------------------------------------------------------------------------
/Layout_Parser/img/doc_1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Layout_Parser/img/doc_1.pdf
--------------------------------------------------------------------------------
/Layout_Parser/img/doc_2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Layout_Parser/img/doc_2.pdf
--------------------------------------------------------------------------------
/Layout_Parser/layout_parser_ex.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "3b88065e-45e9-40ed-aa04-8dd7ec36e45c",
6 | "metadata": {},
7 | "source": [
8 | "# Install Dependencies"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "8a68afad-3a7a-40c4-8d8b-84b00d35e72f",
14 | "metadata": {},
15 | "source": [
16 | "If you work with a Windows machine, it's better to try LayoutParser on Google Colab instead since it's tricky to install Detectron 2 on Windows machine"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "85d97120-8613-464c-95e0-f06eb514b781",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%%capture\n",
27 | "!sudo apt-get install poppler-utils #pdf2image dependency -- restart runtime/kernel after installation\n",
28 | "!sudo apt-get install tesseract-ocr-eng #install Tesseract OCR Engine --restart runtime/kernel after installation"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "27f229b0-b4c9-46ce-845e-d2cc4fe50eb6",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "%%capture\n",
39 | "!pip install layoutparser torchvision && pip install \"detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2\"\n",
40 | "!pip install pdf2img\n",
41 | "!pip install \"layoutparser[ocr]\""
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "769c14cf-293e-46be-81e0-50ab7acba502",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "import pdf2image\n",
52 | "import numpy as np\n",
53 | "import layoutparser as lp\n",
54 | "import torchvision.ops.boxes as bops\n",
55 | "import torch"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "6decc0b6-15ab-4673-8266-8203d62c4cd0",
61 | "metadata": {},
62 | "source": [
63 | "# Layout Detection "
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "b7b8795f-89e7-4e13-9fea-bacb95a2fef6",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "pdf_file= '/img/doc_1.pdf' # Adjust the filepath of your input image accordingly\n",
74 | "img = np.asarray(pdf2image.convert_from_path(pdf_file)[0])"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "14b474a7-5ce0-4bfc-a6be-37f2d7705224",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config',\n",
85 | " extra_config=[\"MODEL.ROI_HEADS.SCORE_THRESH_TEST\", 0.5],\n",
86 | " label_map={0: \"Text\", 1: \"Title\", 2: \"List\", 3:\"Table\", 4:\"Figure\"})\n",
87 | "\n",
88 | "layout_result = model.detect(img)\n",
89 | "\n",
90 | "lp.draw_box(img, layout_result, box_width=5, box_alpha=0.2, show_element_type=True)"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "id": "038df22f-5b1c-454f-80de-9b2443c3a6f7",
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "text_blocks = lp.Layout([b for b in layout_result if b.type=='Text'])\n",
101 | "\n",
102 | "lp.draw_box(img, text_blocks, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "a8c88c59-a3c9-4959-b4d8-f141b7e24670",
108 | "metadata": {},
109 | "source": [
110 | "# OCR Parser with Tesseract"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "id": "881b4809-60a4-4219-acd0-c4919ad25587",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "ocr_agent = lp.TesseractAgent(languages='eng')"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "bdaf8736-6168-40bc-a2de-e6b76d506336",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "image_width = len(img[0])\n",
131 | "\n",
132 | "# Sort element ID of the left column based on y1 coordinate\n",
133 | "left_interval = lp.Interval(0, image_width/2, axis='x').put_on_canvas(img)\n",
134 | "left_blocks = text_blocks.filter_by(left_interval, center=True)._blocks\n",
135 | "left_blocks.sort(key = lambda b:b.coordinates[1])\n",
136 | "\n",
137 | "# Sort element ID of the right column based on y1 coordinate\n",
138 | "right_blocks = [b for b in text_blocks if b not in left_blocks]\n",
139 | "right_blocks.sort(key = lambda b:b.coordinates[1])\n",
140 | "\n",
141 | "# Sort the overall element ID starts from left column\n",
142 | "text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])\n",
143 | "\n",
144 | "lp.draw_box(img, text_blocks, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "db495dc8-4c21-4a99-bd26-c200de49a8f1",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "for block in text_blocks:\n",
155 | "\n",
156 | " # Crop image around the detected layout\n",
157 | " segment_image = (block\n",
158 | " .pad(left=15, right=15, top=5, bottom=5)\n",
159 | " .crop_image(img))\n",
160 | " \n",
161 | " # Perform OCR\n",
162 | " text = ocr_agent.detect(segment_image)\n",
163 | "\n",
164 | " # Save OCR result\n",
165 | " block.set(text=text, inplace=True)"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "69ea88a0-bdb5-4952-8533-88172d4a3b85",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "for txt in text_blocks:\n",
176 | " print(txt.text, end='\\n---\\n')"
177 | ]
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "id": "1a27c5fa-3af9-43d6-8023-8ab3426b691b",
182 | "metadata": {},
183 | "source": [
184 | "# Adjusting LayoutParser Result"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "id": "47ea2cff-6906-4821-8d85-09c45727a1b8",
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "pdf_file_2= '/img/doc_2.pdf' # Adjust the filepath of your input image accordingly\n",
195 | "img_2 = np.asarray(pdf2image.convert_from_path(pdf_file)[0])"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "dc81a71c-91b0-4d8f-8b07-90be4f5ed67d",
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "layout_result_2 = model.detect(img_2)\n",
206 | "\n",
207 | "text_blocks_2 = lp.Layout([b for b in layout_result_2 if b.type=='Text'])\n",
208 | "\n",
209 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "id": "9d6f8412-6f1c-4b18-9bde-a69e6a939329",
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "def set_coordinate(data):\n",
220 | "\n",
221 | " x1 = data.block.x_1\n",
222 | " y1 = data.block.y_1\n",
223 | " x2 = data.block.x_2\n",
224 | " y2 = data.block.y_2\n",
225 | "\n",
226 | " return torch.tensor([[x1, y1, x2, y2]], dtype=torch.float)\n",
227 | "\n",
228 | "def compute_iou(box_1, box_2):\n",
229 | "\n",
230 | " return bops.box_iou(box_1, box_2)\n",
231 | "\n",
232 | "def compute_area(box):\n",
233 | "\n",
234 | " width = box.tolist()[0][2] - box.tolist()[0][0]\n",
235 | " length = box.tolist()[0][3] - box.tolist()[0][1]\n",
236 | " area = width*length\n",
237 | "\n",
238 | " return area\n",
239 | "\n",
240 | "def refine(block_1, block_2):\n",
241 | "\n",
242 | " bb1 = set_coordinate(block_1)\n",
243 | " bb2 = set_coordinate(block_2)\n",
244 | "\n",
245 | " iou = compute_iou(bb1, bb2)\n",
246 | "\n",
247 | " if iou.tolist()[0][0] != 0:\n",
248 | "\n",
249 | " a1 = compute_area(bb1)\n",
250 | " a2 = compute_area(bb2)\n",
251 | "\n",
252 | " block_2.set(type='None', inplace= True) if a1 > a2 else block_1.set(type='None', inplace= True)\n",
253 | " \n",
254 | "\n",
255 | "for layout_i in text_blocks_2:\n",
256 | " \n",
257 | " for layout_j in text_blocks_2:\n",
258 | " \n",
259 | " if layout_i != layout_j: \n",
260 | "\n",
261 | " refine(layout_i, layout_j)\n",
262 | " \n",
263 | "text_blocks_2 = lp.Layout([b for b in text_blocks_2 if b.type=='Text'])\n",
264 | "\n",
265 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "id": "6e28c664-4930-42f5-a107-edc5eb81a23a",
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "text_blocks_2 = lp.Layout([b.set(id = idx) for idx, b in enumerate(text_blocks_2)])\n",
276 | "\n",
277 | "# From the visualization, let's say we know that layout \n",
278 | "# with 'Diameter Thickness' text has element ID of 4\n",
279 | "\n",
280 | "text_blocks_2[4].set(type='None', inplace=True)\n",
281 | "text_blocks_2 = lp.Layout([b for b in text_blocks_2 if b.type=='Text'])\n",
282 | "\n",
283 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)"
284 | ]
285 | }
286 | ],
287 | "metadata": {
288 | "kernelspec": {
289 | "display_name": "Python 3",
290 | "language": "python",
291 | "name": "python3"
292 | },
293 | "language_info": {
294 | "codemirror_mode": {
295 | "name": "ipython",
296 | "version": 3
297 | },
298 | "file_extension": ".py",
299 | "mimetype": "text/x-python",
300 | "name": "python",
301 | "nbconvert_exporter": "python",
302 | "pygments_lexer": "ipython3",
303 | "version": "3.8.8"
304 | }
305 | },
306 | "nbformat": 4,
307 | "nbformat_minor": 5
308 | }
309 |
--------------------------------------------------------------------------------
/Lime/panda_00024.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Lime/panda_00024.jpg
--------------------------------------------------------------------------------
/NER_BERT/.ipynb_checkpoints/NER_with_BERT-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 9,
6 | "id": "1657ccc8-b9dd-46e7-a08f-b9176ea274ba",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%%capture\n",
11 | "pip install transformers"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "438f352b-1664-4219-b257-855919d467fa",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import pandas as pd\n",
22 | "import torch \n",
23 | "import numpy as np\n",
24 | "from transformers import BertTokenizerFast, BertForTokenClassification\n",
25 | "from torch.utils.data import DataLoader\n",
26 | "from tqdm import tqdm\n",
27 | "from torch.optim import SGD"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "1a414a26-8c98-4eef-b97e-5d1a47df5b67",
33 | "metadata": {},
34 | "source": [
35 | "# Read CSV Data"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "id": "e5536782-4d9a-4c35-9d22-7da36a08911a",
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "text/html": [
47 | "\n",
48 | "\n",
61 | "
\n",
62 | " \n",
63 | " \n",
64 | " | \n",
65 | " text | \n",
66 | " labels | \n",
67 | "
\n",
68 | " \n",
69 | " \n",
70 | " \n",
71 | " | 0 | \n",
72 | " Thousands of demonstrators have marched throug... | \n",
73 | " O O O O O O B-geo O O O O O B-geo O O O O O B-... | \n",
74 | "
\n",
75 | " \n",
76 | " | 1 | \n",
77 | " Iranian officials say they expect to get acces... | \n",
78 | " B-gpe O O O O O O O O O O O O O O B-tim O O O ... | \n",
79 | "
\n",
80 | " \n",
81 | " | 2 | \n",
82 | " Helicopter gunships Saturday pounded militant ... | \n",
83 | " O O B-tim O O O O O B-geo O O O O O B-org O O ... | \n",
84 | "
\n",
85 | " \n",
86 | " | 3 | \n",
87 | " They left after a tense hour-long standoff wit... | \n",
88 | " O O O O O O O O O O O | \n",
89 | "
\n",
90 | " \n",
91 | " | 4 | \n",
92 | " U.N. relief coordinator Jan Egeland said Sunda... | \n",
93 | " B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... | \n",
94 | "
\n",
95 | " \n",
96 | "
\n",
97 | "
"
98 | ],
99 | "text/plain": [
100 | " text \\\n",
101 | "0 Thousands of demonstrators have marched throug... \n",
102 | "1 Iranian officials say they expect to get acces... \n",
103 | "2 Helicopter gunships Saturday pounded militant ... \n",
104 | "3 They left after a tense hour-long standoff wit... \n",
105 | "4 U.N. relief coordinator Jan Egeland said Sunda... \n",
106 | "\n",
107 | " labels \n",
108 | "0 O O O O O O B-geo O O O O O B-geo O O O O O B-... \n",
109 | "1 B-gpe O O O O O O O O O O O O O O B-tim O O O ... \n",
110 | "2 O O B-tim O O O O O B-geo O O O O O B-org O O ... \n",
111 | "3 O O O O O O O O O O O \n",
112 | "4 B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... "
113 | ]
114 | },
115 | "execution_count": 3,
116 | "metadata": {},
117 | "output_type": "execute_result"
118 | }
119 | ],
120 | "source": [
121 | "df = pd.read_csv('ner.csv')\n",
122 | "df.head()"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "id": "c2b1bd34-8843-4706-baa5-201f31245183",
128 | "metadata": {},
129 | "source": [
130 | "# Initialize Tokenizer"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "d41be369-ee57-4949-aeb8-7960746d2aea",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "id": "0e3439c2-580e-4972-8603-3a00bc3be62d",
146 | "metadata": {},
147 | "source": [
148 | "# Create Dataset Class "
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "ac7f0682-ea50-4aeb-bcd3-9230df735554",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "label_all_tokens = False\n",
159 | "\n",
160 | "def align_label(texts, labels):\n",
161 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n",
162 | "\n",
163 | " word_ids = tokenized_inputs.word_ids()\n",
164 | "\n",
165 | " previous_word_idx = None\n",
166 | " label_ids = []\n",
167 | "\n",
168 | " for word_idx in word_ids:\n",
169 | "\n",
170 | " if word_idx is None:\n",
171 | " label_ids.append(-100)\n",
172 | "\n",
173 | " elif word_idx != previous_word_idx:\n",
174 | " try:\n",
175 | " label_ids.append(labels_to_ids[labels[word_idx]])\n",
176 | " except:\n",
177 | " label_ids.append(-100)\n",
178 | " else:\n",
179 | " try:\n",
180 | " label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)\n",
181 | " except:\n",
182 | " label_ids.append(-100)\n",
183 | " previous_word_idx = word_idx\n",
184 | "\n",
185 | " return label_ids\n",
186 | "\n",
187 | "class DataSequence(torch.utils.data.Dataset):\n",
188 | "\n",
189 | " def __init__(self, df):\n",
190 | "\n",
191 | " lb = [i.split() for i in df['labels'].values.tolist()]\n",
192 | " txt = df['text'].values.tolist()\n",
193 | " self.texts = [tokenizer(str(i),\n",
194 | " padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\") for i in txt]\n",
195 | " self.labels = [align_label(i,j) for i,j in zip(txt, lb)]\n",
196 | "\n",
197 | " def __len__(self):\n",
198 | "\n",
199 | " return len(self.labels)\n",
200 | "\n",
201 | " def get_batch_data(self, idx):\n",
202 | "\n",
203 | " return self.texts[idx]\n",
204 | "\n",
205 | " def get_batch_labels(self, idx):\n",
206 | "\n",
207 | " return torch.LongTensor(self.labels[idx])\n",
208 | "\n",
209 | " def __getitem__(self, idx):\n",
210 | "\n",
211 | " batch_data = self.get_batch_data(idx)\n",
212 | " batch_labels = self.get_batch_labels(idx)\n",
213 | "\n",
214 | " return batch_data, batch_labels"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "id": "496b3cb5-24c8-4c1b-8382-7c4d0f2339a4",
220 | "metadata": {},
221 | "source": [
222 | "# Split Data and Define Unique Labels"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "6599961c-1cda-47bf-8c82-a1f9ebc94a95",
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "df = df[0:1000]\n",
233 | "\n",
234 | "labels = [i.split() for i in df['labels'].values.tolist()]\n",
235 | "unique_labels = set()\n",
236 | "\n",
237 | "for lb in labels:\n",
238 | " [unique_labels.add(i) for i in lb if i not in unique_labels]\n",
239 | "labels_to_ids = {k: v for v, k in enumerate(unique_labels)}\n",
240 | "ids_to_labels = {v: k for v, k in enumerate(unique_labels)}\n",
241 | "\n",
242 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),\n",
243 | " [int(.8 * len(df)), int(.9 * len(df))])"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "id": "d54b96c5-6875-4990-9248-5d6ad5b053e9",
249 | "metadata": {},
250 | "source": [
251 | "# Build Model"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "13ebfa5e-c91a-4967-b0cc-23e314c32348",
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "class BertModel(torch.nn.Module):\n",
262 | "\n",
263 | " def __init__(self):\n",
264 | "\n",
265 | " super(BertModel, self).__init__()\n",
266 | "\n",
267 | " self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))\n",
268 | "\n",
269 | " def forward(self, input_id, mask, label):\n",
270 | "\n",
271 | " output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)\n",
272 | "\n",
273 | " return output"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "id": "c3a48d06-d343-449b-829d-2bcad4b2af52",
279 | "metadata": {},
280 | "source": [
281 | "# Model Training"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "id": "291bfdad-2df3-4de3-954d-b7e7a9a1b253",
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "def train_loop(model, df_train, df_val):\n",
292 | "\n",
293 | " train_dataset = DataSequence(df_train)\n",
294 | " val_dataset = DataSequence(df_val)\n",
295 | "\n",
296 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)\n",
297 | " val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)\n",
298 | "\n",
299 | " use_cuda = torch.cuda.is_available()\n",
300 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
301 | "\n",
302 | " optimizer = SGD(model.parameters(), lr=LEARNING_RATE)\n",
303 | "\n",
304 | " if use_cuda:\n",
305 | " model = model.cuda()\n",
306 | "\n",
307 | " best_acc = 0\n",
308 | " best_loss = 1000\n",
309 | "\n",
310 | " for epoch_num in range(EPOCHS):\n",
311 | "\n",
312 | " total_acc_train = 0\n",
313 | " total_loss_train = 0\n",
314 | "\n",
315 | " model.train()\n",
316 | "\n",
317 | " for train_data, train_label in tqdm(train_dataloader):\n",
318 | "\n",
319 | " train_label = train_label.to(device)\n",
320 | " mask = train_data['attention_mask'].squeeze(1).to(device)\n",
321 | " input_id = train_data['input_ids'].squeeze(1).to(device)\n",
322 | "\n",
323 | " optimizer.zero_grad()\n",
324 | " loss, logits = model(input_id, mask, train_label)\n",
325 | "\n",
326 | " for i in range(logits.shape[0]):\n",
327 | "\n",
328 | " logits_clean = logits[i][train_label[i] != -100]\n",
329 | " label_clean = train_label[i][train_label[i] != -100]\n",
330 | "\n",
331 | " predictions = logits_clean.argmax(dim=1)\n",
332 | " acc = (predictions == label_clean).float().mean()\n",
333 | " total_acc_train += acc\n",
334 | " total_loss_train += loss.item()\n",
335 | "\n",
336 | " loss.backward()\n",
337 | " optimizer.step()\n",
338 | "\n",
339 | " model.eval()\n",
340 | "\n",
341 | " total_acc_val = 0\n",
342 | " total_loss_val = 0\n",
343 | "\n",
344 | " for val_data, val_label in val_dataloader:\n",
345 | "\n",
346 | " val_label = val_label.to(device)\n",
347 | " mask = val_data['attention_mask'].squeeze(1).to(device)\n",
348 | " input_id = val_data['input_ids'].squeeze(1).to(device)\n",
349 | "\n",
350 | " loss, logits = model(input_id, mask, val_label)\n",
351 | "\n",
352 | " for i in range(logits.shape[0]):\n",
353 | "\n",
354 | " logits_clean = logits[i][val_label[i] != -100]\n",
355 | " label_clean = val_label[i][val_label[i] != -100]\n",
356 | "\n",
357 | " predictions = logits_clean.argmax(dim=1)\n",
358 | " acc = (predictions == label_clean).float().mean()\n",
359 | " total_acc_val += acc\n",
360 | " total_loss_val += loss.item()\n",
361 | "\n",
362 | " val_accuracy = total_acc_val / len(df_val)\n",
363 | " val_loss = total_loss_val / len(df_val)\n",
364 | "\n",
365 | " print(\n",
366 | " f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')\n",
367 | "\n",
368 | "LEARNING_RATE = 5e-3\n",
369 | "EPOCHS = 5\n",
370 | "BATCH_SIZE = 2\n",
371 | "\n",
372 | "model = BertModel()\n",
373 | "train_loop(model, df_train, df_val)"
374 | ]
375 | },
376 | {
377 | "cell_type": "markdown",
378 | "id": "69e1af60-33c3-497e-984f-0094f9bc3a4f",
379 | "metadata": {},
380 | "source": [
381 | "# Evaluate Model"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "04295796-7033-4bdf-849f-95e030fc94aa",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "def evaluate(model, df_test):\n",
392 | "\n",
393 | " test_dataset = DataSequence(df_test)\n",
394 | "\n",
395 | " test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1)\n",
396 | "\n",
397 | " use_cuda = torch.cuda.is_available()\n",
398 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
399 | "\n",
400 | " if use_cuda:\n",
401 | " model = model.cuda()\n",
402 | "\n",
403 | " total_acc_test = 0.0\n",
404 | "\n",
405 | " for test_data, test_label in test_dataloader:\n",
406 | "\n",
407 | " test_label = test_label.to(device)\n",
408 | " mask = test_data['attention_mask'].squeeze(1).to(device)\n",
409 | "\n",
410 | " input_id = test_data['input_ids'].squeeze(1).to(device)\n",
411 | "\n",
412 | " loss, logits = model(input_id, mask, test_label)\n",
413 | "\n",
414 | " for i in range(logits.shape[0]):\n",
415 | "\n",
416 | " logits_clean = logits[i][test_label[i] != -100]\n",
417 | " label_clean = test_label[i][test_label[i] != -100]\n",
418 | "\n",
419 | " predictions = logits_clean.argmax(dim=1)\n",
420 | " acc = (predictions == label_clean).float().mean()\n",
421 | " total_acc_test += acc\n",
422 | "\n",
423 | " val_accuracy = total_acc_test / len(df_test)\n",
424 | " print(f'Test Accuracy: {total_acc_test / len(df_test): .3f}')\n",
425 | "\n",
426 | "\n",
427 | "evaluate(model, df_test)"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "id": "12edbaf3-cd39-463f-b2ba-d0ce46c246bd",
433 | "metadata": {},
434 | "source": [
435 | "# Predict One Sentence"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "id": "99bc3835-a075-4fce-812b-7b9e96778816",
442 | "metadata": {},
443 | "outputs": [],
444 | "source": [
445 | "def align_word_ids(texts):\n",
446 | " \n",
447 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n",
448 | "\n",
449 | " word_ids = tokenized_inputs.word_ids()\n",
450 | "\n",
451 | " previous_word_idx = None\n",
452 | " label_ids = []\n",
453 | "\n",
454 | " for word_idx in word_ids:\n",
455 | "\n",
456 | " if word_idx is None:\n",
457 | " label_ids.append(-100)\n",
458 | "\n",
459 | " elif word_idx != previous_word_idx:\n",
460 | " try:\n",
461 | " label_ids.append(1)\n",
462 | " except:\n",
463 | " label_ids.append(-100)\n",
464 | " else:\n",
465 | " try:\n",
466 | " label_ids.append(1 if label_all_tokens else -100)\n",
467 | " except:\n",
468 | " label_ids.append(-100)\n",
469 | " previous_word_idx = word_idx\n",
470 | "\n",
471 | " return label_ids\n",
472 | "\n",
473 | "\n",
474 | "def evaluate_one_text(model, sentence):\n",
475 | "\n",
476 | "\n",
477 | " use_cuda = torch.cuda.is_available()\n",
478 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
479 | "\n",
480 | " if use_cuda:\n",
481 | " model = model.cuda()\n",
482 | "\n",
483 | " text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\")\n",
484 | "\n",
485 | " mask = text['attention_mask'].to(device)\n",
486 | " input_id = text['input_ids'].to(device)\n",
487 | " label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)\n",
488 | "\n",
489 | " logits = model(input_id, mask, None)\n",
490 | " logits_clean = logits[0][label_ids != -100]\n",
491 | "\n",
492 | " predictions = logits_clean.argmax(dim=1).tolist()\n",
493 | " prediction_label = [ids_to_labels[i] for i in predictions]\n",
494 | " print(sentence)\n",
495 | " print(prediction_label)\n",
496 | " \n",
497 | "evaluate_one_text(model, 'Bill Gates is the founder of Microsoft')"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "id": "fa186d96-7e3c-4457-a3ab-cabc61f2d261",
504 | "metadata": {},
505 | "outputs": [],
506 | "source": []
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "id": "723f31f2-12d7-48cc-9f88-1c8fbe860f4c",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": []
515 | }
516 | ],
517 | "metadata": {
518 | "kernelspec": {
519 | "display_name": "Python 3 (ipykernel)",
520 | "language": "python",
521 | "name": "python3"
522 | },
523 | "language_info": {
524 | "codemirror_mode": {
525 | "name": "ipython",
526 | "version": 3
527 | },
528 | "file_extension": ".py",
529 | "mimetype": "text/x-python",
530 | "name": "python",
531 | "nbconvert_exporter": "python",
532 | "pygments_lexer": "ipython3",
533 | "version": "3.9.7"
534 | }
535 | },
536 | "nbformat": 4,
537 | "nbformat_minor": 5
538 | }
539 |
--------------------------------------------------------------------------------
/NER_BERT/NER_with_BERT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 9,
6 | "id": "1657ccc8-b9dd-46e7-a08f-b9176ea274ba",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%%capture\n",
11 | "pip install transformers"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "438f352b-1664-4219-b257-855919d467fa",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import pandas as pd\n",
22 | "import torch \n",
23 | "import numpy as np\n",
24 | "from transformers import BertTokenizerFast, BertForTokenClassification\n",
25 | "from torch.utils.data import DataLoader\n",
26 | "from tqdm import tqdm\n",
27 | "from torch.optim import SGD"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "1a414a26-8c98-4eef-b97e-5d1a47df5b67",
33 | "metadata": {},
34 | "source": [
35 | "# Read CSV Data"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 3,
41 | "id": "e5536782-4d9a-4c35-9d22-7da36a08911a",
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "text/html": [
47 | "\n",
48 | "\n",
61 | "
\n",
62 | " \n",
63 | " \n",
64 | " | \n",
65 | " text | \n",
66 | " labels | \n",
67 | "
\n",
68 | " \n",
69 | " \n",
70 | " \n",
71 | " | 0 | \n",
72 | " Thousands of demonstrators have marched throug... | \n",
73 | " O O O O O O B-geo O O O O O B-geo O O O O O B-... | \n",
74 | "
\n",
75 | " \n",
76 | " | 1 | \n",
77 | " Iranian officials say they expect to get acces... | \n",
78 | " B-gpe O O O O O O O O O O O O O O B-tim O O O ... | \n",
79 | "
\n",
80 | " \n",
81 | " | 2 | \n",
82 | " Helicopter gunships Saturday pounded militant ... | \n",
83 | " O O B-tim O O O O O B-geo O O O O O B-org O O ... | \n",
84 | "
\n",
85 | " \n",
86 | " | 3 | \n",
87 | " They left after a tense hour-long standoff wit... | \n",
88 | " O O O O O O O O O O O | \n",
89 | "
\n",
90 | " \n",
91 | " | 4 | \n",
92 | " U.N. relief coordinator Jan Egeland said Sunda... | \n",
93 | " B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... | \n",
94 | "
\n",
95 | " \n",
96 | "
\n",
97 | "
"
98 | ],
99 | "text/plain": [
100 | " text \\\n",
101 | "0 Thousands of demonstrators have marched throug... \n",
102 | "1 Iranian officials say they expect to get acces... \n",
103 | "2 Helicopter gunships Saturday pounded militant ... \n",
104 | "3 They left after a tense hour-long standoff wit... \n",
105 | "4 U.N. relief coordinator Jan Egeland said Sunda... \n",
106 | "\n",
107 | " labels \n",
108 | "0 O O O O O O B-geo O O O O O B-geo O O O O O B-... \n",
109 | "1 B-gpe O O O O O O O O O O O O O O B-tim O O O ... \n",
110 | "2 O O B-tim O O O O O B-geo O O O O O B-org O O ... \n",
111 | "3 O O O O O O O O O O O \n",
112 | "4 B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... "
113 | ]
114 | },
115 | "execution_count": 3,
116 | "metadata": {},
117 | "output_type": "execute_result"
118 | }
119 | ],
120 | "source": [
121 | "df = pd.read_csv('ner.csv')\n",
122 | "df.head()"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "id": "c2b1bd34-8843-4706-baa5-201f31245183",
128 | "metadata": {},
129 | "source": [
130 | "# Initialize Tokenizer"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "d41be369-ee57-4949-aeb8-7960746d2aea",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "id": "0e3439c2-580e-4972-8603-3a00bc3be62d",
146 | "metadata": {},
147 | "source": [
148 | "# Create Dataset Class "
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "ac7f0682-ea50-4aeb-bcd3-9230df735554",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "label_all_tokens = False\n",
159 | "\n",
160 | "def align_label(texts, labels):\n",
161 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n",
162 | "\n",
163 | " word_ids = tokenized_inputs.word_ids()\n",
164 | "\n",
165 | " previous_word_idx = None\n",
166 | " label_ids = []\n",
167 | "\n",
168 | " for word_idx in word_ids:\n",
169 | "\n",
170 | " if word_idx is None:\n",
171 | " label_ids.append(-100)\n",
172 | "\n",
173 | " elif word_idx != previous_word_idx:\n",
174 | " try:\n",
175 | " label_ids.append(labels_to_ids[labels[word_idx]])\n",
176 | " except:\n",
177 | " label_ids.append(-100)\n",
178 | " else:\n",
179 | " try:\n",
180 | " label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)\n",
181 | " except:\n",
182 | " label_ids.append(-100)\n",
183 | " previous_word_idx = word_idx\n",
184 | "\n",
185 | " return label_ids\n",
186 | "\n",
187 | "class DataSequence(torch.utils.data.Dataset):\n",
188 | "\n",
189 | " def __init__(self, df):\n",
190 | "\n",
191 | " lb = [i.split() for i in df['labels'].values.tolist()]\n",
192 | " txt = df['text'].values.tolist()\n",
193 | " self.texts = [tokenizer(str(i),\n",
194 | " padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\") for i in txt]\n",
195 | " self.labels = [align_label(i,j) for i,j in zip(txt, lb)]\n",
196 | "\n",
197 | " def __len__(self):\n",
198 | "\n",
199 | " return len(self.labels)\n",
200 | "\n",
201 | " def get_batch_data(self, idx):\n",
202 | "\n",
203 | " return self.texts[idx]\n",
204 | "\n",
205 | " def get_batch_labels(self, idx):\n",
206 | "\n",
207 | " return torch.LongTensor(self.labels[idx])\n",
208 | "\n",
209 | " def __getitem__(self, idx):\n",
210 | "\n",
211 | " batch_data = self.get_batch_data(idx)\n",
212 | " batch_labels = self.get_batch_labels(idx)\n",
213 | "\n",
214 | " return batch_data, batch_labels"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "id": "496b3cb5-24c8-4c1b-8382-7c4d0f2339a4",
220 | "metadata": {},
221 | "source": [
222 | "# Split Data and Define Unique Labels"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "6599961c-1cda-47bf-8c82-a1f9ebc94a95",
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "df = df[0:1000]\n",
233 | "\n",
234 | "labels = [i.split() for i in df['labels'].values.tolist()]\n",
235 | "unique_labels = set()\n",
236 | "\n",
237 | "for lb in labels:\n",
238 | " [unique_labels.add(i) for i in lb if i not in unique_labels]\n",
239 | "labels_to_ids = {k: v for v, k in enumerate(unique_labels)}\n",
240 | "ids_to_labels = {v: k for v, k in enumerate(unique_labels)}\n",
241 | "\n",
242 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),\n",
243 | " [int(.8 * len(df)), int(.9 * len(df))])"
244 | ]
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "id": "d54b96c5-6875-4990-9248-5d6ad5b053e9",
249 | "metadata": {},
250 | "source": [
251 | "# Build Model"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "13ebfa5e-c91a-4967-b0cc-23e314c32348",
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "class BertModel(torch.nn.Module):\n",
262 | "\n",
263 | " def __init__(self):\n",
264 | "\n",
265 | " super(BertModel, self).__init__()\n",
266 | "\n",
267 | " self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))\n",
268 | "\n",
269 | " def forward(self, input_id, mask, label):\n",
270 | "\n",
271 | " output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)\n",
272 | "\n",
273 | " return output"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "id": "c3a48d06-d343-449b-829d-2bcad4b2af52",
279 | "metadata": {},
280 | "source": [
281 | "# Model Training"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "id": "291bfdad-2df3-4de3-954d-b7e7a9a1b253",
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "def train_loop(model, df_train, df_val):\n",
292 | "\n",
293 | " train_dataset = DataSequence(df_train)\n",
294 | " val_dataset = DataSequence(df_val)\n",
295 | "\n",
296 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)\n",
297 | " val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)\n",
298 | "\n",
299 | " use_cuda = torch.cuda.is_available()\n",
300 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
301 | "\n",
302 | " optimizer = SGD(model.parameters(), lr=LEARNING_RATE)\n",
303 | "\n",
304 | " if use_cuda:\n",
305 | " model = model.cuda()\n",
306 | "\n",
307 | " best_acc = 0\n",
308 | " best_loss = 1000\n",
309 | "\n",
310 | " for epoch_num in range(EPOCHS):\n",
311 | "\n",
312 | " total_acc_train = 0\n",
313 | " total_loss_train = 0\n",
314 | "\n",
315 | " model.train()\n",
316 | "\n",
317 | " for train_data, train_label in tqdm(train_dataloader):\n",
318 | "\n",
319 | " train_label = train_label.to(device)\n",
320 | " mask = train_data['attention_mask'].squeeze(1).to(device)\n",
321 | " input_id = train_data['input_ids'].squeeze(1).to(device)\n",
322 | "\n",
323 | " optimizer.zero_grad()\n",
324 | " loss, logits = model(input_id, mask, train_label)\n",
325 | "\n",
326 | " for i in range(logits.shape[0]):\n",
327 | "\n",
328 | " logits_clean = logits[i][train_label[i] != -100]\n",
329 | " label_clean = train_label[i][train_label[i] != -100]\n",
330 | "\n",
331 | " predictions = logits_clean.argmax(dim=1)\n",
332 | " acc = (predictions == label_clean).float().mean()\n",
333 | " total_acc_train += acc\n",
334 | " total_loss_train += loss.item()\n",
335 | "\n",
336 | " loss.backward()\n",
337 | " optimizer.step()\n",
338 | "\n",
339 | " model.eval()\n",
340 | "\n",
341 | " total_acc_val = 0\n",
342 | " total_loss_val = 0\n",
343 | "\n",
344 | " for val_data, val_label in val_dataloader:\n",
345 | "\n",
346 | " val_label = val_label.to(device)\n",
347 | " mask = val_data['attention_mask'].squeeze(1).to(device)\n",
348 | " input_id = val_data['input_ids'].squeeze(1).to(device)\n",
349 | "\n",
350 | " loss, logits = model(input_id, mask, val_label)\n",
351 | "\n",
352 | " for i in range(logits.shape[0]):\n",
353 | "\n",
354 | " logits_clean = logits[i][val_label[i] != -100]\n",
355 | " label_clean = val_label[i][val_label[i] != -100]\n",
356 | "\n",
357 | " predictions = logits_clean.argmax(dim=1)\n",
358 | " acc = (predictions == label_clean).float().mean()\n",
359 | " total_acc_val += acc\n",
360 | " total_loss_val += loss.item()\n",
361 | "\n",
362 | " val_accuracy = total_acc_val / len(df_val)\n",
363 | " val_loss = total_loss_val / len(df_val)\n",
364 | "\n",
365 | " print(\n",
366 | " f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')\n",
367 | "\n",
368 | "LEARNING_RATE = 5e-3\n",
369 | "EPOCHS = 5\n",
370 | "BATCH_SIZE = 2\n",
371 | "\n",
372 | "model = BertModel()\n",
373 | "train_loop(model, df_train, df_val)"
374 | ]
375 | },
376 | {
377 | "cell_type": "markdown",
378 | "id": "69e1af60-33c3-497e-984f-0094f9bc3a4f",
379 | "metadata": {},
380 | "source": [
381 | "# Evaluate Model"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "04295796-7033-4bdf-849f-95e030fc94aa",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "def evaluate(model, df_test):\n",
392 | "\n",
393 | " test_dataset = DataSequence(df_test)\n",
394 | "\n",
395 | " test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1)\n",
396 | "\n",
397 | " use_cuda = torch.cuda.is_available()\n",
398 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
399 | "\n",
400 | " if use_cuda:\n",
401 | " model = model.cuda()\n",
402 | "\n",
403 | " total_acc_test = 0.0\n",
404 | "\n",
405 | " for test_data, test_label in test_dataloader:\n",
406 | "\n",
407 | " test_label = test_label.to(device)\n",
408 | " mask = test_data['attention_mask'].squeeze(1).to(device)\n",
409 | "\n",
410 | " input_id = test_data['input_ids'].squeeze(1).to(device)\n",
411 | "\n",
412 | " loss, logits = model(input_id, mask, test_label)\n",
413 | "\n",
414 | " for i in range(logits.shape[0]):\n",
415 | "\n",
416 | " logits_clean = logits[i][test_label[i] != -100]\n",
417 | " label_clean = test_label[i][test_label[i] != -100]\n",
418 | "\n",
419 | " predictions = logits_clean.argmax(dim=1)\n",
420 | " acc = (predictions == label_clean).float().mean()\n",
421 | " total_acc_test += acc\n",
422 | "\n",
423 | " val_accuracy = total_acc_test / len(df_test)\n",
424 | " print(f'Test Accuracy: {total_acc_test / len(df_test): .3f}')\n",
425 | "\n",
426 | "\n",
427 | "evaluate(model, df_test)"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "id": "12edbaf3-cd39-463f-b2ba-d0ce46c246bd",
433 | "metadata": {},
434 | "source": [
435 | "# Predict One Sentence"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "id": "99bc3835-a075-4fce-812b-7b9e96778816",
442 | "metadata": {},
443 | "outputs": [],
444 | "source": [
445 | "def align_word_ids(texts):\n",
446 | " \n",
447 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n",
448 | "\n",
449 | " word_ids = tokenized_inputs.word_ids()\n",
450 | "\n",
451 | " previous_word_idx = None\n",
452 | " label_ids = []\n",
453 | "\n",
454 | " for word_idx in word_ids:\n",
455 | "\n",
456 | " if word_idx is None:\n",
457 | " label_ids.append(-100)\n",
458 | "\n",
459 | " elif word_idx != previous_word_idx:\n",
460 | " try:\n",
461 | " label_ids.append(1)\n",
462 | " except:\n",
463 | " label_ids.append(-100)\n",
464 | " else:\n",
465 | " try:\n",
466 | " label_ids.append(1 if label_all_tokens else -100)\n",
467 | " except:\n",
468 | " label_ids.append(-100)\n",
469 | " previous_word_idx = word_idx\n",
470 | "\n",
471 | " return label_ids\n",
472 | "\n",
473 | "\n",
474 | "def evaluate_one_text(model, sentence):\n",
475 | "\n",
476 | "\n",
477 | " use_cuda = torch.cuda.is_available()\n",
478 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
479 | "\n",
480 | " if use_cuda:\n",
481 | " model = model.cuda()\n",
482 | "\n",
483 | " text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\")\n",
484 | "\n",
485 | " mask = text['attention_mask'].to(device)\n",
486 | " input_id = text['input_ids'].to(device)\n",
487 | " label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)\n",
488 | "\n",
489 | " logits = model(input_id, mask, None)\n",
490 | " logits_clean = logits[0][label_ids != -100]\n",
491 | "\n",
492 | " predictions = logits_clean.argmax(dim=1).tolist()\n",
493 | " prediction_label = [ids_to_labels[i] for i in predictions]\n",
494 | " print(sentence)\n",
495 | " print(prediction_label)\n",
496 | " \n",
497 | "evaluate_one_text(model, 'Bill Gates is the founder of Microsoft')"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "id": "fa186d96-7e3c-4457-a3ab-cabc61f2d261",
504 | "metadata": {},
505 | "outputs": [],
506 | "source": []
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "id": "723f31f2-12d7-48cc-9f88-1c8fbe860f4c",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": []
515 | }
516 | ],
517 | "metadata": {
518 | "kernelspec": {
519 | "display_name": "Python 3 (ipykernel)",
520 | "language": "python",
521 | "name": "python3"
522 | },
523 | "language_info": {
524 | "codemirror_mode": {
525 | "name": "ipython",
526 | "version": 3
527 | },
528 | "file_extension": ".py",
529 | "mimetype": "text/x-python",
530 | "name": "python",
531 | "nbconvert_exporter": "python",
532 | "pygments_lexer": "ipython3",
533 | "version": "3.9.7"
534 | }
535 | },
536 | "nbformat": 4,
537 | "nbformat_minor": 5
538 | }
539 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Source code for Medium articles
2 |
--------------------------------------------------------------------------------
/STS_BERT/STS_BERT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "28df765c-ec6b-450b-a8b4-b40c65f73159",
6 | "metadata": {},
7 | "source": [
8 | "# Install necessary libraries"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "bc873328-78e8-4e81-bb0b-fc1ba41bcb82",
15 | "metadata": {
16 | "tags": []
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%%capture\n",
21 | "\n",
22 | "!pip install datasets\n",
23 | "!pip install sentence-transformers\n",
24 | "!pip install transformers"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "002b8bb8-e806-48a8-ab70-8c8f35d13466",
30 | "metadata": {},
31 | "source": [
32 | "# Import libraries"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "d8fe01b5-6a9a-458d-a97f-1fb20b621b0f",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "import torch\n",
43 | "from sentence_transformers import SentenceTransformer, models\n",
44 | "from transformers import BertTokenizer\n",
45 | "from torch.optim import Adam\n",
46 | "from torch.utils.data import DataLoader\n",
47 | "from tqdm import tqdm\n",
48 | "from datasets import load_dataset"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "0bd9f3e6-9625-4c64-bb76-5e90911239a6",
54 | "metadata": {},
55 | "source": [
56 | "# Fetch data for training and test, as well as the tokenizer"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "0fc07b38-8951-4d47-a5e1-ecdf810cdd93",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "# Dataset for training\n",
67 | "dataset = load_dataset(\"stsb_multi_mt\", name=\"en\", split=\"train\")\n",
68 | "similarity = [i['similarity_score'] for i in dataset]\n",
69 | "normalized_similarity = [i/5.0 for i in similarity]\n",
70 | "\n",
71 | "# Dataset for test\n",
72 | "test_dataset = load_dataset(\"stsb_multi_mt\", name=\"en\", split=\"test\")\n",
73 | "\n",
74 | "# Prepare test data\n",
75 | "sentence_1_test = [i['sentence1'] for i in test_dataset]\n",
76 | "sentence_2_test = [i['sentence2'] for i in test_dataset]\n",
77 | "text_cat_test = [[str(x), str(y)] for x,y in zip(sentence_1_test, sentence_2_test)]\n",
78 | "\n",
79 | "# Set the tokenizer\n",
80 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "f5299ca8-d00d-4151-8390-86c97403785d",
86 | "metadata": {},
87 | "source": [
88 | "# Define Model architecture"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "8f44b7c4-d15c-4b79-bc5e-371d5bd42ffe",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "class STSBertModel(torch.nn.Module):\n",
99 | "\n",
100 | " def __init__(self):\n",
101 | "\n",
102 | " super(STSBertModel, self).__init__()\n",
103 | "\n",
104 | " word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=128)\n",
105 | " pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())\n",
106 | " self.sts_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n",
107 | "\n",
108 | " def forward(self, input_data):\n",
109 | "\n",
110 | " output = self.sts_model(input_data)\n",
111 | " \n",
112 | " return output"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "id": "3103d5bd-e4f3-4b98-a1c6-cfc354f9edca",
118 | "metadata": {},
119 | "source": [
120 | "# Define Dataloader for training"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "98ae2242-af8f-4c76-a21a-4baeeed3ae43",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "class DataSequence(torch.utils.data.Dataset):\n",
131 | "\n",
132 | " def __init__(self, dataset):\n",
133 | "\n",
134 | " similarity = [i['similarity_score'] for i in dataset]\n",
135 | " self.label = [i/5.0 for i in similarity]\n",
136 | " self.sentence_1 = [i['sentence1'] for i in dataset]\n",
137 | " self.sentence_2 = [i['sentence2'] for i in dataset]\n",
138 | " self.text_cat = [[str(x), str(y)] for x,y in zip(self.sentence_1, self.sentence_2)]\n",
139 | "\n",
140 | " def __len__(self):\n",
141 | "\n",
142 | " return len(self.text_cat)\n",
143 | "\n",
144 | " def get_batch_labels(self, idx):\n",
145 | "\n",
146 | " return torch.tensor(self.label[idx])\n",
147 | "\n",
148 | " def get_batch_texts(self, idx):\n",
149 | "\n",
150 | " return tokenizer(self.text_cat[idx], padding='max_length', max_length = 128, truncation=True, return_tensors=\"pt\")\n",
151 | "\n",
152 | " def __getitem__(self, idx):\n",
153 | "\n",
154 | " batch_texts = self.get_batch_texts(idx)\n",
155 | " batch_y = self.get_batch_labels(idx)\n",
156 | "\n",
157 | " return batch_texts, batch_y\n",
158 | "\n",
159 | "def collate_fn(texts):\n",
160 | "\n",
161 | " num_texts = len(texts['input_ids'])\n",
162 | " features = list()\n",
163 | " for i in range(num_texts):\n",
164 | " features.append({'input_ids':texts['input_ids'][i], 'attention_mask':texts['attention_mask'][i]})\n",
165 | " \n",
166 | " return features"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "id": "3fd81992-4a05-4f48-9869-83b38b7a3f90",
172 | "metadata": {},
173 | "source": [
174 | "# Define loss function for training"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "id": "78ca0312-8648-4101-9429-7286d6268bb5",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "class CosineSimilarityLoss(torch.nn.Module):\n",
185 | "\n",
186 | " def __init__(self, loss_fct = torch.nn.MSELoss(), cos_score_transformation=torch.nn.Identity()):\n",
187 | " \n",
188 | " super(CosineSimilarityLoss, self).__init__()\n",
189 | " self.loss_fct = loss_fct\n",
190 | " self.cos_score_transformation = cos_score_transformation\n",
191 | " self.cos = torch.nn.CosineSimilarity(dim=1)\n",
192 | "\n",
193 | " def forward(self, input, label):\n",
194 | "\n",
195 | " embedding_1 = torch.stack([inp[0] for inp in input])\n",
196 | " embedding_2 = torch.stack([inp[1] for inp in input])\n",
197 | "\n",
198 | " output = self.cos_score_transformation(self.cos(embedding_1, embedding_2))\n",
199 | "\n",
200 | " return self.loss_fct(output, label.squeeze())"
201 | ]
202 | },
203 | {
204 | "cell_type": "markdown",
205 | "id": "0fc17726-80e4-409d-8183-1ec17d2e05da",
206 | "metadata": {},
207 | "source": [
208 | "# Train the Model"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "id": "bcf41a99-3ce1-4125-b464-d3b8d0d295af",
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "def model_train(dataset, epochs, learning_rate, bs):\n",
219 | "\n",
220 | " use_cuda = torch.cuda.is_available()\n",
221 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
222 | "\n",
223 | " model = STSBertModel()\n",
224 | "\n",
225 | " criterion = CosineSimilarityLoss()\n",
226 | " optimizer = Adam(model.parameters(), lr=learning_rate)\n",
227 | "\n",
228 | " train_dataset = DataSequence(dataset)\n",
229 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True)\n",
230 | "\n",
231 | " if use_cuda:\n",
232 | " model = model.cuda()\n",
233 | " criterion = criterion.cuda()\n",
234 | "\n",
235 | " best_acc = 0.0\n",
236 | " best_loss = 1000\n",
237 | "\n",
238 | " for i in range(epochs):\n",
239 | "\n",
240 | " total_acc_train = 0\n",
241 | " total_loss_train = 0.0\n",
242 | "\n",
243 | " for train_data, train_label in tqdm(train_dataloader):\n",
244 | "\n",
245 | " train_data['input_ids'] = train_data['input_ids'].to(device)\n",
246 | " train_data['attention_mask'] = train_data['attention_mask'].to(device)\n",
247 | " del train_data['token_type_ids']\n",
248 | "\n",
249 | " train_data = collate_fn(train_data)\n",
250 | "\n",
251 | " output = [model(feature)['sentence_embedding'] for feature in train_data]\n",
252 | "\n",
253 | " loss = criterion(output, train_label.to(device))\n",
254 | " total_loss_train += loss.item()\n",
255 | "\n",
256 | " loss.backward()\n",
257 | " optimizer.step()\n",
258 | " optimizer.zero_grad()\n",
259 | "\n",
260 | " print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(dataset): .3f}')\n",
261 | " model.train()\n",
262 | "\n",
263 | " return model\n",
264 | "\n",
265 | "EPOCHS = 8\n",
266 | "LEARNING_RATE = 1e-6\n",
267 | "BATCH_SIZE = 8\n",
268 | "\n",
269 | "# Train the model\n",
270 | "trained_model = model_train(dataset, EPOCHS, LEARNING_RATE, BATCH_SIZE)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "id": "f8052fa6-8a87-4378-87c1-1714f55a88bc",
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "# Function to predict test data\n",
281 | "def predict_sts(texts):\n",
282 | "\n",
283 | " trained_model.to('cpu')\n",
284 | " trained_model.eval()\n",
285 | " test_input = tokenizer(texts, padding='max_length', max_length = 128, truncation=True, return_tensors=\"pt\")\n",
286 | " test_input['input_ids'] = test_input['input_ids']\n",
287 | " test_input['attention_mask'] = test_input['attention_mask']\n",
288 | " del test_input['token_type_ids']\n",
289 | "\n",
290 | " test_output = trained_model(test_input)['sentence_embedding']\n",
291 | " sim = torch.nn.functional.cosine_similarity(test_output[0], test_output[1], dim=0).item()\n",
292 | "\n",
293 | " return sim"
294 | ]
295 | },
296 | {
297 | "cell_type": "markdown",
298 | "id": "d4f70635-7047-46b5-b313-e2e6963ffdab",
299 | "metadata": {},
300 | "source": [
301 | "# Predict on test data"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "id": "e4fde3ae-db23-4135-aa04-4a79d040b089",
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "predict_sts(text_cat_test[245])"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "id": "f757ad3e-3ea6-4706-a4f3-4a090fa6dbd2",
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "predict_sts(text_cat_test[420])"
322 | ]
323 | }
324 | ],
325 | "metadata": {
326 | "kernelspec": {
327 | "display_name": "Python 3 (ipykernel)",
328 | "language": "python",
329 | "name": "python3"
330 | },
331 | "language_info": {
332 | "codemirror_mode": {
333 | "name": "ipython",
334 | "version": 3
335 | },
336 | "file_extension": ".py",
337 | "mimetype": "text/x-python",
338 | "name": "python",
339 | "nbconvert_exporter": "python",
340 | "pygments_lexer": "ipython3",
341 | "version": "3.9.7"
342 | }
343 | },
344 | "nbformat": 4,
345 | "nbformat_minor": 5
346 | }
347 |
--------------------------------------------------------------------------------
/Spaces_Translation_App/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3 | import nltk
4 | from nltk import tokenize
5 | nltk.download('punkt')
6 |
7 | tokenizer = AutoTokenizer.from_pretrained("t5-base")
8 |
9 | @st.cache(allow_output_mutation=True)
10 | def load_model():
11 | model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
12 |
13 | return model
14 |
15 | model = load_model()
16 |
17 | st.sidebar.subheader('Select your source and target language below.')
18 | source_lang = st.sidebar.selectbox("Source language",['English'])
19 | target_lang = st.sidebar.selectbox("Target language",['German','French'])
20 |
21 | st.title('Simple English ➡️ German/French Translation App')
22 |
23 | st.write('This is a simple machine translation app that will translate\
24 | your English input text into German or French language\
25 | by leveraging a pre-trained [Text-To-Text Transfer Tranformers](https://arxiv.org/abs/1910.10683) model.')
26 |
27 | st.write('You can see the source code to build this app in the \'Files and version\' tab.')
28 |
29 | st.subheader('Input Text')
30 | text = st.text_area(' ', height=200)
31 |
32 | if text != '':
33 |
34 | prefix = 'translate '+str(source_lang)+' to '+str(target_lang)
35 | sentence_token = tokenize.sent_tokenize(text)
36 | output = tokenizer([prefix+sentence for sentence in sentence_token], padding=True, return_tensors="pt")
37 | translated_id = model.generate(output["input_ids"], attention_mask=output['attention_mask'], max_length=100)
38 | translated_word = tokenizer.batch_decode(translated_id, skip_special_tokens=True)
39 |
40 | st.subheader('Translated Text')
41 | st.write(' '.join(translated_word))
42 |
--------------------------------------------------------------------------------
/Text_Classification_BERT/bert_medium.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "6fed133d-61b7-4ce6-8a44-fe98acf0eed2",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%%capture\n",
11 | "!pip install transformers"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "2f3a8fd1-25a9-426d-a6be-c93b750cbcb8",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import pandas as pd\n",
22 | "import torch\n",
23 | "import numpy as np\n",
24 | "from transformers import BertTokenizer, BertModel\n",
25 | "from torch import nn\n",
26 | "from torch.optim import Adam\n",
27 | "from tqdm import tqdm"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "id": "47a53036-31ab-4374-bf15-a4dca17a7cbf",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "datapath = f'/content/drive/My Drive/Medium/bbc-text.csv'\n",
38 | "df = pd.read_csv(datapath)\n",
39 | "df.head()"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "ab965eff-e1eb-416f-b80c-850554d8026c",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "df.groupby(['category']).size().plot.bar()"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "5074c270-ed3e-4e1a-863d-71737c743cb8",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
60 | "labels = {'business':0,\n",
61 | " 'entertainment':1,\n",
62 | " 'sport':2,\n",
63 | " 'tech':3,\n",
64 | " 'politics':4\n",
65 | " }\n",
66 | "\n",
67 | "class Dataset(torch.utils.data.Dataset):\n",
68 | "\n",
69 | " def __init__(self, df):\n",
70 | "\n",
71 | " self.labels = [labels[label] for label in df['category']]\n",
72 | " self.texts = [tokenizer(text, \n",
73 | " padding='max_length', max_length = 512, truncation=True,\n",
74 | " return_tensors=\"pt\") for text in df['text']]\n",
75 | "\n",
76 | " def classes(self):\n",
77 | " return self.labels\n",
78 | "\n",
79 | " def __len__(self):\n",
80 | " return len(self.labels)\n",
81 | "\n",
82 | " def get_batch_labels(self, idx):\n",
83 | " # Fetch a batch of labels\n",
84 | " return np.array(self.labels[idx])\n",
85 | "\n",
86 | " def get_batch_texts(self, idx):\n",
87 | " # Fetch a batch of inputs\n",
88 | " return self.texts[idx]\n",
89 | "\n",
90 | " def __getitem__(self, idx):\n",
91 | "\n",
92 | " batch_texts = self.get_batch_texts(idx)\n",
93 | " batch_y = self.get_batch_labels(idx)\n",
94 | "\n",
95 | " return batch_texts, batch_y"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "0c8a5d0f-80c3-42b3-9f06-ecfc3a21f395",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "class BertClassifier(nn.Module):\n",
106 | "\n",
107 | " def __init__(self, dropout=0.5):\n",
108 | "\n",
109 | " super(BertClassifier, self).__init__()\n",
110 | "\n",
111 | " self.bert = BertModel.from_pretrained('bert-base-cased')\n",
112 | " self.dropout = nn.Dropout(dropout)\n",
113 | " self.linear = nn.Linear(768, 5)\n",
114 | " self.relu = nn.ReLU()\n",
115 | "\n",
116 | " def forward(self, input_id, mask):\n",
117 | "\n",
118 | " _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)\n",
119 | " dropout_output = self.dropout(pooled_output)\n",
120 | " linear_output = self.linear(dropout_output)\n",
121 | " final_layer = self.relu(linear_output)\n",
122 | "\n",
123 | " return final_layer"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "fa1f1cf7-65db-4966-9a55-ba26bd22ed6c",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def train(model, train_data, val_data, learning_rate, epochs):\n",
134 | "\n",
135 | " train, val = Dataset(train_data), Dataset(val_data)\n",
136 | "\n",
137 | " train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)\n",
138 | " val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)\n",
139 | "\n",
140 | " use_cuda = torch.cuda.is_available()\n",
141 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
142 | "\n",
143 | " criterion = nn.CrossEntropyLoss()\n",
144 | " optimizer = Adam(model.parameters(), lr= learning_rate)\n",
145 | "\n",
146 | " if use_cuda:\n",
147 | "\n",
148 | " model = model.cuda()\n",
149 | " criterion = criterion.cuda()\n",
150 | "\n",
151 | " for epoch_num in range(epochs):\n",
152 | "\n",
153 | " total_acc_train = 0\n",
154 | " total_loss_train = 0\n",
155 | "\n",
156 | " for train_input, train_label in tqdm(train_dataloader):\n",
157 | "\n",
158 | " train_label = train_label.to(device)\n",
159 | " mask = train_input['attention_mask'].to(device)\n",
160 | " input_id = train_input['input_ids'].squeeze(1).to(device)\n",
161 | "\n",
162 | " output = model(input_id, mask)\n",
163 | " \n",
164 | " batch_loss = criterion(output, train_label.long())\n",
165 | " total_loss_train += batch_loss.item()\n",
166 | " \n",
167 | " acc = (output.argmax(dim=1) == train_label).sum().item()\n",
168 | " total_acc_train += acc\n",
169 | "\n",
170 | " model.zero_grad()\n",
171 | " batch_loss.backward()\n",
172 | " optimizer.step()\n",
173 | " \n",
174 | " total_acc_val = 0\n",
175 | " total_loss_val = 0\n",
176 | "\n",
177 | " with torch.no_grad():\n",
178 | "\n",
179 | " for val_input, val_label in val_dataloader:\n",
180 | "\n",
181 | " val_label = val_label.to(device)\n",
182 | " mask = val_input['attention_mask'].to(device)\n",
183 | " input_id = val_input['input_ids'].squeeze(1).to(device)\n",
184 | "\n",
185 | " output = model(input_id, mask)\n",
186 | "\n",
187 | " batch_loss = criterion(output, val_label.long())\n",
188 | " total_loss_val += batch_loss.item()\n",
189 | " \n",
190 | " acc = (output.argmax(dim=1) == val_label).sum().item()\n",
191 | " total_acc_val += acc\n",
192 | " \n",
193 | " print(\n",
194 | " f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} | Train Accuracy: {total_acc_train / len(train_data): .3f} | Val Loss: {total_loss_val / len(val_data): .3f} | Val Accuracy: {total_acc_val / len(val_data): .3f}')\n",
195 | " "
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "cd8a670d-c449-45fe-8f4c-9a5fb27855c1",
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "def evaluate(model, test_data):\n",
206 | "\n",
207 | " test = Dataset(test_data)\n",
208 | "\n",
209 | " test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)\n",
210 | "\n",
211 | " use_cuda = torch.cuda.is_available()\n",
212 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
213 | "\n",
214 | " if use_cuda:\n",
215 | "\n",
216 | " model = model.cuda()\n",
217 | "\n",
218 | " total_acc_test = 0\n",
219 | " with torch.no_grad():\n",
220 | "\n",
221 | " for test_input, test_label in test_dataloader:\n",
222 | "\n",
223 | " test_label = test_label.to(device)\n",
224 | " mask = test_input['attention_mask'].to(device)\n",
225 | " input_id = test_input['input_ids'].squeeze(1).to(device)\n",
226 | "\n",
227 | " output = model(input_id, mask)\n",
228 | "\n",
229 | " acc = (output.argmax(dim=1) == test_label).sum().item()\n",
230 | " total_acc_test += acc\n",
231 | " \n",
232 | " print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "25d2231d-fef1-42cf-a73e-188cac932727",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "np.random.seed(112)\n",
243 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), \n",
244 | " [int(.8*len(df)), int(.9*len(df))])\n",
245 | "\n",
246 | "print(len(df_train),len(df_val), len(df_test))"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "id": "30242239-de70-4c03-8f56-9f5ade43518d",
253 | "metadata": {},
254 | "outputs": [],
255 | "source": [
256 | "EPOCHS = 5\n",
257 | "model = BertClassifier()\n",
258 | "LR = 1e-6\n",
259 | " \n",
260 | "train(model, df_train, df_val, LR, EPOCHS)"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "id": "ccc00f0a-9a15-4942-9c9b-2f9789c8dd22",
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "evaluate(model, df_test)"
271 | ]
272 | }
273 | ],
274 | "metadata": {
275 | "kernelspec": {
276 | "display_name": "Python 3",
277 | "language": "python",
278 | "name": "python3"
279 | },
280 | "language_info": {
281 | "codemirror_mode": {
282 | "name": "ipython",
283 | "version": 3
284 | },
285 | "file_extension": ".py",
286 | "mimetype": "text/x-python",
287 | "name": "python",
288 | "nbconvert_exporter": "python",
289 | "pygments_lexer": "ipython3",
290 | "version": "3.8.8"
291 | }
292 | },
293 | "nbformat": 4,
294 | "nbformat_minor": 5
295 | }
296 |
--------------------------------------------------------------------------------