├── NLP_Competitions.md ├── notebooks ├── transformers_multiclass_classification.ipynb ├── transformers_ner.ipynb └── transformers_summarization_wandb.ipynb ├── readme.md └── researcher.md /NLP_Competitions.md: -------------------------------------------------------------------------------- 1 | # 人工智能竞赛 2 | 3 | 4 | - [KDD Cup](https://www.kdd.org/kdd2019/calls/view/kdd-cup-2019-call-for-proposals) 5 | - 4月2号开始 6 | - Basic machine learning / Auto ML 7 | - [CLEF 2019](http://clef2019.clef-initiative.eu/index.php?page=Pages/labs.html) 8 | - [TREC](https://trec.nist.gov/pubs/call2019.html) 9 | - answer retrieval/ News / medicine / information retrieval 10 | - [TAC](https://tac.nist.gov/) 11 | - TAC2018 12 | - 2018年报名时间截止7月15 13 | - ACE 已经停止 14 | - [SemEval](http://alt.qcri.org/semeval2020/) 15 | - 中文介绍:https://zhuanlan.zhihu.com/p/81062536?utm_source=wechat_session&utm_medium=social&utm_oi=30536802238464 16 | - semantic parsing 17 | - emotional detection/offensice language detection 18 | - *suggestion mining from review anf forums 结束* 19 | - fact checking/ rumour detection 20 | - [阿里云天池](https://tianchi.aliyun.com/home/) 21 | - 可报名竞赛:https://tianchi.aliyun.com/competition/gameList/activeList 22 | - [kaggle数据竞赛](https://www.kaggle.com/competitions) 23 | - 可报名竞赛:https://www.kaggle.com/competitions 24 | - Predict stack price via news 25 | - Customer Transaction Prediction 26 | - entity linking: https://www.kaggle.com/c/gendered-pronoun-resolution 27 | - text classification: https://www.kaggle.com/c/transfer-learning-on-stack-exchange-tags 28 | - [DataFountain](https://www.datafountain.cn/) 29 | - 可报名的竞赛:https://www.datafountain.cn/competitions?state=can_signup 30 | - 用户画像 31 | - [AI Challenger](https://challenger.ai/) 32 | - 8/29 开放训练机、验证集、测试集 33 | - 12/19 截止 34 | -------------------------------------------------------------------------------- /notebooks/transformers_multiclass_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "Mmzp8U1LD4ZY" 7 | }, 8 | "source": [ 9 | "# Fine Tuning Transformer for MultiClass Text Classification" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "fyvd41hpD4Zb" 16 | }, 17 | "source": [ 18 | "### Introduction\n", 19 | "\n", 20 | "In this tutorial we will be fine tuning a transformer model for the **Multiclass text classification** problem.\n", 21 | "This is one of the most common business problems where a given piece of text/sentence/document needs to be classified into one of the categories out of the given list.\n", 22 | "\n", 23 | "#### Flow of the notebook\n", 24 | "\n", 25 | "The notebook will be divided into seperate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:\n", 26 | "\n", 27 | "1. [Importing Python Libraries and preparing the environment](#section01)\n", 28 | "2. [Importing and Pre-Processing the domain data](#section02)\n", 29 | "3. [Preparing the Dataset and Dataloader](#section03)\n", 30 | "4. [Creating the Neural Network for Fine Tuning](#section04)\n", 31 | "5. [Fine Tuning the Model](#section05)\n", 32 | "6. [Validating the Model Performance](#section06)\n", 33 | "7. [Saving the model and artifacts for Inference in Future](#section07)\n", 34 | "\n", 35 | "#### Technical Details\n", 36 | "\n", 37 | "This script leverages on multiple tools designed by other teams. Details of the tools used below. Please ensure that these elements are present in your setup to successfully implement this script.\n", 38 | "\n", 39 | " - Data:\n", 40 | "\t - We are using the News aggregator dataset available at by [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/News+Aggregator)\n", 41 | "\t - We are referring only to the first csv file from the data dump: `newsCorpora.csv`\n", 42 | "\t - There are `422937` rows of data. Where each row has the following data-point:\n", 43 | "\t\t - ID Numeric ID \n", 44 | "\t\t - TITLE News title \n", 45 | "\t\t - URL Url \n", 46 | "\t\t - PUBLISHER Publisher name \n", 47 | "\t\t - CATEGORY News category (b = business, t = science and technology, e = entertainment, m = health) \n", 48 | "\t\t - STORY Alphanumeric ID of the cluster that includes news about the same story \n", 49 | "\t\t - HOSTNAME Url hostname \n", 50 | "\t\t - TIMESTAMP Approximate time the news was published, as the number of milliseconds since the epoch 00:00:00 GMT, January 1, 1970\n", 51 | "\n", 52 | "\n", 53 | " - Language Model Used:\n", 54 | "\t - DistilBERT this is a smaller transformer model as compared to BERT or Roberta. It is created by process of distillation applied to Bert.\n", 55 | "\t - [Blog-Post](https://medium.com/huggingface/distilbert-8cf3380435b5)\n", 56 | "\t - [Research Paper](https://arxiv.org/abs/1910.01108)\n", 57 | " - [Documentation for python](https://huggingface.co/transformers/model_doc/distilbert.html)\n", 58 | "\n", 59 | "\n", 60 | " - Hardware Requirements:\n", 61 | "\t - Python 3.6 and above\n", 62 | "\t - Pytorch, Transformers and All the stock Python ML Libraries\n", 63 | "\t - GPU enabled setup\n", 64 | "\n", 65 | "\n", 66 | " - Script Objective:\n", 67 | "\t - The objective of this script is to fine tune DistilBERT to be able to classify a news headline into the following categories:\n", 68 | "\t\t - Business\n", 69 | "\t\t - Technology\n", 70 | "\t\t - Health\n", 71 | "\t\t - Entertainment\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "7-sC5FY1D4Zd" 78 | }, 79 | "source": [ 80 | "\n", 81 | "### Importing Python Libraries and preparing the environment\n", 82 | "\n", 83 | "At this step we will be importing the libraries and modules needed to run our script. Libraries are:\n", 84 | "* Pandas\n", 85 | "* Pytorch\n", 86 | "* Pytorch Utils for Dataset and Dataloader\n", 87 | "* Transformers\n", 88 | "* DistilBERT Model and Tokenizer\n", 89 | "\n", 90 | "Followed by that we will preapre the device for CUDA execeution. This configuration is needed if you want to leverage on onboard GPU." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "id": "wuMlXT80GAMK" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "# Importing the libraries needed\n", 102 | "import pandas as pd\n", 103 | "import torch\n", 104 | "import transformers\n", 105 | "from torch.utils.data import Dataset, DataLoader\n", 106 | "from transformers import DistilBertModel, DistilBertTokenizer" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "id": "xQMKTZ4ARk12" 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "# Setting up the device for GPU usage\n", 118 | "\n", 119 | "from torch import cuda\n", 120 | "device = 'cuda' if cuda.is_available() else 'cpu'" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "ABWbKP_qD4Zf" 127 | }, 128 | "source": [ 129 | "\n", 130 | "### Importing and Pre-Processing the domain data\n", 131 | "\n", 132 | "We will be working with the data and preparing for fine tuning purposes.\n", 133 | "*Assuming that the `newCorpora.csv` is already downloaded in your `data` folder*\n", 134 | "\n", 135 | "Import the file in a dataframe and give it the headers as per the documentation.\n", 136 | "Cleaning the file to remove the unwanted columns and create an additional column for training.\n", 137 | "The final Dataframe will be something like this:\n", 138 | "\n", 139 | "|TITLE|CATEGORY|ENCODED_CAT|\n", 140 | "|--|--|--|\n", 141 | "| title_1|Entertainment | 1 |\n", 142 | "| title_2|Entertainment | 1 |\n", 143 | "| title_3|Business| 2 |\n", 144 | "| title_4|Science| 3 |\n", 145 | "| title_5|Science| 3 |\n", 146 | "| title_6|Health| 4 |" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "id": "iNCaZ2epNcSO" 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "# Import the csv into pandas dataframe and add the headers\n", 158 | "df = pd.read_csv('./data/newsCorpora.csv', sep='\\t', names=['ID','TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP'])\n", 159 | "# df.head()\n", 160 | "# # Removing unwanted columns and only leaving title of news and the category which will be the target\n", 161 | "df = df[['TITLE','CATEGORY']]\n", 162 | "# df.head()\n", 163 | "\n", 164 | "# # Converting the codes to appropriate categories using a dictionary\n", 165 | "my_dict = {\n", 166 | " 'e':'Entertainment',\n", 167 | " 'b':'Business',\n", 168 | " 't':'Science',\n", 169 | " 'm':'Health'\n", 170 | "}\n", 171 | "\n", 172 | "def update_cat(x):\n", 173 | " return my_dict[x]\n", 174 | "\n", 175 | "df['CATEGORY'] = df['CATEGORY'].apply(lambda x: update_cat(x))\n", 176 | "\n", 177 | "encode_dict = {}\n", 178 | "\n", 179 | "def encode_cat(x):\n", 180 | " if x not in encode_dict.keys():\n", 181 | " encode_dict[x]=len(encode_dict)\n", 182 | " return encode_dict[x]\n", 183 | "\n", 184 | "df['ENCODE_CAT'] = df['CATEGORY'].apply(lambda x: encode_cat(x))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "9WK8-eswD4Zh" 191 | }, 192 | "source": [ 193 | "\n", 194 | "### Preparing the Dataset and Dataloader\n", 195 | "\n", 196 | "We will start with defining few key variables that will be used later during the training/fine tuning stage.\n", 197 | "Followed by creation of Dataset class - This defines how the text is pre-processed before sending it to the neural network. We will also define the Dataloader that will feed the data in batches to the neural network for suitable training and processing.\n", 198 | "Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)\n", 199 | "\n", 200 | "#### *Triage* Dataset Class\n", 201 | "- This class is defined to accept the Dataframe as input and generate tokenized output that is used by the DistilBERT model for training.\n", 202 | "- We are using the DistilBERT tokenizer to tokenize the data in the `TITLE` column of the dataframe.\n", 203 | "- The tokenizer uses the `encode_plus` method to perform tokenization and generate the necessary outputs, namely: `ids`, `attention_mask`\n", 204 | "- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/distilbert.html#distilberttokenizer)\n", 205 | "- `target` is the encoded category on the news headline.\n", 206 | "- The *Triage* class is used to create 2 datasets, for training and for validation.\n", 207 | "- *Training Dataset* is used to fine tune the model: **80% of the original data**\n", 208 | "- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training.\n", 209 | "\n", 210 | "#### Dataloader\n", 211 | "- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.\n", 212 | "- This control is achieved using the parameters such as `batch_size` and `max_len`.\n", 213 | "- Training and Validation dataloaders are used in the training and validation part of the flow respectively" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": { 220 | "id": "JrBr2YesGdO_" 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "# Defining some key variables that will be used later on in the training\n", 225 | "MAX_LEN = 512\n", 226 | "TRAIN_BATCH_SIZE = 4\n", 227 | "VALID_BATCH_SIZE = 2\n", 228 | "EPOCHS = 1\n", 229 | "LEARNING_RATE = 1e-05\n", 230 | "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "2vX7kzaAHu39" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "class Triage(Dataset):\n", 242 | " def __init__(self, dataframe, tokenizer, max_len):\n", 243 | " self.len = len(dataframe)\n", 244 | " self.data = dataframe\n", 245 | " self.tokenizer = tokenizer\n", 246 | " self.max_len = max_len\n", 247 | "\n", 248 | " def __getitem__(self, index):\n", 249 | " title = str(self.data.TITLE[index])\n", 250 | " title = \" \".join(title.split())\n", 251 | " inputs = self.tokenizer.encode_plus(\n", 252 | " title,\n", 253 | " None,\n", 254 | " add_special_tokens=True,\n", 255 | " max_length=self.max_len,\n", 256 | " pad_to_max_length=True,\n", 257 | " return_token_type_ids=True,\n", 258 | " truncation=True\n", 259 | " )\n", 260 | " ids = inputs['input_ids']\n", 261 | " mask = inputs['attention_mask']\n", 262 | "\n", 263 | " return {\n", 264 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 265 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 266 | " 'targets': torch.tensor(self.data.ENCODE_CAT[index], dtype=torch.long)\n", 267 | " }\n", 268 | "\n", 269 | " def __len__(self):\n", 270 | " return self.len" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "id": "Zcwq13c0NE9c", 278 | "outputId": "c66ec4e9-a194-47f8-f3ab-de6a9d180a8b" 279 | }, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "FULL Dataset: (422419, 3)\n", 286 | "TRAIN Dataset: (337935, 3)\n", 287 | "TEST Dataset: (84484, 3)\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "# Creating the dataset and dataloader for the neural network\n", 293 | "\n", 294 | "train_size = 0.8\n", 295 | "train_dataset=df.sample(frac=train_size,random_state=200)\n", 296 | "test_dataset=df.drop(train_dataset.index).reset_index(drop=True)\n", 297 | "train_dataset = train_dataset.reset_index(drop=True)\n", 298 | "\n", 299 | "\n", 300 | "print(\"FULL Dataset: {}\".format(df.shape))\n", 301 | "print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n", 302 | "print(\"TEST Dataset: {}\".format(test_dataset.shape))\n", 303 | "\n", 304 | "training_set = Triage(train_dataset, tokenizer, MAX_LEN)\n", 305 | "testing_set = Triage(test_dataset, tokenizer, MAX_LEN)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": { 312 | "id": "l1BgA1CkQSYa" 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 317 | " 'shuffle': True,\n", 318 | " 'num_workers': 0\n", 319 | " }\n", 320 | "\n", 321 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 322 | " 'shuffle': True,\n", 323 | " 'num_workers': 0\n", 324 | " }\n", 325 | "\n", 326 | "training_loader = DataLoader(training_set, **train_params)\n", 327 | "testing_loader = DataLoader(testing_set, **test_params)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "id": "P2HaCpRlD4Zj" 334 | }, 335 | "source": [ 336 | "\n", 337 | "### Creating the Neural Network for Fine Tuning\n", 338 | "\n", 339 | "#### Neural Network\n", 340 | " - We will be creating a neural network with the `DistillBERTClass`.\n", 341 | " - This network will have the DistilBERT Language model followed by a `dropout` and finally a `Linear` layer to obtain the final outputs.\n", 342 | " - The data will be fed to the DistilBERT Language model as defined in the dataset.\n", 343 | " - Final layer outputs is what will be compared to the `encoded category` to determine the accuracy of models prediction.\n", 344 | " - We will initiate an instance of the network called `model`. This instance will be used for training and then to save the final trained model for future inference.\n", 345 | "\n", 346 | "#### Loss Function and Optimizer\n", 347 | " - `Loss Function` and `Optimizer` and defined in the next cell.\n", 348 | " - The `Loss Function` is used the calculate the difference in the output created by the model and the actual output.\n", 349 | " - `Optimizer` is used to update the weights of the neural network to improve its performance.\n", 350 | "\n", 351 | "#### Further Reading\n", 352 | "- You can refer to my [Pytorch Tutorials](https://github.com/abhimishra91/pytorch-tutorials) to get an intuition of Loss Function and Optimizer.\n", 353 | "- [Pytorch Documentation for Loss Function](https://pytorch.org/docs/stable/nn.html#loss-functions)\n", 354 | "- [Pytorch Documentation for Optimizer](https://pytorch.org/docs/stable/optim.html)\n", 355 | "- Refer to the links provided on the top of the notebook to read more about DistiBERT." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": { 362 | "id": "H8yQV2evD4Zk" 363 | }, 364 | "outputs": [], 365 | "source": [ 366 | "# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.\n", 367 | "\n", 368 | "class DistillBERTClass(torch.nn.Module):\n", 369 | " def __init__(self):\n", 370 | " super(DistillBERTClass, self).__init__()\n", 371 | " self.l1 = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n", 372 | " self.pre_classifier = torch.nn.Linear(768, 768)\n", 373 | " self.dropout = torch.nn.Dropout(0.3)\n", 374 | " self.classifier = torch.nn.Linear(768, 4)\n", 375 | "\n", 376 | " def forward(self, input_ids, attention_mask):\n", 377 | " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)\n", 378 | " hidden_state = output_1[0]\n", 379 | " pooler = hidden_state[:, 0]\n", 380 | " pooler = self.pre_classifier(pooler)\n", 381 | " pooler = torch.nn.ReLU()(pooler)\n", 382 | " pooler = self.dropout(pooler)\n", 383 | " output = self.classifier(pooler)\n", 384 | " return output" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "metadata": { 391 | "collapsed": true, 392 | "id": "DuXft-9lD4Zk", 393 | "outputId": "605740ea-3c00-4fe0-86ae-3faee1abad7e" 394 | }, 395 | "outputs": [ 396 | { 397 | "data": { 398 | "text/plain": [ 399 | "DistillBERTClass(\n", 400 | " (l1): DistilBertModel(\n", 401 | " (embeddings): Embeddings(\n", 402 | " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", 403 | " (position_embeddings): Embedding(512, 768)\n", 404 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 405 | " (dropout): Dropout(p=0.1, inplace=False)\n", 406 | " )\n", 407 | " (transformer): Transformer(\n", 408 | " (layer): ModuleList(\n", 409 | " (0): TransformerBlock(\n", 410 | " (attention): MultiHeadSelfAttention(\n", 411 | " (dropout): Dropout(p=0.1, inplace=False)\n", 412 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 413 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 414 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 415 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 416 | " )\n", 417 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 418 | " (ffn): FFN(\n", 419 | " (dropout): Dropout(p=0.1, inplace=False)\n", 420 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 421 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 422 | " )\n", 423 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 424 | " )\n", 425 | " (1): TransformerBlock(\n", 426 | " (attention): MultiHeadSelfAttention(\n", 427 | " (dropout): Dropout(p=0.1, inplace=False)\n", 428 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 429 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 430 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 431 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 432 | " )\n", 433 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 434 | " (ffn): FFN(\n", 435 | " (dropout): Dropout(p=0.1, inplace=False)\n", 436 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 437 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 438 | " )\n", 439 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 440 | " )\n", 441 | " (2): TransformerBlock(\n", 442 | " (attention): MultiHeadSelfAttention(\n", 443 | " (dropout): Dropout(p=0.1, inplace=False)\n", 444 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 445 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 446 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 447 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 448 | " )\n", 449 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 450 | " (ffn): FFN(\n", 451 | " (dropout): Dropout(p=0.1, inplace=False)\n", 452 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 453 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 454 | " )\n", 455 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 456 | " )\n", 457 | " (3): TransformerBlock(\n", 458 | " (attention): MultiHeadSelfAttention(\n", 459 | " (dropout): Dropout(p=0.1, inplace=False)\n", 460 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 461 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 462 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 463 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 464 | " )\n", 465 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 466 | " (ffn): FFN(\n", 467 | " (dropout): Dropout(p=0.1, inplace=False)\n", 468 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 469 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 470 | " )\n", 471 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 472 | " )\n", 473 | " (4): TransformerBlock(\n", 474 | " (attention): MultiHeadSelfAttention(\n", 475 | " (dropout): Dropout(p=0.1, inplace=False)\n", 476 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 477 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 478 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 479 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 480 | " )\n", 481 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 482 | " (ffn): FFN(\n", 483 | " (dropout): Dropout(p=0.1, inplace=False)\n", 484 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 485 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 486 | " )\n", 487 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 488 | " )\n", 489 | " (5): TransformerBlock(\n", 490 | " (attention): MultiHeadSelfAttention(\n", 491 | " (dropout): Dropout(p=0.1, inplace=False)\n", 492 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 493 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 494 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 495 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 496 | " )\n", 497 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 498 | " (ffn): FFN(\n", 499 | " (dropout): Dropout(p=0.1, inplace=False)\n", 500 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 501 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 502 | " )\n", 503 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 504 | " )\n", 505 | " )\n", 506 | " )\n", 507 | " )\n", 508 | " (l2): Dropout(p=0.3, inplace=False)\n", 509 | " (l3): Linear(in_features=768, out_features=1, bias=True)\n", 510 | ")" 511 | ] 512 | }, 513 | "execution_count": 9, 514 | "metadata": {}, 515 | "output_type": "execute_result" 516 | } 517 | ], 518 | "source": [ 519 | "model = DistillBERTClass()\n", 520 | "model.to(device)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": { 527 | "id": "6Ty28GEbD4Zk" 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "# Creating the loss function and optimizer\n", 532 | "loss_function = torch.nn.CrossEntropyLoss()\n", 533 | "optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": { 539 | "id": "KanYVPObD4Zk" 540 | }, 541 | "source": [ 542 | "\n", 543 | "### Fine Tuning the Model\n", 544 | "\n", 545 | "After all the effort of loading and preparing the data and datasets, creating the model and defining its loss and optimizer. This is probably the easier steps in the process.\n", 546 | "\n", 547 | "Here we define a training function that trains the model on the training dataset created above, specified number of times (EPOCH), An epoch defines how many times the complete data will be passed through the network.\n", 548 | "\n", 549 | "Following events happen in this function to fine tune the neural network:\n", 550 | "- The dataloader passes data to the model based on the batch size.\n", 551 | "- Subsequent output from the model and the actual category are compared to calculate the loss.\n", 552 | "- Loss value is used to optimize the weights of the neurons in the network.\n", 553 | "- After every 5000 steps the loss value is printed in the console.\n", 554 | "\n", 555 | "As you can see just in 1 epoch by the final step the model was working with a miniscule loss of 0.0002485 i.e. the output is extremely close to the actual output." 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "id": "S6oCWa5RD4Zl" 563 | }, 564 | "outputs": [], 565 | "source": [ 566 | "# Function to calcuate the accuracy of the model\n", 567 | "\n", 568 | "def calcuate_accu(big_idx, targets):\n", 569 | " n_correct = (big_idx==targets).sum().item()\n", 570 | " return n_correct" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": { 577 | "id": "GCSFxHHQD4Zl" 578 | }, 579 | "outputs": [], 580 | "source": [ 581 | "# Defining the training function on the 80% of the dataset for tuning the distilbert model\n", 582 | "\n", 583 | "def train(epoch):\n", 584 | " tr_loss = 0\n", 585 | " n_correct = 0\n", 586 | " nb_tr_steps = 0\n", 587 | " nb_tr_examples = 0\n", 588 | " model.train()\n", 589 | " for _,data in enumerate(training_loader, 0):\n", 590 | " ids = data['ids'].to(device, dtype = torch.long)\n", 591 | " mask = data['mask'].to(device, dtype = torch.long)\n", 592 | " targets = data['targets'].to(device, dtype = torch.long)\n", 593 | "\n", 594 | " outputs = model(ids, mask)\n", 595 | " loss = loss_function(outputs, targets)\n", 596 | " tr_loss += loss.item()\n", 597 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 598 | " n_correct += calcuate_accu(big_idx, targets)\n", 599 | "\n", 600 | " nb_tr_steps += 1\n", 601 | " nb_tr_examples+=targets.size(0)\n", 602 | "\n", 603 | " if _%5000==0:\n", 604 | " loss_step = tr_loss/nb_tr_steps\n", 605 | " accu_step = (n_correct*100)/nb_tr_examples\n", 606 | " print(f\"Training Loss per 5000 steps: {loss_step}\")\n", 607 | " print(f\"Training Accuracy per 5000 steps: {accu_step}\")\n", 608 | "\n", 609 | " optimizer.zero_grad()\n", 610 | " loss.backward()\n", 611 | " # # When using GPU\n", 612 | " optimizer.step()\n", 613 | "\n", 614 | " print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')\n", 615 | " epoch_loss = tr_loss/nb_tr_steps\n", 616 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 617 | " print(f\"Training Loss Epoch: {epoch_loss}\")\n", 618 | " print(f\"Training Accuracy Epoch: {epoch_accu}\")\n", 619 | "\n", 620 | " return" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": null, 626 | "metadata": { 627 | "id": "lyGRtQeUD4Zl", 628 | "outputId": "f0dabe2a-2fbd-4089-8fb7-6ebd817170f3" 629 | }, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "Epoch: 0, Loss: 6.332988739013672\n", 636 | "Epoch: 0, Loss: 0.0013066530227661133\n", 637 | "Epoch: 0, Loss: 0.0029534101486206055\n", 638 | "Epoch: 0, Loss: 0.005258679389953613\n", 639 | "Epoch: 0, Loss: 0.0020235776901245117\n", 640 | "Epoch: 0, Loss: 0.0023298263549804688\n", 641 | "Epoch: 0, Loss: 0.0034378767013549805\n", 642 | "Epoch: 0, Loss: 0.004993081092834473\n", 643 | "Epoch: 0, Loss: 0.008559942245483398\n", 644 | "Epoch: 0, Loss: 0.0014510154724121094\n", 645 | "Epoch: 0, Loss: 0.0028634071350097656\n", 646 | "Epoch: 0, Loss: 0.0006411075592041016\n", 647 | "Epoch: 0, Loss: 0.0012137889862060547\n", 648 | "Epoch: 0, Loss: 0.002307891845703125\n", 649 | "Epoch: 0, Loss: 0.00028586387634277344\n", 650 | "Epoch: 0, Loss: 0.0029143095016479492\n", 651 | "Epoch: 0, Loss: 0.0002485513687133789\n" 652 | ] 653 | } 654 | ], 655 | "source": [ 656 | "for epoch in range(EPOCHS):\n", 657 | " train(epoch)" 658 | ] 659 | }, 660 | { 661 | "cell_type": "markdown", 662 | "metadata": { 663 | "id": "8kpH_KTgD4Zl" 664 | }, 665 | "source": [ 666 | "\n", 667 | "### Validating the Model\n", 668 | "\n", 669 | "During the validation stage we pass the unseen data(Testing Dataset) to the model. This step determines how good the model performs on the unseen data.\n", 670 | "\n", 671 | "This unseen data is the 20% of `newscorpora.csv` which was seperated during the Dataset creation stage.\n", 672 | "During the validation stage the weights of the model are not updated. Only the final output is compared to the actual value. This comparison is then used to calcuate the accuracy of the model.\n", 673 | "\n", 674 | "As you can see the model is predicting the correct category of a given headline to a 99.9% accuracy." 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": { 681 | "id": "X_d98YhqD4Zl" 682 | }, 683 | "outputs": [], 684 | "source": [ 685 | "def valid(model, testing_loader):\n", 686 | " model.eval()\n", 687 | " n_correct = 0; n_wrong = 0; total = 0\n", 688 | " with torch.no_grad():\n", 689 | " for _, data in enumerate(testing_loader, 0):\n", 690 | " ids = data['ids'].to(device, dtype = torch.long)\n", 691 | " mask = data['mask'].to(device, dtype = torch.long)\n", 692 | " targets = data['targets'].to(device, dtype = torch.long)\n", 693 | " outputs = model(ids, mask).squeeze()\n", 694 | " loss = loss_function(outputs, targets)\n", 695 | " tr_loss += loss.item()\n", 696 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 697 | " n_correct += calcuate_accu(big_idx, targets)\n", 698 | "\n", 699 | " nb_tr_steps += 1\n", 700 | " nb_tr_examples+=targets.size(0)\n", 701 | "\n", 702 | " if _%5000==0:\n", 703 | " loss_step = tr_loss/nb_tr_steps\n", 704 | " accu_step = (n_correct*100)/nb_tr_examples\n", 705 | " print(f\"Validation Loss per 100 steps: {loss_step}\")\n", 706 | " print(f\"Validation Accuracy per 100 steps: {accu_step}\")\n", 707 | " epoch_loss = tr_loss/nb_tr_steps\n", 708 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 709 | " print(f\"Validation Loss Epoch: {epoch_loss}\")\n", 710 | " print(f\"Validation Accuracy Epoch: {epoch_accu}\")\n", 711 | "\n", 712 | " return epoch_accu\n" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": null, 718 | "metadata": { 719 | "id": "FoBIToCvD4Zl", 720 | "outputId": "5fa23d2c-8b60-4567-b142-d12b71c991d1" 721 | }, 722 | "outputs": [ 723 | { 724 | "name": "stdout", 725 | "output_type": "stream", 726 | "text": [ 727 | "This is the validation section to print the accuracy and see how it performs\n", 728 | "Here we are leveraging on the dataloader crearted for the validation dataset, the approcah is using more of pytorch\n", 729 | "Accuracy on test data = 99.99%\n" 730 | ] 731 | } 732 | ], 733 | "source": [ 734 | "print('This is the validation section to print the accuracy and see how it performs')\n", 735 | "print('Here we are leveraging on the dataloader crearted for the validation dataset, the approcah is using more of pytorch')\n", 736 | "\n", 737 | "acc = valid(model, testing_loader)\n", 738 | "print(\"Accuracy on test data = %0.2f%%\" % acc)" 739 | ] 740 | }, 741 | { 742 | "cell_type": "markdown", 743 | "metadata": { 744 | "id": "ERqVMUsiD4Zl" 745 | }, 746 | "source": [ 747 | "\n", 748 | "### Saving the Trained Model Artifacts for inference\n", 749 | "\n", 750 | "This is the final step in the process of fine tuning the model.\n", 751 | "\n", 752 | "The model and its vocabulary are saved locally. These files are then used in the future to make inference on new inputs of news headlines.\n", 753 | "\n", 754 | "Please remember that a trained neural network is only useful when used in actual inference after its training.\n", 755 | "\n", 756 | "In the lifecycle of an ML projects this is only half the job done. We will leave the inference of these models for some other day." 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": null, 762 | "metadata": { 763 | "id": "D7JGtlxjD4Zl", 764 | "outputId": "4eea8e62-afbf-4a03-d37d-15775e1576a7" 765 | }, 766 | "outputs": [ 767 | { 768 | "name": "stdout", 769 | "output_type": "stream", 770 | "text": [ 771 | "All files saved\n", 772 | "This tutorial is completed\n" 773 | ] 774 | } 775 | ], 776 | "source": [ 777 | "# Saving the files for re-use\n", 778 | "\n", 779 | "output_model_file = './models/pytorch_distilbert_news.bin'\n", 780 | "output_vocab_file = './models/vocab_distilbert_news.bin'\n", 781 | "\n", 782 | "model_to_save = model\n", 783 | "torch.save(model_to_save, output_model_file)\n", 784 | "tokenizer.save_vocabulary(output_vocab_file)\n", 785 | "\n", 786 | "print('All files saved')\n", 787 | "print('This tutorial is completed')" 788 | ] 789 | } 790 | ], 791 | "metadata": { 792 | "colab": { 793 | "name": "01_transformers_multiclass_classification.ipynb", 794 | "provenance": [], 795 | "gpuType": "T4" 796 | }, 797 | "kernelspec": { 798 | "display_name": "Python 3", 799 | "name": "python3" 800 | }, 801 | "varInspector": { 802 | "cols": { 803 | "lenName": 16, 804 | "lenType": 16, 805 | "lenVar": 40 806 | }, 807 | "kernels_config": { 808 | "python": { 809 | "delete_cmd_postfix": "", 810 | "delete_cmd_prefix": "del ", 811 | "library": "var_list.py", 812 | "varRefreshCmd": "print(var_dic_list())" 813 | }, 814 | "r": { 815 | "delete_cmd_postfix": ") ", 816 | "delete_cmd_prefix": "rm(", 817 | "library": "var_list.r", 818 | "varRefreshCmd": "cat(var_dic_list()) " 819 | } 820 | }, 821 | "types_to_exclude": [ 822 | "module", 823 | "function", 824 | "builtin_function_or_method", 825 | "instance", 826 | "_Feature" 827 | ], 828 | "window_display": false 829 | }, 830 | "accelerator": "GPU" 831 | }, 832 | "nbformat": 4, 833 | "nbformat_minor": 0 834 | } -------------------------------------------------------------------------------- /notebooks/transformers_ner.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "rFvcIe4qz_2t" 7 | }, 8 | "source": [ 9 | "# Fine Tuning Transformer for Named Entity Recognition" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "Zt90X_Dw0B_T" 16 | }, 17 | "source": [ 18 | "### Introduction\n", 19 | "\n", 20 | "In this tutorial we will be fine tuning a transformer model for the **Named Entity Recognition** problem.\n", 21 | "This is one of the most common business problems where a given piece of text/sentence/document different entites need to be identified such as: Name, Location, Number, Entity etc.\n", 22 | "\n", 23 | "#### Flow of the notebook\n", 24 | "\n", 25 | "The notebook will be divided into seperate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:\n", 26 | "\n", 27 | "1. [Installing packages for preparing the system](#section00)\n", 28 | "2. [Importing Python Libraries and preparing the environment](#section01)\n", 29 | "3. [Importing and Pre-Processing the domain data](#section02)\n", 30 | "4. [Preparing the Dataset and Dataloader](#section03)\n", 31 | "5. [Creating the Neural Network for Fine Tuning](#section04)\n", 32 | "6. [Fine Tuning the Model](#section05)\n", 33 | "7. [Validating the Model Performance](#section06)\n", 34 | "\n", 35 | "#### Technical Details\n", 36 | "\n", 37 | "This script leverages on multiple tools designed by other teams. Details of the tools used below. Please ensure that these elements are present in your setup to successfully implement this script.\n", 38 | "\n", 39 | " - Data:\n", 40 | "\t- We are working from a dataset available on [Kaggle](https://www.kaggle.com/)\n", 41 | " - This NER annotated dataset is available at the following [link](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus)\n", 42 | " - We will be working with the file `ner.csv` from the dataset.\n", 43 | " - In the given file we will be looking at the following columns for the purpose of this fine tuning:\n", 44 | " - `sentence_idx` : This is the identifier that the word in the row is part of the same sentence\n", 45 | " - `word` : Word in the sentence\n", 46 | " - `tag` : This is the identifier that is used to identify the entity in the dataset.\n", 47 | " - The various entites tagged in this dataset are as per below:\n", 48 | " - geo = Geographical Entity\n", 49 | " - org = Organization\n", 50 | " - per = Person\n", 51 | " - gpe = Geopolitical Entity\n", 52 | " - tim = Time indicator\n", 53 | " - art = Artifact\n", 54 | " - eve = Event\n", 55 | " - nat = Natural Phenomenon\n", 56 | "\n", 57 | "\n", 58 | " - Language Model Used:\n", 59 | "\t - We are using BERT for this project. Hugging face team has created a customized model for token classification, called **BertForTokenClassification**. We will be using it in our custommodel class for training.\n", 60 | "\t - [Blog-Post](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)\n", 61 | " - [Documentation for python](https://huggingface.co/transformers/model_doc/bert.html#bertfortokenclassification)\n", 62 | "\n", 63 | "\n", 64 | " - Hardware Requirements:\n", 65 | "\t - Python 3.6 and above\n", 66 | "\t - Pytorch, Transformers and All the stock Python ML Libraries\n", 67 | "\t - TPU enabled setup. This can also be executed over GPU but the code base will need some changes.\n", 68 | "\n", 69 | "\n", 70 | " - Script Objective:\n", 71 | "\t - The objective of this script is to fine tune **BertForTokenClassification**` to be able to identify the entites as per the given test dataset. The entities labled in the given dataset are as follows:" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "qFDNndJfD6s1" 78 | }, 79 | "source": [ 80 | "\n", 81 | "### Installing packages for preparing the system\n", 82 | "\n", 83 | "We are installing 2 packages for the purposes of TPU execution and f1 metric score calculation respectively\n", 84 | "*You can skip this step if you already have these libraries installed in your environment*" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "colab": { 92 | "base_uri": "https://localhost:8080/", 93 | "height": 773 94 | }, 95 | "id": "pWbkd8Ld8MwL", 96 | "outputId": "b44f7ea3-2c0a-4e7c-f7ed-19f43d62de28" 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 104 | " Dload Upload Total Spent Left Speed\n", 105 | "\r", 106 | " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r", 107 | "100 3727 100 3727 0 0 41876 0 --:--:-- --:--:-- --:--:-- 41876\n", 108 | "Updating TPU and VM. This may take around 2 minutes.\n", 109 | "Updating TPU runtime to pytorch-dev20200325 ...\n", 110 | "Done updating TPU runtime: \n", 111 | "Uninstalling torch-1.5.0a0+d6149a7:\n", 112 | " Successfully uninstalled torch-1.5.0a0+d6149a7\n", 113 | "Uninstalling torchvision-0.6.0a0+3c254fb:\n", 114 | " Successfully uninstalled torchvision-0.6.0a0+3c254fb\n", 115 | "Copying gs://tpu-pytorch/wheels/torch-nightly+20200325-cp36-cp36m-linux_x86_64.whl...\n", 116 | "- [1 files][ 83.4 MiB/ 83.4 MiB] \n", 117 | "Operation completed over 1 objects/83.4 MiB. \n", 118 | "Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200325-cp36-cp36m-linux_x86_64.whl...\n", 119 | "\\ [1 files][114.5 MiB/114.5 MiB] \n", 120 | "Operation completed over 1 objects/114.5 MiB. \n", 121 | "Copying gs://tpu-pytorch/wheels/torchvision-nightly+20200325-cp36-cp36m-linux_x86_64.whl...\n", 122 | "/ [1 files][ 2.5 MiB/ 2.5 MiB] \n", 123 | "Operation completed over 1 objects/2.5 MiB. \n", 124 | "Processing ./torch-nightly+20200325-cp36-cp36m-linux_x86_64.whl\n", 125 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==nightly+20200325) (1.18.3)\n", 126 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch==nightly+20200325) (0.16.0)\n", 127 | "\u001b[31mERROR: fastai 1.0.61 requires torchvision, which is not installed.\u001b[0m\n", 128 | "Installing collected packages: torch\n", 129 | "Successfully installed torch-1.5.0a0+d6149a7\n", 130 | "Processing ./torch_xla-nightly+20200325-cp36-cp36m-linux_x86_64.whl\n", 131 | "Installing collected packages: torch-xla\n", 132 | " Found existing installation: torch-xla 1.6+e788e5b\n", 133 | " Uninstalling torch-xla-1.6+e788e5b:\n", 134 | " Successfully uninstalled torch-xla-1.6+e788e5b\n", 135 | "Successfully installed torch-xla-1.6+e788e5b\n", 136 | "Processing ./torchvision-nightly+20200325-cp36-cp36m-linux_x86_64.whl\n", 137 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200325) (1.5.0a0+d6149a7)\n", 138 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200325) (7.0.0)\n", 139 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200325) (1.12.0)\n", 140 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200325) (1.18.3)\n", 141 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchvision==nightly+20200325) (0.16.0)\n", 142 | "Installing collected packages: torchvision\n", 143 | "Successfully installed torchvision-0.6.0a0+3c254fb\n", 144 | "Reading package lists... Done\n", 145 | "Building dependency tree \n", 146 | "Reading state information... Done\n", 147 | "libomp5 is already the newest version (5.0.1-1).\n", 148 | "libopenblas-dev is already the newest version (0.2.20+ds-4).\n", 149 | "0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "!curl -q https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", 155 | "!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev\n", 156 | "!pip -q install seqeval" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": { 162 | "id": "6ts5noAi0a6K" 163 | }, 164 | "source": [ 165 | "\n", 166 | "### Importing Python Libraries and preparing the environment\n", 167 | "\n", 168 | "At this step we will be importing the libraries and modules needed to run our script. Libraries are:\n", 169 | "* Pandas\n", 170 | "* Pytorch\n", 171 | "* Pytorch Utils for Dataset and Dataloader\n", 172 | "* Transformers\n", 173 | "* BERT Model and Tokenizer\n", 174 | "\n", 175 | "Followed by that we will preapre the device for TPU execeution. This configuration is needed if you want to leverage on onboard TPU." 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "id": "NYUqKiOZdR1H" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "# Importing pytorch and the library for TPU execution\n", 187 | "\n", 188 | "import torch\n", 189 | "import torch_xla\n", 190 | "import torch_xla.core.xla_model as xm" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "id": "y3jTWir2cBlN" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "# Importing stock ml libraries\n", 202 | "\n", 203 | "import numpy as np\n", 204 | "import pandas as pd\n", 205 | "import transformers\n", 206 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", 207 | "from transformers import BertForTokenClassification, BertTokenizer, BertConfig, BertModel\n", 208 | "\n", 209 | "# Preparing for TPU usage\n", 210 | "dev = xm.xla_device()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": { 216 | "id": "JrBFYkn2D6s6" 217 | }, 218 | "source": [ 219 | "\n", 220 | "### Importing and Pre-Processing the domain data\n", 221 | "\n", 222 | "We will be working with the data and preparing for fine tuning purposes.\n", 223 | "*Assuming that the `ner.csv` is already downloaded in your `data` folder*\n", 224 | "\n", 225 | "* Import the file in a dataframe and give it the headers as per the documentation.\n", 226 | "* Cleaning the file to remove the unwanted columns.\n", 227 | "* We will create a class `SentenceGetter` that will pull the words from the columns and create them into sentences\n", 228 | "* Followed by that we will create some additional lists and dict to keep the data that will be used for future processing" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": { 236 | "base_uri": "https://localhost:8080/", 237 | "height": 212 238 | }, 239 | "id": "81kDZbz2cDn7", 240 | "outputId": "1312c879-5c24-4964-c6ab-ba1b9b189964" 241 | }, 242 | "outputs": [ 243 | { 244 | "name": "stderr", 245 | "output_type": "stream", 246 | "text": [ 247 | "b'Skipping line 281837: expected 25 fields, saw 34\\n'\n" 248 | ] 249 | }, 250 | { 251 | "data": { 252 | "text/html": [ 253 | "
\n", 254 | "\n", 267 | "\n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | "
possentence_idxwordtag
0NNS1.0ThousandsO
1IN1.0ofO
2NNS1.0demonstratorsO
3VBP1.0haveO
4VBN1.0marchedO
\n", 315 | "
" 316 | ], 317 | "text/plain": [ 318 | " pos sentence_idx word tag\n", 319 | "0 NNS 1.0 Thousands O\n", 320 | "1 IN 1.0 of O\n", 321 | "2 NNS 1.0 demonstrators O\n", 322 | "3 VBP 1.0 have O\n", 323 | "4 VBN 1.0 marched O" 324 | ] 325 | }, 326 | "execution_count": 5, 327 | "metadata": { 328 | "tags": [] 329 | }, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "df = pd.read_csv(\"./data/ner.csv\", encoding = \"ISO-8859-1\", error_bad_lines=False)\n", 335 | "dataset=df.drop(['Unnamed: 0', 'lemma', 'next-lemma', 'next-next-lemma', 'next-next-pos',\n", 336 | " 'next-next-shape', 'next-next-word', 'next-pos', 'next-shape',\n", 337 | " 'next-word', 'prev-iob', 'prev-lemma', 'prev-pos',\n", 338 | " 'prev-prev-iob', 'prev-prev-lemma', 'prev-prev-pos', 'prev-prev-shape',\n", 339 | " 'prev-prev-word', 'prev-shape', 'prev-word','shape'],axis=1)\n", 340 | "dataset.head()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": { 347 | "id": "sdqhHeAqcLnO" 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "# Creating a class to pull the words from the columns and create them into sentences\n", 352 | "\n", 353 | "class SentenceGetter(object):\n", 354 | "\n", 355 | " def __init__(self, dataset):\n", 356 | " self.n_sent = 1\n", 357 | " self.dataset = dataset\n", 358 | " self.empty = False\n", 359 | " agg_func = lambda s: [(w,p, t) for w,p, t in zip(s[\"word\"].values.tolist(),\n", 360 | " s['pos'].values.tolist(),\n", 361 | " s[\"tag\"].values.tolist())]\n", 362 | " self.grouped = self.dataset.groupby(\"sentence_idx\").apply(agg_func)\n", 363 | " self.sentences = [s for s in self.grouped]\n", 364 | "\n", 365 | " def get_next(self):\n", 366 | " try:\n", 367 | " s = self.grouped[\"Sentence: {}\".format(self.n_sent)]\n", 368 | " self.n_sent += 1\n", 369 | " return s\n", 370 | " except:\n", 371 | " return None\n", 372 | "\n", 373 | "getter = SentenceGetter(dataset)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": { 380 | "id": "Cqln3QJecNsJ" 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "# Creating new lists and dicts that will be used at a later stage for reference and processing\n", 385 | "\n", 386 | "tags_vals = list(set(dataset[\"tag\"].values))\n", 387 | "tag2idx = {t: i for i, t in enumerate(tags_vals)}\n", 388 | "sentences = [' '.join([s[0] for s in sent]) for sent in getter.sentences]\n", 389 | "labels = [[s[2] for s in sent] for sent in getter.sentences]\n", 390 | "labels = [[tag2idx.get(l) for l in lab] for lab in labels]" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "id": "iA1zZ7zND6s8" 397 | }, 398 | "source": [ 399 | "\n", 400 | "### Preparing the Dataset and Dataloader\n", 401 | "\n", 402 | "We will start with defining few key variables that will be used later during the training/fine tuning stage.\n", 403 | "Followed by creation of Dataset class - This defines how the text is pre-processed before sending it to the neural network. We will also define the Dataloader that will feed the data in batches to the neural network for suitable training and processing.\n", 404 | "Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)\n", 405 | "\n", 406 | "#### *CustomDataset* Dataset Class\n", 407 | "- This class is defined to accept the `tokenizer`, `sentences` and `labels` as input and generate tokenized output and tags that is used by the BERT model for training.\n", 408 | "- We are using the BERT tokenizer to tokenize the data in the `sentences` list for encoding.\n", 409 | "- The tokenizer uses the `encode_plus` method to perform tokenization and generate the necessary outputs, namely: `ids`, `attention_mask`\n", 410 | "- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/bert.html#berttokenizer)\n", 411 | "- `tags` is the encoded entity from the annonated dataset.\n", 412 | "- The *CustomDataset* class is used to create 2 datasets, for training and for validation.\n", 413 | "- *Training Dataset* is used to fine tune the model: **80% of the original data**\n", 414 | "- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training.\n", 415 | "\n", 416 | "#### Dataloader\n", 417 | "- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.\n", 418 | "- This control is achieved using the parameters such as `batch_size` and `max_len`.\n", 419 | "- Training and Validation dataloaders are used in the training and validation part of the flow respectively" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": { 426 | "id": "kL0b1VIQcTAC" 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "# Defining some key variables that will be used later on in the training\n", 431 | "\n", 432 | "MAX_LEN = 200\n", 433 | "TRAIN_BATCH_SIZE = 32\n", 434 | "VALID_BATCH_SIZE = 16\n", 435 | "EPOCHS = 5\n", 436 | "LEARNING_RATE = 2e-05\n", 437 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": { 444 | "id": "IV72GFgq_ZYb" 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "class CustomDataset(Dataset):\n", 449 | " def __init__(self, tokenizer, sentences, labels, max_len):\n", 450 | " self.len = len(sentences)\n", 451 | " self.sentences = sentences\n", 452 | " self.labels = labels\n", 453 | " self.tokenizer = tokenizer\n", 454 | " self.max_len = max_len\n", 455 | "\n", 456 | " def __getitem__(self, index):\n", 457 | " sentence = str(self.sentences[index])\n", 458 | " inputs = self.tokenizer.encode_plus(\n", 459 | " sentence,\n", 460 | " None,\n", 461 | " add_special_tokens=True,\n", 462 | " max_length=self.max_len,\n", 463 | " pad_to_max_length=True,\n", 464 | " return_token_type_ids=True\n", 465 | " )\n", 466 | " ids = inputs['input_ids']\n", 467 | " mask = inputs['attention_mask']\n", 468 | " label = self.labels[index]\n", 469 | " label.extend([4]*200)\n", 470 | " label=label[:200]\n", 471 | "\n", 472 | " return {\n", 473 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 474 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 475 | " 'tags': torch.tensor(label, dtype=torch.long)\n", 476 | " }\n", 477 | "\n", 478 | " def __len__(self):\n", 479 | " return self.len" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": { 486 | "colab": { 487 | "base_uri": "https://localhost:8080/", 488 | "height": 67 489 | }, 490 | "id": "7BvniiyvX-rB", 491 | "outputId": "61fdc431-4544-4a19-a176-46f19ca58f0f" 492 | }, 493 | "outputs": [ 494 | { 495 | "name": "stdout", 496 | "output_type": "stream", 497 | "text": [ 498 | "FULL Dataset: 35177\n", 499 | "TRAIN Dataset: 28141\n", 500 | "TEST Dataset: 7036\n" 501 | ] 502 | } 503 | ], 504 | "source": [ 505 | "# Creating the dataset and dataloader for the neural network\n", 506 | "\n", 507 | "train_percent = 0.8\n", 508 | "train_size = int(train_percent*len(sentences))\n", 509 | "# train_dataset=df.sample(frac=train_size,random_state=200).reset_index(drop=True)\n", 510 | "# test_dataset=df.drop(train_dataset.index).reset_index(drop=True)\n", 511 | "train_sentences = sentences[0:train_size]\n", 512 | "train_labels = labels[0:train_size]\n", 513 | "\n", 514 | "test_sentences = sentences[train_size:]\n", 515 | "test_labels = labels[train_size:]\n", 516 | "\n", 517 | "print(\"FULL Dataset: {}\".format(len(sentences)))\n", 518 | "print(\"TRAIN Dataset: {}\".format(len(train_sentences)))\n", 519 | "print(\"TEST Dataset: {}\".format(len(test_sentences)))\n", 520 | "\n", 521 | "training_set = CustomDataset(tokenizer, train_sentences, train_labels, MAX_LEN)\n", 522 | "testing_set = CustomDataset(tokenizer, test_sentences, test_labels, MAX_LEN)" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": { 529 | "id": "F05bou5zZYtV" 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 534 | " 'shuffle': True,\n", 535 | " 'num_workers': 0\n", 536 | " }\n", 537 | "\n", 538 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 539 | " 'shuffle': True,\n", 540 | " 'num_workers': 0\n", 541 | " }\n", 542 | "\n", 543 | "training_loader = DataLoader(training_set, **train_params)\n", 544 | "testing_loader = DataLoader(testing_set, **test_params)" 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": { 550 | "id": "X201uu_hD6s9" 551 | }, 552 | "source": [ 553 | "\n", 554 | "### Creating the Neural Network for Fine Tuning\n", 555 | "\n", 556 | "#### Neural Network\n", 557 | " - We will be creating a neural network with the `BERTClass`.\n", 558 | " - This network will have the `BertForTokenClassification` model.\n", 559 | " - The data will be fed to the `BertForTokenClassification` as defined in the dataset.\n", 560 | " - Final layer outputs is what will be used to calcuate the loss and to determine the accuracy of models prediction.\n", 561 | " - We will initiate an instance of the network called `model`. This instance will be used for training and then to save the final trained model for future inference.\n", 562 | "\n", 563 | "#### Loss Function and Optimizer\n", 564 | " - `Optimizer` is defined in the next cell.\n", 565 | " - We do not define any `Loss function` since the specified model already outputs `Loss` for a given input.\n", 566 | " - `Optimizer` is used to update the weights of the neural network to improve its performance.\n", 567 | "\n", 568 | "#### Further Reading\n", 569 | "- You can refer to my [Pytorch Tutorials](https://github.com/abhimishra91/pytorch-tutorials) to get an intuition of Loss Function and Optimizer.\n", 570 | "- [Pytorch Documentation for Loss Function](https://pytorch.org/docs/stable/nn.html#loss-functions)\n", 571 | "- [Pytorch Documentation for Optimizer](https://pytorch.org/docs/stable/optim.html)\n", 572 | "- Refer to the links provided on the top of the notebook to read more about `BertForTokenClassification`." 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": null, 578 | "metadata": { 579 | "id": "9vuIJrvSZble" 580 | }, 581 | "outputs": [], 582 | "source": [ 583 | "# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.\n", 584 | "\n", 585 | "class BERTClass(torch.nn.Module):\n", 586 | " def __init__(self):\n", 587 | " super(BERTClass, self).__init__()\n", 588 | " self.l1 = transformers.BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=18)\n", 589 | " # self.l2 = torch.nn.Dropout(0.3)\n", 590 | " # self.l3 = torch.nn.Linear(768, 200)\n", 591 | "\n", 592 | " def forward(self, ids, mask, labels):\n", 593 | " output_1= self.l1(ids, mask, labels = labels)\n", 594 | " # output_2 = self.l2(output_1[0])\n", 595 | " # output = self.l3(output_2)\n", 596 | " return output_1" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/", 605 | "height": 1000 606 | }, 607 | "collapsed": true, 608 | "id": "CflOeT2-ZoV6", 609 | "outputId": "8045068c-9aa0-48b8-cad1-8ac7c3b1e807" 610 | }, 611 | "outputs": [ 612 | { 613 | "data": { 614 | "text/plain": [ 615 | "BertForTokenClassification(\n", 616 | " (bert): BertModel(\n", 617 | " (embeddings): BertEmbeddings(\n", 618 | " (word_embeddings): Embedding(28996, 768, padding_idx=0)\n", 619 | " (position_embeddings): Embedding(512, 768)\n", 620 | " (token_type_embeddings): Embedding(2, 768)\n", 621 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 622 | " (dropout): Dropout(p=0.1, inplace=False)\n", 623 | " )\n", 624 | " (encoder): BertEncoder(\n", 625 | " (layer): ModuleList(\n", 626 | " (0): BertLayer(\n", 627 | " (attention): BertAttention(\n", 628 | " (self): BertSelfAttention(\n", 629 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 630 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 631 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 632 | " (dropout): Dropout(p=0.1, inplace=False)\n", 633 | " )\n", 634 | " (output): BertSelfOutput(\n", 635 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 636 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 637 | " (dropout): Dropout(p=0.1, inplace=False)\n", 638 | " )\n", 639 | " )\n", 640 | " (intermediate): BertIntermediate(\n", 641 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 642 | " )\n", 643 | " (output): BertOutput(\n", 644 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 645 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 646 | " (dropout): Dropout(p=0.1, inplace=False)\n", 647 | " )\n", 648 | " )\n", 649 | " (1): BertLayer(\n", 650 | " (attention): BertAttention(\n", 651 | " (self): BertSelfAttention(\n", 652 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 653 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 654 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 655 | " (dropout): Dropout(p=0.1, inplace=False)\n", 656 | " )\n", 657 | " (output): BertSelfOutput(\n", 658 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 659 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 660 | " (dropout): Dropout(p=0.1, inplace=False)\n", 661 | " )\n", 662 | " )\n", 663 | " (intermediate): BertIntermediate(\n", 664 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 665 | " )\n", 666 | " (output): BertOutput(\n", 667 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 668 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 669 | " (dropout): Dropout(p=0.1, inplace=False)\n", 670 | " )\n", 671 | " )\n", 672 | " (2): BertLayer(\n", 673 | " (attention): BertAttention(\n", 674 | " (self): BertSelfAttention(\n", 675 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 676 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 677 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 678 | " (dropout): Dropout(p=0.1, inplace=False)\n", 679 | " )\n", 680 | " (output): BertSelfOutput(\n", 681 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 682 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 683 | " (dropout): Dropout(p=0.1, inplace=False)\n", 684 | " )\n", 685 | " )\n", 686 | " (intermediate): BertIntermediate(\n", 687 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 688 | " )\n", 689 | " (output): BertOutput(\n", 690 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 691 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 692 | " (dropout): Dropout(p=0.1, inplace=False)\n", 693 | " )\n", 694 | " )\n", 695 | " (3): BertLayer(\n", 696 | " (attention): BertAttention(\n", 697 | " (self): BertSelfAttention(\n", 698 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 699 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 700 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 701 | " (dropout): Dropout(p=0.1, inplace=False)\n", 702 | " )\n", 703 | " (output): BertSelfOutput(\n", 704 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 705 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 706 | " (dropout): Dropout(p=0.1, inplace=False)\n", 707 | " )\n", 708 | " )\n", 709 | " (intermediate): BertIntermediate(\n", 710 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 711 | " )\n", 712 | " (output): BertOutput(\n", 713 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 714 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 715 | " (dropout): Dropout(p=0.1, inplace=False)\n", 716 | " )\n", 717 | " )\n", 718 | " (4): BertLayer(\n", 719 | " (attention): BertAttention(\n", 720 | " (self): BertSelfAttention(\n", 721 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 722 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 723 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 724 | " (dropout): Dropout(p=0.1, inplace=False)\n", 725 | " )\n", 726 | " (output): BertSelfOutput(\n", 727 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 728 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 729 | " (dropout): Dropout(p=0.1, inplace=False)\n", 730 | " )\n", 731 | " )\n", 732 | " (intermediate): BertIntermediate(\n", 733 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 734 | " )\n", 735 | " (output): BertOutput(\n", 736 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 737 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 738 | " (dropout): Dropout(p=0.1, inplace=False)\n", 739 | " )\n", 740 | " )\n", 741 | " (5): BertLayer(\n", 742 | " (attention): BertAttention(\n", 743 | " (self): BertSelfAttention(\n", 744 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 745 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 746 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 747 | " (dropout): Dropout(p=0.1, inplace=False)\n", 748 | " )\n", 749 | " (output): BertSelfOutput(\n", 750 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 751 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 752 | " (dropout): Dropout(p=0.1, inplace=False)\n", 753 | " )\n", 754 | " )\n", 755 | " (intermediate): BertIntermediate(\n", 756 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 757 | " )\n", 758 | " (output): BertOutput(\n", 759 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 760 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 761 | " (dropout): Dropout(p=0.1, inplace=False)\n", 762 | " )\n", 763 | " )\n", 764 | " (6): BertLayer(\n", 765 | " (attention): BertAttention(\n", 766 | " (self): BertSelfAttention(\n", 767 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 768 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 769 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 770 | " (dropout): Dropout(p=0.1, inplace=False)\n", 771 | " )\n", 772 | " (output): BertSelfOutput(\n", 773 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 774 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 775 | " (dropout): Dropout(p=0.1, inplace=False)\n", 776 | " )\n", 777 | " )\n", 778 | " (intermediate): BertIntermediate(\n", 779 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 780 | " )\n", 781 | " (output): BertOutput(\n", 782 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 783 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 784 | " (dropout): Dropout(p=0.1, inplace=False)\n", 785 | " )\n", 786 | " )\n", 787 | " (7): BertLayer(\n", 788 | " (attention): BertAttention(\n", 789 | " (self): BertSelfAttention(\n", 790 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 791 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 792 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 793 | " (dropout): Dropout(p=0.1, inplace=False)\n", 794 | " )\n", 795 | " (output): BertSelfOutput(\n", 796 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 797 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 798 | " (dropout): Dropout(p=0.1, inplace=False)\n", 799 | " )\n", 800 | " )\n", 801 | " (intermediate): BertIntermediate(\n", 802 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 803 | " )\n", 804 | " (output): BertOutput(\n", 805 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 806 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 807 | " (dropout): Dropout(p=0.1, inplace=False)\n", 808 | " )\n", 809 | " )\n", 810 | " (8): BertLayer(\n", 811 | " (attention): BertAttention(\n", 812 | " (self): BertSelfAttention(\n", 813 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 814 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 815 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 816 | " (dropout): Dropout(p=0.1, inplace=False)\n", 817 | " )\n", 818 | " (output): BertSelfOutput(\n", 819 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 820 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 821 | " (dropout): Dropout(p=0.1, inplace=False)\n", 822 | " )\n", 823 | " )\n", 824 | " (intermediate): BertIntermediate(\n", 825 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 826 | " )\n", 827 | " (output): BertOutput(\n", 828 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 829 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 830 | " (dropout): Dropout(p=0.1, inplace=False)\n", 831 | " )\n", 832 | " )\n", 833 | " (9): BertLayer(\n", 834 | " (attention): BertAttention(\n", 835 | " (self): BertSelfAttention(\n", 836 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 837 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 838 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 839 | " (dropout): Dropout(p=0.1, inplace=False)\n", 840 | " )\n", 841 | " (output): BertSelfOutput(\n", 842 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 843 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 844 | " (dropout): Dropout(p=0.1, inplace=False)\n", 845 | " )\n", 846 | " )\n", 847 | " (intermediate): BertIntermediate(\n", 848 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 849 | " )\n", 850 | " (output): BertOutput(\n", 851 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 852 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 853 | " (dropout): Dropout(p=0.1, inplace=False)\n", 854 | " )\n", 855 | " )\n", 856 | " (10): BertLayer(\n", 857 | " (attention): BertAttention(\n", 858 | " (self): BertSelfAttention(\n", 859 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 860 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 861 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 862 | " (dropout): Dropout(p=0.1, inplace=False)\n", 863 | " )\n", 864 | " (output): BertSelfOutput(\n", 865 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 866 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 867 | " (dropout): Dropout(p=0.1, inplace=False)\n", 868 | " )\n", 869 | " )\n", 870 | " (intermediate): BertIntermediate(\n", 871 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 872 | " )\n", 873 | " (output): BertOutput(\n", 874 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 875 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 876 | " (dropout): Dropout(p=0.1, inplace=False)\n", 877 | " )\n", 878 | " )\n", 879 | " (11): BertLayer(\n", 880 | " (attention): BertAttention(\n", 881 | " (self): BertSelfAttention(\n", 882 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 883 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 884 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 885 | " (dropout): Dropout(p=0.1, inplace=False)\n", 886 | " )\n", 887 | " (output): BertSelfOutput(\n", 888 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 889 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 890 | " (dropout): Dropout(p=0.1, inplace=False)\n", 891 | " )\n", 892 | " )\n", 893 | " (intermediate): BertIntermediate(\n", 894 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 895 | " )\n", 896 | " (output): BertOutput(\n", 897 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 898 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 899 | " (dropout): Dropout(p=0.1, inplace=False)\n", 900 | " )\n", 901 | " )\n", 902 | " )\n", 903 | " )\n", 904 | " (pooler): BertPooler(\n", 905 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 906 | " (activation): Tanh()\n", 907 | " )\n", 908 | " )\n", 909 | " (dropout): Dropout(p=0.1, inplace=False)\n", 910 | " (classifier): Linear(in_features=768, out_features=18, bias=True)\n", 911 | ")" 912 | ] 913 | }, 914 | "execution_count": 16, 915 | "metadata": { 916 | "tags": [] 917 | }, 918 | "output_type": "execute_result" 919 | } 920 | ], 921 | "source": [ 922 | "model = BERTClass()\n", 923 | "model.to(dev)" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": null, 929 | "metadata": { 930 | "id": "DN_u9NC5aaa_" 931 | }, 932 | "outputs": [], 933 | "source": [ 934 | "optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)" 935 | ] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "metadata": { 940 | "id": "CBg-f8rYD6s-" 941 | }, 942 | "source": [ 943 | "\n", 944 | "### Fine Tuning the Model\n", 945 | "\n", 946 | "After all the effort of loading and preparing the data and datasets, creating the model and defining its loss and optimizer. This is probably the easier steps in the process.\n", 947 | "\n", 948 | "Here we define a training function that trains the model on the training dataset created above, specified number of times (EPOCH), An epoch defines how many times the complete data will be passed through the network.\n", 949 | "\n", 950 | "Following events happen in this function to fine tune the neural network:\n", 951 | "- The dataloader passes data to the model based on the batch size.\n", 952 | "- Subsequent output from the model and the actual category are compared to calculate the loss.\n", 953 | "- Loss value is used to optimize the weights of the neurons in the network.\n", 954 | "- After every 500 steps the loss value is printed in the console.\n", 955 | "\n", 956 | "As you can see just in 1 epoch by the final step the model was working with a miniscule loss of 0.08503091335296631 i.e. the output is extremely close to the actual output." 957 | ] 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": null, 962 | "metadata": { 963 | "id": "8aQ6WCk2a-Vd" 964 | }, 965 | "outputs": [], 966 | "source": [ 967 | "def train(epoch):\n", 968 | " model.train()\n", 969 | " for _,data in enumerate(training_loader, 0):\n", 970 | " ids = data['ids'].to(dev, dtype = torch.long)\n", 971 | " mask = data['mask'].to(dev, dtype = torch.long)\n", 972 | " targets = data['tags'].to(dev, dtype = torch.long)\n", 973 | "\n", 974 | " loss = model(ids, mask, labels = targets)[0]\n", 975 | "\n", 976 | " # optimizer.zero_grad()\n", 977 | " if _%500==0:\n", 978 | " print(f'Epoch: {epoch}, Loss: {loss.item()}')\n", 979 | "\n", 980 | " optimizer.zero_grad()\n", 981 | " loss.backward()\n", 982 | " xm.optimizer_step(optimizer)\n", 983 | " xm.mark_step()" 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": null, 989 | "metadata": { 990 | "colab": { 991 | "base_uri": "https://localhost:8080/", 992 | "height": 689 993 | }, 994 | "collapsed": true, 995 | "id": "50oMGTe0bvl0", 996 | "outputId": "292ede16-6da3-460f-d174-f7ee10c729e0" 997 | }, 998 | "outputs": [ 999 | { 1000 | "name": "stdout", 1001 | "output_type": "stream", 1002 | "text": [ 1003 | "Epoch: 0, Loss: 0.21416641771793365\n", 1004 | "Epoch: 0, Loss: 0.08791390806436539\n", 1005 | "Epoch: 0, Loss: 0.1277497559785843\n", 1006 | "Epoch: 0, Loss: 0.25511449575424194\n", 1007 | "Epoch: 0, Loss: 0.11072967946529388\n", 1008 | "Epoch: 0, Loss: 0.1202322468161583\n", 1009 | "Epoch: 0, Loss: 0.16198261082172394\n", 1010 | "Epoch: 0, Loss: 0.31682807207107544\n", 1011 | "Epoch: 1, Loss: 0.09211093187332153\n", 1012 | "Epoch: 1, Loss: 0.15079179406166077\n", 1013 | "Epoch: 1, Loss: 0.1959223747253418\n", 1014 | "Epoch: 1, Loss: 0.09143798053264618\n", 1015 | "Epoch: 1, Loss: 0.29411888122558594\n", 1016 | "Epoch: 1, Loss: 0.11708520352840424\n", 1017 | "Epoch: 1, Loss: 0.11245028674602509\n", 1018 | "Epoch: 1, Loss: 0.14728033542633057\n", 1019 | "Epoch: 2, Loss: 0.1607980579137802\n", 1020 | "Epoch: 2, Loss: 0.08060580492019653\n", 1021 | "Epoch: 2, Loss: 0.14363577961921692\n", 1022 | "Epoch: 2, Loss: 0.12225533276796341\n", 1023 | "Epoch: 2, Loss: 0.10335233807563782\n", 1024 | "Epoch: 2, Loss: 0.04923604056239128\n", 1025 | "Epoch: 2, Loss: 0.09237729012966156\n", 1026 | "Epoch: 2, Loss: 0.12473192811012268\n", 1027 | "Epoch: 3, Loss: 0.09085617959499359\n", 1028 | "Epoch: 3, Loss: 0.09351193159818649\n", 1029 | "Epoch: 3, Loss: 0.06728512048721313\n", 1030 | "Epoch: 3, Loss: 0.1666068434715271\n", 1031 | "Epoch: 3, Loss: 0.19255675375461578\n", 1032 | "Epoch: 3, Loss: 0.16131675243377686\n", 1033 | "Epoch: 3, Loss: 0.15462705492973328\n", 1034 | "Epoch: 3, Loss: 0.18679684400558472\n", 1035 | "Epoch: 4, Loss: 0.11378277838230133\n", 1036 | "Epoch: 4, Loss: 0.025372153148055077\n", 1037 | "Epoch: 4, Loss: 0.08231651782989502\n", 1038 | "Epoch: 4, Loss: 0.2682102620601654\n", 1039 | "Epoch: 4, Loss: 0.05264609679579735\n", 1040 | "Epoch: 4, Loss: 0.056522976607084274\n", 1041 | "Epoch: 4, Loss: 0.15710100531578064\n", 1042 | "Epoch: 4, Loss: 0.08503091335296631\n" 1043 | ] 1044 | } 1045 | ], 1046 | "source": [ 1047 | "for epoch in range(5):\n", 1048 | " train(epoch)" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "markdown", 1053 | "metadata": { 1054 | "id": "GTskyuG1D6tB" 1055 | }, 1056 | "source": [ 1057 | "\n", 1058 | "### Validating the Model\n", 1059 | "\n", 1060 | "During the validation stage we pass the unseen data(Testing Dataset) to the model. This step determines how good the model performs on the unseen data.\n", 1061 | "\n", 1062 | "This unseen data is the 30% of `ner.csv` which was seperated during the Dataset creation stage.\n", 1063 | "During the validation stage the weights of the model are not updated. Only the final output is compared to the actual value. This comparison is then used to calcuate the accuracy of the model.\n", 1064 | "\n", 1065 | "The metric used for measuring the performance of model for these problem statements is called F1 score. We will create a helper function for helping us with f1 score calcuation and also import a library for the same. `seqeval`" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "code", 1070 | "execution_count": null, 1071 | "metadata": { 1072 | "colab": { 1073 | "base_uri": "https://localhost:8080/", 1074 | "height": 34 1075 | }, 1076 | "id": "6OckC0XNkWWm", 1077 | "outputId": "d27682ee-c34f-4811-d791-0c922afa8b05" 1078 | }, 1079 | "outputs": [ 1080 | { 1081 | "name": "stdout", 1082 | "output_type": "stream", 1083 | "text": [ 1084 | " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 1085 | ] 1086 | } 1087 | ], 1088 | "source": [ 1089 | "from seqeval.metrics import f1_score\n", 1090 | "\n", 1091 | "def flat_accuracy(preds, labels):\n", 1092 | " flat_preds = np.argmax(preds, axis=2).flatten()\n", 1093 | " flat_labels = labels.flatten()\n", 1094 | " return np.sum(flat_preds == flat_labels)/len(flat_labels)" 1095 | ] 1096 | }, 1097 | { 1098 | "cell_type": "code", 1099 | "execution_count": null, 1100 | "metadata": { 1101 | "id": "9zwFzzBriLMO" 1102 | }, 1103 | "outputs": [], 1104 | "source": [ 1105 | "def valid(model, testing_loader):\n", 1106 | " model.eval()\n", 1107 | " eval_loss = 0; eval_accuracy = 0\n", 1108 | " n_correct = 0; n_wrong = 0; total = 0\n", 1109 | " predictions , true_labels = [], []\n", 1110 | " nb_eval_steps, nb_eval_examples = 0, 0\n", 1111 | " with torch.no_grad():\n", 1112 | " for _, data in enumerate(testing_loader, 0):\n", 1113 | " ids = data['ids'].to(dev, dtype = torch.long)\n", 1114 | " mask = data['mask'].to(dev, dtype = torch.long)\n", 1115 | " targets = data['tags'].to(dev, dtype = torch.long)\n", 1116 | "\n", 1117 | " output = model(ids, mask, labels=targets)\n", 1118 | " loss, logits = output[:2]\n", 1119 | " logits = logits.detach().cpu().numpy()\n", 1120 | " label_ids = targets.to('cpu').numpy()\n", 1121 | " predictions.extend([list(p) for p in np.argmax(logits, axis=2)])\n", 1122 | " true_labels.append(label_ids)\n", 1123 | " accuracy = flat_accuracy(logits, label_ids)\n", 1124 | " eval_loss += loss.mean().item()\n", 1125 | " eval_accuracy += accuracy\n", 1126 | " nb_eval_examples += ids.size(0)\n", 1127 | " nb_eval_steps += 1\n", 1128 | " eval_loss = eval_loss/nb_eval_steps\n", 1129 | " print(\"Validation loss: {}\".format(eval_loss))\n", 1130 | " print(\"Validation Accuracy: {}\".format(eval_accuracy/nb_eval_steps))\n", 1131 | " pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]\n", 1132 | " valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]\n", 1133 | " print(\"F1-Score: {}\".format(f1_score(pred_tags, valid_tags)))" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": null, 1139 | "metadata": { 1140 | "id": "A7UAHVEED6tB" 1141 | }, 1142 | "outputs": [], 1143 | "source": [ 1144 | "# To get the results on the validation set. This data is not seen by the model\n", 1145 | "\n", 1146 | "valid(model, testing_loader)" 1147 | ] 1148 | } 1149 | ], 1150 | "metadata": { 1151 | "accelerator": "TPU", 1152 | "colab": { 1153 | "name": "transformers_ner.ipynb", 1154 | "provenance": [] 1155 | }, 1156 | "kernelspec": { 1157 | "display_name": "Python 3", 1158 | "language": "python", 1159 | "name": "python3" 1160 | }, 1161 | "language_info": { 1162 | "codemirror_mode": { 1163 | "name": "ipython", 1164 | "version": 3 1165 | }, 1166 | "file_extension": ".py", 1167 | "mimetype": "text/x-python", 1168 | "name": "python", 1169 | "nbconvert_exporter": "python", 1170 | "pygments_lexer": "ipython3", 1171 | "version": "3.7.6" 1172 | } 1173 | }, 1174 | "nbformat": 4, 1175 | "nbformat_minor": 0 1176 | } -------------------------------------------------------------------------------- /notebooks/transformers_summarization_wandb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "wJNsDzk6D7X-" 7 | }, 8 | "source": [ 9 | "# Fine Tuning Transformer for Summary Generation" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "7axYY6_iD7YB" 16 | }, 17 | "source": [ 18 | "\n", 19 | "### Introduction\n", 20 | "\n", 21 | "In this tutorial we will be fine tuning a transformer model for **Summarization Task**.\n", 22 | "In this task a summary of a given article/document is generated when passed through a network. There are 2 types of summary generation mechanisms:\n", 23 | "\n", 24 | "1. ***Extractive Summary:*** the network calculates the most important sentences from the article and gets them together to provide the most meaningful information from the article.\n", 25 | "2. ***Abstractive Summary***: The network creates new sentences to encapsulate maximum gist of the article and generates that as output. The sentences in the summary may or may not be contained in the article.\n", 26 | "\n", 27 | "In this tutorial we will be generating ***Abstractive Summary***.\n", 28 | "\n", 29 | "#### Flow of the notebook\n", 30 | "\n", 31 | "* As with all the tutorials previously, this notebook also follows a easy to follow steps. Making the process of fine tuning and training a Transformers model a straight forward task.\n", 32 | "* However, unlike the other notebooks, in the tutorial, most of the sections have been created into functions, and they are called from the `main()` in the end of the notebook.\n", 33 | "* This is done to leverage the [Weights and Biases Service](https://www.wandb.com/) WandB in short.\n", 34 | "* It is a experiment tracking, parameter optimization and artifact management service. That can be very easily integrated to any of the Deep learning or Machine learning frameworks.\n", 35 | "\n", 36 | "The notebook will be divided into separate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:\n", 37 | "\n", 38 | "1. [Preparing Environment and Importing Libraries](#section01)\n", 39 | "2. [Preparing the Dataset for data processing: Class](#section02)\n", 40 | "3. [Fine Tuning the Model: Function](#section03)\n", 41 | "4. [Validating the Model Performance: Function](#section04)\n", 42 | "5. [Main Function](#section05)\n", 43 | " * [Initializing WandB](#section501)\n", 44 | " * [Importing and Pre-Processing the domain data](#section502)\n", 45 | " * [Creation of Dataset and Dataloader](#section503)\n", 46 | " * [Neural Network and Optimizer](#section504)\n", 47 | " * [Training Model and Logging to WandB](#section505)\n", 48 | " * [Validation and generation of Summary](#section506)\n", 49 | "6. [Examples of the Summary Generated from the model](#section06)\n", 50 | "\n", 51 | "\n", 52 | "#### Technical Details\n", 53 | "\n", 54 | "This script leverages on multiple tools designed by other teams. Details of the tools used below. Please ensure that these elements are present in your setup to successfully implement this script.\n", 55 | "\n", 56 | "- **Data**:\n", 57 | "\t- We are using the News Summary dataset available at [Kaggle](https://www.kaggle.com/sunnysai12345/news-summary)\n", 58 | "\t- This dataset is the collection created from Newspapers published in India, extracting, details that are listed below. We are referring only to the first csv file from the data dump: `news_summary.csv`\n", 59 | "\t- There are`4514` rows of data. Where each row has the following data-point:\n", 60 | "\t\t- **author** : Author of the article\n", 61 | "\t\t- **date** : Date the article was published\n", 62 | "\t\t- **headline**: Headline for the published article\n", 63 | "\t\t- **read_more** : URL for the article to follow online\n", 64 | "\t\t- **text**: This is the summary of the article\n", 65 | "\t\t- **ctext**: This is the complete article\n", 66 | "\n", 67 | "\n", 68 | "- **Language Model Used**:\n", 69 | " - This notebook uses one of the most recent and novel transformers model ***T5***. [Research Paper](https://arxiv.org/abs/1910.10683) \n", 70 | " - ***T5*** in many ways is one of its kind transformers architecture that not only gives state of the art results in many NLP tasks, but also has a very radical approach to NLP tasks.\n", 71 | " - **Text-2-Text** - According to the graphic taken from the T5 paper. All NLP tasks are converted to a **text-to-text** problem. Tasks such as translation, classification, summarization and question answering, all of them are treated as a text-to-text conversion problem, rather than seen as separate unique problem statements.\n", 72 | " - **Unified approach for NLP Deep Learning** - Since the task is reflected purely in the text input and output, you can use the same model, objective, training procedure, and decoding process to ANY task. Above framework can be used for any task - show Q&A, summarization, etc.\n", 73 | " - We will be taking inputs from the T5 paper to prepare our dataset prior to fine tuning and training. \n", 74 | " - [Documentation for python](https://huggingface.co/transformers/model_doc/t5.html)\n", 75 | "\n", 76 | "![**Each NLP problem as a “text-to-text” problem** - input: text, output: text](https://miro.medium.com/max/4006/1*D0J1gNQf8vrrUpKeyD8wPA.png)\n", 77 | "\n", 78 | "\n", 79 | "\n", 80 | "- Hardware Requirements:\n", 81 | "\t- Python 3.6 and above\n", 82 | "\t- Pytorch, Transformers and\n", 83 | "\t- All the stock Python ML Library\n", 84 | "\t- GPU enabled setup\n", 85 | " \n", 86 | "\n", 87 | "- **Script Objective**:\n", 88 | "\t- The objective of this script is to fine tune ***T5 *** to be able to generate summary, that a close to or better than the actual summary while ensuring the important information from the article is not lost.\n", 89 | "\n", 90 | "---\n", 91 | "NOTE:\n", 92 | "We are using the Weights and Biases Tool-set in this tutorial. The different components will be explained as we go through the article." 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "id": "KXwvPRsbD7YD" 99 | }, 100 | "source": [ 101 | "\n", 102 | "### Preparing Environment and Importing Libraries\n", 103 | "\n", 104 | "At this step we will be installing the necessary libraries followed by importing the libraries and modules needed to run our script.\n", 105 | "We will be installing:\n", 106 | "* transformers\n", 107 | "* wandb\n", 108 | "\n", 109 | "Libraries imported are:\n", 110 | "* Pandas\n", 111 | "* Pytorch\n", 112 | "* Pytorch Utils for Dataset and Dataloader\n", 113 | "* Transformers\n", 114 | "* T5 Model and Tokenizer\n", 115 | "* wandb\n", 116 | "\n", 117 | "Followed by that we will preapre the device for CUDA execeution. This configuration is needed if you want to leverage on onboard GPU. First we will check the GPU avaiable to us, using the nvidia command followed by defining our device.\n", 118 | "\n", 119 | "Finally, we will be logging into the [wandb](https://www.wandb.com/) serice using the login command" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "colab": { 127 | "base_uri": "https://localhost:8080/", 128 | "height": 316 129 | }, 130 | "id": "WD_vnyLXZQzD", 131 | "outputId": "b2ff57b8-a147-4893-80bd-e40d18042f98" 132 | }, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "\u001b[K |████████████████████████████████| 665kB 3.5MB/s \n", 139 | "\u001b[K |████████████████████████████████| 890kB 9.6MB/s \n", 140 | "\u001b[K |████████████████████████████████| 1.1MB 22.0MB/s \n", 141 | "\u001b[K |████████████████████████████████| 3.8MB 30.8MB/s \n", 142 | "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 143 | "\u001b[K |████████████████████████████████| 1.4MB 3.4MB/s \n", 144 | "\u001b[K |████████████████████████████████| 102kB 9.8MB/s \n", 145 | "\u001b[K |████████████████████████████████| 460kB 18.8MB/s \n", 146 | "\u001b[K |████████████████████████████████| 102kB 9.9MB/s \n", 147 | "\u001b[K |████████████████████████████████| 112kB 23.9MB/s \n", 148 | "\u001b[K |████████████████████████████████| 71kB 8.8MB/s \n", 149 | "\u001b[K |████████████████████████████████| 71kB 7.5MB/s \n", 150 | "\u001b[?25h Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 151 | " Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 152 | " Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 153 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 154 | " Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "!pip install transformers -q\n", 160 | "!pip install wandb -q\n", 161 | "\n", 162 | "# Code for TPU packages install\n", 163 | "# !curl -q https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", 164 | "# !python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "colab": { 172 | "base_uri": "https://localhost:8080/", 173 | "height": 34 174 | }, 175 | "id": "pzM1_ykHaFur", 176 | "outputId": "58fa0ba8-b486-4b26-aaea-c0331b343b70" 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stderr", 181 | "output_type": "stream", 182 | "text": [ 183 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# Importing stock libraries\n", 189 | "import numpy as np\n", 190 | "import pandas as pd\n", 191 | "import torch\n", 192 | "import torch.nn.functional as F\n", 193 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", 194 | "\n", 195 | "# Importing the T5 modules from huggingface/transformers\n", 196 | "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", 197 | "\n", 198 | "# WandB – Import the wandb library\n", 199 | "import wandb" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "colab": { 207 | "base_uri": "https://localhost:8080/", 208 | "height": 316 209 | }, 210 | "id": "KvPxXdKJguYB", 211 | "outputId": "6c523635-a25a-429b-cbd8-7b8bf9636972" 212 | }, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "Mon Jun 1 03:23:27 2020 \n", 219 | "+-----------------------------------------------------------------------------+\n", 220 | "| NVIDIA-SMI 440.82 Driver Version: 418.67 CUDA Version: 10.1 |\n", 221 | "|-------------------------------+----------------------+----------------------+\n", 222 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 223 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 224 | "|===============================+======================+======================|\n", 225 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", 226 | "| N/A 39C P0 25W / 250W | 0MiB / 16280MiB | 0% Default |\n", 227 | "+-------------------------------+----------------------+----------------------+\n", 228 | " \n", 229 | "+-----------------------------------------------------------------------------+\n", 230 | "| Processes: GPU Memory |\n", 231 | "| GPU PID Type Process name Usage |\n", 232 | "|=============================================================================|\n", 233 | "| No running processes found |\n", 234 | "+-----------------------------------------------------------------------------+\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "# Checking out the GPU we have access to. This is output is from the google colab version.\n", 240 | "!nvidia-smi" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": { 247 | "id": "NLxxwd1scQNv" 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "# # Setting up the device for GPU usage\n", 252 | "from torch import cuda\n", 253 | "device = 'cuda' if cuda.is_available() else 'cpu'\n", 254 | "\n", 255 | "# Preparing for TPU usage\n", 256 | "# import torch_xla\n", 257 | "# import torch_xla.core.xla_model as xm\n", 258 | "# device = xm.xla_device()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "colab": { 266 | "base_uri": "https://localhost:8080/", 267 | "height": 87 268 | }, 269 | "id": "L-ePh9dEKXMw", 270 | "outputId": "a35fd305-1c09-48ff-978c-fa1d0762c5e2" 271 | }, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://app.wandb.ai/authorize\n", 278 | "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter: \n", 279 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n", 280 | "\u001b[32mSuccessfully logged in to Weights & Biases!\u001b[0m\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "# Login to wandb to log the model run and all the parameters\n", 286 | "!wandb login" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": { 292 | "id": "tonuYAzGD7YI" 293 | }, 294 | "source": [ 295 | "\n", 296 | "### Preparing the Dataset for data processing: Class\n", 297 | "\n", 298 | "We will start with creation of Dataset class - This defines how the text is pre-processed before sending it to the neural network. This dataset will be used the the Dataloader method that will feed the data in batches to the neural network for suitable training and processing.\n", 299 | "The Dataloader and Dataset will be used inside the `main()`.\n", 300 | "Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)\n", 301 | "\n", 302 | "#### *CustomDataset* Dataset Class\n", 303 | "- This class is defined to accept the Dataframe as input and generate tokenized output that is used by the **T5** model for training.\n", 304 | "- We are using the **T5** tokenizer to tokenize the data in the `text` and `ctext` column of the dataframe.\n", 305 | "- The tokenizer uses the ` batch_encode_plus` method to perform tokenization and generate the necessary outputs, namely: `source_id`, `source_mask` from the actual text and `target_id` and `target_mask` from the summary text.\n", 306 | "- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/t5.html#t5tokenizer)\n", 307 | "- The *CustomDataset* class is used to create 2 datasets, for training and for validation.\n", 308 | "- *Training Dataset* is used to fine tune the model: **80% of the original data**\n", 309 | "- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training.\n", 310 | "\n", 311 | "#### Dataloader: Called inside the `main()`\n", 312 | "- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of data loaded to the memory and then passed to the neural network needs to be controlled.\n", 313 | "- This control is achieved using the parameters such as `batch_size` and `max_len`.\n", 314 | "- Training and Validation dataloaders are used in the training and validation part of the flow respectively" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": { 321 | "id": "932p8NhxeNw4" 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "# Creating a custom dataset for reading the dataframe and loading it into the dataloader to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions\n", 326 | "\n", 327 | "class CustomDataset(Dataset):\n", 328 | "\n", 329 | " def __init__(self, dataframe, tokenizer, source_len, summ_len):\n", 330 | " self.tokenizer = tokenizer\n", 331 | " self.data = dataframe\n", 332 | " self.source_len = source_len\n", 333 | " self.summ_len = summ_len\n", 334 | " self.text = self.data.text\n", 335 | " self.ctext = self.data.ctext\n", 336 | "\n", 337 | " def __len__(self):\n", 338 | " return len(self.text)\n", 339 | "\n", 340 | " def __getitem__(self, index):\n", 341 | " ctext = str(self.ctext[index])\n", 342 | " ctext = ' '.join(ctext.split())\n", 343 | "\n", 344 | " text = str(self.text[index])\n", 345 | " text = ' '.join(text.split())\n", 346 | "\n", 347 | " source = self.tokenizer.batch_encode_plus([ctext], max_length= self.source_len, pad_to_max_length=True,return_tensors='pt')\n", 348 | " target = self.tokenizer.batch_encode_plus([text], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')\n", 349 | "\n", 350 | " source_ids = source['input_ids'].squeeze()\n", 351 | " source_mask = source['attention_mask'].squeeze()\n", 352 | " target_ids = target['input_ids'].squeeze()\n", 353 | " target_mask = target['attention_mask'].squeeze()\n", 354 | "\n", 355 | " return {\n", 356 | " 'source_ids': source_ids.to(dtype=torch.long),\n", 357 | " 'source_mask': source_mask.to(dtype=torch.long),\n", 358 | " 'target_ids': target_ids.to(dtype=torch.long),\n", 359 | " 'target_ids_y': target_ids.to(dtype=torch.long)\n", 360 | " }" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "_beHF0NDD7YJ" 367 | }, 368 | "source": [ 369 | "\n", 370 | "### Fine Tuning the Model: Function\n", 371 | "\n", 372 | "Here we define a training function that trains the model on the training dataset created above, specified number of times (EPOCH), An epoch defines how many times the complete data will be passed through the network.\n", 373 | "\n", 374 | "This function is called in the `main()`\n", 375 | "\n", 376 | "Following events happen in this function to fine tune the neural network:\n", 377 | "- The epoch, tokenizer, model, device details, testing_ dataloader and optimizer are passed to the `train ()` when its called from the `main()`\n", 378 | "- The dataloader passes data to the model based on the batch size.\n", 379 | "- `language_model_labels` are calculated from the `target_ids` also, `source_id` and `attention_mask` are extracted.\n", 380 | "- The model outputs first element gives the loss for the forward pass.\n", 381 | "- Loss value is used to optimize the weights of the neurons in the network.\n", 382 | "- After every 10 steps the loss value is logged in the wandb service. This log is then used to generate graphs for analysis. Such as [these](https://app.wandb.ai/abhimishra-91/transformers_tutorials_summarization?workspace=user-abhimishra-91)\n", 383 | "- After every 500 steps the loss value is printed in the console." 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": { 390 | "id": "SaPAR7TWmxoM" 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "# Creating the training function. This will be called in the main function. It is run depending on the epoch value.\n", 395 | "# The model is put into train mode and then we wnumerate over the training loader and passed to the defined network\n", 396 | "\n", 397 | "def train(epoch, tokenizer, model, device, loader, optimizer):\n", 398 | " model.train()\n", 399 | " for _,data in enumerate(loader, 0):\n", 400 | " y = data['target_ids'].to(device, dtype = torch.long)\n", 401 | " y_ids = y[:, :-1].contiguous()\n", 402 | " lm_labels = y[:, 1:].clone().detach()\n", 403 | " lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100\n", 404 | " ids = data['source_ids'].to(device, dtype = torch.long)\n", 405 | " mask = data['source_mask'].to(device, dtype = torch.long)\n", 406 | "\n", 407 | " outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, lm_labels=lm_labels)\n", 408 | " loss = outputs[0]\n", 409 | "\n", 410 | " if _%10 == 0:\n", 411 | " wandb.log({\"Training Loss\": loss.item()})\n", 412 | "\n", 413 | " if _%500==0:\n", 414 | " print(f'Epoch: {epoch}, Loss: {loss.item()}')\n", 415 | "\n", 416 | " optimizer.zero_grad()\n", 417 | " loss.backward()\n", 418 | " optimizer.step()\n", 419 | " # xm.optimizer_step(optimizer)\n", 420 | " # xm.mark_step()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "_ObCrYi_D7YK" 427 | }, 428 | "source": [ 429 | "\n", 430 | "### Validating the Model Performance: Function\n", 431 | "\n", 432 | "During the validation stage we pass the unseen data(Testing Dataset), trained model, tokenizer and device details to the function to perform the validation run. This step generates new summary for dataset that it has not seen during the training session.\n", 433 | "\n", 434 | "This function is called in the `main()`\n", 435 | "\n", 436 | "This unseen data is the 20% of `news_summary.csv` which was seperated during the Dataset creation stage.\n", 437 | "During the validation stage the weights of the model are not updated. We use the generate method for generating new text for the summary.\n", 438 | "\n", 439 | "It depends on the `Beam-Search coding` method developed for sequence generation for models with LM head.\n", 440 | "\n", 441 | "The generated text and originally summary are decoded from tokens to text and returned to the `main()`" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "id": "j9TNdHlQ0CLz" 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "def validate(epoch, tokenizer, model, device, loader):\n", 453 | " model.eval()\n", 454 | " predictions = []\n", 455 | " actuals = []\n", 456 | " with torch.no_grad():\n", 457 | " for _, data in enumerate(loader, 0):\n", 458 | " y = data['target_ids'].to(device, dtype = torch.long)\n", 459 | " ids = data['source_ids'].to(device, dtype = torch.long)\n", 460 | " mask = data['source_mask'].to(device, dtype = torch.long)\n", 461 | "\n", 462 | " generated_ids = model.generate(\n", 463 | " input_ids = ids,\n", 464 | " attention_mask = mask,\n", 465 | " max_length=150,\n", 466 | " num_beams=2,\n", 467 | " repetition_penalty=2.5,\n", 468 | " length_penalty=1.0,\n", 469 | " early_stopping=True\n", 470 | " )\n", 471 | " preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]\n", 472 | " target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]\n", 473 | " if _%100==0:\n", 474 | " print(f'Completed {_}')\n", 475 | "\n", 476 | " predictions.extend(preds)\n", 477 | " actuals.extend(target)\n", 478 | " return predictions, actuals" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": { 484 | "id": "FOuUKIBQD7YK" 485 | }, 486 | "source": [ 487 | "\n", 488 | "### Main Function\n", 489 | "\n", 490 | "The `main()` as the name suggests is the central location to execute all the functions/flows created above in the notebook. The following steps are executed in the `main()`:\n", 491 | "\n", 492 | "\n", 493 | "\n", 494 | "#### Initializing WandB\n", 495 | "\n", 496 | "* The `main()` begins with initializing WandB run under a specific project. This command initiates a new run for each execution of this command.\n", 497 | "\n", 498 | "* Before we proceed any further i will give a brief overview of the **[WandB Service](https://www.wandb.com/)**\n", 499 | "\n", 500 | "* This service has been created to track ML experiments, Optimize the experiments and save artifacts. It is designed to seamlessly integrate with all the Machine Learning and Deep Learning Frameworks. Each script can be organized into *Project* and each execution of the script will be registered as a *run* in the respective project.\n", 501 | "\n", 502 | "* The service can be configured to log several default metrics, such a network weights, hardware usage, gradients and weights of the network.\n", 503 | "\n", 504 | "* It can also be used to log user defined metrics, such a loss in the `train()`.\n", 505 | "\n", 506 | "* This particular tutorial is logged in the project: **[transformers_tutorials_summarization](https://app.wandb.ai/abhimishra-91/transformers_tutorials_summarization?workspace=user-abhimishra-91)**\n", 507 | "\n", 508 | "**One of the charts from the project**\n", 509 | "![](https://github.com/abhimishra91/transformers-tutorials/blob/master/meta/wandb.png?raw=1)\n", 510 | "\n", 511 | "* Visit the project page to see the details of different runs and what information is logged by the service.\n", 512 | "\n", 513 | "* Following the initialization of the WandB service we define configuration parameters that will be used across the tutorial such as `batch_size`, `epoch`, `learning_rate` etc.\n", 514 | "\n", 515 | "* These parameters are also passed to the WandB config. The config construct with all the parameters can be optimized using the Sweep service from WandB. Currently, that is outof scope of this tutorial.\n", 516 | "\n", 517 | "* Next we defining seed values so that the experiment and results can be reproduced.\n", 518 | "\n", 519 | "\n", 520 | "\n", 521 | "#### Importing and Pre-Processing the domain data\n", 522 | "\n", 523 | "We will be working with the data and preparing it for fine tuning purposes.\n", 524 | "*Assuming that the `news_summary.csv` is already downloaded in your `data` folder*\n", 525 | "\n", 526 | "* The file is imported as a dataframe and give it the headers as per the documentation.\n", 527 | "* Cleaning the file to remove the unwanted columns.\n", 528 | "* A new string is added to the main article column `summarize: ` prior to the actual article. This is done because **T5** had similar formatting for the summarization dataset.\n", 529 | "* The final Dataframe will be something like this:\n", 530 | "\n", 531 | "|text|ctext|\n", 532 | "|--|--|\n", 533 | "|summary-1|summarize: article 1|\n", 534 | "|summary-2|summarize: article 2|\n", 535 | "|summary-3|summarize: article 3|\n", 536 | "\n", 537 | "* Top 5 rows of the dataframe are printed on the console.\n", 538 | "\n", 539 | "\n", 540 | "#### Creation of Dataset and Dataloader\n", 541 | "\n", 542 | "* The updated dataframe is divided into 80-20 ratio for test and validation.\n", 543 | "* Both the data-frames are passed to the `CustomerDataset` class for tokenization of the new articles and their summaries.\n", 544 | "* The tokenization is done using the length parameters passed to the class.\n", 545 | "* Train and Validation parameters are defined and passed to the `pytorch Dataloader contstruct` to create `train` and `validation` data loaders.\n", 546 | "* These dataloaders will be passed to `train()` and `validate()` respectively for training and validation action.\n", 547 | "* The shape of datasets is printed in the console.\n", 548 | "\n", 549 | "\n", 550 | "\n", 551 | "#### Neural Network and Optimizer\n", 552 | "\n", 553 | "* In this stage we define the model and optimizer that will be used for training and to update the weights of the network.\n", 554 | "* We are using the `t5-base` transformer model for our project. You can read about the `T5 model` and its features above.\n", 555 | "* We use the `T5ForConditionalGeneration.from_pretrained(\"t5-base\")` commad to define our model. The `T5ForConditionalGeneration` adds a Language Model head to our `T5 model`. The Language Model head allows us to generate text based on the training of `T5 model`.\n", 556 | "* We are using the `Adam` optimizer for our project. This has been a standard for all our tutorials and is something that can be changed updated to see how different optimizer perform with different learning rates.\n", 557 | "* There is also a scope for doing more with Optimizer such a decay, momentum to dynamically update the Learning rate and other parameters. All those concepts have been kept out of scope for these tutorials.\n", 558 | "\n", 559 | "\n", 560 | "\n", 561 | "#### Training Model and Logging to WandB\n", 562 | "\n", 563 | "* Now we log all the metrics in WandB project that we have initialized above.\n", 564 | "* Followed by that we call the `train()` with all the necessary parameters.\n", 565 | "* Loss at every 500th step is printed on the console.\n", 566 | "* Loss at every 10th step is logged as Loss in the WandB service.\n", 567 | "\n", 568 | "\n", 569 | "\n", 570 | "#### Validation and generation of Summary\n", 571 | "\n", 572 | "* After the training is completed, the validation step is initiated.\n", 573 | "* As defined in the validation function, the model weights are not updated. We use the fine tuned model to generate new summaries based on the article text.\n", 574 | "* An output is printed on the console giving a count of how many steps are complete after every 100th step.\n", 575 | "* The original summary and generated summary are converted into a list and returned to the main function.\n", 576 | "* Both the lists are used to create the final dataframe with 2 columns **Generated Summary** and **Actual Summary**\n", 577 | "* The dataframe is saved as a csv file in the local drive.\n", 578 | "* A qualitative analysis can be done with the Dataframe." 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": { 585 | "colab": { 586 | "base_uri": "https://localhost:8080/", 587 | "height": 691, 588 | "referenced_widgets": [ 589 | "694ec243104f470093820e5e0dbbfc8e", 590 | "730be2945e39401dac1ac1247cf2d5fb", 591 | "71a4e2121f554668b2fd6461de3b2dcb", 592 | "9218011c0e1544ffb8128e0add0439c6", 593 | "9754d7526eb249d6ba849cb30f415ffc", 594 | "6afe5bad5a2b4e8c94791949b0f08ead", 595 | "4fca290286b3402fae9ddd0b1ce54504", 596 | "ae78c7bf57934f03a0f553b8829e44d4", 597 | "5c5331a892a64b95a2a130d7cb953e27", 598 | "a84d01823be4413db834ac23cffd9c26", 599 | "7b6f595e9c6a45f9b1edc6bd5512d205", 600 | "662b4ba823df409d8169696c81dccb46", 601 | "94d2a8fa47a3440a8332cb36036ea68e", 602 | "98459f90e8e94b29b93f576b5fdebe58", 603 | "e46fe62d895043878985a1014c3d853a", 604 | "a7edcb37126443c298eca2251194bb66", 605 | "54502d44bc774e2cb8067b59faa3f1bf", 606 | "1e6cc5a0d54c4fdab326fcaa2660e781", 607 | "ec8ba72354fd4b1d92c3b7d061e3c464", 608 | "d9627b05dcda4043aa149c940d7d2f57", 609 | "caf8968694c64b4fbbf5a15ed4948c34", 610 | "bb96976987544a5f800e71f2124b97dd", 611 | "4ee0316176b245239434bd88fe8f8572", 612 | "a0002b369f23475ca440c59df649e450" 613 | ] 614 | }, 615 | "id": "ZtNs9ytpCow2", 616 | "outputId": "80545587-0a82-455a-a9ba-13eb3fcb1550" 617 | }, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/html": [ 622 | "\n", 623 | " Logging results to Weights & Biases (Documentation).
\n", 624 | " Project page: https://app.wandb.ai/abhimishra-91/transformers_tutorials_summarization
\n", 625 | " Run page: https://app.wandb.ai/abhimishra-91/transformers_tutorials_summarization/runs/2erhbv26
\n", 626 | " " 627 | ], 628 | "text/plain": [ 629 | "" 630 | ] 631 | }, 632 | "metadata": { 633 | "tags": [] 634 | }, 635 | "output_type": "display_data" 636 | }, 637 | { 638 | "data": { 639 | "application/vnd.jupyter.widget-view+json": { 640 | "model_id": "694ec243104f470093820e5e0dbbfc8e", 641 | "version_major": 2, 642 | "version_minor": 0 643 | }, 644 | "text/plain": [ 645 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" 646 | ] 647 | }, 648 | "metadata": { 649 | "tags": [] 650 | }, 651 | "output_type": "display_data" 652 | }, 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "\n", 658 | " text ctext\n", 659 | "0 The Administration of Union Territory Daman an... summarize: The Daman and Diu administration on...\n", 660 | "1 Malaika Arora slammed an Instagram user who tr... summarize: From her special numbers to TV?appe...\n", 661 | "2 The Indira Gandhi Institute of Medical Science... summarize: The Indira Gandhi Institute of Medi...\n", 662 | "3 Lashkar-e-Taiba's Kashmir commander Abu Dujana... summarize: Lashkar-e-Taiba's Kashmir commander...\n", 663 | "4 Hotels in Maharashtra will train their staff t... summarize: Hotels in Mumbai and other Indian c...\n", 664 | "FULL Dataset: (4514, 2)\n", 665 | "TRAIN Dataset: (3611, 2)\n", 666 | "TEST Dataset: (903, 2)\n" 667 | ] 668 | }, 669 | { 670 | "data": { 671 | "application/vnd.jupyter.widget-view+json": { 672 | "model_id": "5c5331a892a64b95a2a130d7cb953e27", 673 | "version_major": 2, 674 | "version_minor": 0 675 | }, 676 | "text/plain": [ 677 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1199.0, style=ProgressStyle(description…" 678 | ] 679 | }, 680 | "metadata": { 681 | "tags": [] 682 | }, 683 | "output_type": "display_data" 684 | }, 685 | { 686 | "name": "stdout", 687 | "output_type": "stream", 688 | "text": [ 689 | "\n" 690 | ] 691 | }, 692 | { 693 | "data": { 694 | "application/vnd.jupyter.widget-view+json": { 695 | "model_id": "54502d44bc774e2cb8067b59faa3f1bf", 696 | "version_major": 2, 697 | "version_minor": 0 698 | }, 699 | "text/plain": [ 700 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891691430.0, style=ProgressStyle(descri…" 701 | ] 702 | }, 703 | "metadata": { 704 | "tags": [] 705 | }, 706 | "output_type": "display_data" 707 | }, 708 | { 709 | "name": "stdout", 710 | "output_type": "stream", 711 | "text": [ 712 | "\n", 713 | "Initiating Fine-Tuning for the model on our dataset\n", 714 | "/n\n", 715 | "Epoch: 0, Loss: 5.861971378326416\n", 716 | "Epoch: 0, Loss: 1.923175573348999\n", 717 | "Epoch: 0, Loss: 1.489040493965149\n", 718 | "Epoch: 0, Loss: 1.9674493074417114\n", 719 | "Epoch: 1, Loss: 2.0224180221557617\n", 720 | "Epoch: 1, Loss: 1.2057034969329834\n", 721 | "Epoch: 1, Loss: 1.2185782194137573\n", 722 | "Epoch: 1, Loss: 1.6962525844573975\n", 723 | "Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe\n", 724 | "/n\n", 725 | "Completed 0\n", 726 | "Completed 100\n", 727 | "Completed 200\n", 728 | "Completed 300\n", 729 | "Completed 400\n", 730 | "Output Files generated for review\n" 731 | ] 732 | } 733 | ], 734 | "source": [ 735 | "def main():\n", 736 | " # WandB – Initialize a new run\n", 737 | " wandb.init(project=\"transformers_tutorials_summarization\")\n", 738 | "\n", 739 | " # WandB – Config is a variable that holds and saves hyperparameters and inputs\n", 740 | " # Defining some key variables that will be used later on in the training\n", 741 | " config = wandb.config # Initialize config\n", 742 | " config.TRAIN_BATCH_SIZE = 2 # input batch size for training (default: 64)\n", 743 | " config.VALID_BATCH_SIZE = 2 # input batch size for testing (default: 1000)\n", 744 | " config.TRAIN_EPOCHS = 2 # number of epochs to train (default: 10)\n", 745 | " config.VAL_EPOCHS = 1\n", 746 | " config.LEARNING_RATE = 1e-4 # learning rate (default: 0.01)\n", 747 | " config.SEED = 42 # random seed (default: 42)\n", 748 | " config.MAX_LEN = 512\n", 749 | " config.SUMMARY_LEN = 150\n", 750 | "\n", 751 | " # Set random seeds and deterministic pytorch for reproducibility\n", 752 | " torch.manual_seed(config.SEED) # pytorch random seed\n", 753 | " np.random.seed(config.SEED) # numpy random seed\n", 754 | " torch.backends.cudnn.deterministic = True\n", 755 | "\n", 756 | " # tokenzier for encoding the text\n", 757 | " tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n", 758 | "\n", 759 | "\n", 760 | " # Importing and Pre-Processing the domain data\n", 761 | " # Selecting the needed columns only.\n", 762 | " # Adding the summarzie text in front of the text. This is to format the dataset similar to how T5 model was trained for summarization task.\n", 763 | " df = pd.read_csv('./data/news_summary.csv',encoding='latin-1')\n", 764 | " df = df[['text','ctext']]\n", 765 | " df.ctext = 'summarize: ' + df.ctext\n", 766 | " print(df.head())\n", 767 | "\n", 768 | "\n", 769 | " # Creation of Dataset and Dataloader\n", 770 | " # Defining the train size. So 80% of the data will be used for training and the rest will be used for validation.\n", 771 | " train_size = 0.8\n", 772 | " train_dataset=df.sample(frac=train_size,random_state = config.SEED)\n", 773 | " val_dataset=df.drop(train_dataset.index).reset_index(drop=True)\n", 774 | " train_dataset = train_dataset.reset_index(drop=True)\n", 775 | "\n", 776 | " print(\"FULL Dataset: {}\".format(df.shape))\n", 777 | " print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n", 778 | " print(\"TEST Dataset: {}\".format(val_dataset.shape))\n", 779 | "\n", 780 | "\n", 781 | " # Creating the Training and Validation dataset for further creation of Dataloader\n", 782 | " training_set = CustomDataset(train_dataset, tokenizer, config.MAX_LEN, config.SUMMARY_LEN)\n", 783 | " val_set = CustomDataset(val_dataset, tokenizer, config.MAX_LEN, config.SUMMARY_LEN)\n", 784 | "\n", 785 | " # Defining the parameters for creation of dataloaders\n", 786 | " train_params = {\n", 787 | " 'batch_size': config.TRAIN_BATCH_SIZE,\n", 788 | " 'shuffle': True,\n", 789 | " 'num_workers': 0\n", 790 | " }\n", 791 | "\n", 792 | " val_params = {\n", 793 | " 'batch_size': config.VALID_BATCH_SIZE,\n", 794 | " 'shuffle': False,\n", 795 | " 'num_workers': 0\n", 796 | " }\n", 797 | "\n", 798 | " # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.\n", 799 | " training_loader = DataLoader(training_set, **train_params)\n", 800 | " val_loader = DataLoader(val_set, **val_params)\n", 801 | "\n", 802 | "\n", 803 | "\n", 804 | " # Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary.\n", 805 | " # Further this model is sent to device (GPU/TPU) for using the hardware.\n", 806 | " model = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n", 807 | " model = model.to(device)\n", 808 | "\n", 809 | " # Defining the optimizer that will be used to tune the weights of the network in the training session.\n", 810 | " optimizer = torch.optim.Adam(params = model.parameters(), lr=config.LEARNING_RATE)\n", 811 | "\n", 812 | " # Log metrics with wandb\n", 813 | " wandb.watch(model, log=\"all\")\n", 814 | " # Training loop\n", 815 | " print('Initiating Fine-Tuning for the model on our dataset')\n", 816 | "\n", 817 | " for epoch in range(config.TRAIN_EPOCHS):\n", 818 | " train(epoch, tokenizer, model, device, training_loader, optimizer)\n", 819 | "\n", 820 | "\n", 821 | " # Validation loop and saving the resulting file with predictions and acutals in a dataframe.\n", 822 | " # Saving the dataframe as predictions.csv\n", 823 | " print('Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe')\n", 824 | " for epoch in range(config.VAL_EPOCHS):\n", 825 | " predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)\n", 826 | " final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})\n", 827 | " final_df.to_csv('./models/predictions.csv')\n", 828 | " print('Output Files generated for review')\n", 829 | "\n", 830 | "if __name__ == '__main__':\n", 831 | " main()" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "metadata": { 837 | "id": "j8gvVcM4D7YL" 838 | }, 839 | "source": [ 840 | "\n", 841 | "### Examples of the Summary Generated from the model\n", 842 | "\n", 843 | "##### Example 1\n", 844 | "\n", 845 | "**Original Text**\n", 846 | "New Delhi, Apr 25 (PTI) Union minister Vijay Goel today batted for the unification of the three municipal corporations in the national capital saying a discussion over the issue was pertinent. The BJP leader, who was confident of a good show by his party in the MCD polls, the results of which will be declared tomorrow, said the civic bodies needed to be \"revamped\" in order to deliver the services to the people more effectively. The first thing needed was a discussion on the unification of the three municipal corporations and there should also be an end to the practice of sending Delhi government officials to serve in the civic bodies, said the Union Minister of State (Independent Charge) for Youth Affairs and Sports. \"Barring one, the two other civic bodies have been incurring losses. It would be more fruitful and efficient if all the three were merged,\" he said, referring to the north, south and east Delhi municipal corporations. The erstwhile Municipal Corporation of Delhi (MCD) was trifurcated into NDMC, SDMC and EDMC by the then Sheila Dikshit-led Delhi government in 2012. Goel predicted a \"thumping\" victory for the BJP in the MCD polls. He said the newly-elected BJP councillors will be trained on the functioning of the civic bodies and dealing with the bureaucracy.\n", 847 | "\n", 848 | "\n", 849 | "**Original Summary**\n", 850 | "Union Minister Vijay Goel has favoured unification of three MCDs ? North, South and East ? in order to deliver the services more effectively. \"Barring one, the two other civic bodies have been incurring losses. It would be more fruitful and efficient if all the three were merged,\" he said. MCD was trifurcated into EDMC, NDMC and SDMC in 2012.\n", 851 | "\n", 852 | "**Generated Summary**\n", 853 | "BJP leader Vijay Goel on Saturday batted for the unification of three municipal corporations in the national capital saying a discussion over this was pertinent. \"Barring one, two other civic bodies have been incurring losses,\" said Goels. The erstwhile Municipal Corporations of Delhi (MCD) were trifurcated into NDMC and SDMC by the then Sheilha Dikshi-led government in 2012. Notably, the MCD poll results will be declared tomorrow." 854 | ] 855 | }, 856 | { 857 | "cell_type": "markdown", 858 | "metadata": { 859 | "id": "b5IcsdrLD7YL" 860 | }, 861 | "source": [ 862 | "##### Example 2\n", 863 | "\n", 864 | "**Original Text**\n", 865 | "After much wait, the first UDAN flight took off from Shimla today after being flagged off by Prime Minister Narendra Modi.The flight will be operated by Alliance Air, the regional arm of Air India. PM Narendra Modi handed over boarding passes to some of passengers travelling via the first UDAN flight at the Shimla airport.Tomorrow PM @narendramodi will flag off the first UDAN flight under the Regional Connectivity Scheme, on Shimla-Delhi sector.Air India yesterday opened bookings for the first launch flight from Shimla to Delhi with all inclusive fares starting at Rs2,036.THE GREAT 'UDAN'The UDAN (Ude Desh ka Aam Naagrik) scheme seeks to make flying more affordable for the common people, holding a plan to connect over 45 unserved and under-served airports.Under UDAN, 50 per cent of the seats on each flight would have a cap of Rs 2,500 per seat/hour. The government has also extended subsidy in the form of viability gap funding to the operators flying on these routes.The scheme was launched to \"make air travel accessible to citizens in regionally important cities,\" and has been described as \"a first-of-its-kind scheme globally to stimulate regional connectivity through a market-based mechanism.\" Report have it the first flight today will not be flying at full capacity on its 70-seater ATR airplane because of payload restrictions related to the short Shimla airfield.|| Read more ||Udan scheme: Now you can fly to these 43 cities, see the full list hereUDAN scheme to fly hour-long flights capped at Rs 2,500 to smaller cities\n", 866 | "\n", 867 | "\n", 868 | "**Original Summary**\n", 869 | "PM Narendra Modi on Thursday launched Ude Desh ka Aam Nagrik (UDAN) scheme for regional flight connectivity by flagging off the inaugural flight from Shimla to Delhi. Under UDAN, government will connect small towns by air with 50% plane seats' fare capped at?2,500 for a one-hour journey of 500 kilometres. UDAN will connect over 45 unserved and under-served airports.\n", 870 | "\n", 871 | "**Generated Summary**\n", 872 | "UDAN (Ude Desh Ka Aam Naagrik) scheme, launched to make air travel accessible in regionally important cities under the Regional Connectivity Scheme, took off from Shimla on Tuesday. The first flight will be operated by Alliance Air, which is the regional arm of India's Air India. Under the scheme, 50% seats would have?2,500 per seat/hour and 50% of the seats would have capped at this rate. It was also extended subsidy in form-based funding for operators flying these routes as well." 873 | ] 874 | }, 875 | { 876 | "cell_type": "markdown", 877 | "metadata": { 878 | "id": "Zmku4YWKD7YL" 879 | }, 880 | "source": [ 881 | "##### Example 3\n", 882 | "\n", 883 | "**Original Text**\n", 884 | "New Delhi, Apr 25 (PTI) The Income Tax department has issued a Rs 24,646 crore tax demand notice to Sahara Groups Aamby Valley Limited (AVL) after conducting a special audit of the company. The department, as part of a special investigation and audit into the account books of AVL, found that an income of over Rs 48,000 crore for a particular assessment year was allegedly not reflected in the record books of the firm and hence it raised a fresh tax demand and penalty amount on it. A Sahara Group spokesperson confirmed the development to PTI. \"Yes, the Income Tax Department has raised Rs 48,085.79 crores to the income of the Aamby Valley Limited with a total demand of income tax of Rs 24,646.96 crores on the Aamby Valley Limited,\" the spokesperson said in a brief statement. Officials said the notice was issued by the taxman in January this year after the special audit of AVLs income for the Assessment Year 2012-13 found that the parent firm had allegedly floated a clutch of Special Purpose Vehicles whose incomes were later accounted on the account of AVL as they were merged with the former in due course of time. The AVL, in its income return filed for AY 2012-13, had reflected a loss of few crores but the special I-T audit brought up the added income, a senior official said. The Supreme Court, last week, had asked the Bombay High Courts official liquidator to sell the Rs 34,000 crore worth of properties of Aamby Valley owned by the Sahara Group and directed its chief Subrata Roy to personally appear before it on April 28. \n", 885 | "\n", 886 | "\n", 887 | "**Original Summary**\n", 888 | "The Income Tax Department has issued a ?24,646 crore tax demand notice to Sahara Group's Aamby Valley Limited. The department's audit found that an income of over ?48,000 crore for the assessment year 2012-13 was not reflected in the record books of the firm. A week ago, the SC ordered Bombay HC to auction Sahara's Aamby Valley worth ?34,000 crore.\n", 889 | "\n", 890 | "**Generated Summary**\n", 891 | "the Income Tax department has issued a?24,646 crore tax demand notice to Sahara Groups Aamby Valley Limited (AVL) after conducting an audit of the company. The notice was issued in January this year after the special audit found that the parent firm had floated Special Purpose Vehicle income for the Assessment Year 2012-13 and later accounted on its account as they were merged with the former. \"Yes...the Income Tax Department raised Rs48,085.79 crores to the income,\" he added earlier said at the notice." 892 | ] 893 | } 894 | ], 895 | "metadata": { 896 | "accelerator": "GPU", 897 | "colab": { 898 | "name": "transformers_summarization_wandb.ipynb", 899 | "provenance": [] 900 | }, 901 | "kernelspec": { 902 | "display_name": "Python 3", 903 | "language": "python", 904 | "name": "python3" 905 | }, 906 | "language_info": { 907 | "codemirror_mode": { 908 | "name": "ipython", 909 | "version": 3 910 | }, 911 | "file_extension": ".py", 912 | "mimetype": "text/x-python", 913 | "name": "python", 914 | "nbconvert_exporter": "python", 915 | "pygments_lexer": "ipython3", 916 | "version": "3.7.7" 917 | }, 918 | "varInspector": { 919 | "cols": { 920 | "lenName": 16, 921 | "lenType": 16, 922 | "lenVar": 40 923 | }, 924 | "kernels_config": { 925 | "python": { 926 | "delete_cmd_postfix": "", 927 | "delete_cmd_prefix": "del ", 928 | "library": "var_list.py", 929 | "varRefreshCmd": "print(var_dic_list())" 930 | }, 931 | "r": { 932 | "delete_cmd_postfix": ") ", 933 | "delete_cmd_prefix": "rm(", 934 | "library": "var_list.r", 935 | "varRefreshCmd": "cat(var_dic_list()) " 936 | } 937 | }, 938 | "types_to_exclude": [ 939 | "module", 940 | "function", 941 | "builtin_function_or_method", 942 | "instance", 943 | "_Feature" 944 | ], 945 | "window_display": false 946 | }, 947 | "widgets": { 948 | "application/vnd.jupyter.widget-state+json": { 949 | "1e6cc5a0d54c4fdab326fcaa2660e781": { 950 | "model_module": "@jupyter-widgets/base", 951 | "model_name": "LayoutModel", 952 | "state": { 953 | "_model_module": "@jupyter-widgets/base", 954 | "_model_module_version": "1.2.0", 955 | "_model_name": "LayoutModel", 956 | "_view_count": null, 957 | "_view_module": "@jupyter-widgets/base", 958 | "_view_module_version": "1.2.0", 959 | "_view_name": "LayoutView", 960 | "align_content": null, 961 | "align_items": null, 962 | "align_self": null, 963 | "border": null, 964 | "bottom": null, 965 | "display": null, 966 | "flex": null, 967 | "flex_flow": null, 968 | "grid_area": null, 969 | "grid_auto_columns": null, 970 | "grid_auto_flow": null, 971 | "grid_auto_rows": null, 972 | "grid_column": null, 973 | "grid_gap": null, 974 | "grid_row": null, 975 | "grid_template_areas": null, 976 | "grid_template_columns": null, 977 | "grid_template_rows": null, 978 | "height": null, 979 | "justify_content": null, 980 | "justify_items": null, 981 | "left": null, 982 | "margin": null, 983 | "max_height": null, 984 | "max_width": null, 985 | "min_height": null, 986 | "min_width": null, 987 | "object_fit": null, 988 | "object_position": null, 989 | "order": null, 990 | "overflow": null, 991 | "overflow_x": null, 992 | "overflow_y": null, 993 | "padding": null, 994 | "right": null, 995 | "top": null, 996 | "visibility": null, 997 | "width": null 998 | }, 999 | "model_module_version": "1.2.0" 1000 | }, 1001 | "4ee0316176b245239434bd88fe8f8572": { 1002 | "model_module": "@jupyter-widgets/controls", 1003 | "model_name": "DescriptionStyleModel", 1004 | "state": { 1005 | "_model_module": "@jupyter-widgets/controls", 1006 | "_model_module_version": "1.5.0", 1007 | "_model_name": "DescriptionStyleModel", 1008 | "_view_count": null, 1009 | "_view_module": "@jupyter-widgets/base", 1010 | "_view_module_version": "1.2.0", 1011 | "_view_name": "StyleView", 1012 | "description_width": "" 1013 | }, 1014 | "model_module_version": "1.5.0" 1015 | }, 1016 | "4fca290286b3402fae9ddd0b1ce54504": { 1017 | "model_module": "@jupyter-widgets/controls", 1018 | "model_name": "DescriptionStyleModel", 1019 | "state": { 1020 | "_model_module": "@jupyter-widgets/controls", 1021 | "_model_module_version": "1.5.0", 1022 | "_model_name": "DescriptionStyleModel", 1023 | "_view_count": null, 1024 | "_view_module": "@jupyter-widgets/base", 1025 | "_view_module_version": "1.2.0", 1026 | "_view_name": "StyleView", 1027 | "description_width": "" 1028 | }, 1029 | "model_module_version": "1.5.0" 1030 | }, 1031 | "54502d44bc774e2cb8067b59faa3f1bf": { 1032 | "model_module": "@jupyter-widgets/controls", 1033 | "model_name": "HBoxModel", 1034 | "state": { 1035 | "_dom_classes": [], 1036 | "_model_module": "@jupyter-widgets/controls", 1037 | "_model_module_version": "1.5.0", 1038 | "_model_name": "HBoxModel", 1039 | "_view_count": null, 1040 | "_view_module": "@jupyter-widgets/controls", 1041 | "_view_module_version": "1.5.0", 1042 | "_view_name": "HBoxView", 1043 | "box_style": "", 1044 | "children": [ 1045 | "IPY_MODEL_ec8ba72354fd4b1d92c3b7d061e3c464", 1046 | "IPY_MODEL_d9627b05dcda4043aa149c940d7d2f57" 1047 | ], 1048 | "layout": "IPY_MODEL_1e6cc5a0d54c4fdab326fcaa2660e781" 1049 | }, 1050 | "model_module_version": "1.5.0" 1051 | }, 1052 | "5c5331a892a64b95a2a130d7cb953e27": { 1053 | "model_module": "@jupyter-widgets/controls", 1054 | "model_name": "HBoxModel", 1055 | "state": { 1056 | "_dom_classes": [], 1057 | "_model_module": "@jupyter-widgets/controls", 1058 | "_model_module_version": "1.5.0", 1059 | "_model_name": "HBoxModel", 1060 | "_view_count": null, 1061 | "_view_module": "@jupyter-widgets/controls", 1062 | "_view_module_version": "1.5.0", 1063 | "_view_name": "HBoxView", 1064 | "box_style": "", 1065 | "children": [ 1066 | "IPY_MODEL_7b6f595e9c6a45f9b1edc6bd5512d205", 1067 | "IPY_MODEL_662b4ba823df409d8169696c81dccb46" 1068 | ], 1069 | "layout": "IPY_MODEL_a84d01823be4413db834ac23cffd9c26" 1070 | }, 1071 | "model_module_version": "1.5.0" 1072 | }, 1073 | "662b4ba823df409d8169696c81dccb46": { 1074 | "model_module": "@jupyter-widgets/controls", 1075 | "model_name": "HTMLModel", 1076 | "state": { 1077 | "_dom_classes": [], 1078 | "_model_module": "@jupyter-widgets/controls", 1079 | "_model_module_version": "1.5.0", 1080 | "_model_name": "HTMLModel", 1081 | "_view_count": null, 1082 | "_view_module": "@jupyter-widgets/controls", 1083 | "_view_module_version": "1.5.0", 1084 | "_view_name": "HTMLView", 1085 | "description": "", 1086 | "description_tooltip": null, 1087 | "layout": "IPY_MODEL_a7edcb37126443c298eca2251194bb66", 1088 | "placeholder": "​", 1089 | "style": "IPY_MODEL_e46fe62d895043878985a1014c3d853a", 1090 | "value": " 1.20k/1.20k [00:16<00:00, 73.7B/s]" 1091 | }, 1092 | "model_module_version": "1.5.0" 1093 | }, 1094 | "694ec243104f470093820e5e0dbbfc8e": { 1095 | "model_module": "@jupyter-widgets/controls", 1096 | "model_name": "HBoxModel", 1097 | "state": { 1098 | "_dom_classes": [], 1099 | "_model_module": "@jupyter-widgets/controls", 1100 | "_model_module_version": "1.5.0", 1101 | "_model_name": "HBoxModel", 1102 | "_view_count": null, 1103 | "_view_module": "@jupyter-widgets/controls", 1104 | "_view_module_version": "1.5.0", 1105 | "_view_name": "HBoxView", 1106 | "box_style": "", 1107 | "children": [ 1108 | "IPY_MODEL_71a4e2121f554668b2fd6461de3b2dcb", 1109 | "IPY_MODEL_9218011c0e1544ffb8128e0add0439c6" 1110 | ], 1111 | "layout": "IPY_MODEL_730be2945e39401dac1ac1247cf2d5fb" 1112 | }, 1113 | "model_module_version": "1.5.0" 1114 | }, 1115 | "6afe5bad5a2b4e8c94791949b0f08ead": { 1116 | "model_module": "@jupyter-widgets/base", 1117 | "model_name": "LayoutModel", 1118 | "state": { 1119 | "_model_module": "@jupyter-widgets/base", 1120 | "_model_module_version": "1.2.0", 1121 | "_model_name": "LayoutModel", 1122 | "_view_count": null, 1123 | "_view_module": "@jupyter-widgets/base", 1124 | "_view_module_version": "1.2.0", 1125 | "_view_name": "LayoutView", 1126 | "align_content": null, 1127 | "align_items": null, 1128 | "align_self": null, 1129 | "border": null, 1130 | "bottom": null, 1131 | "display": null, 1132 | "flex": null, 1133 | "flex_flow": null, 1134 | "grid_area": null, 1135 | "grid_auto_columns": null, 1136 | "grid_auto_flow": null, 1137 | "grid_auto_rows": null, 1138 | "grid_column": null, 1139 | "grid_gap": null, 1140 | "grid_row": null, 1141 | "grid_template_areas": null, 1142 | "grid_template_columns": null, 1143 | "grid_template_rows": null, 1144 | "height": null, 1145 | "justify_content": null, 1146 | "justify_items": null, 1147 | "left": null, 1148 | "margin": null, 1149 | "max_height": null, 1150 | "max_width": null, 1151 | "min_height": null, 1152 | "min_width": null, 1153 | "object_fit": null, 1154 | "object_position": null, 1155 | "order": null, 1156 | "overflow": null, 1157 | "overflow_x": null, 1158 | "overflow_y": null, 1159 | "padding": null, 1160 | "right": null, 1161 | "top": null, 1162 | "visibility": null, 1163 | "width": null 1164 | }, 1165 | "model_module_version": "1.2.0" 1166 | }, 1167 | "71a4e2121f554668b2fd6461de3b2dcb": { 1168 | "model_module": "@jupyter-widgets/controls", 1169 | "model_name": "FloatProgressModel", 1170 | "state": { 1171 | "_dom_classes": [], 1172 | "_model_module": "@jupyter-widgets/controls", 1173 | "_model_module_version": "1.5.0", 1174 | "_model_name": "FloatProgressModel", 1175 | "_view_count": null, 1176 | "_view_module": "@jupyter-widgets/controls", 1177 | "_view_module_version": "1.5.0", 1178 | "_view_name": "ProgressView", 1179 | "bar_style": "success", 1180 | "description": "Downloading: 100%", 1181 | "description_tooltip": null, 1182 | "layout": "IPY_MODEL_6afe5bad5a2b4e8c94791949b0f08ead", 1183 | "max": 791656, 1184 | "min": 0, 1185 | "orientation": "horizontal", 1186 | "style": "IPY_MODEL_9754d7526eb249d6ba849cb30f415ffc", 1187 | "value": 791656 1188 | }, 1189 | "model_module_version": "1.5.0" 1190 | }, 1191 | "730be2945e39401dac1ac1247cf2d5fb": { 1192 | "model_module": "@jupyter-widgets/base", 1193 | "model_name": "LayoutModel", 1194 | "state": { 1195 | "_model_module": "@jupyter-widgets/base", 1196 | "_model_module_version": "1.2.0", 1197 | "_model_name": "LayoutModel", 1198 | "_view_count": null, 1199 | "_view_module": "@jupyter-widgets/base", 1200 | "_view_module_version": "1.2.0", 1201 | "_view_name": "LayoutView", 1202 | "align_content": null, 1203 | "align_items": null, 1204 | "align_self": null, 1205 | "border": null, 1206 | "bottom": null, 1207 | "display": null, 1208 | "flex": null, 1209 | "flex_flow": null, 1210 | "grid_area": null, 1211 | "grid_auto_columns": null, 1212 | "grid_auto_flow": null, 1213 | "grid_auto_rows": null, 1214 | "grid_column": null, 1215 | "grid_gap": null, 1216 | "grid_row": null, 1217 | "grid_template_areas": null, 1218 | "grid_template_columns": null, 1219 | "grid_template_rows": null, 1220 | "height": null, 1221 | "justify_content": null, 1222 | "justify_items": null, 1223 | "left": null, 1224 | "margin": null, 1225 | "max_height": null, 1226 | "max_width": null, 1227 | "min_height": null, 1228 | "min_width": null, 1229 | "object_fit": null, 1230 | "object_position": null, 1231 | "order": null, 1232 | "overflow": null, 1233 | "overflow_x": null, 1234 | "overflow_y": null, 1235 | "padding": null, 1236 | "right": null, 1237 | "top": null, 1238 | "visibility": null, 1239 | "width": null 1240 | }, 1241 | "model_module_version": "1.2.0" 1242 | }, 1243 | "7b6f595e9c6a45f9b1edc6bd5512d205": { 1244 | "model_module": "@jupyter-widgets/controls", 1245 | "model_name": "FloatProgressModel", 1246 | "state": { 1247 | "_dom_classes": [], 1248 | "_model_module": "@jupyter-widgets/controls", 1249 | "_model_module_version": "1.5.0", 1250 | "_model_name": "FloatProgressModel", 1251 | "_view_count": null, 1252 | "_view_module": "@jupyter-widgets/controls", 1253 | "_view_module_version": "1.5.0", 1254 | "_view_name": "ProgressView", 1255 | "bar_style": "success", 1256 | "description": "Downloading: 100%", 1257 | "description_tooltip": null, 1258 | "layout": "IPY_MODEL_98459f90e8e94b29b93f576b5fdebe58", 1259 | "max": 1199, 1260 | "min": 0, 1261 | "orientation": "horizontal", 1262 | "style": "IPY_MODEL_94d2a8fa47a3440a8332cb36036ea68e", 1263 | "value": 1199 1264 | }, 1265 | "model_module_version": "1.5.0" 1266 | }, 1267 | "9218011c0e1544ffb8128e0add0439c6": { 1268 | "model_module": "@jupyter-widgets/controls", 1269 | "model_name": "HTMLModel", 1270 | "state": { 1271 | "_dom_classes": [], 1272 | "_model_module": "@jupyter-widgets/controls", 1273 | "_model_module_version": "1.5.0", 1274 | "_model_name": "HTMLModel", 1275 | "_view_count": null, 1276 | "_view_module": "@jupyter-widgets/controls", 1277 | "_view_module_version": "1.5.0", 1278 | "_view_name": "HTMLView", 1279 | "description": "", 1280 | "description_tooltip": null, 1281 | "layout": "IPY_MODEL_ae78c7bf57934f03a0f553b8829e44d4", 1282 | "placeholder": "​", 1283 | "style": "IPY_MODEL_4fca290286b3402fae9ddd0b1ce54504", 1284 | "value": " 792k/792k [00:00<00:00, 2.85MB/s]" 1285 | }, 1286 | "model_module_version": "1.5.0" 1287 | }, 1288 | "94d2a8fa47a3440a8332cb36036ea68e": { 1289 | "model_module": "@jupyter-widgets/controls", 1290 | "model_name": "ProgressStyleModel", 1291 | "state": { 1292 | "_model_module": "@jupyter-widgets/controls", 1293 | "_model_module_version": "1.5.0", 1294 | "_model_name": "ProgressStyleModel", 1295 | "_view_count": null, 1296 | "_view_module": "@jupyter-widgets/base", 1297 | "_view_module_version": "1.2.0", 1298 | "_view_name": "StyleView", 1299 | "bar_color": null, 1300 | "description_width": "initial" 1301 | }, 1302 | "model_module_version": "1.5.0" 1303 | }, 1304 | "9754d7526eb249d6ba849cb30f415ffc": { 1305 | "model_module": "@jupyter-widgets/controls", 1306 | "model_name": "ProgressStyleModel", 1307 | "state": { 1308 | "_model_module": "@jupyter-widgets/controls", 1309 | "_model_module_version": "1.5.0", 1310 | "_model_name": "ProgressStyleModel", 1311 | "_view_count": null, 1312 | "_view_module": "@jupyter-widgets/base", 1313 | "_view_module_version": "1.2.0", 1314 | "_view_name": "StyleView", 1315 | "bar_color": null, 1316 | "description_width": "initial" 1317 | }, 1318 | "model_module_version": "1.5.0" 1319 | }, 1320 | "98459f90e8e94b29b93f576b5fdebe58": { 1321 | "model_module": "@jupyter-widgets/base", 1322 | "model_name": "LayoutModel", 1323 | "state": { 1324 | "_model_module": "@jupyter-widgets/base", 1325 | "_model_module_version": "1.2.0", 1326 | "_model_name": "LayoutModel", 1327 | "_view_count": null, 1328 | "_view_module": "@jupyter-widgets/base", 1329 | "_view_module_version": "1.2.0", 1330 | "_view_name": "LayoutView", 1331 | "align_content": null, 1332 | "align_items": null, 1333 | "align_self": null, 1334 | "border": null, 1335 | "bottom": null, 1336 | "display": null, 1337 | "flex": null, 1338 | "flex_flow": null, 1339 | "grid_area": null, 1340 | "grid_auto_columns": null, 1341 | "grid_auto_flow": null, 1342 | "grid_auto_rows": null, 1343 | "grid_column": null, 1344 | "grid_gap": null, 1345 | "grid_row": null, 1346 | "grid_template_areas": null, 1347 | "grid_template_columns": null, 1348 | "grid_template_rows": null, 1349 | "height": null, 1350 | "justify_content": null, 1351 | "justify_items": null, 1352 | "left": null, 1353 | "margin": null, 1354 | "max_height": null, 1355 | "max_width": null, 1356 | "min_height": null, 1357 | "min_width": null, 1358 | "object_fit": null, 1359 | "object_position": null, 1360 | "order": null, 1361 | "overflow": null, 1362 | "overflow_x": null, 1363 | "overflow_y": null, 1364 | "padding": null, 1365 | "right": null, 1366 | "top": null, 1367 | "visibility": null, 1368 | "width": null 1369 | }, 1370 | "model_module_version": "1.2.0" 1371 | }, 1372 | "a0002b369f23475ca440c59df649e450": { 1373 | "model_module": "@jupyter-widgets/base", 1374 | "model_name": "LayoutModel", 1375 | "state": { 1376 | "_model_module": "@jupyter-widgets/base", 1377 | "_model_module_version": "1.2.0", 1378 | "_model_name": "LayoutModel", 1379 | "_view_count": null, 1380 | "_view_module": "@jupyter-widgets/base", 1381 | "_view_module_version": "1.2.0", 1382 | "_view_name": "LayoutView", 1383 | "align_content": null, 1384 | "align_items": null, 1385 | "align_self": null, 1386 | "border": null, 1387 | "bottom": null, 1388 | "display": null, 1389 | "flex": null, 1390 | "flex_flow": null, 1391 | "grid_area": null, 1392 | "grid_auto_columns": null, 1393 | "grid_auto_flow": null, 1394 | "grid_auto_rows": null, 1395 | "grid_column": null, 1396 | "grid_gap": null, 1397 | "grid_row": null, 1398 | "grid_template_areas": null, 1399 | "grid_template_columns": null, 1400 | "grid_template_rows": null, 1401 | "height": null, 1402 | "justify_content": null, 1403 | "justify_items": null, 1404 | "left": null, 1405 | "margin": null, 1406 | "max_height": null, 1407 | "max_width": null, 1408 | "min_height": null, 1409 | "min_width": null, 1410 | "object_fit": null, 1411 | "object_position": null, 1412 | "order": null, 1413 | "overflow": null, 1414 | "overflow_x": null, 1415 | "overflow_y": null, 1416 | "padding": null, 1417 | "right": null, 1418 | "top": null, 1419 | "visibility": null, 1420 | "width": null 1421 | }, 1422 | "model_module_version": "1.2.0" 1423 | }, 1424 | "a7edcb37126443c298eca2251194bb66": { 1425 | "model_module": "@jupyter-widgets/base", 1426 | "model_name": "LayoutModel", 1427 | "state": { 1428 | "_model_module": "@jupyter-widgets/base", 1429 | "_model_module_version": "1.2.0", 1430 | "_model_name": "LayoutModel", 1431 | "_view_count": null, 1432 | "_view_module": "@jupyter-widgets/base", 1433 | "_view_module_version": "1.2.0", 1434 | "_view_name": "LayoutView", 1435 | "align_content": null, 1436 | "align_items": null, 1437 | "align_self": null, 1438 | "border": null, 1439 | "bottom": null, 1440 | "display": null, 1441 | "flex": null, 1442 | "flex_flow": null, 1443 | "grid_area": null, 1444 | "grid_auto_columns": null, 1445 | "grid_auto_flow": null, 1446 | "grid_auto_rows": null, 1447 | "grid_column": null, 1448 | "grid_gap": null, 1449 | "grid_row": null, 1450 | "grid_template_areas": null, 1451 | "grid_template_columns": null, 1452 | "grid_template_rows": null, 1453 | "height": null, 1454 | "justify_content": null, 1455 | "justify_items": null, 1456 | "left": null, 1457 | "margin": null, 1458 | "max_height": null, 1459 | "max_width": null, 1460 | "min_height": null, 1461 | "min_width": null, 1462 | "object_fit": null, 1463 | "object_position": null, 1464 | "order": null, 1465 | "overflow": null, 1466 | "overflow_x": null, 1467 | "overflow_y": null, 1468 | "padding": null, 1469 | "right": null, 1470 | "top": null, 1471 | "visibility": null, 1472 | "width": null 1473 | }, 1474 | "model_module_version": "1.2.0" 1475 | }, 1476 | "a84d01823be4413db834ac23cffd9c26": { 1477 | "model_module": "@jupyter-widgets/base", 1478 | "model_name": "LayoutModel", 1479 | "state": { 1480 | "_model_module": "@jupyter-widgets/base", 1481 | "_model_module_version": "1.2.0", 1482 | "_model_name": "LayoutModel", 1483 | "_view_count": null, 1484 | "_view_module": "@jupyter-widgets/base", 1485 | "_view_module_version": "1.2.0", 1486 | "_view_name": "LayoutView", 1487 | "align_content": null, 1488 | "align_items": null, 1489 | "align_self": null, 1490 | "border": null, 1491 | "bottom": null, 1492 | "display": null, 1493 | "flex": null, 1494 | "flex_flow": null, 1495 | "grid_area": null, 1496 | "grid_auto_columns": null, 1497 | "grid_auto_flow": null, 1498 | "grid_auto_rows": null, 1499 | "grid_column": null, 1500 | "grid_gap": null, 1501 | "grid_row": null, 1502 | "grid_template_areas": null, 1503 | "grid_template_columns": null, 1504 | "grid_template_rows": null, 1505 | "height": null, 1506 | "justify_content": null, 1507 | "justify_items": null, 1508 | "left": null, 1509 | "margin": null, 1510 | "max_height": null, 1511 | "max_width": null, 1512 | "min_height": null, 1513 | "min_width": null, 1514 | "object_fit": null, 1515 | "object_position": null, 1516 | "order": null, 1517 | "overflow": null, 1518 | "overflow_x": null, 1519 | "overflow_y": null, 1520 | "padding": null, 1521 | "right": null, 1522 | "top": null, 1523 | "visibility": null, 1524 | "width": null 1525 | }, 1526 | "model_module_version": "1.2.0" 1527 | }, 1528 | "ae78c7bf57934f03a0f553b8829e44d4": { 1529 | "model_module": "@jupyter-widgets/base", 1530 | "model_name": "LayoutModel", 1531 | "state": { 1532 | "_model_module": "@jupyter-widgets/base", 1533 | "_model_module_version": "1.2.0", 1534 | "_model_name": "LayoutModel", 1535 | "_view_count": null, 1536 | "_view_module": "@jupyter-widgets/base", 1537 | "_view_module_version": "1.2.0", 1538 | "_view_name": "LayoutView", 1539 | "align_content": null, 1540 | "align_items": null, 1541 | "align_self": null, 1542 | "border": null, 1543 | "bottom": null, 1544 | "display": null, 1545 | "flex": null, 1546 | "flex_flow": null, 1547 | "grid_area": null, 1548 | "grid_auto_columns": null, 1549 | "grid_auto_flow": null, 1550 | "grid_auto_rows": null, 1551 | "grid_column": null, 1552 | "grid_gap": null, 1553 | "grid_row": null, 1554 | "grid_template_areas": null, 1555 | "grid_template_columns": null, 1556 | "grid_template_rows": null, 1557 | "height": null, 1558 | "justify_content": null, 1559 | "justify_items": null, 1560 | "left": null, 1561 | "margin": null, 1562 | "max_height": null, 1563 | "max_width": null, 1564 | "min_height": null, 1565 | "min_width": null, 1566 | "object_fit": null, 1567 | "object_position": null, 1568 | "order": null, 1569 | "overflow": null, 1570 | "overflow_x": null, 1571 | "overflow_y": null, 1572 | "padding": null, 1573 | "right": null, 1574 | "top": null, 1575 | "visibility": null, 1576 | "width": null 1577 | }, 1578 | "model_module_version": "1.2.0" 1579 | }, 1580 | "bb96976987544a5f800e71f2124b97dd": { 1581 | "model_module": "@jupyter-widgets/base", 1582 | "model_name": "LayoutModel", 1583 | "state": { 1584 | "_model_module": "@jupyter-widgets/base", 1585 | "_model_module_version": "1.2.0", 1586 | "_model_name": "LayoutModel", 1587 | "_view_count": null, 1588 | "_view_module": "@jupyter-widgets/base", 1589 | "_view_module_version": "1.2.0", 1590 | "_view_name": "LayoutView", 1591 | "align_content": null, 1592 | "align_items": null, 1593 | "align_self": null, 1594 | "border": null, 1595 | "bottom": null, 1596 | "display": null, 1597 | "flex": null, 1598 | "flex_flow": null, 1599 | "grid_area": null, 1600 | "grid_auto_columns": null, 1601 | "grid_auto_flow": null, 1602 | "grid_auto_rows": null, 1603 | "grid_column": null, 1604 | "grid_gap": null, 1605 | "grid_row": null, 1606 | "grid_template_areas": null, 1607 | "grid_template_columns": null, 1608 | "grid_template_rows": null, 1609 | "height": null, 1610 | "justify_content": null, 1611 | "justify_items": null, 1612 | "left": null, 1613 | "margin": null, 1614 | "max_height": null, 1615 | "max_width": null, 1616 | "min_height": null, 1617 | "min_width": null, 1618 | "object_fit": null, 1619 | "object_position": null, 1620 | "order": null, 1621 | "overflow": null, 1622 | "overflow_x": null, 1623 | "overflow_y": null, 1624 | "padding": null, 1625 | "right": null, 1626 | "top": null, 1627 | "visibility": null, 1628 | "width": null 1629 | }, 1630 | "model_module_version": "1.2.0" 1631 | }, 1632 | "caf8968694c64b4fbbf5a15ed4948c34": { 1633 | "model_module": "@jupyter-widgets/controls", 1634 | "model_name": "ProgressStyleModel", 1635 | "state": { 1636 | "_model_module": "@jupyter-widgets/controls", 1637 | "_model_module_version": "1.5.0", 1638 | "_model_name": "ProgressStyleModel", 1639 | "_view_count": null, 1640 | "_view_module": "@jupyter-widgets/base", 1641 | "_view_module_version": "1.2.0", 1642 | "_view_name": "StyleView", 1643 | "bar_color": null, 1644 | "description_width": "initial" 1645 | }, 1646 | "model_module_version": "1.5.0" 1647 | }, 1648 | "d9627b05dcda4043aa149c940d7d2f57": { 1649 | "model_module": "@jupyter-widgets/controls", 1650 | "model_name": "HTMLModel", 1651 | "state": { 1652 | "_dom_classes": [], 1653 | "_model_module": "@jupyter-widgets/controls", 1654 | "_model_module_version": "1.5.0", 1655 | "_model_name": "HTMLModel", 1656 | "_view_count": null, 1657 | "_view_module": "@jupyter-widgets/controls", 1658 | "_view_module_version": "1.5.0", 1659 | "_view_name": "HTMLView", 1660 | "description": "", 1661 | "description_tooltip": null, 1662 | "layout": "IPY_MODEL_a0002b369f23475ca440c59df649e450", 1663 | "placeholder": "​", 1664 | "style": "IPY_MODEL_4ee0316176b245239434bd88fe8f8572", 1665 | "value": " 892M/892M [00:15<00:00, 57.8MB/s]" 1666 | }, 1667 | "model_module_version": "1.5.0" 1668 | }, 1669 | "e46fe62d895043878985a1014c3d853a": { 1670 | "model_module": "@jupyter-widgets/controls", 1671 | "model_name": "DescriptionStyleModel", 1672 | "state": { 1673 | "_model_module": "@jupyter-widgets/controls", 1674 | "_model_module_version": "1.5.0", 1675 | "_model_name": "DescriptionStyleModel", 1676 | "_view_count": null, 1677 | "_view_module": "@jupyter-widgets/base", 1678 | "_view_module_version": "1.2.0", 1679 | "_view_name": "StyleView", 1680 | "description_width": "" 1681 | }, 1682 | "model_module_version": "1.5.0" 1683 | }, 1684 | "ec8ba72354fd4b1d92c3b7d061e3c464": { 1685 | "model_module": "@jupyter-widgets/controls", 1686 | "model_name": "FloatProgressModel", 1687 | "state": { 1688 | "_dom_classes": [], 1689 | "_model_module": "@jupyter-widgets/controls", 1690 | "_model_module_version": "1.5.0", 1691 | "_model_name": "FloatProgressModel", 1692 | "_view_count": null, 1693 | "_view_module": "@jupyter-widgets/controls", 1694 | "_view_module_version": "1.5.0", 1695 | "_view_name": "ProgressView", 1696 | "bar_style": "success", 1697 | "description": "Downloading: 100%", 1698 | "description_tooltip": null, 1699 | "layout": "IPY_MODEL_bb96976987544a5f800e71f2124b97dd", 1700 | "max": 891691430, 1701 | "min": 0, 1702 | "orientation": "horizontal", 1703 | "style": "IPY_MODEL_caf8968694c64b4fbbf5a15ed4948c34", 1704 | "value": 891691430 1705 | }, 1706 | "model_module_version": "1.5.0" 1707 | } 1708 | } 1709 | } 1710 | }, 1711 | "nbformat": 4, 1712 | "nbformat_minor": 0 1713 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # NLP 学习路线 2 | 3 | ## 指南 4 | 5 | ################### 6 | 7 | **竞赛** 8 | 9 | 查看NLP领域的众多竞赛情况(不断更新中):[竞赛](https://github.com/nine09/NLP-Syllabus/blob/master/NLP_Competitions.md) 10 | 11 | ################## 12 | 13 | **进阶指南:** 14 | 15 | 研究生阶段正式开始。NLP研究方向,必修知识指南在这里:[进阶](https://github.com/nine09/NLP-Syllabus/blob/master/researcher.md) 16 | 17 | ################### 18 | 19 | 以下为研究生入门指南。 20 | 21 | ## 基础知识 22 | 23 | 1. 数学 24 | - 微积分 25 | - 概率论 26 | - 线性代数 27 | 2. Python 28 | - virtualenvironment 29 | - numpy 30 | - sklearn 31 | 3. 深度学习框架:**用到再看** 32 | - Tensorflow 33 | - Pytorch 34 | 35 | ## 课程学习 36 | 37 | **Stanford CS224n** 38 | 39 | Download Slides and Pdf lecture in: http://web.stanford.edu/class/cs224n/ 40 | 41 | Videos in: https://www.bilibili.com/video/av13383754?from=search&seid=5889103122225870394 42 | 43 | 完成课程学习与课程作业3 44 | 45 | ## 一:文本分类 46 | 47 | 文本分类是最基础,最简单的NLP任务形式,即句子分类:给定input句子,预测该句所属的类别。通常,解决该问题通常需要两步:1)将输入文本编码为连续空间中的一个向量,接下来 2)基于该向量进行分类。其中步骤1可以使用word2vec,神经网络等方法,步骤2多使用Linear Layers。 48 | 49 | 了解该过程可以建立对表示学习的基本认识,并认识到编码过程和预测过程的相对独立性。 50 | 51 | 必要知识:tensorflow & sklearn; word2vec; CNN & RNN; keras 52 | 53 | 参考url: https://zhuanlan.zhihu.com/p/26729228 54 | 55 | Download Data Set: http://www.sogou.com/labs/resource/cs.php 56 | 57 | 比较以下模型的分类效果: 58 | 59 | - CNN 60 | - LSTM 61 | - Naive Bayes 62 | - SVM 63 | 64 | 其中,深度学习模型需要使用word2vec初始化。 65 | 66 | ## 二:序列标注 67 | 68 | 序列标注是词性标注,实体识别,信息抽取等众多任务的常见解法。其旨在对句子中的每一个词/字进行分类。在了解该问题过程中,我们可以感受实际NLP问题是如何被分解/聚合后套在方便实现的框架下的。例如,实体识别任务是如何通过对逐字分类的过程中完成的。 69 | 70 | 序列标注可以采用经典概率图方法,例如HMM,CRF等:http://fancyerii.github.io/books/sequential_labeling/ 71 | 72 | 但直接了解基于神经网络的方法更直接(推荐):https://zhuanlan.zhihu.com/p/34828874 73 | 74 | 拓展阅读:https://zhuanlan.zhihu.com/p/268579769 75 | 76 | ## 三:seq2seq 77 | 78 | reference url: https://github.com/NELSONZHAO/zhihu/tree/master/basic_seq2seq?1521452873816 79 | 80 | 使用Tensorflow / Pytorch 实现seq2seq模型 81 | 82 | 拓展知识: 83 | 84 | - Attention mechanism 注意力机制 (重要) 85 | - BiRNN + Attention 机器翻译模型: https://zhuanlan.zhihu.com/p/37290775 86 | - 推荐paper:Attention based documents classification: http://www.aclweb.org/anthology/N16-1174 87 | - 推荐项目:OpenNMT 88 | - url: http://opennmt.net/ 89 | 90 | ## 四:Transformer 91 | 92 | Attention is all you need: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 93 | 94 | 中文:https://zhuanlan.zhihu.com/p/48508221 95 | 96 | transformer基本已经统一深度学习届,所以算是必读的论文了。 97 | 98 | ## 五:BERT & GPT 99 | 100 | 预训练语言模型是近些年NLP届最大突破。现在预训练+微调已经成为所有研究者的工作模式。预训练预言模型最经典的两个model就是BERT和GPT。两者的核心思路都十分简单,不建议读论文。 101 | 102 | BERT:https://zhuanlan.zhihu.com/p/51413773 103 | 104 | GPT:https://zhuanlan.zhihu.com/p/125139937 105 | 106 | 了解BERT的最好方法是使用BERT解决序列标注问题,可以很清晰的了解其优点与缺点。 107 | 通过transformers的example可以很好的了解,顺便入门transformers这个library了。 108 | https://github.com/huggingface/transformers/blob/master/examples/pytorch/token-classification/run_ner.py 109 | 110 | 111 | ## 拓展一:知识图谱 112 | 113 | 知识图谱入门课程: 114 | 115 | 百度网盘链接: https://pan.baidu.com/s/1NzUdiIkIk330VxbWEGL3MQ 提取码: 32q3 116 | 117 | 了解知识图谱的基础知识,常用工具与研究方向 -------------------------------------------------------------------------------- /researcher.md: -------------------------------------------------------------------------------- 1 | # NLP 研究方向 & 基础知识 2 | 3 | ## 必修 4 | 5 | ### 网络 & 预训练模型 6 | 7 | **Transformer** 当前最风行的NLP神经网络。在大部分任务上效果超过RNN和CNN。相比于RNN,优点在于在大数据上训练时速度大幅度提升,同时允许多GPU并行。被目前几乎所有NLP模型所采用。 8 | - 论文:https://arxiv.org/abs/1706.03762 9 | - Harvard出品Transformer的pytorch版实现:https://nlp.seas.harvard.edu/2018/04/03/attention.html 10 | 11 | **BERT** NLP大规模预训练语言模型。区别于unidirectional language model,BERT采用Mask Language Model,以便于得到每个位置上的双向的信息。使用BERT做base model可以提高大多数下游任务的效果。 12 | - 论文:https://arxiv.org/abs/1810.04805 13 | - Google research tensorflow 版实现:https://github.com/google-research/bert 14 | 15 | **GPT & GPT-2** OpenAI出品,大规模预训练语言模型。由于BERT的Mask language model设定,使BERT很难应用在生成任务上(最近也有研究BERT做生成的,见下文)。GPT采用经典的unidirectional language model,并在大规模语料上预训练。GPT不止在生成任务上,在许多任务上都取得了很大的进步。 16 | - GPT-2:https://openai.com/blog/better-language-models/ 17 | - GPT:https://openai.com/blog/language-unsupervised/ 18 | 19 | ### 强化学习 & GAN 20 | 21 | Reinforcement learning在近两年来开始在NLP领域展露头脚,学习Reinforcement Learning是十分有必要的。 22 | - 入门可以观看:https://www.coursera.org/learn/practical-rl?specialization=aml 23 | - 或者阅读综述(不推荐):http://incompleteideas.net/book/bookdraft2017nov5.pdf 24 | - An Introduction to Deep Reinforcement Learning: https://arxiv.org/abs/1811.12560 25 | - 强化学习入门教程:https://simoninithomas.github.io/Deep_reinforcement_learning_Course/ 26 | 27 | GAN在NLP的许多任务上都有采用,学习GAN是必要的。建议从GAN在CV上的应用开始了解,最后阅读GAN在NLP领域上的论文。 28 | GAN入门:https://zhuanlan.zhihu.com/p/58812258 29 | 30 | GAN for text generation: 31 | 32 | - GANs for Sequences of Discrete Elements with the Gumbel-softmax Distribution 33 | - Generating Text via Adversarial Training 34 | - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient 35 | - Adversarial Feature Matching for Text Generation 36 | - Long Text Generation via Adversarial Training with Leaked Information AAAI2018 37 | 38 | ### 迁移学习:Transfer-Learning 39 | 40 | Transfer-learning 的思想在NLP的许多任务的方法上都有所体现。了解Transfer-learning不会让你在读论文时在这方面碰到障碍。 41 | 42 | 迁移学习的资料合集,包括论文和代码: 43 | 44 | reference url: http://transferlearning.xyz/ 45 | 46 | 47 | ## 研究方向:文本生成 48 | 文本生成有很长的研究历史,并且有众多分支任务,例如:机器翻译,对话生成,文本摘要等等。文本生成也分Unconditional text generation和Conditional text generation. 49 | 50 | Conditional text generation指根据特定的条件(例如:问题;英语文本)生成特定的结果(例如:回答;中文文本)。Uncoditional text generation是文本生成的基础方式,可以续写文章等。 51 | 52 | 了解文本生成,综述:https://arxiv.org/abs/1703.09902 53 | 54 | 55 | ### 子方向一:对话生成 56 | 57 | 对话生成一般指给出问题生成回复。 58 | 59 | 推荐论文: 60 | 61 | - Generating Informative Responses with Controlled Sentence Function (AAAI2018,清华) 62 | - Learning to Ask Questions in Open-domain Conversational Systems with Typed Decoders 63 | - Commonsense Knowledge Aware Conversation Generation with Graph Attention 64 | - Adversarial learning for neural dialogue generation (通过对抗学习) 65 | 66 | ### 子方向二:机器翻译 67 | 68 | 推荐论文: 69 | 70 | - Neural Machine Translation by Jointly Learning to Align and Translate 71 | - A Method for Stochastic Optimization 72 | - Neural Machine Translation of Rare Words with Subword Units 73 | - Attention is All You Need (Transformer) 74 | 75 | 76 | ## 研究方向:关系抽取 77 | 78 | 根据ACL2019接受情况,关系抽取是当下最热门的研究方向,同时也是被接收论文最多的方向。 79 | 80 | 中文综述,包含简单模型和数据介绍:https://shomy.top/2018/02/28/relation-extraction/ 81 | 82 | 推荐论文: 83 | 84 | - 当前所有论文的baseline:Distant Supervision for Relation Extraction via Piecewise Convolutional Neural Network 85 | - 2018年的State-of-the-art:Extracting Relational Facts by an End-to-End Neural Model with Copy Mechanism 86 | - 2016年的SOTA-效果依然很好:End-to-End Relation Extraction using LSTMs on Sequences and Tree Structures 87 | - 使用GAN处理远程监督数据噪声-当前的重要研究方向之一2018AAAI:Reinforcement Learning for Relation Classification from Noisy Data 88 | - 联合抽取实体和关系 89 | - Joint Extraction of Entities and Relations Based on a Novel Tagging Scheme 90 | - Going out on a limb: Joint Extraction of Entity Mentions and Relations without Dependency Trees 91 | - 多样例,多标签的抽取:Multi-instance Multi-label Learning for Relation Extraction 92 | - An interpretable Generative Adversarial Approach to Classification of Latent Entity Relations in Unstructured Sentences 93 | - Distant supervision for relation extraction without labeled data 94 | 95 | ## 研究方向:KBQA 96 | 97 | 基于知识库的问答系统。 98 | 99 | 推荐论文: 100 | 101 | - [ACL15]Semantic Parsing via Staged Query Graph Generation: Question Answering with Knowledge Base 102 | - [ACL17]Improved Neural Relation Detection for Knowledge Base Question Answering --------------------------------------------------------------------------------