├── README.md ├── mbart_finetuning ├── run.sh ├── model │ ├── dataloader.py │ └── model.py └── train.py ├── Transformers_multilabel_distilbert.ipynb ├── bert-pretraining.ipynb └── sentiment_analysis_using_roberta.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # NLP scripts 2 | 3 | Hi, 4 | 5 | This repo contains notebooks related to various transformers based models for different nlp based tasks. I will be adding more notebooks in this repo with time. 6 | 7 | I am open for collabarations as well. If you want to contribute, create a new pull request after adding new scripts. Will merge them if they are usable. 8 | -------------------------------------------------------------------------------- /mbart_finetuning/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # mbart 4 | python train.py --train_path train.json --val_path val.json --test_path test.json --tokenizer facebook/mbart-large-50 --model facebook/mbart-large-50 --exp_name mbart-finetuning --save_dir ./ --num_epochs 20 --train_batch_size 4 --val_batch_size 4 --test_batch_size 4 --max_source_length 512 --max_target_length 512 --n_gpus 4 --strategy ddp --sanity_run no -------------------------------------------------------------------------------- /mbart_finetuning/model/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import pytorch_lightning as pl 3 | from transformers import AutoTokenizer 4 | import pandas as pd 5 | import json 6 | import torch 7 | 8 | class Dataset1(Dataset): 9 | def __init__(self, data_path, tokenizer, max_source_length, max_target_length, target_lang): 10 | fp = open(data_path, 'r') 11 | self.df = [json.loads(line, strict=False) for line in fp.readlines()] 12 | self.tokenizer = tokenizer 13 | self.max_source_length = max_source_length 14 | self.max_target_length = max_target_length 15 | self.languages_map = { 16 | 'bn': 'bn_IN', 17 | 'en': 'en_XX', 18 | 'hi': 'hi_IN', 19 | 'ml': 'ml_IN', 20 | 'mr': 'mr_IN', 21 | 'or': 'or_IN', 22 | 'pa': 'pa_IN', 23 | 'ta': 'ta_IN', 24 | } 25 | 26 | def __len__(self): 27 | return len(self.df) 28 | 29 | def __getitem__(self, idx): 30 | input_text = ' '.join(self.df[idx]['input_text']) 31 | target_text = self.df[idx]['target_text'] 32 | src_lang_code = self.df[idx]['src_lang'] 33 | tgt_lang_code = self.df[idx]['tgt_lang'] 34 | if src_lang_code not in self.languages_map: 35 | src_lang_code='en' 36 | if tgt_lang_code not in self.languages_map: 37 | tgt_lang_code='en' 38 | src_lang = self.languages_map[src_lang_code] 39 | tgt_lang = self.languages_map[tgt_lang_code] 40 | 41 | input_encoding = self.tokenizer(src_lang + ' ' + input_text + ' ', return_tensors='pt', max_length=self.max_source_length ,padding='max_length', truncation=True) 42 | 43 | target_encoding = self.tokenizer(tgt_lang + ' ' + target_text + ' ', return_tensors='pt', max_length=self.max_target_length ,padding='max_length', truncation=True) 44 | 45 | input_ids, attention_mask = input_encoding['input_ids'], input_encoding['attention_mask'] 46 | labels = target_encoding['input_ids'] 47 | 48 | return {'input_ids': input_ids.squeeze(), 'attention_mask': attention_mask.squeeze(), 'labels': labels.squeeze(), 'src_lang': src_lang, 'tgt_lang': tgt_lang} 49 | 50 | class DataModule(pl.LightningDataModule): 51 | def __init__(self, *args, **kwargs): 52 | super().__init__() 53 | self.save_hyperparameters() 54 | self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.tokenizer_name_or_path) 55 | 56 | def setup(self, stage=None): 57 | self.train = Dataset1(self.hparams.train_path, self.tokenizer, self.hparams.max_source_length, self.hparams.max_target_length, self.hparams.target_lang) 58 | self.val = Dataset1(self.hparams.val_path, self.tokenizer, self.hparams.max_source_length, self.hparams.max_target_length, self.hparams.target_lang) 59 | self.test = Dataset1(self.hparams.test_path, self.tokenizer, self.hparams.max_source_length, self.hparams.max_target_length, self.hparams.target_lang) 60 | 61 | def train_dataloader(self): 62 | return DataLoader(self.train, batch_size=self.hparams.train_batch_size, num_workers=1,shuffle=True) 63 | 64 | def val_dataloader(self): 65 | return DataLoader(self.val, batch_size=self.hparams.val_batch_size, num_workers=1,shuffle=False) 66 | 67 | def test_dataloader(self): 68 | return DataLoader(self.test, batch_size=self.hparams.test_batch_size, num_workers=1,shuffle=False) 69 | 70 | def predict_dataloader(self): 71 | return self.test_dataloader() 72 | -------------------------------------------------------------------------------- /mbart_finetuning/train.py: -------------------------------------------------------------------------------- 1 | from model.model import Summarizer 2 | from model.dataloader import DataModule 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.loggers import WandbLogger 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.plugins import DDPPlugin 8 | 9 | import os 10 | import sys 11 | import glob 12 | import time 13 | import argparse 14 | os.environ["WANDB_SILENT"] = "True" 15 | 16 | def main(args): 17 | 18 | train_path = args.train_path 19 | val_path = args.val_path 20 | test_path = args.test_path 21 | 22 | tokenizer_name_or_path = args.tokenizer 23 | model_name_or_path = args.model 24 | 25 | 26 | if args.config is not None: 27 | config = args.config 28 | else: 29 | config = model_name_or_path 30 | 31 | if not os.path.exists(args.prediction_path): 32 | os.system(f'mkdir -p {args.prediction_path}') 33 | 34 | n_gpus = args.n_gpus 35 | strategy = args.strategy 36 | EXP_NAME = args.exp_name 37 | save_dir = args.save_dir 38 | target_lang = args.target_lang 39 | num_epochs = args.num_epochs 40 | train_batch_size = args.train_batch_size 41 | val_batch_size = args.val_batch_size 42 | test_batch_size = args.test_batch_size 43 | max_source_length = args.max_source_length 44 | max_target_length = args.max_target_length 45 | prediction_path = args.prediction_path 46 | 47 | dm_hparams = dict( 48 | train_path=train_path, 49 | val_path=val_path, 50 | test_path=test_path, 51 | tokenizer_name_or_path=tokenizer_name_or_path, 52 | max_source_length=max_source_length, 53 | max_target_length=max_target_length, 54 | train_batch_size=train_batch_size, 55 | val_batch_size=val_batch_size, 56 | test_batch_size=test_batch_size, 57 | target_lang=target_lang 58 | ) 59 | dm = DataModule(**dm_hparams) 60 | 61 | model_hparams = dict( 62 | learning_rate=2e-5, 63 | model_name_or_path=model_name_or_path, 64 | config = config, 65 | eval_beams=4, 66 | tgt_max_seq_len=max_target_length, 67 | tokenizer=dm.tokenizer, 68 | target_lang=target_lang, 69 | prediction_path=prediction_path 70 | ) 71 | 72 | model = Summarizer(**model_hparams) 73 | 74 | if args.sanity_run=='yes': 75 | log_model = False 76 | limit_train_batches = 4 77 | limit_val_batches = 4 78 | limit_test_batches = 4 79 | else: 80 | log_model = True 81 | limit_train_batches = 1.0 82 | limit_val_batches = 1.0 83 | limit_test_batches = 1.0 84 | 85 | checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', 86 | dirpath=os.path.join(save_dir+EXP_NAME, 'lightning-checkpoints'), 87 | filename='best_checkpoint', 88 | save_top_k=1, 89 | verbose=True, 90 | save_last=False, 91 | save_weights_only=False) 92 | 93 | trainer_hparams = dict( 94 | gpus=n_gpus, 95 | strategy=strategy, 96 | max_epochs=num_epochs, 97 | num_sanity_val_steps=3, 98 | logger=WandbLogger(name=model_name_or_path.split('/')[-1], save_dir=save_dir+EXP_NAME, project=EXP_NAME, log_model=False), 99 | check_val_every_n_epoch=1, 100 | val_check_interval=1.0, 101 | enable_checkpointing=True, 102 | callbacks=[checkpoint_callback], 103 | limit_train_batches=limit_train_batches, 104 | limit_val_batches=limit_val_batches, 105 | limit_test_batches=limit_test_batches 106 | ) 107 | trainer = pl.Trainer(**trainer_hparams) 108 | 109 | trainer.fit(model, dm) 110 | 111 | ckpt_path = save_dir+EXP_NAME + 'lightning-checkpoints/best_checkpoint.ckpt' 112 | model = model.load_from_checkpoint(ckpt_path) 113 | results = trainer.test(model=model, datamodule=dm, verbose=True) 114 | 115 | 116 | if __name__ == '__main__': 117 | 118 | parser = argparse.ArgumentParser(description='Input parameters for extractive stage') 119 | parser.add_argument('--n_gpus', default=1, type=int, help='number of gpus to use') 120 | parser.add_argument('--train_path', help='path to input json file for a given domain in given language') 121 | parser.add_argument('--val_path', help='path to intermediate output json file for a given domain in given language') 122 | parser.add_argument('--test_path', help='path to output json file for a given domain in given language') 123 | parser.add_argument('--config', default=None, help='which config file to use') 124 | parser.add_argument('--tokenizer', default='facebook/mbart-large-50', help='which tokenizer to use') 125 | parser.add_argument('--model', default='facebook/mbart-large-50', help='which model to use') 126 | parser.add_argument('--exp_name', default='mbart-basline', help='experiment name') 127 | parser.add_argument('--save_dir', default='checkpoints/', help='where to save the logs and checkpoints') 128 | parser.add_argument('--target_lang', default='hi', help='what is the target language') 129 | parser.add_argument('--num_epochs', default=5, type=int, help='number of epochs') 130 | parser.add_argument('--train_batch_size', default=4, type=int, help='train batch size') 131 | parser.add_argument('--val_batch_size', default=4, type=int, help='val batch size') 132 | parser.add_argument('--test_batch_size', default=4, type=int, help='test batch size') 133 | parser.add_argument('--max_source_length', default=1024, type=int, help='max source length') 134 | parser.add_argument('--max_target_length', default=1024, type=int, help='max target length') 135 | parser.add_argument('--strategy', default='dp', help='which strategy to use') 136 | parser.add_argument('--sanity_run', default='no', help='which strategy to use') 137 | 138 | args = parser.parse_args() 139 | 140 | main(args) 141 | -------------------------------------------------------------------------------- /mbart_finetuning/model/model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from transformers import MBartForConditionalGeneration, AutoConfig, AutoModelForSeq2SeqLM, MBartTokenizer 3 | import torch 4 | from rouge import Rouge 5 | import json 6 | from indicnlp.transliterate import unicode_transliterate 7 | import pandas as pd 8 | 9 | class Summarizer(pl.LightningModule): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.save_hyperparameters() 13 | self.rouge = Rouge() 14 | self.config = AutoConfig.from_pretrained(self.hparams.config) 15 | 16 | self.model = MBartForConditionalGeneration.from_pretrained(self.hparams.model_name_or_path) 17 | 18 | self.languages_map = { 19 | 'bn': 'bn_IN', 20 | 'en': 'en_XX', 21 | 'hi': 'hi_IN', 22 | 'ml': 'ml_IN', 23 | 'mr': 'mr_IN', 24 | 'or': 'or_IN', 25 | 'pa': 'pa_IN', 26 | 'ta': 'ta_IN', 27 | } 28 | 29 | def forward(self, input_ids, attention_mask, labels): 30 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 31 | return outputs 32 | 33 | def _step(self, batch): 34 | input_ids, attention_mask, labels, src_lang, tgt_lang = batch['input_ids'], batch['attention_mask'], batch['labels'], batch['src_lang'], batch['tgt_lang'] 35 | outputs = self(input_ids, attention_mask, labels) 36 | loss = outputs[0] 37 | return loss 38 | 39 | def _generative_step(self, batch): 40 | token_id = self.hparams.tokenizer.lang_code_to_id[batch['tgt_lang']] 41 | self.hparams.tokenizer.tgt_lang = batch['tgt_lang'] 42 | generated_ids = self.model.generate( 43 | input_ids=batch['input_ids'], 44 | attention_mask=batch['attention_mask'], 45 | use_cache=True, 46 | num_beams=self.hparams.eval_beams, 47 | forced_bos_token_id=token_id, 48 | max_length=self.hparams.tgt_max_seq_len 49 | ) 50 | 51 | input_text = self.hparams.tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True) 52 | pred_text = self.hparams.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 53 | ref_text = self.hparams.tokenizer.batch_decode(batch['labels'], skip_special_tokens=True) 54 | 55 | return input_text, pred_text, ref_text 56 | 57 | def training_step(self, batch, batch_idx): 58 | loss = self._step(batch) 59 | self.log("train_loss", loss, on_epoch=True) 60 | return {'loss': loss} 61 | 62 | def validation_step(self, batch, batch_idx): 63 | loss = self._step(batch) 64 | input_text, pred_text, ref_text = self._generative_step(batch) 65 | self.log("val_loss", loss, on_epoch=True) 66 | return loss 67 | 68 | def validation_epoch_end(self, outputs): 69 | pred_text = [] 70 | ref_text = [] 71 | for x in outputs: 72 | pred = x['pred_text'] 73 | if pred[0] == '': 74 | pred[0] = 'default text' 75 | pred_text.extend(pred) 76 | else: 77 | pred_text.extend(pred) 78 | 79 | ref = x['ref_text'] 80 | if ref[0] == '': 81 | ref[0] = 'default text' 82 | ref_text.extend(ref) 83 | else: 84 | ref_text.extend(ref) 85 | 86 | rouge = self.rouge.get_scores(pred_text, ref_text, avg=True) 87 | 88 | self.log("val_rouge-1_prec", rouge['rouge-1']['p']) 89 | self.log("val_rouge-1_rec", rouge['rouge-1']['r']) 90 | self.log("val_rouge-1_f1", rouge['rouge-1']['f']) 91 | 92 | self.log("val_rouge-2_prec", rouge['rouge-2']['p']) 93 | self.log("val_rouge-2_rec", rouge['rouge-2']['r']) 94 | self.log("val_rouge-2_f1", rouge['rouge-2']['f']) 95 | 96 | self.log("val_rouge-l_prec", rouge['rouge-l']['p']) 97 | self.log("val_rouge-l_rec", rouge['rouge-l']['r']) 98 | self.log("val_rouge-l_f1", rouge['rouge-l']['f']) 99 | return 100 | 101 | 102 | def predict_step(self, batch, batch_idx): 103 | input_text, pred_text, ref_text = self._generative_step(batch) 104 | return {'input_text': input_text, 'pred_text': pred_text, 'ref_text': ref_text} 105 | 106 | def test_step(self, batch, batch_idx): 107 | loss = self._step(batch) 108 | input_text, pred_text, ref_text = self._generative_step(batch) 109 | return {'test_loss': loss, 'input_text': input_text, 'pred_text': pred_text, 'ref_text': ref_text} 110 | 111 | def test_epoch_end(self, outputs): 112 | df_to_write = pd.DataFrame(columns=['lang', 'input_text', 'ref_text', 'pred_text', 'rouge']) 113 | input_text = [] 114 | langs = [] 115 | pred_text = [] 116 | ref_text = [] 117 | langs = [] 118 | for x in outputs: 119 | input_texts.extend(x['input_text']) 120 | pred_texts.extend(x['pred_text']) 121 | ref_texts.extend(x['ref_text']) 122 | langs.extend(x['lang']) 123 | 124 | for key in self.languages_map: 125 | self.languages_map[key]['original_pred_text'] = [self.process_for_rouge(pred_text, self.lang_id_map[lang]) for pred_text, lang in zip(pred_texts, langs) if lang == self.languages_map[key]['id']] 126 | self.languages_map[key]['original_ref_text'] = [self.process_for_rouge(ref_text, self.lang_id_map[lang]) for ref_text, lang in zip(ref_texts, langs) if lang == self.languages_map[key]['id']] 127 | self.languages_map[key]['original_input_text'] = [self.process_for_rouge(input_text, self.lang_id_map[lang]) for input_text, lang in zip(input_texts, langs) if lang == self.languages_map[key]['id']] 128 | 129 | overall_rouge = 0 130 | for key in self.languages_map: 131 | try: 132 | self.languages_map[key]['rouge'] = self.rouge.get_scores(self.languages_map[key]['original_pred_text'], [self.languages_map[key]['original_ref_text']]).score 133 | self.log(f"test_rouge_{key}", self.languages_map[key]['rouge']) 134 | overall_rouge += self.languages_map[key]['rouge'] 135 | except: 136 | pass 137 | 138 | self.log("test_rouge", overall_rouge/len(self.languages_map)) 139 | 140 | for key in self.languages_map: 141 | l = len(self.languages_map[key]['original_pred_text']) 142 | self.languages_map[key]['rouges'] = [self.cal_bleu.corpus_score([self.languages_map[key]['original_pred_text'][i]], [[self.languages_map[key]['original_ref_text'][i]]]).score for i in range(len(self.languages_map[key]['original_pred_text']))] 143 | df_key = pd.DataFrame({ 144 | 'lang':[key for i in range(l)], 145 | 'input_text':[self.languages_map[key]['original_input_text'][i] for i in range(l)], 146 | 'pred_text':[self.languages_map[key]['original_pred_text'][i] for i in range(l)], 147 | 'ref_text':[self.languages_map[key]['original_ref_text'][i] for i in range(l)], 148 | 'rouge':[self.languages_map[key]['rouges'][i] for i in range(l)] 149 | }) 150 | df_to_write = pd.concat([df_to_write, df_key]) 151 | 152 | df_to_write.to_csv(self.hparams.prediction_path + 'preds_mbart.csv', index=False) 153 | 154 | return 155 | 156 | def configure_optimizers(self): 157 | return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 158 | 159 | @staticmethod 160 | def add_model_specific_args(parent_parser): 161 | parser = parent_parser.add_argument_group('Bart Fine-tuning Parameters') 162 | parser.add_argument('--learning_rate', default=2e-5, type=float) 163 | parser.add_argument('--model_name_or_path', default='bart-base', type=str) 164 | parser.add_argument('--eval_beams', default=4, type=int) 165 | parser.add_argument('--tgt_max_seq_len', default=128, type=int) 166 | parser.add_argument('--tokenizer', default='bart-base', type=str) 167 | return parent_parser -------------------------------------------------------------------------------- /Transformers_multilabel_distilbert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Transformers_multilabel_distilbert.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true, 10 | "mount_file_id": "15ZXLCtJmBmByqWLkJsyuRuidGunSQt3c", 11 | "authorship_tag": "ABX9TyMYFn2H5CH0K8Hce4nsXrwt", 12 | "include_colab_link": true 13 | }, 14 | "kernelspec": { 15 | "name": "python3", 16 | "display_name": "Python 3" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "kT5-oqMPB6vp", 35 | "colab_type": "text" 36 | }, 37 | "source": [ 38 | "# Fine Tuning DistilBERT for MultiLabel Text Classification" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "I4R39UTxNKTk", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "### Introduction\n", 49 | "\n", 50 | "In this tutorial we will be fine tuning a transformer model for the **Multilabel text classification** problem. \n", 51 | "This is one of the most common business problems where a given piece of text/sentence/document needs to be classified into one or more of categories out of the given list. For example a movie can be categorized into 1 or more genres.\n", 52 | "\n", 53 | "#### Flow of the notebook\n", 54 | "\n", 55 | "The notebook will be divided into seperate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:\n", 56 | "\n", 57 | "1. [Importing Python Libraries and preparing the environment](#section01)\n", 58 | "2. [Importing and Pre-Processing the domain data](#section02)\n", 59 | "3. [Preparing the Dataset and Dataloader](#section03)\n", 60 | "4. [Creating the Neural Network for Fine Tuning](#section04)\n", 61 | "5. [Fine Tuning the Model](#section05)\n", 62 | "6. [Validating the Model Performance](#section06)\n", 63 | "7. [Saving the model and artifacts for Inference in Future](#section07)\n", 64 | "\n", 65 | "#### Technical Details\n", 66 | "\n", 67 | "This script leverages on multiple tools designed by other teams. Details of the tools used below. Please ensure that these elements are present in your setup to successfully implement this script.\n", 68 | "\n", 69 | " - Data: \n", 70 | "\t - We are using the Jigsaw toxic data from [Kaggle](https://www.kaggle.com/)\n", 71 | " - This is competion provide the souce dataset [Toxic Comment Competition](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge)\n", 72 | "\t - We are referring only to the first csv file from the data dump: `train.csv`\n", 73 | "\t - There are rows of data. Where each row has the following data-point: \n", 74 | "\t\t - Comment Text\n", 75 | "\t\t - `toxic`\n", 76 | "\t\t - `severe_toxic`\n", 77 | "\t\t - `obscene`\n", 78 | "\t\t - `threat`\n", 79 | "\t\t - `insult`\n", 80 | "\t\t - `identity_hate`\n", 81 | "\n", 82 | "Each comment can be marked for multiple categories. If the comment is `toxic` and `obscene`, then for both those headers the value will be `1` and for the others it will be `0`.\n", 83 | "\n", 84 | "\n", 85 | " - Language Model Used:\n", 86 | "\t - DistilBERT is a smaller transformer model as compared to BERT or Roberta. It is created by process of distillation applied to Bert. \n", 87 | "\t - [Blog-Post](https://medium.com/huggingface/distilbert-8cf3380435b5)\n", 88 | "\t - [Research Paper](https://arxiv.org/pdf/1910.01108)\n", 89 | " - [Documentation for python](https://huggingface.co/transformers/model_doc/distilbert.html)\n", 90 | "\n", 91 | "\n", 92 | " - Hardware Requirements:\n", 93 | "\t - Python 3.6 and above\n", 94 | "\t - Pytorch, Transformers and All the stock Python ML Libraries\n", 95 | "\t - GPU enabled setup \n", 96 | "\n", 97 | "\n", 98 | " - Script Objective:\n", 99 | "\t - The objective of this script is to fine tune DistilBERT to be able to label a comment into the following categories:\n", 100 | "\t\t - `toxic`\n", 101 | "\t\t - `severe_toxic`\n", 102 | "\t\t - `obscene`\n", 103 | "\t\t - `threat`\n", 104 | "\t\t - `insult`\n", 105 | "\t\t - `identity_hate`\n", 106 | "\n", 107 | "---\n", 108 | "***NOTE***\n", 109 | "- *It is to be noted that the overall mechanisms for a multiclass and multilabel problems are similar, except for few differences namely:*\n", 110 | "\t- *Loss function is designed to evaluate all the probability of categories individually rather than as compared to other categories. Hence the use of `BCE` rather than `Cross Entropy` when defining loss.*\n", 111 | "\t- *Sigmoid of the outputs calcuated to rather than Softmax. Again for the reasons defined in the previous point*\n", 112 | "\t- *The [loss metrics](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html) and **Hamming Score** are used for direct comparison of expected vs predicted*\n", 113 | "---" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "-FA0wthINIsH", 120 | "colab_type": "text" 121 | }, 122 | "source": [ 123 | "\n", 124 | "### Importing Python Libraries and preparing the environment\n", 125 | "\n", 126 | "At this step we will be importing the libraries and modules needed to run our script. Libraries are:\n", 127 | "* warnings\n", 128 | "* Numpy\n", 129 | "* Pandas\n", 130 | "* tqdm\n", 131 | "* scikit-learn metrics\n", 132 | "* Pytorch\n", 133 | "* Pytorch Utils for Dataset and Dataloader\n", 134 | "* Transformers\n", 135 | "* DistilBERT Model and Tokenizer\n", 136 | "* logging\n", 137 | "\n", 138 | "Followed by that we will preapre the device for CUDA execeution. This configuration is needed if you want to leverage on onboard GPU. " 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "metadata": { 144 | "id": "XTkCDy7NPuOx", 145 | "colab_type": "code", 146 | "colab": {} 147 | }, 148 | "source": [ 149 | "! pip install transformers==3.0.2" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "zHxRRzqpBf76", 158 | "colab_type": "code", 159 | "colab": {} 160 | }, 161 | "source": [ 162 | "# Importing stock ml libraries\n", 163 | "import warnings\n", 164 | "warnings.simplefilter('ignore')\n", 165 | "import numpy as np\n", 166 | "import pandas as pd\n", 167 | "from tqdm import tqdm\n", 168 | "from sklearn import metrics\n", 169 | "import transformers\n", 170 | "import torch\n", 171 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n", 172 | "from transformers import DistilBertTokenizer, DistilBertModel\n", 173 | "import logging\n", 174 | "logging.basicConfig(level=logging.ERROR)" 175 | ], 176 | "execution_count": 1, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "B7N-SkxWC7zT", 183 | "colab_type": "code", 184 | "colab": {} 185 | }, 186 | "source": [ 187 | "# # Setting up the device for GPU usage\n", 188 | "\n", 189 | "from torch import cuda\n", 190 | "device = 'cuda' if cuda.is_available() else 'cpu'" 191 | ], 192 | "execution_count": 2, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "tr6XeiZW3YbT", 199 | "colab_type": "code", 200 | "colab": {} 201 | }, 202 | "source": [ 203 | "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n", 204 | " acc_list = []\n", 205 | " for i in range(y_true.shape[0]):\n", 206 | " set_true = set( np.where(y_true[i])[0] )\n", 207 | " set_pred = set( np.where(y_pred[i])[0] )\n", 208 | " tmp_a = None\n", 209 | " if len(set_true) == 0 and len(set_pred) == 0:\n", 210 | " tmp_a = 1\n", 211 | " else:\n", 212 | " tmp_a = len(set_true.intersection(set_pred))/\\\n", 213 | " float( len(set_true.union(set_pred)) )\n", 214 | " acc_list.append(tmp_a)\n", 215 | " return np.mean(acc_list)" 216 | ], 217 | "execution_count": 46, 218 | "outputs": [] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": { 223 | "id": "6RJt2gnlNKT3", 224 | "colab_type": "text" 225 | }, 226 | "source": [ 227 | "\n", 228 | "### Importing and Pre-Processing the domain data\n", 229 | "\n", 230 | "We will be working with the data and preparing for fine tuning purposes. \n", 231 | "*Assuming that the `train.csv` is already downloaded, unzipped and saved in your `data` folder*\n", 232 | "\n", 233 | "* First step will be to remove the **id** column from the data.\n", 234 | "* A new dataframe is made and input text is stored in the **text** column.\n", 235 | "* The values of all the categories and coverting it into a list.\n", 236 | "* The list is appened as a new column names as **labels**." 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "9w86kD1qC_kb", 243 | "colab_type": "code", 244 | "colab": {} 245 | }, 246 | "source": [ 247 | "data = pd.read_csv('train.csv')" 248 | ], 249 | "execution_count": 3, 250 | "outputs": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "metadata": { 255 | "id": "goM0DYRyvDP_", 256 | "colab_type": "code", 257 | "colab": {} 258 | }, 259 | "source": [ 260 | "data.drop(['id'], inplace=True, axis=1)" 261 | ], 262 | "execution_count": 4, 263 | "outputs": [] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "metadata": { 268 | "id": "Svs-AUlrvaMA", 269 | "colab_type": "code", 270 | "colab": {} 271 | }, 272 | "source": [ 273 | "new_df = pd.DataFrame()\n", 274 | "new_df['text'] = data['comment_text']\n", 275 | "new_df['labels'] = data.iloc[:, 1:].values.tolist()" 276 | ], 277 | "execution_count": 5, 278 | "outputs": [] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "metadata": { 283 | "id": "Lfj9oQc1v6Dc", 284 | "colab_type": "code", 285 | "colab": { 286 | "base_uri": "https://localhost:8080/", 287 | "height": 204 288 | }, 289 | "outputId": "9a3a5d3f-c8cb-4b7b-9a79-5274329aba05" 290 | }, 291 | "source": [ 292 | "new_df.head()" 293 | ], 294 | "execution_count": 6, 295 | "outputs": [ 296 | { 297 | "output_type": "execute_result", 298 | "data": { 299 | "text/html": [ 300 | "
\n", 301 | "\n", 314 | "\n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | "
textlabels
0Explanation\\nWhy the edits made under my usern...[0, 0, 0, 0, 0, 0]
1D'aww! He matches this background colour I'm s...[0, 0, 0, 0, 0, 0]
2Hey man, I'm really not trying to edit war. It...[0, 0, 0, 0, 0, 0]
3\"\\nMore\\nI can't make any real suggestions on ...[0, 0, 0, 0, 0, 0]
4You, sir, are my hero. Any chance you remember...[0, 0, 0, 0, 0, 0]
\n", 350 | "
" 351 | ], 352 | "text/plain": [ 353 | " text labels\n", 354 | "0 Explanation\\nWhy the edits made under my usern... [0, 0, 0, 0, 0, 0]\n", 355 | "1 D'aww! He matches this background colour I'm s... [0, 0, 0, 0, 0, 0]\n", 356 | "2 Hey man, I'm really not trying to edit war. It... [0, 0, 0, 0, 0, 0]\n", 357 | "3 \"\\nMore\\nI can't make any real suggestions on ... [0, 0, 0, 0, 0, 0]\n", 358 | "4 You, sir, are my hero. Any chance you remember... [0, 0, 0, 0, 0, 0]" 359 | ] 360 | }, 361 | "metadata": { 362 | "tags": [] 363 | }, 364 | "execution_count": 6 365 | } 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": { 371 | "id": "yobavIDmNKT7", 372 | "colab_type": "text" 373 | }, 374 | "source": [ 375 | "\n", 376 | "### Preparing the Dataset and Dataloader\n", 377 | "\n", 378 | "We will start with defining few key variables that will be used later during the training/fine tuning stage.\n", 379 | "Followed by creation of MultiLabelDataset class - This defines how the text is pre-processed before sending it to the neural network. We will also define the Dataloader that will feed the data in batches to the neural network for suitable training and processing. \n", 380 | "Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)\n", 381 | "\n", 382 | "#### *MultiLabelDataset* Dataset Class\n", 383 | "- This class is defined to accept the `tokenizer`, `dataframe` and `max_length` as input and generate tokenized output and tags that is used by the BERT model for training. \n", 384 | "- We are using the DistilBERT tokenizer to tokenize the data in the `text` column of the dataframe.\n", 385 | "- The tokenizer uses the `encode_plus` method to perform tokenization and generate the necessary outputs, namely: `ids`, `attention_mask`, `token_type_ids`\n", 386 | "\n", 387 | "- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/distilbert.html#distilberttokenizer)\n", 388 | "- `targets` is the list of categories labled as `0` or `1` in the dataframe. \n", 389 | "- The *MultiLabelDataset* class is used to create 2 datasets, for training and for validation.\n", 390 | "- *Training Dataset* is used to fine tune the model: **80% of the original data**\n", 391 | "- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training. \n", 392 | "\n", 393 | "#### Dataloader\n", 394 | "- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.\n", 395 | "- This control is achieved using the parameters such as `batch_size` and `max_len`.\n", 396 | "- Training and Validation dataloaders are used in the training and validation part of the flow respectively" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "metadata": { 402 | "id": "VpF4ZAaxC_OJ", 403 | "colab_type": "code", 404 | "colab": {} 405 | }, 406 | "source": [ 407 | "# Sections of config\n", 408 | "\n", 409 | "# Defining some key variables that will be used later on in the training\n", 410 | "MAX_LEN = 128\n", 411 | "TRAIN_BATCH_SIZE = 4\n", 412 | "VALID_BATCH_SIZE = 4\n", 413 | "EPOCHS = 1\n", 414 | "LEARNING_RATE = 1e-05\n", 415 | "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)" 416 | ], 417 | "execution_count": 7, 418 | "outputs": [] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "metadata": { 423 | "id": "fC7zA4nYDGX3", 424 | "colab_type": "code", 425 | "colab": {} 426 | }, 427 | "source": [ 428 | "class MultiLabelDataset(Dataset):\n", 429 | "\n", 430 | " def __init__(self, dataframe, tokenizer, max_len):\n", 431 | " self.tokenizer = tokenizer\n", 432 | " self.data = dataframe\n", 433 | " self.text = dataframe.text\n", 434 | " self.targets = self.data.labels\n", 435 | " self.max_len = max_len\n", 436 | "\n", 437 | " def __len__(self):\n", 438 | " return len(self.text)\n", 439 | "\n", 440 | " def __getitem__(self, index):\n", 441 | " text = str(self.text[index])\n", 442 | " text = \" \".join(text.split())\n", 443 | "\n", 444 | " inputs = self.tokenizer.encode_plus(\n", 445 | " text,\n", 446 | " None,\n", 447 | " add_special_tokens=True,\n", 448 | " max_length=self.max_len,\n", 449 | " pad_to_max_length=True,\n", 450 | " return_token_type_ids=True\n", 451 | " )\n", 452 | " ids = inputs['input_ids']\n", 453 | " mask = inputs['attention_mask']\n", 454 | " token_type_ids = inputs[\"token_type_ids\"]\n", 455 | "\n", 456 | "\n", 457 | " return {\n", 458 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 459 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 460 | " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n", 461 | " 'targets': torch.tensor(self.targets[index], dtype=torch.float)\n", 462 | " }" 463 | ], 464 | "execution_count": 8, 465 | "outputs": [] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "metadata": { 470 | "id": "qOZWsLfGDOBG", 471 | "colab_type": "code", 472 | "colab": { 473 | "base_uri": "https://localhost:8080/", 474 | "height": 68 475 | }, 476 | "outputId": "8ebb8e94-4f9f-4db6-89f4-20a30217cc15" 477 | }, 478 | "source": [ 479 | "# Creating the dataset and dataloader for the neural network\n", 480 | "\n", 481 | "train_size = 0.8\n", 482 | "train_data=new_df.sample(frac=train_size,random_state=200)\n", 483 | "test_data=new_df.drop(train_data.index).reset_index(drop=True)\n", 484 | "train_data = train_data.reset_index(drop=True)\n", 485 | "\n", 486 | "\n", 487 | "print(\"FULL Dataset: {}\".format(new_df.shape))\n", 488 | "print(\"TRAIN Dataset: {}\".format(train_data.shape))\n", 489 | "print(\"TEST Dataset: {}\".format(test_data.shape))\n", 490 | "\n", 491 | "training_set = MultiLabelDataset(train_data, tokenizer, MAX_LEN)\n", 492 | "testing_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN)" 493 | ], 494 | "execution_count": 9, 495 | "outputs": [ 496 | { 497 | "output_type": "stream", 498 | "text": [ 499 | "FULL Dataset: (159571, 2)\n", 500 | "TRAIN Dataset: (127657, 2)\n", 501 | "TEST Dataset: (31914, 2)\n" 502 | ], 503 | "name": "stdout" 504 | } 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "metadata": { 510 | "id": "StbPlIyKDP9E", 511 | "colab_type": "code", 512 | "colab": {} 513 | }, 514 | "source": [ 515 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 516 | " 'shuffle': True,\n", 517 | " 'num_workers': 0\n", 518 | " }\n", 519 | "\n", 520 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 521 | " 'shuffle': True,\n", 522 | " 'num_workers': 0\n", 523 | " }\n", 524 | "\n", 525 | "training_loader = DataLoader(training_set, **train_params)\n", 526 | "testing_loader = DataLoader(testing_set, **test_params)" 527 | ], 528 | "execution_count": 10, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": { 534 | "id": "yU4TWUBtNKUN", 535 | "colab_type": "text" 536 | }, 537 | "source": [ 538 | "\n", 539 | "### Creating the Neural Network for Fine Tuning\n", 540 | "\n", 541 | "#### Neural Network\n", 542 | " - We will be creating a neural network with the `DistilBERTClass`. \n", 543 | " - This network will have the `DistilBERT` model. Follwed by a `Droput` and `Linear Layer`. They are added for the purpose of **Regulariaztion** and **Classification** respectively. \n", 544 | " - In the forward loop, there are 2 output from the `DistilBERTClass` layer.\n", 545 | " - The second output `output_1` or called the `pooled output` is passed to the `Drop Out layer` and the subsequent output is given to the `Linear layer`. \n", 546 | " - Keep note the number of dimensions for `Linear Layer` is **6** because that is the total number of categories in which we are looking to classify our model.\n", 547 | " - The data will be fed to the `DistilBERTClass` as defined in the dataset. \n", 548 | " - Final layer outputs is what will be used to calcuate the loss and to determine the accuracy of models prediction. \n", 549 | " - We will initiate an instance of the network called `model`. This instance will be used for training and then to save the final trained model for future inference. \n", 550 | " \n", 551 | "#### Loss Function and Optimizer\n", 552 | " - The Loss is defined in the next cell as `loss_fn`.\n", 553 | " - As defined above, the loss function used will be a combination of Binary Cross Entropy which is implemented as [BCELogits Loss](https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss) in PyTorch\n", 554 | " - `Optimizer` is defined in the next cell.\n", 555 | " - `Optimizer` is used to update the weights of the neural network to improve its performance." 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "metadata": { 561 | "id": "FeftvDhjDSPp", 562 | "colab_type": "code", 563 | "colab": { 564 | "base_uri": "https://localhost:8080/", 565 | "height": 1000 566 | }, 567 | "outputId": "4ba915de-3a1e-4650-d253-2b43857f2d99" 568 | }, 569 | "source": [ 570 | "# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. \n", 571 | "\n", 572 | "class DistilBERTClass(torch.nn.Module):\n", 573 | " def __init__(self):\n", 574 | " super(DistilBERTClass, self).__init__()\n", 575 | " self.l1 = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n", 576 | " self.pre_classifier = torch.nn.Linear(768, 768)\n", 577 | " self.dropout = torch.nn.Dropout(0.1)\n", 578 | " self.classifier = torch.nn.Linear(768, 6)\n", 579 | "\n", 580 | " def forward(self, input_ids, attention_mask, token_type_ids):\n", 581 | " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)\n", 582 | " hidden_state = output_1[0]\n", 583 | " pooler = hidden_state[:, 0]\n", 584 | " pooler = self.pre_classifier(pooler)\n", 585 | " pooler = torch.nn.Tanh()(pooler)\n", 586 | " pooler = self.dropout(pooler)\n", 587 | " output = self.classifier(pooler)\n", 588 | " return output\n", 589 | "\n", 590 | "model = DistilBERTClass()\n", 591 | "model.to(device)" 592 | ], 593 | "execution_count": 11, 594 | "outputs": [ 595 | { 596 | "output_type": "execute_result", 597 | "data": { 598 | "text/plain": [ 599 | "DistilBERTClass(\n", 600 | " (l1): DistilBertModel(\n", 601 | " (embeddings): Embeddings(\n", 602 | " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", 603 | " (position_embeddings): Embedding(512, 768)\n", 604 | " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 605 | " (dropout): Dropout(p=0.1, inplace=False)\n", 606 | " )\n", 607 | " (transformer): Transformer(\n", 608 | " (layer): ModuleList(\n", 609 | " (0): TransformerBlock(\n", 610 | " (attention): MultiHeadSelfAttention(\n", 611 | " (dropout): Dropout(p=0.1, inplace=False)\n", 612 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 613 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 614 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 615 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 616 | " )\n", 617 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 618 | " (ffn): FFN(\n", 619 | " (dropout): Dropout(p=0.1, inplace=False)\n", 620 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 621 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 622 | " )\n", 623 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 624 | " )\n", 625 | " (1): TransformerBlock(\n", 626 | " (attention): MultiHeadSelfAttention(\n", 627 | " (dropout): Dropout(p=0.1, inplace=False)\n", 628 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 629 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 630 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 631 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 632 | " )\n", 633 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 634 | " (ffn): FFN(\n", 635 | " (dropout): Dropout(p=0.1, inplace=False)\n", 636 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 637 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 638 | " )\n", 639 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 640 | " )\n", 641 | " (2): TransformerBlock(\n", 642 | " (attention): MultiHeadSelfAttention(\n", 643 | " (dropout): Dropout(p=0.1, inplace=False)\n", 644 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 645 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 646 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 647 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 648 | " )\n", 649 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 650 | " (ffn): FFN(\n", 651 | " (dropout): Dropout(p=0.1, inplace=False)\n", 652 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 653 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 654 | " )\n", 655 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 656 | " )\n", 657 | " (3): TransformerBlock(\n", 658 | " (attention): MultiHeadSelfAttention(\n", 659 | " (dropout): Dropout(p=0.1, inplace=False)\n", 660 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 661 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 662 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 663 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 664 | " )\n", 665 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 666 | " (ffn): FFN(\n", 667 | " (dropout): Dropout(p=0.1, inplace=False)\n", 668 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 669 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 670 | " )\n", 671 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 672 | " )\n", 673 | " (4): TransformerBlock(\n", 674 | " (attention): MultiHeadSelfAttention(\n", 675 | " (dropout): Dropout(p=0.1, inplace=False)\n", 676 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 677 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 678 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 679 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 680 | " )\n", 681 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 682 | " (ffn): FFN(\n", 683 | " (dropout): Dropout(p=0.1, inplace=False)\n", 684 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 685 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 686 | " )\n", 687 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 688 | " )\n", 689 | " (5): TransformerBlock(\n", 690 | " (attention): MultiHeadSelfAttention(\n", 691 | " (dropout): Dropout(p=0.1, inplace=False)\n", 692 | " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", 693 | " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", 694 | " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", 695 | " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", 696 | " )\n", 697 | " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 698 | " (ffn): FFN(\n", 699 | " (dropout): Dropout(p=0.1, inplace=False)\n", 700 | " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", 701 | " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", 702 | " )\n", 703 | " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", 704 | " )\n", 705 | " )\n", 706 | " )\n", 707 | " )\n", 708 | " (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n", 709 | " (dropout): Dropout(p=0.1, inplace=False)\n", 710 | " (classifier): Linear(in_features=768, out_features=6, bias=True)\n", 711 | ")" 712 | ] 713 | }, 714 | "metadata": { 715 | "tags": [] 716 | }, 717 | "execution_count": 11 718 | } 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "metadata": { 724 | "id": "MZ_wI0YwDVJZ", 725 | "colab_type": "code", 726 | "colab": {} 727 | }, 728 | "source": [ 729 | "def loss_fn(outputs, targets):\n", 730 | " return torch.nn.BCEWithLogitsLoss()(outputs, targets)" 731 | ], 732 | "execution_count": 12, 733 | "outputs": [] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "metadata": { 738 | "id": "oO49FuR9DXsW", 739 | "colab_type": "code", 740 | "colab": {} 741 | }, 742 | "source": [ 743 | "optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)" 744 | ], 745 | "execution_count": 13, 746 | "outputs": [] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": { 751 | "id": "_frCoix1NKUZ", 752 | "colab_type": "text" 753 | }, 754 | "source": [ 755 | "\n", 756 | "### Fine Tuning the Model\n", 757 | "\n", 758 | "After all the effort of loading and preparing the data and datasets, creating the model and defining its loss and optimizer. This is probably the easier steps in the process. \n", 759 | "\n", 760 | "Here we define a training function that trains the model on the training dataset created above, specified number of times (EPOCH), An epoch defines how many times the complete data will be passed through the network. \n", 761 | "\n", 762 | "Following events happen in this function to fine tune the neural network:\n", 763 | "- The dataloader passes data to the model based on the batch size. \n", 764 | "- Subsequent output from the model and the actual category are compared to calculate the loss. \n", 765 | "- Loss value is used to optimize the weights of the neurons in the network.\n", 766 | "- After every 5000 steps the loss value is printed in the console.\n", 767 | "\n", 768 | "As you can see just in 1 epoch by the final step the model was working with a miniscule loss of 0.05 i.e. the network output is extremely close to the actual output." 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "metadata": { 774 | "id": "fb9-Yr9YDZqo", 775 | "colab_type": "code", 776 | "colab": {} 777 | }, 778 | "source": [ 779 | "def train(epoch):\n", 780 | " model.train()\n", 781 | " for _,data in tqdm(enumerate(training_loader, 0)):\n", 782 | " ids = data['ids'].to(device, dtype = torch.long)\n", 783 | " mask = data['mask'].to(device, dtype = torch.long)\n", 784 | " token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n", 785 | " targets = data['targets'].to(device, dtype = torch.float)\n", 786 | "\n", 787 | " outputs = model(ids, mask, token_type_ids)\n", 788 | "\n", 789 | " optimizer.zero_grad()\n", 790 | " loss = loss_fn(outputs, targets)\n", 791 | " if _%5000==0:\n", 792 | " print(f'Epoch: {epoch}, Loss: {loss.item()}')\n", 793 | " \n", 794 | " loss.backward()\n", 795 | " optimizer.step()" 796 | ], 797 | "execution_count": 14, 798 | "outputs": [] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "metadata": { 803 | "id": "Reta6H84DcJq", 804 | "colab_type": "code", 805 | "colab": { 806 | "base_uri": "https://localhost:8080/", 807 | "height": 153 808 | }, 809 | "outputId": "dd66c704-9c33-4815-828c-b552e7b5cf5e" 810 | }, 811 | "source": [ 812 | "for epoch in range(EPOCHS):\n", 813 | " train(epoch)" 814 | ], 815 | "execution_count": 15, 816 | "outputs": [ 817 | { 818 | "output_type": "stream", 819 | "text": [ 820 | "1it [00:00, 9.65it/s]" 821 | ], 822 | "name": "stderr" 823 | }, 824 | { 825 | "output_type": "stream", 826 | "text": [ 827 | "Epoch: 0, Loss: 0.6973826289176941\n" 828 | ], 829 | "name": "stdout" 830 | }, 831 | { 832 | "output_type": "stream", 833 | "text": [ 834 | "5003it [05:32, 14.64it/s]" 835 | ], 836 | "name": "stderr" 837 | }, 838 | { 839 | "output_type": "stream", 840 | "text": [ 841 | "Epoch: 0, Loss: 0.0016139214858412743\n" 842 | ], 843 | "name": "stdout" 844 | }, 845 | { 846 | "output_type": "stream", 847 | "text": [ 848 | "10003it [11:14, 14.59it/s]" 849 | ], 850 | "name": "stderr" 851 | }, 852 | { 853 | "output_type": "stream", 854 | "text": [ 855 | "Epoch: 0, Loss: 0.031105097383260727\n" 856 | ], 857 | "name": "stdout" 858 | }, 859 | { 860 | "output_type": "stream", 861 | "text": [ 862 | "15003it [16:56, 14.58it/s]" 863 | ], 864 | "name": "stderr" 865 | }, 866 | { 867 | "output_type": "stream", 868 | "text": [ 869 | "Epoch: 0, Loss: 0.0019174569752067327\n" 870 | ], 871 | "name": "stdout" 872 | }, 873 | { 874 | "output_type": "stream", 875 | "text": [ 876 | "20003it [22:38, 14.60it/s]" 877 | ], 878 | "name": "stderr" 879 | }, 880 | { 881 | "output_type": "stream", 882 | "text": [ 883 | "Epoch: 0, Loss: 0.0015925116604194045\n" 884 | ], 885 | "name": "stdout" 886 | }, 887 | { 888 | "output_type": "stream", 889 | "text": [ 890 | "25003it [28:20, 14.58it/s]" 891 | ], 892 | "name": "stderr" 893 | }, 894 | { 895 | "output_type": "stream", 896 | "text": [ 897 | "Epoch: 0, Loss: 0.08796875923871994\n" 898 | ], 899 | "name": "stdout" 900 | }, 901 | { 902 | "output_type": "stream", 903 | "text": [ 904 | "30003it [34:03, 14.58it/s]" 905 | ], 906 | "name": "stderr" 907 | }, 908 | { 909 | "output_type": "stream", 910 | "text": [ 911 | "Epoch: 0, Loss: 0.05103427171707153\n" 912 | ], 913 | "name": "stdout" 914 | }, 915 | { 916 | "output_type": "stream", 917 | "text": [ 918 | "31915it [36:14, 14.68it/s]\n" 919 | ], 920 | "name": "stderr" 921 | } 922 | ] 923 | }, 924 | { 925 | "cell_type": "markdown", 926 | "metadata": { 927 | "id": "AFv7mNcuNKUh", 928 | "colab_type": "text" 929 | }, 930 | "source": [ 931 | "\n", 932 | "### Validating the Model\n", 933 | "\n", 934 | "During the validation stage we pass the unseen data(Testing Dataset) to the model. This step determines how good the model performs on the unseen data. \n", 935 | "\n", 936 | "This unseen data is the 20% of `train.csv` which was seperated during the Dataset creation stage. \n", 937 | "During the validation stage the weights of the model are not updated. Only the final output is compared to the actual value. This comparison is then used to calcuate the accuracy of the model. \n", 938 | "\n", 939 | "As defined above to get a measure of our models performance we are using the following metrics. \n", 940 | "- Hamming Score\n", 941 | "- Hamming Loss\n" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "metadata": { 947 | "id": "aPqZKQ1BDfLW", 948 | "colab_type": "code", 949 | "colab": {} 950 | }, 951 | "source": [ 952 | "def validation(testing_loader):\n", 953 | " model.eval()\n", 954 | " fin_targets=[]\n", 955 | " fin_outputs=[]\n", 956 | " with torch.no_grad():\n", 957 | " for _, data in tqdm(enumerate(testing_loader, 0)):\n", 958 | " ids = data['ids'].to(device, dtype = torch.long)\n", 959 | " mask = data['mask'].to(device, dtype = torch.long)\n", 960 | " token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n", 961 | " targets = data['targets'].to(device, dtype = torch.float)\n", 962 | " outputs = model(ids, mask, token_type_ids)\n", 963 | " fin_targets.extend(targets.cpu().detach().numpy().tolist())\n", 964 | " fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())\n", 965 | " return fin_outputs, fin_targets" 966 | ], 967 | "execution_count": 17, 968 | "outputs": [] 969 | }, 970 | { 971 | "cell_type": "code", 972 | "metadata": { 973 | "id": "cQmJd7RvB4A-", 974 | "colab_type": "code", 975 | "colab": { 976 | "base_uri": "https://localhost:8080/", 977 | "height": 34 978 | }, 979 | "outputId": "ce9b6e23-94d3-4e11-bce0-2ad3ae5ab38e" 980 | }, 981 | "source": [ 982 | "outputs, targets = validation(testing_loader)\n", 983 | "\n", 984 | "final_outputs = np.array(outputs) >=0.5" 985 | ], 986 | "execution_count": 55, 987 | "outputs": [ 988 | { 989 | "output_type": "stream", 990 | "text": [ 991 | "7979it [02:39, 50.13it/s]\n" 992 | ], 993 | "name": "stderr" 994 | } 995 | ] 996 | }, 997 | { 998 | "cell_type": "code", 999 | "metadata": { 1000 | "id": "qekc58XPDhIp", 1001 | "colab_type": "code", 1002 | "colab": { 1003 | "base_uri": "https://localhost:8080/", 1004 | "height": 51 1005 | }, 1006 | "outputId": "950ef928-9171-46a1-f36d-69183bc369f3" 1007 | }, 1008 | "source": [ 1009 | "val_hamming_loss = metrics.hamming_loss(targets, final_outputs)\n", 1010 | "val_hamming_score = hamming_score(np.array(targets), np.array(final_outputs))\n", 1011 | "\n", 1012 | "print(f\"Hamming Score = {val_hamming_score}\")\n", 1013 | "print(f\"Hamming Loss = {val_hamming_loss}\")" 1014 | ], 1015 | "execution_count": 57, 1016 | "outputs": [ 1017 | { 1018 | "output_type": "stream", 1019 | "text": [ 1020 | "Hamming Score = 0.9452533893171231\n", 1021 | "Hamming Loss = 0.017207704037935284\n" 1022 | ], 1023 | "name": "stdout" 1024 | } 1025 | ] 1026 | }, 1027 | { 1028 | "cell_type": "markdown", 1029 | "metadata": { 1030 | "id": "nA59gbi0NItL", 1031 | "colab_type": "text" 1032 | }, 1033 | "source": [ 1034 | "\n", 1035 | "### Saving the Trained Model for inference\n", 1036 | "\n", 1037 | "This is the final step in the process of fine tuning the model. \n", 1038 | "\n", 1039 | "The model and its vocabulary are saved locally. These files are then used in the future to make inference on new inputs of news headlines." 1040 | ] 1041 | }, 1042 | { 1043 | "cell_type": "code", 1044 | "metadata": { 1045 | "id": "9yzLQmsdNItM", 1046 | "colab_type": "code", 1047 | "colab": {} 1048 | }, 1049 | "source": [ 1050 | "# Saving the files for inference\n", 1051 | "\n", 1052 | "output_model_file = './models/pytorch_distilbert_news.bin'\n", 1053 | "output_vocab_file = './models/vocab_distilbert_news.bin'\n", 1054 | "\n", 1055 | "torch.save(model, output_model_file)\n", 1056 | "tokenizer.save_vocabulary(output_vocab_file)\n", 1057 | "\n", 1058 | "print('Saved')" 1059 | ], 1060 | "execution_count": null, 1061 | "outputs": [] 1062 | } 1063 | ] 1064 | } 1065 | -------------------------------------------------------------------------------- /bert-pretraining.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "73e17100", 7 | "metadata": { 8 | "execution": { 9 | "iopub.execute_input": "2021-11-19T14:06:36.986013Z", 10 | "iopub.status.busy": "2021-11-19T14:06:36.984474Z", 11 | "iopub.status.idle": "2021-11-19T14:06:43.995598Z", 12 | "shell.execute_reply": "2021-11-19T14:06:43.996716Z" 13 | }, 14 | "papermill": { 15 | "duration": 7.036687, 16 | "end_time": "2021-11-19T14:06:43.997089", 17 | "exception": false, 18 | "start_time": "2021-11-19T14:06:36.960402", 19 | "status": "completed" 20 | }, 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "# Importing the libraries needed\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import torch\n", 29 | "import transformers\n", 30 | "import tokenizers\n", 31 | "from tqdm import tqdm\n", 32 | "from torch.utils.data import Dataset, DataLoader\n", 33 | "from transformers import BertTokenizer, LineByLineTextDataset, BertModel, BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling\n", 34 | "from transformers import Trainer, TrainingArguments\n", 35 | "import logging\n", 36 | "logging.basicConfig(level=logging.ERROR)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "cc9c6fa3", 43 | "metadata": { 44 | "execution": { 45 | "iopub.execute_input": "2021-11-19T14:06:44.065083Z", 46 | "iopub.status.busy": "2021-11-19T14:06:44.064315Z", 47 | "iopub.status.idle": "2021-11-19T14:06:44.188518Z", 48 | "shell.execute_reply": "2021-11-19T14:06:44.189633Z", 49 | "shell.execute_reply.started": "2021-11-19T07:49:25.870010Z" 50 | }, 51 | "papermill": { 52 | "duration": 0.164149, 53 | "end_time": "2021-11-19T14:06:44.189845", 54 | "exception": false, 55 | "start_time": "2021-11-19T14:06:44.025696", 56 | "status": "completed" 57 | }, 58 | "tags": [] 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "train_data = pd.read_csv('../input/query-wellformedness/train.tsv', sep='\\t', header=None)\n", 63 | "val_data = pd.read_csv('../input/query-wellformedness/dev.tsv', sep='\\t', header=None)\n", 64 | "test_data = pd.read_csv('../input/query-wellformedness/test.tsv', sep='\\t', header=None)\n", 65 | "\n", 66 | "train_data.columns = ['query', 'label']\n", 67 | "val_data.columns = ['query', 'label']\n", 68 | "test_data.columns = ['query', 'label']\n", 69 | "\n", 70 | "train_data['label'] = [1 if label>=0.8 else 0 for label in train_data['label']]\n", 71 | "val_data['label'] = [1 if label>=0.8 else 0 for label in val_data['label']]\n", 72 | "test_data['label'] = [1 if label>=0.8 else 0 for label in test_data['label']]\n", 73 | "\n", 74 | "pretraining_data = train_data['query'].tolist() + val_data['query'].tolist() + test_data['query'].tolist()\n", 75 | "\n", 76 | "with open('/kaggle/working/pretraining_data.txt', 'w') as f:\n", 77 | " for sent in pretraining_data:\n", 78 | " f.write(\"%s\\n\" % sent)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "de6d6438", 85 | "metadata": { 86 | "execution": { 87 | "iopub.execute_input": "2021-11-19T14:06:44.260509Z", 88 | "iopub.status.busy": "2021-11-19T14:06:44.259622Z", 89 | "iopub.status.idle": "2021-11-19T14:06:45.340056Z", 90 | "shell.execute_reply": "2021-11-19T14:06:45.340469Z", 91 | "shell.execute_reply.started": "2021-11-19T08:11:25.306607Z" 92 | }, 93 | "papermill": { 94 | "duration": 1.122382, 95 | "end_time": "2021-11-19T14:06:45.340632", 96 | "exception": false, 97 | "start_time": "2021-11-19T14:06:44.218250", 98 | "status": "completed" 99 | }, 100 | "tags": [] 101 | }, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "\n", 108 | "\n", 109 | "\n" 110 | ] 111 | }, 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "['/kaggle/working/vocab.txt']" 116 | ] 117 | }, 118 | "execution_count": 3, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "bwpt = tokenizers.BertWordPieceTokenizer()\n", 125 | " \n", 126 | "filepath = \"/kaggle/working/pretraining_data.txt\"\n", 127 | "\n", 128 | "bwpt.train(\n", 129 | " files=[filepath],\n", 130 | " vocab_size=50000,\n", 131 | " min_frequency=3,\n", 132 | " limit_alphabet=1000\n", 133 | ")\n", 134 | "\n", 135 | "bwpt.save_model('/kaggle/working/')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "id": "37438ef5", 142 | "metadata": { 143 | "execution": { 144 | "iopub.execute_input": "2021-11-19T14:06:45.380517Z", 145 | "iopub.status.busy": "2021-11-19T14:06:45.379873Z", 146 | "iopub.status.idle": "2021-11-19T14:06:45.398186Z", 147 | "shell.execute_reply": "2021-11-19T14:06:45.398559Z", 148 | "shell.execute_reply.started": "2021-11-19T08:11:33.412758Z" 149 | }, 150 | "papermill": { 151 | "duration": 0.04008, 152 | "end_time": "2021-11-19T14:06:45.398711", 153 | "exception": false, 154 | "start_time": "2021-11-19T14:06:45.358631", 155 | "status": "completed" 156 | }, 157 | "tags": [] 158 | }, 159 | "outputs": [ 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "/opt/conda/lib/python3.7/site-packages/transformers/tokenization_utils_base.py:1621: FutureWarning: Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated and won't be possible anymore in v5. Use a model identifier or the path to a directory instead.\n", 165 | " FutureWarning,\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# Load the tokenizer\n", 171 | "\n", 172 | "vocab_file_dir = '/kaggle/working/vocab.txt'\n", 173 | "\n", 174 | "tokenizer = BertTokenizer.from_pretrained(vocab_file_dir)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 5, 180 | "id": "11cde739", 181 | "metadata": { 182 | "execution": { 183 | "iopub.execute_input": "2021-11-19T14:06:45.438876Z", 184 | "iopub.status.busy": "2021-11-19T14:06:45.438218Z", 185 | "iopub.status.idle": "2021-11-19T14:06:53.320512Z", 186 | "shell.execute_reply": "2021-11-19T14:06:53.320091Z", 187 | "shell.execute_reply.started": "2021-11-19T08:11:37.079042Z" 188 | }, 189 | "papermill": { 190 | "duration": 7.903844, 191 | "end_time": "2021-11-19T14:06:53.320661", 192 | "exception": false, 193 | "start_time": "2021-11-19T14:06:45.416817", 194 | "status": "completed" 195 | }, 196 | "tags": [] 197 | }, 198 | "outputs": [ 199 | { 200 | "name": "stderr", 201 | "output_type": "stream", 202 | "text": [ 203 | "/opt/conda/lib/python3.7/site-packages/transformers/data/datasets/language_modeling.py:124: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py\n", 204 | " FutureWarning,\n" 205 | ] 206 | }, 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "No. of lines: 25100\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "dataset= LineByLineTextDataset(\n", 217 | " tokenizer = tokenizer,\n", 218 | " file_path = '/kaggle/working/pretraining_data.txt',\n", 219 | " block_size = 128\n", 220 | ")\n", 221 | "\n", 222 | "print('No. of lines: ', len(dataset))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 6, 228 | "id": "4c1a0938", 229 | "metadata": { 230 | "execution": { 231 | "iopub.execute_input": "2021-11-19T14:06:53.367897Z", 232 | "iopub.status.busy": "2021-11-19T14:06:53.364309Z", 233 | "iopub.status.idle": "2021-11-19T14:06:55.530232Z", 234 | "shell.execute_reply": "2021-11-19T14:06:55.530654Z", 235 | "shell.execute_reply.started": "2021-11-19T08:11:59.925287Z" 236 | }, 237 | "papermill": { 238 | "duration": 2.190946, 239 | "end_time": "2021-11-19T14:06:55.530814", 240 | "exception": false, 241 | "start_time": "2021-11-19T14:06:53.339868", 242 | "status": "completed" 243 | }, 244 | "tags": [] 245 | }, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "No of parameters: 81965648\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "config = BertConfig(\n", 257 | " vocab_size=50000,\n", 258 | " hidden_size=768, \n", 259 | " num_hidden_layers=6, \n", 260 | " num_attention_heads=12,\n", 261 | " max_position_embeddings=512\n", 262 | ")\n", 263 | " \n", 264 | "model = BertForMaskedLM(config)\n", 265 | "print('No of parameters: ', model.num_parameters())\n", 266 | "\n", 267 | "\n", 268 | "data_collator = DataCollatorForLanguageModeling(\n", 269 | " tokenizer=tokenizer, mlm=True, mlm_probability=0.15\n", 270 | ")" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 7, 276 | "id": "ff28c9ec", 277 | "metadata": { 278 | "execution": { 279 | "iopub.execute_input": "2021-11-19T14:06:55.573883Z", 280 | "iopub.status.busy": "2021-11-19T14:06:55.573052Z", 281 | "iopub.status.idle": "2021-11-19T14:07:02.065895Z", 282 | "shell.execute_reply": "2021-11-19T14:07:02.065349Z", 283 | "shell.execute_reply.started": "2021-11-19T08:18:37.451077Z" 284 | }, 285 | "papermill": { 286 | "duration": 6.516251, 287 | "end_time": "2021-11-19T14:07:02.066032", 288 | "exception": false, 289 | "start_time": "2021-11-19T14:06:55.549781", 290 | "status": "completed" 291 | }, 292 | "tags": [] 293 | }, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 300 | "To disable this warning, you can either:\n", 301 | "\t- Avoid using `tokenizers` before the fork if possible\n", 302 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 303 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 304 | "To disable this warning, you can either:\n", 305 | "\t- Avoid using `tokenizers` before the fork if possible\n", 306 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 307 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 308 | "To disable this warning, you can either:\n", 309 | "\t- Avoid using `tokenizers` before the fork if possible\n", 310 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "training_args = TrainingArguments(\n", 316 | " output_dir='/kaggle/working/',\n", 317 | " overwrite_output_dir=True,\n", 318 | " num_train_epochs=7,\n", 319 | " per_device_train_batch_size=32,\n", 320 | " save_steps=10_000,\n", 321 | " save_total_limit=2,\n", 322 | ")\n", 323 | "\n", 324 | "trainer = Trainer(\n", 325 | " model=model,\n", 326 | " args=training_args,\n", 327 | " data_collator=data_collator,\n", 328 | " train_dataset=dataset,\n", 329 | "# prediction_loss_only=True,\n", 330 | ")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 8, 336 | "id": "d0536f16", 337 | "metadata": { 338 | "execution": { 339 | "iopub.execute_input": "2021-11-19T14:07:02.296536Z", 340 | "iopub.status.busy": "2021-11-19T14:07:02.108451Z", 341 | "iopub.status.idle": "2021-11-19T14:07:06.244558Z", 342 | "shell.execute_reply": "2021-11-19T14:07:06.242409Z", 343 | "shell.execute_reply.started": "2021-11-19T08:18:38.086060Z" 344 | }, 345 | "papermill": { 346 | "duration": 4.158679, 347 | "end_time": "2021-11-19T14:07:06.244963", 348 | "exception": true, 349 | "start_time": "2021-11-19T14:07:02.086284", 350 | "status": "failed" 351 | }, 352 | "tags": [] 353 | }, 354 | "outputs": [ 355 | { 356 | "name": "stderr", 357 | "output_type": "stream", 358 | "text": [ 359 | "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", 360 | "Traceback (most recent call last):\n", 361 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_init.py\", line 867, in init\n", 362 | " wi.setup(kwargs)\n", 363 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_init.py\", line 171, in setup\n", 364 | " _silent=(settings._quiet or settings._silent) is True,\n", 365 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\", line 274, in _login\n", 366 | " wlogin.prompt_api_key()\n", 367 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\", line 202, in prompt_api_key\n", 368 | " key, status = self._prompt_api_key()\n", 369 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\", line 185, in _prompt_api_key\n", 370 | " no_create=self._settings.force if self._settings else None,\n", 371 | " File \"/opt/conda/lib/python3.7/site-packages/wandb/sdk/lib/apikey.py\", line 123, in prompt_api_key\n", 372 | " key = input_callback(api_ask).strip()\n", 373 | " File \"/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py\", line 978, in getpass\n", 374 | " \"getpass was called, but this frontend does not support input requests.\"\n", 375 | "IPython.core.error.StdinNotImplementedError: getpass was called, but this frontend does not support input requests.\n", 376 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Abnormal program exit\n" 377 | ] 378 | }, 379 | { 380 | "ename": "Exception", 381 | "evalue": "problem", 382 | "output_type": "error", 383 | "traceback": [ 384 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 385 | "\u001b[0;31mStdinNotImplementedError\u001b[0m Traceback (most recent call last)", 386 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[0;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[1;32m 866\u001b[0m \u001b[0mwi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_WandbInit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 867\u001b[0;31m \u001b[0mwi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 868\u001b[0m \u001b[0mexcept_exit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msettings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_except_exit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 387 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_init.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0m_disable_warning\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 171\u001b[0;31m \u001b[0m_silent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_quiet\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0msettings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_silent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 172\u001b[0m )\n", 388 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\u001b[0m in \u001b[0;36m_login\u001b[0;34m(anonymous, key, relogin, host, force, timeout, _backend, _silent, _disable_warning)\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0mwlogin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprompt_api_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 389 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprompt_api_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 202\u001b[0;31m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_prompt_api_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 203\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstatus\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mApiKeyStatus\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNOTTY\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 390 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_login.py\u001b[0m in \u001b[0;36m_prompt_api_key\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0mno_offline\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforce\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_settings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mno_create\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_settings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforce\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_settings\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m )\n", 391 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/lib/apikey.py\u001b[0m in \u001b[0;36mprompt_api_key\u001b[0;34m(settings, api, input_callback, browser_callback, no_offline, no_create, local)\u001b[0m\n\u001b[1;32m 122\u001b[0m )\n\u001b[0;32m--> 123\u001b[0;31m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapi_ask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0mwrite_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 392 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mgetpass\u001b[0;34m(self, prompt, stream)\u001b[0m\n\u001b[1;32m 977\u001b[0m raise StdinNotImplementedError(\n\u001b[0;32m--> 978\u001b[0;31m \u001b[0;34m\"getpass was called, but this frontend does not support input requests.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 979\u001b[0m )\n", 393 | "\u001b[0;31mStdinNotImplementedError\u001b[0m: getpass was called, but this frontend does not support input requests.", 394 | "\nThe above exception was the direct cause of the following exception:\n", 395 | "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", 396 | "\u001b[0;32m/tmp/ipykernel_24/638423469.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/kaggle/working/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 397 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, **kwargs)\u001b[0m\n\u001b[1;32m 1067\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1068\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1069\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1070\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1071\u001b[0m \u001b[0;31m# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 398 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer_callback.py\u001b[0m in \u001b[0;36mon_train_begin\u001b[0;34m(self, args, state, control)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_train_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainingArguments\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainerState\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrol\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainerControl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0mcontrol\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshould_training_stop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"on_train_begin\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_train_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainingArguments\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainerState\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrol\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrainerControl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 399 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/trainer_callback.py\u001b[0m in \u001b[0;36mcall_event\u001b[0;34m(self, event, args, state, control, **kwargs)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0meval_dataloader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 388\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 389\u001b[0m )\n\u001b[1;32m 390\u001b[0m \u001b[0;31m# A Callback can skip the return of `control` if it doesn't change it.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 400 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/integrations.py\u001b[0m in \u001b[0;36mon_train_begin\u001b[0;34m(self, args, state, control, model, **kwargs)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_wandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_initialized\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 629\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 630\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_train_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 401 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/transformers/integrations.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self, args, state, model, **kwargs)\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[0mproject\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetenv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"WANDB_PROJECT\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"huggingface\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 605\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrun_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 606\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0minit_args\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 607\u001b[0m )\n\u001b[1;32m 608\u001b[0m \u001b[0;31m# add config parameters (run may have been created manually)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 402 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_init.py\u001b[0m in \u001b[0;36minit\u001b[0;34m(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, settings)\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mexcept_exit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 907\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 908\u001b[0;31m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_from\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"problem\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror_seen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 909\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 403 | "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/six.py\u001b[0m in \u001b[0;36mraise_from\u001b[0;34m(value, from_value)\u001b[0m\n", 404 | "\u001b[0;31mException\u001b[0m: problem" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "trainer.train()\n", 410 | "trainer.save_model('/kaggle/working/')" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "id": "586d6c33", 417 | "metadata": { 418 | "papermill": { 419 | "duration": null, 420 | "end_time": null, 421 | "exception": null, 422 | "start_time": null, 423 | "status": "pending" 424 | }, 425 | "tags": [] 426 | }, 427 | "outputs": [], 428 | "source": [ 429 | "MAX_LEN = 512\n", 430 | "TRAIN_BATCH_SIZE = 8\n", 431 | "VALID_BATCH_SIZE = 4\n", 432 | "LEARNING_RATE = 1e-05\n", 433 | "\n", 434 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', truncation=True, do_lower_case=True)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "62ccd6e4", 441 | "metadata": { 442 | "papermill": { 443 | "duration": null, 444 | "end_time": null, 445 | "exception": null, 446 | "start_time": null, 447 | "status": "pending" 448 | }, 449 | "tags": [] 450 | }, 451 | "outputs": [], 452 | "source": [ 453 | "class QueryData(Dataset):\n", 454 | " def __init__(self, dataframe, tokenizer, max_len):\n", 455 | " self.tokenizer = tokenizer\n", 456 | " self.text = dataframe['query']\n", 457 | " self.targets = dataframe['label']\n", 458 | " self.max_len = max_len\n", 459 | "\n", 460 | " def __len__(self):\n", 461 | " return len(self.text)\n", 462 | "\n", 463 | " def __getitem__(self, index):\n", 464 | " text = str(self.text[index])\n", 465 | " text = \" \".join(text.split())\n", 466 | "\n", 467 | " inputs = self.tokenizer.encode_plus(\n", 468 | " text,\n", 469 | " None,\n", 470 | " add_special_tokens=True,\n", 471 | " max_length=self.max_len,\n", 472 | " pad_to_max_length=True,\n", 473 | " return_token_type_ids=True\n", 474 | " )\n", 475 | " ids = inputs['input_ids']\n", 476 | " mask = inputs['attention_mask']\n", 477 | " token_type_ids = inputs[\"token_type_ids\"]\n", 478 | "\n", 479 | "\n", 480 | " return {\n", 481 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 482 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 483 | " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n", 484 | " 'targets': torch.tensor(self.targets[index], dtype=torch.float)\n", 485 | " }" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "id": "1e08f85d", 492 | "metadata": { 493 | "papermill": { 494 | "duration": null, 495 | "end_time": null, 496 | "exception": null, 497 | "start_time": null, 498 | "status": "pending" 499 | }, 500 | "tags": [] 501 | }, 502 | "outputs": [], 503 | "source": [ 504 | "print(\"Train Dataset: {}\".format(train_data.shape))\n", 505 | "print(\"Validation Dataset: {}\".format(val_data.shape))\n", 506 | "print(\"Test Dataset: {}\".format(test_data.shape))\n", 507 | "\n", 508 | "training_set = QueryData(train_data, tokenizer, MAX_LEN)\n", 509 | "val_set = QueryData(val_data, tokenizer, MAX_LEN)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "id": "88eddc56", 516 | "metadata": { 517 | "papermill": { 518 | "duration": null, 519 | "end_time": null, 520 | "exception": null, 521 | "start_time": null, 522 | "status": "pending" 523 | }, 524 | "tags": [] 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 529 | " 'shuffle': True,\n", 530 | " 'num_workers': 0\n", 531 | " }\n", 532 | "\n", 533 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 534 | " 'shuffle': True,\n", 535 | " 'num_workers': 0\n", 536 | " }\n", 537 | "\n", 538 | "training_loader = DataLoader(training_set, **train_params)\n", 539 | "val_loader = DataLoader(val_set, **test_params)" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "id": "2c315701", 546 | "metadata": { 547 | "papermill": { 548 | "duration": null, 549 | "end_time": null, 550 | "exception": null, 551 | "start_time": null, 552 | "status": "pending" 553 | }, 554 | "tags": [] 555 | }, 556 | "outputs": [], 557 | "source": [ 558 | "class BertClass(torch.nn.Module):\n", 559 | " def __init__(self):\n", 560 | " super(BertClass, self).__init__()\n", 561 | " self.l1 = BertModel.from_pretrained(\"pytorch_model.bin\")\n", 562 | " self.pre_classifier = torch.nn.Linear(768, 768)\n", 563 | " self.dropout = torch.nn.Dropout(0.1)\n", 564 | " self.classifier = torch.nn.Linear(768, 2)\n", 565 | " self.relu = torch.nn.ReLU()\n", 566 | "\n", 567 | " def forward(self, input_ids, attention_mask, token_type_ids):\n", 568 | " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", 569 | " hidden_state = output_1[0]\n", 570 | " pooler = hidden_state[:, 0]\n", 571 | " pooler = self.pre_classifier(pooler)\n", 572 | " pooler = self.relu(pooler)\n", 573 | " pooler = self.dropout(pooler)\n", 574 | " output = self.classifier(pooler)\n", 575 | " return output" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "id": "8c5a9896", 582 | "metadata": { 583 | "papermill": { 584 | "duration": null, 585 | "end_time": null, 586 | "exception": null, 587 | "start_time": null, 588 | "status": "pending" 589 | }, 590 | "tags": [] 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "model = BertClass()\n", 595 | "model.to(device)" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "id": "a81e8648", 602 | "metadata": { 603 | "papermill": { 604 | "duration": null, 605 | "end_time": null, 606 | "exception": null, 607 | "start_time": null, 608 | "status": "pending" 609 | }, 610 | "tags": [] 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "# Creating the loss function and optimizer\n", 615 | "loss_function = torch.nn.CrossEntropyLoss()\n", 616 | "optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": null, 622 | "id": "fd03c4c6", 623 | "metadata": { 624 | "papermill": { 625 | "duration": null, 626 | "end_time": null, 627 | "exception": null, 628 | "start_time": null, 629 | "status": "pending" 630 | }, 631 | "tags": [] 632 | }, 633 | "outputs": [], 634 | "source": [ 635 | "def calcuate_accuracy(preds, targets):\n", 636 | " n_correct = (preds==targets).sum().item()\n", 637 | " return n_correct" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "id": "6b3cb621", 644 | "metadata": { 645 | "papermill": { 646 | "duration": null, 647 | "end_time": null, 648 | "exception": null, 649 | "start_time": null, 650 | "status": "pending" 651 | }, 652 | "tags": [] 653 | }, 654 | "outputs": [], 655 | "source": [ 656 | "def train(epoch, training_loader):\n", 657 | " tr_loss = 0\n", 658 | " n_correct = 0\n", 659 | " nb_tr_steps = 0\n", 660 | " nb_tr_examples = 0\n", 661 | " model.train()\n", 662 | " for _,data in tqdm(enumerate(training_loader, 0)):\n", 663 | " ids = data['ids'].to(device, dtype = torch.long)\n", 664 | " mask = data['mask'].to(device, dtype = torch.long)\n", 665 | " token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n", 666 | " targets = data['targets'].to(device, dtype = torch.long)\n", 667 | "\n", 668 | " outputs = model(ids, mask, token_type_ids)\n", 669 | " loss = loss_function(outputs, targets)\n", 670 | " tr_loss += loss.item()\n", 671 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 672 | " n_correct += calcuate_accuracy(big_idx, targets)\n", 673 | "\n", 674 | " nb_tr_steps += 1\n", 675 | " nb_tr_examples+=targets.size(0)\n", 676 | " \n", 677 | " if _%500==0:\n", 678 | " loss_step = tr_loss/nb_tr_steps\n", 679 | " accu_step = (n_correct*100)/nb_tr_examples \n", 680 | " print(f\"Training Loss per 500 steps: {loss_step}\")\n", 681 | " print(f\"Training Accuracy per 500 steps: {accu_step}\")\n", 682 | "\n", 683 | " optimizer.zero_grad()\n", 684 | " loss.backward()\n", 685 | " # # When using GPU\n", 686 | " optimizer.step()\n", 687 | "\n", 688 | " print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')\n", 689 | " epoch_loss = tr_loss/nb_tr_steps\n", 690 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 691 | " print(f\"Training Loss Epoch: {epoch_loss}\")\n", 692 | " print(f\"Training Accuracy Epoch: {epoch_accu}\")\n", 693 | "\n", 694 | " return " 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": null, 700 | "id": "e1b5902d", 701 | "metadata": { 702 | "papermill": { 703 | "duration": null, 704 | "end_time": null, 705 | "exception": null, 706 | "start_time": null, 707 | "status": "pending" 708 | }, 709 | "tags": [] 710 | }, 711 | "outputs": [], 712 | "source": [ 713 | "EPOCHS = 2\n", 714 | "for epoch in range(EPOCHS):\n", 715 | " train(epoch, training_loader)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": null, 721 | "id": "7f9a0369", 722 | "metadata": { 723 | "papermill": { 724 | "duration": null, 725 | "end_time": null, 726 | "exception": null, 727 | "start_time": null, 728 | "status": "pending" 729 | }, 730 | "tags": [] 731 | }, 732 | "outputs": [], 733 | "source": [ 734 | "def valid(model, testing_loader):\n", 735 | " model.eval()\n", 736 | " n_correct = 0; n_wrong = 0; total = 0; tr_loss=0; nb_tr_steps=0; nb_tr_examples=0\n", 737 | " with torch.no_grad():\n", 738 | " for _, data in tqdm(enumerate(testing_loader, 0)):\n", 739 | " ids = data['ids'].to(device, dtype = torch.long)\n", 740 | " mask = data['mask'].to(device, dtype = torch.long)\n", 741 | " token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n", 742 | " targets = data['targets'].to(device, dtype = torch.long)\n", 743 | " outputs = model(ids, mask, token_type_ids)\n", 744 | " loss = loss_function(outputs, targets)\n", 745 | " tr_loss += loss.item()\n", 746 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 747 | " n_correct += calcuate_accuracy(big_idx, targets)\n", 748 | "\n", 749 | " nb_tr_steps += 1\n", 750 | " nb_tr_examples+=targets.size(0)\n", 751 | " \n", 752 | " if _%5000==0:\n", 753 | " loss_step = tr_loss/nb_tr_steps\n", 754 | " accu_step = (n_correct*100)/nb_tr_examples\n", 755 | " print(f\"Validation Loss per 100 steps: {loss_step}\")\n", 756 | " print(f\"Validation Accuracy per 100 steps: {accu_step}\")\n", 757 | " epoch_loss = tr_loss/nb_tr_steps\n", 758 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 759 | " print(f\"Validation Loss Epoch: {epoch_loss}\")\n", 760 | " print(f\"Validation Accuracy Epoch: {epoch_accu}\")\n", 761 | " \n", 762 | " return epoch_accu\n" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "id": "0e777814", 769 | "metadata": { 770 | "papermill": { 771 | "duration": null, 772 | "end_time": null, 773 | "exception": null, 774 | "start_time": null, 775 | "status": "pending" 776 | }, 777 | "tags": [] 778 | }, 779 | "outputs": [], 780 | "source": [ 781 | "acc = valid(model, val_loader)\n", 782 | "print(\"Accuracy on validation data = %0.2f%%\" % acc)" 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "execution_count": null, 788 | "id": "8d8cc3d5", 789 | "metadata": { 790 | "papermill": { 791 | "duration": null, 792 | "end_time": null, 793 | "exception": null, 794 | "start_time": null, 795 | "status": "pending" 796 | }, 797 | "tags": [] 798 | }, 799 | "outputs": [], 800 | "source": [ 801 | "class QueryData(Dataset):\n", 802 | " def __init__(self, dataframe, tokenizer, max_len):\n", 803 | " self.tokenizer = tokenizer\n", 804 | " self.data = dataframe\n", 805 | " self.text = dataframe['query']\n", 806 | "# self.targets = self.data.citation_influence_label\n", 807 | " self.max_len = max_len\n", 808 | "\n", 809 | " def __len__(self):\n", 810 | " return len(self.text)\n", 811 | "\n", 812 | " def __getitem__(self, index):\n", 813 | " text = str(self.text[index])\n", 814 | " text = \" \".join(text.split())\n", 815 | "\n", 816 | " inputs = self.tokenizer.encode_plus(\n", 817 | " text,\n", 818 | " None,\n", 819 | " add_special_tokens=True,\n", 820 | " max_length=self.max_len,\n", 821 | " pad_to_max_length=True,\n", 822 | " return_token_type_ids=True\n", 823 | " )\n", 824 | " ids = inputs['input_ids']\n", 825 | " mask = inputs['attention_mask']\n", 826 | " token_type_ids = inputs[\"token_type_ids\"]\n", 827 | "\n", 828 | "\n", 829 | " return {\n", 830 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 831 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 832 | " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long)\n", 833 | "# 'targets': torch.tensor(self.targets[index], dtype=torch.float)\n", 834 | " }" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "execution_count": null, 840 | "id": "4f2f335b", 841 | "metadata": { 842 | "papermill": { 843 | "duration": null, 844 | "end_time": null, 845 | "exception": null, 846 | "start_time": null, 847 | "status": "pending" 848 | }, 849 | "tags": [] 850 | }, 851 | "outputs": [], 852 | "source": [ 853 | "test_data.head(2)" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "id": "037db795", 860 | "metadata": { 861 | "papermill": { 862 | "duration": null, 863 | "end_time": null, 864 | "exception": null, 865 | "start_time": null, 866 | "status": "pending" 867 | }, 868 | "tags": [] 869 | }, 870 | "outputs": [], 871 | "source": [ 872 | "data_to_test = QueryData(test_data[['query']], tokenizer, MAX_LEN)\n", 873 | "\n", 874 | "test_params = {'batch_size': 4,\n", 875 | " 'shuffle': False,\n", 876 | " 'num_workers': 0\n", 877 | " }\n", 878 | "\n", 879 | "testing_loader_f = DataLoader(data_to_test, **test_params)" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": null, 885 | "id": "d7a98abe", 886 | "metadata": { 887 | "papermill": { 888 | "duration": null, 889 | "end_time": null, 890 | "exception": null, 891 | "start_time": null, 892 | "status": "pending" 893 | }, 894 | "tags": [] 895 | }, 896 | "outputs": [], 897 | "source": [ 898 | "def test(model, testing_loader):\n", 899 | " res = []\n", 900 | " model.eval()\n", 901 | " n_correct = 0; n_wrong = 0; total = 0; tr_loss=0; nb_tr_steps=0; nb_tr_examples=0\n", 902 | " with torch.no_grad():\n", 903 | " for _, data in tqdm(enumerate(testing_loader, 0)):\n", 904 | " ids = data['ids'].to(device, dtype = torch.long)\n", 905 | " mask = data['mask'].to(device, dtype = torch.long)\n", 906 | " token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n", 907 | "# targets = data['targets'].to(device, dtype = torch.long)\n", 908 | " outputs = model(ids, mask, token_type_ids)\n", 909 | " big_val, big_idx = torch.max(outputs, dim=1)\n", 910 | " res.extend(big_idx.tolist())\n", 911 | " \n", 912 | " return res" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": null, 918 | "id": "9254ae08", 919 | "metadata": { 920 | "papermill": { 921 | "duration": null, 922 | "end_time": null, 923 | "exception": null, 924 | "start_time": null, 925 | "status": "pending" 926 | }, 927 | "tags": [] 928 | }, 929 | "outputs": [], 930 | "source": [ 931 | "res = test(model, testing_loader_f)" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": null, 937 | "id": "6e39e57c", 938 | "metadata": { 939 | "papermill": { 940 | "duration": null, 941 | "end_time": null, 942 | "exception": null, 943 | "start_time": null, 944 | "status": "pending" 945 | }, 946 | "tags": [] 947 | }, 948 | "outputs": [], 949 | "source": [ 950 | "test_data['label'] = [1 if label>=0.8 else 0 for label in test_data['label']]" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "id": "14266fb0", 957 | "metadata": { 958 | "papermill": { 959 | "duration": null, 960 | "end_time": null, 961 | "exception": null, 962 | "start_time": null, 963 | "status": "pending" 964 | }, 965 | "tags": [] 966 | }, 967 | "outputs": [], 968 | "source": [ 969 | "correct = [1 if pred==lab else 0 for pred, lab in zip(res, test_data['label'].tolist())]\n", 970 | "print('accuracy on test set is - ', sum(correct)/len(test_data['label'].tolist())*100, '%')" 971 | ] 972 | }, 973 | { 974 | "cell_type": "code", 975 | "execution_count": null, 976 | "id": "f5cf1ca4", 977 | "metadata": { 978 | "papermill": { 979 | "duration": null, 980 | "end_time": null, 981 | "exception": null, 982 | "start_time": null, 983 | "status": "pending" 984 | }, 985 | "tags": [] 986 | }, 987 | "outputs": [], 988 | "source": [ 989 | "output_model_file = 'pytorch_bert_pretrained.bin'\n", 990 | "output_vocab_file = './'\n", 991 | "\n", 992 | "model_to_save = model\n", 993 | "torch.save(model_to_save, output_model_file)\n", 994 | "tokenizer.save_vocabulary(output_vocab_file)\n", 995 | "\n", 996 | "print('All files saved')\n", 997 | "print('This tutorial is completed')" 998 | ] 999 | } 1000 | ], 1001 | "metadata": { 1002 | "kernelspec": { 1003 | "display_name": "Python 3", 1004 | "language": "python", 1005 | "name": "python3" 1006 | }, 1007 | "language_info": { 1008 | "codemirror_mode": { 1009 | "name": "ipython", 1010 | "version": 3 1011 | }, 1012 | "file_extension": ".py", 1013 | "mimetype": "text/x-python", 1014 | "name": "python", 1015 | "nbconvert_exporter": "python", 1016 | "pygments_lexer": "ipython3", 1017 | "version": "3.7.10" 1018 | }, 1019 | "papermill": { 1020 | "default_parameters": {}, 1021 | "duration": 40.486835, 1022 | "end_time": "2021-11-19T14:07:09.957930", 1023 | "environment_variables": {}, 1024 | "exception": true, 1025 | "input_path": "__notebook__.ipynb", 1026 | "output_path": "__notebook__.ipynb", 1027 | "parameters": {}, 1028 | "start_time": "2021-11-19T14:06:29.471095", 1029 | "version": "2.3.3" 1030 | } 1031 | }, 1032 | "nbformat": 4, 1033 | "nbformat_minor": 5 1034 | } 1035 | -------------------------------------------------------------------------------- /sentiment_analysis_using_roberta.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "name": "python", 12 | "version": "3.6.6", 13 | "mimetype": "text/x-python", 14 | "codemirror_mode": { 15 | "name": "ipython", 16 | "version": 3 17 | }, 18 | "pygments_lexer": "ipython3", 19 | "nbconvert_exporter": "python", 20 | "file_extension": ".py" 21 | }, 22 | "colab": { 23 | "name": "sentiment-analysis-using-roberta.ipynb", 24 | "provenance": [], 25 | "collapsed_sections": [], 26 | "include_colab_link": true 27 | }, 28 | "accelerator": "GPU", 29 | "widgets": { 30 | "application/vnd.jupyter.widget-state+json": { 31 | "30864762e7f242c281b72862c5c08a33": { 32 | "model_module": "@jupyter-widgets/controls", 33 | "model_name": "HBoxModel", 34 | "state": { 35 | "_view_name": "HBoxView", 36 | "_dom_classes": [], 37 | "_model_name": "HBoxModel", 38 | "_view_module": "@jupyter-widgets/controls", 39 | "_model_module_version": "1.5.0", 40 | "_view_count": null, 41 | "_view_module_version": "1.5.0", 42 | "box_style": "", 43 | "layout": "IPY_MODEL_dd12a39995584ba79f0e786b370b1a99", 44 | "_model_module": "@jupyter-widgets/controls", 45 | "children": [ 46 | "IPY_MODEL_b89c9e76b5594a8ea601b9c5d2af4fa6", 47 | "IPY_MODEL_f65c7649640b4e87a7819a3da2f54fe0" 48 | ] 49 | } 50 | }, 51 | "dd12a39995584ba79f0e786b370b1a99": { 52 | "model_module": "@jupyter-widgets/base", 53 | "model_name": "LayoutModel", 54 | "state": { 55 | "_view_name": "LayoutView", 56 | "grid_template_rows": null, 57 | "right": null, 58 | "justify_content": null, 59 | "_view_module": "@jupyter-widgets/base", 60 | "overflow": null, 61 | "_model_module_version": "1.2.0", 62 | "_view_count": null, 63 | "flex_flow": null, 64 | "width": null, 65 | "min_width": null, 66 | "border": null, 67 | "align_items": null, 68 | "bottom": null, 69 | "_model_module": "@jupyter-widgets/base", 70 | "top": null, 71 | "grid_column": null, 72 | "overflow_y": null, 73 | "overflow_x": null, 74 | "grid_auto_flow": null, 75 | "grid_area": null, 76 | "grid_template_columns": null, 77 | "flex": null, 78 | "_model_name": "LayoutModel", 79 | "justify_items": null, 80 | "grid_row": null, 81 | "max_height": null, 82 | "align_content": null, 83 | "visibility": null, 84 | "align_self": null, 85 | "height": null, 86 | "min_height": null, 87 | "padding": null, 88 | "grid_auto_rows": null, 89 | "grid_gap": null, 90 | "max_width": null, 91 | "order": null, 92 | "_view_module_version": "1.2.0", 93 | "grid_template_areas": null, 94 | "object_position": null, 95 | "object_fit": null, 96 | "grid_auto_columns": null, 97 | "margin": null, 98 | "display": null, 99 | "left": null 100 | } 101 | }, 102 | "b89c9e76b5594a8ea601b9c5d2af4fa6": { 103 | "model_module": "@jupyter-widgets/controls", 104 | "model_name": "FloatProgressModel", 105 | "state": { 106 | "_view_name": "ProgressView", 107 | "style": "IPY_MODEL_80a6b6c9c4d5436ebe3b90b791c6fd93", 108 | "_dom_classes": [], 109 | "description": "Downloading: 100%", 110 | "_model_name": "FloatProgressModel", 111 | "bar_style": "success", 112 | "max": 898823, 113 | "_view_module": "@jupyter-widgets/controls", 114 | "_model_module_version": "1.5.0", 115 | "value": 898823, 116 | "_view_count": null, 117 | "_view_module_version": "1.5.0", 118 | "orientation": "horizontal", 119 | "min": 0, 120 | "description_tooltip": null, 121 | "_model_module": "@jupyter-widgets/controls", 122 | "layout": "IPY_MODEL_8c1f6e94723842faa6bd3dcd9ff4ea82" 123 | } 124 | }, 125 | "f65c7649640b4e87a7819a3da2f54fe0": { 126 | "model_module": "@jupyter-widgets/controls", 127 | "model_name": "HTMLModel", 128 | "state": { 129 | "_view_name": "HTMLView", 130 | "style": "IPY_MODEL_39b5ca071fd3452e9d9145dd2b366da1", 131 | "_dom_classes": [], 132 | "description": "", 133 | "_model_name": "HTMLModel", 134 | "placeholder": "​", 135 | "_view_module": "@jupyter-widgets/controls", 136 | "_model_module_version": "1.5.0", 137 | "value": " 899k/899k [00:02<00:00, 333kB/s]", 138 | "_view_count": null, 139 | "_view_module_version": "1.5.0", 140 | "description_tooltip": null, 141 | "_model_module": "@jupyter-widgets/controls", 142 | "layout": "IPY_MODEL_f8dfd3ea6bb7413592115195bc6e0b83" 143 | } 144 | }, 145 | "80a6b6c9c4d5436ebe3b90b791c6fd93": { 146 | "model_module": "@jupyter-widgets/controls", 147 | "model_name": "ProgressStyleModel", 148 | "state": { 149 | "_view_name": "StyleView", 150 | "_model_name": "ProgressStyleModel", 151 | "description_width": "initial", 152 | "_view_module": "@jupyter-widgets/base", 153 | "_model_module_version": "1.5.0", 154 | "_view_count": null, 155 | "_view_module_version": "1.2.0", 156 | "bar_color": null, 157 | "_model_module": "@jupyter-widgets/controls" 158 | } 159 | }, 160 | "8c1f6e94723842faa6bd3dcd9ff4ea82": { 161 | "model_module": "@jupyter-widgets/base", 162 | "model_name": "LayoutModel", 163 | "state": { 164 | "_view_name": "LayoutView", 165 | "grid_template_rows": null, 166 | "right": null, 167 | "justify_content": null, 168 | "_view_module": "@jupyter-widgets/base", 169 | "overflow": null, 170 | "_model_module_version": "1.2.0", 171 | "_view_count": null, 172 | "flex_flow": null, 173 | "width": null, 174 | "min_width": null, 175 | "border": null, 176 | "align_items": null, 177 | "bottom": null, 178 | "_model_module": "@jupyter-widgets/base", 179 | "top": null, 180 | "grid_column": null, 181 | "overflow_y": null, 182 | "overflow_x": null, 183 | "grid_auto_flow": null, 184 | "grid_area": null, 185 | "grid_template_columns": null, 186 | "flex": null, 187 | "_model_name": "LayoutModel", 188 | "justify_items": null, 189 | "grid_row": null, 190 | "max_height": null, 191 | "align_content": null, 192 | "visibility": null, 193 | "align_self": null, 194 | "height": null, 195 | "min_height": null, 196 | "padding": null, 197 | "grid_auto_rows": null, 198 | "grid_gap": null, 199 | "max_width": null, 200 | "order": null, 201 | "_view_module_version": "1.2.0", 202 | "grid_template_areas": null, 203 | "object_position": null, 204 | "object_fit": null, 205 | "grid_auto_columns": null, 206 | "margin": null, 207 | "display": null, 208 | "left": null 209 | } 210 | }, 211 | "39b5ca071fd3452e9d9145dd2b366da1": { 212 | "model_module": "@jupyter-widgets/controls", 213 | "model_name": "DescriptionStyleModel", 214 | "state": { 215 | "_view_name": "StyleView", 216 | "_model_name": "DescriptionStyleModel", 217 | "description_width": "", 218 | "_view_module": "@jupyter-widgets/base", 219 | "_model_module_version": "1.5.0", 220 | "_view_count": null, 221 | "_view_module_version": "1.2.0", 222 | "_model_module": "@jupyter-widgets/controls" 223 | } 224 | }, 225 | "f8dfd3ea6bb7413592115195bc6e0b83": { 226 | "model_module": "@jupyter-widgets/base", 227 | "model_name": "LayoutModel", 228 | "state": { 229 | "_view_name": "LayoutView", 230 | "grid_template_rows": null, 231 | "right": null, 232 | "justify_content": null, 233 | "_view_module": "@jupyter-widgets/base", 234 | "overflow": null, 235 | "_model_module_version": "1.2.0", 236 | "_view_count": null, 237 | "flex_flow": null, 238 | "width": null, 239 | "min_width": null, 240 | "border": null, 241 | "align_items": null, 242 | "bottom": null, 243 | "_model_module": "@jupyter-widgets/base", 244 | "top": null, 245 | "grid_column": null, 246 | "overflow_y": null, 247 | "overflow_x": null, 248 | "grid_auto_flow": null, 249 | "grid_area": null, 250 | "grid_template_columns": null, 251 | "flex": null, 252 | "_model_name": "LayoutModel", 253 | "justify_items": null, 254 | "grid_row": null, 255 | "max_height": null, 256 | "align_content": null, 257 | "visibility": null, 258 | "align_self": null, 259 | "height": null, 260 | "min_height": null, 261 | "padding": null, 262 | "grid_auto_rows": null, 263 | "grid_gap": null, 264 | "max_width": null, 265 | "order": null, 266 | "_view_module_version": "1.2.0", 267 | "grid_template_areas": null, 268 | "object_position": null, 269 | "object_fit": null, 270 | "grid_auto_columns": null, 271 | "margin": null, 272 | "display": null, 273 | "left": null 274 | } 275 | }, 276 | "611dfdca86f4498e8aa1491ed6ffb13d": { 277 | "model_module": "@jupyter-widgets/controls", 278 | "model_name": "HBoxModel", 279 | "state": { 280 | "_view_name": "HBoxView", 281 | "_dom_classes": [], 282 | "_model_name": "HBoxModel", 283 | "_view_module": "@jupyter-widgets/controls", 284 | "_model_module_version": "1.5.0", 285 | "_view_count": null, 286 | "_view_module_version": "1.5.0", 287 | "box_style": "", 288 | "layout": "IPY_MODEL_be4857f17c244fb39a771f2c97283fd5", 289 | "_model_module": "@jupyter-widgets/controls", 290 | "children": [ 291 | "IPY_MODEL_2fe41e1db18b4295a6907771462a0fce", 292 | "IPY_MODEL_0b29a9e1a275451bbc2114807532f91e" 293 | ] 294 | } 295 | }, 296 | "be4857f17c244fb39a771f2c97283fd5": { 297 | "model_module": "@jupyter-widgets/base", 298 | "model_name": "LayoutModel", 299 | "state": { 300 | "_view_name": "LayoutView", 301 | "grid_template_rows": null, 302 | "right": null, 303 | "justify_content": null, 304 | "_view_module": "@jupyter-widgets/base", 305 | "overflow": null, 306 | "_model_module_version": "1.2.0", 307 | "_view_count": null, 308 | "flex_flow": null, 309 | "width": null, 310 | "min_width": null, 311 | "border": null, 312 | "align_items": null, 313 | "bottom": null, 314 | "_model_module": "@jupyter-widgets/base", 315 | "top": null, 316 | "grid_column": null, 317 | "overflow_y": null, 318 | "overflow_x": null, 319 | "grid_auto_flow": null, 320 | "grid_area": null, 321 | "grid_template_columns": null, 322 | "flex": null, 323 | "_model_name": "LayoutModel", 324 | "justify_items": null, 325 | "grid_row": null, 326 | "max_height": null, 327 | "align_content": null, 328 | "visibility": null, 329 | "align_self": null, 330 | "height": null, 331 | "min_height": null, 332 | "padding": null, 333 | "grid_auto_rows": null, 334 | "grid_gap": null, 335 | "max_width": null, 336 | "order": null, 337 | "_view_module_version": "1.2.0", 338 | "grid_template_areas": null, 339 | "object_position": null, 340 | "object_fit": null, 341 | "grid_auto_columns": null, 342 | "margin": null, 343 | "display": null, 344 | "left": null 345 | } 346 | }, 347 | "2fe41e1db18b4295a6907771462a0fce": { 348 | "model_module": "@jupyter-widgets/controls", 349 | "model_name": "FloatProgressModel", 350 | "state": { 351 | "_view_name": "ProgressView", 352 | "style": "IPY_MODEL_115c8809853d410fac6e7f69af5a5488", 353 | "_dom_classes": [], 354 | "description": "Downloading: 100%", 355 | "_model_name": "FloatProgressModel", 356 | "bar_style": "success", 357 | "max": 456318, 358 | "_view_module": "@jupyter-widgets/controls", 359 | "_model_module_version": "1.5.0", 360 | "value": 456318, 361 | "_view_count": null, 362 | "_view_module_version": "1.5.0", 363 | "orientation": "horizontal", 364 | "min": 0, 365 | "description_tooltip": null, 366 | "_model_module": "@jupyter-widgets/controls", 367 | "layout": "IPY_MODEL_390827d7d2cb4b4fbfc0c022f015f7ed" 368 | } 369 | }, 370 | "0b29a9e1a275451bbc2114807532f91e": { 371 | "model_module": "@jupyter-widgets/controls", 372 | "model_name": "HTMLModel", 373 | "state": { 374 | "_view_name": "HTMLView", 375 | "style": "IPY_MODEL_5d305d4db08f47ba91461edb343874a4", 376 | "_dom_classes": [], 377 | "description": "", 378 | "_model_name": "HTMLModel", 379 | "placeholder": "​", 380 | "_view_module": "@jupyter-widgets/controls", 381 | "_model_module_version": "1.5.0", 382 | "value": " 456k/456k [00:00<00:00, 472kB/s]", 383 | "_view_count": null, 384 | "_view_module_version": "1.5.0", 385 | "description_tooltip": null, 386 | "_model_module": "@jupyter-widgets/controls", 387 | "layout": "IPY_MODEL_773af3ca0add4e7cac0a036fb8b55632" 388 | } 389 | }, 390 | "115c8809853d410fac6e7f69af5a5488": { 391 | "model_module": "@jupyter-widgets/controls", 392 | "model_name": "ProgressStyleModel", 393 | "state": { 394 | "_view_name": "StyleView", 395 | "_model_name": "ProgressStyleModel", 396 | "description_width": "initial", 397 | "_view_module": "@jupyter-widgets/base", 398 | "_model_module_version": "1.5.0", 399 | "_view_count": null, 400 | "_view_module_version": "1.2.0", 401 | "bar_color": null, 402 | "_model_module": "@jupyter-widgets/controls" 403 | } 404 | }, 405 | "390827d7d2cb4b4fbfc0c022f015f7ed": { 406 | "model_module": "@jupyter-widgets/base", 407 | "model_name": "LayoutModel", 408 | "state": { 409 | "_view_name": "LayoutView", 410 | "grid_template_rows": null, 411 | "right": null, 412 | "justify_content": null, 413 | "_view_module": "@jupyter-widgets/base", 414 | "overflow": null, 415 | "_model_module_version": "1.2.0", 416 | "_view_count": null, 417 | "flex_flow": null, 418 | "width": null, 419 | "min_width": null, 420 | "border": null, 421 | "align_items": null, 422 | "bottom": null, 423 | "_model_module": "@jupyter-widgets/base", 424 | "top": null, 425 | "grid_column": null, 426 | "overflow_y": null, 427 | "overflow_x": null, 428 | "grid_auto_flow": null, 429 | "grid_area": null, 430 | "grid_template_columns": null, 431 | "flex": null, 432 | "_model_name": "LayoutModel", 433 | "justify_items": null, 434 | "grid_row": null, 435 | "max_height": null, 436 | "align_content": null, 437 | "visibility": null, 438 | "align_self": null, 439 | "height": null, 440 | "min_height": null, 441 | "padding": null, 442 | "grid_auto_rows": null, 443 | "grid_gap": null, 444 | "max_width": null, 445 | "order": null, 446 | "_view_module_version": "1.2.0", 447 | "grid_template_areas": null, 448 | "object_position": null, 449 | "object_fit": null, 450 | "grid_auto_columns": null, 451 | "margin": null, 452 | "display": null, 453 | "left": null 454 | } 455 | }, 456 | "5d305d4db08f47ba91461edb343874a4": { 457 | "model_module": "@jupyter-widgets/controls", 458 | "model_name": "DescriptionStyleModel", 459 | "state": { 460 | "_view_name": "StyleView", 461 | "_model_name": "DescriptionStyleModel", 462 | "description_width": "", 463 | "_view_module": "@jupyter-widgets/base", 464 | "_model_module_version": "1.5.0", 465 | "_view_count": null, 466 | "_view_module_version": "1.2.0", 467 | "_model_module": "@jupyter-widgets/controls" 468 | } 469 | }, 470 | "773af3ca0add4e7cac0a036fb8b55632": { 471 | "model_module": "@jupyter-widgets/base", 472 | "model_name": "LayoutModel", 473 | "state": { 474 | "_view_name": "LayoutView", 475 | "grid_template_rows": null, 476 | "right": null, 477 | "justify_content": null, 478 | "_view_module": "@jupyter-widgets/base", 479 | "overflow": null, 480 | "_model_module_version": "1.2.0", 481 | "_view_count": null, 482 | "flex_flow": null, 483 | "width": null, 484 | "min_width": null, 485 | "border": null, 486 | "align_items": null, 487 | "bottom": null, 488 | "_model_module": "@jupyter-widgets/base", 489 | "top": null, 490 | "grid_column": null, 491 | "overflow_y": null, 492 | "overflow_x": null, 493 | "grid_auto_flow": null, 494 | "grid_area": null, 495 | "grid_template_columns": null, 496 | "flex": null, 497 | "_model_name": "LayoutModel", 498 | "justify_items": null, 499 | "grid_row": null, 500 | "max_height": null, 501 | "align_content": null, 502 | "visibility": null, 503 | "align_self": null, 504 | "height": null, 505 | "min_height": null, 506 | "padding": null, 507 | "grid_auto_rows": null, 508 | "grid_gap": null, 509 | "max_width": null, 510 | "order": null, 511 | "_view_module_version": "1.2.0", 512 | "grid_template_areas": null, 513 | "object_position": null, 514 | "object_fit": null, 515 | "grid_auto_columns": null, 516 | "margin": null, 517 | "display": null, 518 | "left": null 519 | } 520 | } 521 | } 522 | } 523 | }, 524 | "cells": [ 525 | { 526 | "cell_type": "markdown", 527 | "metadata": { 528 | "id": "view-in-github", 529 | "colab_type": "text" 530 | }, 531 | "source": [ 532 | "\"Open" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": { 538 | "id": "8OhO4xlwqExT" 539 | }, 540 | "source": [ 541 | "# Fine Tuning Roberta for Sentiment Analysis\n", 542 | "\n", 543 | "\n", 544 | "\n" 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": { 550 | "id": "WTdfPjhFqExX" 551 | }, 552 | "source": [ 553 | "### Introduction\n", 554 | "\n", 555 | "In this tutorial I will be fine tuning a roberta model for the **Sentiment Analysis** problem. \n", 556 | "\n", 557 | "#### Flow of the notebook\n", 558 | "\n", 559 | "The notebook will be divided into seperate sections to provide a organized walk through for the process used. This process can be modified for individual use cases. The sections are:\n", 560 | "\n", 561 | "1. [Importing Python Libraries and preparing the environment](#section01)\n", 562 | "2. [Importing and Pre-Processing the domain data](#section02)\n", 563 | "3. [Preparing the Dataset and Dataloader](#section03)\n", 564 | "4. [Creating the Neural Network for Fine Tuning](#section04)\n", 565 | "5. [Fine Tuning the Model](#section05)\n", 566 | "6. [Validating the Model Performance](#section06)\n", 567 | "7. [Saving the model and artifacts for Inference in Future](#section07)\n", 568 | "\n", 569 | "#### Technical Details\n", 570 | "\n", 571 | "This script leverages on multiple tools designed by other teams. Details of the tools used below. Please ensure that these elements are present in your setup to successfully implement this script.\n", 572 | "\n", 573 | " - Data: \n", 574 | "\t - I will be using the dataset available at [Kaggle Competition](https://www.kaggle.com/c/movie-review-sentiment-analysis-kernels-only)\n", 575 | "\t - I will be referring only to the first csv file from the data dump: `train.tsv`\n", 576 | "\n", 577 | " - Language Model Used:\n", 578 | "\t - The RoBERTa model was proposed in RoBERTa: A Robustly Optimized BERT Pretraining Approach by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. It is based on Google’s BERT model released in 2018.\n", 579 | "\t - [Blog-Post](https://ai.facebook.com/blog/roberta-an-optimized-method-for-pretraining-self-supervised-nlp-systems/)\n", 580 | "\t - [Research Paper](https://arxiv.org/pdf/1907.11692)\n", 581 | " - [Documentation for python](https://huggingface.co/transformers/model_doc/roberta.html)\n", 582 | "\n", 583 | "\n", 584 | " - Hardware Requirements:\n", 585 | "\t - Python 3.6 and above\n", 586 | "\t - Pytorch, Transformers and All the stock Python ML Libraries\n", 587 | "\t - GPU enabled setup " 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": { 593 | "id": "97CEi-bdqExb" 594 | }, 595 | "source": [ 596 | "\n", 597 | "### Importing Python Libraries and preparing the environment\n", 598 | "\n", 599 | "At this step we will be importing the libraries and modules needed to run our script. Libraries are:\n", 600 | "* Pandas\n", 601 | "* Pytorch\n", 602 | "* Pytorch Utils for Dataset and Dataloader\n", 603 | "* Transformers\n", 604 | "* tqdm\n", 605 | "* sklearn\n", 606 | "* Robert Model and Tokenizer\n", 607 | "\n", 608 | "Followed by that we will preapre the device for CUDA execeution. This configuration is needed if you want to leverage on onboard GPU. " 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "metadata": { 614 | "id": "a-GlywkSFegL" 615 | }, 616 | "source": [ 617 | "!pip install transformers==3.0.2" 618 | ], 619 | "execution_count": null, 620 | "outputs": [] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "metadata": { 625 | "trusted": true, 626 | "_uuid": "e7b5f5ab6f8f300c8900321a91b9340376c986f2", 627 | "id": "979OUro5Eac3" 628 | }, 629 | "source": [ 630 | "# Importing the libraries needed\n", 631 | "import pandas as pd\n", 632 | "import numpy as np\n", 633 | "from sklearn.model_selection import train_test_split\n", 634 | "import torch\n", 635 | "import seaborn as sns\n", 636 | "import transformers\n", 637 | "import json\n", 638 | "from tqdm import tqdm\n", 639 | "from torch.utils.data import Dataset, DataLoader\n", 640 | "from transformers import RobertaModel, RobertaTokenizer\n", 641 | "import logging\n", 642 | "logging.basicConfig(level=logging.ERROR)" 643 | ], 644 | "execution_count": null, 645 | "outputs": [] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "metadata": { 650 | "id": "sb1Q5N6LGK7z" 651 | }, 652 | "source": [ 653 | "# Setting up the device for GPU usage\n", 654 | "\n", 655 | "from torch import cuda\n", 656 | "device = 'cuda' if cuda.is_available() else 'cpu'" 657 | ], 658 | "execution_count": null, 659 | "outputs": [] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "metadata": { 664 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", 665 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", 666 | "trusted": true, 667 | "id": "J3FzcAlgEac8" 668 | }, 669 | "source": [ 670 | "train = pd.read_csv('train.tsv', delimiter='\\t')" 671 | ], 672 | "execution_count": null, 673 | "outputs": [] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "metadata": { 678 | "id": "TFIoIjucGjJw", 679 | "outputId": "52fda1fa-cb54-4eb2-c142-c90b5f958edf", 680 | "colab": { 681 | "base_uri": "https://localhost:8080/", 682 | "height": 34 683 | } 684 | }, 685 | "source": [ 686 | "train.shape" 687 | ], 688 | "execution_count": null, 689 | "outputs": [ 690 | { 691 | "output_type": "execute_result", 692 | "data": { 693 | "text/plain": [ 694 | "(156060, 4)" 695 | ] 696 | }, 697 | "metadata": { 698 | "tags": [] 699 | }, 700 | "execution_count": 15 701 | } 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "metadata": { 707 | "trusted": true, 708 | "_uuid": "c8dee062192ea016c0d306d3441ae2c573e2183c", 709 | "id": "aTsOsl4MEadB", 710 | "outputId": "ada765d8-2547-4196-9d73-c33c2b3bda39", 711 | "colab": { 712 | "base_uri": "https://localhost:8080/", 713 | "height": 204 714 | } 715 | }, 716 | "source": [ 717 | "train.head()" 718 | ], 719 | "execution_count": null, 720 | "outputs": [ 721 | { 722 | "output_type": "execute_result", 723 | "data": { 724 | "text/html": [ 725 | "
\n", 726 | "\n", 739 | "\n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | "
PhraseIdSentenceIdPhraseSentiment
011A series of escapades demonstrating the adage ...1
121A series of escapades demonstrating the adage ...2
231A series2
341A2
451series2
\n", 787 | "
" 788 | ], 789 | "text/plain": [ 790 | " PhraseId ... Sentiment\n", 791 | "0 1 ... 1\n", 792 | "1 2 ... 2\n", 793 | "2 3 ... 2\n", 794 | "3 4 ... 2\n", 795 | "4 5 ... 2\n", 796 | "\n", 797 | "[5 rows x 4 columns]" 798 | ] 799 | }, 800 | "metadata": { 801 | "tags": [] 802 | }, 803 | "execution_count": 5 804 | } 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "metadata": { 810 | "id": "lGcvxwWXIbfq", 811 | "outputId": "e5fe3e91-acba-4249-9ce9-d18fb30bc575", 812 | "colab": { 813 | "base_uri": "https://localhost:8080/", 814 | "height": 34 815 | } 816 | }, 817 | "source": [ 818 | "train['Sentiment'].unique()" 819 | ], 820 | "execution_count": null, 821 | "outputs": [ 822 | { 823 | "output_type": "execute_result", 824 | "data": { 825 | "text/plain": [ 826 | "array([1, 2, 3, 4, 0])" 827 | ] 828 | }, 829 | "metadata": { 830 | "tags": [] 831 | }, 832 | "execution_count": 21 833 | } 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "metadata": { 839 | "trusted": true, 840 | "_uuid": "4cc9d80f5b9969346c8f5ff24e3ce8de25dfc93d", 841 | "id": "y43HcyWgEadG", 842 | "outputId": "38ac5b75-2b34-4d13-8cbf-c56739ac6d02", 843 | "colab": { 844 | "base_uri": "https://localhost:8080/", 845 | "height": 297 846 | } 847 | }, 848 | "source": [ 849 | "train.describe()" 850 | ], 851 | "execution_count": null, 852 | "outputs": [ 853 | { 854 | "output_type": "execute_result", 855 | "data": { 856 | "text/html": [ 857 | "
\n", 858 | "\n", 871 | "\n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | "
PhraseIdSentenceIdSentiment
count156060.000000156060.000000156060.000000
mean78030.5000004079.7327442.063578
std45050.7858422502.7643940.893832
min1.0000001.0000000.000000
25%39015.7500001861.7500002.000000
50%78030.5000004017.0000002.000000
75%117045.2500006244.0000003.000000
max156060.0000008544.0000004.000000
\n", 931 | "
" 932 | ], 933 | "text/plain": [ 934 | " PhraseId SentenceId Sentiment\n", 935 | "count 156060.000000 156060.000000 156060.000000\n", 936 | "mean 78030.500000 4079.732744 2.063578\n", 937 | "std 45050.785842 2502.764394 0.893832\n", 938 | "min 1.000000 1.000000 0.000000\n", 939 | "25% 39015.750000 1861.750000 2.000000\n", 940 | "50% 78030.500000 4017.000000 2.000000\n", 941 | "75% 117045.250000 6244.000000 3.000000\n", 942 | "max 156060.000000 8544.000000 4.000000" 943 | ] 944 | }, 945 | "metadata": { 946 | "tags": [] 947 | }, 948 | "execution_count": 6 949 | } 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "metadata": { 955 | "trusted": true, 956 | "_uuid": "01da38cc4626a85b73fbb526d9a8d128d1fd9338", 957 | "id": "baSmeDdIEadM" 958 | }, 959 | "source": [ 960 | "new_df = train[['Phrase', 'Sentiment']]" 961 | ], 962 | "execution_count": null, 963 | "outputs": [] 964 | }, 965 | { 966 | "cell_type": "markdown", 967 | "metadata": { 968 | "id": "c3Q9NDdmqEyo" 969 | }, 970 | "source": [ 971 | "\n", 972 | "### Preparing the Dataset and Dataloader\n", 973 | "\n", 974 | "I will start with defining few key variables that will be used later during the training/fine tuning stage.\n", 975 | "Followed by creation of Dataset class - This defines how the text is pre-processed before sending it to the neural network. I will also define the Dataloader that will feed the data in batches to the neural network for suitable training and processing. \n", 976 | "Dataset and Dataloader are constructs of the PyTorch library for defining and controlling the data pre-processing and its passage to neural network. For further reading into Dataset and Dataloader read the [docs at PyTorch](https://pytorch.org/docs/stable/data.html)\n", 977 | "\n", 978 | "#### *SentimentData* Dataset Class\n", 979 | "- This class is defined to accept the Dataframe as input and generate tokenized output that is used by the Roberta model for training. \n", 980 | "- I am using the Roberta tokenizer to tokenize the data in the `TITLE` column of the dataframe. \n", 981 | "- The tokenizer uses the `encode_plus` method to perform tokenization and generate the necessary outputs, namely: `ids`, `attention_mask`\n", 982 | "- To read further into the tokenizer, [refer to this document](https://huggingface.co/transformers/model_doc/roberta.html#robertatokenizer)\n", 983 | "- `target` is the encoded category on the news headline. \n", 984 | "- The *SentimentData* class is used to create 2 datasets, for training and for validation.\n", 985 | "- *Training Dataset* is used to fine tune the model: **80% of the original data**\n", 986 | "- *Validation Dataset* is used to evaluate the performance of the model. The model has not seen this data during training. \n", 987 | "\n", 988 | "#### Dataloader\n", 989 | "- Dataloader is used to for creating training and validation dataloader that load data to the neural network in a defined manner. This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.\n", 990 | "- This control is achieved using the parameters such as `batch_size` and `max_len`.\n", 991 | "- Training and Validation dataloaders are used in the training and validation part of the flow respectively" 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "metadata": { 997 | "id": "nvXxpfNCGER2", 998 | "outputId": "d7281fe1-0dbf-42d7-c1e0-b51c4231c9c0", 999 | "colab": { 1000 | "base_uri": "https://localhost:8080/", 1001 | "height": 115, 1002 | "referenced_widgets": [ 1003 | "30864762e7f242c281b72862c5c08a33", 1004 | "dd12a39995584ba79f0e786b370b1a99", 1005 | "b89c9e76b5594a8ea601b9c5d2af4fa6", 1006 | "f65c7649640b4e87a7819a3da2f54fe0", 1007 | "80a6b6c9c4d5436ebe3b90b791c6fd93", 1008 | "8c1f6e94723842faa6bd3dcd9ff4ea82", 1009 | "39b5ca071fd3452e9d9145dd2b366da1", 1010 | "f8dfd3ea6bb7413592115195bc6e0b83", 1011 | "611dfdca86f4498e8aa1491ed6ffb13d", 1012 | "be4857f17c244fb39a771f2c97283fd5", 1013 | "2fe41e1db18b4295a6907771462a0fce", 1014 | "0b29a9e1a275451bbc2114807532f91e", 1015 | "115c8809853d410fac6e7f69af5a5488", 1016 | "390827d7d2cb4b4fbfc0c022f015f7ed", 1017 | "5d305d4db08f47ba91461edb343874a4", 1018 | "773af3ca0add4e7cac0a036fb8b55632" 1019 | ] 1020 | } 1021 | }, 1022 | "source": [ 1023 | "# Defining some key variables that will be used later on in the training\n", 1024 | "MAX_LEN = 256\n", 1025 | "TRAIN_BATCH_SIZE = 8\n", 1026 | "VALID_BATCH_SIZE = 4\n", 1027 | "# EPOCHS = 1\n", 1028 | "LEARNING_RATE = 1e-05\n", 1029 | "tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)" 1030 | ], 1031 | "execution_count": null, 1032 | "outputs": [ 1033 | { 1034 | "output_type": "display_data", 1035 | "data": { 1036 | "application/vnd.jupyter.widget-view+json": { 1037 | "model_id": "30864762e7f242c281b72862c5c08a33", 1038 | "version_minor": 0, 1039 | "version_major": 2 1040 | }, 1041 | "text/plain": [ 1042 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…" 1043 | ] 1044 | }, 1045 | "metadata": { 1046 | "tags": [] 1047 | } 1048 | }, 1049 | { 1050 | "output_type": "stream", 1051 | "text": [ 1052 | "\n" 1053 | ], 1054 | "name": "stdout" 1055 | }, 1056 | { 1057 | "output_type": "display_data", 1058 | "data": { 1059 | "application/vnd.jupyter.widget-view+json": { 1060 | "model_id": "611dfdca86f4498e8aa1491ed6ffb13d", 1061 | "version_minor": 0, 1062 | "version_major": 2 1063 | }, 1064 | "text/plain": [ 1065 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…" 1066 | ] 1067 | }, 1068 | "metadata": { 1069 | "tags": [] 1070 | } 1071 | }, 1072 | { 1073 | "output_type": "stream", 1074 | "text": [ 1075 | "\n" 1076 | ], 1077 | "name": "stdout" 1078 | } 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "metadata": { 1084 | "id": "3vWRDemOGxJD" 1085 | }, 1086 | "source": [ 1087 | "class SentimentData(Dataset):\n", 1088 | " def __init__(self, dataframe, tokenizer, max_len):\n", 1089 | " self.tokenizer = tokenizer\n", 1090 | " self.data = dataframe\n", 1091 | " self.text = dataframe.Phrase\n", 1092 | " self.targets = self.data.Sentiment\n", 1093 | " self.max_len = max_len\n", 1094 | "\n", 1095 | " def __len__(self):\n", 1096 | " return len(self.text)\n", 1097 | "\n", 1098 | " def __getitem__(self, index):\n", 1099 | " text = str(self.text[index])\n", 1100 | " text = \" \".join(text.split())\n", 1101 | "\n", 1102 | " inputs = self.tokenizer.encode_plus(\n", 1103 | " text,\n", 1104 | " None,\n", 1105 | " add_special_tokens=True,\n", 1106 | " max_length=self.max_len,\n", 1107 | " pad_to_max_length=True,\n", 1108 | " return_token_type_ids=True\n", 1109 | " )\n", 1110 | " ids = inputs['input_ids']\n", 1111 | " mask = inputs['attention_mask']\n", 1112 | " token_type_ids = inputs[\"token_type_ids\"]\n", 1113 | "\n", 1114 | "\n", 1115 | " return {\n", 1116 | " 'ids': torch.tensor(ids, dtype=torch.long),\n", 1117 | " 'mask': torch.tensor(mask, dtype=torch.long),\n", 1118 | " 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n", 1119 | " 'targets': torch.tensor(self.targets[index], dtype=torch.float)\n", 1120 | " }" 1121 | ], 1122 | "execution_count": null, 1123 | "outputs": [] 1124 | }, 1125 | { 1126 | "cell_type": "code", 1127 | "metadata": { 1128 | "id": "7Gpe9D1QHoCd", 1129 | "outputId": "7fc7fc2e-68a2-44b7-8e80-6bb6ce6c178b", 1130 | "colab": { 1131 | "base_uri": "https://localhost:8080/", 1132 | "height": 68 1133 | } 1134 | }, 1135 | "source": [ 1136 | "train_size = 0.8\n", 1137 | "train_data=new_df.sample(frac=train_size,random_state=200)\n", 1138 | "test_data=new_df.drop(train_data.index).reset_index(drop=True)\n", 1139 | "train_data = train_data.reset_index(drop=True)\n", 1140 | "\n", 1141 | "\n", 1142 | "print(\"FULL Dataset: {}\".format(new_df.shape))\n", 1143 | "print(\"TRAIN Dataset: {}\".format(train_data.shape))\n", 1144 | "print(\"TEST Dataset: {}\".format(test_data.shape))\n", 1145 | "\n", 1146 | "training_set = SentimentData(train_data, tokenizer, MAX_LEN)\n", 1147 | "testing_set = SentimentData(test_data, tokenizer, MAX_LEN)" 1148 | ], 1149 | "execution_count": null, 1150 | "outputs": [ 1151 | { 1152 | "output_type": "stream", 1153 | "text": [ 1154 | "FULL Dataset: (156060, 2)\n", 1155 | "TRAIN Dataset: (124848, 2)\n", 1156 | "TEST Dataset: (31212, 2)\n" 1157 | ], 1158 | "name": "stdout" 1159 | } 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "metadata": { 1165 | "trusted": true, 1166 | "_uuid": "9fc198d13d7f33dc70588c3f22bc7b7c4f4ebb45", 1167 | "id": "c1tInLk2Eadt" 1168 | }, 1169 | "source": [ 1170 | "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n", 1171 | " 'shuffle': True,\n", 1172 | " 'num_workers': 0\n", 1173 | " }\n", 1174 | "\n", 1175 | "test_params = {'batch_size': VALID_BATCH_SIZE,\n", 1176 | " 'shuffle': True,\n", 1177 | " 'num_workers': 0\n", 1178 | " }\n", 1179 | "\n", 1180 | "training_loader = DataLoader(training_set, **train_params)\n", 1181 | "testing_loader = DataLoader(testing_set, **test_params)" 1182 | ], 1183 | "execution_count": null, 1184 | "outputs": [] 1185 | }, 1186 | { 1187 | "cell_type": "markdown", 1188 | "metadata": { 1189 | "id": "yZk0A9K8qE0C" 1190 | }, 1191 | "source": [ 1192 | "\n", 1193 | "### Creating the Neural Network for Fine Tuning\n", 1194 | "\n", 1195 | "#### Neural Network\n", 1196 | " - We will be creating a neural network with the `RobertaClass`. \n", 1197 | " - This network will have the Roberta Language model followed by a `dropout` and finally a `Linear` layer to obtain the final outputs. \n", 1198 | " - The data will be fed to the Roberta Language model as defined in the dataset. \n", 1199 | " - Final layer outputs is what will be compared to the `Sentiment category` to determine the accuracy of models prediction. \n", 1200 | " - We will initiate an instance of the network called `model`. This instance will be used for training and then to save the final trained model for future inference. \n", 1201 | " \n", 1202 | "#### Loss Function and Optimizer\n", 1203 | " - `Loss Function` and `Optimizer` and defined in the next cell.\n", 1204 | " - The `Loss Function` is used the calculate the difference in the output created by the model and the actual output. \n", 1205 | " - `Optimizer` is used to update the weights of the neural network to improve its performance." 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "metadata": { 1211 | "trusted": true, 1212 | "_uuid": "cb8f194ee79d76356be0002b0e18f947e1412d66", 1213 | "id": "HMqQTafXEaei" 1214 | }, 1215 | "source": [ 1216 | "class RobertaClass(torch.nn.Module):\n", 1217 | " def __init__(self):\n", 1218 | " super(RobertaClass, self).__init__()\n", 1219 | " self.l1 = RobertaModel.from_pretrained(\"roberta-base\")\n", 1220 | " self.pre_classifier = torch.nn.Linear(768, 768)\n", 1221 | " self.dropout = torch.nn.Dropout(0.3)\n", 1222 | " self.classifier = torch.nn.Linear(768, 5)\n", 1223 | "\n", 1224 | " def forward(self, input_ids, attention_mask, token_type_ids):\n", 1225 | " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", 1226 | " hidden_state = output_1[0]\n", 1227 | " pooler = hidden_state[:, 0]\n", 1228 | " pooler = self.pre_classifier(pooler)\n", 1229 | " pooler = torch.nn.ReLU()(pooler)\n", 1230 | " pooler = self.dropout(pooler)\n", 1231 | " output = self.classifier(pooler)\n", 1232 | " return output" 1233 | ], 1234 | "execution_count": null, 1235 | "outputs": [] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "metadata": { 1240 | "id": "sZ55mIPZIkp_", 1241 | "outputId": "35048672-7bf0-44bc-8fd5-b9ffbd207a17", 1242 | "colab": { 1243 | "base_uri": "https://localhost:8080/", 1244 | "height": 1000 1245 | } 1246 | }, 1247 | "source": [ 1248 | "model = RobertaClass()\n", 1249 | "model.to(device)" 1250 | ], 1251 | "execution_count": null, 1252 | "outputs": [ 1253 | { 1254 | "output_type": "execute_result", 1255 | "data": { 1256 | "text/plain": [ 1257 | "RobertaClass(\n", 1258 | " (l1): RobertaModel(\n", 1259 | " (embeddings): RobertaEmbeddings(\n", 1260 | " (word_embeddings): Embedding(50265, 768, padding_idx=1)\n", 1261 | " (position_embeddings): Embedding(514, 768, padding_idx=1)\n", 1262 | " (token_type_embeddings): Embedding(1, 768)\n", 1263 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1264 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1265 | " )\n", 1266 | " (encoder): BertEncoder(\n", 1267 | " (layer): ModuleList(\n", 1268 | " (0): BertLayer(\n", 1269 | " (attention): BertAttention(\n", 1270 | " (self): BertSelfAttention(\n", 1271 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1272 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1273 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1274 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1275 | " )\n", 1276 | " (output): BertSelfOutput(\n", 1277 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1278 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1279 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1280 | " )\n", 1281 | " )\n", 1282 | " (intermediate): BertIntermediate(\n", 1283 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1284 | " )\n", 1285 | " (output): BertOutput(\n", 1286 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1287 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1288 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1289 | " )\n", 1290 | " )\n", 1291 | " (1): BertLayer(\n", 1292 | " (attention): BertAttention(\n", 1293 | " (self): BertSelfAttention(\n", 1294 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1295 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1296 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1297 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1298 | " )\n", 1299 | " (output): BertSelfOutput(\n", 1300 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1301 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1302 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1303 | " )\n", 1304 | " )\n", 1305 | " (intermediate): BertIntermediate(\n", 1306 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1307 | " )\n", 1308 | " (output): BertOutput(\n", 1309 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1310 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1311 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1312 | " )\n", 1313 | " )\n", 1314 | " (2): BertLayer(\n", 1315 | " (attention): BertAttention(\n", 1316 | " (self): BertSelfAttention(\n", 1317 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1318 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1319 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1320 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1321 | " )\n", 1322 | " (output): BertSelfOutput(\n", 1323 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1324 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1325 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1326 | " )\n", 1327 | " )\n", 1328 | " (intermediate): BertIntermediate(\n", 1329 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1330 | " )\n", 1331 | " (output): BertOutput(\n", 1332 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1333 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1334 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1335 | " )\n", 1336 | " )\n", 1337 | " (3): BertLayer(\n", 1338 | " (attention): BertAttention(\n", 1339 | " (self): BertSelfAttention(\n", 1340 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1341 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1342 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1343 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1344 | " )\n", 1345 | " (output): BertSelfOutput(\n", 1346 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1347 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1348 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1349 | " )\n", 1350 | " )\n", 1351 | " (intermediate): BertIntermediate(\n", 1352 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1353 | " )\n", 1354 | " (output): BertOutput(\n", 1355 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1356 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1357 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1358 | " )\n", 1359 | " )\n", 1360 | " (4): BertLayer(\n", 1361 | " (attention): BertAttention(\n", 1362 | " (self): BertSelfAttention(\n", 1363 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1364 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1365 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1366 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1367 | " )\n", 1368 | " (output): BertSelfOutput(\n", 1369 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1370 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1371 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1372 | " )\n", 1373 | " )\n", 1374 | " (intermediate): BertIntermediate(\n", 1375 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1376 | " )\n", 1377 | " (output): BertOutput(\n", 1378 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1379 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1380 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1381 | " )\n", 1382 | " )\n", 1383 | " (5): BertLayer(\n", 1384 | " (attention): BertAttention(\n", 1385 | " (self): BertSelfAttention(\n", 1386 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1387 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1388 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1389 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1390 | " )\n", 1391 | " (output): BertSelfOutput(\n", 1392 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1393 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1394 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1395 | " )\n", 1396 | " )\n", 1397 | " (intermediate): BertIntermediate(\n", 1398 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1399 | " )\n", 1400 | " (output): BertOutput(\n", 1401 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1402 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1403 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1404 | " )\n", 1405 | " )\n", 1406 | " (6): BertLayer(\n", 1407 | " (attention): BertAttention(\n", 1408 | " (self): BertSelfAttention(\n", 1409 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1410 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1411 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1412 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1413 | " )\n", 1414 | " (output): BertSelfOutput(\n", 1415 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1416 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1417 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1418 | " )\n", 1419 | " )\n", 1420 | " (intermediate): BertIntermediate(\n", 1421 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1422 | " )\n", 1423 | " (output): BertOutput(\n", 1424 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1425 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1426 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1427 | " )\n", 1428 | " )\n", 1429 | " (7): BertLayer(\n", 1430 | " (attention): BertAttention(\n", 1431 | " (self): BertSelfAttention(\n", 1432 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1433 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1434 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1435 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1436 | " )\n", 1437 | " (output): BertSelfOutput(\n", 1438 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1439 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1440 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1441 | " )\n", 1442 | " )\n", 1443 | " (intermediate): BertIntermediate(\n", 1444 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1445 | " )\n", 1446 | " (output): BertOutput(\n", 1447 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1448 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1449 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1450 | " )\n", 1451 | " )\n", 1452 | " (8): BertLayer(\n", 1453 | " (attention): BertAttention(\n", 1454 | " (self): BertSelfAttention(\n", 1455 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1456 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1457 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1458 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1459 | " )\n", 1460 | " (output): BertSelfOutput(\n", 1461 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1462 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1463 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1464 | " )\n", 1465 | " )\n", 1466 | " (intermediate): BertIntermediate(\n", 1467 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1468 | " )\n", 1469 | " (output): BertOutput(\n", 1470 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1471 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1472 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1473 | " )\n", 1474 | " )\n", 1475 | " (9): BertLayer(\n", 1476 | " (attention): BertAttention(\n", 1477 | " (self): BertSelfAttention(\n", 1478 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1479 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1480 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1481 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1482 | " )\n", 1483 | " (output): BertSelfOutput(\n", 1484 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1485 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1486 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1487 | " )\n", 1488 | " )\n", 1489 | " (intermediate): BertIntermediate(\n", 1490 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1491 | " )\n", 1492 | " (output): BertOutput(\n", 1493 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1494 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1495 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1496 | " )\n", 1497 | " )\n", 1498 | " (10): BertLayer(\n", 1499 | " (attention): BertAttention(\n", 1500 | " (self): BertSelfAttention(\n", 1501 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1502 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1503 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1504 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1505 | " )\n", 1506 | " (output): BertSelfOutput(\n", 1507 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1508 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1509 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1510 | " )\n", 1511 | " )\n", 1512 | " (intermediate): BertIntermediate(\n", 1513 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1514 | " )\n", 1515 | " (output): BertOutput(\n", 1516 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1517 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1518 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1519 | " )\n", 1520 | " )\n", 1521 | " (11): BertLayer(\n", 1522 | " (attention): BertAttention(\n", 1523 | " (self): BertSelfAttention(\n", 1524 | " (query): Linear(in_features=768, out_features=768, bias=True)\n", 1525 | " (key): Linear(in_features=768, out_features=768, bias=True)\n", 1526 | " (value): Linear(in_features=768, out_features=768, bias=True)\n", 1527 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1528 | " )\n", 1529 | " (output): BertSelfOutput(\n", 1530 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1531 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1532 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1533 | " )\n", 1534 | " )\n", 1535 | " (intermediate): BertIntermediate(\n", 1536 | " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", 1537 | " )\n", 1538 | " (output): BertOutput(\n", 1539 | " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", 1540 | " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 1541 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1542 | " )\n", 1543 | " )\n", 1544 | " )\n", 1545 | " )\n", 1546 | " (pooler): BertPooler(\n", 1547 | " (dense): Linear(in_features=768, out_features=768, bias=True)\n", 1548 | " (activation): Tanh()\n", 1549 | " )\n", 1550 | " )\n", 1551 | " (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n", 1552 | " (dropout): Dropout(p=0.3, inplace=False)\n", 1553 | " (classifier): Linear(in_features=768, out_features=5, bias=True)\n", 1554 | ")" 1555 | ] 1556 | }, 1557 | "metadata": { 1558 | "tags": [] 1559 | }, 1560 | "execution_count": 37 1561 | } 1562 | ] 1563 | }, 1564 | { 1565 | "cell_type": "markdown", 1566 | "metadata": { 1567 | "id": "gsRa7gY3qE0n" 1568 | }, 1569 | "source": [ 1570 | "\n", 1571 | "### Fine Tuning the Model\n", 1572 | "\n", 1573 | "After all the effort of loading and preparing the data and datasets, creating the model and defining its loss and optimizer. This is probably the easier steps in the process. \n", 1574 | "\n", 1575 | "Here we define a training function that trains the model on the training dataset created above, specified number of times (EPOCH), An epoch defines how many times the complete data will be passed through the network. \n", 1576 | "\n", 1577 | "Following events happen in this function to fine tune the neural network:\n", 1578 | "- The dataloader passes data to the model based on the batch size. \n", 1579 | "- Subsequent output from the model and the actual category are compared to calculate the loss. \n", 1580 | "- Loss value is used to optimize the weights of the neurons in the network.\n", 1581 | "- After every 5000 steps the loss value is printed in the console.\n", 1582 | "\n", 1583 | "As you can see just in 1 epoch by the final step the model was working with a loss of 0.8141926634122427." 1584 | ] 1585 | }, 1586 | { 1587 | "cell_type": "code", 1588 | "metadata": { 1589 | "id": "XYZ7YuJ5InOS" 1590 | }, 1591 | "source": [ 1592 | "# Creating the loss function and optimizer\n", 1593 | "loss_function = torch.nn.CrossEntropyLoss()\n", 1594 | "optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)" 1595 | ], 1596 | "execution_count": null, 1597 | "outputs": [] 1598 | }, 1599 | { 1600 | "cell_type": "code", 1601 | "metadata": { 1602 | "id": "yPhA2V3iIpzN" 1603 | }, 1604 | "source": [ 1605 | "def calcuate_accuracy(preds, targets):\n", 1606 | " n_correct = (preds==targets).sum().item()\n", 1607 | " return n_correct" 1608 | ], 1609 | "execution_count": null, 1610 | "outputs": [] 1611 | }, 1612 | { 1613 | "cell_type": "code", 1614 | "metadata": { 1615 | "id": "mhqvtY2SIup7" 1616 | }, 1617 | "source": [ 1618 | "# Defining the training function on the 80% of the dataset for tuning the distilbert model\n", 1619 | "\n", 1620 | "def train(epoch):\n", 1621 | " tr_loss = 0\n", 1622 | " n_correct = 0\n", 1623 | " nb_tr_steps = 0\n", 1624 | " nb_tr_examples = 0\n", 1625 | " model.train()\n", 1626 | " for _,data in tqdm(enumerate(training_loader, 0)):\n", 1627 | " ids = data['ids'].to(device, dtype = torch.long)\n", 1628 | " mask = data['mask'].to(device, dtype = torch.long)\n", 1629 | " token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)\n", 1630 | " targets = data['targets'].to(device, dtype = torch.long)\n", 1631 | "\n", 1632 | " outputs = model(ids, mask, token_type_ids)\n", 1633 | " loss = loss_function(outputs, targets)\n", 1634 | " tr_loss += loss.item()\n", 1635 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 1636 | " n_correct += calcuate_accuracy(big_idx, targets)\n", 1637 | "\n", 1638 | " nb_tr_steps += 1\n", 1639 | " nb_tr_examples+=targets.size(0)\n", 1640 | " \n", 1641 | " if _%5000==0:\n", 1642 | " loss_step = tr_loss/nb_tr_steps\n", 1643 | " accu_step = (n_correct*100)/nb_tr_examples \n", 1644 | " print(f\"Training Loss per 5000 steps: {loss_step}\")\n", 1645 | " print(f\"Training Accuracy per 5000 steps: {accu_step}\")\n", 1646 | "\n", 1647 | " optimizer.zero_grad()\n", 1648 | " loss.backward()\n", 1649 | " # # When using GPU\n", 1650 | " optimizer.step()\n", 1651 | "\n", 1652 | " print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')\n", 1653 | " epoch_loss = tr_loss/nb_tr_steps\n", 1654 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 1655 | " print(f\"Training Loss Epoch: {epoch_loss}\")\n", 1656 | " print(f\"Training Accuracy Epoch: {epoch_accu}\")\n", 1657 | "\n", 1658 | " return " 1659 | ], 1660 | "execution_count": null, 1661 | "outputs": [] 1662 | }, 1663 | { 1664 | "cell_type": "code", 1665 | "metadata": { 1666 | "id": "Afn7xaunJHnI", 1667 | "outputId": "4ca0a58f-3d9f-432f-9da4-0c74f3b1cb02", 1668 | "colab": { 1669 | "base_uri": "https://localhost:8080/", 1670 | "height": 221 1671 | } 1672 | }, 1673 | "source": [ 1674 | "EPOCHS = 1\n", 1675 | "for epoch in range(EPOCHS):\n", 1676 | " train(epoch)" 1677 | ], 1678 | "execution_count": null, 1679 | "outputs": [ 1680 | { 1681 | "output_type": "stream", 1682 | "text": [ 1683 | "1it [00:00, 4.04it/s]" 1684 | ], 1685 | "name": "stderr" 1686 | }, 1687 | { 1688 | "output_type": "stream", 1689 | "text": [ 1690 | "Training Loss per 5000 steps: 1.2416878938674927\n", 1691 | "Training Accuracy per 5000 steps: 62.5\n" 1692 | ], 1693 | "name": "stdout" 1694 | }, 1695 | { 1696 | "output_type": "stream", 1697 | "text": [ 1698 | "5001it [18:40, 4.42it/s]" 1699 | ], 1700 | "name": "stderr" 1701 | }, 1702 | { 1703 | "output_type": "stream", 1704 | "text": [ 1705 | "Training Loss per 5000 steps: 0.8735729315070672\n", 1706 | "Training Accuracy per 5000 steps: 64.37212557488502\n" 1707 | ], 1708 | "name": "stdout" 1709 | }, 1710 | { 1711 | "output_type": "stream", 1712 | "text": [ 1713 | "10001it [37:22, 4.40it/s]" 1714 | ], 1715 | "name": "stderr" 1716 | }, 1717 | { 1718 | "output_type": "stream", 1719 | "text": [ 1720 | "Training Loss per 5000 steps: 0.8366646968724489\n", 1721 | "Training Accuracy per 5000 steps: 65.3947105289471\n" 1722 | ], 1723 | "name": "stdout" 1724 | }, 1725 | { 1726 | "output_type": "stream", 1727 | "text": [ 1728 | "15001it [56:03, 4.43it/s]" 1729 | ], 1730 | "name": "stderr" 1731 | }, 1732 | { 1733 | "output_type": "stream", 1734 | "text": [ 1735 | "Training Loss per 5000 steps: 0.8169289541139411\n", 1736 | "Training Accuracy per 5000 steps: 66.141423905073\n" 1737 | ], 1738 | "name": "stdout" 1739 | }, 1740 | { 1741 | "output_type": "stream", 1742 | "text": [ 1743 | "15606it [58:18, 4.46it/s]" 1744 | ], 1745 | "name": "stderr" 1746 | }, 1747 | { 1748 | "output_type": "stream", 1749 | "text": [ 1750 | "The Total Accuracy for Epoch 0: 66.24695629885942\n", 1751 | "Training Loss Epoch: 0.8141926634122427\n", 1752 | "Training Accuracy Epoch: 66.24695629885942\n" 1753 | ], 1754 | "name": "stdout" 1755 | }, 1756 | { 1757 | "output_type": "stream", 1758 | "text": [ 1759 | "\n" 1760 | ], 1761 | "name": "stderr" 1762 | } 1763 | ] 1764 | }, 1765 | { 1766 | "cell_type": "markdown", 1767 | "metadata": { 1768 | "id": "vOcgTsovqE1A" 1769 | }, 1770 | "source": [ 1771 | "\n", 1772 | "### Validating the Model\n", 1773 | "\n", 1774 | "During the validation stage we pass the unseen data(Testing Dataset) to the model. This step determines how good the model performs on the unseen data. \n", 1775 | "\n", 1776 | "This unseen data is the 20% of `train.tsv` which was seperated during the Dataset creation stage. \n", 1777 | "During the validation stage the weights of the model are not updated. Only the final output is compared to the actual value. This comparison is then used to calcuate the accuracy of the model. \n", 1778 | "\n", 1779 | "As you can see the model is predicting the correct category of a given sample to a 69.47% accuracy which can further be improved by training more." 1780 | ] 1781 | }, 1782 | { 1783 | "cell_type": "code", 1784 | "metadata": { 1785 | "id": "bFiNcy16JLwt" 1786 | }, 1787 | "source": [ 1788 | "def valid(model, testing_loader):\n", 1789 | " model.eval()\n", 1790 | " n_correct = 0; n_wrong = 0; total = 0; tr_loss=0; nb_tr_steps=0; nb_tr_examples=0\n", 1791 | " with torch.no_grad():\n", 1792 | " for _, data in tqdm(enumerate(testing_loader, 0)):\n", 1793 | " ids = data['ids'].to(device, dtype = torch.long)\n", 1794 | " mask = data['mask'].to(device, dtype = torch.long)\n", 1795 | " token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)\n", 1796 | " targets = data['targets'].to(device, dtype = torch.long)\n", 1797 | " outputs = model(ids, mask, token_type_ids).squeeze()\n", 1798 | " loss = loss_function(outputs, targets)\n", 1799 | " tr_loss += loss.item()\n", 1800 | " big_val, big_idx = torch.max(outputs.data, dim=1)\n", 1801 | " n_correct += calcuate_accuracy(big_idx, targets)\n", 1802 | "\n", 1803 | " nb_tr_steps += 1\n", 1804 | " nb_tr_examples+=targets.size(0)\n", 1805 | " \n", 1806 | " if _%5000==0:\n", 1807 | " loss_step = tr_loss/nb_tr_steps\n", 1808 | " accu_step = (n_correct*100)/nb_tr_examples\n", 1809 | " print(f\"Validation Loss per 100 steps: {loss_step}\")\n", 1810 | " print(f\"Validation Accuracy per 100 steps: {accu_step}\")\n", 1811 | " epoch_loss = tr_loss/nb_tr_steps\n", 1812 | " epoch_accu = (n_correct*100)/nb_tr_examples\n", 1813 | " print(f\"Validation Loss Epoch: {epoch_loss}\")\n", 1814 | " print(f\"Validation Accuracy Epoch: {epoch_accu}\")\n", 1815 | " \n", 1816 | " return epoch_accu\n" 1817 | ], 1818 | "execution_count": null, 1819 | "outputs": [] 1820 | }, 1821 | { 1822 | "cell_type": "code", 1823 | "metadata": { 1824 | "id": "UcUylInzKdV-", 1825 | "outputId": "ebb887fc-04a6-4a44-cb50-9fb8b17853ea", 1826 | "colab": { 1827 | "base_uri": "https://localhost:8080/", 1828 | "height": 153 1829 | } 1830 | }, 1831 | "source": [ 1832 | "acc = valid(model, testing_loader)\n", 1833 | "print(\"Accuracy on test data = %0.2f%%\" % acc)" 1834 | ], 1835 | "execution_count": null, 1836 | "outputs": [ 1837 | { 1838 | "output_type": "stream", 1839 | "text": [ 1840 | "3it [00:00, 23.17it/s]" 1841 | ], 1842 | "name": "stderr" 1843 | }, 1844 | { 1845 | "output_type": "stream", 1846 | "text": [ 1847 | "Validation Loss per 100 steps: 0.6547155380249023\n", 1848 | "Validation Accuracy per 100 steps: 75.0\n" 1849 | ], 1850 | "name": "stdout" 1851 | }, 1852 | { 1853 | "output_type": "stream", 1854 | "text": [ 1855 | "5004it [03:13, 25.87it/s]" 1856 | ], 1857 | "name": "stderr" 1858 | }, 1859 | { 1860 | "output_type": "stream", 1861 | "text": [ 1862 | "Validation Loss per 100 steps: 0.736690901120772\n", 1863 | "Validation Accuracy per 100 steps: 69.22115576884623\n" 1864 | ], 1865 | "name": "stdout" 1866 | }, 1867 | { 1868 | "output_type": "stream", 1869 | "text": [ 1870 | "7803it [05:01, 25.89it/s]" 1871 | ], 1872 | "name": "stderr" 1873 | }, 1874 | { 1875 | "output_type": "stream", 1876 | "text": [ 1877 | "Validation Loss Epoch: 0.7332612214877096\n", 1878 | "Validation Accuracy Epoch: 69.46687171600666\n", 1879 | "Accuracy on test data = 69.47%\n" 1880 | ], 1881 | "name": "stdout" 1882 | }, 1883 | { 1884 | "output_type": "stream", 1885 | "text": [ 1886 | "\n" 1887 | ], 1888 | "name": "stderr" 1889 | } 1890 | ] 1891 | }, 1892 | { 1893 | "cell_type": "markdown", 1894 | "metadata": { 1895 | "id": "tZgO6C1BqE1a" 1896 | }, 1897 | "source": [ 1898 | "\n", 1899 | "### Saving the Trained Model Artifacts for inference\n", 1900 | "\n", 1901 | "This is the final step in the process of fine tuning the model. \n", 1902 | "\n", 1903 | "The model and its vocabulary are saved locally. These files are then used in the future to make inference on new inputs of news headlines." 1904 | ] 1905 | }, 1906 | { 1907 | "cell_type": "code", 1908 | "metadata": { 1909 | "id": "8eKt004BKjyT", 1910 | "outputId": "8f43c6b5-8772-4158-f8cc-f5bbd72b2f14", 1911 | "colab": { 1912 | "base_uri": "https://localhost:8080/", 1913 | "height": 51 1914 | } 1915 | }, 1916 | "source": [ 1917 | "output_model_file = 'pytorch_roberta_sentiment.bin'\n", 1918 | "output_vocab_file = './'\n", 1919 | "\n", 1920 | "model_to_save = model\n", 1921 | "torch.save(model_to_save, output_model_file)\n", 1922 | "tokenizer.save_vocabulary(output_vocab_file)\n", 1923 | "\n", 1924 | "print('All files saved')\n", 1925 | "print('This tutorial is completed')" 1926 | ], 1927 | "execution_count": null, 1928 | "outputs": [ 1929 | { 1930 | "output_type": "stream", 1931 | "text": [ 1932 | "All files saved\n", 1933 | "This tutorial is completed\n" 1934 | ], 1935 | "name": "stdout" 1936 | } 1937 | ] 1938 | }, 1939 | { 1940 | "cell_type": "code", 1941 | "metadata": { 1942 | "id": "IetKrn_SY-OT" 1943 | }, 1944 | "source": [ 1945 | "" 1946 | ], 1947 | "execution_count": null, 1948 | "outputs": [] 1949 | } 1950 | ] 1951 | } --------------------------------------------------------------------------------