├── code ├── generation │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── text.py │ │ ├── optim.py │ │ ├── utils.py │ │ ├── dataset.py │ │ ├── transformer_module.py │ │ ├── trainer.py │ │ ├── postprocessing.py │ │ ├── trainer_kd.py │ │ └── transformer_model.py │ ├── run_dialog_kd.py │ ├── run.py │ ├── run_dialog.py │ ├── metrics.py │ ├── train.py │ ├── train_kd.py │ └── config.py └── retrieval │ ├── readme │ ├── run.sh │ ├── pytorch_pretrained_bert │ ├── __init__.py │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── convert_gpt2_checkpoint_to_pytorch.py │ ├── convert_openai_checkpoint_to_pytorch.py │ ├── __main__.py │ ├── convert_transfo_xl_checkpoint_to_pytorch.py │ ├── optimization_openai.py │ ├── file_utils.py │ ├── tokenization_gpt2.py │ ├── tokenization_openai.py │ ├── optimization.py │ └── tokenization.py │ └── metrics.py └── README.md /code/generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/generation/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/retrieval/readme: -------------------------------------------------------------------------------- 1 | Modified from https://github.com/huggingface/transformers. -------------------------------------------------------------------------------- /code/retrieval/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## for training 4 | python run_distilling.py \ 5 | --task_name buy_data \ 6 | --do_train \ 7 | --do_eval \ 8 | --do_lower_case \ 9 | --data_dir $data_dir \ 10 | --teacher_bert_model $teacher \ 11 | --student_bert_model $student \ 12 | --max_seq_length 64 \ 13 | --train_batch_size 128 \ 14 | --eval_batch_size 128 \ 15 | --learning_rate 2e-5 \ 16 | --num_train_epochs 10.0 \ 17 | --output_dir $output \ 18 | --temperature 1 --alpha 0.5 19 | 20 | ## for testing 21 | python run_ranker_test.py \ 22 | --task_name bug_data \ 23 | --do_eval \ 24 | --do_lower_case \ 25 | --data_dir $data_dir \ 26 | --bert_model $model \ 27 | --max_seq_length 64 \ 28 | --train_batch_size 128 \ 29 | --eval_batch_size 128 \ 30 | --learning_rate 2e-5 \ 31 | --num_train_epochs 10.0 \ 32 | --output_dir $output 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dialogue Distillation 2 | 3 | code/data for EMNLP'2020 long paper "[Dialogue Distillation: Open-domain Dialogue Augmentation Using Unpaired Data](https://arxiv.org/abs/2009.09427)" 4 | 5 | ## Code 6 | 7 | `code/generation`: code of generation based dialogue model 8 | 9 | `code/retrieval`: code of retrieval based dialogue model 10 | 11 | ## Data 12 | 13 | The data we used can be downloaded from this [link](https://drive.google.com/file/d/1mNQf7QydWGhxPE1-1IW0yfwSLJJ9zVG7/view?usp=sharing) 14 | 15 | ## Citation 16 | 17 | Please cite our EMNLP paper if you find our work useful :) 18 | 19 | @inproceedings{zhang2020distill, 20 | title={Dialogue Distillation: Open-domain Dialogue Augmentation Using Unpaired Data}, 21 | author={Zheng, Rongsheng and Zheng, Yinhe and Shao, Jianzhi and Mao, Xiaoxi and Xi, Yadong and Huang, Minlie}, 22 | booktitle={EMNLP}, 23 | year={2020}, 24 | url={https://arxiv.org/abs/2009.09427} 25 | } -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | # from .tokenization_gpt2 import GPT2Tokenizer # omit regex install error 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 25 | -------------------------------------------------------------------------------- /code/generation/model/loss.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | def __init__(self, n_labels, smoothing=0.0, ignore_index=-100, size_average=True): 23 | super(LabelSmoothingLoss, self).__init__() 24 | assert 0 <= smoothing <= 1 25 | 26 | self.ignore_index = ignore_index 27 | self.confidence = 1 - smoothing 28 | 29 | if smoothing > 0: 30 | print('using label smoothing') 31 | self.criterion = nn.KLDivLoss(size_average=size_average) 32 | n_ignore_idxs = 1 + (ignore_index >= 0) 33 | one_hot = torch.full((1, n_labels), fill_value=(smoothing / (n_labels - n_ignore_idxs))) 34 | if ignore_index >= 0: 35 | one_hot[0, ignore_index] = 0 36 | self.register_buffer('one_hot', one_hot) 37 | else: 38 | self.criterion = nn.NLLLoss(size_average=size_average, ignore_index=ignore_index) 39 | 40 | def forward(self, log_inputs, targets): 41 | if self.confidence < 1: 42 | tdata = targets.data 43 | 44 | tmp = self.one_hot.repeat(targets.shape[0], 1) 45 | tmp.scatter_(1, tdata.unsqueeze(1), self.confidence) 46 | 47 | if self.ignore_index >= 0: 48 | mask = torch.nonzero(tdata.eq(self.ignore_index)).squeeze(-1) 49 | if mask.numel() > 0: 50 | tmp.index_fill_(0, mask, 0) 51 | 52 | targets = tmp 53 | 54 | return self.criterion(log_inputs, targets) 55 | -------------------------------------------------------------------------------- /code/generation/model/text.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | class myVocab: 18 | spl = '

' 19 | pad = '' 20 | eos = '' 21 | unk = '' 22 | 23 | def __init__(self, vocab_file): 24 | # TODO: add check for special tokens 25 | self.spec_tokens = [myVocab.spl, myVocab.pad, myVocab.eos, myVocab.unk] 26 | with open(vocab_file, 'r', encoding='utf8') as fr: 27 | vocab = [line.strip('\n').split()[0] for line in fr.readlines()] 28 | vocab = self.spec_tokens + vocab 29 | self.token2id = {t: i for i, t in enumerate(vocab)} 30 | self.id2token = {i: t for i, t in enumerate(vocab)} 31 | 32 | def __len__(self): 33 | return len(self.token2id) 34 | 35 | @property 36 | def n_special_tokens(self): 37 | return len(self.spec_tokens) 38 | 39 | @property 40 | def special_tokens_ids(self): 41 | return [self.token2id[t] for t in self.spec_tokens] 42 | 43 | @property 44 | def pad_id(self): 45 | return self.token2id[myVocab.pad] 46 | 47 | @property 48 | def spl_id(self): 49 | return self.token2id[myVocab.spl] 50 | 51 | @property 52 | def bos_id(self): 53 | return self.token2id[myVocab.eos] 54 | 55 | @property 56 | def eos_id(self): 57 | return self.token2id[myVocab.eos] 58 | 59 | def string2ids(self, string): 60 | tokens = string.split() 61 | ids = [self.token2id[t] for t in tokens if t in self.token2id] 62 | return ids 63 | 64 | def ids2string(self, ids): 65 | tokens = [self.id2token[id] for id in ids] 66 | return ''.join(tokens) 67 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /code/generation/run_dialog_kd.py: -------------------------------------------------------------------------------- 1 | from run_dialog_interface import dialog 2 | from metrics import calc_f1, calc_bleu, calc_distinct, calc_avg_len 3 | 4 | bz1 = 5; lp1 = 1.6 5 | bz2 = 5; lp2 = 1.0 6 | bz3 = 5; lp3 = 2.0 7 | bz4 = 5; lp4 = 1.6 8 | bz5 = 5; lp5 = 1.6 9 | 10 | crowded = dialog('/root/generation_with_augmentation/checkpoints/dialog_v2/crowded/last_checkpoint15', bz1, lp1) 11 | fake = dialog('/root/generation_with_augmentation/checkpoints/dialog_fake_v2/last_checkpoint7', bz2, lp2) 12 | fake_kd = dialog('/root/generation_with_augmentation/checkpoints/dialog_kd/fake/last_checkpoint28', bz3, lp3) 13 | crowded_fake = dialog('/root/generation_with_augmentation/checkpoints/dialog_v2/crowded_fake/last_checkpoint18', bz4, lp4) 14 | crowded_fake_kd = dialog('/root/generation_with_augmentation/checkpoints/dialog_kd/crowded_fake/last_checkpoint29', bz5, lp5) 15 | 16 | 17 | crowded_pairs = [] 18 | fake_pairs = [] 19 | fake_kd_pairs = [] 20 | crowded_fake_pairs = [] 21 | crowded_fake_kd_pairs = [] 22 | 23 | if __name__ == '__main__': 24 | with open('dataset/dialog/test_1k.txt', 'r', encoding='utf8') as fr: 25 | lines = fr.readlines() 26 | cnt = 0 27 | with open('dataset/dialog/test_1k_output_kd_bm5_lp_diff_v2.txt', 'w', encoding='utf8') as fw: 28 | fw.write('question\tanswer\tcrowded\tfake\tfake_kd\tcrowded_fake\tcrowded_fake_kd\n') 29 | for line in lines: 30 | cnt += 1 31 | if cnt%100==0: 32 | print(cnt) 33 | q,a = line.strip('\n').split('\t') 34 | a1 = crowded.answer_beams(q) 35 | crowded_pairs.append([list(a1), list(a)]) 36 | a2 = fake.answer_beams(q) 37 | fake_pairs.append([list(a2), list(a)]) 38 | a3 = fake_kd.answer_beams(q) 39 | fake_kd_pairs.append([list(a3), list(a)]) 40 | a4 = crowded_fake.answer_beams(q) 41 | crowded_fake_pairs.append([list(a4), list(a)]) 42 | a5 = crowded_fake_kd.answer_beams(q) 43 | crowded_fake_kd_pairs.append([list(a5), list(a)]) 44 | res = [q, a, a1, a2, a3, a4, a5] 45 | fw.write('\t'.join(res) + '\n') 46 | for res in [crowded_pairs, fake_pairs, fake_kd_pairs, crowded_fake_pairs, crowded_fake_kd_pairs]: 47 | f1 = calc_f1(res) 48 | bleu = calc_bleu(res) 49 | distinct = calc_distinct(res) 50 | avg_len = calc_avg_len(res) 51 | print('f1: ', f1) 52 | print('bleu: ', bleu) 53 | print('distinct: ', distinct) 54 | print('avg_len: ', avg_len) 55 | ''' 56 | while True: 57 | message = input('>') 58 | print('crowded', crowded.answer_beams(message)) 59 | print('crowded_fake', crowded_fake.answer_beams(message)) 60 | print('fake', fake.answer_beams(message)) 61 | ''' -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /code/retrieval/metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def mean_average_precision(sort_data): 5 | # to do 6 | count_1 = 0 7 | sum_precision = 0 8 | for index in range(len(sort_data)): 9 | if sort_data[index][1] == 1: 10 | count_1 += 1 11 | sum_precision += 1.0 * count_1 / (index + 1) 12 | return sum_precision / count_1 if count_1 != 0 else 0.0 13 | 14 | 15 | def mean_reciprocal_rank(sort_data): 16 | sort_lable = [s_d[1] for s_d in sort_data] 17 | assert 1 in sort_lable 18 | return 1.0 / (1 + sort_lable.index(1)) 19 | 20 | 21 | def precision_at_position_1(sort_data): 22 | if sort_data[0][1] == 1: 23 | return 1 24 | else: 25 | return 0 26 | 27 | 28 | def recall_at_position_k_in_10(sort_data, k): 29 | sort_lable = [s_d[1] for s_d in sort_data] 30 | select_lable = sort_lable[:k] 31 | return 1.0 * select_lable.count(1) / sort_lable.count(1) 32 | 33 | 34 | def evaluation_one_session(data): 35 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 36 | m_a_p = mean_average_precision(sort_data) 37 | m_r_r = mean_reciprocal_rank(sort_data) 38 | p_1 = precision_at_position_1(sort_data) 39 | r_1 = recall_at_position_k_in_10(sort_data, 1) 40 | r_2 = recall_at_position_k_in_10(sort_data, 2) 41 | r_5 = recall_at_position_k_in_10(sort_data, 5) 42 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5 43 | 44 | 45 | def evaluate(file_path, n=10): 46 | sum_m_a_p = 0 47 | sum_m_r_r = 0 48 | sum_p_1 = 0 49 | sum_r_1 = 0 50 | sum_r_2 = 0 51 | sum_r_5 = 0 52 | 53 | i = 0 54 | total_num = 0 55 | data = None 56 | with open(file_path, 'r') as infile: 57 | for line in infile: 58 | if i % n == 0: 59 | data = [] 60 | 61 | tokens = line.strip().split('\t') 62 | if len(tokens) < 2: 63 | print('i', i, 'tokens', tokens) 64 | data.append((float(tokens[0]), int(tokens[1]))) 65 | 66 | if i % n == n - 1: 67 | if 1 not in [s_d[1] for s_d in data]: 68 | continue 69 | total_num += 1 70 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data) 71 | sum_m_a_p += m_a_p 72 | sum_m_r_r += m_r_r 73 | sum_p_1 += p_1 74 | sum_r_1 += r_1 75 | sum_r_2 += r_2 76 | sum_r_5 += r_5 77 | 78 | i += 1 79 | 80 | print('total num: %s' % total_num) 81 | print('MAP: %s'.format(1.0*sum_m_a_p/total_num)) 82 | print('MRR: {}'.format(1.0*sum_m_r_r/total_num)) 83 | print('P@1: {}'.format(1.0*sum_p_1/total_num)) 84 | print('R{}@1: {}'.format(n, str(1.0*sum_r_1/total_num))) 85 | print('R{}@2: {}'.format(n, str(1.0*sum_r_2/total_num))) 86 | print('R{}@5: {}'.format(n, str(1.0*sum_r_5/total_num))) 87 | # print('R10@1: %s' %(1.0*sum_r_1/total_num)) 88 | # print('R10@2: %s' %(1.0*sum_r_2/total_num)) 89 | # print('R10@5: %s' %(1.0*sum_r_5/total_num)) 90 | return (1.0 * sum_m_a_p / total_num, 1.0 * sum_m_r_r / total_num, 1.0 * sum_p_1 / total_num, 91 | 1.0 * sum_r_1 / total_num, 1.0 * sum_r_2 / total_num, 1.0 * sum_r_5 / total_num) 92 | 93 | 94 | if __name__ == '__main__': 95 | print(sys.argv) 96 | result = evaluate(sys.argv[1], sys.argv[2]) 97 | # for r in result: 98 | # print(r) 99 | 100 | 101 | """ 102 | i = 0 103 | cnt = 0 104 | utterances, responses, labels = [], [], [] 105 | for line in open('data/test.txt'): 106 | contexts = line.strip().split('\t') 107 | uttes, resp, l = contexts[1:-1], contexts[-1], contexts[0] 108 | uttes = [utte.split() for utte in uttes] 109 | labels.append(int(l)) 110 | if i % 10 == 9: 111 | if 1 in labels[i-9:]: 112 | cnt += 10 113 | i += 1 114 | output: 115 | i: 10000 116 | cnt: 6670 117 | """ 118 | -------------------------------------------------------------------------------- /code/generation/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from model.utils import load_openai_weights_chinese, set_seed, f1_score 4 | from model.transformer_model import TransformerModel 5 | from model.text import myVocab 6 | from config import get_model_config_poem, get_test_config_poem 7 | import readline 8 | 9 | 10 | def main(): 11 | model_config = get_model_config_poem() 12 | test_config = get_test_config_poem() 13 | 14 | set_seed(test_config.seed) 15 | device = torch.device(test_config.device) 16 | 17 | vocab = myVocab(model_config.vocab_path) 18 | 19 | transformer = TransformerModel(n_layers=model_config.n_layers, 20 | n_embeddings=len(vocab), 21 | n_pos_embeddings=model_config.n_pos_embeddings, 22 | embeddings_size=model_config.embeddings_size, 23 | padding_idx=vocab.pad_id, 24 | n_heads=model_config.n_heads, 25 | dropout=model_config.dropout, 26 | embed_dropout=model_config.embed_dropout, 27 | attn_dropout=model_config.attn_dropout, 28 | ff_dropout=model_config.ff_dropout, 29 | bos_id=vocab.bos_id, 30 | eos_id=vocab.eos_id, 31 | max_seq_len=model_config.max_seq_len, 32 | beam_size=model_config.beam_size, 33 | length_penalty=model_config.length_penalty, 34 | n_segments=model_config.n_segments, 35 | annealing_topk=model_config.annealing_topk, 36 | annealing=model_config.annealing, 37 | diversity_coef=model_config.diversity_coef, 38 | diversity_groups=model_config.diversity_groups) 39 | 40 | transformer = transformer.to(device) 41 | state_dict = torch.load(test_config.last_checkpoint_path, map_location=device) 42 | temp = dict(state_dict['model']) 43 | keys = list(temp.keys()) 44 | for key in keys: 45 | # new_key = '.'.join([i for i in key.split('.') if i != 'module']) 46 | new_key = key.replace('.module', '') 47 | temp[new_key] = temp.pop(key) 48 | transformer.load_state_dict(temp) 49 | transformer.eval() 50 | print('Weights loaded from {}'.format(test_config.last_checkpoint_path)) 51 | 52 | 53 | def answer(message): 54 | message = ' '.join(message) 55 | message = vocab.string2ids(message) 56 | message = [vocab.bos_id] + message + [vocab.eos_id] 57 | message = message[:60] 58 | # print(message) 59 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0] 60 | prediction = transformer.predict(contexts)[0] 61 | prediction_str = vocab.ids2string(prediction) 62 | return prediction_str 63 | 64 | def answer_beams(message): 65 | message = ' '.join(message) 66 | message = vocab.string2ids(message) 67 | message = [vocab.bos_id] + message + [vocab.eos_id] 68 | message = message[:20] 69 | # print(message) 70 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0] 71 | predictions = transformer.predict_beams(contexts)[0] 72 | prediction_strs = [vocab.ids2string(prediction) for prediction in predictions] 73 | return prediction_strs 74 | 75 | ''' 76 | with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw: 77 | with open('data/test200.txt', 'r', encoding='utf8') as fr: 78 | lines = fr.readlines() 79 | for line in lines: 80 | post, response = line.strip('\n').replace(' ', '').split('\t') 81 | ans = answer(post) 82 | fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n') 83 | ''' 84 | ''' 85 | while True: 86 | message = input('>') 87 | ans = answer(message) 88 | print(ans) 89 | ''' 90 | 91 | while True: 92 | message = input('>') 93 | ans = answer_beams(message) 94 | for i in ans: 95 | print(i) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() -------------------------------------------------------------------------------- /code/generation/run_dialog.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from model.utils import load_openai_weights_chinese, set_seed, f1_score 4 | from model.transformer_model import TransformerModel 5 | from model.text import myVocab 6 | from config import get_model_config_dialog, get_test_config_dialog 7 | import readline 8 | 9 | 10 | def main(): 11 | model_config = get_model_config_dialog() 12 | test_config = get_test_config_dialog() 13 | 14 | set_seed(test_config.seed) 15 | device = torch.device(test_config.device) 16 | 17 | vocab = myVocab(model_config.vocab_path) 18 | 19 | transformer = TransformerModel(n_layers=model_config.n_layers, 20 | n_embeddings=len(vocab), 21 | n_pos_embeddings=model_config.n_pos_embeddings, 22 | embeddings_size=model_config.embeddings_size, 23 | padding_idx=vocab.pad_id, 24 | n_heads=model_config.n_heads, 25 | dropout=model_config.dropout, 26 | embed_dropout=model_config.embed_dropout, 27 | attn_dropout=model_config.attn_dropout, 28 | ff_dropout=model_config.ff_dropout, 29 | bos_id=vocab.bos_id, 30 | eos_id=vocab.eos_id, 31 | max_seq_len=model_config.max_seq_len, 32 | beam_size=model_config.beam_size, 33 | length_penalty=model_config.length_penalty, 34 | n_segments=model_config.n_segments, 35 | annealing_topk=model_config.annealing_topk, 36 | annealing=model_config.annealing, 37 | diversity_coef=model_config.diversity_coef, 38 | diversity_groups=model_config.diversity_groups) 39 | 40 | transformer = transformer.to(device) 41 | state_dict = torch.load(test_config.last_checkpoint_path, map_location=device) 42 | temp = dict(state_dict['model']) 43 | keys = list(temp.keys()) 44 | for key in keys: 45 | # new_key = '.'.join([i for i in key.split('.') if i != 'module']) 46 | new_key = key.replace('.module', '') 47 | temp[new_key] = temp.pop(key) 48 | transformer.load_state_dict(temp) 49 | transformer.eval() 50 | print('Weights loaded from {}'.format(test_config.last_checkpoint_path)) 51 | 52 | 53 | def answer(message): 54 | message = ' '.join(message) 55 | message = vocab.string2ids(message) 56 | message = [vocab.bos_id] + message + [vocab.eos_id] 57 | message = message[:60] 58 | # print(message) 59 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0] 60 | prediction = transformer.predict(contexts)[0] 61 | prediction_str = vocab.ids2string(prediction) 62 | return prediction_str 63 | 64 | def answer_beams(message): 65 | message = ' '.join(message) 66 | message = vocab.string2ids(message) 67 | message = [vocab.bos_id] + message + [vocab.eos_id] 68 | message = message[:30] 69 | # print(message) 70 | contexts = [torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0] 71 | predictions = transformer.predict_beams(contexts)[0] 72 | prediction_strs = [vocab.ids2string(prediction) for prediction in predictions] 73 | return prediction_strs 74 | 75 | ''' 76 | with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw: 77 | with open('data/test200.txt', 'r', encoding='utf8') as fr: 78 | lines = fr.readlines() 79 | for line in lines: 80 | post, response = line.strip('\n').replace(' ', '').split('\t') 81 | ans = answer(post) 82 | fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n') 83 | ''' 84 | ''' 85 | while True: 86 | message = input('>') 87 | ans = answer(message) 88 | print(ans) 89 | ''' 90 | 91 | while True: 92 | message = input('>') 93 | ans = answer_beams(message) 94 | for i in ans: 95 | print(i) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /code/generation/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | 4 | def get_dict(tokens, ngram, gdict=None): 5 | """ 6 | get_dict 7 | """ 8 | token_dict = {} 9 | if gdict is not None: 10 | token_dict = gdict 11 | tlen = len(tokens) 12 | for i in range(0, tlen - ngram + 1): 13 | ngram_token = "".join(tokens[i:(i + ngram)]) 14 | if token_dict.get(ngram_token) is not None: 15 | token_dict[ngram_token] += 1 16 | else: 17 | token_dict[ngram_token] = 1 18 | return token_dict 19 | 20 | 21 | def count(pred_tokens, gold_tokens, ngram, result): 22 | """ 23 | count 24 | """ 25 | cover_count, total_count = result 26 | pred_dict = get_dict(pred_tokens, ngram) 27 | gold_dict = get_dict(gold_tokens, ngram) 28 | cur_cover_count = 0 29 | cur_total_count = 0 30 | for token, freq in pred_dict.items(): 31 | if gold_dict.get(token) is not None: 32 | gold_freq = gold_dict[token] 33 | cur_cover_count += min(freq, gold_freq) 34 | cur_total_count += freq 35 | result[0] += cur_cover_count 36 | result[1] += cur_total_count 37 | 38 | 39 | def calc_bp(pair_list): 40 | """ 41 | calc_bp 42 | """ 43 | c_count = 0.0 44 | r_count = 0.0 45 | for pair in pair_list: 46 | pred_tokens, gold_tokens = pair 47 | c_count += len(pred_tokens) 48 | r_count += len(gold_tokens) 49 | bp = 1 50 | if c_count < r_count: 51 | bp = math.exp(1 - r_count / c_count) 52 | return bp 53 | 54 | 55 | def calc_cover_rate(pair_list, ngram): 56 | """ 57 | calc_cover_rate 58 | """ 59 | result = [0.0, 0.0] # [cover_count, total_count] 60 | for pair in pair_list: 61 | pred_tokens, gold_tokens = pair 62 | count(pred_tokens, gold_tokens, ngram, result) 63 | if result[1] == 0: 64 | cover_rate = 0 65 | else: 66 | cover_rate = result[0] / result[1] 67 | return cover_rate 68 | 69 | 70 | def calc_bleu(pair_list): 71 | """ 72 | calc_bleu: [[predict, golden], ...] 73 | """ 74 | bp = calc_bp(pair_list) 75 | print('bp: ', bp) 76 | cover_rate1 = calc_cover_rate(pair_list, 1) 77 | cover_rate2 = calc_cover_rate(pair_list, 2) 78 | cover_rate3 = calc_cover_rate(pair_list, 3) 79 | bleu1 = 0 80 | bleu2 = 0 81 | bleu3 = 0 82 | if cover_rate1 > 0: 83 | bleu1 = bp * math.exp(math.log(cover_rate1)) 84 | if cover_rate2 > 0: 85 | bleu2 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2)) / 2) 86 | if cover_rate3 > 0: 87 | bleu3 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2) + math.log(cover_rate3)) / 3) 88 | return [bleu1, bleu2] 89 | 90 | 91 | def calc_distinct_ngram(pair_list, ngram): 92 | """ 93 | calc_distinct_ngram 94 | """ 95 | ngram_total = 0.0 96 | ngram_distinct_count = 0.0 97 | pred_dict = {} 98 | for predict_tokens, _ in pair_list: 99 | get_dict(predict_tokens, ngram, pred_dict) 100 | for key, freq in pred_dict.items(): 101 | ngram_total += freq 102 | ngram_distinct_count += 1 103 | #if freq == 1: 104 | # ngram_distinct_count += freq 105 | if ngram_total == 0: 106 | return 0 107 | return ngram_distinct_count / ngram_total 108 | 109 | 110 | def calc_distinct(pair_list): 111 | """ 112 | calc_distinct 113 | """ 114 | distinct1 = calc_distinct_ngram(pair_list, 1) 115 | distinct2 = calc_distinct_ngram(pair_list, 2) 116 | return [distinct1, distinct2] 117 | 118 | 119 | def calc_f1(data): 120 | """ 121 | calc_f1 122 | """ 123 | golden_char_total = 0.0 124 | pred_char_total = 0.0 125 | hit_char_total = 0.0 126 | for response, golden_response in data: 127 | #golden_response = "".join(golden_response).decode("utf8") 128 | #response = "".join(response).decode("utf8") 129 | golden_response = "".join(golden_response) 130 | response = "".join(response) 131 | common = Counter(response) & Counter(golden_response) 132 | hit_char_total += sum(common.values()) 133 | golden_char_total += len(golden_response) 134 | pred_char_total += len(response) 135 | if pred_char_total == 0: 136 | p = 0 137 | else: 138 | p = hit_char_total / pred_char_total 139 | if golden_char_total == 0: 140 | r = 0 141 | else: 142 | r = hit_char_total / golden_char_total 143 | if p + r == 0: 144 | f1 = 0 145 | else: 146 | f1 = 2 * p * r / (p + r) 147 | return f1 148 | 149 | def calc_avg_len(data): 150 | """ 151 | calc average length 152 | """ 153 | all_len = 0 154 | for response, golden_response in data: 155 | all_len += len(response) 156 | if len(data) == 0: 157 | return 0 158 | return all_len/len(data) -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /code/generation/model/optim.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import math 18 | import torch 19 | 20 | 21 | class Adam(torch.optim.Optimizer): 22 | """Implements Adam algorithm. 23 | This implementation is modified from torch.optim.Adam based on: 24 | `Fixed Weight Decay Regularization in Adam` 25 | (see https://arxiv.org/abs/1711.05101) 26 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 27 | Arguments: 28 | params (iterable): iterable of parameters to optimize or dicts defining 29 | parameter groups 30 | lr (float, optional): learning rate (default: 1e-3) 31 | betas (Tuple[float, float], optional): coefficients used for computing 32 | running averages of gradient and its square (default: (0.9, 0.999)) 33 | eps (float, optional): term added to the denominator to improve 34 | numerical stability (default: 1e-8) 35 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 36 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 37 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 38 | .. _Adam\: A Method for Stochastic Optimization: 39 | https://arxiv.org/abs/1412.6980 40 | .. _On the Convergence of Adam and Beyond: 41 | https://openreview.net/forum?id=ryQu7f-RZ 42 | """ 43 | 44 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): 45 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 46 | super(Adam, self).__init__(params, defaults) 47 | 48 | def step(self, closure=None): 49 | """Performs a single optimization step. 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 65 | amsgrad = group['amsgrad'] 66 | 67 | state = self.state[p] 68 | 69 | # State initialization 70 | if len(state) == 0: 71 | state['step'] = 0 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p.data) 74 | # Exponential moving average of squared gradient values 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | if amsgrad: 77 | # Maintains max of all exp. moving avg. of sq. grad. values 78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 79 | 80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 81 | if amsgrad: 82 | max_exp_avg_sq = state['max_exp_avg_sq'] 83 | beta1, beta2 = group['betas'] 84 | 85 | state['step'] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 90 | if amsgrad: 91 | # Maintains the maximum of all 2nd moment running avg. till now 92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 93 | # Use the max. for normalizing running avg. of gradient 94 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 95 | else: 96 | denom = exp_avg_sq.sqrt().add_(group['eps']) 97 | 98 | bias_correction1 = 1 - beta1 ** state['step'] 99 | bias_correction2 = 1 - beta2 ** state['step'] 100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 101 | 102 | if group['weight_decay'] != 0: 103 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 104 | 105 | p.data.addcdiv_(-step_size, exp_avg, denom) 106 | 107 | return loss 108 | 109 | 110 | class NoamOpt: 111 | def __init__(self, embeddings_size, factor, warmup, optimizer): 112 | self.embeddings_size = embeddings_size 113 | self.factor = factor 114 | self.warmup = warmup 115 | self.optimizer = optimizer 116 | 117 | self._step = 1 118 | 119 | def state_dict(self): 120 | return {'step': self._step, 121 | 'optimizer': self.optimizer.state_dict()} 122 | 123 | def load_state_dict(self, state_dict): 124 | self._step = state_dict['step'] 125 | self.optimizer.load_state_dict(state_dict['optimizer']) 126 | 127 | def zero_grad(self): 128 | return self.optimizer.zero_grad() 129 | 130 | @property 131 | def param_groups(self): 132 | return self.optimizer.param_groups 133 | 134 | def step(self): 135 | self._step += 1 136 | rate = self.rate() 137 | for p in self.optimizer.param_groups: 138 | p['lr'] = rate 139 | self.optimizer.step() 140 | 141 | def rate(self, step=None): 142 | if step is None: 143 | step = self._step 144 | 145 | return self.factor * (self.embeddings_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine':warmup_cosine, 47 | 'warmup_constant':warmup_constant, 48 | 'warmup_linear':warmup_linear, 49 | } 50 | 51 | 52 | class OpenAIAdam(Optimizer): 53 | """Implements Open AI version of Adam algorithm with weight decay fix. 54 | """ 55 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 56 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 57 | vector_l2=False, max_grad_norm=-1, **kwargs): 58 | if lr is not required and lr < 0.0: 59 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 60 | if schedule not in SCHEDULES: 61 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 62 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 63 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 64 | if not 0.0 <= b1 < 1.0: 65 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 66 | if not 0.0 <= b2 < 1.0: 67 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 68 | if not e >= 0.0: 69 | raise ValueError("Invalid epsilon value: {}".format(e)) 70 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 71 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 72 | max_grad_norm=max_grad_norm) 73 | super(OpenAIAdam, self).__init__(params, defaults) 74 | 75 | def get_lr(self): 76 | lr = [] 77 | for group in self.param_groups: 78 | for p in group['params']: 79 | state = self.state[p] 80 | if len(state) == 0: 81 | return [0] 82 | if group['t_total'] != -1: 83 | schedule_fct = SCHEDULES[group['schedule']] 84 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 85 | else: 86 | lr_scheduled = group['lr'] 87 | lr.append(lr_scheduled) 88 | return lr 89 | 90 | def step(self, closure=None): 91 | """Performs a single optimization step. 92 | 93 | Arguments: 94 | closure (callable, optional): A closure that reevaluates the model 95 | and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | warned_for_t_total = False 102 | 103 | for group in self.param_groups: 104 | for p in group['params']: 105 | if p.grad is None: 106 | continue 107 | grad = p.grad.data 108 | if grad.is_sparse: 109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 110 | 111 | state = self.state[p] 112 | 113 | # State initialization 114 | if len(state) == 0: 115 | state['step'] = 0 116 | # Exponential moving average of gradient values 117 | state['exp_avg'] = torch.zeros_like(p.data) 118 | # Exponential moving average of squared gradient values 119 | state['exp_avg_sq'] = torch.zeros_like(p.data) 120 | 121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 122 | beta1, beta2 = group['b1'], group['b2'] 123 | 124 | state['step'] += 1 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 132 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | denom = exp_avg_sq.sqrt().add_(group['e']) 134 | 135 | bias_correction1 = 1 - beta1 ** state['step'] 136 | bias_correction2 = 1 - beta2 ** state['step'] 137 | 138 | if group['t_total'] != -1: 139 | schedule_fct = SCHEDULES[group['schedule']] 140 | progress = state['step']/group['t_total'] 141 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 142 | # warning for exceeding t_total (only active with warmup_linear 143 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 144 | logger.warning( 145 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 146 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 147 | warned_for_t_total = True 148 | # end warning 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 153 | 154 | p.data.addcdiv_(-step_size, exp_avg, denom) 155 | 156 | # Add weight decay at the end (fixed version) 157 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 158 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 159 | 160 | return loss 161 | -------------------------------------------------------------------------------- /code/generation/model/utils.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import re 18 | import os 19 | import json 20 | import random 21 | from collections import namedtuple, Counter 22 | 23 | import torch 24 | import numpy as np 25 | from scipy.interpolate import RectBivariateSpline 26 | from torch.utils.checkpoint import checkpoint 27 | 28 | 29 | def set_seed(seed): 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | random.seed(seed) 33 | 34 | 35 | def pad_sequence(sequences, batch_first=False, padding_value=0): 36 | # assuming trailing dimensions and type of all the Tensors 37 | # in sequences are same and fetching those from sequences[0] 38 | max_size = sequences[0].size() 39 | trailing_dims = max_size[1:] 40 | max_len = max([s.size(0) for s in sequences]) 41 | if batch_first: 42 | out_dims = (len(sequences), max_len) + trailing_dims 43 | else: 44 | out_dims = (max_len, len(sequences)) + trailing_dims 45 | 46 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 47 | for i, tensor in enumerate(sequences): 48 | length = tensor.size(0) 49 | # use index notation to prevent duplicate references to the tensor 50 | if batch_first: 51 | out_tensor[i, :length, ...] = tensor 52 | else: 53 | out_tensor[:length, i, ...] = tensor 54 | 55 | return out_tensor 56 | 57 | 58 | def checkpoint_sequential(functions, segments, *inputs): 59 | def run_function(start, end, functions): 60 | def forward(*inputs): 61 | for j in range(start, end + 1): 62 | inputs = functions[j](*inputs) 63 | return inputs 64 | return forward 65 | 66 | if isinstance(functions, torch.nn.Sequential): 67 | functions = list(functions.children()) 68 | 69 | segment_size = len(functions) // segments 70 | # the last chunk has to be non-volatile 71 | end = -1 72 | for start in range(0, segment_size * (segments - 1), segment_size): 73 | end = start + segment_size - 1 74 | inputs = checkpoint(run_function(start, end, functions), *inputs) 75 | if not isinstance(inputs, tuple): 76 | inputs = (inputs,) 77 | return run_function(end + 1, len(functions) - 1, functions)(*inputs) 78 | 79 | 80 | def f1_score(predictions, targets, average=True): 81 | def f1_score_items(pred_items, gold_items): 82 | common = Counter(gold_items) & Counter(pred_items) 83 | num_same = sum(common.values()) 84 | 85 | if num_same == 0: 86 | return 0 87 | 88 | precision = num_same / len(pred_items) 89 | recall = num_same / len(gold_items) 90 | f1 = (2 * precision * recall) / (precision + recall) 91 | 92 | return f1 93 | 94 | scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)] 95 | 96 | if average: 97 | return sum(scores) / len(scores) 98 | 99 | return scores 100 | 101 | 102 | def openai_transformer_config(): 103 | class dotdict(dict): 104 | __getattr__ = dict.get 105 | __setattr__ = dict.__setitem__ 106 | __delattr__ = dict.__delitem__ 107 | 108 | cfg = dotdict({'n_layers': 12, 'n_embeddings': 40477, 'n_pos_embeddings': 512, 109 | 'embeddings_size': 768, 'n_heads': 12, 'dropout': 0.1, 110 | 'embed_dropout': 0.1, 'attn_dropout': 0.1, 'ff_dropout': 0.1}) 111 | 112 | return cfg 113 | 114 | 115 | def load_openai_weights_chinese(model, directory): 116 | openai_model = torch.load(directory) 117 | openai_model.pop('decoder.pre_softmax.weight') 118 | b = list(openai_model.keys()) 119 | for i in b: 120 | openai_model['decoder.' + i] = openai_model.pop(i) 121 | model.load_state_dict(openai_model) 122 | 123 | 124 | def load_openai_weights(model, directory, n_special_tokens=0): 125 | # TODO: add check of shapes 126 | 127 | parameters_names_path = os.path.join(directory, 'parameters_names.json') 128 | parameters_shapes_path = os.path.join(directory, 'parameters_shapes.json') 129 | parameters_weights_paths = [os.path.join(directory, 'params_{}.npy'.format(n)) for n in range(10)] 130 | 131 | with open(parameters_names_path, 'r') as parameters_names_file: 132 | parameters_names = json.load(parameters_names_file) 133 | 134 | with open(parameters_shapes_path, 'r') as parameters_shapes_file: 135 | parameters_shapes = json.load(parameters_shapes_file) 136 | 137 | parameters_weights = [np.load(path) for path in parameters_weights_paths] 138 | parameters_offsets = np.cumsum([np.prod(shape) for shape in parameters_shapes]) 139 | parameters_weights = np.split(np.concatenate(parameters_weights, 0), parameters_offsets)[:-1] 140 | parameters_weights = [p.reshape(s) for p, s in zip(parameters_weights, parameters_shapes)] 141 | 142 | parameters_weights[1] = parameters_weights[1][1:] # skip 0 - 143 | 144 | 145 | if model.pos_embeddings.num_embeddings - 1 > parameters_weights[0].shape[0]: 146 | xx = np.linspace(0, parameters_weights[0].shape[0], model.pos_embeddings.num_embeddings - 1) 147 | new_kernel = RectBivariateSpline(np.arange(parameters_weights[0].shape[0]), 148 | np.arange(parameters_weights[0].shape[1]), 149 | parameters_weights[0]) 150 | parameters_weights[0] = new_kernel(xx, np.arange(parameters_weights[0].shape[1])) 151 | 152 | parameters_weights[0] = parameters_weights[0][:model.pos_embeddings.num_embeddings - 1] 153 | parameters_weights[1] = parameters_weights[1][:model.embeddings.num_embeddings - n_special_tokens] 154 | 155 | model.pos_embeddings.weight.data[1:] = torch.from_numpy(parameters_weights[0]) 156 | model.embeddings.weight.data[n_special_tokens:] = torch.from_numpy(parameters_weights[1]) 157 | 158 | 159 | parameters_weights = parameters_weights[2:] 160 | 161 | for name, weights in zip(parameters_names, parameters_weights): 162 | name = name[6:] # skip "model/" 163 | assert name[-2:] == ':0' 164 | name = name[:-2] 165 | name = name.split('/') 166 | 167 | pointer = model 168 | for m_name in name: 169 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 170 | l = re.split(r'(\d+)', m_name) 171 | else: 172 | l = [m_name] 173 | 174 | pointer = getattr(pointer, l[0]) 175 | 176 | if len(l) >= 2: 177 | num = int(l[1]) 178 | pointer = pointer[num] 179 | 180 | if len(weights.shape) == 3: # conv1d to linear 181 | weights = weights[0].transpose((1, 0)) 182 | 183 | pointer.data[...] = torch.from_numpy(weights) 184 | -------------------------------------------------------------------------------- /code/generation/model/dataset.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | from torch.utils.data import Dataset 18 | import random 19 | 20 | class S2sDataset_dialog(Dataset): 21 | def __init__(self, paths, vocab, max_lengths=2048): 22 | if isinstance(paths, str): 23 | print('path is str') 24 | paths = [paths] 25 | 26 | self.vocab = vocab 27 | self.max_lengths = max_lengths 28 | self.data = S2sDataset_dialog.make_dataset(paths, vocab, max_lengths) 29 | print(len(self.data)) 30 | 31 | @staticmethod 32 | def make_dataset(paths, vocab, max_lengths): 33 | dataset = [] 34 | for path in paths: 35 | with open(path, 'r', encoding='utf8') as fr: 36 | lines = fr.readlines() 37 | for line in lines: 38 | context, response = line.strip('\n').split('\t') 39 | context = vocab.string2ids(' '.join(context)) 40 | response = vocab.string2ids(' '.join(response)) 41 | dataset.append((context, response)) 42 | random.shuffle(dataset) 43 | return dataset 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | def __getitem__(self, idx): 49 | context, response = self.data[idx] 50 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id] 51 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id] 52 | context = context[-40:] 53 | response = response[:40] 54 | return context, response 55 | 56 | class S2sDataset_dialog_overlap(Dataset): 57 | def __init__(self, paths, vocab, max_lengths=2048): 58 | if isinstance(paths, str): 59 | print('path is str') 60 | paths = [paths] 61 | 62 | self.vocab = vocab 63 | self.max_lengths = max_lengths 64 | self.data = S2sDataset_dialog_overlap.make_dataset(paths, vocab, max_lengths) 65 | print(len(self.data)) 66 | print(self.data[0]) 67 | 68 | @staticmethod 69 | def make_dataset(paths, vocab, max_lengths): 70 | dataset = [] 71 | for path in paths: 72 | with open(path, 'r', encoding='utf8') as fr: 73 | lines = fr.readlines() 74 | for line in lines: 75 | context, response, overlap = line.strip('\n').split('\t') 76 | context = vocab.string2ids(' '.join(context)) 77 | response = vocab.string2ids(' '.join(response)) 78 | overlap = int(overlap) 79 | if overlap < 2: 80 | overlap_sym = vocab.token2id[''] 81 | elif overlap < 4: 82 | overlap_sym = vocab.token2id[''] 83 | elif overlap < 6: 84 | overlap_sym = vocab.token2id[''] 85 | else: 86 | overlap_sym = vocab.token2id[''] 87 | dataset.append((context, response, overlap_sym)) 88 | random.shuffle(dataset) 89 | return dataset 90 | 91 | def __len__(self): 92 | return len(self.data) 93 | 94 | def __getitem__(self, idx): 95 | context, response, overlap = self.data[idx] 96 | context = [self.vocab.bos_id] + context[-37:] + [overlap] + [self.vocab.eos_id] 97 | response = [self.vocab.bos_id] + response[-38:] + [self.vocab.eos_id] 98 | #context = context[-40:] 99 | #response = response[:40] 100 | return context, response 101 | 102 | class S2sDataset_poem(Dataset): 103 | def __init__(self, paths, vocab, max_lengths=2048): 104 | if isinstance(paths, str): 105 | paths = [paths] 106 | 107 | self.vocab = vocab 108 | self.max_lengths = max_lengths 109 | self.data = S2sDataset_poem.make_dataset_wu(paths[0], vocab, max_lengths) 110 | 111 | @staticmethod 112 | def make_dataset_xiandai(paths, vocab, max_lengths): 113 | dataset = [] 114 | with open(paths, 'r', encoding='utf8') as fr: 115 | lines = fr.readlines() 116 | for line in lines: 117 | context, response = line.strip('\n').split('\t') 118 | context = vocab.string2ids(context) 119 | response = vocab.string2ids(response) 120 | dataset.append((context[:18], response[:118])) 121 | return dataset 122 | 123 | @staticmethod 124 | def make_dataset_wu(paths, vocab, max_lengths): 125 | dataset = [] 126 | with open(paths, 'r', encoding='utf8') as fr: 127 | lines = fr.readlines() 128 | for line in lines: 129 | target, source = line.strip('\n').split('\t') 130 | source = '

'.join([' '.join(i) for i in source.split()]) 131 | target = ' '.join(target) 132 | context = vocab.string2ids(source) 133 | response = vocab.string2ids(target) 134 | dataset.append((context[:18], response[:30])) 135 | return dataset 136 | 137 | def __len__(self): 138 | return len(self.data) 139 | 140 | def __getitem__(self, idx): 141 | context, response = self.data[idx] 142 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id] 143 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id] 144 | # context = context[:20] 145 | # response = response[:120] 146 | return context, response 147 | 148 | class S2sDataset_meme(Dataset): 149 | def __init__(self, paths, vocab, max_lengths=2048): 150 | if isinstance(paths, str): 151 | paths = [paths] 152 | 153 | self.vocab = vocab 154 | self.max_lengths = max_lengths 155 | self.data = S2sDataset_meme.make_dataset(paths[0], vocab, max_lengths) 156 | 157 | @staticmethod 158 | def make_dataset(paths, vocab, max_lengths): 159 | dataset = [] 160 | with open(paths, 'r', encoding='utf8') as fr: 161 | lines = fr.readlines() 162 | for line in lines: 163 | context, response = line.strip('\n').split('\t') 164 | context = vocab.string2ids(context) 165 | response = vocab.string2ids(response) 166 | dataset.append((context, response)) 167 | return dataset 168 | 169 | def __len__(self): 170 | return len(self.data) 171 | 172 | def __getitem__(self, idx): 173 | context, response = self.data[idx] 174 | context = [self.vocab.bos_id] + context + [self.vocab.eos_id] 175 | response = [self.vocab.bos_id] + response + [self.vocab.eos_id] 176 | context = context[:30] 177 | response = response[:30] 178 | return context, response -------------------------------------------------------------------------------- /code/generation/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import os 4 | from model.utils import load_openai_weights_chinese, set_seed, f1_score 5 | from model.transformer_model import TransformerModel 6 | from model.trainer import Trainer 7 | from model.text import myVocab 8 | from model.dataset import S2sDataset_dialog 9 | from config import get_model_config_dialog, get_trainer_config_dialog 10 | from torch.nn.parallel import DistributedDataParallel 11 | import argparse 12 | 13 | 14 | def main(): 15 | model_config = get_model_config_dialog() 16 | trainer_config = get_trainer_config_dialog() 17 | 18 | set_seed(trainer_config.seed) 19 | device = torch.device(trainer_config.device) 20 | # zrs 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--local_rank", type=int, default=-1) 23 | args = parser.parse_args() 24 | distributed = (args.local_rank != -1) 25 | if distributed: 26 | print(args.local_rank) 27 | torch.cuda.set_device(args.local_rank) 28 | device = torch.device("cuda", args.local_rank) 29 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 30 | 31 | vocab = myVocab(model_config.vocab_path) 32 | 33 | transformer = TransformerModel(n_layers=model_config.n_layers, 34 | n_embeddings=len(vocab), 35 | n_pos_embeddings=model_config.n_pos_embeddings, 36 | embeddings_size=model_config.embeddings_size, 37 | padding_idx=vocab.pad_id, 38 | n_heads=model_config.n_heads, 39 | dropout=model_config.dropout, 40 | embed_dropout=model_config.embed_dropout, 41 | attn_dropout=model_config.attn_dropout, 42 | ff_dropout=model_config.ff_dropout, 43 | bos_id=vocab.bos_id, 44 | eos_id=vocab.eos_id, 45 | max_seq_len=model_config.max_seq_len, 46 | beam_size=model_config.beam_size, 47 | length_penalty=model_config.length_penalty, 48 | n_segments=model_config.n_segments, 49 | annealing_topk=model_config.annealing_topk, 50 | temperature=model_config.temperature, 51 | annealing=model_config.annealing, 52 | diversity_coef=model_config.diversity_coef, 53 | diversity_groups=model_config.diversity_groups) 54 | 55 | if not trainer_config.load_last: 56 | openai_model = torch.load(trainer_config.openai_parameters_dir, map_location=device) 57 | openai_model.pop('decoder.pre_softmax.weight') 58 | b = list(openai_model.keys()) 59 | for i in b: 60 | temp = i.split('.') 61 | keep = True 62 | for j in range(model_config.n_layers, 12): 63 | if str(j) in temp: 64 | keep = False 65 | break 66 | if keep: 67 | openai_model[i.split('.', 1)[1]] = openai_model.pop(i) 68 | else: 69 | print(i) 70 | openai_model.pop(i) 71 | #openai_model[i.split('.', 1)[1]] = openai_model.pop(i) 72 | transformer.transformer_module.load_state_dict(openai_model, strict=True) 73 | # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir) 74 | print('OpenAI weights chinese loaded from {}'.format(trainer_config.openai_parameters_dir)) 75 | 76 | train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1) 77 | test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1) 78 | 79 | model_trainer = Trainer(transformer, 80 | train_dataset, 81 | test_dataset, 82 | batch_size=trainer_config.batch_size, 83 | batch_split=trainer_config.batch_split, 84 | lr=trainer_config.lr, 85 | lr_warmup=trainer_config.lr_warmup, 86 | lm_weight=trainer_config.lm_weight, 87 | risk_weight=trainer_config.risk_weight, 88 | n_jobs=trainer_config.n_jobs, 89 | clip_grad=trainer_config.clip_grad, 90 | # label_smoothing=trainer_config.label_smoothing, 91 | device=device, 92 | ignore_idxs=vocab.special_tokens_ids, 93 | distributed=distributed) 94 | if distributed: 95 | model_trainer.model.transformer_module = DistributedDataParallel(model_trainer.model.transformer_module, 96 | device_ids=[args.local_rank], output_device=args.local_rank) 97 | 98 | start_epoch = 0 99 | init_epoch = 0 100 | 101 | if trainer_config.load_last: 102 | state_dict = torch.load(trainer_config.last_checkpoint_path + str(init_epoch - 1), map_location=device) 103 | model_trainer.load_state_dict(state_dict) 104 | # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1 105 | start_epoch = init_epoch 106 | print('Weights loaded from {}'.format(trainer_config.last_checkpoint_path + str(init_epoch - 1))) 107 | 108 | 109 | # helpers ----------------------------------------------------- 110 | def save_func(epoch): 111 | dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1]) 112 | if not os.path.exists(dirs): 113 | os.makedirs(dirs) 114 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path) 115 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path + str(epoch)) 116 | if os.path.exists(trainer_config.last_checkpoint_path + str(epoch - 100)): 117 | os.remove(trainer_config.last_checkpoint_path + str(epoch - 100)) 118 | 119 | def sample_text_func(epoch): 120 | n_samples = 5 121 | samples_idxs = random.sample(range(len(test_dataset)), n_samples) 122 | samples = [test_dataset[idx] for idx in samples_idxs] 123 | for source, target in samples: 124 | contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [source] if len(c) > 0] 125 | prediction = model_trainer.model.predict(contexts)[0] 126 | source_str = vocab.ids2string(source) 127 | target_str = vocab.ids2string(target[1:-1]) 128 | prediction_str = vocab.ids2string(prediction) 129 | print('\n') 130 | print('Source:{}'.format(source_str)) 131 | print('Target:\n\t{}'.format(target_str)) 132 | print('Prediction:\n\t{}'.format(prediction_str)) 133 | 134 | def test_func(epoch): 135 | if (epoch+1) % trainer_config.test_period == 0: 136 | metric_funcs = {'f1_score': f1_score} 137 | model_trainer.test(metric_funcs) 138 | 139 | def f1_risk(predictions, targets): 140 | scores = f1_score(predictions, targets, average=False) 141 | return [1-s for s in scores] 142 | 143 | # helpers ----------------------------------------------------- 144 | 145 | # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1]) 146 | try: 147 | if args.local_rank in [-1, 0]: 148 | model_trainer.train(start_epoch, trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func], 149 | risk_func=f1_risk) 150 | else: 151 | model_trainer.train(start_epoch, trainer_config.n_epochs) 152 | except (KeyboardInterrupt, Exception, RuntimeError) as e: 153 | torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path) 154 | raise e 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | 160 | -------------------------------------------------------------------------------- /code/generation/model/transformer_module.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import math 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from .utils import checkpoint_sequential 22 | 23 | 24 | class MultiheadAttention(nn.Module): 25 | @classmethod 26 | def _get_future_mask(cls, size, device): 27 | if not hasattr(cls, '_future_mask') or cls._future_mask.device != device or cls._future_mask.shape < size: 28 | cls._future_mask = torch.triu(torch.ones(size[0], size[1], dtype=torch.uint8, device=device), 1) 29 | 30 | mask = cls._future_mask[:size[0], :size[1]] 31 | 32 | return mask 33 | 34 | def __init__(self, n_features, n_heads, dropout): 35 | super(MultiheadAttention, self).__init__() 36 | assert n_features % n_heads == 0 37 | 38 | self.n_features = n_features 39 | self.n_heads = n_heads 40 | self.qkv_proj = nn.Linear(n_features, 3 * n_features) 41 | self.out_proj = nn.Linear(n_features, n_features) 42 | self.dropout = nn.Dropout(dropout) 43 | 44 | self._init_weights() 45 | 46 | def _init_weights(self): 47 | nn.init.normal_(self.qkv_proj.weight, std=0.02) 48 | nn.init.normal_(self.out_proj.weight, std=0.02) 49 | 50 | def _split_heads(self, x, is_key=False): 51 | x = x.view(x.shape[0], x.shape[1], self.n_heads, self.n_features // self.n_heads) 52 | x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) 53 | 54 | return x 55 | 56 | def _attn(self, q, k, v, apply_future_mask=True, padding_mask=None): 57 | w = torch.matmul(q, k) / math.sqrt(self.n_features // self.n_heads) 58 | 59 | if apply_future_mask: 60 | future_mask = MultiheadAttention._get_future_mask(w.shape[-2:], w.device).unsqueeze(0).unsqueeze(0) 61 | w.masked_fill_(future_mask, float('-inf')) 62 | 63 | if padding_mask is not None: 64 | w.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) 65 | 66 | w = F.softmax(w, dim=-1) 67 | w = self.dropout(w) 68 | 69 | if padding_mask is not None: 70 | w.masked_fill_(padding_mask.all(dim=-1).unsqueeze(1).unsqueeze(2).unsqueeze(3), 0) 71 | 72 | out = torch.matmul(w, v) 73 | 74 | return out 75 | 76 | def _merge_heads(self, x): 77 | x = x.permute(0, 2, 1, 3).contiguous() 78 | x = x.view(x.shape[0], x.shape[1], self.n_features) 79 | 80 | return x 81 | 82 | def forward(self, query, key, value, padding_mask): 83 | qkv_same = (query.data_ptr() == key.data_ptr() == value.data_ptr()) 84 | kv_same = (key.data_ptr() == value.data_ptr()) 85 | 86 | if qkv_same: 87 | query, key, value = self.qkv_proj(query).split(self.n_features, dim=-1) 88 | apply_future_mask = True # self-attention 89 | elif kv_same: 90 | q_w, q_b = self.qkv_proj.weight[:self.n_features, :], self.qkv_proj.bias[:self.n_features] 91 | query = F.linear(query, q_w, q_b) 92 | kv_w, kv_b = self.qkv_proj.weight[self.n_features:, :], self.qkv_proj.bias[self.n_features:] 93 | key, value = F.linear(key, kv_w, kv_b).split(self.n_features, dim=-1) 94 | apply_future_mask = False 95 | else: 96 | assert False 97 | 98 | query = self._split_heads(query) 99 | key = self._split_heads(key, is_key=True) 100 | value = self._split_heads(value) 101 | 102 | x = self._attn(query, key, value, apply_future_mask, padding_mask) 103 | x = self._merge_heads(x) 104 | 105 | x = self.out_proj(x) 106 | 107 | return x 108 | 109 | 110 | class FeedForward(nn.Module): 111 | @staticmethod 112 | def gelu(x): 113 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 114 | 115 | def __init__(self, in_features, middle_features, dropout): 116 | super(FeedForward, self).__init__() 117 | 118 | self.layer_1 = nn.Linear(in_features, middle_features) 119 | self.layer_2 = nn.Linear(middle_features, in_features) 120 | self.dropout = nn.Dropout(dropout) 121 | 122 | self._init_weights() 123 | 124 | def _init_weights(self): 125 | nn.init.normal_(self.layer_1.weight, std=0.02) 126 | nn.init.normal_(self.layer_2.weight, std=0.02) 127 | 128 | def forward(self, x): 129 | x = FeedForward.gelu(self.layer_1(x)) 130 | x = self.dropout(x) 131 | x = self.layer_2(x) 132 | 133 | return x 134 | 135 | 136 | class TransformerBlock(nn.Module): 137 | def __init__(self, n_features, n_heads, dropout, attn_dropout, ff_dropout): 138 | super(TransformerBlock, self).__init__() 139 | 140 | self.attn = MultiheadAttention(n_features, n_heads, attn_dropout) 141 | self.attn_norm = nn.LayerNorm(n_features) 142 | self.ff = FeedForward(n_features, 4 * n_features, ff_dropout) 143 | self.ff_norm = nn.LayerNorm(n_features) 144 | self.dropout = nn.Dropout(dropout) 145 | 146 | def forward(self, x, padding_mask, *contexts): 147 | '''contexts = [(context1, padding_mask1), ...]''' 148 | 149 | inputs = (x, padding_mask) + contexts 150 | 151 | full_attn = 0 152 | n_attn = len(inputs) // 2 153 | for i in range(0, len(inputs), 2): 154 | c, m = inputs[i], inputs[i+1].byte() 155 | a = self.attn(x, c, c, m) 156 | full_attn += (a / n_attn) 157 | 158 | full_attn = self.dropout(full_attn) 159 | x = self.attn_norm(x + full_attn) 160 | 161 | f = self.ff(x) 162 | f = self.dropout(f) 163 | x = self.ff_norm(x + f) 164 | 165 | return (x, padding_mask) + contexts 166 | 167 | 168 | class TransformerModule(nn.Module): 169 | def __init__(self, n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 170 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, ff_dropout, 171 | n_segments=None): 172 | super(TransformerModule, self).__init__() 173 | 174 | self.embeddings = nn.Embedding(n_embeddings, embeddings_size, padding_idx=padding_idx) 175 | self.pos_embeddings = nn.Embedding(n_pos_embeddings + 1, embeddings_size, padding_idx=0) 176 | self.embed_dropout = nn.Dropout(embed_dropout) 177 | self.layers = nn.ModuleList([TransformerBlock(embeddings_size, n_heads, dropout, attn_dropout, ff_dropout) for _ in range(n_layers)]) 178 | self.n_segments = n_segments 179 | 180 | self._init_weights() 181 | 182 | def _init_weights(self): 183 | nn.init.normal_(self.embeddings.weight, std=0.02) 184 | nn.init.normal_(self.pos_embeddings.weight, std=0.02) 185 | 186 | def forward(self, x, enc_contexts=[]): 187 | padding_mask = x.eq(self.embeddings.padding_idx) 188 | 189 | positions = torch.cumsum(~padding_mask, dim=-1, dtype=torch.long) 190 | positions.masked_fill_(padding_mask, self.pos_embeddings.padding_idx) 191 | 192 | x = self.embeddings(x) * math.sqrt(self.embeddings.embedding_dim) + self.pos_embeddings(positions) 193 | x = self.embed_dropout(x) 194 | 195 | enc_contexts = sum(enc_contexts, ()) 196 | 197 | if self.n_segments is not None: 198 | padding_mask = padding_mask.float() # fucking checkpoint_sequential 199 | padding_mask.requires_grad_() # fucking checkpoint_sequential 200 | out = checkpoint_sequential(self.layers, self.n_segments, x, padding_mask, *enc_contexts) 201 | x = out[0] 202 | else: 203 | for layer in self.layers: 204 | out = layer(x, padding_mask, *enc_contexts) 205 | x = out[0] 206 | 207 | return x, padding_mask 208 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | # import regex as re 23 | import re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | } 41 | PRETRAINED_MERGES_ARCHIVE_MAP = { 42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 43 | } 44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 45 | 'gpt2': 1024, 46 | } 47 | VOCAB_NAME = 'vocab.json' 48 | MERGES_NAME = 'merges.txt' 49 | 50 | @lru_cache() 51 | def bytes_to_unicode(): 52 | """ 53 | Returns list of utf-8 byte and a corresponding list of unicode strings. 54 | The reversible bpe codes work on unicode strings. 55 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 56 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 57 | This is a signficant percentage of your normal, say, 32K bpe vocab. 58 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 59 | And avoids mapping to whitespace/control characters the bpe code barfs on. 60 | """ 61 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 62 | cs = bs[:] 63 | n = 0 64 | for b in range(2**8): 65 | if b not in bs: 66 | bs.append(b) 67 | cs.append(2**8+n) 68 | n += 1 69 | cs = [chr(n) for n in cs] 70 | return dict(zip(bs, cs)) 71 | 72 | def get_pairs(word): 73 | """Return set of symbol pairs in a word. 74 | 75 | Word is represented as tuple of symbols (symbols being variable-length strings). 76 | """ 77 | pairs = set() 78 | prev_char = word[0] 79 | for char in word[1:]: 80 | pairs.add((prev_char, char)) 81 | prev_char = char 82 | return pairs 83 | 84 | class GPT2Tokenizer(object): 85 | """ 86 | GPT-2 BPE tokenizer. Peculiarities: 87 | - Byte-level BPE 88 | """ 89 | @classmethod 90 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 91 | """ 92 | Instantiate a PreTrainedBertModel from a pre-trained model file. 93 | Download and cache the pre-trained model file if needed. 94 | """ 95 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 96 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 97 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 98 | else: 99 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 100 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 101 | # redirect to the cache, if necessary 102 | try: 103 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 104 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 105 | except EnvironmentError: 106 | logger.error( 107 | "Model name '{}' was not found in model name list ({}). " 108 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 109 | "at this path or url.".format( 110 | pretrained_model_name_or_path, 111 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 112 | pretrained_model_name_or_path, 113 | vocab_file, merges_file)) 114 | return None 115 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 116 | logger.info("loading vocabulary file {}".format(vocab_file)) 117 | logger.info("loading merges file {}".format(merges_file)) 118 | else: 119 | logger.info("loading vocabulary file {} from cache at {}".format( 120 | vocab_file, resolved_vocab_file)) 121 | logger.info("loading merges file {} from cache at {}".format( 122 | merges_file, resolved_merges_file)) 123 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 124 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 125 | # than the number of positional embeddings 126 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 127 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 128 | # Instantiate tokenizer. 129 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 130 | return tokenizer 131 | 132 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): 133 | self.max_len = max_len if max_len is not None else int(1e12) 134 | self.encoder = json.load(open(vocab_file)) 135 | self.decoder = {v:k for k,v in self.encoder.items()} 136 | self.errors = errors # how to handle errors in decoding 137 | self.byte_encoder = bytes_to_unicode() 138 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 139 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 140 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 141 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 142 | self.cache = {} 143 | 144 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 145 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 146 | 147 | def __len__(self): 148 | return len(self.encoder) 149 | 150 | def bpe(self, token): 151 | if token in self.cache: 152 | return self.cache[token] 153 | word = tuple(token) 154 | pairs = get_pairs(word) 155 | 156 | if not pairs: 157 | return token 158 | 159 | while True: 160 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 161 | if bigram not in self.bpe_ranks: 162 | break 163 | first, second = bigram 164 | new_word = [] 165 | i = 0 166 | while i < len(word): 167 | try: 168 | j = word.index(first, i) 169 | new_word.extend(word[i:j]) 170 | i = j 171 | except: 172 | new_word.extend(word[i:]) 173 | break 174 | 175 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 176 | new_word.append(first+second) 177 | i += 2 178 | else: 179 | new_word.append(word[i]) 180 | i += 1 181 | new_word = tuple(new_word) 182 | word = new_word 183 | if len(word) == 1: 184 | break 185 | else: 186 | pairs = get_pairs(word) 187 | word = ' '.join(word) 188 | self.cache[token] = word 189 | return word 190 | 191 | def encode(self, text): 192 | bpe_tokens = [] 193 | for token in re.findall(self.pat, text): 194 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 195 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 196 | if len(bpe_tokens) > self.max_len: 197 | logger.warning( 198 | "Token indices sequence length is longer than the specified maximum " 199 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" 200 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) 201 | ) 202 | return bpe_tokens 203 | 204 | def decode(self, tokens): 205 | text = ''.join([self.decoder[token] for token in tokens]) 206 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 207 | return text 208 | -------------------------------------------------------------------------------- /code/generation/model/trainer.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import random 21 | from torch.utils.data import DataLoader 22 | from tqdm import tqdm 23 | from .utils import pad_sequence 24 | from .optim import Adam, NoamOpt 25 | from .loss import LabelSmoothingLoss 26 | import json 27 | import logging 28 | import math 29 | logger = logging.getLogger('s2s') 30 | logger.setLevel(logging.INFO) 31 | fh = logging.FileHandler('s2s.log', encoding='utf-8') 32 | fh.setLevel(logging.INFO) 33 | ch = logging.StreamHandler() 34 | ch.setLevel(logging.INFO) 35 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 36 | fh.setFormatter(formatter) 37 | ch.setFormatter(formatter) 38 | logger.addHandler(fh) 39 | logger.addHandler(ch) 40 | # 记录一条日志 41 | logger.info('python logging test') 42 | 43 | 44 | class Trainer: 45 | def __init__(self, model, train_dataset, test_dataset=None, batch_size=8, 46 | batch_split=1, lm_weight=0.5, risk_weight=0, lr=6.25e-5, lr_warmup=2000, 47 | n_jobs=0, clip_grad=None, label_smoothing=0, device=torch.device('cuda'), 48 | ignore_idxs=[], distributed=False): 49 | self.model = model.to(device) 50 | self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.model.padding_idx).to(device) 51 | self.criterion = LabelSmoothingLoss(n_labels=self.model.n_embeddings, smoothing=label_smoothing, 52 | ignore_index=self.model.padding_idx).to(device) 53 | base_optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=0.01) 54 | self.optimizer = NoamOpt(self.model.embeddings_size, 0.1, lr_warmup, base_optimizer) 55 | 56 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 57 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if distributed else None 58 | self.train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size//batch_split, 59 | shuffle=(not distributed), num_workers=n_jobs, collate_fn=self.collate_func_cn) 60 | if test_dataset is not None: 61 | self.test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size//batch_split, 62 | shuffle=False, num_workers=n_jobs, collate_fn=self.collate_func_cn) 63 | 64 | self.batch_split = batch_split 65 | self.lm_weight = lm_weight 66 | self.risk_weight = risk_weight 67 | self.clip_grad = clip_grad 68 | self.device = device 69 | self.ignore_idxs = ignore_idxs 70 | 71 | def state_dict(self): 72 | return {'model': self.model.state_dict(), 73 | 'optimizer': self.optimizer.state_dict()} 74 | 75 | def load_state_dict(self, state_dict): 76 | self.model.load_state_dict(state_dict['model'], strict=True) 77 | self.optimizer.load_state_dict(state_dict['optimizer']) 78 | 79 | def collate_func_cn(self, data): 80 | x, y = zip(*data) 81 | contexts = [] 82 | 83 | if max(map(len, x)) > 0: 84 | x = [torch.tensor(d, dtype=torch.long) for d in x] 85 | x = pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx) 86 | contexts.append(x) 87 | y = [torch.tensor(d, dtype=torch.long) for d in y] 88 | y = pad_sequence(y, batch_first=True, padding_value=self.model.padding_idx) 89 | return contexts, y 90 | 91 | def _eval_train(self, epoch, risk_func=None): 92 | self.model.train() 93 | 94 | tqdm_data = tqdm(self.train_dataloader, desc='Train (epoch #{})'.format(epoch)) 95 | loss = 0 96 | lm_loss = 0 97 | for i, (contexts, targets) in enumerate(tqdm_data): 98 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device) 99 | 100 | enc_contexts = [] 101 | 102 | # lm loss 103 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device) 104 | for context in contexts: 105 | enc_context = self.model.encode(context.clone()) 106 | enc_contexts.append(enc_context) 107 | 108 | if self.lm_weight > 0: 109 | context_outputs = self.model.generate(enc_context[0]) 110 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1) 111 | context.masked_fill_(ignore_mask, self.model.padding_idx) 112 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous() 113 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts)) 114 | 115 | # s2s loss 116 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous() 117 | outputs = self.model.decode(prevs, enc_contexts) 118 | outputs = F.log_softmax(outputs, dim=-1) 119 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1)) 120 | 121 | # optimization 122 | full_loss = (batch_lm_loss * self.lm_weight + batch_loss) / self.batch_split 123 | full_loss.backward() 124 | 125 | if (i + 1) % self.batch_split == 0: 126 | if self.clip_grad is not None: 127 | for group in self.optimizer.param_groups: 128 | nn.utils.clip_grad_norm_(group['params'], self.clip_grad) 129 | 130 | self.optimizer.step() 131 | self.optimizer.zero_grad() 132 | 133 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1) 134 | loss = (i * loss + batch_loss.item()) / (i + 1) 135 | tqdm_data.set_postfix({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss), 'loss_step': batch_loss.item(), 136 | 'lr': self.optimizer.rate(), 'step': self.optimizer._step}) 137 | 138 | log_dict = {'epoch': epoch, 'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss), 'lr': self.optimizer.rate(), 139 | 'step': self.optimizer._step} 140 | log_dict_json = json.dumps(log_dict, ensure_ascii=False) 141 | logger.info(log_dict_json) 142 | 143 | def _eval_test(self, metric_funcs={}): 144 | self.model.eval() 145 | 146 | tqdm_data = tqdm(self.test_dataloader, desc='Test') 147 | loss = 0 148 | lm_loss = 0 149 | metrics = {name: 0 for name in metric_funcs.keys()} 150 | for i, (contexts, targets) in enumerate(tqdm_data): 151 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device) 152 | enc_contexts = [] 153 | 154 | # lm loss 155 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device) 156 | for context in contexts: 157 | enc_context = self.model.encode(context.clone()) 158 | enc_contexts.append(enc_context) 159 | 160 | if self.lm_weight > 0: 161 | context_outputs = self.model.generate(enc_context[0]) 162 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1) 163 | context.masked_fill_(ignore_mask, self.model.padding_idx) 164 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous() 165 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts)) 166 | 167 | # s2s loss 168 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous() 169 | outputs = self.model.decode(prevs, enc_contexts) 170 | outputs = F.log_softmax(outputs, dim=-1) 171 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1)) 172 | 173 | predictions = self.model.beam_search(enc_contexts) 174 | target_lens = targets.ne(self.model.padding_idx).sum(dim=-1) 175 | targets = [t[1:l-1].tolist() for t, l in zip(targets, target_lens)] 176 | 177 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1) 178 | loss = (i * loss + batch_loss.item()) / (i + 1) 179 | for name, func in metric_funcs.items(): 180 | score = func(predictions, targets) 181 | metrics[name] = (metrics[name] * i + score) / (i + 1) 182 | 183 | tqdm_data.set_postfix(dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics)) 184 | log_dict = dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics) 185 | log_dict_json = json.dumps(log_dict, ensure_ascii=False) 186 | logger.info(log_dict_json) 187 | 188 | def test(self, metric_funcs={}): 189 | if hasattr(self, 'test_dataloader'): 190 | self._eval_test(metric_funcs) 191 | 192 | def train(self, start_epoch, epochs, after_epoch_funcs=[], risk_func=None): 193 | for epoch in range(start_epoch, epochs): 194 | self._eval_train(epoch, risk_func) 195 | for func in after_epoch_funcs: 196 | func(epoch) 197 | -------------------------------------------------------------------------------- /code/generation/model/postprocessing.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import language_check 18 | from mosestokenizer import * 19 | from collections import deque 20 | from nltk import ngrams 21 | from nltk.corpus import wordnet 22 | import nltk 23 | import difflib 24 | import random 25 | from .retrieval import RetrievalBot 26 | import re 27 | import numpy as np 28 | from itertools import combinations 29 | 30 | 31 | SPELL_EXCEPTIONS = ['lol'] 32 | STANDARD_ANSWERS = ['do you wanna talk about something else? ', 33 | 'tell me something about yourself.', 34 | 'it is interesting. how is it outside?', 35 | 'do you like walking outside?', 36 | 'cats... do you like cats!', 37 | 'how do you spend your free time?', 38 | 'how do you usually spend your weekend?', 39 | 'i think you are interesting person. tell me something about yourself.'] 40 | 41 | 42 | def syntax_fix(text): 43 | def _i_replace(text): 44 | text = text.split() 45 | for i in range(len(text)): 46 | if text[i] == 'i': 47 | text[i] = 'I' 48 | if text[i] == 'i\'m': 49 | text[i] = 'I\'m' 50 | 51 | text = ' '.join(text) 52 | 53 | return text 54 | 55 | tool = language_check.LanguageTool('en-US') 56 | 57 | matches = tool.check(text) 58 | matches = [m for m in matches if text[m.fromx:m.tox].lower() not in SPELL_EXCEPTIONS] 59 | 60 | return _i_replace(language_check.correct(text, matches)) 61 | 62 | 63 | def detokenize(text): 64 | text = text.split(' ') 65 | text[0] = text[0].title() 66 | 67 | with MosesDetokenizer('en') as detokenize: 68 | text = detokenize(text) 69 | 70 | text = syntax_fix(text) 71 | 72 | return text 73 | 74 | 75 | class ReplyChecker: 76 | def __init__(self, max_len=10, theshold=0.8, correct_generative=True, split_into_sentences=True): 77 | self._replies = deque([], maxlen=max_len) 78 | self._theshold = theshold 79 | self._retrieval = RetrievalBot() 80 | self._info = None 81 | self._max_len = max_len 82 | 83 | self._correct_generative = correct_generative 84 | self._split_into_sentences = split_into_sentences 85 | 86 | self._reset_prob() 87 | 88 | def _reset_prob(self): 89 | self._def_prob = np.ones(len(STANDARD_ANSWERS)) / len(STANDARD_ANSWERS) 90 | 91 | def _ratio(self, seq1, seq2): 92 | # todo: only works good for same sequences 93 | return difflib.SequenceMatcher(None, seq1, seq2).ratio() 94 | 95 | def _sentence_max_coincidence_drop(self, reply): 96 | history = sum([re.split(r' *[\?\.\!][\'"\)\]]* *', r) for r in self._replies], []) 97 | 98 | split_reply = re.split(r' *[\?\.\!][\'"\)\]]* *', reply) 99 | punc = list(re.finditer(r' *[\?\.\!][\'"\)\]]* *', reply)) 100 | 101 | # ratio = 0 102 | drop = [] 103 | 104 | for i, r in enumerate(split_reply): 105 | for h in history: 106 | if h and r: 107 | ratio = self._ratio(r, h) 108 | if ratio > self._theshold: 109 | drop.append(i) 110 | 111 | drop = sorted(set(drop), reverse=True) 112 | for d in drop: 113 | split_reply.pop(d) 114 | punc.pop(d) 115 | 116 | original_text = '' 117 | 118 | for s, m in zip(split_reply, punc): 119 | original_text += s + m.group() 120 | if len(split_reply) > len(punc): 121 | original_text += split_reply[-1] 122 | 123 | return original_text.strip() 124 | 125 | def _max_coincidence(self, reply): 126 | if not self._replies: 127 | return None, reply 128 | 129 | if self._split_into_sentences: 130 | reply = self._sentence_max_coincidence_drop(reply) 131 | if not reply: 132 | return 1.0, reply 133 | 134 | mc = max(self._replies, key=lambda x: self._ratio(x, reply)) 135 | 136 | ratio = self._ratio(mc, reply) 137 | 138 | return ratio, reply 139 | 140 | def _replase_reply(self, reply, request, info): 141 | dialog = 2 * ['None'] + [request] 142 | res = self._retrieval.generate_question(dialog, info) 143 | if res is None: 144 | if self._info is None: 145 | self._info = self._retrieval.get_reply_info(info) 146 | 147 | if not self._info: 148 | idx = np.random.choice(range(len(STANDARD_ANSWERS)), p=self._def_prob) 149 | self._def_prob[idx] = 0 150 | 151 | if np.sum(self._def_prob) == 0: 152 | self._reset_prob() 153 | else: 154 | self._def_prob /= np.sum(self._def_prob) 155 | 156 | return STANDARD_ANSWERS[idx] 157 | 158 | res = random.choice(list(self._info.keys())) 159 | del self._info[res] 160 | 161 | return res 162 | 163 | @staticmethod 164 | def _correct_repeated_sentences(text): 165 | split_text = re.split(r' *[\?\.\!][\'"\)\]]* *', text) 166 | matches = list(re.finditer(r' *[\?\.\!][\'"\)\]]* *', text)) 167 | 168 | drop = [] 169 | for i, j in combinations(range(len(split_text)), 2): 170 | if split_text[j] and split_text[j] in split_text[i]: 171 | drop.append(j) 172 | drop = set(drop) 173 | drop = sorted(drop, reverse=True) 174 | 175 | for d in drop: 176 | split_text.pop(d) 177 | matches.pop(d) 178 | 179 | original_text = '' 180 | 181 | for s, m in zip(split_text, matches): 182 | original_text += s + m.group() 183 | if len(split_text) > len(matches): 184 | original_text += split_text[-1] 185 | return original_text 186 | 187 | def check_reply(self, reply, request, info): 188 | log = [reply] 189 | log_names = ['IN: ', 'RL: ', 'RS: '] 190 | 191 | try: 192 | if self._correct_generative: 193 | reply = ReplyChecker._correct_repeated_sentences(reply) 194 | 195 | ratio, reply = self._max_coincidence(reply) 196 | log.append(reply) 197 | if ratio is not None: 198 | # ratio = self._ratio(mc, reply) 199 | 200 | if ratio > self._theshold: 201 | reply = self._replase_reply(reply, request, info) 202 | log.append(reply) 203 | 204 | except Exception as e: 205 | print('ERROR: ', e) 206 | reply = log[0] 207 | 208 | # print('[' + ' | '.join([n + str(v) for n, v in zip(log_names, log) ]) + ']') 209 | self._replies.append(reply) 210 | 211 | return reply 212 | 213 | def clean(self): 214 | self._info = None 215 | self._replies = deque([], maxlen=self._max_len) 216 | self._reset_prob() 217 | 218 | 219 | def get_syn(seq): 220 | seq = seq.replace('i ', 'I ') 221 | seq = nltk.pos_tag(nltk.word_tokenize(seq)) 222 | 223 | synonyms = {} 224 | 225 | for w, s_p in seq: 226 | if len(w) < 3: 227 | continue 228 | if s_p not in ['VBP', 'NN', 'NNS']: 229 | continue 230 | 231 | pos = wordnet.VERB if s_p == 'VBP' else wordnet.NOUN 232 | 233 | s = wordnet.synsets(w, pos=pos) 234 | for word in s: 235 | for l in word.lemma_names(): 236 | if l != w: 237 | synonyms[l.replace('_', ' ')] = w 238 | break 239 | 240 | if not synonyms: 241 | return None 242 | 243 | key = random.choice(list(synonyms.keys())) 244 | return synonyms[key], key 245 | 246 | 247 | def equal_phrases(phrases): 248 | matches = {' am ': '\'m ', 249 | ' are ': '\'re ', 250 | ' have ': '\'ve ', 251 | ' has ': '\'s ', 252 | 'do not': 'don\'t', 253 | 'does not': 'doesn\'t' 254 | } 255 | 256 | replasments = [] 257 | 258 | for ph in phrases: 259 | a = ph 260 | for o, r in matches.items(): 261 | if o in a: 262 | a = a.replace(o, r) 263 | break 264 | if r in a: 265 | a = a.replace(r, o) 266 | break 267 | 268 | if a == ph: 269 | # todo: find synonims 270 | syn = get_syn(a) 271 | if syn is None: 272 | a = a.split(' ') 273 | a[-2], a[-1] = a[-1], a[-2] 274 | a = ' '.join(a) 275 | else: 276 | a = a.replace(syn[0], syn[1]) 277 | 278 | replasments.append(a) 279 | 280 | return replasments 281 | 282 | 283 | def ngram_replaser(info, reply, n=3): 284 | if info is None: 285 | return reply 286 | 287 | org_reply = reply 288 | 289 | info = re.split(r' *[\?\.\!][\'"\)\]]* *', info.strip().lower()) 290 | reply = re.split(r' *[\?\.\!][\'"\)\]]* *', reply.strip().lower()) 291 | 292 | info = sum([list(ngrams(i.split(), n=n)) for i in info if i], []) 293 | reply = sum([list(ngrams(r.split(), n=n)) for r in reply if r], []) 294 | 295 | phrases = [] 296 | 297 | for i in info: 298 | for r in reply: 299 | if i == r: 300 | phrases.append(' '.join(r)) 301 | 302 | replasments = equal_phrases(phrases) 303 | 304 | for o, r in zip(phrases, replasments): 305 | org_reply = org_reply.replace(o, r) 306 | 307 | return org_reply 308 | -------------------------------------------------------------------------------- /code/generation/train_kd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import os 4 | from model.utils import load_openai_weights_chinese, set_seed, f1_score 5 | from model.transformer_model import TransformerModel 6 | from model.trainer_kd import Trainer 7 | from model.text import myVocab 8 | from model.dataset import S2sDataset_dialog 9 | from config import get_model_config_dialog, get_trainer_config_dialog 10 | from torch.nn.parallel import DistributedDataParallel 11 | import argparse 12 | 13 | 14 | def main(): 15 | model_config = get_model_config_dialog() 16 | trainer_config = get_trainer_config_dialog() 17 | 18 | set_seed(trainer_config.seed) 19 | device = torch.device(trainer_config.device) 20 | # zrs 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--local_rank", type=int, default=0) 23 | args = parser.parse_args() 24 | distributed = (args.local_rank != -1) 25 | if distributed: 26 | print(args.local_rank) 27 | torch.cuda.set_device(args.local_rank) 28 | device = torch.device("cuda", args.local_rank) 29 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 30 | 31 | vocab = myVocab(model_config.vocab_path) 32 | 33 | transformer = TransformerModel(n_layers=model_config.n_layers, 34 | n_embeddings=len(vocab), 35 | n_pos_embeddings=model_config.n_pos_embeddings, 36 | embeddings_size=model_config.embeddings_size, 37 | padding_idx=vocab.pad_id, 38 | n_heads=model_config.n_heads, 39 | dropout=model_config.dropout, 40 | embed_dropout=model_config.embed_dropout, 41 | attn_dropout=model_config.attn_dropout, 42 | ff_dropout=model_config.ff_dropout, 43 | bos_id=vocab.bos_id, 44 | eos_id=vocab.eos_id, 45 | max_seq_len=model_config.max_seq_len, 46 | beam_size=model_config.beam_size, 47 | length_penalty=model_config.length_penalty, 48 | n_segments=model_config.n_segments, 49 | annealing_topk=model_config.annealing_topk, 50 | temperature=model_config.temperature, 51 | annealing=model_config.annealing, 52 | diversity_coef=model_config.diversity_coef, 53 | diversity_groups=model_config.diversity_groups) 54 | 55 | transformer_T = TransformerModel(n_layers=model_config.n_layers, 56 | n_embeddings=len(vocab), 57 | n_pos_embeddings=model_config.n_pos_embeddings, 58 | embeddings_size=model_config.embeddings_size, 59 | padding_idx=vocab.pad_id, 60 | n_heads=model_config.n_heads, 61 | dropout=model_config.dropout, 62 | embed_dropout=model_config.embed_dropout, 63 | attn_dropout=model_config.attn_dropout, 64 | ff_dropout=model_config.ff_dropout, 65 | bos_id=vocab.bos_id, 66 | eos_id=vocab.eos_id, 67 | max_seq_len=model_config.max_seq_len, 68 | beam_size=model_config.beam_size, 69 | length_penalty=model_config.length_penalty, 70 | n_segments=model_config.n_segments, 71 | annealing_topk=model_config.annealing_topk, 72 | temperature=model_config.temperature, 73 | annealing=model_config.annealing, 74 | diversity_coef=model_config.diversity_coef, 75 | diversity_groups=model_config.diversity_groups) 76 | 77 | state_dict = torch.load(trainer_config.teacher_checkpoint_path, map_location=device) 78 | print(state_dict['model'].keys()) 79 | temp = dict(state_dict['model']) 80 | keys = list(temp.keys()) 81 | for key in keys: 82 | # new_key = '.'.join([i for i in key.split('.') if i != 'module']) 83 | new_key = key.replace('.module', '') 84 | temp[new_key] = temp.pop(key) 85 | transformer_T.load_state_dict(temp, strict=True) 86 | print('load teacher model ok') 87 | 88 | if not trainer_config.load_last: 89 | openai_model = torch.load(trainer_config.openai_parameters_dir, map_location=device) 90 | openai_model.pop('decoder.pre_softmax.weight') 91 | b = list(openai_model.keys()) 92 | for i in b: 93 | temp = i.split('.') 94 | keep = True 95 | for j in range(model_config.n_layers, 12): 96 | if str(j) in temp: 97 | keep = False 98 | break 99 | if keep: 100 | openai_model[i.split('.', 1)[1]] = openai_model.pop(i) 101 | else: 102 | print(i) 103 | openai_model.pop(i) 104 | #openai_model[i.split('.', 1)[1]] = openai_model.pop(i) 105 | transformer.transformer_module.load_state_dict(openai_model, strict=True) 106 | # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir) 107 | print('OpenAI weights chinese loaded from {}'.format(trainer_config.openai_parameters_dir)) 108 | 109 | train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1) 110 | test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1) 111 | 112 | model_trainer = Trainer(transformer, 113 | transformer_T, 114 | train_dataset, 115 | test_dataset, 116 | batch_size=trainer_config.batch_size, 117 | batch_split=trainer_config.batch_split, 118 | lr=trainer_config.lr, 119 | lr_warmup=trainer_config.lr_warmup, 120 | lm_weight=trainer_config.lm_weight, 121 | risk_weight=trainer_config.risk_weight, 122 | n_jobs=trainer_config.n_jobs, 123 | clip_grad=trainer_config.clip_grad, 124 | # label_smoothing=trainer_config.label_smoothing, 125 | device=device, 126 | ignore_idxs=vocab.special_tokens_ids, 127 | distributed=distributed) 128 | if distributed: 129 | model_trainer.model.transformer_module = DistributedDataParallel(model_trainer.model.transformer_module, 130 | device_ids=[args.local_rank], output_device=args.local_rank) 131 | model_trainer.model_T.transformer_module = DistributedDataParallel(model_trainer.model_T.transformer_module, 132 | device_ids=[args.local_rank], output_device=args.local_rank) 133 | 134 | start_epoch = 0 135 | init_epoch = 35 136 | 137 | if trainer_config.load_last: 138 | state_dict = torch.load(trainer_config.last_checkpoint_path + str(init_epoch - 1), map_location=device) 139 | model_trainer.load_state_dict(state_dict) 140 | # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1 141 | start_epoch = init_epoch 142 | print('Weights loaded from {}'.format(trainer_config.last_checkpoint_path + str(init_epoch - 1))) 143 | 144 | 145 | # helpers ----------------------------------------------------- 146 | def save_func(epoch): 147 | dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1]) 148 | if not os.path.exists(dirs): 149 | os.makedirs(dirs) 150 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path) 151 | torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path + str(epoch)) 152 | if os.path.exists(trainer_config.last_checkpoint_path + str(epoch - 150)): 153 | os.remove(trainer_config.last_checkpoint_path + str(epoch - 150)) 154 | 155 | def sample_text_func(epoch): 156 | n_samples = 5 157 | samples_idxs = random.sample(range(len(test_dataset)), n_samples) 158 | samples = [test_dataset[idx] for idx in samples_idxs] 159 | for source, target in samples: 160 | contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [source] if len(c) > 0] 161 | prediction = model_trainer.model.predict(contexts)[0] 162 | source_str = vocab.ids2string(source) 163 | target_str = vocab.ids2string(target[1:-1]) 164 | prediction_str = vocab.ids2string(prediction) 165 | print('\n') 166 | print('Source:{}'.format(source_str)) 167 | print('Target:\n\t{}'.format(target_str)) 168 | print('Prediction:\n\t{}'.format(prediction_str)) 169 | 170 | def test_func(epoch): 171 | if (epoch+1) % trainer_config.test_period == 0: 172 | metric_funcs = {'f1_score': f1_score} 173 | model_trainer.test(metric_funcs) 174 | 175 | def f1_risk(predictions, targets): 176 | scores = f1_score(predictions, targets, average=False) 177 | return [1-s for s in scores] 178 | 179 | # helpers ----------------------------------------------------- 180 | 181 | # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1]) 182 | try: 183 | if args.local_rank in [-1, 0]: 184 | model_trainer.train(start_epoch, trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func], 185 | risk_func=f1_risk) 186 | else: 187 | model_trainer.train(start_epoch, trainer_config.n_epochs) 188 | except (KeyboardInterrupt, Exception, RuntimeError) as e: 189 | torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path) 190 | raise e 191 | 192 | 193 | if __name__ == '__main__': 194 | main() 195 | 196 | -------------------------------------------------------------------------------- /code/generation/model/trainer_kd.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import random 21 | from torch.utils.data import DataLoader 22 | from tqdm import tqdm 23 | from .utils import pad_sequence 24 | from .optim import Adam, NoamOpt 25 | from .loss import LabelSmoothingLoss 26 | import json 27 | import logging 28 | import math 29 | logger = logging.getLogger('s2s-dialog-fake-kd') 30 | logger.setLevel(logging.INFO) 31 | fh = logging.FileHandler('s2s-dialog-fake-kd.log', encoding='utf-8') 32 | fh.setLevel(logging.INFO) 33 | ch = logging.StreamHandler() 34 | ch.setLevel(logging.INFO) 35 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 36 | fh.setFormatter(formatter) 37 | ch.setFormatter(formatter) 38 | logger.addHandler(fh) 39 | logger.addHandler(ch) 40 | # 记录一条日志 41 | logger.info('python logging test') 42 | 43 | 44 | class Trainer: 45 | def __init__(self, model, model_T, train_dataset, test_dataset=None, batch_size=8, 46 | batch_split=1, lm_weight=0.5, risk_weight=0, lr=6.25e-5, lr_warmup=2000, 47 | n_jobs=0, clip_grad=None, label_smoothing=0, device=torch.device('cuda'), 48 | ignore_idxs=[], distributed=False): 49 | self.model = model.to(device) 50 | self.model_T = model_T.to(device) 51 | self.model_T.requires_grad = False 52 | # self.model_T.eval() 53 | self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.model.padding_idx).to(device) 54 | self.criterion = LabelSmoothingLoss(n_labels=self.model.n_embeddings, smoothing=label_smoothing, 55 | ignore_index=self.model.padding_idx).to(device) 56 | self.kl_loss = torch.nn.KLDivLoss(reduction='batchmean').to(device) 57 | #base_optimizer = Adam(self.model.parameters(), lr=lr, weight_decay=0.01) 58 | base_optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=0.01) 59 | self.optimizer = NoamOpt(self.model.embeddings_size, 0.05, lr_warmup, base_optimizer) 60 | 61 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 62 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if distributed else None 63 | self.train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size//batch_split, 64 | shuffle=(not distributed), num_workers=n_jobs, collate_fn=self.collate_func_cn) 65 | if test_dataset is not None: 66 | self.test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size//batch_split, 67 | shuffle=False, num_workers=n_jobs, collate_fn=self.collate_func_cn) 68 | 69 | self.batch_split = batch_split 70 | self.lm_weight = lm_weight 71 | self.risk_weight = risk_weight 72 | self.clip_grad = clip_grad 73 | self.device = device 74 | self.ignore_idxs = ignore_idxs 75 | 76 | def state_dict(self): 77 | return {'model': self.model.state_dict(), 78 | 'optimizer': self.optimizer.state_dict()} 79 | 80 | def load_state_dict(self, state_dict): 81 | self.model.load_state_dict(state_dict['model'], strict=True) 82 | self.optimizer.load_state_dict(state_dict['optimizer']) 83 | 84 | def collate_func_cn(self, data): 85 | x, y = zip(*data) 86 | contexts = [] 87 | 88 | if max(map(len, x)) > 0: 89 | x = [torch.tensor(d, dtype=torch.long) for d in x] 90 | x = pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx) 91 | contexts.append(x) 92 | y = [torch.tensor(d, dtype=torch.long) for d in y] 93 | y = pad_sequence(y, batch_first=True, padding_value=self.model.padding_idx) 94 | return contexts, y 95 | 96 | def _eval_train(self, epoch, risk_func=None): 97 | self.model.train() 98 | 99 | tqdm_data = tqdm(self.train_dataloader, desc='Train (epoch #{})'.format(epoch)) 100 | loss = 0 101 | lm_loss = 0 102 | kl_loss = 0 103 | for i, (contexts, targets) in enumerate(tqdm_data): 104 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device) 105 | 106 | enc_contexts = [] 107 | enc_contexts_t = [] 108 | 109 | # lm loss 110 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device) 111 | for context in contexts: 112 | enc_context = self.model.encode(context.clone()) 113 | enc_context_t = self.model_T.encode(context.clone()) 114 | enc_contexts.append(enc_context) 115 | enc_contexts_t.append(enc_context_t) 116 | # print('enc_context: ', enc_context) 117 | # print('enc_context_t: ', enc_context_t) 118 | 119 | if self.lm_weight > 0: 120 | context_outputs = self.model.generate(enc_context[0]) 121 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1) 122 | context.masked_fill_(ignore_mask, self.model.padding_idx) 123 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous() 124 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts)) 125 | 126 | # s2s loss 127 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous() 128 | outputs = self.model.decode(prevs, enc_contexts) 129 | outputs = F.log_softmax(outputs, dim=-1) 130 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1)) 131 | outputs_t = self.model_T.decode(prevs, enc_contexts_t) 132 | outputs_t = F.softmax(outputs_t, dim=-1) 133 | batch_kl_loss = self.kl_loss(outputs.view(-1, outputs.shape[-1]), outputs_t.view(-1, outputs_t.shape[-1])) 134 | ''' 135 | # teacher model 136 | prevs_t = targets[:, :-1].contiguous() 137 | enc_contexts_t = [] 138 | for context in contexts: 139 | temp = context.clone() 140 | # print('temp: ', temp) 141 | enc_context_t = self.model_T.encode(temp) 142 | enc_contexts_t.append(enc_context_t) 143 | print('enc_context_t: ', enc_context_t) 144 | outputs_t = self.model_T.decode(prevs_t, enc_contexts_t) 145 | # print('outputs_t: ', outputs_t) 146 | outputs_t = F.softmax(outputs_t, dim=-1) 147 | batch_kl_loss = self.kl_loss(outputs.view(-1, outputs.shape[-1]), outputs_t.view(-1, outputs_t.shape[-1]).data) 148 | print(outputs.shape, outputs_t.shape, batch_kl_loss) 149 | ''' 150 | 151 | # optimization 152 | full_loss = (batch_lm_loss * self.lm_weight + batch_loss + batch_kl_loss) / self.batch_split 153 | full_loss.backward() 154 | 155 | if (i + 1) % self.batch_split == 0: 156 | if self.clip_grad is not None: 157 | for group in self.optimizer.param_groups: 158 | nn.utils.clip_grad_norm_(group['params'], self.clip_grad) 159 | 160 | self.optimizer.step() 161 | self.optimizer.zero_grad() 162 | 163 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1) 164 | loss = (i * loss + batch_loss.item()) / (i + 1) 165 | kl_loss = (i * kl_loss + batch_kl_loss.item()) / (i + 1) 166 | tqdm_data.set_postfix({'lm_loss': lm_loss, 'loss': loss, 'kl_loss': kl_loss, 'ppl': math.exp(loss), 'loss_step': batch_loss.item(), 167 | 'lr': self.optimizer.rate(), 'step': self.optimizer._step}) 168 | 169 | log_dict = {'epoch': epoch, 'lm_loss': lm_loss, 'loss': loss, 'kl_loss': kl_loss, 'ppl': math.exp(loss), 'lr': self.optimizer.rate(), 170 | 'step': self.optimizer._step} 171 | log_dict_json = json.dumps(log_dict, ensure_ascii=False) 172 | logger.info(log_dict_json) 173 | 174 | def _eval_test(self, metric_funcs={}): 175 | self.model.eval() 176 | 177 | tqdm_data = tqdm(self.test_dataloader, desc='Test') 178 | loss = 0 179 | lm_loss = 0 180 | metrics = {name: 0 for name in metric_funcs.keys()} 181 | for i, (contexts, targets) in enumerate(tqdm_data): 182 | contexts, targets = [c.to(self.device) for c in contexts], targets.to(self.device) 183 | enc_contexts = [] 184 | 185 | # lm loss 186 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=self.device) 187 | for context in contexts: 188 | enc_context = self.model.encode(context.clone()) 189 | enc_contexts.append(enc_context) 190 | 191 | if self.lm_weight > 0: 192 | context_outputs = self.model.generate(enc_context[0]) 193 | ignore_mask = torch.stack([context == idx for idx in self.ignore_idxs], dim=-1).any(dim=-1) 194 | context.masked_fill_(ignore_mask, self.model.padding_idx) 195 | prevs, nexts = context_outputs[:, :-1, :].contiguous(), context[:, 1:].contiguous() 196 | batch_lm_loss += (self.lm_criterion(prevs.view(-1, prevs.shape[-1]), nexts.view(-1)) / len(contexts)) 197 | 198 | # s2s loss 199 | prevs, nexts = targets[:, :-1].contiguous(), targets[:, 1:].contiguous() 200 | outputs = self.model.decode(prevs, enc_contexts) 201 | outputs = F.log_softmax(outputs, dim=-1) 202 | batch_loss = self.criterion(outputs.view(-1, outputs.shape[-1]), nexts.view(-1)) 203 | 204 | predictions = self.model.beam_search(enc_contexts) 205 | target_lens = targets.ne(self.model.padding_idx).sum(dim=-1) 206 | targets = [t[1:l-1].tolist() for t, l in zip(targets, target_lens)] 207 | 208 | lm_loss = (i * lm_loss + batch_lm_loss.item()) / (i + 1) 209 | loss = (i * loss + batch_loss.item()) / (i + 1) 210 | for name, func in metric_funcs.items(): 211 | score = func(predictions, targets) 212 | metrics[name] = (metrics[name] * i + score) / (i + 1) 213 | 214 | tqdm_data.set_postfix(dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics)) 215 | log_dict = dict({'lm_loss': lm_loss, 'loss': loss, 'ppl': math.exp(loss)}, **metrics) 216 | log_dict_json = json.dumps(log_dict, ensure_ascii=False) 217 | logger.info(log_dict_json) 218 | 219 | def test(self, metric_funcs={}): 220 | if hasattr(self, 'test_dataloader'): 221 | self._eval_test(metric_funcs) 222 | 223 | def train(self, start_epoch, epochs, after_epoch_funcs=[], risk_func=None): 224 | for epoch in range(start_epoch, epochs): 225 | self._eval_train(epoch, risk_func) 226 | for func in after_epoch_funcs: 227 | func(epoch) 228 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | 45 | def get_pairs(word): 46 | """ 47 | Return set of symbol pairs in a word. 48 | word is represented as tuple of symbols (symbols being variable-length strings) 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | def text_standardize(text): 58 | """ 59 | fixes some issues the spacy tokenizer had on books corpus 60 | also does some whitespace standardization 61 | """ 62 | text = text.replace('—', '-') 63 | text = text.replace('–', '-') 64 | text = text.replace('―', '-') 65 | text = text.replace('…', '...') 66 | text = text.replace('´', "'") 67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 68 | text = re.sub(r'\s*\n\s*', ' \n ', text) 69 | text = re.sub(r'[^\S\n]+', ' ', text) 70 | return text.strip() 71 | 72 | class OpenAIGPTTokenizer(object): 73 | """ 74 | BPE tokenizer. Peculiarities: 75 | - lower case all inputs 76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 77 | - argument special_tokens and function set_special_tokens: 78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 79 | """ 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 82 | """ 83 | Instantiate a PreTrainedBertModel from a pre-trained model file. 84 | Download and cache the pre-trained model file if needed. 85 | """ 86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | else: 90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 92 | # redirect to the cache, if necessary 93 | try: 94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 96 | except EnvironmentError: 97 | logger.error( 98 | "Model name '{}' was not found in model name list ({}). " 99 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 100 | "at this path or url.".format( 101 | pretrained_model_name_or_path, 102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 103 | pretrained_model_name_or_path, 104 | vocab_file, merges_file)) 105 | return None 106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 107 | logger.info("loading vocabulary file {}".format(vocab_file)) 108 | logger.info("loading merges file {}".format(merges_file)) 109 | else: 110 | logger.info("loading vocabulary file {} from cache at {}".format( 111 | vocab_file, resolved_vocab_file)) 112 | logger.info("loading merges file {} from cache at {}".format( 113 | merges_file, resolved_merges_file)) 114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 116 | # than the number of positional embeddings 117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 119 | # Instantiate tokenizer. 120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 121 | return tokenizer 122 | 123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 124 | try: 125 | import ftfy 126 | import spacy 127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 128 | self.fix_text = ftfy.fix_text 129 | except ImportError: 130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 131 | self.nlp = BasicTokenizer(do_lower_case=True, 132 | never_split=special_tokens if special_tokens is not None else []) 133 | self.fix_text = None 134 | 135 | self.max_len = max_len if max_len is not None else int(1e12) 136 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 137 | self.decoder = {v:k for k,v in self.encoder.items()} 138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | merges = [tuple(merge.split()) for merge in merges] 140 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 141 | self.cache = {} 142 | self.set_special_tokens(special_tokens) 143 | 144 | def __len__(self): 145 | return len(self.encoder) + len(self.special_tokens) 146 | 147 | def set_special_tokens(self, special_tokens): 148 | """ Add a list of additional tokens to the encoder. 149 | The additional tokens are indexed starting from the last index of the 150 | current vocabulary in the order of the `special_tokens` list. 151 | """ 152 | if not special_tokens: 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | return 156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 158 | if self.fix_text is None: 159 | # Using BERT's BasicTokenizer: we can update the tokenizer 160 | self.nlp.never_split = special_tokens 161 | logger.info("Special tokens {}".format(self.special_tokens)) 162 | 163 | def bpe(self, token): 164 | word = tuple(token[:-1]) + (token[-1] + '',) 165 | if token in self.cache: 166 | return self.cache[token] 167 | pairs = get_pairs(word) 168 | 169 | if not pairs: 170 | return token+'' 171 | 172 | while True: 173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 174 | if bigram not in self.bpe_ranks: 175 | break 176 | first, second = bigram 177 | new_word = [] 178 | i = 0 179 | while i < len(word): 180 | try: 181 | j = word.index(first, i) 182 | new_word.extend(word[i:j]) 183 | i = j 184 | except: 185 | new_word.extend(word[i:]) 186 | break 187 | 188 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 189 | new_word.append(first+second) 190 | i += 2 191 | else: 192 | new_word.append(word[i]) 193 | i += 1 194 | new_word = tuple(new_word) 195 | word = new_word 196 | if len(word) == 1: 197 | break 198 | else: 199 | pairs = get_pairs(word) 200 | word = ' '.join(word) 201 | if word == '\n ': 202 | word = '\n' 203 | self.cache[token] = word 204 | return word 205 | 206 | def tokenize(self, text): 207 | """ Tokenize a string. """ 208 | split_tokens = [] 209 | if self.fix_text is None: 210 | # Using BERT's BasicTokenizer 211 | text = self.nlp.tokenize(text) 212 | for token in text: 213 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 214 | else: 215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 216 | text = self.nlp(text_standardize(self.fix_text(text))) 217 | for token in text: 218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 219 | return split_tokens 220 | 221 | def convert_tokens_to_ids(self, tokens): 222 | """ Converts a sequence of tokens into ids using the vocab. """ 223 | ids = [] 224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 225 | if tokens in self.special_tokens: 226 | return self.special_tokens[tokens] 227 | else: 228 | return self.encoder.get(tokens, 0) 229 | for token in tokens: 230 | if token in self.special_tokens: 231 | ids.append(self.special_tokens[token]) 232 | else: 233 | ids.append(self.encoder.get(token, 0)) 234 | if len(ids) > self.max_len: 235 | logger.warning( 236 | "Token indices sequence length is longer than the specified maximum " 237 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 239 | ) 240 | return ids 241 | 242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 243 | """Converts a sequence of ids in BPE tokens using the vocab.""" 244 | tokens = [] 245 | for i in ids: 246 | if i in self.special_tokens_decoder: 247 | if not skip_special_tokens: 248 | tokens.append(self.special_tokens_decoder[i]) 249 | else: 250 | tokens.append(self.decoder[i]) 251 | return tokens 252 | 253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): 254 | """Converts a sequence of ids in a string.""" 255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 256 | out_string = ''.join(tokens).replace('', ' ').strip() 257 | if clean_up_tokenization_spaces: 258 | out_string = out_string.replace('', '') 259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't" 261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " 262 | ).replace(" 've", "'ve") 263 | return out_string 264 | -------------------------------------------------------------------------------- /code/generation/model/transformer_model.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import random 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from .transformer_module import TransformerModule 23 | 24 | 25 | class TransformerModel(nn.Module): 26 | def __init__(self, n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 27 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, ff_dropout, 28 | bos_id, eos_id, max_seq_len=256, beam_size=5, sample=False, 29 | length_penalty=0.8, annealing_topk=None, temperature=1.0, annealing=0, 30 | diversity_coef=0, diversity_groups=1, n_segments=None): 31 | 32 | super(TransformerModel, self).__init__() 33 | 34 | self.padding_idx = padding_idx 35 | self.n_embeddings = n_embeddings 36 | self.n_pos_embeddings = n_pos_embeddings 37 | self.embeddings_size = embeddings_size 38 | 39 | self.bos_id = bos_id 40 | self.eos_id = eos_id 41 | 42 | self.max_seq_len = max_seq_len 43 | self.beam_size = beam_size 44 | self.sample = sample 45 | self.length_penalty_coef = length_penalty 46 | self.annealing = annealing 47 | self.annealing_topk = annealing_topk 48 | self.temperature = temperature 49 | self.diversity_coef = diversity_coef 50 | self.diversity_groups = diversity_groups 51 | 52 | self.transformer_module = TransformerModule(n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 53 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, 54 | ff_dropout, n_segments) 55 | self.pre_softmax = nn.Linear(embeddings_size, n_embeddings, bias=False) 56 | self.pre_softmax.weight = self.transformer_module.embeddings.weight 57 | 58 | def forward(self, x, contexts=[]): 59 | enc_contexts = [self.encode(c) for c in contexts] 60 | return self.decode(x, enc_contexts) 61 | 62 | def encode(self, x): 63 | return self.transformer_module(x) 64 | 65 | def generate(self, enc_x): 66 | return self.pre_softmax(enc_x) 67 | 68 | def decode(self, x, enc_contexts=[]): 69 | x, _ = self.transformer_module(x, enc_contexts) 70 | return self.generate(x) 71 | 72 | def predict(self, contexts=[]): 73 | enc_contexts = [self.encode(c) for c in contexts] 74 | prediction = self.beam_search(enc_contexts) 75 | return prediction 76 | 77 | def predict_beams(self, contexts=[]): 78 | enc_contexts = [self.encode(c) for c in contexts] 79 | prediction = self.beam_search(enc_contexts, return_beams=True) 80 | return prediction 81 | 82 | def _length_penalty(self, sequence_lengths): 83 | """https://arxiv.org/abs/1609.08144""" 84 | return (5 + sequence_lengths) ** self.length_penalty_coef / (5 + 1) ** self.length_penalty_coef 85 | 86 | def beam_search(self, enc_contexts=[], return_beams=False): 87 | with torch.no_grad(): 88 | if len(enc_contexts) == 0: 89 | return [] 90 | 91 | batch_size = enc_contexts[0][0].shape[0] 92 | device = next(self.parameters()).device 93 | 94 | prevs = torch.full((batch_size * self.beam_size, 1), fill_value=self.bos_id, dtype=torch.long, device=device) 95 | 96 | beam_scores = torch.zeros(batch_size, self.beam_size, device=device) 97 | beam_lens = torch.ones(batch_size, self.beam_size, dtype=torch.long, device=device) 98 | is_end = torch.zeros(batch_size, self.beam_size, dtype=torch.uint8, device=device) 99 | 100 | beam_enc_contexts = [] 101 | for c, p in enc_contexts: 102 | c = c.unsqueeze(1).repeat(1, self.beam_size, 1, 1) 103 | c = c.view(-1, c.shape[2], c.shape[3]) 104 | p = p.unsqueeze(1).repeat(1, self.beam_size, 1) 105 | p = p.view(-1, p.shape[2]) 106 | beam_enc_contexts.append((c, p)) 107 | 108 | current_sample_prob = 1 109 | group_size = self.beam_size // self.diversity_groups 110 | diversity_penalty = torch.zeros((batch_size, self.n_embeddings), device=device) 111 | 112 | # zrs: 113 | repeat = [{} for i in range(batch_size * self.beam_size)] 114 | # ********** 115 | for i in range(self.max_seq_len): 116 | outputs, _ = self.transformer_module(prevs, beam_enc_contexts) 117 | 118 | logits = self.generate(outputs[:, -1, :]) 119 | log_probs = F.log_softmax(logits, dim=-1) 120 | ''' 121 | # zrs: remove n repeat. prevs: (batch_size*beam_size, 1) 122 | for idx in range(batch_size * self.beam_size): 123 | for key in repeat[idx]: 124 | for value in repeat[idx][key]: 125 | log_probs[idx][value] = -1000 126 | # ********** 127 | ''' 128 | # zrs 129 | prevs_list_temp = prevs.tolist() 130 | for idx in range(batch_size * self.beam_size): 131 | b_list = prevs_list_temp[idx] 132 | if len(b_list) > 1 and b_list[-1] != self.padding_idx and b_list[-1] != self.eos_id: 133 | key = (int(b_list[-2]), int(b_list[-1])) 134 | if key in repeat[idx]: 135 | for value in repeat[idx][key]: 136 | log_probs[idx][value] = -1000 137 | # ********** 138 | 139 | log_probs = log_probs.view(batch_size, self.beam_size, -1) 140 | 141 | beam_scores = beam_scores.unsqueeze(-1) + log_probs * (1 - is_end.float().unsqueeze(-1)) 142 | 143 | # zrs, log_probs: batch * beam * dim 144 | ba, be, dim = beam_scores.shape 145 | for ba_idx in range(ba): 146 | for be_idx in range(be): 147 | if int(torch.max(beam_scores[ba_idx][be_idx]) == torch.min(beam_scores[ba_idx][be_idx])): 148 | temp = float(beam_scores[ba_idx][be_idx][0]) 149 | beam_scores[ba_idx][be_idx] = -float('inf') 150 | beam_scores[ba_idx][be_idx][0] = temp 151 | # ********** 152 | 153 | 154 | penalty = self._length_penalty(beam_lens.float() + 1 - is_end.float()) 155 | penalty = penalty.unsqueeze(-1).repeat(1, 1, self.n_embeddings) 156 | beam_scores = beam_scores / penalty 157 | 158 | if i == 0: 159 | penalty = penalty[:, 0, :] 160 | beam_scores = beam_scores[:, 0, :] 161 | 162 | beam_scores, idxs = beam_scores.topk(self.beam_size, dim=-1) 163 | beam_idxs = torch.zeros((batch_size, self.beam_size), dtype=torch.long, device=device) 164 | else: 165 | penalty = penalty.view(batch_size, self.diversity_groups, group_size, -1) 166 | beam_scores = beam_scores.view(batch_size, self.diversity_groups, group_size, -1) 167 | 168 | all_scores, all_idxs = [], [] 169 | for g in range(self.diversity_groups): 170 | g_beam_scores = beam_scores[:, g, :, :] 171 | g_penalty = penalty[:, g, :, :] 172 | g_beam_scores -= self.diversity_coef * diversity_penalty.unsqueeze(1) / g_penalty 173 | g_beam_scores = g_beam_scores.view(batch_size, -1) 174 | 175 | if random.random() < current_sample_prob: 176 | beam_probas = F.softmax(g_beam_scores/self.temperature, dim=-1) 177 | if self.annealing_topk is not None: 178 | beam_probas, sample_idxs = beam_probas.topk(self.annealing_topk, dim=-1) 179 | g_idxs = torch.multinomial(beam_probas, group_size) 180 | g_idxs = torch.gather(sample_idxs, 1, g_idxs) 181 | else: 182 | g_idxs = torch.multinomial(beam_probas, group_size) 183 | else: 184 | _, g_idxs = g_beam_scores.topk(group_size, dim=-1) 185 | 186 | g_scores = torch.gather(beam_scores[:, g, :, :].view(batch_size, -1), 1, g_idxs) 187 | g_idxs += g * group_size * self.n_embeddings 188 | 189 | all_scores.append(g_scores) 190 | all_idxs.append(g_idxs) 191 | 192 | diversity_penalty.scatter_add_(1, torch.fmod(g_idxs, self.n_embeddings), torch.ones((batch_size, group_size), device=device)) 193 | 194 | diversity_penalty.fill_(0) 195 | penalty = penalty.view(batch_size, -1) 196 | beam_scores = torch.cat(all_scores, dim=-1) 197 | idxs = torch.cat(all_idxs, dim=-1) 198 | 199 | beam_idxs = (idxs.float() / self.n_embeddings).long() 200 | # print('beam_scores: ', beam_scores) 201 | 202 | penalty = torch.gather(penalty, 1, idxs) 203 | sym_idxs = torch.fmod(idxs, log_probs.shape[-1]) 204 | is_end = torch.gather(is_end, 1, beam_idxs) 205 | beam_lens = torch.gather(beam_lens, 1, beam_idxs) 206 | 207 | sym_idxs[is_end] = self.padding_idx 208 | beam_lens[~is_end] += 1 209 | is_end[sym_idxs == self.eos_id] = 1 210 | 211 | sym_idxs = sym_idxs.view(batch_size * self.beam_size, 1) 212 | prevs = prevs.view(batch_size, self.beam_size, -1) 213 | prevs = torch.gather(prevs, 1, beam_idxs.unsqueeze(-1).repeat(1, 1, prevs.shape[-1])) 214 | prevs = prevs.view(batch_size * self.beam_size, -1) 215 | prevs = torch.cat([prevs, sym_idxs], dim=1) 216 | 217 | # zrs: 218 | prevs_list = prevs.tolist() 219 | for b in range(batch_size * self.beam_size): 220 | b_list = prevs_list[b] 221 | if len(b_list) > 2 and b_list[-1] != self.padding_idx and b_list[-1] != self.eos_id: 222 | key = (int(b_list[-3]), int(b_list[-2])) 223 | if key in repeat[b]: 224 | repeat[b][key].append(int(b_list[-1])) 225 | else: 226 | repeat[b][key] = [int(b_list[-1])] 227 | # ******** 228 | 229 | if all(is_end.view(-1)): 230 | break 231 | # print(beam_scores.shape) 232 | beam_scores *= penalty 233 | current_sample_prob *= self.annealing 234 | 235 | predicts = [] 236 | result = prevs.view(batch_size, self.beam_size, -1) 237 | 238 | # if return_beams: 239 | # return result, beam_lens 240 | if return_beams: 241 | bests = torch.argsort(beam_scores, dim=-1, descending=True) 242 | for i in range(batch_size): 243 | temp = [] 244 | for j in range(self.beam_size): 245 | best_len = beam_lens[i, bests[i][j]] 246 | best_seq = result[i, bests[i][j], 1:best_len - 1] 247 | temp.append(best_seq.tolist()) 248 | predicts.append(temp) 249 | return predicts 250 | 251 | if self.sample: 252 | probs = F.softmax(beam_scores, dim=-1) 253 | bests = torch.multinomial(probs, 1).view(-1) 254 | else: 255 | bests = beam_scores.argmax(dim=-1) 256 | 257 | for i in range(batch_size): 258 | best_len = beam_lens[i, bests[i]] 259 | best_seq = result[i, bests[i], 1:best_len-1] 260 | predicts.append(best_seq.tolist()) 261 | 262 | return predicts 263 | -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | Arguments: 240 | closure (callable, optional): A closure that reevaluates the model 241 | and returns the loss. 242 | """ 243 | loss = None 244 | if closure is not None: 245 | loss = closure() 246 | 247 | for group in self.param_groups: 248 | for p in group['params']: 249 | if p.grad is None: 250 | continue 251 | grad = p.grad.data 252 | if grad.is_sparse: 253 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 254 | 255 | state = self.state[p] 256 | 257 | # State initialization 258 | if len(state) == 0: 259 | state['step'] = 0 260 | # Exponential moving average of gradient values 261 | state['next_m'] = torch.zeros_like(p.data) 262 | # Exponential moving average of squared gradient values 263 | state['next_v'] = torch.zeros_like(p.data) 264 | 265 | next_m, next_v = state['next_m'], state['next_v'] 266 | beta1, beta2 = group['b1'], group['b2'] 267 | 268 | # Add grad clipping 269 | if group['max_grad_norm'] > 0: 270 | clip_grad_norm_(p, group['max_grad_norm']) 271 | 272 | # Decay the first and second moment running average coefficient 273 | # In-place operations to update the averages at the same time 274 | next_m.mul_(beta1).add_(1 - beta1, grad) 275 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 276 | update = next_m / (next_v.sqrt() + group['e']) 277 | 278 | # Just adding the square of the weights to the loss function is *not* 279 | # the correct way of using L2 regularization/weight decay with Adam, 280 | # since that will interact with the m and v parameters in strange ways. 281 | # 282 | # Instead we want to decay the weights in a manner that doesn't interact 283 | # with the m/v parameters. This is equivalent to adding the square 284 | # of the weights to the loss with plain (non-momentum) SGD. 285 | if group['weight_decay'] > 0.0: 286 | update += group['weight_decay'] * p.data 287 | 288 | lr_scheduled = group['lr'] 289 | lr_scheduled *= group['schedule'].get_lr(state['step']) 290 | 291 | update_with_lr = lr_scheduled * update 292 | p.data.add_(-update_with_lr) 293 | 294 | state['step'] += 1 295 | 296 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 297 | # No bias correction 298 | # bias_correction1 = 1 - beta1 ** state['step'] 299 | # bias_correction2 = 1 - beta2 ** state['step'] 300 | 301 | return loss 302 | -------------------------------------------------------------------------------- /code/generation/config.py: -------------------------------------------------------------------------------- 1 | from attrdict import AttrDict 2 | from model.utils import openai_transformer_config 3 | 4 | 5 | # transformer config 6 | def get_model_config_dialog(): 7 | default_config = openai_transformer_config() 8 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt', 9 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/last_checkpoint', 10 | 'n_layers': 12, 11 | 'n_pos_embeddings': 512, 12 | 'embeddings_size': default_config.embeddings_size, 13 | 'n_heads': default_config.n_heads, 14 | 'dropout': default_config.dropout, 15 | 'embed_dropout': default_config.embed_dropout, 16 | 'attn_dropout': default_config.attn_dropout, 17 | 'ff_dropout': default_config.ff_dropout, 18 | 'max_seq_len': 32, 19 | 'beam_size': 1, 20 | 'diversity_coef': 0, 21 | 'diversity_groups': 1, 22 | 'temperature': 1.0, 23 | 'annealing_topk': None, 24 | 'annealing': 0, 25 | 'length_penalty': 1.0, 26 | 'n_segments': None}) 27 | 28 | return config 29 | 30 | 31 | def get_trainer_config_dialog(): 32 | config = AttrDict({'n_epochs': 100, 33 | 'batch_size': 256, 34 | 'batch_split': 32, 35 | 'lr': 6.25e-5, 36 | 'lr_warmup': 1000, 37 | 'lm_weight': 0.5, 38 | 'risk_weight': 0, 39 | 'n_jobs': 4, 40 | 'label_smoothing': 0.1, 41 | 'clip_grad': None, 42 | 'test_period': 1, 43 | 'seed': 0, 44 | 'device': 'cuda', 45 | 'load_last': True, 46 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 47 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/last_checkpoint', 48 | 'teacher_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded/last_checkpoint15', 49 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_300k/crowded_extend_both_kd/interrupt_checkpoint', 50 | 'train_datasets': [#'/root/generation_with_augmentation/dataset/dialog/crowded_500k.txt', 51 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k.txt', 52 | #'/root/generation_with_augmentation/dataset/dialog/crowded_100k.txt', 53 | #'/root/generation_with_augmentation/dataset/dialog/crowded_unpaired_500k.txt', 54 | #'/root/generation_with_augmentation/dataset/dialog/fake_unique.txt', 55 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_eda_v2.txt', 56 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_cvae_post.txt', 57 | #'/root/generation_with_augmentation/dataset/dialog/crowded_dirty_300k.txt', 58 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_bt.txt', 59 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_post.txt', 60 | #'/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_resp.txt', 61 | '/root/generation_with_augmentation/dataset/dialog/crowded_300k_extend_both.txt', 62 | #'/root/generation_with_augmentation/dataset/dialog/fake_wo_matching_generation.txt', 63 | ], 64 | 'test_datasets': ['/root/generation_with_augmentation/dataset/dialog/valid_9k.txt']}) 65 | return config 66 | 67 | 68 | def get_test_config_dialog(): 69 | config = AttrDict({'seed': 0, 70 | 'device': 'cuda', 71 | 'load_last': True, 72 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 73 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_kd/crowded_fake/last_checkpoint11'}) 74 | return config 75 | 76 | 77 | # transformer config overlap 78 | def get_model_config_dialog_overlap(): 79 | default_config = openai_transformer_config() 80 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab_overlap.txt', 81 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint', 82 | 'n_layers': 12, 83 | 'n_pos_embeddings': 512, 84 | 'embeddings_size': default_config.embeddings_size, 85 | 'n_heads': default_config.n_heads, 86 | 'dropout': default_config.dropout, 87 | 'embed_dropout': default_config.embed_dropout, 88 | 'attn_dropout': default_config.attn_dropout, 89 | 'ff_dropout': default_config.ff_dropout, 90 | 'max_seq_len': 32, 91 | 'beam_size': 5, 92 | 'diversity_coef': 0, 93 | 'diversity_groups': 1, 94 | 'temperature': 1.0, 95 | 'annealing_topk': None, 96 | 'annealing': 0, 97 | 'length_penalty': 1.5, 98 | 'n_segments': None}) 99 | 100 | return config 101 | 102 | 103 | def get_trainer_config_dialog_overlap(): 104 | config = AttrDict({'n_epochs': 30, 105 | 'batch_size': 256, 106 | 'batch_split': 16, 107 | 'lr': 6.25e-5, 108 | 'lr_warmup': 1000, 109 | 'lm_weight': 0.5, 110 | 'risk_weight': 0, 111 | 'n_jobs': 4, 112 | 'label_smoothing': 0.1, 113 | 'clip_grad': None, 114 | 'test_period': 1, 115 | 'seed': 0, 116 | 'device': 'cuda', 117 | 'load_last': False, 118 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 119 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint', 120 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/interrupt_checkpoint', 121 | 'train_datasets': ['/root/generation_with_augmentation/dataset/dialog/crowded_500k_overlap.txt', 122 | #'/root/generation_with_augmentation/dataset/dialog/crowded_unpaired_500k.txt', 123 | #'/root/generation_with_augmentation/dataset/dialog/fake_unique.txt', 124 | ], 125 | 'test_datasets': ['/root/generation_with_augmentation/dataset/dialog/valid_9k_overlap.txt']}) 126 | return config 127 | 128 | def get_test_config_dialog_overlap(): 129 | config = AttrDict({'seed': 0, 130 | 'device': 'cuda', 131 | 'load_last': True, 132 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 133 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/dialog_overlap/v1/last_checkpoint15'}) 134 | return config 135 | 136 | # transformer config 137 | def get_model_config_poem(): 138 | default_config = openai_transformer_config() 139 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt', 140 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/last_checkpoint', 141 | 'n_layers': 12, 142 | 'n_pos_embeddings': 512, 143 | 'embeddings_size': default_config.embeddings_size, 144 | 'n_heads': default_config.n_heads, 145 | 'dropout': default_config.dropout, 146 | 'embed_dropout': default_config.embed_dropout, 147 | 'attn_dropout': default_config.attn_dropout, 148 | 'ff_dropout': default_config.ff_dropout, 149 | 'max_seq_len': 128, 150 | 'beam_size': 15, 151 | 'diversity_coef': 0.5, 152 | 'diversity_groups': 5, 153 | 'temperature': 0.8, 154 | 'annealing_topk': 20, 155 | 'annealing': 1.0, 156 | 'length_penalty': 0.6, 157 | 'n_segments': None}) 158 | 159 | return config 160 | 161 | 162 | def get_trainer_config_poem(): 163 | config = AttrDict({'n_epochs': 30, 164 | 'batch_size': 256, 165 | 'batch_split': 8, 166 | 'lr': 6.25e-5, 167 | 'lr_warmup': 1000, 168 | 'lm_weight': 0, 169 | 'risk_weight': 0, 170 | 'n_jobs': 4, 171 | 'label_smoothing': 0, 172 | 'clip_grad': None, 173 | 'test_period': 1, 174 | 'seed': 0, 175 | 'device': 'cuda', 176 | 'load_last': False, 177 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 178 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/last_checkpoint', 179 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem_wu/interrupt_checkpoint', 180 | 'train_datasets': ['/root/generation_with_augmentation/dataset/poem/train_wu.txt'], 181 | 'test_datasets': ['/root/generation_with_augmentation/dataset/poem/valid_wu.txt']}) 182 | return config 183 | 184 | 185 | def get_test_config_poem(): 186 | config = AttrDict({'seed': 0, 187 | 'device': 'cuda', 188 | 'load_last': True, 189 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 190 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/poem/last_checkpoint20'}) 191 | 192 | return config 193 | 194 | # transformer config 195 | def get_model_config_meme(): 196 | default_config = openai_transformer_config() 197 | config = AttrDict({'vocab_path': '/root/generation_with_augmentation/parameters/vocab.txt', 198 | 'checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint', 199 | 'n_layers': 6, 200 | 'n_pos_embeddings': 512, 201 | 'embeddings_size': default_config.embeddings_size, 202 | 'n_heads': default_config.n_heads, 203 | 'dropout': default_config.dropout, 204 | 'embed_dropout': default_config.embed_dropout, 205 | 'attn_dropout': default_config.attn_dropout, 206 | 'ff_dropout': default_config.ff_dropout, 207 | 'max_seq_len': 30, 208 | 'beam_size': 5, 209 | 'diversity_coef': 0, 210 | 'diversity_groups': 1, 211 | 'temperature': 1.0, 212 | 'annealing_topk': None, 213 | 'annealing': 0, 214 | 'length_penalty': 1.0, 215 | 'n_segments': None}) 216 | 217 | return config 218 | 219 | 220 | def get_trainer_config_meme(): 221 | config = AttrDict({'n_epochs': 50, 222 | 'batch_size': 256, 223 | 'batch_split': 8, 224 | 'lr': 6.25e-5, 225 | 'lr_warmup': 1000, 226 | 'lm_weight': 0, 227 | 'risk_weight': 0, 228 | 'n_jobs': 4, 229 | 'label_smoothing': 0, 230 | 'clip_grad': None, 231 | 'test_period': 1, 232 | 'seed': 0, 233 | 'device': 'cuda', 234 | 'load_last': False, 235 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 236 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint', 237 | 'interrupt_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/interrupt_checkpoint', 238 | 'train_datasets': ['/root/generation_with_augmentation/dataset/meme/train_all.txt'], 239 | 'test_datasets': ['/root/generation_with_augmentation/dataset/meme/valid_all.txt']}) 240 | return config 241 | 242 | 243 | def get_test_config_meme(): 244 | config = AttrDict({'seed': 0, 245 | 'device': 'cuda', 246 | 'load_last': True, 247 | 'openai_parameters_dir': '/root/generation_with_augmentation/parameters/chinese_pretrain.pt', 248 | 'last_checkpoint_path': '/root/generation_with_augmentation/checkpoints/meme_all/last_checkpoint19'}) 249 | return config -------------------------------------------------------------------------------- /code/retrieval/pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 79 | """Constructs a BertTokenizer. 80 | 81 | Args: 82 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 83 | do_lower_case: Whether to lower case the input 84 | Only has an effect when do_wordpiece_only=False 85 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 86 | max_len: An artificial maximum length to truncate tokenized sequences to; 87 | Effective maximum length is always the minimum of this 88 | value (if specified) and the underlying BERT model's 89 | sequence length. 90 | never_split: List of tokens which will never be split during tokenization. 91 | Only has an effect when do_wordpiece_only=False 92 | """ 93 | if not os.path.isfile(vocab_file): 94 | raise ValueError( 95 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 96 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 97 | self.vocab = load_vocab(vocab_file) 98 | self.ids_to_tokens = collections.OrderedDict( 99 | [(ids, tok) for tok, ids in self.vocab.items()]) 100 | self.do_basic_tokenize = do_basic_tokenize 101 | if do_basic_tokenize: 102 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 103 | never_split=never_split) 104 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 105 | self.max_len = max_len if max_len is not None else int(1e12) 106 | 107 | def tokenize(self, text): 108 | if self.do_basic_tokenize: 109 | split_tokens = [] 110 | for token in self.basic_tokenizer.tokenize(text): 111 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 112 | split_tokens.append(sub_token) 113 | else: 114 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | """Converts a sequence of tokens into ids using the vocab.""" 119 | ids = [] 120 | for token in tokens: 121 | ids.append(self.vocab[token]) 122 | if len(ids) > self.max_len: 123 | logger.warning( 124 | "Token indices sequence length is longer than the specified maximum " 125 | " sequence length for this BERT model ({} > {}). Running this" 126 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 127 | ) 128 | return ids 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 132 | tokens = [] 133 | for i in ids: 134 | tokens.append(self.ids_to_tokens[i]) 135 | return tokens 136 | 137 | @classmethod 138 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 139 | """ 140 | Instantiate a PreTrainedBertModel from a pre-trained model file. 141 | Download and cache the pre-trained model file if needed. 142 | """ 143 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 144 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 145 | else: 146 | vocab_file = pretrained_model_name_or_path 147 | if os.path.isdir(vocab_file): 148 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 149 | # redirect to the cache, if necessary 150 | try: 151 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 152 | except EnvironmentError: 153 | logger.error( 154 | "Model name '{}' was not found in model name list ({}). " 155 | "We assumed '{}' was a path or url but couldn't find any file " 156 | "associated to this path or url.".format( 157 | pretrained_model_name_or_path, 158 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 159 | vocab_file)) 160 | return None 161 | if resolved_vocab_file == vocab_file: 162 | logger.info("loading vocabulary file {}".format(vocab_file)) 163 | else: 164 | logger.info("loading vocabulary file {} from cache at {}".format( 165 | vocab_file, resolved_vocab_file)) 166 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 167 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 168 | # than the number of positional embeddings 169 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 170 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 171 | # Instantiate tokenizer. 172 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 173 | return tokenizer 174 | 175 | 176 | class BasicTokenizer(object): 177 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 178 | 179 | def __init__(self, 180 | do_lower_case=True, 181 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 182 | """Constructs a BasicTokenizer. 183 | 184 | Args: 185 | do_lower_case: Whether to lower case the input. 186 | """ 187 | self.do_lower_case = do_lower_case 188 | self.never_split = never_split 189 | 190 | def tokenize(self, text): 191 | """Tokenizes a piece of text.""" 192 | text = self._clean_text(text) 193 | # This was added on November 1st, 2018 for the multilingual and Chinese 194 | # models. This is also applied to the English models now, but it doesn't 195 | # matter since the English models were not trained on any Chinese data 196 | # and generally don't have any Chinese data in them (there are Chinese 197 | # characters in the vocabulary because Wikipedia does have some Chinese 198 | # words in the English Wikipedia.). 199 | text = self._tokenize_chinese_chars(text) 200 | orig_tokens = whitespace_tokenize(text) 201 | split_tokens = [] 202 | for token in orig_tokens: 203 | if self.do_lower_case and token not in self.never_split: 204 | token = token.lower() 205 | token = self._run_strip_accents(token) 206 | split_tokens.extend(self._run_split_on_punc(token)) 207 | 208 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 209 | return output_tokens 210 | 211 | def _run_strip_accents(self, text): 212 | """Strips accents from a piece of text.""" 213 | text = unicodedata.normalize("NFD", text) 214 | output = [] 215 | for char in text: 216 | cat = unicodedata.category(char) 217 | if cat == "Mn": 218 | continue 219 | output.append(char) 220 | return "".join(output) 221 | 222 | def _run_split_on_punc(self, text): 223 | """Splits punctuation on a piece of text.""" 224 | if text in self.never_split: 225 | return [text] 226 | chars = list(text) 227 | i = 0 228 | start_new_word = True 229 | output = [] 230 | while i < len(chars): 231 | char = chars[i] 232 | if _is_punctuation(char): 233 | output.append([char]) 234 | start_new_word = True 235 | else: 236 | if start_new_word: 237 | output.append([]) 238 | start_new_word = False 239 | output[-1].append(char) 240 | i += 1 241 | 242 | return ["".join(x) for x in output] 243 | 244 | def _tokenize_chinese_chars(self, text): 245 | """Adds whitespace around any CJK character.""" 246 | output = [] 247 | for char in text: 248 | cp = ord(char) 249 | if self._is_chinese_char(cp): 250 | output.append(" ") 251 | output.append(char) 252 | output.append(" ") 253 | else: 254 | output.append(char) 255 | return "".join(output) 256 | 257 | def _is_chinese_char(self, cp): 258 | """Checks whether CP is the codepoint of a CJK character.""" 259 | # This defines a "chinese character" as anything in the CJK Unicode block: 260 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 261 | # 262 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 263 | # despite its name. The modern Korean Hangul alphabet is a different block, 264 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 265 | # space-separated words, so they are not treated specially and handled 266 | # like the all of the other languages. 267 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 268 | (cp >= 0x3400 and cp <= 0x4DBF) or # 269 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 270 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 271 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 272 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 273 | (cp >= 0xF900 and cp <= 0xFAFF) or # 274 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 275 | return True 276 | 277 | return False 278 | 279 | def _clean_text(self, text): 280 | """Performs invalid character removal and whitespace cleanup on text.""" 281 | output = [] 282 | for char in text: 283 | cp = ord(char) 284 | if cp == 0 or cp == 0xfffd or _is_control(char): 285 | continue 286 | if _is_whitespace(char): 287 | output.append(" ") 288 | else: 289 | output.append(char) 290 | return "".join(output) 291 | 292 | 293 | class WordpieceTokenizer(object): 294 | """Runs WordPiece tokenization.""" 295 | 296 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 297 | self.vocab = vocab 298 | self.unk_token = unk_token 299 | self.max_input_chars_per_word = max_input_chars_per_word 300 | 301 | def tokenize(self, text): 302 | """Tokenizes a piece of text into its word pieces. 303 | 304 | This uses a greedy longest-match-first algorithm to perform tokenization 305 | using the given vocabulary. 306 | 307 | For example: 308 | input = "unaffable" 309 | output = ["un", "##aff", "##able"] 310 | 311 | Args: 312 | text: A single token or whitespace separated tokens. This should have 313 | already been passed through `BasicTokenizer`. 314 | 315 | Returns: 316 | A list of wordpiece tokens. 317 | """ 318 | 319 | output_tokens = [] 320 | for token in whitespace_tokenize(text): 321 | chars = list(token) 322 | if len(chars) > self.max_input_chars_per_word: 323 | output_tokens.append(self.unk_token) 324 | continue 325 | 326 | is_bad = False 327 | start = 0 328 | sub_tokens = [] 329 | while start < len(chars): 330 | end = len(chars) 331 | cur_substr = None 332 | while start < end: 333 | substr = "".join(chars[start:end]) 334 | if start > 0: 335 | substr = "##" + substr 336 | if substr in self.vocab: 337 | cur_substr = substr 338 | break 339 | end -= 1 340 | if cur_substr is None: 341 | is_bad = True 342 | break 343 | sub_tokens.append(cur_substr) 344 | start = end 345 | 346 | if is_bad: 347 | output_tokens.append(self.unk_token) 348 | else: 349 | output_tokens.extend(sub_tokens) 350 | return output_tokens 351 | 352 | 353 | def _is_whitespace(char): 354 | """Checks whether `chars` is a whitespace character.""" 355 | # \t, \n, and \r are technically contorl characters but we treat them 356 | # as whitespace since they are generally considered as such. 357 | if char == " " or char == "\t" or char == "\n" or char == "\r": 358 | return True 359 | cat = unicodedata.category(char) 360 | if cat == "Zs": 361 | return True 362 | return False 363 | 364 | 365 | def _is_control(char): 366 | """Checks whether `chars` is a control character.""" 367 | # These are technically control characters but we count them as whitespace 368 | # characters. 369 | if char == "\t" or char == "\n" or char == "\r": 370 | return False 371 | cat = unicodedata.category(char) 372 | if cat.startswith("C"): 373 | return True 374 | return False 375 | 376 | 377 | def _is_punctuation(char): 378 | """Checks whether `chars` is a punctuation character.""" 379 | cp = ord(char) 380 | # We treat all non-letter/number ASCII as punctuation. 381 | # Characters such as "^", "$", and "`" are not in the Unicode 382 | # Punctuation class but we treat them as punctuation anyways, for 383 | # consistency. 384 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 385 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 386 | return True 387 | cat = unicodedata.category(char) 388 | if cat.startswith("P"): 389 | return True 390 | return False 391 | --------------------------------------------------------------------------------