├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── chatbot ├── chatbot.py ├── config.yml ├── data.py ├── interact.py ├── processing.py ├── train.py └── utils.py ├── dataset ├── train_dialogues.pickle ├── train_ids.pickle ├── valid_dialogues.pickle └── valid_ids.pickle └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /__pycache__ 3 | /models -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.1 2 | RUN mkdir /app 3 | COPY . /app 4 | WORKDIR /app/chatbot 5 | RUN pip install -r requirements.txt 6 | ENTRYPOINT ["bash"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 xcapt0 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ☕ GPT2 Chatbot 2 | 3 | GPT-2 chatbot for daily conversations trained on `Daily Dialogue`, `Empathetic Dialogues`, `PERSONA-CHAT`, `Blended Skill Talk` datasets. This chatbot is made based on GPT2 Model transformer with a language modeling head on top. 4 | 5 | ![chatbot](https://user-images.githubusercontent.com/70326958/151570518-ce70261a-6e8e-47a0-92e5-2d7638e7aa68.jpg) 6 | 7 | 8 | ## ⌛ Installation 9 | 10 | Download the [model](https://gpt2chatbot.s3.us-east-2.amazonaws.com/model.h5) from AWS S3 storage and run the following command: 11 | 12 | ```sh 13 | git pull https://github.com/xcapt0/gpt2_chatbot.git 14 | docker build -t gpt2_bot . 15 | ``` 16 | 17 | ## 🤖 Usage 18 | 19 | Run the docker container: 20 | ```sh 21 | docker run --rm -it gpt2_bot 22 | ``` 23 | 24 | There are 2 different ways to use the chatbot: `train` and `interact` mode 25 | 26 | ### Interaction mode 27 | To launch the chatbot run the following command. Specify `--checkpoint` path to your model 28 | ```sh 29 | python chatbot.py --mode interact --checkpoint path/to/model.h5 30 | ``` 31 | 32 | ### Train mode 33 | To train the model run the following command. Specify `--checkpoint` if you needed 34 | ```sh 35 | python chatbot.py --mode train 36 | ``` 37 | 38 | ## 📝 License 39 | 40 | Copyright © 2022 [xcapt0](https://github.com/xcapt0).
41 | This project is [MIT](https://github.com/xcapt0/gpt2_chatbot/blob/master/LICENSE) licensed. 42 | -------------------------------------------------------------------------------- /chatbot/chatbot.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import nltk 4 | from glob import glob 5 | from argparse import ArgumentParser 6 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 7 | 8 | from data import Dialogues 9 | from utils import set_seed 10 | 11 | 12 | def main(args): 13 | set_seed(args['seed']) 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | args['device'] = device 17 | 18 | tokenizer = load_tokenizer(args) 19 | model = load_model(args, tokenizer, device) 20 | 21 | if dataset_is_missing(args): 22 | dialogues = Dialogues(tokenizer, args) 23 | train_dataset, valid_dataset = dialogues.load() 24 | 25 | dataset_types = ['train', 'valid'] 26 | datasets = [train_dataset, valid_dataset] 27 | 28 | for dataset_type, dataset in zip(dataset_types, datasets): 29 | dialogues.save(dataset_type, tokenizer, dataset) 30 | 31 | if args['mode'] == 'train': 32 | from train import Trainer 33 | trainer = Trainer(model, args) 34 | trainer.train() 35 | elif args['mode'] == 'interact': 36 | from interact import Chatbot 37 | chatbot = Chatbot(model, tokenizer, args) 38 | chatbot.run() 39 | 40 | 41 | def dataset_is_missing(args): 42 | if len(glob(f'{args["dataset_dir"]}/*.pickle')) == 0: 43 | return True 44 | return False 45 | 46 | 47 | def load_tokenizer(args): 48 | tokenizer = GPT2Tokenizer.from_pretrained(args['model_name']) 49 | special_tokens = ['', ''] 50 | tokenizer.add_special_tokens({ 51 | 'bos_token': '', 52 | 'additional_special_tokens': special_tokens 53 | }) 54 | 55 | # add new token ids to args 56 | special_tokens += ['', ''] 57 | sp1_id, sp2_id, bos_id, eos_id = tokenizer.encode(special_tokens) 58 | args['sp1_id'] = sp1_id 59 | args['sp2_id'] = sp2_id 60 | args['bos_id'] = bos_id 61 | args['eos_id'] = eos_id 62 | 63 | return tokenizer 64 | 65 | 66 | def load_model(args, tokenizer, device): 67 | model = GPT2LMHeadModel.from_pretrained(args['model_name']).to(device) 68 | model.resize_token_embeddings(len(tokenizer)) 69 | return model 70 | 71 | 72 | if __name__ == '__main__': 73 | nltk.download('wordnet') 74 | nltk.download('omw-1.4') 75 | nltk.download('punkt') 76 | 77 | parser = ArgumentParser() 78 | parser.add_argument('--mode', type=str, required=True, 79 | help='Pass "train" for training mode and "interact" for interaction mode') 80 | parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint of the model') 81 | 82 | user_args = parser.parse_args() 83 | arguments = yaml.safe_load(open('config.yml')) 84 | arguments.update(vars(user_args)) 85 | 86 | main(arguments) 87 | -------------------------------------------------------------------------------- /chatbot/config.yml: -------------------------------------------------------------------------------- 1 | dataset_dir: "../dataset" # name of the directory where files are stored 2 | train_frac: 0.85 # ratio of the conversations to be included in the train set 3 | model_name: "gpt2" # name of the model for tokenizer and transformer 4 | seed: 8459 # random seed 5 | lr: 0.00002 # learning rate 6 | warmup_ratio: 0.1 # ratio of warmup steps to the total training steps 7 | batch_size: 8 # batch size 8 | num_epochs: 10 # number of total epochs 9 | max_len: 100 # maximum length of input sequence 10 | max_history: 5 # maximum number of dialogue histories to include 11 | models_dir: "../models" # directory name for saved checkpoints 12 | stop_command: "bye" # command to stop the conversation when inferencing 13 | top_p: 0.9 # top p 14 | top_k: 50 # top k 15 | temperature: 0.7 # randomness of predictions -------------------------------------------------------------------------------- /chatbot/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from itertools import chain 4 | from tqdm.auto import tqdm 5 | from torch.utils.data import Dataset 6 | 7 | from processing import Processing 8 | 9 | 10 | class Dialogues(Processing): 11 | def __init__(self, tokenizer, args): 12 | self.tokenizer = tokenizer 13 | self.args = args 14 | self.dataset_list = ['daily_dialog', 'empathetic_dialogues', 'persona_chat', 'blended_skill_talk'] 15 | super().__init__(tokenizer, args['train_frac']) 16 | 17 | def load(self): 18 | train_dataset = [] 19 | valid_dataset = [] 20 | 21 | for dataset_name in self.dataset_list: 22 | print(f'Loading {dataset_name} dataset...') 23 | 24 | train_dialogues, valid_dialogues = self._load_dialog(dataset=dataset_name) 25 | train_dataset += train_dialogues 26 | valid_dataset += valid_dialogues 27 | 28 | return train_dataset, valid_dataset 29 | 30 | def save(self, prefix, tokenizer, dialogues): 31 | print(f'Saving {prefix} dialogues to file...') 32 | 33 | if not os.path.isdir(self.args["dataset_dir"]): 34 | os.makedirs(self.args["dataset_dir"]) 35 | 36 | dialogues_path = f'{self.args["dataset_dir"]}/{prefix}_dialogues.pickle' 37 | ids_path = f'{self.args["dataset_dir"]}/{prefix}_ids.pickle' 38 | 39 | with open(dialogues_path, 'wb') as f: 40 | pickle.dump(dialogues, f) 41 | 42 | print(f'Saving {prefix} ids to file...') 43 | ids = [] 44 | for dialogue in tqdm(dialogues): 45 | dialogue_ids = [] 46 | for utter in dialogue: 47 | tokens = tokenizer.tokenize(utter) 48 | token_ids = tokenizer.encode(tokens) 49 | dialogue_ids.append(token_ids) 50 | ids.append(dialogue_ids) 51 | 52 | with open(ids_path, 'wb') as f: 53 | pickle.dump(ids, f) 54 | 55 | print('Saving complete!') 56 | 57 | def _load_dialog(self, dataset=None): 58 | if dataset == 'daily_dialog': 59 | return self._load_daily() 60 | elif dataset == 'empathetic_dialogues': 61 | return self._load_empathetic() 62 | elif dataset == 'persona_chat': 63 | return self._load_persona() 64 | elif dataset == 'blended_skill_talk': 65 | return self._load_blended() 66 | 67 | 68 | class DialoguesDataset(Dataset): 69 | def __init__(self, prefix, args): 70 | self.input_ids = [] 71 | self.token_type_ids = [] 72 | self.labels = [] 73 | self._prepare_data(prefix, args) 74 | 75 | def __len__(self): 76 | return len(self.input_ids) 77 | 78 | def __getitem__(self, idx): 79 | return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx] 80 | 81 | def _prepare_data(self, prefix, args): 82 | with open(f'{args["dataset_dir"]}/{prefix}_ids.pickle', 'rb') as f: 83 | dials = pickle.load(f) 84 | 85 | for dial in tqdm(dials): 86 | hists = [] 87 | for i, sentence in enumerate(dial): 88 | if i % 2 == 0: 89 | hists.append([args['sp1_id']] + sentence) 90 | else: 91 | hists.append([args['sp2_id']] + sentence) 92 | 93 | for i in range(len(hists)): 94 | if hists[i][0] == args['sp2_id']: 95 | for j in range(0, i): 96 | contexts = hists[j:i + 1] 97 | if len(contexts) > args['max_history']: 98 | num_exceeded = len(contexts) - args['max_history'] 99 | contexts = contexts[num_exceeded:] 100 | if len(contexts) < 2: 101 | break 102 | 103 | input_ids = [args['bos_id']] + list(chain.from_iterable(contexts)) + [args['eos_id']] 104 | if len(input_ids) <= args['max_len']: 105 | start_sp_id, next_sp_id = contexts[0][0], contexts[1][0] 106 | token_type_ids = [[start_sp_id] * len(ctx) if c % 2 == 0 else [next_sp_id] * len(ctx) for c, ctx in enumerate(contexts)] 107 | token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [args['sp2_id']] 108 | 109 | labels = [[-100] * len(ctx) if c < len(contexts) - 1 else [-100] + ctx[1:] for c, ctx in enumerate(contexts)] 110 | labels = [-100] + list(chain.from_iterable(labels)) + [args['eos_id']] 111 | 112 | self.input_ids.append(input_ids) 113 | self.token_type_ids.append(token_type_ids) 114 | self.labels.append(labels) 115 | 116 | break 117 | 118 | del dials 119 | -------------------------------------------------------------------------------- /chatbot/interact.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from utils import top_k_filter, lemma_sentence 8 | 9 | 10 | class Chatbot: 11 | def __init__(self, model, tokenizer, args): 12 | self.model = model 13 | self.tokenizer = tokenizer 14 | self.args = args 15 | 16 | def run(self): 17 | assert self.args['checkpoint'], 'Checkpoint was not found. Please specify the valid checkpoint through --checkpoint CHECKPOINT_PATH' 18 | self._load_checkpoint() 19 | 20 | print('Launching the chatbot...') 21 | print(f'If you want to stop, type the \"{self.args["stop_command"]}\" command') 22 | 23 | self.model.eval() 24 | 25 | with torch.no_grad(): 26 | input_hists = [] 27 | 28 | while True: 29 | sentence = input('You: ') 30 | if sentence == self.args['stop_command']: 31 | print('Bot: Good bye.') 32 | break 33 | 34 | sentence = lemma_sentence(sentence) 35 | 36 | input_ids = [self.args['sp1_id']] + self.tokenizer.encode(sentence) 37 | input_hists.append(input_ids) 38 | 39 | if len(input_hists) >= self.args['max_history']: 40 | num_exceeded = len(input_hists) - self.args['max_history'] 41 | input_hists = input_hists[num_exceeded:] 42 | 43 | input_ids = [self.args['bos_id']] + list(chain.from_iterable(input_hists)) + [self.args['sp2_id']] 44 | start_sp_id = input_hists[0][0] 45 | next_sp_id = self.args['sp1_id'] if start_sp_id == self.args['sp2_id'] else self.args['sp2_id'] 46 | token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)] 47 | assert len(token_type_ids) == len(input_hists) 48 | token_type_ids = [start_sp_id] + list(chain.from_iterable(input_hists)) + [self.args['sp2_id']] 49 | assert len(input_ids) == len(token_type_ids) 50 | input_len = len(input_ids) 51 | 52 | input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.args['device']) 53 | token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.args['device']) 54 | 55 | output_ids = self._top_filtering(input_ids, token_type_ids) 56 | answer = self.tokenizer.decode(output_ids, skip_special_tokens=True) 57 | 58 | print(f'Bot: {answer}') 59 | input_hists.append([self.args['sp2_id']] + self.tokenizer.encode(answer)) 60 | 61 | def _top_filtering(self, input_ids, token_type_ids): 62 | output_ids = [] 63 | 64 | for pos in range(self.args['max_len']): 65 | output = self.model(input_ids=input_ids, token_type_ids=token_type_ids)[0] 66 | 67 | logits = output[0, -1, :] / self.args['temperature'] 68 | logits = top_k_filter(logits, top_k=self.args['top_k']) 69 | output = F.softmax(logits, dim=-1).unsqueeze(0) 70 | 71 | sorted_probs, sorted_idxs = torch.sort(output, descending=True) 72 | cumsum_probs = torch.cumsum(sorted_probs, dim=-1) 73 | idx_remove = cumsum_probs > self.args['top_p'] 74 | idx_remove[:, 1:] = idx_remove[:, :-1].clone() 75 | idx_remove[:, 0] = False 76 | sorted_probs[idx_remove] = 0.0 77 | sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True) 78 | 79 | probs = torch.zeros(output.shape, device=self.args['device']).scatter_(-1, sorted_idxs, sorted_probs) 80 | idx = torch.multinomial(probs, 1) 81 | 82 | idx_item = idx.squeeze(-1).squeeze(-1).item() 83 | 84 | if idx_item in output_ids: 85 | continue 86 | 87 | output_ids.append(idx_item) 88 | 89 | if idx_item == self.args['eos_id']: 90 | break 91 | 92 | input_ids = torch.cat((input_ids, idx.reshape(1, 1)), dim=-1) 93 | next_type_id = torch.LongTensor([[self.args['sp2_id']]]).to(self.args['device']) 94 | token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1) 95 | assert input_ids.shape == token_type_ids.shape 96 | 97 | return output_ids 98 | 99 | def _load_checkpoint(self): 100 | path = self.args['checkpoint'] 101 | if os.path.exists(path): 102 | print('Loading checkpoint...') 103 | checkpoint = torch.load(path, map_location=self.args['device']) 104 | self.model.load_state_dict(checkpoint['model_state_dict']) 105 | print(f'Found checkpoint file: {os.path.basename(path)}') 106 | else: 107 | print("Can't find the specified checkpoint") 108 | -------------------------------------------------------------------------------- /chatbot/processing.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tqdm.auto import tqdm 3 | 4 | 5 | class Processing: 6 | def __init__(self, tokenizer, train_frac): 7 | self.tokenizer = tokenizer 8 | self.train_frac = train_frac 9 | 10 | def _load_daily(self): 11 | dataset = load_dataset('daily_dialog') 12 | train_dialogues = dataset['train']['dialog'] 13 | valid_dialogues = dataset['validation']['dialog'] 14 | test_dialogues = dataset['test']['dialog'] 15 | 16 | all_dialogues = train_dialogues + valid_dialogues + test_dialogues 17 | 18 | for i, dialogue in enumerate(tqdm(all_dialogues)): 19 | new_dialogue = [] 20 | for sentence in dialogue: 21 | token_list = self.tokenizer.tokenize(sentence.strip().replace('’', '\'')) 22 | token_list = self._process_token_list(token_list) 23 | text = self.tokenizer.convert_tokens_to_string(token_list) 24 | new_dialogue.append(text) 25 | 26 | all_dialogues[i] = new_dialogue 27 | 28 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)] 29 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):] 30 | 31 | return train_dialogues, valid_dialogues 32 | 33 | def _load_empathetic(self): 34 | dataset = load_dataset('empathetic_dialogues') 35 | train_data = dataset['train'] 36 | valid_data = dataset['validation'] 37 | test_data = dataset['test'] 38 | 39 | sentences = train_data['utterance'] + valid_data['utterance'] + test_data['utterance'] 40 | total_conv_ids = train_data['conv_id'] + valid_data['conv_id'] + test_data['conv_id'] 41 | total_speaker_ids = train_data['speaker_idx'] + valid_data['speaker_idx'] + test_data['speaker_idx'] 42 | 43 | conv_dict = {} 44 | cur_speaker_idx = -1 45 | for i, sentence in enumerate(tqdm(sentences)): 46 | conv_id = total_conv_ids[i] 47 | speaker_idx = total_speaker_ids[i] 48 | 49 | sentence_modified = sentence.strip().replace('_comma_', ',') 50 | new_token_list = self._process_token_list(self.tokenizer.tokenize(sentence_modified)) 51 | text = self.tokenizer.convert_tokens_to_string(new_token_list) 52 | 53 | if '_conv' in sentence: 54 | continue 55 | 56 | if conv_id not in conv_dict: 57 | conv_dict[conv_id] = [] 58 | cur_speaker_idx = -1 59 | 60 | if cur_speaker_idx != speaker_idx: 61 | conv_dict[conv_id].append(text) 62 | cur_speaker_idx = speaker_idx 63 | else: 64 | conv_dict[conv_id][-1] += f" {text}" 65 | 66 | train_dialogues = [] 67 | valid_dialogues = [] 68 | 69 | train_dialogue_num = int(len(conv_dict) * self.train_frac) 70 | for i, (conv_id, utter_list) in enumerate(conv_dict.items()): 71 | if i < train_dialogue_num: 72 | train_dialogues.append(utter_list) 73 | else: 74 | valid_dialogues.append(utter_list) 75 | 76 | return train_dialogues, valid_dialogues 77 | 78 | def _load_persona(self): 79 | import requests 80 | 81 | url = 'https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json' 82 | response = requests.get(url) 83 | assert response.status_code == 200, 'Error receiving data from server' 84 | dataset = response.json() 85 | 86 | train_data = dataset['train'] 87 | valid_data = dataset['valid'] 88 | all_data = train_data + valid_data 89 | all_dialogues = [] 90 | 91 | for obj in tqdm(all_data): 92 | dialogue = obj['utterances'][-1]['history'] 93 | new_dialogue = [] 94 | 95 | for i, sentence in enumerate(dialogue): 96 | if sentence.strip() != '__ SILENCE __': 97 | token_list = self.tokenizer.tokenize(sentence.strip()) 98 | new_token_list = self._process_token_list(token_list) 99 | text = self.tokenizer.convert_tokens_to_string(new_token_list) 100 | new_dialogue.append(text) 101 | 102 | all_dialogues.append(new_dialogue) 103 | 104 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)] 105 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):] 106 | 107 | return train_dialogues, valid_dialogues 108 | 109 | def _load_blended(self): 110 | dataset = load_dataset('blended_skill_talk') 111 | data_train = dataset['train'] 112 | data_valid = dataset['validation'] 113 | data_test = dataset['test'] 114 | 115 | all_previous_sentences = data_train['previous_utterance'] + \ 116 | data_valid['previous_utterance'] + \ 117 | data_test['previous_utterance'] 118 | all_free_messages = data_train['free_messages'] + \ 119 | data_valid['free_messages'] + \ 120 | data_test['free_messages'] 121 | all_guided_messages = data_train['guided_messages'] + \ 122 | data_valid['guided_messages'] + \ 123 | data_test['guided_messages'] 124 | 125 | all_dialogues = [] 126 | for i, free_message in enumerate(tqdm(all_free_messages)): 127 | free_message_list = [sentence.strip() for sentence in free_message if len(sentence.strip()) > 0] 128 | guided_message_list = [sentence.strip() for sentence in all_guided_messages[i] if len(sentence.strip()) > 0] 129 | dialogue = all_previous_sentences[i] 130 | 131 | for j in range(len(free_message_list)): 132 | token_list = self._process_token_list(self.tokenizer.tokenize(free_message_list[j])) 133 | text = self.tokenizer.convert_tokens_to_string(token_list) 134 | dialogue.append(text) 135 | 136 | if j < len(guided_message_list): 137 | token_list = self._process_token_list(self.tokenizer.tokenize(guided_message_list[j])) 138 | text = self.tokenizer.convert_tokens_to_string(token_list) 139 | dialogue.append(text) 140 | 141 | all_dialogues.append(dialogue) 142 | 143 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)] 144 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):] 145 | 146 | return train_dialogues, valid_dialogues 147 | 148 | @staticmethod 149 | def _process_token_list(token_list): 150 | space = 'Ġ' 151 | quotes = ['"', '\''] 152 | end_marks = ['.', ',', '?', '!', '...'] 153 | abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve'] 154 | token_list[0] = token_list[0].capitalize() 155 | quote_count = 0 156 | 157 | for i, token in enumerate(token_list): 158 | if space in token: 159 | if token[1:] in end_marks or token[1:] in abbreviations: 160 | token_list[i] = token[1:] 161 | 162 | if token[1:] == quotes[1]: 163 | if i < len(token_list) - 1: 164 | if token_list[i + 1] in abbreviations or ( 165 | token_list[i + 1][0] == space and token_list[i + 1][1:] in abbreviations): 166 | token_list[i] = token[1:] 167 | 168 | if token[0] == space and token[1:] in quotes: 169 | if quote_count % 2 == 1: 170 | token_list[i] = token[1:] 171 | quote_count = 0 172 | else: 173 | if i < len(token_list) - 1 and token_list[i + 1][0] == space: 174 | token_list[i + 1] = token_list[i + 1][1:] 175 | quote_count += 1 176 | 177 | if token in end_marks or token[1:] in end_marks: 178 | if i < len(token_list) - 1: 179 | if token_list[i + 1][0] != space: 180 | token_list[i + 1] = space + token_list[i + 1].capitalize() 181 | else: 182 | token_list[i + 1] = space + token_list[i + 1][1:].capitalize() 183 | 184 | new_token_list = [token for token in token_list if token != space and len(token) > 0] 185 | if new_token_list[-1] not in end_marks: 186 | new_token_list.append(end_marks[0]) 187 | 188 | return new_token_list 189 | -------------------------------------------------------------------------------- /chatbot/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.optim import AdamW 9 | from transformers import get_polynomial_decay_schedule_with_warmup 10 | 11 | from data import DialoguesDataset 12 | from utils import PadCollate 13 | 14 | 15 | class Trainer: 16 | def __init__(self, model, args): 17 | print('Loading the optimizer...') 18 | self.optimizer = AdamW(model.parameters(), lr=args['lr']) 19 | self.best_loss = 1e+10 20 | self.last_epoch = 0 21 | 22 | print('Loading train & valid data...') 23 | train_dataset = DialoguesDataset('train', args) 24 | valid_dataset = DialoguesDataset('valid', args) 25 | pad = PadCollate(args) 26 | 27 | self.train_loader = DataLoader(train_dataset, 28 | collate_fn=pad, 29 | shuffle=True, 30 | batch_size=args['batch_size'], 31 | num_workers=1, 32 | pin_memory=True) 33 | self.valid_loader = DataLoader(valid_dataset, 34 | collate_fn=pad, 35 | batch_size=args['batch_size'], 36 | num_workers=1, 37 | pin_memory=True) 38 | 39 | if not os.path.exists(args['models_dir']): 40 | os.makedirs(args['models_dir']) 41 | 42 | # Calculate total training steps 43 | num_batches = len(self.train_loader) 44 | total_train_steps = args['num_epochs'] * num_batches 45 | warmup_steps = int(args['warmup_ratio'] * total_train_steps) 46 | 47 | self.model = model 48 | self.args = args 49 | self.scheduler = get_polynomial_decay_schedule_with_warmup( 50 | self.optimizer, 51 | num_warmup_steps=warmup_steps, 52 | num_training_steps=total_train_steps, 53 | power=2 54 | ) 55 | 56 | if args['checkpoint']: 57 | self._load_checkpoint() 58 | 59 | def train(self): 60 | print('Launch training...') 61 | 62 | start_epoch = self.last_epoch + 1 63 | for epoch in range(start_epoch, start_epoch + self.args['num_epochs']): 64 | print('-' * 50 + f'\nEpoch: {epoch}\n' + '-' * 50) 65 | 66 | self.model.train() 67 | train_losses = [] 68 | train_perplexity = [] 69 | 70 | for i, batch in enumerate(tqdm(self.train_loader)): 71 | input_ids, token_type_ids, labels = batch 72 | input_ids = input_ids.to(self.args['device']) 73 | token_type_ids = token_type_ids.to(self.args['device']) 74 | labels = labels.to(self.args['device']) 75 | 76 | outputs = self.model( 77 | input_ids=input_ids, 78 | token_type_ids=token_type_ids, 79 | labels=labels 80 | ) 81 | 82 | loss, logits = outputs[0], outputs[1] 83 | 84 | self.optimizer.zero_grad() 85 | loss.backward() 86 | self.optimizer.step() 87 | self.scheduler.step() 88 | 89 | train_losses.append(loss.detach()) 90 | ppx = torch.exp(loss.detach()) 91 | train_perplexity.append(ppx) 92 | 93 | train_losses = [loss.item() for loss in train_losses] 94 | train_perplexity = [ppx.item() if not math.isinf(ppx.item()) else 1e+8 for ppx in train_perplexity] 95 | train_loss = np.mean(train_losses) 96 | train_ppx = np.mean(train_perplexity) 97 | print(f'Train loss: {train_loss} \nTrain perplexity: {train_ppx}') 98 | 99 | self.last_epoch += 1 100 | 101 | valid_loss, valid_ppx = self.validate() 102 | 103 | if valid_loss < self.best_loss: 104 | self.best_loss = valid_loss 105 | state_dict = { 106 | 'model_state_dict': self.model.state_dict(), 107 | 'optim_state_dict': self.optimizer.state_dict(), 108 | 'loss': self.best_loss, 109 | 'epoch': self.last_epoch 110 | } 111 | 112 | filename = f"{self.args['model_dir']}/model_best_{round(self.best_loss, 4)}.h5" 113 | torch.save(state_dict, filename) 114 | print(f'Checkpoint saved: {filename}') 115 | 116 | print(f'Best valid loss: {self.best_loss}') 117 | print(f'Valid loss: {valid_loss} \nValid perplexity: {valid_ppx}') 118 | 119 | print('Training completed') 120 | 121 | def validate(self): 122 | print('Launch validation...') 123 | self.model.eval() 124 | 125 | valid_losses = [] 126 | valid_ppxs = [] 127 | with torch.no_grad(): 128 | for i, batch in enumerate(tqdm(self.valid_loader)): 129 | input_ids, token_type_ids, labels = batch 130 | input_ids = input_ids.to(self.args['device']) 131 | token_type_ids = token_type_ids.to(self.args['device']) 132 | labels = labels.to(self.args['device']) 133 | 134 | outputs = self.model( 135 | input_ids=input_ids, 136 | token_type_ids=token_type_ids, 137 | labels=labels 138 | ) 139 | 140 | loss, logits = outputs[0], outputs[1] 141 | 142 | valid_losses.append(loss.detach()) 143 | ppx = torch.exp(loss.detach()) 144 | valid_ppxs.append(ppx) 145 | 146 | valid_losses = [loss.item() for loss in valid_losses] 147 | valid_ppxs = [ppx.item() if not math.isinf(ppx.item()) else 1e+8 for ppx in valid_ppxs] 148 | valid_loss = np.mean(valid_losses) 149 | valid_ppx = np.mean(valid_ppxs) 150 | 151 | if math.isnan(valid_ppx): 152 | valid_ppx = 1e+8 153 | 154 | return valid_loss, valid_ppx 155 | 156 | def _load_checkpoint(self): 157 | path = self.args['checkpoint'] 158 | if os.path.exists(path): 159 | print('Loading checkpoint...') 160 | checkpoint = torch.load(path, map_location=self.args['device']) 161 | self.model.load_state_dict(checkpoint['model_state_dict']) 162 | 163 | print(f'The training restarts with the specified checkpoint: {os.path.basename(path)}') 164 | self.optimizer.load_state_dict(checkpoint['optim_state_dict']) 165 | self.best_loss = checkpoint['loss'] 166 | self.last_epoch = checkpoint['epoch'] 167 | else: 168 | print("Can't find the specified checkpoint") -------------------------------------------------------------------------------- /chatbot/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import nltk 5 | from nltk.stem import WordNetLemmatizer 6 | 7 | import torch 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | 11 | class PadCollate: 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | def __call__(self, batch): 16 | eos_id = self.args['eos_id'] 17 | input_ids, token_type_ids, labels = [], [], [] 18 | 19 | for idx, seqs in enumerate(batch): 20 | input_ids.append(torch.LongTensor(seqs[0])) 21 | token_type_ids.append(torch.LongTensor(seqs[0])) 22 | labels.append(torch.LongTensor(seqs[2])) 23 | 24 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=eos_id) 25 | token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=eos_id) 26 | labels = pad_sequence(labels, batch_first=True, padding_value=-100) 27 | 28 | return input_ids, token_type_ids, labels 29 | 30 | 31 | def set_seed(seed): 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | random.seed(seed) 36 | 37 | 38 | def top_k_filter(logits, top_k=0., threshold=-float('Inf'), filter_value=-float('Inf')): 39 | assert logits.dim() == 1 40 | top_k = min(top_k, logits.size(-1)) 41 | 42 | if top_k > 0: 43 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 44 | logits[indices_to_remove] = filter_value 45 | 46 | indices_to_remove = logits < threshold 47 | logits[indices_to_remove] = filter_value 48 | 49 | return logits 50 | 51 | 52 | def lemma_sentence(text): 53 | lemmatizer = WordNetLemmatizer() 54 | stop_words = ['stop', 'the', 'to', 'and', 'a', 'in', 'it', '\'s', 'is', 'I', 'that', 'had', 'on', 'for', 'were', 'was'] 55 | tokenization = [word for word in nltk.word_tokenize(text) if word not in stop_words] 56 | sentence = ' '.join([lemmatizer.lemmatize(word) for word in tokenization]) 57 | return sentence 58 | -------------------------------------------------------------------------------- /dataset/train_dialogues.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/train_dialogues.pickle -------------------------------------------------------------------------------- /dataset/train_ids.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/train_ids.pickle -------------------------------------------------------------------------------- /dataset/valid_dialogues.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/valid_dialogues.pickle -------------------------------------------------------------------------------- /dataset/valid_ids.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/valid_ids.pickle -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.1 2 | aiosignal==1.2.0 3 | async-timeout==4.0.2 4 | attrs==21.3.0 5 | certifi==2021.10.8 6 | charset-normalizer==2.0.9 7 | click==8.0.3 8 | colorama==0.4.4 9 | datasets==1.17.0 10 | dill==0.3.4 11 | filelock==3.4.2 12 | frozenlist==1.2.0 13 | fsspec==2021.11.1 14 | huggingface-hub==0.2.1 15 | idna==3.3 16 | joblib==1.1.0 17 | multidict==5.2.0 18 | multiprocess==0.70.12.2 19 | nltk==3.6.7 20 | numpy==1.21.5 21 | packaging==21.3 22 | pandas==1.3.5 23 | pyarrow==6.0.1 24 | pyparsing==3.0.6 25 | python-dateutil==2.8.2 26 | pytz==2021.3 27 | PyYAML==6.0 28 | regex==2021.11.10 29 | requests==2.26.0 30 | sacremoses==0.0.46 31 | six==1.16.0 32 | tokenizers==0.10.3 33 | torch==1.10.1 34 | tqdm==4.62.3 35 | transformers==4.15.0 36 | typing_extensions==4.0.1 37 | urllib3==1.26.7 38 | xxhash==2.0.2 39 | yarl==1.7.2 40 | --------------------------------------------------------------------------------