├── .DS_Store ├── BERT_Captum ├── .DS_Store └── Bert_captum.ipynb ├── Bar Chart Race └── bar chart race.ipynb ├── Bottleneck_Adapters └── Bottleneck_Adapters_Medium.ipynb ├── GPT2_TextGeneration └── GPT_2_Medium.ipynb ├── Layout_Parser ├── img │ ├── doc_1.pdf │ └── doc_2.pdf └── layout_parser_ex.ipynb ├── Lime ├── LIME_image_class.ipynb └── panda_00024.jpg ├── NER_BERT ├── .ipynb_checkpoints │ └── NER_with_BERT-checkpoint.ipynb └── NER_with_BERT.ipynb ├── Optuna └── Optuna.ipynb ├── README.md ├── STS_BERT └── STS_BERT.ipynb ├── Spaces_Translation_App └── app.py ├── Text_Classification_BERT └── bert_medium.ipynb ├── Text_Classification_Transformer_Encoders └── Transformer_Encoder.ipynb └── ViT └── Vision_Transformer.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/.DS_Store -------------------------------------------------------------------------------- /BERT_Captum/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/BERT_Captum/.DS_Store -------------------------------------------------------------------------------- /BERT_Captum/Bert_captum.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f0f52175-a7e8-45ab-a965-0d4a1e8e760e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%%capture\n", 11 | "!pip install transformers\n", 12 | "!pip install captum" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "90853bde-95df-4827-98af-70410794c502", 18 | "metadata": {}, 19 | "source": [ 20 | "# Tokenization Example" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "3d970c20-8f80-4421-bcc3-85a581cf9738", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from transformers import BertTokenizer\n", 31 | "\n", 32 | "# Instantiate tokenizer\n", 33 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n", 34 | "\n", 35 | "text = 'The movie is superb'\n", 36 | "\n", 37 | "# Tokenize input text\n", 38 | "text_ids = tokenizer.encode(text, add_special_tokens=True)\n", 39 | "\n", 40 | "# Print the tokens\n", 41 | "print(tokenizer.convert_ids_to_tokens(text_ids))\n", 42 | "# Output: ['[CLS]', 'The', 'movie', 'is', 'superb', '[SEP]']\n", 43 | "\n", 44 | "# Print the ids of the tokens\n", 45 | "print(text_ids)\n", 46 | "# Output: [101, 1109, 2523, 1110, 25876, 102]" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "62e541e1-da3d-47dd-a278-7950b4fb0e54", 52 | "metadata": {}, 53 | "source": [ 54 | "# Minimal Example to Fetch the Embeddings of Tokens" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "2fb12068-b429-4411-bda5-bedff656d387", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from transformers import BertModel\n", 65 | "import torch\n", 66 | "# Instantiate BERT model\n", 67 | "model = BertModel.from_pretrained('bert-base-cased')\n", 68 | "\n", 69 | "embeddings = model.embeddings(torch.tensor([text_ids]))\n", 70 | "print(embeddings.size())\n", 71 | "# Output: torch.Size([1, 6, 768]), since there are 6 tokens in text_ids" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "cc1d6810-edf6-4b1d-a48e-44e034a32a90", 77 | "metadata": {}, 78 | "source": [ 79 | "# Specify Model Architecture" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "2ea24f09-9fef-469c-8884-794fda046a5a", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from torch import nn\n", 90 | "\n", 91 | "class BertClassifier(nn.Module):\n", 92 | "\n", 93 | " def __init__(self, dropout=0.5):\n", 94 | "\n", 95 | " super(BertClassifier, self).__init__()\n", 96 | "\n", 97 | " self.bert = BertModel.from_pretrained('bert-base-cased')\n", 98 | " self.dropout = nn.Dropout(dropout)\n", 99 | " self.linear = nn.Linear(768, 2)\n", 100 | " self.relu = nn.ReLU()\n", 101 | "\n", 102 | " def forward(self, input_id, mask = None):\n", 103 | "\n", 104 | " _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)\n", 105 | " dropout_output = self.dropout(pooled_output)\n", 106 | " linear_output = self.linear(dropout_output)\n", 107 | " final_layer = self.relu(linear_output)\n", 108 | "\n", 109 | " return final_layer" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "ee6f1929-401c-40c2-a7a2-4a061380f633", 115 | "metadata": {}, 116 | "source": [ 117 | "# Load Model's Parameters " 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "1da8d59e-4ed4-4ca2-bcdb-93857a9295e7", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "model = BertClassifier()\n", 128 | "model.load_state_dict(torch.load('path/to/bert_model.pt', map_location=torch.device('cpu')))\n", 129 | "model.eval()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "8d29dd60-b4cc-43d6-b5a7-44f56adca3be", 135 | "metadata": {}, 136 | "source": [ 137 | "# Define Model Input and Output" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "1c8bc95b-99f1-4087-a012-8306fe328daa", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# Define model output\n", 148 | "def model_output(inputs):\n", 149 | " return model(inputs)[0]\n", 150 | "\n", 151 | "# Define model input\n", 152 | "model_input = model.bert.embeddings" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "d04dae59-850a-4b11-9684-cdcba55c6e2c", 158 | "metadata": {}, 159 | "source": [ 160 | "# Instantiate Integrated Gradients Method" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "e8e805ad-9c5d-40b6-bb9d-890eecde0945", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "from captum.attr import LayerIntegratedGradients\n", 171 | "lig = LayerIntegratedGradients(model_output, model_input)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "id": "86492c38-8201-4d54-aba0-de21e4966a42", 177 | "metadata": {}, 178 | "source": [ 179 | "# Construct Original and Baseline Input" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "563b09b7-a6c2-4e67-bc11-2b25aa3a4a4c", 186 | "metadata": { 187 | "tags": [] 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "def construct_input_and_baseline(text):\n", 192 | "\n", 193 | " max_length = 510\n", 194 | " baseline_token_id = tokenizer.pad_token_id \n", 195 | " sep_token_id = tokenizer.sep_token_id \n", 196 | " cls_token_id = tokenizer.cls_token_id \n", 197 | "\n", 198 | " text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)\n", 199 | " \n", 200 | " input_ids = [cls_token_id] + text_ids + [sep_token_id]\n", 201 | " token_list = tokenizer.convert_ids_to_tokens(input_ids)\n", 202 | " \n", 203 | "\n", 204 | " baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]\n", 205 | " return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list\n", 206 | "\n", 207 | "text = 'This movie is superb'\n", 208 | "input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)\n", 209 | "\n", 210 | "print(f'original text: {input_ids}')\n", 211 | "print(f'baseline text: {baseline_input_ids}')\n", 212 | "\n", 213 | "# Output: original text: tensor([[ 101, 1109, 2523, 1110, 25876, 102]])\n", 214 | "# Output: baseline text: tensor([[101, 0, 0, 0, 0, 102]])" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "1aa548b0-e738-45fe-bc0f-0d490427b79b", 220 | "metadata": {}, 221 | "source": [ 222 | "# Compute Attributions" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "3db56682-41f3-46ec-b331-c81617391fec", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "attributions, delta = lig.attribute(inputs= input_ids,\n", 233 | " baselines= baseline_input_ids,\n", 234 | " return_convergence_delta=True\n", 235 | " )\n", 236 | "print(attributions.size())\n", 237 | "# Output: torch.Size([1, 6, 768])" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "864769c7-2378-42a2-8f35-76bd44af63ac", 243 | "metadata": {}, 244 | "source": [ 245 | "# Compute Attribution for Each Token" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "826a93e2-4b4b-430a-9c20-10cbb5ac08e1", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "def summarize_attributions(attributions):\n", 256 | "\n", 257 | " attributions = attributions.sum(dim=-1).squeeze(0)\n", 258 | " attributions = attributions / torch.norm(attributions)\n", 259 | " \n", 260 | " return attributions\n", 261 | "\n", 262 | "attributions_sum = summarize_attributions(attributions)\n", 263 | "print(attributions_sum.size())\n", 264 | "# Output: torch.Size([6])" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "id": "0123bbd3-536f-48cb-a0fa-c9a6fd9557df", 270 | "metadata": {}, 271 | "source": [ 272 | "# Visualize the Interpretation" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "af2eaaa4-0fa4-4928-a7e2-a727f8e56fdd", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "from captum.attr import visualization as viz\n", 283 | "\n", 284 | "score_vis = viz.VisualizationDataRecord(\n", 285 | " word_attributions = attributions_sum,\n", 286 | " pred_prob = torch.max(model(input_ids)[0]),\n", 287 | " pred_class = torch.argmax(model(input_ids)[0]).numpy(),\n", 288 | " true_class = 1,\n", 289 | " attr_class = text,\n", 290 | " attr_score = attributions_sum.sum(), \n", 291 | " raw_input_ids = all_tokens,\n", 292 | " convergence_score = delta)\n", 293 | "\n", 294 | "viz.visualize_text([score_vis])" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "a7c16834-2b58-4d96-b56b-f5908468f64a", 300 | "metadata": {}, 301 | "source": [ 302 | "# Encapsulate All the Steps Above" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "b6c9ec71-84ca-4734-afd0-b006c3a711df", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "def interpret_text(text, true_class):\n", 313 | "\n", 314 | " input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)\n", 315 | " attributions, delta = lig.attribute(inputs= input_ids,\n", 316 | " baselines= baseline_input_ids,\n", 317 | " return_convergence_delta=True\n", 318 | " )\n", 319 | " attributions_sum = summarize_attributions(attributions)\n", 320 | "\n", 321 | " score_vis = viz.VisualizationDataRecord(\n", 322 | " word_attributions = attributions_sum,\n", 323 | " pred_prob = torch.max(model(input_ids)[0]),\n", 324 | " pred_class = torch.argmax(model(input_ids)[0]).numpy(),\n", 325 | " true_class = true_class,\n", 326 | " attr_class = text,\n", 327 | " attr_score = attributions_sum.sum(), \n", 328 | " raw_input_ids = all_tokens,\n", 329 | " convergence_score = delta)\n", 330 | "\n", 331 | " viz.visualize_text([score_vis])" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "id": "57984aee-32e9-418a-a286-f35e0bc67b70", 337 | "metadata": {}, 338 | "source": [ 339 | "# Interpret Texts" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "id": "d530db94-52fa-4c9a-a673-27247eac3b99", 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "text = \"It's a heartfelt film about love, loss, and legacy\"\n", 350 | "true_class = 1\n", 351 | "interpret_text(text, true_class)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "id": "8544cb5c-3afd-45a7-ae82-7f9e692ff223", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "text = \"A noisy, hideous, and viciously cumbersome movie\"\n", 362 | "true_class = 0\n", 363 | "interpret_text(text, true_class)" 364 | ] 365 | } 366 | ], 367 | "metadata": { 368 | "kernelspec": { 369 | "display_name": "Python 3 (ipykernel)", 370 | "language": "python", 371 | "name": "python3" 372 | }, 373 | "language_info": { 374 | "codemirror_mode": { 375 | "name": "ipython", 376 | "version": 3 377 | }, 378 | "file_extension": ".py", 379 | "mimetype": "text/x-python", 380 | "name": "python", 381 | "nbconvert_exporter": "python", 382 | "pygments_lexer": "ipython3", 383 | "version": "3.9.7" 384 | } 385 | }, 386 | "nbformat": 4, 387 | "nbformat_minor": 5 388 | } 389 | -------------------------------------------------------------------------------- /Bar Chart Race/bar chart race.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "prem_league = pd.read_csv('D:/PL Dataset/premierLeague_tables_1992-2017.csv')" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/html": [ 30 | "
\n", 31 | "\n", 44 | "\n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | "
seasonteampointswdlgfgagdpld...d_hd_al_hl_agf_hgf_aga_hga_agd_hgd_a
02017-18Manchester City1003242106277938...2211614514134732
12017-18Manchester United81256768284038...242538309192911
22017-18Tottenham Hotspur77238774363838...4425403416202414
32017-18Liverpool752112584384638...7505453910283511
42017-18Chelsea702171062382438...4346303216221410
\n", 194 | "

5 rows × 26 columns

\n", 195 | "
" 196 | ], 197 | "text/plain": [ 198 | " season team points w d l gf ga gd pld ... d_h \\\n", 199 | "0 2017-18 Manchester City 100 32 4 2 106 27 79 38 ... 2 \n", 200 | "1 2017-18 Manchester United 81 25 6 7 68 28 40 38 ... 2 \n", 201 | "2 2017-18 Tottenham Hotspur 77 23 8 7 74 36 38 38 ... 4 \n", 202 | "3 2017-18 Liverpool 75 21 12 5 84 38 46 38 ... 7 \n", 203 | "4 2017-18 Chelsea 70 21 7 10 62 38 24 38 ... 4 \n", 204 | "\n", 205 | " d_a l_h l_a gf_h gf_a ga_h ga_a gd_h gd_a \n", 206 | "0 2 1 1 61 45 14 13 47 32 \n", 207 | "1 4 2 5 38 30 9 19 29 11 \n", 208 | "2 4 2 5 40 34 16 20 24 14 \n", 209 | "3 5 0 5 45 39 10 28 35 11 \n", 210 | "4 3 4 6 30 32 16 22 14 10 \n", 211 | "\n", 212 | "[5 rows x 26 columns]" 213 | ] 214 | }, 215 | "execution_count": 3, 216 | "metadata": {}, 217 | "output_type": "execute_result" 218 | } 219 | ], 220 | "source": [ 221 | "prem_league.head()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 4, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "prem_league = prem_league[['season', 'team', 'points']]" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 5, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/html": [ 241 | "
\n", 242 | "\n", 255 | "\n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | "
seasonteampoints
02017-18Manchester City100
12017-18Manchester United81
22017-18Tottenham Hotspur77
32017-18Liverpool75
42017-18Chelsea70
\n", 297 | "
" 298 | ], 299 | "text/plain": [ 300 | " season team points\n", 301 | "0 2017-18 Manchester City 100\n", 302 | "1 2017-18 Manchester United 81\n", 303 | "2 2017-18 Tottenham Hotspur 77\n", 304 | "3 2017-18 Liverpool 75\n", 305 | "4 2017-18 Chelsea 70" 306 | ] 307 | }, 308 | "execution_count": 5, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "prem_league.head()" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 6, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "season object\n", 326 | "team object\n", 327 | "points int64\n", 328 | "dtype: object" 329 | ] 330 | }, 331 | "execution_count": 6, 332 | "metadata": {}, 333 | "output_type": "execute_result" 334 | } 335 | ], 336 | "source": [ 337 | "prem_league.dtypes" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 7, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "df = prem_league.pivot_table(values = 'points',index = ['season'], columns = 'team')" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 8, 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "data": { 356 | "text/html": [ 357 | "
\n", 358 | "\n", 371 | "\n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | "
teamArsenalAston VillaBarnsleyBirmingham CityBlackburn RoversBlackpoolBolton WanderersBournemouthBradford CityBrighton and Hove Albion...SunderlandSwansea CitySwindon TownTottenham HotspurWatfordWest Bromwich AlbionWest Ham UnitedWigan AthleticWimbledon FCWolverhampton Wanderers
season
1992-9356.074.0NaNNaN71.0NaNNaNNaNNaNNaN...NaNNaNNaN59.0NaNNaNNaNNaN54.0NaN
1993-9471.057.0NaNNaN84.0NaNNaNNaNNaNNaN...NaNNaN30.045.0NaNNaN52.0NaN65.0NaN
1994-9551.048.0NaNNaN89.0NaNNaNNaNNaNNaN...NaNNaNNaN62.0NaNNaN50.0NaN56.0NaN
1995-9663.063.0NaNNaN61.0NaN29.0NaNNaNNaN...NaNNaNNaN61.0NaNNaN51.0NaN41.0NaN
1996-9768.061.0NaNNaN42.0NaNNaNNaNNaNNaN...40.0NaNNaN46.0NaNNaN42.0NaN56.0NaN
\n", 545 | "

5 rows × 49 columns

\n", 546 | "
" 547 | ], 548 | "text/plain": [ 549 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n", 550 | "season \n", 551 | "1992-93 56.0 74.0 NaN NaN 71.0 \n", 552 | "1993-94 71.0 57.0 NaN NaN 84.0 \n", 553 | "1994-95 51.0 48.0 NaN NaN 89.0 \n", 554 | "1995-96 63.0 63.0 NaN NaN 61.0 \n", 555 | "1996-97 68.0 61.0 NaN NaN 42.0 \n", 556 | "\n", 557 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n", 558 | "season \n", 559 | "1992-93 NaN NaN NaN NaN \n", 560 | "1993-94 NaN NaN NaN NaN \n", 561 | "1994-95 NaN NaN NaN NaN \n", 562 | "1995-96 NaN 29.0 NaN NaN \n", 563 | "1996-97 NaN NaN NaN NaN \n", 564 | "\n", 565 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n", 566 | "season ... \n", 567 | "1992-93 NaN ... NaN NaN \n", 568 | "1993-94 NaN ... NaN NaN \n", 569 | "1994-95 NaN ... NaN NaN \n", 570 | "1995-96 NaN ... NaN NaN \n", 571 | "1996-97 NaN ... 40.0 NaN \n", 572 | "\n", 573 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n", 574 | "season \n", 575 | "1992-93 NaN 59.0 NaN NaN \n", 576 | "1993-94 30.0 45.0 NaN NaN \n", 577 | "1994-95 NaN 62.0 NaN NaN \n", 578 | "1995-96 NaN 61.0 NaN NaN \n", 579 | "1996-97 NaN 46.0 NaN NaN \n", 580 | "\n", 581 | "team West Ham United Wigan Athletic Wimbledon FC \\\n", 582 | "season \n", 583 | "1992-93 NaN NaN 54.0 \n", 584 | "1993-94 52.0 NaN 65.0 \n", 585 | "1994-95 50.0 NaN 56.0 \n", 586 | "1995-96 51.0 NaN 41.0 \n", 587 | "1996-97 42.0 NaN 56.0 \n", 588 | "\n", 589 | "team Wolverhampton Wanderers \n", 590 | "season \n", 591 | "1992-93 NaN \n", 592 | "1993-94 NaN \n", 593 | "1994-95 NaN \n", 594 | "1995-96 NaN \n", 595 | "1996-97 NaN \n", 596 | "\n", 597 | "[5 rows x 49 columns]" 598 | ] 599 | }, 600 | "execution_count": 8, 601 | "metadata": {}, 602 | "output_type": "execute_result" 603 | } 604 | ], 605 | "source": [ 606 | "df.head()" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 9, 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "df.fillna(0, inplace=True)\n", 616 | "df.sort_values(list(df.columns),inplace=True)\n", 617 | "df = df.sort_index()" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 10, 623 | "metadata": {}, 624 | "outputs": [ 625 | { 626 | "data": { 627 | "text/html": [ 628 | "
\n", 629 | "\n", 642 | "\n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | "
teamArsenalAston VillaBarnsleyBirmingham CityBlackburn RoversBlackpoolBolton WanderersBournemouthBradford CityBrighton and Hove Albion...SunderlandSwansea CitySwindon TownTottenham HotspurWatfordWest Bromwich AlbionWest Ham UnitedWigan AthleticWimbledon FCWolverhampton Wanderers
season
1992-9356.074.00.00.071.00.00.00.00.00.0...0.00.00.059.00.00.00.00.054.00.0
1993-9471.057.00.00.084.00.00.00.00.00.0...0.00.030.045.00.00.052.00.065.00.0
1994-9551.048.00.00.089.00.00.00.00.00.0...0.00.00.062.00.00.050.00.056.00.0
1995-9663.063.00.00.061.00.029.00.00.00.0...0.00.00.061.00.00.051.00.041.00.0
1996-9768.061.00.00.042.00.00.00.00.00.0...40.00.00.046.00.00.042.00.056.00.0
\n", 816 | "

5 rows × 49 columns

\n", 817 | "
" 818 | ], 819 | "text/plain": [ 820 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n", 821 | "season \n", 822 | "1992-93 56.0 74.0 0.0 0.0 71.0 \n", 823 | "1993-94 71.0 57.0 0.0 0.0 84.0 \n", 824 | "1994-95 51.0 48.0 0.0 0.0 89.0 \n", 825 | "1995-96 63.0 63.0 0.0 0.0 61.0 \n", 826 | "1996-97 68.0 61.0 0.0 0.0 42.0 \n", 827 | "\n", 828 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n", 829 | "season \n", 830 | "1992-93 0.0 0.0 0.0 0.0 \n", 831 | "1993-94 0.0 0.0 0.0 0.0 \n", 832 | "1994-95 0.0 0.0 0.0 0.0 \n", 833 | "1995-96 0.0 29.0 0.0 0.0 \n", 834 | "1996-97 0.0 0.0 0.0 0.0 \n", 835 | "\n", 836 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n", 837 | "season ... \n", 838 | "1992-93 0.0 ... 0.0 0.0 \n", 839 | "1993-94 0.0 ... 0.0 0.0 \n", 840 | "1994-95 0.0 ... 0.0 0.0 \n", 841 | "1995-96 0.0 ... 0.0 0.0 \n", 842 | "1996-97 0.0 ... 40.0 0.0 \n", 843 | "\n", 844 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n", 845 | "season \n", 846 | "1992-93 0.0 59.0 0.0 0.0 \n", 847 | "1993-94 30.0 45.0 0.0 0.0 \n", 848 | "1994-95 0.0 62.0 0.0 0.0 \n", 849 | "1995-96 0.0 61.0 0.0 0.0 \n", 850 | "1996-97 0.0 46.0 0.0 0.0 \n", 851 | "\n", 852 | "team West Ham United Wigan Athletic Wimbledon FC \\\n", 853 | "season \n", 854 | "1992-93 0.0 0.0 54.0 \n", 855 | "1993-94 52.0 0.0 65.0 \n", 856 | "1994-95 50.0 0.0 56.0 \n", 857 | "1995-96 51.0 0.0 41.0 \n", 858 | "1996-97 42.0 0.0 56.0 \n", 859 | "\n", 860 | "team Wolverhampton Wanderers \n", 861 | "season \n", 862 | "1992-93 0.0 \n", 863 | "1993-94 0.0 \n", 864 | "1994-95 0.0 \n", 865 | "1995-96 0.0 \n", 866 | "1996-97 0.0 \n", 867 | "\n", 868 | "[5 rows x 49 columns]" 869 | ] 870 | }, 871 | "execution_count": 10, 872 | "metadata": {}, 873 | "output_type": "execute_result" 874 | } 875 | ], 876 | "source": [ 877 | "df.head()" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": 11, 883 | "metadata": {}, 884 | "outputs": [], 885 | "source": [ 886 | "df.iloc[:, 0:-1] = df.iloc[:, 0:-1].cumsum()" 887 | ] 888 | }, 889 | { 890 | "cell_type": "code", 891 | "execution_count": 12, 892 | "metadata": { 893 | "scrolled": true 894 | }, 895 | "outputs": [ 896 | { 897 | "data": { 898 | "text/html": [ 899 | "
\n", 900 | "\n", 913 | "\n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | "
teamArsenalAston VillaBarnsleyBirmingham CityBlackburn RoversBlackpoolBolton WanderersBournemouthBradford CityBrighton and Hove Albion...SunderlandSwansea CitySwindon TownTottenham HotspurWatfordWest Bromwich AlbionWest Ham UnitedWigan AthleticWimbledon FCWolverhampton Wanderers
season
1992-9356.074.00.00.071.00.00.00.00.00.0...0.00.00.059.00.00.00.00.054.00.0
1993-94127.0131.00.00.0155.00.00.00.00.00.0...0.00.030.0104.00.00.052.00.0119.00.0
1994-95178.0179.00.00.0244.00.00.00.00.00.0...0.00.030.0166.00.00.0102.00.0175.00.0
1995-96241.0242.00.00.0305.00.029.00.00.00.0...0.00.030.0227.00.00.0153.00.0216.00.0
1996-97309.0303.00.00.0347.00.029.00.00.00.0...40.00.030.0273.00.00.0195.00.0272.00.0
1997-98387.0360.035.00.0405.00.069.00.00.00.0...40.00.030.0317.00.00.0251.00.0316.00.0
1998-99465.0415.035.00.0440.00.069.00.00.00.0...40.00.030.0364.00.00.0308.00.0358.00.0
1999-00538.0473.035.00.0440.00.069.00.036.00.0...98.00.030.0417.024.00.0363.00.0391.00.0
2000-01608.0527.035.00.0440.00.069.00.062.00.0...155.00.030.0466.024.00.0405.00.0391.00.0
2001-02695.0577.035.00.0486.00.0109.00.062.00.0...195.00.030.0516.024.00.0458.00.0391.00.0
\n", 1207 | "

10 rows × 49 columns

\n", 1208 | "
" 1209 | ], 1210 | "text/plain": [ 1211 | "team Arsenal Aston Villa Barnsley Birmingham City Blackburn Rovers \\\n", 1212 | "season \n", 1213 | "1992-93 56.0 74.0 0.0 0.0 71.0 \n", 1214 | "1993-94 127.0 131.0 0.0 0.0 155.0 \n", 1215 | "1994-95 178.0 179.0 0.0 0.0 244.0 \n", 1216 | "1995-96 241.0 242.0 0.0 0.0 305.0 \n", 1217 | "1996-97 309.0 303.0 0.0 0.0 347.0 \n", 1218 | "1997-98 387.0 360.0 35.0 0.0 405.0 \n", 1219 | "1998-99 465.0 415.0 35.0 0.0 440.0 \n", 1220 | "1999-00 538.0 473.0 35.0 0.0 440.0 \n", 1221 | "2000-01 608.0 527.0 35.0 0.0 440.0 \n", 1222 | "2001-02 695.0 577.0 35.0 0.0 486.0 \n", 1223 | "\n", 1224 | "team Blackpool Bolton Wanderers Bournemouth Bradford City \\\n", 1225 | "season \n", 1226 | "1992-93 0.0 0.0 0.0 0.0 \n", 1227 | "1993-94 0.0 0.0 0.0 0.0 \n", 1228 | "1994-95 0.0 0.0 0.0 0.0 \n", 1229 | "1995-96 0.0 29.0 0.0 0.0 \n", 1230 | "1996-97 0.0 29.0 0.0 0.0 \n", 1231 | "1997-98 0.0 69.0 0.0 0.0 \n", 1232 | "1998-99 0.0 69.0 0.0 0.0 \n", 1233 | "1999-00 0.0 69.0 0.0 36.0 \n", 1234 | "2000-01 0.0 69.0 0.0 62.0 \n", 1235 | "2001-02 0.0 109.0 0.0 62.0 \n", 1236 | "\n", 1237 | "team Brighton and Hove Albion ... Sunderland Swansea City \\\n", 1238 | "season ... \n", 1239 | "1992-93 0.0 ... 0.0 0.0 \n", 1240 | "1993-94 0.0 ... 0.0 0.0 \n", 1241 | "1994-95 0.0 ... 0.0 0.0 \n", 1242 | "1995-96 0.0 ... 0.0 0.0 \n", 1243 | "1996-97 0.0 ... 40.0 0.0 \n", 1244 | "1997-98 0.0 ... 40.0 0.0 \n", 1245 | "1998-99 0.0 ... 40.0 0.0 \n", 1246 | "1999-00 0.0 ... 98.0 0.0 \n", 1247 | "2000-01 0.0 ... 155.0 0.0 \n", 1248 | "2001-02 0.0 ... 195.0 0.0 \n", 1249 | "\n", 1250 | "team Swindon Town Tottenham Hotspur Watford West Bromwich Albion \\\n", 1251 | "season \n", 1252 | "1992-93 0.0 59.0 0.0 0.0 \n", 1253 | "1993-94 30.0 104.0 0.0 0.0 \n", 1254 | "1994-95 30.0 166.0 0.0 0.0 \n", 1255 | "1995-96 30.0 227.0 0.0 0.0 \n", 1256 | "1996-97 30.0 273.0 0.0 0.0 \n", 1257 | "1997-98 30.0 317.0 0.0 0.0 \n", 1258 | "1998-99 30.0 364.0 0.0 0.0 \n", 1259 | "1999-00 30.0 417.0 24.0 0.0 \n", 1260 | "2000-01 30.0 466.0 24.0 0.0 \n", 1261 | "2001-02 30.0 516.0 24.0 0.0 \n", 1262 | "\n", 1263 | "team West Ham United Wigan Athletic Wimbledon FC \\\n", 1264 | "season \n", 1265 | "1992-93 0.0 0.0 54.0 \n", 1266 | "1993-94 52.0 0.0 119.0 \n", 1267 | "1994-95 102.0 0.0 175.0 \n", 1268 | "1995-96 153.0 0.0 216.0 \n", 1269 | "1996-97 195.0 0.0 272.0 \n", 1270 | "1997-98 251.0 0.0 316.0 \n", 1271 | "1998-99 308.0 0.0 358.0 \n", 1272 | "1999-00 363.0 0.0 391.0 \n", 1273 | "2000-01 405.0 0.0 391.0 \n", 1274 | "2001-02 458.0 0.0 391.0 \n", 1275 | "\n", 1276 | "team Wolverhampton Wanderers \n", 1277 | "season \n", 1278 | "1992-93 0.0 \n", 1279 | "1993-94 0.0 \n", 1280 | "1994-95 0.0 \n", 1281 | "1995-96 0.0 \n", 1282 | "1996-97 0.0 \n", 1283 | "1997-98 0.0 \n", 1284 | "1998-99 0.0 \n", 1285 | "1999-00 0.0 \n", 1286 | "2000-01 0.0 \n", 1287 | "2001-02 0.0 \n", 1288 | "\n", 1289 | "[10 rows x 49 columns]" 1290 | ] 1291 | }, 1292 | "execution_count": 12, 1293 | "metadata": {}, 1294 | "output_type": "execute_result" 1295 | } 1296 | ], 1297 | "source": [ 1298 | "df[0:10]" 1299 | ] 1300 | }, 1301 | { 1302 | "cell_type": "code", 1303 | "execution_count": 13, 1304 | "metadata": {}, 1305 | "outputs": [], 1306 | "source": [ 1307 | "top_prem_clubs = set()\n", 1308 | "\n", 1309 | "for index, row in df.iterrows():\n", 1310 | " top_prem_clubs |= set(row[row > 0].sort_values(ascending=False).head(6).index)\n", 1311 | "\n", 1312 | "df = df[top_prem_clubs]" 1313 | ] 1314 | }, 1315 | { 1316 | "cell_type": "code", 1317 | "execution_count": 14, 1318 | "metadata": {}, 1319 | "outputs": [ 1320 | { 1321 | "data": { 1322 | "text/html": [ 1323 | "
\n", 1324 | "\n", 1337 | "\n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | "
teamEvertonQueens Park RangersNewcastle UnitedSheffield WednesdayBlackburn RoversChelseaAston VillaNorwich CityManchester UnitedArsenalLeeds UnitedTottenham HotspurLiverpool
season
1992-9353.063.00.059.071.056.074.072.084.056.051.059.059.0
1993-9497.0123.077.0123.0155.0107.0131.0125.0176.0127.0121.0104.0119.0
1994-95147.0183.0149.0174.0244.0161.0179.0168.0264.0178.0194.0166.0193.0
1995-96208.0216.0227.0214.0305.0211.0242.0168.0346.0241.0237.0227.0264.0
1996-97250.0216.0295.0271.0347.0270.0303.0168.0421.0309.0283.0273.0332.0
\n", 1455 | "
" 1456 | ], 1457 | "text/plain": [ 1458 | "team Everton Queens Park Rangers Newcastle United Sheffield Wednesday \\\n", 1459 | "season \n", 1460 | "1992-93 53.0 63.0 0.0 59.0 \n", 1461 | "1993-94 97.0 123.0 77.0 123.0 \n", 1462 | "1994-95 147.0 183.0 149.0 174.0 \n", 1463 | "1995-96 208.0 216.0 227.0 214.0 \n", 1464 | "1996-97 250.0 216.0 295.0 271.0 \n", 1465 | "\n", 1466 | "team Blackburn Rovers Chelsea Aston Villa Norwich City \\\n", 1467 | "season \n", 1468 | "1992-93 71.0 56.0 74.0 72.0 \n", 1469 | "1993-94 155.0 107.0 131.0 125.0 \n", 1470 | "1994-95 244.0 161.0 179.0 168.0 \n", 1471 | "1995-96 305.0 211.0 242.0 168.0 \n", 1472 | "1996-97 347.0 270.0 303.0 168.0 \n", 1473 | "\n", 1474 | "team Manchester United Arsenal Leeds United Tottenham Hotspur \\\n", 1475 | "season \n", 1476 | "1992-93 84.0 56.0 51.0 59.0 \n", 1477 | "1993-94 176.0 127.0 121.0 104.0 \n", 1478 | "1994-95 264.0 178.0 194.0 166.0 \n", 1479 | "1995-96 346.0 241.0 237.0 227.0 \n", 1480 | "1996-97 421.0 309.0 283.0 273.0 \n", 1481 | "\n", 1482 | "team Liverpool \n", 1483 | "season \n", 1484 | "1992-93 59.0 \n", 1485 | "1993-94 119.0 \n", 1486 | "1994-95 193.0 \n", 1487 | "1995-96 264.0 \n", 1488 | "1996-97 332.0 " 1489 | ] 1490 | }, 1491 | "execution_count": 14, 1492 | "metadata": {}, 1493 | "output_type": "execute_result" 1494 | } 1495 | ], 1496 | "source": [ 1497 | "df.head()" 1498 | ] 1499 | }, 1500 | { 1501 | "cell_type": "code", 1502 | "execution_count": 15, 1503 | "metadata": {}, 1504 | "outputs": [], 1505 | "source": [ 1506 | "import bar_chart_race as bcr" 1507 | ] 1508 | }, 1509 | { 1510 | "cell_type": "code", 1511 | "execution_count": 22, 1512 | "metadata": {}, 1513 | "outputs": [], 1514 | "source": [ 1515 | "bcr.bar_chart_race(df = df, \n", 1516 | " n_bars = 6, \n", 1517 | " sort='desc',\n", 1518 | " title='Premier League Clubs Points Since 1992',\n", 1519 | " period_length = 750,\n", 1520 | " filename = 'pl_clubs.mp4')" 1521 | ] 1522 | }, 1523 | { 1524 | "cell_type": "code", 1525 | "execution_count": 29, 1526 | "metadata": {}, 1527 | "outputs": [ 1528 | { 1529 | "ename": "TypeError", 1530 | "evalue": "bar_chart_race() got an unexpected keyword argument 'img_label_folder'", 1531 | "output_type": "error", 1532 | "traceback": [ 1533 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 1534 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 1535 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mbcr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbar_chart_race\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg_label_folder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'PL clubs'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_bars\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m6\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mperiod_length\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m750\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 1536 | "\u001b[1;31mTypeError\u001b[0m: bar_chart_race() got an unexpected keyword argument 'img_label_folder'" 1537 | ] 1538 | } 1539 | ], 1540 | "source": [ 1541 | "bcr.bar_chart_race(df, img_label_folder = 'PL clubs', n_bars=6, period_length = 750)" 1542 | ] 1543 | }, 1544 | { 1545 | "cell_type": "code", 1546 | "execution_count": null, 1547 | "metadata": {}, 1548 | "outputs": [], 1549 | "source": [] 1550 | } 1551 | ], 1552 | "metadata": { 1553 | "kernelspec": { 1554 | "display_name": "Python 3", 1555 | "language": "python", 1556 | "name": "python3" 1557 | }, 1558 | "language_info": { 1559 | "codemirror_mode": { 1560 | "name": "ipython", 1561 | "version": 3 1562 | }, 1563 | "file_extension": ".py", 1564 | "mimetype": "text/x-python", 1565 | "name": "python", 1566 | "nbconvert_exporter": "python", 1567 | "pygments_lexer": "ipython3", 1568 | "version": "3.7.6" 1569 | } 1570 | }, 1571 | "nbformat": 4, 1572 | "nbformat_minor": 4 1573 | } 1574 | -------------------------------------------------------------------------------- /Layout_Parser/img/doc_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Layout_Parser/img/doc_1.pdf -------------------------------------------------------------------------------- /Layout_Parser/img/doc_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Layout_Parser/img/doc_2.pdf -------------------------------------------------------------------------------- /Layout_Parser/layout_parser_ex.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3b88065e-45e9-40ed-aa04-8dd7ec36e45c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Install Dependencies" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "8a68afad-3a7a-40c4-8d8b-84b00d35e72f", 14 | "metadata": {}, 15 | "source": [ 16 | "If you work with a Windows machine, it's better to try LayoutParser on Google Colab instead since it's tricky to install Detectron 2 on Windows machine" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "85d97120-8613-464c-95e0-f06eb514b781", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%%capture\n", 27 | "!sudo apt-get install poppler-utils #pdf2image dependency -- restart runtime/kernel after installation\n", 28 | "!sudo apt-get install tesseract-ocr-eng #install Tesseract OCR Engine --restart runtime/kernel after installation" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "27f229b0-b4c9-46ce-845e-d2cc4fe50eb6", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "%%capture\n", 39 | "!pip install layoutparser torchvision && pip install \"detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2\"\n", 40 | "!pip install pdf2img\n", 41 | "!pip install \"layoutparser[ocr]\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "769c14cf-293e-46be-81e0-50ab7acba502", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import pdf2image\n", 52 | "import numpy as np\n", 53 | "import layoutparser as lp\n", 54 | "import torchvision.ops.boxes as bops\n", 55 | "import torch" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "6decc0b6-15ab-4673-8266-8203d62c4cd0", 61 | "metadata": {}, 62 | "source": [ 63 | "# Layout Detection " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "b7b8795f-89e7-4e13-9fea-bacb95a2fef6", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "pdf_file= '/img/doc_1.pdf' # Adjust the filepath of your input image accordingly\n", 74 | "img = np.asarray(pdf2image.convert_from_path(pdf_file)[0])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "14b474a7-5ce0-4bfc-a6be-37f2d7705224", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config',\n", 85 | " extra_config=[\"MODEL.ROI_HEADS.SCORE_THRESH_TEST\", 0.5],\n", 86 | " label_map={0: \"Text\", 1: \"Title\", 2: \"List\", 3:\"Table\", 4:\"Figure\"})\n", 87 | "\n", 88 | "layout_result = model.detect(img)\n", 89 | "\n", 90 | "lp.draw_box(img, layout_result, box_width=5, box_alpha=0.2, show_element_type=True)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "038df22f-5b1c-454f-80de-9b2443c3a6f7", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "text_blocks = lp.Layout([b for b in layout_result if b.type=='Text'])\n", 101 | "\n", 102 | "lp.draw_box(img, text_blocks, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "a8c88c59-a3c9-4959-b4d8-f141b7e24670", 108 | "metadata": {}, 109 | "source": [ 110 | "# OCR Parser with Tesseract" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "881b4809-60a4-4219-acd0-c4919ad25587", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "ocr_agent = lp.TesseractAgent(languages='eng')" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "bdaf8736-6168-40bc-a2de-e6b76d506336", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "image_width = len(img[0])\n", 131 | "\n", 132 | "# Sort element ID of the left column based on y1 coordinate\n", 133 | "left_interval = lp.Interval(0, image_width/2, axis='x').put_on_canvas(img)\n", 134 | "left_blocks = text_blocks.filter_by(left_interval, center=True)._blocks\n", 135 | "left_blocks.sort(key = lambda b:b.coordinates[1])\n", 136 | "\n", 137 | "# Sort element ID of the right column based on y1 coordinate\n", 138 | "right_blocks = [b for b in text_blocks if b not in left_blocks]\n", 139 | "right_blocks.sort(key = lambda b:b.coordinates[1])\n", 140 | "\n", 141 | "# Sort the overall element ID starts from left column\n", 142 | "text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])\n", 143 | "\n", 144 | "lp.draw_box(img, text_blocks, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "db495dc8-4c21-4a99-bd26-c200de49a8f1", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "for block in text_blocks:\n", 155 | "\n", 156 | " # Crop image around the detected layout\n", 157 | " segment_image = (block\n", 158 | " .pad(left=15, right=15, top=5, bottom=5)\n", 159 | " .crop_image(img))\n", 160 | " \n", 161 | " # Perform OCR\n", 162 | " text = ocr_agent.detect(segment_image)\n", 163 | "\n", 164 | " # Save OCR result\n", 165 | " block.set(text=text, inplace=True)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "69ea88a0-bdb5-4952-8533-88172d4a3b85", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "for txt in text_blocks:\n", 176 | " print(txt.text, end='\\n---\\n')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "1a27c5fa-3af9-43d6-8023-8ab3426b691b", 182 | "metadata": {}, 183 | "source": [ 184 | "# Adjusting LayoutParser Result" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "47ea2cff-6906-4821-8d85-09c45727a1b8", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "pdf_file_2= '/img/doc_2.pdf' # Adjust the filepath of your input image accordingly\n", 195 | "img_2 = np.asarray(pdf2image.convert_from_path(pdf_file)[0])" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "dc81a71c-91b0-4d8f-8b07-90be4f5ed67d", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "layout_result_2 = model.detect(img_2)\n", 206 | "\n", 207 | "text_blocks_2 = lp.Layout([b for b in layout_result_2 if b.type=='Text'])\n", 208 | "\n", 209 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "9d6f8412-6f1c-4b18-9bde-a69e6a939329", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "def set_coordinate(data):\n", 220 | "\n", 221 | " x1 = data.block.x_1\n", 222 | " y1 = data.block.y_1\n", 223 | " x2 = data.block.x_2\n", 224 | " y2 = data.block.y_2\n", 225 | "\n", 226 | " return torch.tensor([[x1, y1, x2, y2]], dtype=torch.float)\n", 227 | "\n", 228 | "def compute_iou(box_1, box_2):\n", 229 | "\n", 230 | " return bops.box_iou(box_1, box_2)\n", 231 | "\n", 232 | "def compute_area(box):\n", 233 | "\n", 234 | " width = box.tolist()[0][2] - box.tolist()[0][0]\n", 235 | " length = box.tolist()[0][3] - box.tolist()[0][1]\n", 236 | " area = width*length\n", 237 | "\n", 238 | " return area\n", 239 | "\n", 240 | "def refine(block_1, block_2):\n", 241 | "\n", 242 | " bb1 = set_coordinate(block_1)\n", 243 | " bb2 = set_coordinate(block_2)\n", 244 | "\n", 245 | " iou = compute_iou(bb1, bb2)\n", 246 | "\n", 247 | " if iou.tolist()[0][0] != 0:\n", 248 | "\n", 249 | " a1 = compute_area(bb1)\n", 250 | " a2 = compute_area(bb2)\n", 251 | "\n", 252 | " block_2.set(type='None', inplace= True) if a1 > a2 else block_1.set(type='None', inplace= True)\n", 253 | " \n", 254 | "\n", 255 | "for layout_i in text_blocks_2:\n", 256 | " \n", 257 | " for layout_j in text_blocks_2:\n", 258 | " \n", 259 | " if layout_i != layout_j: \n", 260 | "\n", 261 | " refine(layout_i, layout_j)\n", 262 | " \n", 263 | "text_blocks_2 = lp.Layout([b for b in text_blocks_2 if b.type=='Text'])\n", 264 | "\n", 265 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "6e28c664-4930-42f5-a107-edc5eb81a23a", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "text_blocks_2 = lp.Layout([b.set(id = idx) for idx, b in enumerate(text_blocks_2)])\n", 276 | "\n", 277 | "# From the visualization, let's say we know that layout \n", 278 | "# with 'Diameter Thickness' text has element ID of 4\n", 279 | "\n", 280 | "text_blocks_2[4].set(type='None', inplace=True)\n", 281 | "text_blocks_2 = lp.Layout([b for b in text_blocks_2 if b.type=='Text'])\n", 282 | "\n", 283 | "lp.draw_box(img_2, text_blocks_2, box_width=5, box_alpha=0.2, show_element_type=True, show_element_id=True)" 284 | ] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.8.8" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 5 308 | } 309 | -------------------------------------------------------------------------------- /Lime/panda_00024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrubenw/medium-resources/c28cb4db83f939014eff9a88730d406476606f50/Lime/panda_00024.jpg -------------------------------------------------------------------------------- /NER_BERT/.ipynb_checkpoints/NER_with_BERT-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "id": "1657ccc8-b9dd-46e7-a08f-b9176ea274ba", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%%capture\n", 11 | "pip install transformers" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "438f352b-1664-4219-b257-855919d467fa", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import torch \n", 23 | "import numpy as np\n", 24 | "from transformers import BertTokenizerFast, BertForTokenClassification\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "from tqdm import tqdm\n", 27 | "from torch.optim import SGD" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "1a414a26-8c98-4eef-b97e-5d1a47df5b67", 33 | "metadata": {}, 34 | "source": [ 35 | "# Read CSV Data" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "e5536782-4d9a-4c35-9d22-7da36a08911a", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/html": [ 47 | "
\n", 48 | "\n", 61 | "\n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | "
textlabels
0Thousands of demonstrators have marched throug...O O O O O O B-geo O O O O O B-geo O O O O O B-...
1Iranian officials say they expect to get acces...B-gpe O O O O O O O O O O O O O O B-tim O O O ...
2Helicopter gunships Saturday pounded militant ...O O B-tim O O O O O B-geo O O O O O B-org O O ...
3They left after a tense hour-long standoff wit...O O O O O O O O O O O
4U.N. relief coordinator Jan Egeland said Sunda...B-geo O O B-per I-per O B-tim O B-geo O B-gpe ...
\n", 97 | "
" 98 | ], 99 | "text/plain": [ 100 | " text \\\n", 101 | "0 Thousands of demonstrators have marched throug... \n", 102 | "1 Iranian officials say they expect to get acces... \n", 103 | "2 Helicopter gunships Saturday pounded militant ... \n", 104 | "3 They left after a tense hour-long standoff wit... \n", 105 | "4 U.N. relief coordinator Jan Egeland said Sunda... \n", 106 | "\n", 107 | " labels \n", 108 | "0 O O O O O O B-geo O O O O O B-geo O O O O O B-... \n", 109 | "1 B-gpe O O O O O O O O O O O O O O B-tim O O O ... \n", 110 | "2 O O B-tim O O O O O B-geo O O O O O B-org O O ... \n", 111 | "3 O O O O O O O O O O O \n", 112 | "4 B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... " 113 | ] 114 | }, 115 | "execution_count": 3, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "df = pd.read_csv('ner.csv')\n", 122 | "df.head()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "c2b1bd34-8843-4706-baa5-201f31245183", 128 | "metadata": {}, 129 | "source": [ 130 | "# Initialize Tokenizer" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "d41be369-ee57-4949-aeb8-7960746d2aea", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "0e3439c2-580e-4972-8603-3a00bc3be62d", 146 | "metadata": {}, 147 | "source": [ 148 | "# Create Dataset Class " 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "ac7f0682-ea50-4aeb-bcd3-9230df735554", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "label_all_tokens = False\n", 159 | "\n", 160 | "def align_label(texts, labels):\n", 161 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n", 162 | "\n", 163 | " word_ids = tokenized_inputs.word_ids()\n", 164 | "\n", 165 | " previous_word_idx = None\n", 166 | " label_ids = []\n", 167 | "\n", 168 | " for word_idx in word_ids:\n", 169 | "\n", 170 | " if word_idx is None:\n", 171 | " label_ids.append(-100)\n", 172 | "\n", 173 | " elif word_idx != previous_word_idx:\n", 174 | " try:\n", 175 | " label_ids.append(labels_to_ids[labels[word_idx]])\n", 176 | " except:\n", 177 | " label_ids.append(-100)\n", 178 | " else:\n", 179 | " try:\n", 180 | " label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)\n", 181 | " except:\n", 182 | " label_ids.append(-100)\n", 183 | " previous_word_idx = word_idx\n", 184 | "\n", 185 | " return label_ids\n", 186 | "\n", 187 | "class DataSequence(torch.utils.data.Dataset):\n", 188 | "\n", 189 | " def __init__(self, df):\n", 190 | "\n", 191 | " lb = [i.split() for i in df['labels'].values.tolist()]\n", 192 | " txt = df['text'].values.tolist()\n", 193 | " self.texts = [tokenizer(str(i),\n", 194 | " padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\") for i in txt]\n", 195 | " self.labels = [align_label(i,j) for i,j in zip(txt, lb)]\n", 196 | "\n", 197 | " def __len__(self):\n", 198 | "\n", 199 | " return len(self.labels)\n", 200 | "\n", 201 | " def get_batch_data(self, idx):\n", 202 | "\n", 203 | " return self.texts[idx]\n", 204 | "\n", 205 | " def get_batch_labels(self, idx):\n", 206 | "\n", 207 | " return torch.LongTensor(self.labels[idx])\n", 208 | "\n", 209 | " def __getitem__(self, idx):\n", 210 | "\n", 211 | " batch_data = self.get_batch_data(idx)\n", 212 | " batch_labels = self.get_batch_labels(idx)\n", 213 | "\n", 214 | " return batch_data, batch_labels" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "496b3cb5-24c8-4c1b-8382-7c4d0f2339a4", 220 | "metadata": {}, 221 | "source": [ 222 | "# Split Data and Define Unique Labels" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "6599961c-1cda-47bf-8c82-a1f9ebc94a95", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "df = df[0:1000]\n", 233 | "\n", 234 | "labels = [i.split() for i in df['labels'].values.tolist()]\n", 235 | "unique_labels = set()\n", 236 | "\n", 237 | "for lb in labels:\n", 238 | " [unique_labels.add(i) for i in lb if i not in unique_labels]\n", 239 | "labels_to_ids = {k: v for v, k in enumerate(unique_labels)}\n", 240 | "ids_to_labels = {v: k for v, k in enumerate(unique_labels)}\n", 241 | "\n", 242 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),\n", 243 | " [int(.8 * len(df)), int(.9 * len(df))])" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "d54b96c5-6875-4990-9248-5d6ad5b053e9", 249 | "metadata": {}, 250 | "source": [ 251 | "# Build Model" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "13ebfa5e-c91a-4967-b0cc-23e314c32348", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "class BertModel(torch.nn.Module):\n", 262 | "\n", 263 | " def __init__(self):\n", 264 | "\n", 265 | " super(BertModel, self).__init__()\n", 266 | "\n", 267 | " self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))\n", 268 | "\n", 269 | " def forward(self, input_id, mask, label):\n", 270 | "\n", 271 | " output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)\n", 272 | "\n", 273 | " return output" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "c3a48d06-d343-449b-829d-2bcad4b2af52", 279 | "metadata": {}, 280 | "source": [ 281 | "# Model Training" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "291bfdad-2df3-4de3-954d-b7e7a9a1b253", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "def train_loop(model, df_train, df_val):\n", 292 | "\n", 293 | " train_dataset = DataSequence(df_train)\n", 294 | " val_dataset = DataSequence(df_val)\n", 295 | "\n", 296 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)\n", 297 | " val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)\n", 298 | "\n", 299 | " use_cuda = torch.cuda.is_available()\n", 300 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 301 | "\n", 302 | " optimizer = SGD(model.parameters(), lr=LEARNING_RATE)\n", 303 | "\n", 304 | " if use_cuda:\n", 305 | " model = model.cuda()\n", 306 | "\n", 307 | " best_acc = 0\n", 308 | " best_loss = 1000\n", 309 | "\n", 310 | " for epoch_num in range(EPOCHS):\n", 311 | "\n", 312 | " total_acc_train = 0\n", 313 | " total_loss_train = 0\n", 314 | "\n", 315 | " model.train()\n", 316 | "\n", 317 | " for train_data, train_label in tqdm(train_dataloader):\n", 318 | "\n", 319 | " train_label = train_label.to(device)\n", 320 | " mask = train_data['attention_mask'].squeeze(1).to(device)\n", 321 | " input_id = train_data['input_ids'].squeeze(1).to(device)\n", 322 | "\n", 323 | " optimizer.zero_grad()\n", 324 | " loss, logits = model(input_id, mask, train_label)\n", 325 | "\n", 326 | " for i in range(logits.shape[0]):\n", 327 | "\n", 328 | " logits_clean = logits[i][train_label[i] != -100]\n", 329 | " label_clean = train_label[i][train_label[i] != -100]\n", 330 | "\n", 331 | " predictions = logits_clean.argmax(dim=1)\n", 332 | " acc = (predictions == label_clean).float().mean()\n", 333 | " total_acc_train += acc\n", 334 | " total_loss_train += loss.item()\n", 335 | "\n", 336 | " loss.backward()\n", 337 | " optimizer.step()\n", 338 | "\n", 339 | " model.eval()\n", 340 | "\n", 341 | " total_acc_val = 0\n", 342 | " total_loss_val = 0\n", 343 | "\n", 344 | " for val_data, val_label in val_dataloader:\n", 345 | "\n", 346 | " val_label = val_label.to(device)\n", 347 | " mask = val_data['attention_mask'].squeeze(1).to(device)\n", 348 | " input_id = val_data['input_ids'].squeeze(1).to(device)\n", 349 | "\n", 350 | " loss, logits = model(input_id, mask, val_label)\n", 351 | "\n", 352 | " for i in range(logits.shape[0]):\n", 353 | "\n", 354 | " logits_clean = logits[i][val_label[i] != -100]\n", 355 | " label_clean = val_label[i][val_label[i] != -100]\n", 356 | "\n", 357 | " predictions = logits_clean.argmax(dim=1)\n", 358 | " acc = (predictions == label_clean).float().mean()\n", 359 | " total_acc_val += acc\n", 360 | " total_loss_val += loss.item()\n", 361 | "\n", 362 | " val_accuracy = total_acc_val / len(df_val)\n", 363 | " val_loss = total_loss_val / len(df_val)\n", 364 | "\n", 365 | " print(\n", 366 | " f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')\n", 367 | "\n", 368 | "LEARNING_RATE = 5e-3\n", 369 | "EPOCHS = 5\n", 370 | "BATCH_SIZE = 2\n", 371 | "\n", 372 | "model = BertModel()\n", 373 | "train_loop(model, df_train, df_val)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "id": "69e1af60-33c3-497e-984f-0094f9bc3a4f", 379 | "metadata": {}, 380 | "source": [ 381 | "# Evaluate Model" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "04295796-7033-4bdf-849f-95e030fc94aa", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "def evaluate(model, df_test):\n", 392 | "\n", 393 | " test_dataset = DataSequence(df_test)\n", 394 | "\n", 395 | " test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1)\n", 396 | "\n", 397 | " use_cuda = torch.cuda.is_available()\n", 398 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 399 | "\n", 400 | " if use_cuda:\n", 401 | " model = model.cuda()\n", 402 | "\n", 403 | " total_acc_test = 0.0\n", 404 | "\n", 405 | " for test_data, test_label in test_dataloader:\n", 406 | "\n", 407 | " test_label = test_label.to(device)\n", 408 | " mask = test_data['attention_mask'].squeeze(1).to(device)\n", 409 | "\n", 410 | " input_id = test_data['input_ids'].squeeze(1).to(device)\n", 411 | "\n", 412 | " loss, logits = model(input_id, mask, test_label)\n", 413 | "\n", 414 | " for i in range(logits.shape[0]):\n", 415 | "\n", 416 | " logits_clean = logits[i][test_label[i] != -100]\n", 417 | " label_clean = test_label[i][test_label[i] != -100]\n", 418 | "\n", 419 | " predictions = logits_clean.argmax(dim=1)\n", 420 | " acc = (predictions == label_clean).float().mean()\n", 421 | " total_acc_test += acc\n", 422 | "\n", 423 | " val_accuracy = total_acc_test / len(df_test)\n", 424 | " print(f'Test Accuracy: {total_acc_test / len(df_test): .3f}')\n", 425 | "\n", 426 | "\n", 427 | "evaluate(model, df_test)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "12edbaf3-cd39-463f-b2ba-d0ce46c246bd", 433 | "metadata": {}, 434 | "source": [ 435 | "# Predict One Sentence" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "99bc3835-a075-4fce-812b-7b9e96778816", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "def align_word_ids(texts):\n", 446 | " \n", 447 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n", 448 | "\n", 449 | " word_ids = tokenized_inputs.word_ids()\n", 450 | "\n", 451 | " previous_word_idx = None\n", 452 | " label_ids = []\n", 453 | "\n", 454 | " for word_idx in word_ids:\n", 455 | "\n", 456 | " if word_idx is None:\n", 457 | " label_ids.append(-100)\n", 458 | "\n", 459 | " elif word_idx != previous_word_idx:\n", 460 | " try:\n", 461 | " label_ids.append(1)\n", 462 | " except:\n", 463 | " label_ids.append(-100)\n", 464 | " else:\n", 465 | " try:\n", 466 | " label_ids.append(1 if label_all_tokens else -100)\n", 467 | " except:\n", 468 | " label_ids.append(-100)\n", 469 | " previous_word_idx = word_idx\n", 470 | "\n", 471 | " return label_ids\n", 472 | "\n", 473 | "\n", 474 | "def evaluate_one_text(model, sentence):\n", 475 | "\n", 476 | "\n", 477 | " use_cuda = torch.cuda.is_available()\n", 478 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 479 | "\n", 480 | " if use_cuda:\n", 481 | " model = model.cuda()\n", 482 | "\n", 483 | " text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\")\n", 484 | "\n", 485 | " mask = text['attention_mask'].to(device)\n", 486 | " input_id = text['input_ids'].to(device)\n", 487 | " label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)\n", 488 | "\n", 489 | " logits = model(input_id, mask, None)\n", 490 | " logits_clean = logits[0][label_ids != -100]\n", 491 | "\n", 492 | " predictions = logits_clean.argmax(dim=1).tolist()\n", 493 | " prediction_label = [ids_to_labels[i] for i in predictions]\n", 494 | " print(sentence)\n", 495 | " print(prediction_label)\n", 496 | " \n", 497 | "evaluate_one_text(model, 'Bill Gates is the founder of Microsoft')" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "fa186d96-7e3c-4457-a3ab-cabc61f2d261", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "723f31f2-12d7-48cc-9f88-1c8fbe860f4c", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [] 515 | } 516 | ], 517 | "metadata": { 518 | "kernelspec": { 519 | "display_name": "Python 3 (ipykernel)", 520 | "language": "python", 521 | "name": "python3" 522 | }, 523 | "language_info": { 524 | "codemirror_mode": { 525 | "name": "ipython", 526 | "version": 3 527 | }, 528 | "file_extension": ".py", 529 | "mimetype": "text/x-python", 530 | "name": "python", 531 | "nbconvert_exporter": "python", 532 | "pygments_lexer": "ipython3", 533 | "version": "3.9.7" 534 | } 535 | }, 536 | "nbformat": 4, 537 | "nbformat_minor": 5 538 | } 539 | -------------------------------------------------------------------------------- /NER_BERT/NER_with_BERT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "id": "1657ccc8-b9dd-46e7-a08f-b9176ea274ba", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%%capture\n", 11 | "pip install transformers" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "438f352b-1664-4219-b257-855919d467fa", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import torch \n", 23 | "import numpy as np\n", 24 | "from transformers import BertTokenizerFast, BertForTokenClassification\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "from tqdm import tqdm\n", 27 | "from torch.optim import SGD" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "1a414a26-8c98-4eef-b97e-5d1a47df5b67", 33 | "metadata": {}, 34 | "source": [ 35 | "# Read CSV Data" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "e5536782-4d9a-4c35-9d22-7da36a08911a", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/html": [ 47 | "
\n", 48 | "\n", 61 | "\n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | "
textlabels
0Thousands of demonstrators have marched throug...O O O O O O B-geo O O O O O B-geo O O O O O B-...
1Iranian officials say they expect to get acces...B-gpe O O O O O O O O O O O O O O B-tim O O O ...
2Helicopter gunships Saturday pounded militant ...O O B-tim O O O O O B-geo O O O O O B-org O O ...
3They left after a tense hour-long standoff wit...O O O O O O O O O O O
4U.N. relief coordinator Jan Egeland said Sunda...B-geo O O B-per I-per O B-tim O B-geo O B-gpe ...
\n", 97 | "
" 98 | ], 99 | "text/plain": [ 100 | " text \\\n", 101 | "0 Thousands of demonstrators have marched throug... \n", 102 | "1 Iranian officials say they expect to get acces... \n", 103 | "2 Helicopter gunships Saturday pounded militant ... \n", 104 | "3 They left after a tense hour-long standoff wit... \n", 105 | "4 U.N. relief coordinator Jan Egeland said Sunda... \n", 106 | "\n", 107 | " labels \n", 108 | "0 O O O O O O B-geo O O O O O B-geo O O O O O B-... \n", 109 | "1 B-gpe O O O O O O O O O O O O O O B-tim O O O ... \n", 110 | "2 O O B-tim O O O O O B-geo O O O O O B-org O O ... \n", 111 | "3 O O O O O O O O O O O \n", 112 | "4 B-geo O O B-per I-per O B-tim O B-geo O B-gpe ... " 113 | ] 114 | }, 115 | "execution_count": 3, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "df = pd.read_csv('ner.csv')\n", 122 | "df.head()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "c2b1bd34-8843-4706-baa5-201f31245183", 128 | "metadata": {}, 129 | "source": [ 130 | "# Initialize Tokenizer" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "d41be369-ee57-4949-aeb8-7960746d2aea", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "0e3439c2-580e-4972-8603-3a00bc3be62d", 146 | "metadata": {}, 147 | "source": [ 148 | "# Create Dataset Class " 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "ac7f0682-ea50-4aeb-bcd3-9230df735554", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "label_all_tokens = False\n", 159 | "\n", 160 | "def align_label(texts, labels):\n", 161 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n", 162 | "\n", 163 | " word_ids = tokenized_inputs.word_ids()\n", 164 | "\n", 165 | " previous_word_idx = None\n", 166 | " label_ids = []\n", 167 | "\n", 168 | " for word_idx in word_ids:\n", 169 | "\n", 170 | " if word_idx is None:\n", 171 | " label_ids.append(-100)\n", 172 | "\n", 173 | " elif word_idx != previous_word_idx:\n", 174 | " try:\n", 175 | " label_ids.append(labels_to_ids[labels[word_idx]])\n", 176 | " except:\n", 177 | " label_ids.append(-100)\n", 178 | " else:\n", 179 | " try:\n", 180 | " label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)\n", 181 | " except:\n", 182 | " label_ids.append(-100)\n", 183 | " previous_word_idx = word_idx\n", 184 | "\n", 185 | " return label_ids\n", 186 | "\n", 187 | "class DataSequence(torch.utils.data.Dataset):\n", 188 | "\n", 189 | " def __init__(self, df):\n", 190 | "\n", 191 | " lb = [i.split() for i in df['labels'].values.tolist()]\n", 192 | " txt = df['text'].values.tolist()\n", 193 | " self.texts = [tokenizer(str(i),\n", 194 | " padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\") for i in txt]\n", 195 | " self.labels = [align_label(i,j) for i,j in zip(txt, lb)]\n", 196 | "\n", 197 | " def __len__(self):\n", 198 | "\n", 199 | " return len(self.labels)\n", 200 | "\n", 201 | " def get_batch_data(self, idx):\n", 202 | "\n", 203 | " return self.texts[idx]\n", 204 | "\n", 205 | " def get_batch_labels(self, idx):\n", 206 | "\n", 207 | " return torch.LongTensor(self.labels[idx])\n", 208 | "\n", 209 | " def __getitem__(self, idx):\n", 210 | "\n", 211 | " batch_data = self.get_batch_data(idx)\n", 212 | " batch_labels = self.get_batch_labels(idx)\n", 213 | "\n", 214 | " return batch_data, batch_labels" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "496b3cb5-24c8-4c1b-8382-7c4d0f2339a4", 220 | "metadata": {}, 221 | "source": [ 222 | "# Split Data and Define Unique Labels" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "6599961c-1cda-47bf-8c82-a1f9ebc94a95", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "df = df[0:1000]\n", 233 | "\n", 234 | "labels = [i.split() for i in df['labels'].values.tolist()]\n", 235 | "unique_labels = set()\n", 236 | "\n", 237 | "for lb in labels:\n", 238 | " [unique_labels.add(i) for i in lb if i not in unique_labels]\n", 239 | "labels_to_ids = {k: v for v, k in enumerate(unique_labels)}\n", 240 | "ids_to_labels = {v: k for v, k in enumerate(unique_labels)}\n", 241 | "\n", 242 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),\n", 243 | " [int(.8 * len(df)), int(.9 * len(df))])" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "d54b96c5-6875-4990-9248-5d6ad5b053e9", 249 | "metadata": {}, 250 | "source": [ 251 | "# Build Model" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "13ebfa5e-c91a-4967-b0cc-23e314c32348", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "class BertModel(torch.nn.Module):\n", 262 | "\n", 263 | " def __init__(self):\n", 264 | "\n", 265 | " super(BertModel, self).__init__()\n", 266 | "\n", 267 | " self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))\n", 268 | "\n", 269 | " def forward(self, input_id, mask, label):\n", 270 | "\n", 271 | " output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)\n", 272 | "\n", 273 | " return output" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "c3a48d06-d343-449b-829d-2bcad4b2af52", 279 | "metadata": {}, 280 | "source": [ 281 | "# Model Training" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "291bfdad-2df3-4de3-954d-b7e7a9a1b253", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "def train_loop(model, df_train, df_val):\n", 292 | "\n", 293 | " train_dataset = DataSequence(df_train)\n", 294 | " val_dataset = DataSequence(df_val)\n", 295 | "\n", 296 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)\n", 297 | " val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)\n", 298 | "\n", 299 | " use_cuda = torch.cuda.is_available()\n", 300 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 301 | "\n", 302 | " optimizer = SGD(model.parameters(), lr=LEARNING_RATE)\n", 303 | "\n", 304 | " if use_cuda:\n", 305 | " model = model.cuda()\n", 306 | "\n", 307 | " best_acc = 0\n", 308 | " best_loss = 1000\n", 309 | "\n", 310 | " for epoch_num in range(EPOCHS):\n", 311 | "\n", 312 | " total_acc_train = 0\n", 313 | " total_loss_train = 0\n", 314 | "\n", 315 | " model.train()\n", 316 | "\n", 317 | " for train_data, train_label in tqdm(train_dataloader):\n", 318 | "\n", 319 | " train_label = train_label.to(device)\n", 320 | " mask = train_data['attention_mask'].squeeze(1).to(device)\n", 321 | " input_id = train_data['input_ids'].squeeze(1).to(device)\n", 322 | "\n", 323 | " optimizer.zero_grad()\n", 324 | " loss, logits = model(input_id, mask, train_label)\n", 325 | "\n", 326 | " for i in range(logits.shape[0]):\n", 327 | "\n", 328 | " logits_clean = logits[i][train_label[i] != -100]\n", 329 | " label_clean = train_label[i][train_label[i] != -100]\n", 330 | "\n", 331 | " predictions = logits_clean.argmax(dim=1)\n", 332 | " acc = (predictions == label_clean).float().mean()\n", 333 | " total_acc_train += acc\n", 334 | " total_loss_train += loss.item()\n", 335 | "\n", 336 | " loss.backward()\n", 337 | " optimizer.step()\n", 338 | "\n", 339 | " model.eval()\n", 340 | "\n", 341 | " total_acc_val = 0\n", 342 | " total_loss_val = 0\n", 343 | "\n", 344 | " for val_data, val_label in val_dataloader:\n", 345 | "\n", 346 | " val_label = val_label.to(device)\n", 347 | " mask = val_data['attention_mask'].squeeze(1).to(device)\n", 348 | " input_id = val_data['input_ids'].squeeze(1).to(device)\n", 349 | "\n", 350 | " loss, logits = model(input_id, mask, val_label)\n", 351 | "\n", 352 | " for i in range(logits.shape[0]):\n", 353 | "\n", 354 | " logits_clean = logits[i][val_label[i] != -100]\n", 355 | " label_clean = val_label[i][val_label[i] != -100]\n", 356 | "\n", 357 | " predictions = logits_clean.argmax(dim=1)\n", 358 | " acc = (predictions == label_clean).float().mean()\n", 359 | " total_acc_val += acc\n", 360 | " total_loss_val += loss.item()\n", 361 | "\n", 362 | " val_accuracy = total_acc_val / len(df_val)\n", 363 | " val_loss = total_loss_val / len(df_val)\n", 364 | "\n", 365 | " print(\n", 366 | " f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')\n", 367 | "\n", 368 | "LEARNING_RATE = 5e-3\n", 369 | "EPOCHS = 5\n", 370 | "BATCH_SIZE = 2\n", 371 | "\n", 372 | "model = BertModel()\n", 373 | "train_loop(model, df_train, df_val)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "id": "69e1af60-33c3-497e-984f-0094f9bc3a4f", 379 | "metadata": {}, 380 | "source": [ 381 | "# Evaluate Model" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "04295796-7033-4bdf-849f-95e030fc94aa", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "def evaluate(model, df_test):\n", 392 | "\n", 393 | " test_dataset = DataSequence(df_test)\n", 394 | "\n", 395 | " test_dataloader = DataLoader(test_dataset, num_workers=4, batch_size=1)\n", 396 | "\n", 397 | " use_cuda = torch.cuda.is_available()\n", 398 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 399 | "\n", 400 | " if use_cuda:\n", 401 | " model = model.cuda()\n", 402 | "\n", 403 | " total_acc_test = 0.0\n", 404 | "\n", 405 | " for test_data, test_label in test_dataloader:\n", 406 | "\n", 407 | " test_label = test_label.to(device)\n", 408 | " mask = test_data['attention_mask'].squeeze(1).to(device)\n", 409 | "\n", 410 | " input_id = test_data['input_ids'].squeeze(1).to(device)\n", 411 | "\n", 412 | " loss, logits = model(input_id, mask, test_label)\n", 413 | "\n", 414 | " for i in range(logits.shape[0]):\n", 415 | "\n", 416 | " logits_clean = logits[i][test_label[i] != -100]\n", 417 | " label_clean = test_label[i][test_label[i] != -100]\n", 418 | "\n", 419 | " predictions = logits_clean.argmax(dim=1)\n", 420 | " acc = (predictions == label_clean).float().mean()\n", 421 | " total_acc_test += acc\n", 422 | "\n", 423 | " val_accuracy = total_acc_test / len(df_test)\n", 424 | " print(f'Test Accuracy: {total_acc_test / len(df_test): .3f}')\n", 425 | "\n", 426 | "\n", 427 | "evaluate(model, df_test)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "12edbaf3-cd39-463f-b2ba-d0ce46c246bd", 433 | "metadata": {}, 434 | "source": [ 435 | "# Predict One Sentence" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "99bc3835-a075-4fce-812b-7b9e96778816", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "def align_word_ids(texts):\n", 446 | " \n", 447 | " tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)\n", 448 | "\n", 449 | " word_ids = tokenized_inputs.word_ids()\n", 450 | "\n", 451 | " previous_word_idx = None\n", 452 | " label_ids = []\n", 453 | "\n", 454 | " for word_idx in word_ids:\n", 455 | "\n", 456 | " if word_idx is None:\n", 457 | " label_ids.append(-100)\n", 458 | "\n", 459 | " elif word_idx != previous_word_idx:\n", 460 | " try:\n", 461 | " label_ids.append(1)\n", 462 | " except:\n", 463 | " label_ids.append(-100)\n", 464 | " else:\n", 465 | " try:\n", 466 | " label_ids.append(1 if label_all_tokens else -100)\n", 467 | " except:\n", 468 | " label_ids.append(-100)\n", 469 | " previous_word_idx = word_idx\n", 470 | "\n", 471 | " return label_ids\n", 472 | "\n", 473 | "\n", 474 | "def evaluate_one_text(model, sentence):\n", 475 | "\n", 476 | "\n", 477 | " use_cuda = torch.cuda.is_available()\n", 478 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 479 | "\n", 480 | " if use_cuda:\n", 481 | " model = model.cuda()\n", 482 | "\n", 483 | " text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors=\"pt\")\n", 484 | "\n", 485 | " mask = text['attention_mask'].to(device)\n", 486 | " input_id = text['input_ids'].to(device)\n", 487 | " label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)\n", 488 | "\n", 489 | " logits = model(input_id, mask, None)\n", 490 | " logits_clean = logits[0][label_ids != -100]\n", 491 | "\n", 492 | " predictions = logits_clean.argmax(dim=1).tolist()\n", 493 | " prediction_label = [ids_to_labels[i] for i in predictions]\n", 494 | " print(sentence)\n", 495 | " print(prediction_label)\n", 496 | " \n", 497 | "evaluate_one_text(model, 'Bill Gates is the founder of Microsoft')" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "fa186d96-7e3c-4457-a3ab-cabc61f2d261", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "723f31f2-12d7-48cc-9f88-1c8fbe860f4c", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [] 515 | } 516 | ], 517 | "metadata": { 518 | "kernelspec": { 519 | "display_name": "Python 3 (ipykernel)", 520 | "language": "python", 521 | "name": "python3" 522 | }, 523 | "language_info": { 524 | "codemirror_mode": { 525 | "name": "ipython", 526 | "version": 3 527 | }, 528 | "file_extension": ".py", 529 | "mimetype": "text/x-python", 530 | "name": "python", 531 | "nbconvert_exporter": "python", 532 | "pygments_lexer": "ipython3", 533 | "version": "3.9.7" 534 | } 535 | }, 536 | "nbformat": 4, 537 | "nbformat_minor": 5 538 | } 539 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Source code for Medium articles 2 | -------------------------------------------------------------------------------- /STS_BERT/STS_BERT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "28df765c-ec6b-450b-a8b4-b40c65f73159", 6 | "metadata": {}, 7 | "source": [ 8 | "# Install necessary libraries" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "bc873328-78e8-4e81-bb0b-fc1ba41bcb82", 15 | "metadata": { 16 | "tags": [] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%%capture\n", 21 | "\n", 22 | "!pip install datasets\n", 23 | "!pip install sentence-transformers\n", 24 | "!pip install transformers" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "002b8bb8-e806-48a8-ab70-8c8f35d13466", 30 | "metadata": {}, 31 | "source": [ 32 | "# Import libraries" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "d8fe01b5-6a9a-458d-a97f-1fb20b621b0f", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import torch\n", 43 | "from sentence_transformers import SentenceTransformer, models\n", 44 | "from transformers import BertTokenizer\n", 45 | "from torch.optim import Adam\n", 46 | "from torch.utils.data import DataLoader\n", 47 | "from tqdm import tqdm\n", 48 | "from datasets import load_dataset" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "0bd9f3e6-9625-4c64-bb76-5e90911239a6", 54 | "metadata": {}, 55 | "source": [ 56 | "# Fetch data for training and test, as well as the tokenizer" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "0fc07b38-8951-4d47-a5e1-ecdf810cdd93", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# Dataset for training\n", 67 | "dataset = load_dataset(\"stsb_multi_mt\", name=\"en\", split=\"train\")\n", 68 | "similarity = [i['similarity_score'] for i in dataset]\n", 69 | "normalized_similarity = [i/5.0 for i in similarity]\n", 70 | "\n", 71 | "# Dataset for test\n", 72 | "test_dataset = load_dataset(\"stsb_multi_mt\", name=\"en\", split=\"test\")\n", 73 | "\n", 74 | "# Prepare test data\n", 75 | "sentence_1_test = [i['sentence1'] for i in test_dataset]\n", 76 | "sentence_2_test = [i['sentence2'] for i in test_dataset]\n", 77 | "text_cat_test = [[str(x), str(y)] for x,y in zip(sentence_1_test, sentence_2_test)]\n", 78 | "\n", 79 | "# Set the tokenizer\n", 80 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "f5299ca8-d00d-4151-8390-86c97403785d", 86 | "metadata": {}, 87 | "source": [ 88 | "# Define Model architecture" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "8f44b7c4-d15c-4b79-bc5e-371d5bd42ffe", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "class STSBertModel(torch.nn.Module):\n", 99 | "\n", 100 | " def __init__(self):\n", 101 | "\n", 102 | " super(STSBertModel, self).__init__()\n", 103 | "\n", 104 | " word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=128)\n", 105 | " pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())\n", 106 | " self.sts_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n", 107 | "\n", 108 | " def forward(self, input_data):\n", 109 | "\n", 110 | " output = self.sts_model(input_data)\n", 111 | " \n", 112 | " return output" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "3103d5bd-e4f3-4b98-a1c6-cfc354f9edca", 118 | "metadata": {}, 119 | "source": [ 120 | "# Define Dataloader for training" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "98ae2242-af8f-4c76-a21a-4baeeed3ae43", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "class DataSequence(torch.utils.data.Dataset):\n", 131 | "\n", 132 | " def __init__(self, dataset):\n", 133 | "\n", 134 | " similarity = [i['similarity_score'] for i in dataset]\n", 135 | " self.label = [i/5.0 for i in similarity]\n", 136 | " self.sentence_1 = [i['sentence1'] for i in dataset]\n", 137 | " self.sentence_2 = [i['sentence2'] for i in dataset]\n", 138 | " self.text_cat = [[str(x), str(y)] for x,y in zip(self.sentence_1, self.sentence_2)]\n", 139 | "\n", 140 | " def __len__(self):\n", 141 | "\n", 142 | " return len(self.text_cat)\n", 143 | "\n", 144 | " def get_batch_labels(self, idx):\n", 145 | "\n", 146 | " return torch.tensor(self.label[idx])\n", 147 | "\n", 148 | " def get_batch_texts(self, idx):\n", 149 | "\n", 150 | " return tokenizer(self.text_cat[idx], padding='max_length', max_length = 128, truncation=True, return_tensors=\"pt\")\n", 151 | "\n", 152 | " def __getitem__(self, idx):\n", 153 | "\n", 154 | " batch_texts = self.get_batch_texts(idx)\n", 155 | " batch_y = self.get_batch_labels(idx)\n", 156 | "\n", 157 | " return batch_texts, batch_y\n", 158 | "\n", 159 | "def collate_fn(texts):\n", 160 | "\n", 161 | " num_texts = len(texts['input_ids'])\n", 162 | " features = list()\n", 163 | " for i in range(num_texts):\n", 164 | " features.append({'input_ids':texts['input_ids'][i], 'attention_mask':texts['attention_mask'][i]})\n", 165 | " \n", 166 | " return features" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "3fd81992-4a05-4f48-9869-83b38b7a3f90", 172 | "metadata": {}, 173 | "source": [ 174 | "# Define loss function for training" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "78ca0312-8648-4101-9429-7286d6268bb5", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "class CosineSimilarityLoss(torch.nn.Module):\n", 185 | "\n", 186 | " def __init__(self, loss_fct = torch.nn.MSELoss(), cos_score_transformation=torch.nn.Identity()):\n", 187 | " \n", 188 | " super(CosineSimilarityLoss, self).__init__()\n", 189 | " self.loss_fct = loss_fct\n", 190 | " self.cos_score_transformation = cos_score_transformation\n", 191 | " self.cos = torch.nn.CosineSimilarity(dim=1)\n", 192 | "\n", 193 | " def forward(self, input, label):\n", 194 | "\n", 195 | " embedding_1 = torch.stack([inp[0] for inp in input])\n", 196 | " embedding_2 = torch.stack([inp[1] for inp in input])\n", 197 | "\n", 198 | " output = self.cos_score_transformation(self.cos(embedding_1, embedding_2))\n", 199 | "\n", 200 | " return self.loss_fct(output, label.squeeze())" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "id": "0fc17726-80e4-409d-8183-1ec17d2e05da", 206 | "metadata": {}, 207 | "source": [ 208 | "# Train the Model" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "bcf41a99-3ce1-4125-b464-d3b8d0d295af", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "def model_train(dataset, epochs, learning_rate, bs):\n", 219 | "\n", 220 | " use_cuda = torch.cuda.is_available()\n", 221 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 222 | "\n", 223 | " model = STSBertModel()\n", 224 | "\n", 225 | " criterion = CosineSimilarityLoss()\n", 226 | " optimizer = Adam(model.parameters(), lr=learning_rate)\n", 227 | "\n", 228 | " train_dataset = DataSequence(dataset)\n", 229 | " train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True)\n", 230 | "\n", 231 | " if use_cuda:\n", 232 | " model = model.cuda()\n", 233 | " criterion = criterion.cuda()\n", 234 | "\n", 235 | " best_acc = 0.0\n", 236 | " best_loss = 1000\n", 237 | "\n", 238 | " for i in range(epochs):\n", 239 | "\n", 240 | " total_acc_train = 0\n", 241 | " total_loss_train = 0.0\n", 242 | "\n", 243 | " for train_data, train_label in tqdm(train_dataloader):\n", 244 | "\n", 245 | " train_data['input_ids'] = train_data['input_ids'].to(device)\n", 246 | " train_data['attention_mask'] = train_data['attention_mask'].to(device)\n", 247 | " del train_data['token_type_ids']\n", 248 | "\n", 249 | " train_data = collate_fn(train_data)\n", 250 | "\n", 251 | " output = [model(feature)['sentence_embedding'] for feature in train_data]\n", 252 | "\n", 253 | " loss = criterion(output, train_label.to(device))\n", 254 | " total_loss_train += loss.item()\n", 255 | "\n", 256 | " loss.backward()\n", 257 | " optimizer.step()\n", 258 | " optimizer.zero_grad()\n", 259 | "\n", 260 | " print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(dataset): .3f}')\n", 261 | " model.train()\n", 262 | "\n", 263 | " return model\n", 264 | "\n", 265 | "EPOCHS = 8\n", 266 | "LEARNING_RATE = 1e-6\n", 267 | "BATCH_SIZE = 8\n", 268 | "\n", 269 | "# Train the model\n", 270 | "trained_model = model_train(dataset, EPOCHS, LEARNING_RATE, BATCH_SIZE)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "id": "f8052fa6-8a87-4378-87c1-1714f55a88bc", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "# Function to predict test data\n", 281 | "def predict_sts(texts):\n", 282 | "\n", 283 | " trained_model.to('cpu')\n", 284 | " trained_model.eval()\n", 285 | " test_input = tokenizer(texts, padding='max_length', max_length = 128, truncation=True, return_tensors=\"pt\")\n", 286 | " test_input['input_ids'] = test_input['input_ids']\n", 287 | " test_input['attention_mask'] = test_input['attention_mask']\n", 288 | " del test_input['token_type_ids']\n", 289 | "\n", 290 | " test_output = trained_model(test_input)['sentence_embedding']\n", 291 | " sim = torch.nn.functional.cosine_similarity(test_output[0], test_output[1], dim=0).item()\n", 292 | "\n", 293 | " return sim" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "id": "d4f70635-7047-46b5-b313-e2e6963ffdab", 299 | "metadata": {}, 300 | "source": [ 301 | "# Predict on test data" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "e4fde3ae-db23-4135-aa04-4a79d040b089", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "predict_sts(text_cat_test[245])" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "id": "f757ad3e-3ea6-4706-a4f3-4a090fa6dbd2", 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "predict_sts(text_cat_test[420])" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Python 3 (ipykernel)", 328 | "language": "python", 329 | "name": "python3" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.9.7" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 5 346 | } 347 | -------------------------------------------------------------------------------- /Spaces_Translation_App/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 3 | import nltk 4 | from nltk import tokenize 5 | nltk.download('punkt') 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("t5-base") 8 | 9 | @st.cache(allow_output_mutation=True) 10 | def load_model(): 11 | model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") 12 | 13 | return model 14 | 15 | model = load_model() 16 | 17 | st.sidebar.subheader('Select your source and target language below.') 18 | source_lang = st.sidebar.selectbox("Source language",['English']) 19 | target_lang = st.sidebar.selectbox("Target language",['German','French']) 20 | 21 | st.title('Simple English ➡️ German/French Translation App') 22 | 23 | st.write('This is a simple machine translation app that will translate\ 24 | your English input text into German or French language\ 25 | by leveraging a pre-trained [Text-To-Text Transfer Tranformers](https://arxiv.org/abs/1910.10683) model.') 26 | 27 | st.write('You can see the source code to build this app in the \'Files and version\' tab.') 28 | 29 | st.subheader('Input Text') 30 | text = st.text_area(' ', height=200) 31 | 32 | if text != '': 33 | 34 | prefix = 'translate '+str(source_lang)+' to '+str(target_lang) 35 | sentence_token = tokenize.sent_tokenize(text) 36 | output = tokenizer([prefix+sentence for sentence in sentence_token], padding=True, return_tensors="pt") 37 | translated_id = model.generate(output["input_ids"], attention_mask=output['attention_mask'], max_length=100) 38 | translated_word = tokenizer.batch_decode(translated_id, skip_special_tokens=True) 39 | 40 | st.subheader('Translated Text') 41 | st.write(' '.join(translated_word)) 42 | -------------------------------------------------------------------------------- /Text_Classification_BERT/bert_medium.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6fed133d-61b7-4ce6-8a44-fe98acf0eed2", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%%capture\n", 11 | "!pip install transformers" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "2f3a8fd1-25a9-426d-a6be-c93b750cbcb8", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import torch\n", 23 | "import numpy as np\n", 24 | "from transformers import BertTokenizer, BertModel\n", 25 | "from torch import nn\n", 26 | "from torch.optim import Adam\n", 27 | "from tqdm import tqdm" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "47a53036-31ab-4374-bf15-a4dca17a7cbf", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "datapath = f'/content/drive/My Drive/Medium/bbc-text.csv'\n", 38 | "df = pd.read_csv(datapath)\n", 39 | "df.head()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "ab965eff-e1eb-416f-b80c-850554d8026c", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "df.groupby(['category']).size().plot.bar()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "5074c270-ed3e-4e1a-863d-71737c743cb8", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n", 60 | "labels = {'business':0,\n", 61 | " 'entertainment':1,\n", 62 | " 'sport':2,\n", 63 | " 'tech':3,\n", 64 | " 'politics':4\n", 65 | " }\n", 66 | "\n", 67 | "class Dataset(torch.utils.data.Dataset):\n", 68 | "\n", 69 | " def __init__(self, df):\n", 70 | "\n", 71 | " self.labels = [labels[label] for label in df['category']]\n", 72 | " self.texts = [tokenizer(text, \n", 73 | " padding='max_length', max_length = 512, truncation=True,\n", 74 | " return_tensors=\"pt\") for text in df['text']]\n", 75 | "\n", 76 | " def classes(self):\n", 77 | " return self.labels\n", 78 | "\n", 79 | " def __len__(self):\n", 80 | " return len(self.labels)\n", 81 | "\n", 82 | " def get_batch_labels(self, idx):\n", 83 | " # Fetch a batch of labels\n", 84 | " return np.array(self.labels[idx])\n", 85 | "\n", 86 | " def get_batch_texts(self, idx):\n", 87 | " # Fetch a batch of inputs\n", 88 | " return self.texts[idx]\n", 89 | "\n", 90 | " def __getitem__(self, idx):\n", 91 | "\n", 92 | " batch_texts = self.get_batch_texts(idx)\n", 93 | " batch_y = self.get_batch_labels(idx)\n", 94 | "\n", 95 | " return batch_texts, batch_y" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "0c8a5d0f-80c3-42b3-9f06-ecfc3a21f395", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "class BertClassifier(nn.Module):\n", 106 | "\n", 107 | " def __init__(self, dropout=0.5):\n", 108 | "\n", 109 | " super(BertClassifier, self).__init__()\n", 110 | "\n", 111 | " self.bert = BertModel.from_pretrained('bert-base-cased')\n", 112 | " self.dropout = nn.Dropout(dropout)\n", 113 | " self.linear = nn.Linear(768, 5)\n", 114 | " self.relu = nn.ReLU()\n", 115 | "\n", 116 | " def forward(self, input_id, mask):\n", 117 | "\n", 118 | " _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)\n", 119 | " dropout_output = self.dropout(pooled_output)\n", 120 | " linear_output = self.linear(dropout_output)\n", 121 | " final_layer = self.relu(linear_output)\n", 122 | "\n", 123 | " return final_layer" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "fa1f1cf7-65db-4966-9a55-ba26bd22ed6c", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def train(model, train_data, val_data, learning_rate, epochs):\n", 134 | "\n", 135 | " train, val = Dataset(train_data), Dataset(val_data)\n", 136 | "\n", 137 | " train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)\n", 138 | " val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)\n", 139 | "\n", 140 | " use_cuda = torch.cuda.is_available()\n", 141 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 142 | "\n", 143 | " criterion = nn.CrossEntropyLoss()\n", 144 | " optimizer = Adam(model.parameters(), lr= learning_rate)\n", 145 | "\n", 146 | " if use_cuda:\n", 147 | "\n", 148 | " model = model.cuda()\n", 149 | " criterion = criterion.cuda()\n", 150 | "\n", 151 | " for epoch_num in range(epochs):\n", 152 | "\n", 153 | " total_acc_train = 0\n", 154 | " total_loss_train = 0\n", 155 | "\n", 156 | " for train_input, train_label in tqdm(train_dataloader):\n", 157 | "\n", 158 | " train_label = train_label.to(device)\n", 159 | " mask = train_input['attention_mask'].to(device)\n", 160 | " input_id = train_input['input_ids'].squeeze(1).to(device)\n", 161 | "\n", 162 | " output = model(input_id, mask)\n", 163 | " \n", 164 | " batch_loss = criterion(output, train_label.long())\n", 165 | " total_loss_train += batch_loss.item()\n", 166 | " \n", 167 | " acc = (output.argmax(dim=1) == train_label).sum().item()\n", 168 | " total_acc_train += acc\n", 169 | "\n", 170 | " model.zero_grad()\n", 171 | " batch_loss.backward()\n", 172 | " optimizer.step()\n", 173 | " \n", 174 | " total_acc_val = 0\n", 175 | " total_loss_val = 0\n", 176 | "\n", 177 | " with torch.no_grad():\n", 178 | "\n", 179 | " for val_input, val_label in val_dataloader:\n", 180 | "\n", 181 | " val_label = val_label.to(device)\n", 182 | " mask = val_input['attention_mask'].to(device)\n", 183 | " input_id = val_input['input_ids'].squeeze(1).to(device)\n", 184 | "\n", 185 | " output = model(input_id, mask)\n", 186 | "\n", 187 | " batch_loss = criterion(output, val_label.long())\n", 188 | " total_loss_val += batch_loss.item()\n", 189 | " \n", 190 | " acc = (output.argmax(dim=1) == val_label).sum().item()\n", 191 | " total_acc_val += acc\n", 192 | " \n", 193 | " print(\n", 194 | " f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} | Train Accuracy: {total_acc_train / len(train_data): .3f} | Val Loss: {total_loss_val / len(val_data): .3f} | Val Accuracy: {total_acc_val / len(val_data): .3f}')\n", 195 | " " 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "cd8a670d-c449-45fe-8f4c-9a5fb27855c1", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "def evaluate(model, test_data):\n", 206 | "\n", 207 | " test = Dataset(test_data)\n", 208 | "\n", 209 | " test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)\n", 210 | "\n", 211 | " use_cuda = torch.cuda.is_available()\n", 212 | " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n", 213 | "\n", 214 | " if use_cuda:\n", 215 | "\n", 216 | " model = model.cuda()\n", 217 | "\n", 218 | " total_acc_test = 0\n", 219 | " with torch.no_grad():\n", 220 | "\n", 221 | " for test_input, test_label in test_dataloader:\n", 222 | "\n", 223 | " test_label = test_label.to(device)\n", 224 | " mask = test_input['attention_mask'].to(device)\n", 225 | " input_id = test_input['input_ids'].squeeze(1).to(device)\n", 226 | "\n", 227 | " output = model(input_id, mask)\n", 228 | "\n", 229 | " acc = (output.argmax(dim=1) == test_label).sum().item()\n", 230 | " total_acc_test += acc\n", 231 | " \n", 232 | " print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "25d2231d-fef1-42cf-a73e-188cac932727", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "np.random.seed(112)\n", 243 | "df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), \n", 244 | " [int(.8*len(df)), int(.9*len(df))])\n", 245 | "\n", 246 | "print(len(df_train),len(df_val), len(df_test))" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "30242239-de70-4c03-8f56-9f5ade43518d", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "EPOCHS = 5\n", 257 | "model = BertClassifier()\n", 258 | "LR = 1e-6\n", 259 | " \n", 260 | "train(model, df_train, df_val, LR, EPOCHS)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "ccc00f0a-9a15-4942-9c9b-2f9789c8dd22", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "evaluate(model, df_test)" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "Python 3", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.8.8" 291 | } 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 5 295 | } 296 | --------------------------------------------------------------------------------