├── README.md ├── data └── news_summary.csv └── notebook └── T5_Fine_tuning_with_PyTorch.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning T5 model with PyTorch 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eoQUsisoPmc0e-bpjSKYYd-TE1F5YTqG?usp=sharing) 4 | [![Go to Blog](https://img.shields.io/badge/Go%20to%20Blog-Shivanandroy.com-green)](https://shivanandroy.com/fine-tune-t5-transformer-with-pytorch/) 5 | 6 | ```python 7 | # install libraries 8 | !pip install sentencepiece 9 | !pip install transformers 10 | !pip install torch 11 | !pip install rich[jupyter] 12 | 13 | # Importing libraries 14 | import os 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler 20 | import os 21 | 22 | # Importing the T5 modules from huggingface/transformers 23 | from transformers import T5Tokenizer, T5ForConditionalGeneration 24 | 25 | # rich: for a better display on terminal 26 | from rich.table import Column, Table 27 | from rich import box 28 | from rich.console import Console 29 | 30 | # define a rich console logger 31 | console = Console(record=True) 32 | 33 | # to display dataframe in ASCII format 34 | def display_df(df): 35 | """display dataframe in ASCII format""" 36 | 37 | console = Console() 38 | table = Table( 39 | Column("source_text", justify="center"), 40 | Column("target_text", justify="center"), 41 | title="Sample Data", 42 | pad_edge=False, 43 | box=box.ASCII, 44 | ) 45 | 46 | for i, row in enumerate(df.values.tolist()): 47 | table.add_row(row[0], row[1]) 48 | 49 | console.print(table) 50 | 51 | # training logger to log training progress 52 | training_logger = Table( 53 | Column("Epoch", justify="center"), 54 | Column("Steps", justify="center"), 55 | Column("Loss", justify="center"), 56 | title="Training Status", 57 | pad_edge=False, 58 | box=box.ASCII, 59 | ) 60 | 61 | # Setting up the device for GPU usage 62 | from torch import cuda 63 | device = 'cuda' if cuda.is_available() else 'cpu' 64 | ``` 65 | ### Dataset Class 66 | ```python 67 | class YourDataSetClass(Dataset): 68 | """ 69 | Creating a custom dataset for reading the dataset and 70 | loading it into the dataloader to pass it to the 71 | neural network for finetuning the model 72 | 73 | """ 74 | 75 | def __init__( 76 | self, dataframe, tokenizer, source_len, target_len, source_text, target_text 77 | ): 78 | """ 79 | Initializes a Dataset class 80 | 81 | Args: 82 | dataframe (pandas.DataFrame): Input dataframe 83 | tokenizer (transformers.tokenizer): Transformers tokenizer 84 | source_len (int): Max length of source text 85 | target_len (int): Max length of target text 86 | source_text (str): column name of source text 87 | target_text (str): column name of target text 88 | """ 89 | self.tokenizer = tokenizer 90 | self.data = dataframe 91 | self.source_len = source_len 92 | self.summ_len = target_len 93 | self.target_text = self.data[target_text] 94 | self.source_text = self.data[source_text] 95 | 96 | def __len__(self): 97 | """returns the length of dataframe""" 98 | 99 | return len(self.target_text) 100 | 101 | def __getitem__(self, index): 102 | """return the input ids, attention masks and target ids""" 103 | 104 | source_text = str(self.source_text[index]) 105 | target_text = str(self.target_text[index]) 106 | 107 | # cleaning data so as to ensure data is in string type 108 | source_text = " ".join(source_text.split()) 109 | target_text = " ".join(target_text.split()) 110 | 111 | source = self.tokenizer.batch_encode_plus( 112 | [source_text], 113 | max_length=self.source_len, 114 | pad_to_max_length=True, 115 | truncation=True, 116 | padding="max_length", 117 | return_tensors="pt", 118 | ) 119 | target = self.tokenizer.batch_encode_plus( 120 | [target_text], 121 | max_length=self.summ_len, 122 | pad_to_max_length=True, 123 | truncation=True, 124 | padding="max_length", 125 | return_tensors="pt", 126 | ) 127 | 128 | source_ids = source["input_ids"].squeeze() 129 | source_mask = source["attention_mask"].squeeze() 130 | target_ids = target["input_ids"].squeeze() 131 | target_mask = target["attention_mask"].squeeze() 132 | 133 | return { 134 | "source_ids": source_ids.to(dtype=torch.long), 135 | "source_mask": source_mask.to(dtype=torch.long), 136 | "target_ids": target_ids.to(dtype=torch.long), 137 | "target_ids_y": target_ids.to(dtype=torch.long), 138 | } 139 | ``` 140 | ### Train 141 | ```python 142 | def train(epoch, tokenizer, model, device, loader, optimizer): 143 | 144 | """ 145 | Function to be called for training with the parameters passed from main function 146 | 147 | """ 148 | 149 | model.train() 150 | for _, data in enumerate(loader, 0): 151 | y = data["target_ids"].to(device, dtype=torch.long) 152 | y_ids = y[:, :-1].contiguous() 153 | lm_labels = y[:, 1:].clone().detach() 154 | lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 155 | ids = data["source_ids"].to(device, dtype=torch.long) 156 | mask = data["source_mask"].to(device, dtype=torch.long) 157 | 158 | outputs = model( 159 | input_ids=ids, 160 | attention_mask=mask, 161 | decoder_input_ids=y_ids, 162 | labels=lm_labels, 163 | ) 164 | loss = outputs[0] 165 | 166 | if _ % 10 == 0: 167 | training_logger.add_row(str(epoch), str(_), str(loss)) 168 | console.print(training_logger) 169 | 170 | optimizer.zero_grad() 171 | loss.backward() 172 | optimizer.step() 173 | 174 | ``` 175 | ### Validate 176 | ```python 177 | def validate(epoch, tokenizer, model, device, loader): 178 | 179 | """ 180 | Function to evaluate model for predictions 181 | 182 | """ 183 | model.eval() 184 | predictions = [] 185 | actuals = [] 186 | with torch.no_grad(): 187 | for _, data in enumerate(loader, 0): 188 | y = data['target_ids'].to(device, dtype = torch.long) 189 | ids = data['source_ids'].to(device, dtype = torch.long) 190 | mask = data['source_mask'].to(device, dtype = torch.long) 191 | 192 | generated_ids = model.generate( 193 | input_ids = ids, 194 | attention_mask = mask, 195 | max_length=150, 196 | num_beams=2, 197 | repetition_penalty=2.5, 198 | length_penalty=1.0, 199 | early_stopping=True 200 | ) 201 | preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] 202 | target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y] 203 | if _%10==0: 204 | console.print(f'Completed {_}') 205 | 206 | predictions.extend(preds) 207 | actuals.extend(target) 208 | return predictions, actuals 209 | 210 | ``` 211 | ### T5 Trainer 212 | ```python 213 | def T5Trainer( 214 | dataframe, source_text, target_text, model_params, output_dir="./outputs/" 215 | ): 216 | 217 | """ 218 | T5 trainer 219 | 220 | """ 221 | 222 | # Set random seeds and deterministic pytorch for reproducibility 223 | torch.manual_seed(model_params["SEED"]) # pytorch random seed 224 | np.random.seed(model_params["SEED"]) # numpy random seed 225 | torch.backends.cudnn.deterministic = True 226 | 227 | # logging 228 | console.log(f"""[Model]: Loading {model_params["MODEL"]}...\n""") 229 | 230 | # tokenzier for encoding the text 231 | tokenizer = T5Tokenizer.from_pretrained(model_params["MODEL"]) 232 | 233 | # Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary. 234 | # Further this model is sent to device (GPU/TPU) for using the hardware. 235 | model = T5ForConditionalGeneration.from_pretrained(model_params["MODEL"]) 236 | model = model.to(device) 237 | 238 | # logging 239 | console.log(f"[Data]: Reading data...\n") 240 | 241 | # Importing the raw dataset 242 | dataframe = dataframe[[source_text, target_text]] 243 | display_df(dataframe.head(2)) 244 | 245 | # Creation of Dataset and Dataloader 246 | # Defining the train size. So 80% of the data will be used for training and the rest for validation. 247 | train_size = 0.8 248 | train_dataset = dataframe.sample(frac=train_size, random_state=model_params["SEED"]) 249 | val_dataset = dataframe.drop(train_dataset.index).reset_index(drop=True) 250 | train_dataset = train_dataset.reset_index(drop=True) 251 | 252 | console.print(f"FULL Dataset: {dataframe.shape}") 253 | console.print(f"TRAIN Dataset: {train_dataset.shape}") 254 | console.print(f"TEST Dataset: {val_dataset.shape}\n") 255 | 256 | # Creating the Training and Validation dataset for further creation of Dataloader 257 | training_set = YourDataSetClass( 258 | train_dataset, 259 | tokenizer, 260 | model_params["MAX_SOURCE_TEXT_LENGTH"], 261 | model_params["MAX_TARGET_TEXT_LENGTH"], 262 | source_text, 263 | target_text, 264 | ) 265 | val_set = YourDataSetClass( 266 | val_dataset, 267 | tokenizer, 268 | model_params["MAX_SOURCE_TEXT_LENGTH"], 269 | model_params["MAX_TARGET_TEXT_LENGTH"], 270 | source_text, 271 | target_text, 272 | ) 273 | 274 | # Defining the parameters for creation of dataloaders 275 | train_params = { 276 | "batch_size": model_params["TRAIN_BATCH_SIZE"], 277 | "shuffle": True, 278 | "num_workers": 0, 279 | } 280 | 281 | val_params = { 282 | "batch_size": model_params["VALID_BATCH_SIZE"], 283 | "shuffle": False, 284 | "num_workers": 0, 285 | } 286 | 287 | # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model. 288 | training_loader = DataLoader(training_set, **train_params) 289 | val_loader = DataLoader(val_set, **val_params) 290 | 291 | # Defining the optimizer that will be used to tune the weights of the network in the training session. 292 | optimizer = torch.optim.Adam( 293 | params=model.parameters(), lr=model_params["LEARNING_RATE"] 294 | ) 295 | 296 | # Training loop 297 | console.log(f"[Initiating Fine Tuning]...\n") 298 | 299 | for epoch in range(model_params["TRAIN_EPOCHS"]): 300 | train(epoch, tokenizer, model, device, training_loader, optimizer) 301 | 302 | console.log(f"[Saving Model]...\n") 303 | # Saving the model after training 304 | path = os.path.join(output_dir, "model_files") 305 | model.save_pretrained(path) 306 | tokenizer.save_pretrained(path) 307 | 308 | # evaluating test dataset 309 | console.log(f"[Initiating Validation]...\n") 310 | for epoch in range(model_params["VAL_EPOCHS"]): 311 | predictions, actuals = validate(epoch, tokenizer, model, device, val_loader) 312 | final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals}) 313 | final_df.to_csv(os.path.join(output_dir, "predictions.csv")) 314 | 315 | console.save_text(os.path.join(output_dir, "logs.txt")) 316 | 317 | console.log(f"[Validation Completed.]\n") 318 | console.print( 319 | f"""[Model] Model saved @ {os.path.join(output_dir, "model_files")}\n""" 320 | ) 321 | console.print( 322 | f"""[Validation] Generation on Validation data saved @ {os.path.join(output_dir,'predictions.csv')}\n""" 323 | ) 324 | console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,'logs.txt')}\n""") 325 | ``` 326 | ### Model Parameters 327 | `model_params` is a dictionary containing model paramters for T5 training: 328 | 329 | - `MODEL: "t5-base"`, model_type: t5-base/t5-large 330 | - `TRAIN_BATCH_SIZE: 8`, training batch size 331 | - `VALID_BATCH_SIZE: 8`, validation batch size 332 | - `TRAIN_EPOCHS: 3`, number of training epochs 333 | - `VAL_EPOCHS: 1`, number of validation epochs 334 | - `LEARNING_RATE: 1e-4`, learning rate 335 | - `MAX_SOURCE_TEXT_LENGTH: 512`, max length of source text 336 | - `MAX_TARGET_TEXT_LENGTH: 50`, max length of target text 337 | - `SEED: 42`, set seed for reproducibility 338 | 339 | ```python 340 | 341 | # let's define model parameters specific to T5 342 | model_params = { 343 | "MODEL": "t5-base", # model_type: t5-base/t5-large 344 | "TRAIN_BATCH_SIZE": 8, # training batch size 345 | "VALID_BATCH_SIZE": 8, # validation batch size 346 | "TRAIN_EPOCHS": 3, # number of training epochs 347 | "VAL_EPOCHS": 1, # number of validation epochs 348 | "LEARNING_RATE": 1e-4, # learning rate 349 | "MAX_SOURCE_TEXT_LENGTH": 512, # max length of source text 350 | "MAX_TARGET_TEXT_LENGTH": 50, # max length of target text 351 | "SEED": 42, # set seed for reproducibility 352 | } 353 | ``` 354 | 355 | ### Let's call T5Trainer 356 | 357 | ```python 358 | # let's get a news summary dataset 359 | # dataframe has 2 columns: 360 | # - text: long article content 361 | # - headlines: one line summary of news 362 | path = "https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv" 363 | 364 | df = pd.read_csv(path) 365 | 366 | 367 | # T5 accepts prefix of the task to be performed: 368 | # Since we are summarizing, let's add summarize to source text as a prefix 369 | df["text"] = "summarize: " + df["text"] 370 | 371 | T5Trainer( 372 | dataframe=df, 373 | source_text="text", 374 | target_text="headlines", 375 | model_params=model_params, 376 | output_dir="outputs", 377 | ) 378 | 379 | ``` 380 | ```ASCII 381 | Training Status 382 | +--------------------------------------------------------------------------+ 383 | |Epoch | Steps | Loss | 384 | |------+-------+-----------------------------------------------------------| 385 | | 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)| 386 | | 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)| 387 | | 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)| 388 | | 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)| 389 | | 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)| 390 | | 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)| 391 | | 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)| 392 | | 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)| 393 | | 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)| 394 | | 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)| 395 | | 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)| 396 | | 2 | 10 | tensor(1.1971, device='cuda:0', grad_fn=)| 397 | | 2 | 20 | tensor(1.6908, device='cuda:0', grad_fn=)| 398 | | 2 | 30 | tensor(1.4069, device='cuda:0', grad_fn=)| 399 | | 2 | 40 | tensor(2.1261, device='cuda:0', grad_fn=)| 400 | +--------------------------------------------------------------------------+ 401 | ``` -------------------------------------------------------------------------------- /notebook/T5_Fine_tuning_with_PyTorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "T5 Fine tuning with PyTorch.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyM+sqU2Hgca8RM/Wjv+9kvQ", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "colab": { 33 | "base_uri": "https://localhost:8080/" 34 | }, 35 | "id": "Bbwl6E1E205R", 36 | "outputId": "7ab0c99b-8f25-4f48-a74b-ba54bf93c6bf" 37 | }, 38 | "source": [ 39 | "!pip install sentencepiece\r\n", 40 | "!pip install transformers\r\n", 41 | "!pip install rich[jupyter]" 42 | ], 43 | "execution_count": 1, 44 | "outputs": [ 45 | { 46 | "output_type": "stream", 47 | "text": [ 48 | "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (0.1.95)\n", 49 | "Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (4.3.2)\n", 50 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from transformers) (3.4.0)\n", 51 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n", 52 | "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)\n", 53 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.19.5)\n", 54 | "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.10.1)\n", 55 | "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n", 56 | "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.9)\n", 57 | "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.8)\n", 58 | "Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.43)\n", 59 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n", 60 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers) (3.4.0)\n", 61 | "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers) (3.7.4.3)\n", 62 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)\n", 63 | "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.0.0)\n", 64 | "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n", 65 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.15.0)\n", 66 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.12.5)\n", 67 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n", 68 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n", 69 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)\n", 70 | "Requirement already satisfied: rich[jupyter] in /usr/local/lib/python3.6/dist-packages (9.10.0)\n", 71 | "Requirement already satisfied: commonmark<0.10.0,>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (0.9.1)\n", 72 | "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (2.6.1)\n", 73 | "Requirement already satisfied: dataclasses<0.9,>=0.7; python_version >= \"3.6\" and python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (0.8)\n", 74 | "Requirement already satisfied: typing-extensions<4.0.0,>=3.7.4 in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (3.7.4.3)\n", 75 | "Requirement already satisfied: colorama<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (0.4.4)\n", 76 | "Requirement already satisfied: ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\" in /usr/local/lib/python3.6/dist-packages (from rich[jupyter]) (7.6.3)\n", 77 | "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.1.2)\n", 78 | "Requirement already satisfied: jupyterlab-widgets>=1.0.0; python_version >= \"3.6\" in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.0.0)\n", 79 | "Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.5.0)\n", 80 | "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (3.5.1)\n", 81 | "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (4.3.3)\n", 82 | "Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.6/dist-packages (from ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (4.10.1)\n", 83 | "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (4.7.1)\n", 84 | "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (2.6.0)\n", 85 | "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.2.0)\n", 86 | "Requirement already satisfied: decorator in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (4.4.2)\n", 87 | "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (53.0.0)\n", 88 | "Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.0.18)\n", 89 | "Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.8.1)\n", 90 | "Requirement already satisfied: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.7.5)\n", 91 | "Requirement already satisfied: pexpect; sys_platform != \"win32\" in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (4.8.0)\n", 92 | "Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.6/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.3.1)\n", 93 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from traitlets>=4.3.1->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.15.0)\n", 94 | "Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.6/dist-packages (from ipykernel>=4.5.1->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.1.1)\n", 95 | "Requirement already satisfied: jupyter-client in /usr/local/lib/python3.6/dist-packages (from ipykernel>=4.5.1->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.3.5)\n", 96 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.2.5)\n", 97 | "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.7.0)\n", 98 | "Requirement already satisfied: nbconvert in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (5.6.1)\n", 99 | "Requirement already satisfied: Send2Trash in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.5.0)\n", 100 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (2.11.3)\n", 101 | "Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.9.2)\n", 102 | "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (2.8.1)\n", 103 | "Requirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.6/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (22.0.2)\n", 104 | "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.3)\n", 105 | "Requirement already satisfied: bleach in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (3.3.0)\n", 106 | "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.8.4)\n", 107 | "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.4.3)\n", 108 | "Requirement already satisfied: defusedxml in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.6.0)\n", 109 | "Requirement already satisfied: testpath in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.4.4)\n", 110 | "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (1.1.1)\n", 111 | "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (20.9)\n", 112 | "Requirement already satisfied: webencodings in /usr/local/lib/python3.6/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (0.5.1)\n", 113 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"->rich[jupyter]) (2.4.7)\n" 114 | ], 115 | "name": "stdout" 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "y6nEben93JAk" 123 | }, 124 | "source": [ 125 | "import pandas as pd\r\n", 126 | "df = pd.read_csv(\"https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv\")" 127 | ], 128 | "execution_count": 2, 129 | "outputs": [] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "colab": { 135 | "base_uri": "https://localhost:8080/", 136 | "height": 359 137 | }, 138 | "id": "Suxgy7wC4IqL", 139 | "outputId": "5725be18-918f-4fad-c6f4-587fd50899af" 140 | }, 141 | "source": [ 142 | "df.sample(10)" 143 | ], 144 | "execution_count": 3, 145 | "outputs": [ 146 | { 147 | "output_type": "execute_result", 148 | "data": { 149 | "text/html": [ 150 | "
\n", 151 | "\n", 164 | "\n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | "
headlinestext
23331US flag replaced with Nazi flag at Wyoming parkAn investigation is underway after someone rep...
60219IOC petrol pump in Nagpur gets electric chargi...The Indian Oil Corporation, in collaboration w...
75911Vitamin B3 could reduce miscarriages, birth de...In a first, a 12-year research by Australian s...
57713US airline honours veteran by putting his name...United Airlines honoured veteran Mark Lehman, ...
93367Emma Watson auditioned eight times for role in...British actress Emma Watson auditioned for the...
26177IDBI Bank officers threaten 6-day strike from ...A section of IDBI Bank officers has threatened...
28164UP CM Yogi refuses to wear cap at Sant Kabir's...Uttar Pradesh CM Yogi Adityanath on Wednesday ...
94552Researchers develop self-learning artificial n...France-based researchers have developed a syna...
36972Farmers protest as banker seeks sexual favour ...Farmers in Maharashtra on Saturday staged a pr...
67957Maharashtra plans to convert INS Viraat into n...The Maharashtra government is planning to acqu...
\n", 225 | "
" 226 | ], 227 | "text/plain": [ 228 | " headlines text\n", 229 | "23331 US flag replaced with Nazi flag at Wyoming park An investigation is underway after someone rep...\n", 230 | "60219 IOC petrol pump in Nagpur gets electric chargi... The Indian Oil Corporation, in collaboration w...\n", 231 | "75911 Vitamin B3 could reduce miscarriages, birth de... In a first, a 12-year research by Australian s...\n", 232 | "57713 US airline honours veteran by putting his name... United Airlines honoured veteran Mark Lehman, ...\n", 233 | "93367 Emma Watson auditioned eight times for role in... British actress Emma Watson auditioned for the...\n", 234 | "26177 IDBI Bank officers threaten 6-day strike from ... A section of IDBI Bank officers has threatened...\n", 235 | "28164 UP CM Yogi refuses to wear cap at Sant Kabir's... Uttar Pradesh CM Yogi Adityanath on Wednesday ...\n", 236 | "94552 Researchers develop self-learning artificial n... France-based researchers have developed a syna...\n", 237 | "36972 Farmers protest as banker seeks sexual favour ... Farmers in Maharashtra on Saturday staged a pr...\n", 238 | "67957 Maharashtra plans to convert INS Viraat into n... The Maharashtra government is planning to acqu..." 239 | ] 240 | }, 241 | "metadata": { 242 | "tags": [] 243 | }, 244 | "execution_count": 3 245 | } 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "id": "AYfBicZQ59Jf" 252 | }, 253 | "source": [ 254 | "df[\"text\"] = \"summarize: \"+df[\"text\"]" 255 | ], 256 | "execution_count": 4, 257 | "outputs": [] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "metadata": { 262 | "id": "81f4PKa1F6aM", 263 | "colab": { 264 | "base_uri": "https://localhost:8080/", 265 | "height": 204 266 | }, 267 | "outputId": "fcf57854-d194-4670-c783-35da1574ec5c" 268 | }, 269 | "source": [ 270 | "df.head()" 271 | ], 272 | "execution_count": 5, 273 | "outputs": [ 274 | { 275 | "output_type": "execute_result", 276 | "data": { 277 | "text/html": [ 278 | "
\n", 279 | "\n", 292 | "\n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | "
headlinestext
0upGrad learner switches to career in ML & Al w...summarize: Saurav Kant, an alumnus of upGrad a...
1Delhi techie wins free food from Swiggy for on...summarize: Kunal Shah's credit card bill payme...
2New Zealand end Rohit Sharma-led India's 12-ma...summarize: New Zealand defeated India by 8 wic...
3Aegon life iTerm insurance plan helps customer...summarize: With Aegon Life iTerm Insurance pla...
4Have known Hirani for yrs, what if MeToo claim...summarize: Speaking about the sexual harassmen...
\n", 328 | "
" 329 | ], 330 | "text/plain": [ 331 | " headlines text\n", 332 | "0 upGrad learner switches to career in ML & Al w... summarize: Saurav Kant, an alumnus of upGrad a...\n", 333 | "1 Delhi techie wins free food from Swiggy for on... summarize: Kunal Shah's credit card bill payme...\n", 334 | "2 New Zealand end Rohit Sharma-led India's 12-ma... summarize: New Zealand defeated India by 8 wic...\n", 335 | "3 Aegon life iTerm insurance plan helps customer... summarize: With Aegon Life iTerm Insurance pla...\n", 336 | "4 Have known Hirani for yrs, what if MeToo claim... summarize: Speaking about the sexual harassmen..." 337 | ] 338 | }, 339 | "metadata": { 340 | "tags": [] 341 | }, 342 | "execution_count": 5 343 | } 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "metadata": { 349 | "id": "wB441x104K-o" 350 | }, 351 | "source": [ 352 | "# Importing libraries\r\n", 353 | "import os\r\n", 354 | "import numpy as np\r\n", 355 | "import pandas as pd\r\n", 356 | "import torch\r\n", 357 | "import torch.nn.functional as F\r\n", 358 | "from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\r\n", 359 | "import os\r\n", 360 | "\r\n", 361 | "# Importing the T5 modules from huggingface/transformers\r\n", 362 | "from transformers import T5Tokenizer, T5ForConditionalGeneration\r\n", 363 | "\r\n", 364 | "from rich.table import Column, Table\r\n", 365 | "from rich import box\r\n", 366 | "from rich.console import Console\r\n", 367 | "\r\n", 368 | "# define a rich console logger\r\n", 369 | "console=Console(record=True)\r\n", 370 | "\r\n", 371 | "def display_df(df):\r\n", 372 | " \"\"\"display dataframe in ASCII format\"\"\"\r\n", 373 | "\r\n", 374 | " console=Console()\r\n", 375 | " table = Table(Column(\"source_text\", justify=\"center\" ), Column(\"target_text\", justify=\"center\"), title=\"Sample Data\",pad_edge=False, box=box.ASCII)\r\n", 376 | "\r\n", 377 | " for i, row in enumerate(df.values.tolist()):\r\n", 378 | " table.add_row(row[0], row[1])\r\n", 379 | "\r\n", 380 | " console.print(table)\r\n", 381 | "\r\n", 382 | "training_logger = Table(Column(\"Epoch\", justify=\"center\" ), \r\n", 383 | " Column(\"Steps\", justify=\"center\"),\r\n", 384 | " Column(\"Loss\", justify=\"center\"), \r\n", 385 | " title=\"Training Status\",pad_edge=False, box=box.ASCII)\r\n" 386 | ], 387 | "execution_count": 6, 388 | "outputs": [] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "metadata": { 393 | "id": "tlYaKW9h4ai_" 394 | }, 395 | "source": [ 396 | "# Setting up the device for GPU usage\r\n", 397 | "from torch import cuda\r\n", 398 | "device = 'cuda' if cuda.is_available() else 'cpu'" 399 | ], 400 | "execution_count": 7, 401 | "outputs": [] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "metadata": { 406 | "id": "8vLQPGAn4v17" 407 | }, 408 | "source": [ 409 | "class YourDataSetClass(Dataset):\r\n", 410 | " \"\"\"\r\n", 411 | " Creating a custom dataset for reading the dataset and \r\n", 412 | " loading it into the dataloader to pass it to the neural network for finetuning the model\r\n", 413 | "\r\n", 414 | " \"\"\"\r\n", 415 | "\r\n", 416 | " def __init__(self, dataframe, tokenizer, source_len, target_len, source_text, target_text):\r\n", 417 | " self.tokenizer = tokenizer\r\n", 418 | " self.data = dataframe\r\n", 419 | " self.source_len = source_len\r\n", 420 | " self.summ_len = target_len\r\n", 421 | " self.target_text = self.data[target_text]\r\n", 422 | " self.source_text = self.data[source_text]\r\n", 423 | "\r\n", 424 | " def __len__(self):\r\n", 425 | " return len(self.target_text)\r\n", 426 | "\r\n", 427 | " def __getitem__(self, index):\r\n", 428 | " source_text = str(self.source_text[index])\r\n", 429 | " target_text = str(self.target_text[index])\r\n", 430 | "\r\n", 431 | " #cleaning data so as to ensure data is in string type\r\n", 432 | " source_text = ' '.join(source_text.split())\r\n", 433 | " target_text = ' '.join(target_text.split())\r\n", 434 | "\r\n", 435 | " source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, pad_to_max_length=True, truncation=True, padding=\"max_length\", return_tensors='pt')\r\n", 436 | " target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, pad_to_max_length=True, truncation=True, padding=\"max_length\", return_tensors='pt')\r\n", 437 | "\r\n", 438 | " source_ids = source['input_ids'].squeeze()\r\n", 439 | " source_mask = source['attention_mask'].squeeze()\r\n", 440 | " target_ids = target['input_ids'].squeeze()\r\n", 441 | " target_mask = target['attention_mask'].squeeze()\r\n", 442 | "\r\n", 443 | " return {\r\n", 444 | " 'source_ids': source_ids.to(dtype=torch.long), \r\n", 445 | " 'source_mask': source_mask.to(dtype=torch.long), \r\n", 446 | " 'target_ids': target_ids.to(dtype=torch.long),\r\n", 447 | " 'target_ids_y': target_ids.to(dtype=torch.long)\r\n", 448 | " }" 449 | ], 450 | "execution_count": 8, 451 | "outputs": [] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "metadata": { 456 | "id": "Nkj6wIMt40RK" 457 | }, 458 | "source": [ 459 | "def train(epoch, tokenizer, model, device, loader, optimizer):\r\n", 460 | "\r\n", 461 | " \"\"\"\r\n", 462 | " Function to be called for training with the parameters passed from main function\r\n", 463 | "\r\n", 464 | " \"\"\"\r\n", 465 | "\r\n", 466 | " model.train()\r\n", 467 | " for _,data in enumerate(loader, 0):\r\n", 468 | " y = data['target_ids'].to(device, dtype = torch.long)\r\n", 469 | " y_ids = y[:, :-1].contiguous()\r\n", 470 | " lm_labels = y[:, 1:].clone().detach()\r\n", 471 | " lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100\r\n", 472 | " ids = data['source_ids'].to(device, dtype = torch.long)\r\n", 473 | " mask = data['source_mask'].to(device, dtype = torch.long)\r\n", 474 | "\r\n", 475 | " outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)\r\n", 476 | " loss = outputs[0]\r\n", 477 | "\r\n", 478 | " if _%10==0:\r\n", 479 | " training_logger.add_row(str(epoch), str(_), str(loss))\r\n", 480 | " console.print(training_logger)\r\n", 481 | "\r\n", 482 | " optimizer.zero_grad()\r\n", 483 | " loss.backward()\r\n", 484 | " optimizer.step()" 485 | ], 486 | "execution_count": 9, 487 | "outputs": [] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "metadata": { 492 | "id": "GUBykK-A43DF" 493 | }, 494 | "source": [ 495 | "def validate(epoch, tokenizer, model, device, loader):\r\n", 496 | "\r\n", 497 | " \"\"\"\r\n", 498 | " Function to evaluate model for predictions\r\n", 499 | "\r\n", 500 | " \"\"\"\r\n", 501 | " model.eval()\r\n", 502 | " predictions = []\r\n", 503 | " actuals = []\r\n", 504 | " with torch.no_grad():\r\n", 505 | " for _, data in enumerate(loader, 0):\r\n", 506 | " y = data['target_ids'].to(device, dtype = torch.long)\r\n", 507 | " ids = data['source_ids'].to(device, dtype = torch.long)\r\n", 508 | " mask = data['source_mask'].to(device, dtype = torch.long)\r\n", 509 | "\r\n", 510 | " generated_ids = model.generate(\r\n", 511 | " input_ids = ids,\r\n", 512 | " attention_mask = mask, \r\n", 513 | " max_length=150, \r\n", 514 | " num_beams=2,\r\n", 515 | " repetition_penalty=2.5, \r\n", 516 | " length_penalty=1.0, \r\n", 517 | " early_stopping=True\r\n", 518 | " )\r\n", 519 | " preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]\r\n", 520 | " target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]\r\n", 521 | " if _%10==0:\r\n", 522 | " console.print(f'Completed {_}')\r\n", 523 | "\r\n", 524 | " predictions.extend(preds)\r\n", 525 | " actuals.extend(target)\r\n", 526 | " return predictions, actuals" 527 | ], 528 | "execution_count": 10, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "metadata": { 534 | "id": "V5L4wr3h4612" 535 | }, 536 | "source": [ 537 | "" 538 | ], 539 | "execution_count": 10, 540 | "outputs": [] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "metadata": { 545 | "id": "Tw4RW_qO4_8T" 546 | }, 547 | "source": [ 548 | "def T5Trainer(dataframe, source_text, target_text, model_params, output_dir=\"./outputs/\" ):\r\n", 549 | " \r\n", 550 | " \"\"\"\r\n", 551 | " T5 trainer\r\n", 552 | "\r\n", 553 | " \"\"\"\r\n", 554 | "\r\n", 555 | " # Set random seeds and deterministic pytorch for reproducibility\r\n", 556 | " torch.manual_seed(model_params[\"SEED\"]) # pytorch random seed\r\n", 557 | " np.random.seed(model_params[\"SEED\"]) # numpy random seed\r\n", 558 | " torch.backends.cudnn.deterministic = True\r\n", 559 | "\r\n", 560 | " # logging\r\n", 561 | " console.log(f\"\"\"[Model]: Loading {model_params[\"MODEL\"]}...\\n\"\"\")\r\n", 562 | "\r\n", 563 | " # tokenzier for encoding the text\r\n", 564 | " tokenizer = T5Tokenizer.from_pretrained(model_params[\"MODEL\"])\r\n", 565 | "\r\n", 566 | " # Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary. \r\n", 567 | " # Further this model is sent to device (GPU/TPU) for using the hardware.\r\n", 568 | " model = T5ForConditionalGeneration.from_pretrained(model_params[\"MODEL\"])\r\n", 569 | " model = model.to(device)\r\n", 570 | " \r\n", 571 | " # logging\r\n", 572 | " console.log(f\"[Data]: Reading data...\\n\")\r\n", 573 | "\r\n", 574 | " # Importing the raw dataset\r\n", 575 | " dataframe = dataframe[[source_text,target_text]]\r\n", 576 | " display_df(dataframe.head(2))\r\n", 577 | "\r\n", 578 | " \r\n", 579 | " # Creation of Dataset and Dataloader\r\n", 580 | " # Defining the train size. So 80% of the data will be used for training and the rest for validation. \r\n", 581 | " train_size = 0.8\r\n", 582 | " train_dataset=dataframe.sample(frac=train_size,random_state = model_params[\"SEED\"])\r\n", 583 | " val_dataset=dataframe.drop(train_dataset.index).reset_index(drop=True)\r\n", 584 | " train_dataset = train_dataset.reset_index(drop=True)\r\n", 585 | "\r\n", 586 | " console.print(f\"FULL Dataset: {dataframe.shape}\")\r\n", 587 | " console.print(f\"TRAIN Dataset: {train_dataset.shape}\")\r\n", 588 | " console.print(f\"TEST Dataset: {val_dataset.shape}\\n\")\r\n", 589 | "\r\n", 590 | "\r\n", 591 | " # Creating the Training and Validation dataset for further creation of Dataloader\r\n", 592 | " training_set = YourDataSetClass(train_dataset, tokenizer, model_params[\"MAX_SOURCE_TEXT_LENGTH\"], model_params[\"MAX_TARGET_TEXT_LENGTH\"], source_text, target_text)\r\n", 593 | " val_set = YourDataSetClass(val_dataset, tokenizer, model_params[\"MAX_SOURCE_TEXT_LENGTH\"], model_params[\"MAX_TARGET_TEXT_LENGTH\"], source_text, target_text)\r\n", 594 | "\r\n", 595 | "\r\n", 596 | " # Defining the parameters for creation of dataloaders\r\n", 597 | " train_params = {\r\n", 598 | " 'batch_size': model_params[\"TRAIN_BATCH_SIZE\"],\r\n", 599 | " 'shuffle': True,\r\n", 600 | " 'num_workers': 0\r\n", 601 | " }\r\n", 602 | "\r\n", 603 | "\r\n", 604 | " val_params = {\r\n", 605 | " 'batch_size': model_params[\"VALID_BATCH_SIZE\"],\r\n", 606 | " 'shuffle': False,\r\n", 607 | " 'num_workers': 0\r\n", 608 | " }\r\n", 609 | "\r\n", 610 | "\r\n", 611 | " # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.\r\n", 612 | " training_loader = DataLoader(training_set, **train_params)\r\n", 613 | " val_loader = DataLoader(val_set, **val_params)\r\n", 614 | "\r\n", 615 | "\r\n", 616 | " # Defining the optimizer that will be used to tune the weights of the network in the training session. \r\n", 617 | " optimizer = torch.optim.Adam(params = model.parameters(), lr=model_params[\"LEARNING_RATE\"])\r\n", 618 | "\r\n", 619 | "\r\n", 620 | " # Training loop\r\n", 621 | " console.log(f'[Initiating Fine Tuning]...\\n')\r\n", 622 | "\r\n", 623 | " for epoch in range(model_params[\"TRAIN_EPOCHS\"]):\r\n", 624 | " train(epoch, tokenizer, model, device, training_loader, optimizer)\r\n", 625 | " \r\n", 626 | " console.log(f\"[Saving Model]...\\n\")\r\n", 627 | " #Saving the model after training\r\n", 628 | " path = os.path.join(output_dir, \"model_files\")\r\n", 629 | " model.save_pretrained(path)\r\n", 630 | " tokenizer.save_pretrained(path)\r\n", 631 | "\r\n", 632 | "\r\n", 633 | " # evaluating test dataset\r\n", 634 | " console.log(f\"[Initiating Validation]...\\n\")\r\n", 635 | " for epoch in range(model_params[\"VAL_EPOCHS\"]):\r\n", 636 | " predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)\r\n", 637 | " final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})\r\n", 638 | " final_df.to_csv(os.path.join(output_dir,'predictions.csv'))\r\n", 639 | " \r\n", 640 | " console.save_text(os.path.join(output_dir,'logs.txt'))\r\n", 641 | " \r\n", 642 | " console.log(f\"[Validation Completed.]\\n\")\r\n", 643 | " console.print(f\"\"\"[Model] Model saved @ {os.path.join(output_dir, \"model_files\")}\\n\"\"\")\r\n", 644 | " console.print(f\"\"\"[Validation] Generation on Validation data saved @ {os.path.join(output_dir,'predictions.csv')}\\n\"\"\")\r\n", 645 | " console.print(f\"\"\"[Logs] Logs saved @ {os.path.join(output_dir,'logs.txt')}\\n\"\"\")" 646 | ], 647 | "execution_count": 11, 648 | "outputs": [] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "metadata": { 653 | "id": "PxCpQwD8PDIs" 654 | }, 655 | "source": [ 656 | "model_params={\r\n", 657 | " \"MODEL\":\"t5-base\", # model_type: t5-base/t5-large\r\n", 658 | " \"TRAIN_BATCH_SIZE\":8, # training batch size\r\n", 659 | " \"VALID_BATCH_SIZE\":8, # validation batch size\r\n", 660 | " \"TRAIN_EPOCHS\":3, # number of training epochs\r\n", 661 | " \"VAL_EPOCHS\":1, # number of validation epochs\r\n", 662 | " \"LEARNING_RATE\":1e-4, # learning rate\r\n", 663 | " \"MAX_SOURCE_TEXT_LENGTH\":512, # max length of source text\r\n", 664 | " \"MAX_TARGET_TEXT_LENGTH\":50, # max length of target text\r\n", 665 | " \"SEED\": 42 # set seed for reproducibility \r\n", 666 | "\r\n", 667 | "}" 668 | ], 669 | "execution_count": 12, 670 | "outputs": [] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "metadata": { 675 | "colab": { 676 | "base_uri": "https://localhost:8080/" 677 | }, 678 | "id": "qijZoYeI55fM", 679 | "outputId": "69c68bb6-4fba-47e4-9e74-73f2579aa3c8" 680 | }, 681 | "source": [ 682 | "T5Trainer(dataframe=df[:500], source_text=\"text\", target_text=\"headlines\", model_params=model_params, output_dir=\"outputs\")" 683 | ], 684 | "execution_count": 13, 685 | "outputs": [ 686 | { 687 | "output_type": "stream", 688 | "text": [ 689 | "[12:10:09] [Model]: Loading t5-base... :14\n", 690 | " \n", 691 | "[12:10:21] [Data]: Reading data... :25\n", 692 | " \n", 693 | " Sample Data \n", 694 | "+------------------------------------------------------------------------------+\n", 695 | "| source_text | target_text |\n", 696 | "|--------------------------------------+---------------------------------------|\n", 697 | "|summarize: Saurav Kant, an alumnus of | upGrad learner switches to career in |\n", 698 | "| upGrad and IIIT-B's PG Program in | ML & Al with 90% salary hike |\n", 699 | "| Machine learning and Artificial | |\n", 700 | "| Intelligence, was a Sr Systems | |\n", 701 | "| Engineer at Infosys with almost 5 | |\n", 702 | "|years of work experience. The program | |\n", 703 | "| and upGrad's 360-degree career | |\n", 704 | "| support helped him transition to a | |\n", 705 | "|Data Scientist at Tech Mahindra with | |\n", 706 | "| 90% salary hike. upGrad's Online | |\n", 707 | "| Power Learning has powered 3 lakh+ | |\n", 708 | "| careers. | |\n", 709 | "| summarize: Kunal Shah's credit card | Delhi techie wins free food from |\n", 710 | "| bill payment platform, CRED, gave | Swiggy for one year on CRED |\n", 711 | "|users a chance to win free food from | |\n", 712 | "|Swiggy for one year. Pranav Kaushik, | |\n", 713 | "| a Delhi techie, bagged this reward | |\n", 714 | "|after spending 2000 CRED coins. Users | |\n", 715 | "| get one CRED coin per rupee of bill | |\n", 716 | "| paid, which can be used to avail | |\n", 717 | "| rewards from brands like Ixigo, | |\n", 718 | "| BookMyShow, UberEats, Cult.Fit and | |\n", 719 | "| more. | |\n", 720 | "+------------------------------------------------------------------------------+\n", 721 | "FULL Dataset: (500, 2)\n", 722 | "TRAIN Dataset: (400, 2)\n", 723 | "TEST Dataset: (100, 2)\n", 724 | "\n", 725 | " [Initiating Fine Tuning]... :74\n", 726 | " \n", 727 | " Training Status \n", 728 | "+--------------------------------------------------------------------------+\n", 729 | "|Epoch | Steps | Loss |\n", 730 | "|------+-------+-----------------------------------------------------------|\n", 731 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 732 | "+--------------------------------------------------------------------------+\n", 733 | " Training Status \n", 734 | "+--------------------------------------------------------------------------+\n", 735 | "|Epoch | Steps | Loss |\n", 736 | "|------+-------+-----------------------------------------------------------|\n", 737 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 738 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 739 | "+--------------------------------------------------------------------------+\n", 740 | " Training Status \n", 741 | "+--------------------------------------------------------------------------+\n", 742 | "|Epoch | Steps | Loss |\n", 743 | "|------+-------+-----------------------------------------------------------|\n", 744 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 745 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 746 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 747 | "+--------------------------------------------------------------------------+\n", 748 | " Training Status \n", 749 | "+--------------------------------------------------------------------------+\n", 750 | "|Epoch | Steps | Loss |\n", 751 | "|------+-------+-----------------------------------------------------------|\n", 752 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 753 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 754 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 755 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 756 | "+--------------------------------------------------------------------------+\n", 757 | " Training Status \n", 758 | "+--------------------------------------------------------------------------+\n", 759 | "|Epoch | Steps | Loss |\n", 760 | "|------+-------+-----------------------------------------------------------|\n", 761 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 762 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 763 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 764 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 765 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 766 | "+--------------------------------------------------------------------------+\n", 767 | " Training Status \n", 768 | "+--------------------------------------------------------------------------+\n", 769 | "|Epoch | Steps | Loss |\n", 770 | "|------+-------+-----------------------------------------------------------|\n", 771 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 772 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 773 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 774 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 775 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 776 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 777 | "+--------------------------------------------------------------------------+\n", 778 | " Training Status \n", 779 | "+--------------------------------------------------------------------------+\n", 780 | "|Epoch | Steps | Loss |\n", 781 | "|------+-------+-----------------------------------------------------------|\n", 782 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 783 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 784 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 785 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 786 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 787 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 788 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 789 | "+--------------------------------------------------------------------------+\n", 790 | " Training Status \n", 791 | "+--------------------------------------------------------------------------+\n", 792 | "|Epoch | Steps | Loss |\n", 793 | "|------+-------+-----------------------------------------------------------|\n", 794 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 795 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 796 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 797 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 798 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 799 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 800 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 801 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 802 | "+--------------------------------------------------------------------------+\n", 803 | " Training Status \n", 804 | "+--------------------------------------------------------------------------+\n", 805 | "|Epoch | Steps | Loss |\n", 806 | "|------+-------+-----------------------------------------------------------|\n", 807 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 808 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 809 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 810 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 811 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 812 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 813 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 814 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 815 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 816 | "+--------------------------------------------------------------------------+\n", 817 | " Training Status \n", 818 | "+--------------------------------------------------------------------------+\n", 819 | "|Epoch | Steps | Loss |\n", 820 | "|------+-------+-----------------------------------------------------------|\n", 821 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 822 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 823 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 824 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 825 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 826 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 827 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 828 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 829 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 830 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 831 | "+--------------------------------------------------------------------------+\n", 832 | " Training Status \n", 833 | "+--------------------------------------------------------------------------+\n", 834 | "|Epoch | Steps | Loss |\n", 835 | "|------+-------+-----------------------------------------------------------|\n", 836 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 837 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 838 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 839 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 840 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 841 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 842 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 843 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 844 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 845 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 846 | "| 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)|\n", 847 | "+--------------------------------------------------------------------------+\n", 848 | " Training Status \n", 849 | "+--------------------------------------------------------------------------+\n", 850 | "|Epoch | Steps | Loss |\n", 851 | "|------+-------+-----------------------------------------------------------|\n", 852 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 853 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 854 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 855 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 856 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 857 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 858 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 859 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 860 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 861 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 862 | "| 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)|\n", 863 | "| 2 | 10 | tensor(1.1971, device='cuda:0', grad_fn=)|\n", 864 | "+--------------------------------------------------------------------------+\n", 865 | " Training Status \n", 866 | "+--------------------------------------------------------------------------+\n", 867 | "|Epoch | Steps | Loss |\n", 868 | "|------+-------+-----------------------------------------------------------|\n", 869 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 870 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 871 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 872 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 873 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 874 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 875 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 876 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 877 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 878 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 879 | "| 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)|\n", 880 | "| 2 | 10 | tensor(1.1971, device='cuda:0', grad_fn=)|\n", 881 | "| 2 | 20 | tensor(1.6908, device='cuda:0', grad_fn=)|\n", 882 | "+--------------------------------------------------------------------------+\n", 883 | " Training Status \n", 884 | "+--------------------------------------------------------------------------+\n", 885 | "|Epoch | Steps | Loss |\n", 886 | "|------+-------+-----------------------------------------------------------|\n", 887 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 888 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 889 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 890 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 891 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 892 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 893 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 894 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 895 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 896 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 897 | "| 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)|\n", 898 | "| 2 | 10 | tensor(1.1971, device='cuda:0', grad_fn=)|\n", 899 | "| 2 | 20 | tensor(1.6908, device='cuda:0', grad_fn=)|\n", 900 | "| 2 | 30 | tensor(1.4069, device='cuda:0', grad_fn=)|\n", 901 | "+--------------------------------------------------------------------------+\n", 902 | " Training Status \n", 903 | "+--------------------------------------------------------------------------+\n", 904 | "|Epoch | Steps | Loss |\n", 905 | "|------+-------+-----------------------------------------------------------|\n", 906 | "| 0 | 0 | tensor(8.5338, device='cuda:0', grad_fn=)|\n", 907 | "| 0 | 10 | tensor(3.4278, device='cuda:0', grad_fn=)|\n", 908 | "| 0 | 20 | tensor(3.0148, device='cuda:0', grad_fn=)|\n", 909 | "| 0 | 30 | tensor(3.2338, device='cuda:0', grad_fn=)|\n", 910 | "| 0 | 40 | tensor(2.5963, device='cuda:0', grad_fn=)|\n", 911 | "| 1 | 0 | tensor(2.2411, device='cuda:0', grad_fn=)|\n", 912 | "| 1 | 10 | tensor(1.9470, device='cuda:0', grad_fn=)|\n", 913 | "| 1 | 20 | tensor(1.9091, device='cuda:0', grad_fn=)|\n", 914 | "| 1 | 30 | tensor(2.0122, device='cuda:0', grad_fn=)|\n", 915 | "| 1 | 40 | tensor(1.5261, device='cuda:0', grad_fn=)|\n", 916 | "| 2 | 0 | tensor(1.6496, device='cuda:0', grad_fn=)|\n", 917 | "| 2 | 10 | tensor(1.1971, device='cuda:0', grad_fn=)|\n", 918 | "| 2 | 20 | tensor(1.6908, device='cuda:0', grad_fn=)|\n", 919 | "| 2 | 30 | tensor(1.4069, device='cuda:0', grad_fn=)|\n", 920 | "| 2 | 40 | tensor(2.1261, device='cuda:0', grad_fn=)|\n", 921 | "+--------------------------------------------------------------------------+\n", 922 | "[12:12:58] [Saving Model]... :79\n", 923 | " \n", 924 | "[12:13:02] [Initiating Validation]... :87\n", 925 | " \n", 926 | "Completed 0\n", 927 | "Completed 10\n", 928 | "[12:13:41] [Validation Completed.] :95\n", 929 | " \n", 930 | "[Model] Model saved @ outputs/model_files\n", 931 | "\n", 932 | "[Validation] Validation data saved @ outputs/predictions.csv\n", 933 | "\n", 934 | "[Logs] Logs saved @ outputs/logs.txt\n", 935 | "\n" 936 | ], 937 | "name": "stdout" 938 | } 939 | ] 940 | }, 941 | { 942 | "cell_type": "code", 943 | "metadata": { 944 | "id": "XD2qL87Wsn19" 945 | }, 946 | "source": [ 947 | "" 948 | ], 949 | "execution_count": 14, 950 | "outputs": [] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "metadata": { 955 | "id": "u7SFz-6usqym" 956 | }, 957 | "source": [ 958 | "" 959 | ], 960 | "execution_count": 14, 961 | "outputs": [] 962 | } 963 | ] 964 | } --------------------------------------------------------------------------------