├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── chatbot
├── chatbot.py
├── config.yml
├── data.py
├── interact.py
├── processing.py
├── train.py
└── utils.py
├── dataset
├── train_dialogues.pickle
├── train_ids.pickle
├── valid_dialogues.pickle
└── valid_ids.pickle
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | /.idea
2 | /__pycache__
3 | /models
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8.1
2 | RUN mkdir /app
3 | COPY . /app
4 | WORKDIR /app/chatbot
5 | RUN pip install -r requirements.txt
6 | ENTRYPOINT ["bash"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 xcapt0
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ☕ GPT2 Chatbot
2 |
3 | GPT-2 chatbot for daily conversations trained on `Daily Dialogue`, `Empathetic Dialogues`, `PERSONA-CHAT`, `Blended Skill Talk` datasets. This chatbot is made based on GPT2 Model transformer with a language modeling head on top.
4 |
5 | 
6 |
7 |
8 | ## ⌛ Installation
9 |
10 | Download the [model](https://gpt2chatbot.s3.us-east-2.amazonaws.com/model.h5) from AWS S3 storage and run the following command:
11 |
12 | ```sh
13 | git pull https://github.com/xcapt0/gpt2_chatbot.git
14 | docker build -t gpt2_bot .
15 | ```
16 |
17 | ## 🤖 Usage
18 |
19 | Run the docker container:
20 | ```sh
21 | docker run --rm -it gpt2_bot
22 | ```
23 |
24 | There are 2 different ways to use the chatbot: `train` and `interact` mode
25 |
26 | ### Interaction mode
27 | To launch the chatbot run the following command. Specify `--checkpoint` path to your model
28 | ```sh
29 | python chatbot.py --mode interact --checkpoint path/to/model.h5
30 | ```
31 |
32 | ### Train mode
33 | To train the model run the following command. Specify `--checkpoint` if you needed
34 | ```sh
35 | python chatbot.py --mode train
36 | ```
37 |
38 | ## 📝 License
39 |
40 | Copyright © 2022 [xcapt0](https://github.com/xcapt0).
41 | This project is [MIT](https://github.com/xcapt0/gpt2_chatbot/blob/master/LICENSE) licensed.
42 |
--------------------------------------------------------------------------------
/chatbot/chatbot.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import nltk
4 | from glob import glob
5 | from argparse import ArgumentParser
6 | from transformers import GPT2Tokenizer, GPT2LMHeadModel
7 |
8 | from data import Dialogues
9 | from utils import set_seed
10 |
11 |
12 | def main(args):
13 | set_seed(args['seed'])
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | args['device'] = device
17 |
18 | tokenizer = load_tokenizer(args)
19 | model = load_model(args, tokenizer, device)
20 |
21 | if dataset_is_missing(args):
22 | dialogues = Dialogues(tokenizer, args)
23 | train_dataset, valid_dataset = dialogues.load()
24 |
25 | dataset_types = ['train', 'valid']
26 | datasets = [train_dataset, valid_dataset]
27 |
28 | for dataset_type, dataset in zip(dataset_types, datasets):
29 | dialogues.save(dataset_type, tokenizer, dataset)
30 |
31 | if args['mode'] == 'train':
32 | from train import Trainer
33 | trainer = Trainer(model, args)
34 | trainer.train()
35 | elif args['mode'] == 'interact':
36 | from interact import Chatbot
37 | chatbot = Chatbot(model, tokenizer, args)
38 | chatbot.run()
39 |
40 |
41 | def dataset_is_missing(args):
42 | if len(glob(f'{args["dataset_dir"]}/*.pickle')) == 0:
43 | return True
44 | return False
45 |
46 |
47 | def load_tokenizer(args):
48 | tokenizer = GPT2Tokenizer.from_pretrained(args['model_name'])
49 | special_tokens = ['', '']
50 | tokenizer.add_special_tokens({
51 | 'bos_token': '',
52 | 'additional_special_tokens': special_tokens
53 | })
54 |
55 | # add new token ids to args
56 | special_tokens += ['', '']
57 | sp1_id, sp2_id, bos_id, eos_id = tokenizer.encode(special_tokens)
58 | args['sp1_id'] = sp1_id
59 | args['sp2_id'] = sp2_id
60 | args['bos_id'] = bos_id
61 | args['eos_id'] = eos_id
62 |
63 | return tokenizer
64 |
65 |
66 | def load_model(args, tokenizer, device):
67 | model = GPT2LMHeadModel.from_pretrained(args['model_name']).to(device)
68 | model.resize_token_embeddings(len(tokenizer))
69 | return model
70 |
71 |
72 | if __name__ == '__main__':
73 | nltk.download('wordnet')
74 | nltk.download('omw-1.4')
75 | nltk.download('punkt')
76 |
77 | parser = ArgumentParser()
78 | parser.add_argument('--mode', type=str, required=True,
79 | help='Pass "train" for training mode and "interact" for interaction mode')
80 | parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint of the model')
81 |
82 | user_args = parser.parse_args()
83 | arguments = yaml.safe_load(open('config.yml'))
84 | arguments.update(vars(user_args))
85 |
86 | main(arguments)
87 |
--------------------------------------------------------------------------------
/chatbot/config.yml:
--------------------------------------------------------------------------------
1 | dataset_dir: "../dataset" # name of the directory where files are stored
2 | train_frac: 0.85 # ratio of the conversations to be included in the train set
3 | model_name: "gpt2" # name of the model for tokenizer and transformer
4 | seed: 8459 # random seed
5 | lr: 0.00002 # learning rate
6 | warmup_ratio: 0.1 # ratio of warmup steps to the total training steps
7 | batch_size: 8 # batch size
8 | num_epochs: 10 # number of total epochs
9 | max_len: 100 # maximum length of input sequence
10 | max_history: 5 # maximum number of dialogue histories to include
11 | models_dir: "../models" # directory name for saved checkpoints
12 | stop_command: "bye" # command to stop the conversation when inferencing
13 | top_p: 0.9 # top p
14 | top_k: 50 # top k
15 | temperature: 0.7 # randomness of predictions
--------------------------------------------------------------------------------
/chatbot/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from itertools import chain
4 | from tqdm.auto import tqdm
5 | from torch.utils.data import Dataset
6 |
7 | from processing import Processing
8 |
9 |
10 | class Dialogues(Processing):
11 | def __init__(self, tokenizer, args):
12 | self.tokenizer = tokenizer
13 | self.args = args
14 | self.dataset_list = ['daily_dialog', 'empathetic_dialogues', 'persona_chat', 'blended_skill_talk']
15 | super().__init__(tokenizer, args['train_frac'])
16 |
17 | def load(self):
18 | train_dataset = []
19 | valid_dataset = []
20 |
21 | for dataset_name in self.dataset_list:
22 | print(f'Loading {dataset_name} dataset...')
23 |
24 | train_dialogues, valid_dialogues = self._load_dialog(dataset=dataset_name)
25 | train_dataset += train_dialogues
26 | valid_dataset += valid_dialogues
27 |
28 | return train_dataset, valid_dataset
29 |
30 | def save(self, prefix, tokenizer, dialogues):
31 | print(f'Saving {prefix} dialogues to file...')
32 |
33 | if not os.path.isdir(self.args["dataset_dir"]):
34 | os.makedirs(self.args["dataset_dir"])
35 |
36 | dialogues_path = f'{self.args["dataset_dir"]}/{prefix}_dialogues.pickle'
37 | ids_path = f'{self.args["dataset_dir"]}/{prefix}_ids.pickle'
38 |
39 | with open(dialogues_path, 'wb') as f:
40 | pickle.dump(dialogues, f)
41 |
42 | print(f'Saving {prefix} ids to file...')
43 | ids = []
44 | for dialogue in tqdm(dialogues):
45 | dialogue_ids = []
46 | for utter in dialogue:
47 | tokens = tokenizer.tokenize(utter)
48 | token_ids = tokenizer.encode(tokens)
49 | dialogue_ids.append(token_ids)
50 | ids.append(dialogue_ids)
51 |
52 | with open(ids_path, 'wb') as f:
53 | pickle.dump(ids, f)
54 |
55 | print('Saving complete!')
56 |
57 | def _load_dialog(self, dataset=None):
58 | if dataset == 'daily_dialog':
59 | return self._load_daily()
60 | elif dataset == 'empathetic_dialogues':
61 | return self._load_empathetic()
62 | elif dataset == 'persona_chat':
63 | return self._load_persona()
64 | elif dataset == 'blended_skill_talk':
65 | return self._load_blended()
66 |
67 |
68 | class DialoguesDataset(Dataset):
69 | def __init__(self, prefix, args):
70 | self.input_ids = []
71 | self.token_type_ids = []
72 | self.labels = []
73 | self._prepare_data(prefix, args)
74 |
75 | def __len__(self):
76 | return len(self.input_ids)
77 |
78 | def __getitem__(self, idx):
79 | return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx]
80 |
81 | def _prepare_data(self, prefix, args):
82 | with open(f'{args["dataset_dir"]}/{prefix}_ids.pickle', 'rb') as f:
83 | dials = pickle.load(f)
84 |
85 | for dial in tqdm(dials):
86 | hists = []
87 | for i, sentence in enumerate(dial):
88 | if i % 2 == 0:
89 | hists.append([args['sp1_id']] + sentence)
90 | else:
91 | hists.append([args['sp2_id']] + sentence)
92 |
93 | for i in range(len(hists)):
94 | if hists[i][0] == args['sp2_id']:
95 | for j in range(0, i):
96 | contexts = hists[j:i + 1]
97 | if len(contexts) > args['max_history']:
98 | num_exceeded = len(contexts) - args['max_history']
99 | contexts = contexts[num_exceeded:]
100 | if len(contexts) < 2:
101 | break
102 |
103 | input_ids = [args['bos_id']] + list(chain.from_iterable(contexts)) + [args['eos_id']]
104 | if len(input_ids) <= args['max_len']:
105 | start_sp_id, next_sp_id = contexts[0][0], contexts[1][0]
106 | token_type_ids = [[start_sp_id] * len(ctx) if c % 2 == 0 else [next_sp_id] * len(ctx) for c, ctx in enumerate(contexts)]
107 | token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [args['sp2_id']]
108 |
109 | labels = [[-100] * len(ctx) if c < len(contexts) - 1 else [-100] + ctx[1:] for c, ctx in enumerate(contexts)]
110 | labels = [-100] + list(chain.from_iterable(labels)) + [args['eos_id']]
111 |
112 | self.input_ids.append(input_ids)
113 | self.token_type_ids.append(token_type_ids)
114 | self.labels.append(labels)
115 |
116 | break
117 |
118 | del dials
119 |
--------------------------------------------------------------------------------
/chatbot/interact.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from utils import top_k_filter, lemma_sentence
8 |
9 |
10 | class Chatbot:
11 | def __init__(self, model, tokenizer, args):
12 | self.model = model
13 | self.tokenizer = tokenizer
14 | self.args = args
15 |
16 | def run(self):
17 | assert self.args['checkpoint'], 'Checkpoint was not found. Please specify the valid checkpoint through --checkpoint CHECKPOINT_PATH'
18 | self._load_checkpoint()
19 |
20 | print('Launching the chatbot...')
21 | print(f'If you want to stop, type the \"{self.args["stop_command"]}\" command')
22 |
23 | self.model.eval()
24 |
25 | with torch.no_grad():
26 | input_hists = []
27 |
28 | while True:
29 | sentence = input('You: ')
30 | if sentence == self.args['stop_command']:
31 | print('Bot: Good bye.')
32 | break
33 |
34 | sentence = lemma_sentence(sentence)
35 |
36 | input_ids = [self.args['sp1_id']] + self.tokenizer.encode(sentence)
37 | input_hists.append(input_ids)
38 |
39 | if len(input_hists) >= self.args['max_history']:
40 | num_exceeded = len(input_hists) - self.args['max_history']
41 | input_hists = input_hists[num_exceeded:]
42 |
43 | input_ids = [self.args['bos_id']] + list(chain.from_iterable(input_hists)) + [self.args['sp2_id']]
44 | start_sp_id = input_hists[0][0]
45 | next_sp_id = self.args['sp1_id'] if start_sp_id == self.args['sp2_id'] else self.args['sp2_id']
46 | token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
47 | assert len(token_type_ids) == len(input_hists)
48 | token_type_ids = [start_sp_id] + list(chain.from_iterable(input_hists)) + [self.args['sp2_id']]
49 | assert len(input_ids) == len(token_type_ids)
50 | input_len = len(input_ids)
51 |
52 | input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.args['device'])
53 | token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.args['device'])
54 |
55 | output_ids = self._top_filtering(input_ids, token_type_ids)
56 | answer = self.tokenizer.decode(output_ids, skip_special_tokens=True)
57 |
58 | print(f'Bot: {answer}')
59 | input_hists.append([self.args['sp2_id']] + self.tokenizer.encode(answer))
60 |
61 | def _top_filtering(self, input_ids, token_type_ids):
62 | output_ids = []
63 |
64 | for pos in range(self.args['max_len']):
65 | output = self.model(input_ids=input_ids, token_type_ids=token_type_ids)[0]
66 |
67 | logits = output[0, -1, :] / self.args['temperature']
68 | logits = top_k_filter(logits, top_k=self.args['top_k'])
69 | output = F.softmax(logits, dim=-1).unsqueeze(0)
70 |
71 | sorted_probs, sorted_idxs = torch.sort(output, descending=True)
72 | cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
73 | idx_remove = cumsum_probs > self.args['top_p']
74 | idx_remove[:, 1:] = idx_remove[:, :-1].clone()
75 | idx_remove[:, 0] = False
76 | sorted_probs[idx_remove] = 0.0
77 | sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)
78 |
79 | probs = torch.zeros(output.shape, device=self.args['device']).scatter_(-1, sorted_idxs, sorted_probs)
80 | idx = torch.multinomial(probs, 1)
81 |
82 | idx_item = idx.squeeze(-1).squeeze(-1).item()
83 |
84 | if idx_item in output_ids:
85 | continue
86 |
87 | output_ids.append(idx_item)
88 |
89 | if idx_item == self.args['eos_id']:
90 | break
91 |
92 | input_ids = torch.cat((input_ids, idx.reshape(1, 1)), dim=-1)
93 | next_type_id = torch.LongTensor([[self.args['sp2_id']]]).to(self.args['device'])
94 | token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
95 | assert input_ids.shape == token_type_ids.shape
96 |
97 | return output_ids
98 |
99 | def _load_checkpoint(self):
100 | path = self.args['checkpoint']
101 | if os.path.exists(path):
102 | print('Loading checkpoint...')
103 | checkpoint = torch.load(path, map_location=self.args['device'])
104 | self.model.load_state_dict(checkpoint['model_state_dict'])
105 | print(f'Found checkpoint file: {os.path.basename(path)}')
106 | else:
107 | print("Can't find the specified checkpoint")
108 |
--------------------------------------------------------------------------------
/chatbot/processing.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from tqdm.auto import tqdm
3 |
4 |
5 | class Processing:
6 | def __init__(self, tokenizer, train_frac):
7 | self.tokenizer = tokenizer
8 | self.train_frac = train_frac
9 |
10 | def _load_daily(self):
11 | dataset = load_dataset('daily_dialog')
12 | train_dialogues = dataset['train']['dialog']
13 | valid_dialogues = dataset['validation']['dialog']
14 | test_dialogues = dataset['test']['dialog']
15 |
16 | all_dialogues = train_dialogues + valid_dialogues + test_dialogues
17 |
18 | for i, dialogue in enumerate(tqdm(all_dialogues)):
19 | new_dialogue = []
20 | for sentence in dialogue:
21 | token_list = self.tokenizer.tokenize(sentence.strip().replace('’', '\''))
22 | token_list = self._process_token_list(token_list)
23 | text = self.tokenizer.convert_tokens_to_string(token_list)
24 | new_dialogue.append(text)
25 |
26 | all_dialogues[i] = new_dialogue
27 |
28 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)]
29 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):]
30 |
31 | return train_dialogues, valid_dialogues
32 |
33 | def _load_empathetic(self):
34 | dataset = load_dataset('empathetic_dialogues')
35 | train_data = dataset['train']
36 | valid_data = dataset['validation']
37 | test_data = dataset['test']
38 |
39 | sentences = train_data['utterance'] + valid_data['utterance'] + test_data['utterance']
40 | total_conv_ids = train_data['conv_id'] + valid_data['conv_id'] + test_data['conv_id']
41 | total_speaker_ids = train_data['speaker_idx'] + valid_data['speaker_idx'] + test_data['speaker_idx']
42 |
43 | conv_dict = {}
44 | cur_speaker_idx = -1
45 | for i, sentence in enumerate(tqdm(sentences)):
46 | conv_id = total_conv_ids[i]
47 | speaker_idx = total_speaker_ids[i]
48 |
49 | sentence_modified = sentence.strip().replace('_comma_', ',')
50 | new_token_list = self._process_token_list(self.tokenizer.tokenize(sentence_modified))
51 | text = self.tokenizer.convert_tokens_to_string(new_token_list)
52 |
53 | if '_conv' in sentence:
54 | continue
55 |
56 | if conv_id not in conv_dict:
57 | conv_dict[conv_id] = []
58 | cur_speaker_idx = -1
59 |
60 | if cur_speaker_idx != speaker_idx:
61 | conv_dict[conv_id].append(text)
62 | cur_speaker_idx = speaker_idx
63 | else:
64 | conv_dict[conv_id][-1] += f" {text}"
65 |
66 | train_dialogues = []
67 | valid_dialogues = []
68 |
69 | train_dialogue_num = int(len(conv_dict) * self.train_frac)
70 | for i, (conv_id, utter_list) in enumerate(conv_dict.items()):
71 | if i < train_dialogue_num:
72 | train_dialogues.append(utter_list)
73 | else:
74 | valid_dialogues.append(utter_list)
75 |
76 | return train_dialogues, valid_dialogues
77 |
78 | def _load_persona(self):
79 | import requests
80 |
81 | url = 'https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json'
82 | response = requests.get(url)
83 | assert response.status_code == 200, 'Error receiving data from server'
84 | dataset = response.json()
85 |
86 | train_data = dataset['train']
87 | valid_data = dataset['valid']
88 | all_data = train_data + valid_data
89 | all_dialogues = []
90 |
91 | for obj in tqdm(all_data):
92 | dialogue = obj['utterances'][-1]['history']
93 | new_dialogue = []
94 |
95 | for i, sentence in enumerate(dialogue):
96 | if sentence.strip() != '__ SILENCE __':
97 | token_list = self.tokenizer.tokenize(sentence.strip())
98 | new_token_list = self._process_token_list(token_list)
99 | text = self.tokenizer.convert_tokens_to_string(new_token_list)
100 | new_dialogue.append(text)
101 |
102 | all_dialogues.append(new_dialogue)
103 |
104 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)]
105 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):]
106 |
107 | return train_dialogues, valid_dialogues
108 |
109 | def _load_blended(self):
110 | dataset = load_dataset('blended_skill_talk')
111 | data_train = dataset['train']
112 | data_valid = dataset['validation']
113 | data_test = dataset['test']
114 |
115 | all_previous_sentences = data_train['previous_utterance'] + \
116 | data_valid['previous_utterance'] + \
117 | data_test['previous_utterance']
118 | all_free_messages = data_train['free_messages'] + \
119 | data_valid['free_messages'] + \
120 | data_test['free_messages']
121 | all_guided_messages = data_train['guided_messages'] + \
122 | data_valid['guided_messages'] + \
123 | data_test['guided_messages']
124 |
125 | all_dialogues = []
126 | for i, free_message in enumerate(tqdm(all_free_messages)):
127 | free_message_list = [sentence.strip() for sentence in free_message if len(sentence.strip()) > 0]
128 | guided_message_list = [sentence.strip() for sentence in all_guided_messages[i] if len(sentence.strip()) > 0]
129 | dialogue = all_previous_sentences[i]
130 |
131 | for j in range(len(free_message_list)):
132 | token_list = self._process_token_list(self.tokenizer.tokenize(free_message_list[j]))
133 | text = self.tokenizer.convert_tokens_to_string(token_list)
134 | dialogue.append(text)
135 |
136 | if j < len(guided_message_list):
137 | token_list = self._process_token_list(self.tokenizer.tokenize(guided_message_list[j]))
138 | text = self.tokenizer.convert_tokens_to_string(token_list)
139 | dialogue.append(text)
140 |
141 | all_dialogues.append(dialogue)
142 |
143 | train_dialogues = all_dialogues[:int(len(all_dialogues) * self.train_frac)]
144 | valid_dialogues = all_dialogues[int(len(all_dialogues) * self.train_frac):]
145 |
146 | return train_dialogues, valid_dialogues
147 |
148 | @staticmethod
149 | def _process_token_list(token_list):
150 | space = 'Ġ'
151 | quotes = ['"', '\'']
152 | end_marks = ['.', ',', '?', '!', '...']
153 | abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']
154 | token_list[0] = token_list[0].capitalize()
155 | quote_count = 0
156 |
157 | for i, token in enumerate(token_list):
158 | if space in token:
159 | if token[1:] in end_marks or token[1:] in abbreviations:
160 | token_list[i] = token[1:]
161 |
162 | if token[1:] == quotes[1]:
163 | if i < len(token_list) - 1:
164 | if token_list[i + 1] in abbreviations or (
165 | token_list[i + 1][0] == space and token_list[i + 1][1:] in abbreviations):
166 | token_list[i] = token[1:]
167 |
168 | if token[0] == space and token[1:] in quotes:
169 | if quote_count % 2 == 1:
170 | token_list[i] = token[1:]
171 | quote_count = 0
172 | else:
173 | if i < len(token_list) - 1 and token_list[i + 1][0] == space:
174 | token_list[i + 1] = token_list[i + 1][1:]
175 | quote_count += 1
176 |
177 | if token in end_marks or token[1:] in end_marks:
178 | if i < len(token_list) - 1:
179 | if token_list[i + 1][0] != space:
180 | token_list[i + 1] = space + token_list[i + 1].capitalize()
181 | else:
182 | token_list[i + 1] = space + token_list[i + 1][1:].capitalize()
183 |
184 | new_token_list = [token for token in token_list if token != space and len(token) > 0]
185 | if new_token_list[-1] not in end_marks:
186 | new_token_list.append(end_marks[0])
187 |
188 | return new_token_list
189 |
--------------------------------------------------------------------------------
/chatbot/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | from tqdm.auto import tqdm
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from torch.optim import AdamW
9 | from transformers import get_polynomial_decay_schedule_with_warmup
10 |
11 | from data import DialoguesDataset
12 | from utils import PadCollate
13 |
14 |
15 | class Trainer:
16 | def __init__(self, model, args):
17 | print('Loading the optimizer...')
18 | self.optimizer = AdamW(model.parameters(), lr=args['lr'])
19 | self.best_loss = 1e+10
20 | self.last_epoch = 0
21 |
22 | print('Loading train & valid data...')
23 | train_dataset = DialoguesDataset('train', args)
24 | valid_dataset = DialoguesDataset('valid', args)
25 | pad = PadCollate(args)
26 |
27 | self.train_loader = DataLoader(train_dataset,
28 | collate_fn=pad,
29 | shuffle=True,
30 | batch_size=args['batch_size'],
31 | num_workers=1,
32 | pin_memory=True)
33 | self.valid_loader = DataLoader(valid_dataset,
34 | collate_fn=pad,
35 | batch_size=args['batch_size'],
36 | num_workers=1,
37 | pin_memory=True)
38 |
39 | if not os.path.exists(args['models_dir']):
40 | os.makedirs(args['models_dir'])
41 |
42 | # Calculate total training steps
43 | num_batches = len(self.train_loader)
44 | total_train_steps = args['num_epochs'] * num_batches
45 | warmup_steps = int(args['warmup_ratio'] * total_train_steps)
46 |
47 | self.model = model
48 | self.args = args
49 | self.scheduler = get_polynomial_decay_schedule_with_warmup(
50 | self.optimizer,
51 | num_warmup_steps=warmup_steps,
52 | num_training_steps=total_train_steps,
53 | power=2
54 | )
55 |
56 | if args['checkpoint']:
57 | self._load_checkpoint()
58 |
59 | def train(self):
60 | print('Launch training...')
61 |
62 | start_epoch = self.last_epoch + 1
63 | for epoch in range(start_epoch, start_epoch + self.args['num_epochs']):
64 | print('-' * 50 + f'\nEpoch: {epoch}\n' + '-' * 50)
65 |
66 | self.model.train()
67 | train_losses = []
68 | train_perplexity = []
69 |
70 | for i, batch in enumerate(tqdm(self.train_loader)):
71 | input_ids, token_type_ids, labels = batch
72 | input_ids = input_ids.to(self.args['device'])
73 | token_type_ids = token_type_ids.to(self.args['device'])
74 | labels = labels.to(self.args['device'])
75 |
76 | outputs = self.model(
77 | input_ids=input_ids,
78 | token_type_ids=token_type_ids,
79 | labels=labels
80 | )
81 |
82 | loss, logits = outputs[0], outputs[1]
83 |
84 | self.optimizer.zero_grad()
85 | loss.backward()
86 | self.optimizer.step()
87 | self.scheduler.step()
88 |
89 | train_losses.append(loss.detach())
90 | ppx = torch.exp(loss.detach())
91 | train_perplexity.append(ppx)
92 |
93 | train_losses = [loss.item() for loss in train_losses]
94 | train_perplexity = [ppx.item() if not math.isinf(ppx.item()) else 1e+8 for ppx in train_perplexity]
95 | train_loss = np.mean(train_losses)
96 | train_ppx = np.mean(train_perplexity)
97 | print(f'Train loss: {train_loss} \nTrain perplexity: {train_ppx}')
98 |
99 | self.last_epoch += 1
100 |
101 | valid_loss, valid_ppx = self.validate()
102 |
103 | if valid_loss < self.best_loss:
104 | self.best_loss = valid_loss
105 | state_dict = {
106 | 'model_state_dict': self.model.state_dict(),
107 | 'optim_state_dict': self.optimizer.state_dict(),
108 | 'loss': self.best_loss,
109 | 'epoch': self.last_epoch
110 | }
111 |
112 | filename = f"{self.args['model_dir']}/model_best_{round(self.best_loss, 4)}.h5"
113 | torch.save(state_dict, filename)
114 | print(f'Checkpoint saved: {filename}')
115 |
116 | print(f'Best valid loss: {self.best_loss}')
117 | print(f'Valid loss: {valid_loss} \nValid perplexity: {valid_ppx}')
118 |
119 | print('Training completed')
120 |
121 | def validate(self):
122 | print('Launch validation...')
123 | self.model.eval()
124 |
125 | valid_losses = []
126 | valid_ppxs = []
127 | with torch.no_grad():
128 | for i, batch in enumerate(tqdm(self.valid_loader)):
129 | input_ids, token_type_ids, labels = batch
130 | input_ids = input_ids.to(self.args['device'])
131 | token_type_ids = token_type_ids.to(self.args['device'])
132 | labels = labels.to(self.args['device'])
133 |
134 | outputs = self.model(
135 | input_ids=input_ids,
136 | token_type_ids=token_type_ids,
137 | labels=labels
138 | )
139 |
140 | loss, logits = outputs[0], outputs[1]
141 |
142 | valid_losses.append(loss.detach())
143 | ppx = torch.exp(loss.detach())
144 | valid_ppxs.append(ppx)
145 |
146 | valid_losses = [loss.item() for loss in valid_losses]
147 | valid_ppxs = [ppx.item() if not math.isinf(ppx.item()) else 1e+8 for ppx in valid_ppxs]
148 | valid_loss = np.mean(valid_losses)
149 | valid_ppx = np.mean(valid_ppxs)
150 |
151 | if math.isnan(valid_ppx):
152 | valid_ppx = 1e+8
153 |
154 | return valid_loss, valid_ppx
155 |
156 | def _load_checkpoint(self):
157 | path = self.args['checkpoint']
158 | if os.path.exists(path):
159 | print('Loading checkpoint...')
160 | checkpoint = torch.load(path, map_location=self.args['device'])
161 | self.model.load_state_dict(checkpoint['model_state_dict'])
162 |
163 | print(f'The training restarts with the specified checkpoint: {os.path.basename(path)}')
164 | self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
165 | self.best_loss = checkpoint['loss']
166 | self.last_epoch = checkpoint['epoch']
167 | else:
168 | print("Can't find the specified checkpoint")
--------------------------------------------------------------------------------
/chatbot/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | import nltk
5 | from nltk.stem import WordNetLemmatizer
6 |
7 | import torch
8 | from torch.nn.utils.rnn import pad_sequence
9 |
10 |
11 | class PadCollate:
12 | def __init__(self, args):
13 | self.args = args
14 |
15 | def __call__(self, batch):
16 | eos_id = self.args['eos_id']
17 | input_ids, token_type_ids, labels = [], [], []
18 |
19 | for idx, seqs in enumerate(batch):
20 | input_ids.append(torch.LongTensor(seqs[0]))
21 | token_type_ids.append(torch.LongTensor(seqs[0]))
22 | labels.append(torch.LongTensor(seqs[2]))
23 |
24 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=eos_id)
25 | token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=eos_id)
26 | labels = pad_sequence(labels, batch_first=True, padding_value=-100)
27 |
28 | return input_ids, token_type_ids, labels
29 |
30 |
31 | def set_seed(seed):
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 | torch.cuda.manual_seed_all(seed)
35 | random.seed(seed)
36 |
37 |
38 | def top_k_filter(logits, top_k=0., threshold=-float('Inf'), filter_value=-float('Inf')):
39 | assert logits.dim() == 1
40 | top_k = min(top_k, logits.size(-1))
41 |
42 | if top_k > 0:
43 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
44 | logits[indices_to_remove] = filter_value
45 |
46 | indices_to_remove = logits < threshold
47 | logits[indices_to_remove] = filter_value
48 |
49 | return logits
50 |
51 |
52 | def lemma_sentence(text):
53 | lemmatizer = WordNetLemmatizer()
54 | stop_words = ['stop', 'the', 'to', 'and', 'a', 'in', 'it', '\'s', 'is', 'I', 'that', 'had', 'on', 'for', 'were', 'was']
55 | tokenization = [word for word in nltk.word_tokenize(text) if word not in stop_words]
56 | sentence = ' '.join([lemmatizer.lemmatize(word) for word in tokenization])
57 | return sentence
58 |
--------------------------------------------------------------------------------
/dataset/train_dialogues.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/train_dialogues.pickle
--------------------------------------------------------------------------------
/dataset/train_ids.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/train_ids.pickle
--------------------------------------------------------------------------------
/dataset/valid_dialogues.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/valid_dialogues.pickle
--------------------------------------------------------------------------------
/dataset/valid_ids.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xcapt0/gpt2_chatbot/3f63c85a6c58c7c95609c172422e3473eca10294/dataset/valid_ids.pickle
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.1
2 | aiosignal==1.2.0
3 | async-timeout==4.0.2
4 | attrs==21.3.0
5 | certifi==2021.10.8
6 | charset-normalizer==2.0.9
7 | click==8.0.3
8 | colorama==0.4.4
9 | datasets==1.17.0
10 | dill==0.3.4
11 | filelock==3.4.2
12 | frozenlist==1.2.0
13 | fsspec==2021.11.1
14 | huggingface-hub==0.2.1
15 | idna==3.3
16 | joblib==1.1.0
17 | multidict==5.2.0
18 | multiprocess==0.70.12.2
19 | nltk==3.6.7
20 | numpy==1.21.5
21 | packaging==21.3
22 | pandas==1.3.5
23 | pyarrow==6.0.1
24 | pyparsing==3.0.6
25 | python-dateutil==2.8.2
26 | pytz==2021.3
27 | PyYAML==6.0
28 | regex==2021.11.10
29 | requests==2.26.0
30 | sacremoses==0.0.46
31 | six==1.16.0
32 | tokenizers==0.10.3
33 | torch==1.10.1
34 | tqdm==4.62.3
35 | transformers==4.15.0
36 | typing_extensions==4.0.1
37 | urllib3==1.26.7
38 | xxhash==2.0.2
39 | yarl==1.7.2
40 |
--------------------------------------------------------------------------------