├── output └── README.md ├── data └── README.md ├── generate.py ├── T5.py ├── dataset.py ├── post-processing.ipynb ├── README.md └── train.py /output/README.md: -------------------------------------------------------------------------------- 1 | ## My result for training 2 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Please download train.tsv and dev.tsv below the link and put the files in this directory. 2 | 3 | https://drive.google.com/drive/folders/1n9yzmli8YLHq7bjQD1yvQRUCO0J1uv0w?usp=sharing 4 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import T5 3 | import dataset 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | #Random Seed 11 | random_seed = 1 12 | torch.manual_seed(random_seed) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = False 15 | np.random.seed(random_seed) 16 | random.seed(random_seed) 17 | torch.cuda.manual_seed(random_seed) 18 | 19 | # device 20 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | 22 | # Data file path 23 | dev_path = 'data/test.tsv' 24 | output_path = 'output_question_infilling_and_delete_wh_backtranslation_fine' 25 | 26 | if not os.path.exists(output_path): 27 | os.makedir(output_path) 28 | 29 | # config 30 | batch_size = 16 31 | 32 | # model. tokenizer init 33 | model = T5.T5ConditionalGeneration().to(device) 34 | tokenizer = model.tokenizer 35 | 36 | # dataset 37 | dev_dataset = dataset.T5QGDataset(dev_path, tokenizer) 38 | dev_dataloader = DataLoader(dev_dataset, batch_size) 39 | 40 | count = 2 41 | while(True): 42 | count += 1 43 | model.load_state_dict(torch.load(os.path.join(output_path, f't5_epoch_{count}.pth'))) 44 | model.eval() 45 | 46 | with open(os.path.join(output_path, f'output_{count}.txt'), 'w', encoding='utf-8') as f: 47 | for step_index, batch_data in tqdm( enumerate(dev_dataloader), f"[GENERATE]", total=len(dev_dataloader)): 48 | 49 | input_ids, decoder_input_ids, labels = tuple(value.to(device) for value in batch_data.values()) 50 | 51 | output = model.model.generate(input_ids=input_ids, eos_token_id=tokenizer.eos_token_id, max_length=100, num_beams=5) 52 | 53 | for o in output: 54 | o = tokenizer.decode(o, skip_special_tokens=True) 55 | o = o.replace(' ##', '').replace('##', '').strip() 56 | f.write(o+'\n') 57 | -------------------------------------------------------------------------------- /T5.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from transformers import T5ForConditionalGeneration, PreTrainedTokenizerFast 4 | 5 | class T5ConditionalGeneration(nn.Module): 6 | def __init__(self, device = 'cuda' if torch.cuda.is_available() else 'cpu'): 7 | super().__init__() 8 | self.model = T5ForConditionalGeneration.from_pretrained('t5-large').to('cpu') #to(device) 9 | self.model.train() 10 | self.bos_token = '' 11 | self.eos_token = '' 12 | self.sep_token = '[SEP]' 13 | self.pad_token = '' 14 | self.mask_token = '' 15 | self.highlight_token = '' 16 | self.device = device 17 | 18 | ''' 19 | self.tokenizer = PreTrainedTokenizerFast.from_pretrained('t5-large') 20 | self.tokenizer.add_special_tokens({'sep_token': '[SEP]'}) 21 | self.tokenizer.add_special_tokens({'pad_token': ''}) 22 | self.tokenizer.add_special_tokens({'bos_token': ''}) 23 | self.tokenizer.add_special_tokens({'eos_token': ''}) 24 | self.tokenizer.add_special_tokens({'mask_token': ''}) 25 | self.tokenizer.SPECIAL_TOKENS_ATTRIBUTES.append('highlight_token') 26 | self.tokenizer.add_special_tokens({'highlight_token': ''}) 27 | self.pad_token_id = self.tokenizer.pad_token_id 28 | 29 | self.model.resize_token_embeddings(len(self.tokenizer.vocab)) 30 | ''' 31 | self.model.resize_token_embeddings(32104) 32 | 33 | def forward(self, input_ids, decoder_input_ids, labels): 34 | attention_mask = input_ids.ne(self.pad_token_id).float() 35 | decoder_attention_mask = decoder_input_ids.ne(self.pad_token_id).float() 36 | 37 | return self.model(input_ids=input_ids, 38 | attention_mask=attention_mask, 39 | decoder_input_ids=decoder_input_ids, 40 | decoder_attention_mask=decoder_attention_mask, 41 | labels=labels, return_dict=True) 42 | 43 | def training_step(self, batch, batch_idx): 44 | outs = self(batch) 45 | loss = outs.loss 46 | self.log('train_loss', loss, prog_bar=True) 47 | return loss 48 | 49 | def validation_step(self, batch, batch_idx): 50 | outs = self(batch) 51 | loss = outs['loss'] 52 | return (loss) 53 | 54 | def validation_epoch_end(self, outputs): 55 | losses = [] 56 | for loss in outputs: 57 | losses.append(loss) 58 | self.log('val_loss', torch.stack(losses).mean(), prog_bar=True) 59 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | import random 5 | import torch 6 | 7 | #Random Seed 8 | random_seed = 1 9 | torch.manual_seed(random_seed) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | np.random.seed(random_seed) 13 | random.seed(random_seed) 14 | torch.cuda.manual_seed(random_seed) 15 | 16 | 17 | class T5QGDataset(Dataset): 18 | def __init__(self, file, tokenizer, max_len = 256, ignore_index=-100): 19 | super().__init__() 20 | self.tokenizer = tokenizer 21 | self.max_len = max_len 22 | self.docs = pd.read_csv(file, sep='\t', encoding='utf-8') 23 | self.len = self.docs.shape[0] 24 | 25 | self.pad_index = self.tokenizer.pad_token_id 26 | self.ignore_index = ignore_index 27 | 28 | def add_padding_data(self, inputs): 29 | if len(inputs) < self.max_len: 30 | pad = np.array([self.pad_index] * (self.max_len - len(inputs))) 31 | inputs = np.concatenate([inputs, pad]) 32 | else: 33 | inputs = inputs[:self.max_len] 34 | 35 | return inputs 36 | 37 | def add_ignored_data(self, inputs): 38 | if len(inputs) < self.max_len: 39 | pad = np.array([self.ignore_index] * (self.max_len - len(inputs))) 40 | inputs = np.concatenate([inputs, pad]) 41 | else: 42 | inputs = inputs[:self.max_len] 43 | 44 | return inputs 45 | 46 | def __getitem__(self, idx): 47 | instance = self.docs.iloc[idx] 48 | content = instance['content'] 49 | question = instance['question'].strip() 50 | 51 | sep_index = content.find('[SEP]') 52 | 53 | answer = content[sep_index + 6::].strip() 54 | content = content[:sep_index].strip() 55 | 56 | prefix_content_token_id = self.tokenizer.encode('content:', add_special_tokens=False) 57 | prefix_answer_token_id = self.tokenizer.encode('answer:', add_special_tokens=False) 58 | prefix_question_token_id = self.tokenizer.encode('question:', add_special_tokens=False) 59 | 60 | input_ids = prefix_answer_token_id 61 | input_ids += self.tokenizer.encode(answer, add_special_tokens=False) 62 | input_ids += prefix_content_token_id 63 | input_ids += self.tokenizer.encode(content, add_special_tokens=False) 64 | input_ids = self.add_padding_data(input_ids) 65 | 66 | label_ids = prefix_question_token_id 67 | label_ids += self.tokenizer.encode(question, add_special_tokens=False) 68 | label_ids.append(self.tokenizer.eos_token_id) 69 | dec_input_ids = [self.tokenizer.eos_token_id] 70 | dec_input_ids += label_ids[:-1] 71 | dec_input_ids = self.add_padding_data(dec_input_ids) 72 | label_ids = self.add_ignored_data(label_ids) 73 | 74 | return {'input_ids': np.array(input_ids, dtype=np.int_), 75 | 'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_), 76 | 'labels': np.array(label_ids, dtype=np.int_)} 77 | 78 | def __len__(self): 79 | return self.len 80 | -------------------------------------------------------------------------------- /post-processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 31, 6 | "id": "95a3b55d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "dataset = []\n", 11 | "path = 'output_9.txt'\n", 12 | "index = path.find('.')\n", 13 | "output_path = path[:index] + '_' + path[index::]\n", 14 | "\n", 15 | "with open(path, encoding='utf-8') as f:\n", 16 | " for line in f.readlines():\n", 17 | " dataset.append(line.strip())\n", 18 | " \n", 19 | "for idx, data in enumerate(dataset):\n", 20 | " dataset[idx] = data[:-1] + ' ?'\n", 21 | " \n", 22 | "for idx, data in enumerate(dataset):\n", 23 | " if data.find('question:') == 0:\n", 24 | " dataset[idx] = data[10::]\n", 25 | " \n", 26 | "with open(output_path, 'w', encoding='utf-8') as f:\n", 27 | " for data in dataset:\n", 28 | " f.write(data.strip() + '\\n')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "98da82f9", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 53, 42 | "id": "5aeefd04", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "dataset = []\n", 47 | "path = 'output_9.txt'\n", 48 | "index = path.find('.')\n", 49 | "output_path = path[:index] + '_' + path[index::]\n", 50 | "\n", 51 | "with open(path, encoding='utf-8') as f:\n", 52 | " for line in f.readlines():\n", 53 | " dataset.append(line.strip())\n", 54 | " \n", 55 | "for idx, data in enumerate(dataset):\n", 56 | " if data.find(\"'?\") != -1:\n", 57 | " dataset[idx] = dataset[idx].replace(\"'?\", '?')\n", 58 | " \n", 59 | "for idx, data in enumerate(dataset):\n", 60 | " if data.find(\"'s\") != -1:\n", 61 | " dataset[idx] = dataset[idx].replace(\"'s\", \" 's\")\n", 62 | " #elif data.find(\"'\") != -1:\n", 63 | " # dataset[idx] = dataset[idx].replace(\"'\", \" ' \")\n", 64 | " \n", 65 | "for idx, data in enumerate(dataset):\n", 66 | " dataset[idx] = data[:-1] + ' ?'\n", 67 | " \n", 68 | "for idx, data in enumerate(dataset):\n", 69 | " dataset[idx] = dataset[idx].replace(' ', ' ')\n", 70 | " \n", 71 | "for idx, data in enumerate(dataset):\n", 72 | " if data.find('question:') == 0:\n", 73 | " dataset[idx] = data[10::]\n", 74 | " \n", 75 | "for idx, data in enumerate(dataset):\n", 76 | " if data[0] == ',':\n", 77 | " dataset[idx] = data[1::]\n", 78 | " \n", 79 | "with open(output_path, 'w', encoding='utf-8') as f:\n", 80 | " for data in dataset:\n", 81 | " f.write(data.strip() + '\\n')" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "b88546cf", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "Python 3 (ipykernel)", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.9.7" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 5 114 | } 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T5-Question Generation 2 | 3 | ## Load T5 4 | - using huggingface hub 5 | - https://huggingface.co/t5-large 6 | 7 | ## Download binary 8 | ```python 9 | import torch 10 | from transformers import PreTrainedTokenizerFast 11 | from transformers import T5ForConditionalGeneration 12 | 13 | tokenizer = PreTrainedTokenizerFast.from_pretrained('Sehong/t5-large-QuestionGeneration') 14 | model = T5ForConditionalGeneration.from_pretrained('Sehong/t5-large-QuestionGeneration') 15 | 16 | text = """ 17 | answer:Saint Bernadette Soubirous content:Architecturally , the school has a Catholic character . Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary . Immediately in front of the Main Building and facing it , is a copper statue of Christ with arms upraised with the legend "" Venite Ad Me Omnes "" . Next to the Main Building is the Basilica of the Sacred Heart . Immediately behind the basilica is the Grotto , a Marian place of prayer and reflection . It is a replica of the grotto at Lourdes , France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858 . At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ) , is a simple , modern stone statue of Mary . 18 | """ 19 | 20 | raw_input_ids = tokenizer.encode(text) 21 | input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id] 22 | 23 | question_ids = model.generate(torch.tensor([input_ids]), num_beams=4, max_length=512, eos_token_id=1) 24 | decode = tokenizer.decode(question_ids.squeeze().tolist(), skip_special_tokens=True) 25 | decode = decode.replace(' # # ', '').replace(' ', ' ').replace(' ##', '') 26 | 27 | print(decode) 28 | 29 | 'question: Who did Mary appear to in Lourdes ?' 30 | 31 | ``` 32 | ## Requirements 33 | ``` 34 | torch==1.8.0 35 | transformers==4.18.0 36 | python==2.7 (Evaluation) 37 | ``` 38 | 39 | ## Training Environment 40 | - Ubuntu 41 | - RTX 3090 42 | 43 | ## Data 44 | - SQuAD1.1 45 | - reference: Du et al., 2017 46 | - transform txt to tsv 47 | - Data Structure 48 | - Train Data : 75,722 49 | - Dev Data : 10,570 50 | - Test Data : 11,877 51 | 52 | 53 | | Prefix token | Anwer | Prefix token | content | Prefix token | question | 54 | |:-------:|:--------:|:--------:|:--------:|:--------:|:--------:| 55 | | answer: | answer | content: | content | question: | question | 56 | 57 | ## How to Train 58 | - T5 Question Generation fine-tuning 59 | ```bash 60 | [use gpu] 61 | python train.py 62 | 63 | ``` 64 | 65 | ## How to Inference 66 | ```bash 67 | [use gpu] 68 | python generate.py 69 | 70 | ``` 71 | 72 | ## Generation Sample 73 | | ||Text| 74 | |-------|-------|-------| 75 | |1|Answer|Saint Bernadette Soubirous| 76 | |1|Label|To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France ?| 77 | |1|T5-large|question: Who did Mary appear to in Lourdes, France?| 78 | 79 | | ||Text| 80 | |-------|-------|-------| 81 | |2|Answer|a copper statue of Christ| 82 | |2|Label|What is in front of the Notre Dame Main Building ?| 83 | |2|T5-large|question: What is in front of the Main Building?| 84 | 85 | | ||Text| 86 | |-------|-------|-------| 87 | |3|Answer|the Main Building| 88 | |3|Label|The Basilica of the Sacred heart at Notre Dame is beside to which structure ?| 89 | |3|T5-large|question: Where is the Basilica of the Sacred Heart located?| 90 | 91 | 92 | 93 | ## Model Performance 94 | - Using test data to evaluate BLEU, METEOR, ROUGE-L score 95 | 96 | | |BLEU-1|BLEU-2|BLEU-3|BLEU-4|METEOR|ROUGE-L| 97 | |------|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:| 98 | |Score|51.333|36.742|28.218|22.289|26.126|51.069| 99 | 100 | ## Demo 101 | 102 | https://huggingface.co/Sehong/t5-large-QuestionGeneration 103 | 104 | ## Reference 105 | - [Dataset: Du et al., 2017](https://arxiv.org/pdf/1705.00106.pdf) 106 | - [Evaluation Code](https://github.com/microsoft/unilm/tree/master/unilm-v1/src/qg) 107 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import T5 4 | import dataset 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | from transformers.optimization import AdamW, get_cosine_schedule_with_warmup 9 | from tqdm import tqdm 10 | import math 11 | import random 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | writer = SummaryWriter() 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | 18 | #Random Seed 19 | random_seed = 1 20 | torch.manual_seed(random_seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | np.random.seed(random_seed) 24 | random.seed(random_seed) 25 | torch.cuda.manual_seed(random_seed) 26 | 27 | 28 | # Config 29 | batch_size = 4 30 | epochs = 10 31 | warmup_ratio = 0.1 32 | learning_rate = 3e-5 33 | grad_clip = 1.0 34 | train_log_interval = 100 35 | 36 | # Model, tokenizer init 37 | model = T5.T5ConditionalGeneration() 38 | 39 | # post-training 40 | model.load_state_dict(torch.load('output_question_infilling_and_delete_wh_backtranslation/t5_epoch_0.pth')) 41 | tokenizer = model.tokenizer 42 | 43 | # Data file path 44 | train_path = 'data/train.tsv' 45 | dev_path = 'data/dev.tsv' 46 | output_path = 'output_question_infilling_and_delete_wh_backtranslation_fine' 47 | 48 | if not os.path.exists(output_path): 49 | os.makedirs(output_path) 50 | 51 | # dataset, dataloader 52 | train_dataset = dataset.T5QGDataset(train_path, tokenizer) 53 | train_dataloader = DataLoader(train_dataset, batch_size) 54 | 55 | dev_dataset = dataset.T5QGDataset(dev_path, tokenizer) 56 | dev_dataloader = DataLoader(dev_dataset, batch_size) 57 | 58 | # optimizer 59 | param_optimizer = list(model.named_parameters()) 60 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 61 | optimizer_grouped_parameters = [ 62 | {'params': [p for n, p in param_optimizer if not any( 63 | nd in n for nd in no_decay)], 'weight_decay': 0.01}, 64 | {'params': [p for n, p in param_optimizer if any( 65 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 66 | ] 67 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=False) 68 | 69 | # scheduler 70 | data_len = len(train_dataloader) 71 | num_train_steps = int(data_len / batch_size * epochs) 72 | num_warmup_steps = int(num_train_steps * warmup_ratio) 73 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps) 74 | 75 | # logging data info 76 | logging.info(f'data length {data_len}') 77 | logging.info(f'num_train_steps : {num_train_steps}') 78 | logging.info(f'num_warmup_steps : {num_warmup_steps}') 79 | 80 | # device 81 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 82 | 83 | # dev 84 | def _validate( 85 | model: T5.T5ConditionalGeneration, 86 | dev_dataloader: DataLoader, 87 | device: torch.device, 88 | logger: logging.Logger, 89 | global_step: int, 90 | ): 91 | model.eval() 92 | loss_list = [] 93 | for batch_data in tqdm(dev_dataloader, desc="[EVAL]"): 94 | with torch.no_grad(): 95 | input_ids, decoder_input_ids, labels = tuple(value.to(device) for value in batch_data.values()) 96 | model_outputs = model.forward(input_ids, decoder_input_ids, labels) 97 | loss_list.append(model_outputs.loss.item()) 98 | 99 | mean_loss = np.mean(loss_list) 100 | logger.info(f"[EVAL] global_step:{global_step} loss:{mean_loss:.4f} perplexity:{math.exp(mean_loss):.4f}") 101 | model.train() 102 | 103 | return mean_loss 104 | 105 | model.train() 106 | loss_list_between_log_interval = [] 107 | least_dev_loss = 999 108 | for epoch_id in range(epochs): 109 | for step_index, batch_data in tqdm(enumerate(train_dataloader), f"[TRAIN] EP:{epoch_id}", total=len(train_dataloader)): 110 | global_step = len(train_dataloader) * epoch_id + step_index + 1 111 | optimizer.zero_grad() 112 | 113 | input_ids, decoder_input_ids, labels = tuple(value.to(device) for value in batch_data.values()) 114 | 115 | model_outputs = model.forward(input_ids, decoder_input_ids, labels) 116 | 117 | writer.add_scalar("Perplexity/train", math.exp(model_outputs.loss.item()), global_step) 118 | writer.flush() 119 | 120 | model_outputs.loss.backward() 121 | 122 | # model_outputs.loss.backward() 123 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 124 | optimizer.step() 125 | scheduler.step() 126 | 127 | # for logging 128 | loss_list_between_log_interval.append(model_outputs.loss.item()) 129 | 130 | if global_step % train_log_interval == 0: 131 | mean_loss = np.mean(loss_list_between_log_interval) 132 | logger.info( 133 | f"EP:{epoch_id} global_step:{global_step} " 134 | f"loss:{mean_loss:.4f} perplexity:{math.exp(mean_loss):.4f}" 135 | ) 136 | loss_list_between_log_interval.clear() 137 | ''' 138 | if global_step % validation_interval == 0: 139 | dev_loss = _validate(model, dev_dataloader, device, logger, global_step) 140 | state_dict = model.state_dict() 141 | if dev_loss.item() < least_dev_loss: 142 | least_dev_loss = dev_loss.item() 143 | model_path = os.path.join('output_post_fine4', f"bart_best.pth") 144 | logger.info(f"Save best model") 145 | torch.save(state_dict, model_path) 146 | 147 | if global_step % save_interval == 0: 148 | state_dict = model.state_dict() 149 | model_path = os.path.join('output_post_fine4', f"bart_step_{global_step}.pth") 150 | logger.info(f"global_step: {global_step} model saved at {model_path}") 151 | torch.save(state_dict, model_path) 152 | 153 | ''' 154 | 155 | dev_loss = _validate(model, dev_dataloader, device, logger, global_step) 156 | state_dict = model.state_dict() 157 | if dev_loss.item() < least_dev_loss: 158 | least_dev_loss = dev_loss.item() 159 | model_path = os.path.join(output_path, f"t5_best.pth") 160 | logger.info(f"Save best model") 161 | torch.save(state_dict, model_path) 162 | 163 | model_path = os.path.join(output_path, f"t5_epoch_{epoch_id}.pth") 164 | logger.info(f"epoch: {epoch_id} model saved at {model_path}") 165 | torch.save(state_dict, model_path) 166 | 167 | writer.close() 168 | --------------------------------------------------------------------------------