├── README.md ├── block.PNG ├── config.py ├── datasets ├── ConvAI2 │ └── dummy.txt └── dummy_cache_gpt2 ├── gpt2-small ├── bpe.code ├── bpe.vocab ├── config.json ├── merges.txt ├── special_tokens.txt └── vocab.json ├── inference.py ├── metrics ├── mteval-v14c.pl └── multi-bleu.perl ├── model ├── __init__.py ├── common_layer.py ├── dataset.py ├── gpt2_model.py ├── loss.py ├── openai_model.py ├── optim.py ├── postprocessing.py ├── seq2seq.py ├── seq2seq_vocab.py ├── trainer.py ├── transformer_model.py ├── transformer_module.py └── utils.py ├── new_metrics.py ├── openai-gpt ├── config.json ├── merges.txt ├── special_tokens.txt └── vocab.json ├── requirements.txt └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi-GPT2 2 | The implementation of EMNLP2020-Findings paper 3 | 4 | **Pretrained Language Models for Dialogue Generationwith Multiple Input Sources** 5 | 6 | ![Block architecture](https://github.com/caoyu-noob/Multi-GPT2/blob/main/block.PNG) 7 | 8 | ## Requirements 9 | 1. Python 3.7 10 | 2. Pytorch==1.1.0 11 | 3. transformers==2.5.1 12 | 4. tensorboardX==2.0 13 | 5. git-python 14 | 5. tqdm 15 | 16 | Some other dependencies may be needed, please take requirements.txt as a reference. 17 | 18 | We run our standard experiment using one 32GB V100 GPU. If you use a GPU with smaller memory, please increase the 19 | `batch_split` or decrease the `train_batch_size` defined in `config.py`. 20 | 21 | To obtain the automatic metrics, you also need to install `java`-1.8.0, `perl` and related perl library including 22 | `XML::Twig`, `Sort::Naturally`, `String::Util` (I use `cpanm` to install them on Linux). 23 | 24 | ## How to run 25 | 26 | #### Download pretrained models 27 | You need to download the small-size [GPT2 model](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin) 28 | or [OpenAI GPT model](https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin), rename them as 29 | `pytorch_model.bin` and put them under `gpt2-small` or `openai-gpt` folder respectively. 30 | 31 | #### Download datasets 32 | You can download PersonaChat datasets from [ParlAI](https://github.com/facebookresearch/ParlAI) or use 33 | [our zipped version](https://drive.google.com/file/d/1zQVO5MuEy3wBUfZpM39uYmloD3-Rld4T/view?usp=sharing) 34 | (`train_self_original.txt` contains ~65k training samples while `train_both_original.txt` contains ~131k training samples), 35 | then put these raw txt files under `Dataset/ConvAI2` folder. 36 | 37 | we also provide dummy samples `dummy.txt` under this folder which can be used for test. 38 | 39 | #### Run the experiment 40 | Run the experiment on the dummy data using default mean attention fusion and GPT2-based encoder-decoder model. It can 41 | verify whether your environment is ready for running the experiment. 42 | ``` 43 | python train.py \ 44 | --train_datasets datasets/ConvAI2/dummy.txt \ 45 | --valid_datasets datasets/ConvAI2/dummy.txt \ 46 | --test_datasets datasets/ConvAI2/dummy.txt \ 47 | --train_datasets_cache datasets/dummy_cache_gpt2 \ 48 | --valid_datasets_cache datasets/dummy_cache_gpt2 \ 49 | --test_datasets_cache datasets/dummy_cache_gpt2 \ 50 | --model_type gpt2 \ 51 | ``` 52 | 53 | `model_type` can be changed to `GPT2`, `GPT` or `seq2seq`. 54 | 55 | Run the experiment on PersonaChat dataset using Source weight and GPT2-based encoder-decoder model. 56 | ``` 57 | python train.py \ 58 | --train_datasets datasets/ConvAI2/train_self_original.txt \ 59 | --valid_datasets datasets/ConvAI2/valid_self_original.txt \ 60 | --test_datasets datasets/ConvAI2/test_self_original.txt \ 61 | --train_datasets_cache datasets/train_cache_gpt2 \ 62 | --valid_datasets_cache datasets/valid_cache_gpt2 \ 63 | --test_datasets_cache datasets/test_cache_gpt2 \ 64 | --model_type gpt2 \ 65 | --attention_fusion_type sw \ 66 | --lr 5e-4 \ 67 | --extra_module_lr_rate 5.0 \ 68 | --shared_module 0 \ 69 | --shared_attention 0 \ 70 | ----max_history_size 9 \ 71 | ``` 72 | 73 | `train_datasets_cache` will only be created once and it differs between different base models. 74 | `attention_fusion_type` indicates the way to fuse attentions from different sources. 75 | `extra_module_lr_rate` indicates how many times the learning of extra module (attention fusion) is than the pretrained modules. 76 | 77 | Run the experiment on PersonaChat using Transformer-based Seq2seq model and `single input` indicates that different 78 | information will be concatenated together as one input such as `SI-` models mentioned in the paper. 79 | ``` 80 | python train.py \ 81 | --train_datasets datasets/ConvAI2/train_self_original.txt \ 82 | --valid_datasets datasets/ConvAI2/valid_self_original.txt \ 83 | --test_datasets datasets/ConvAI2/test_self_original.txt \ 84 | --train_datasets_cache datasets/train_cache_gpt2 \ 85 | --valid_datasets_cache datasets/valid_cache_gpt2 \ 86 | --test_datasets_cache datasets/test_cache_gpt2 \ 87 | --model_type seq2seq \ 88 | --pointer_gen \ 89 | --single_input \ 90 | --n_epochs 50 \ 91 | ``` 92 | 93 | We also provide inference script in this repo for inference using existed models. Here is an example. 94 | ``` 95 | python inference.py \ 96 | --valid_datasets datasets/ConvAI2/valid_self_original.txt \ 97 | --valid_datasets_cache datasets/valid_cache_gpt2 \ 98 | --test_datasets datasets/ConvAI2/test_self_original.txt \ 99 | --test_datasets_cache datasets/test_cache_gpt2 \ 100 | --model_type gpt2 \ 101 | --load_last ./test/best_model \ 102 | --inference_mode sampling \ 103 | --response_k 3 \ 104 | ``` 105 | The default inference method is beam search, you can use top-k sampling instead. 106 | -------------------------------------------------------------------------------- /block.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/Multi-GPT2/3fccdd55a5286427558f7669e22b9fa3d710c2e9/block.PNG -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from attrdict import AttrDict 2 | from copy import deepcopy 3 | import torch 4 | from model.utils import openai_transformer_config 5 | import git 6 | import argparse 7 | 8 | 9 | repo = git.Repo(search_parent_directories=True) 10 | 11 | 12 | def cast2(type_): 13 | return lambda val: val if val is None else type_(val) 14 | 15 | 16 | def get_model_config(args): 17 | default_config = openai_transformer_config() 18 | config = AttrDict({'bpe_vocab_path': './parameters/bpe.vocab', 19 | 'bpe_codes_path': './parameters/bpe.code', 20 | 'checkpoint_path': './checkpoints/last_checkpoint', # Keep the checpoint folder for the checkpoints of the agents 21 | 'n_layers': default_config.n_layers, 22 | 'n_pos_embeddings': 512, 23 | 'embeddings_size': default_config.embeddings_size, 24 | 'n_heads': default_config.n_heads, 25 | 'dropout': default_config.dropout, 26 | 'embed_dropout': default_config.embed_dropout, 27 | 'attn_dropout': default_config.attn_dropout, 28 | 'ff_dropout': default_config.ff_dropout, 29 | 'normalize_embeddings': args.normalize_embeddings, 30 | 'max_seq_len': 128, 31 | 'beam_size': args.beam_size, 32 | 'diversity_coef': args.diversity_coef, 33 | 'diversity_groups': args.diversity_groups, 34 | 'annealing_topk': args.annealing_topk, 35 | 'annealing': args.annealing, 36 | 'length_penalty': args.length_penalty, 37 | 'n_segments': None, 38 | 'constant_embedding': args.constant_embedding, 39 | 'multiple_choice_head': args.multiple_choice_head, 40 | 'share_models': True, 41 | 'successive_attention': args.successive_attention, 42 | 'sparse_embeddings': args.sparse_embeddings, 43 | 'shared_attention': args.shared_attention, 44 | 'dialog_embeddings': args.dialog_embeddings, 45 | 'single_input': args.single_input, 46 | 'use_start_end': args.use_start_end, 47 | 'apex_level': args.apex_level, # 'O0', 'O1', 'O2', 'O3', 48 | 'bs_temperature': args.bs_temperature, 49 | 'bs_nucleus_p': args.bs_nucleus_p, 50 | 'same_embedding_lm': args.same_embedding_lm, 51 | }) 52 | 53 | return config 54 | 55 | 56 | def get_trainer_config(args): 57 | config = AttrDict({'n_epochs': args.n_epochs, 58 | 'writer_comment': args.writer_comment, 59 | 'train_batch_size': args.train_batch_size, 60 | 'batch_split': args.batch_split, 61 | 'test_batch_size': args.test_batch_size, 62 | 'lr': args.lr, 63 | 'lr_warmup': args.lr_warmup, # a fraction of total training (epoch * train_set_length) if linear_schedule == True 64 | 'weight_decay': 0.01, 65 | 's2s_weight': args.s2s_weight, 66 | 'lm_weight': args.lm_weight, 67 | 'risk_weight': args.risk_weight, 68 | 'hits_weight': args.hits_weight, 69 | 'negative_samples': args.negative_samples, 70 | 'n_jobs': 4, 71 | 'label_smoothing': args.label_smoothing, 72 | 'clip_grad': args.clip_grad, 73 | 'test_period': 1, 74 | 'seed': args.seed, 75 | 'device': 'cuda', 76 | 'persona_augment': args.persona_augment, 77 | 'persona_aug_syn_proba': args.persona_aug_syn_proba, 78 | 'apex_loss_scale': args.apex_loss_scale, # e.g. '128', 'dynamic' 79 | 'linear_schedule': args.linear_schedule, 80 | 'evaluate_full_sequences': args.evaluate_full_sequences, 81 | 'limit_eval_size': args.limit_eval_size, 82 | 'limit_train_size': args.limit_train_size, 83 | 'risk_metric': args.risk_metric, 84 | 'load_last': args.load_last, #./checkpoints/last_checkpoint', # Now that we save several experiments you can put the path of the checpoint file you want to load here 85 | 'repo_id': str(repo), 86 | 'repo_sha': str(repo.head.object.hexsha), 87 | 'repo_branch': str(repo.active_branch), 88 | 'openai_parameters_dir': './parameters', 89 | 'last_checkpoint_path': 'last_checkpoint', # there are now in the ./runs/XXX/ experiments folders 90 | 'eval_references_file': 'eval_references_file', 91 | 'eval_predictions_file': 'eval_predictions_file', 92 | 'test_references_file': 'test_references_file', 93 | 'test_predictions_file_best': 'test_predictions_file_best', 94 | 'test_predictions_file_last': 'test_predictions_file_last', 95 | 'interrupt_checkpoint_path': 'interrupt_checkpoint', # there are now in the ./runs/XXX/ experiments folders 96 | 'train_datasets': args.train_datasets, 97 | 'train_datasets_cache': args.train_datasets_cache, 98 | 'test_datasets': args.test_datasets, 99 | 'test_datasets_cache': args.test_datasets_cache, 100 | 'valid_datasets': args.valid_datasets, 101 | 'valid_datasets_cache': args.valid_datasets_cache, 102 | 'full_input': args.full_input, 103 | 'single_input': args.single_input, 104 | 'max_history_size': args.max_history_size, 105 | 'model_saving_interval': args.model_saving_interval, 106 | 'patience': args.patience, 107 | 'data_type': args.data_type, 108 | 'ignore_train_indices': None, 109 | }) 110 | 111 | local_config = deepcopy(config) 112 | local_config.train_batch_size = 16 113 | local_config.batch_split = 2 114 | local_config.test_batch_size = 4 115 | local_config.n_jobs = 0 116 | local_config.device = 'cpu' 117 | local_config.risk_weight = 0 118 | local_config.zero_shot = False 119 | local_config.fp16 = False 120 | # local_config.train_datasets_cache = './datasets/train_datasets_cache.bin' 121 | # local_config.test_datasets_cache = './datasets/test_datasets_cache.bin' 122 | 123 | return config if torch.cuda.is_available() else local_config 124 | 125 | class InputConfig(): 126 | def __init__(self): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--seed', type=int, default=0) 129 | parser.add_argument('--normalize_embeddings', action='store_true') 130 | parser.add_argument('--beam_size', default=3, type=int) 131 | parser.add_argument('--inference_mode', default='beam', type=str) 132 | parser.add_argument('--response_k', default=1, type=int) 133 | parser.add_argument('--diversity_coef', default=0, type=int) 134 | parser.add_argument('--lr', default=6.25e-5, type=float) 135 | parser.add_argument('--lr_warmup', default=0.002, type=float) 136 | parser.add_argument('--clip_grad', default=None, type=float) 137 | parser.add_argument('--diversity_groups', default=1, type=int) 138 | parser.add_argument('--annealing_topk', default=None, type=int) 139 | parser.add_argument('--annealing', default=0, type=float) 140 | parser.add_argument('--length_penalty', default=0.6, type=float) 141 | parser.add_argument('--bs_temperature', default=1, type=float) 142 | parser.add_argument('--bs_nucleus_p', default=0, type=float) 143 | parser.add_argument('--apex_level', default=None, type=str) 144 | parser.add_argument('--constant_embedding', action='store_true') 145 | parser.add_argument('--multiple_choice_head', action='store_true') 146 | parser.add_argument('--successive_attention', action='store_true') 147 | parser.add_argument('--sparse_embeddings', default=False, type=bool) 148 | parser.add_argument('--dialog_embeddings', default=True, type=bool) 149 | parser.add_argument('--single_input', action='store_true') 150 | parser.add_argument('--no_persona', action='store_true') 151 | parser.add_argument('--use_start_end', action='store_true') 152 | parser.add_argument('--persona_augment', action='store_true') 153 | parser.add_argument('--linear_schedule', default=True, type=bool) 154 | parser.add_argument('--evaluate_full_sequences', default=True, type=bool) 155 | parser.add_argument('--n_epochs', default=5, type=int) 156 | parser.add_argument('--patience', default=-1, type=int, help="the training patience if the dev result " 157 | "does not promote then training ends") 158 | parser.add_argument('--train_batch_size', default=256, type=int) 159 | parser.add_argument('--batch_split', default=32, type=int) 160 | parser.add_argument('--test_batch_size', default=8, type=int) 161 | parser.add_argument('--writer_comment', default='', type=str) 162 | parser.add_argument('--s2s_weight', default=2, type=float) 163 | parser.add_argument('--lm_weight', default=1, type=float) 164 | parser.add_argument('--risk_weight', default=0, type=float) 165 | parser.add_argument('--hits_weight', default=0, type=float) 166 | parser.add_argument('--label_smoothing', default=-1, type=float, 167 | help='Config for Seq2Seq model, whether use label smoothing loss, -1 means no smoothing') 168 | parser.add_argument('--negative_samples', default=0, type=int) 169 | parser.add_argument('--persona_aug_syn_proba', default=0, type=float) 170 | parser.add_argument('--apex_loss_scale', default=None, type=str) 171 | parser.add_argument('--limit_eval_size', default=-1, type=int) 172 | parser.add_argument('--limit_train_size', default=-1, type=int) 173 | parser.add_argument('--risk_metric', default='f1', type=str) 174 | parser.add_argument('--load_last', default='', type=str) 175 | parser.add_argument('--load_alpha_last', default='', type=str) 176 | parser.add_argument('--data_type', default='persona', type=str, help='data set types, persona/emoji/daily') 177 | parser.add_argument('--test_data_type', default=None, type=str, help='data set types, persona/emoji/daily') 178 | parser.add_argument('--emb_dim', default=300, type=int, help='Config for Seq2Seq model') 179 | parser.add_argument('--hidden_dim', default=300, type=int, help='Config for Seq2Seq model') 180 | parser.add_argument('--num_layers', default=6, type=int, help='Config for Seq2Seq model') 181 | parser.add_argument('--heads', default=4, type=int, help='Config for Seq2Seq model') 182 | parser.add_argument('--depth_size', default=40, type=int, help='Config for Seq2Seq model') 183 | parser.add_argument('--filter_size', default=50, type=int, help='Config for Seq2Seq model') 184 | parser.add_argument('--pointer_gen', action='store_true', help='Config for Seq2Seq model') 185 | parser.add_argument('--pretrained_emb_file', default='./glove/glove.6B.300d.txt', type=str) 186 | parser.add_argument('--vocab_path', default='./datasets/persona_vocab.bin', type=str) 187 | parser.add_argument('--extend_exist_vocab', default=None, type=str) 188 | parser.add_argument('--train_datasets', default='datasets/ConvAI2/train_self_original.txt', type=str) 189 | parser.add_argument('--valid_datasets', default='datasets/ConvAI2/valid_self_original.txt', type=str) 190 | parser.add_argument('--test_datasets', default='datasets/ConvAI2/test_self_original.txt', type=str) 191 | parser.add_argument('--cache_vocab_path', default='datasets/ConvAI2/cached_vocab.pickle', type=str) 192 | parser.add_argument('--train_datasets_cache', default='datasets/train_cache.bin', type=str) 193 | parser.add_argument('--valid_datasets_cache', default='datasets/valid_cache.bin', type=str) 194 | parser.add_argument('--test_datasets_cache', default='datasets/test_cache.bin', type=str) 195 | parser.add_argument('--full_input', action='store_true', help='whether use the concatenated persona, history' 196 | ' and reply as the input ids') 197 | parser.add_argument('--max_history_size', type=int, default=-1, help='max history size in input ids') 198 | parser.add_argument('--same_embedding_lm', type=int, default=1, help='the embedding in transformer and the ' 199 | 'weight in the lm are the same') 200 | parser.add_argument('--uncertainty_loss', action='store_true', help='whether use uncertainty loss') 201 | parser.add_argument('--model_type', type=str, default='gpt2', help='gpt/gpt2/se2seq/rnn-seq2seq') 202 | parser.add_argument('--model_saving_interval', type=int, default=10, help='model saving interval for seq2seq') 203 | parser.add_argument('--shared_module', type=int, default=1) 204 | parser.add_argument('--shared_attention', type=int, default=1) 205 | parser.add_argument('--attention_fusion_type', type=str, default='mean', help='the method to pool attention ' 206 | 'output from different source(mean/min/max/sw/dw/linear/att) ' 207 | 'sw=source level weight, dw=dimension level weight, linear=linear transform for concatenating,' 208 | 'att=extra transformer attention layer to fuse attention output,' 209 | 'dys=dynamic determine the scalar weight for each source by a linear layer' 210 | 'dyd=dynamic determine the vector weight for each dimension for each source by a linear layer' 211 | 'mdys=mutual determine the scalar weight for each source by a linear layer' 212 | 'mdyd=mutual dynamic determine the vector weight for each dimension for each source by a linear layer') 213 | parser.add_argument('--extra_module_lr_rate', type=float, default=1.0, help='The lr mulitply rate for extra module, usually it needs a higher learning rate') 214 | parser.add_argument('--local_rank', type=int, default=-1, help="Distributed training.") 215 | parser.add_argument('--server_ip', type=str, default='', help="Used for debugging on GPU machine.") 216 | parser.add_argument('--server_port', type=str, default='', help="Used for debugging on GPU machine.") 217 | 218 | self.args = parser.parse_args() 219 | -------------------------------------------------------------------------------- /datasets/dummy_cache_gpt2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/Multi-GPT2/3fccdd55a5286427558f7669e22b9fa3d710c2e9/datasets/dummy_cache_gpt2 -------------------------------------------------------------------------------- /gpt2-small/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "GPT2LMHeadModel" 4 | ], 5 | "initializer_range": 0.02, 6 | "layer_norm_epsilon": 1e-05, 7 | "n_ctx": 1024, 8 | "n_embd": 768, 9 | "n_head": 12, 10 | "n_layer": 12, 11 | "n_positions": 1024, 12 | "vocab_size": 50257, 13 | "shared_module": true, 14 | "shared_attention": false, 15 | "context_size": 2 16 | } 17 | -------------------------------------------------------------------------------- /gpt2-small/special_tokens.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers.tokenization_gpt2 import GPT2Tokenizer 7 | from transformers.tokenization_openai import OpenAIGPTTokenizer 8 | 9 | from config import get_trainer_config 10 | from config import InputConfig 11 | from model.dataset import FacebookDataset 12 | from model.gpt2_model import GPT2DoubleHeadsModel 13 | from model.gpt2_model import GPT2EncoderDecoderModel 14 | from model.openai_model import OpenAIGPTEncoderDecoderModel 15 | from model.seq2seq import TransformerSeq2Seq 16 | from model.seq2seq_vocab import Seq2seqVocab 17 | from model.trainer import Trainer 18 | from model.utils import config_logger 19 | from model.utils import f1_score 20 | from model.utils import open 21 | from model.utils import set_seed 22 | from new_metrics import nlp_metrics 23 | 24 | PADDING_IDX = 0 25 | 26 | def modify_tokenizer(tokenizer, data_type): 27 | additional_special_tokens = ['', '', '', '', '', 28 | ''] 29 | if data_type == 'emoji': 30 | with open('datasets/emoji_talk/emojis.json', 'r') as f: 31 | emojis = json.load(f)['emojis'] 32 | additional_special_tokens.extend(emojis) 33 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '', 34 | 'additional_special_tokens': additional_special_tokens}) 35 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id 36 | tokenizer.sent_dialog_id = tokenizer.bos_token_id 37 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \ 38 | tokenizer.added_tokens_encoder[ 39 | ''] 40 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder[''] 41 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \ 42 | tokenizer.added_tokens_encoder[''] 43 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder[''] 44 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \ 45 | tokenizer.added_tokens_encoder[''] 46 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder[''] 47 | return tokenizer, len(additional_special_tokens) + 3 48 | 49 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False): 50 | # assuming trailing dimensions and type of all the Tensors 51 | # in sequences are same and fetching those from sequences[0] 52 | if not len(sequences): 53 | return torch.empty(0) 54 | trailing_dims = sequences[0].size()[1:] 55 | max_len = max([s.size(0) for s in sequences]) 56 | if batch_first: 57 | out_dims = (len(sequences), max_len) + trailing_dims 58 | else: 59 | out_dims = (max_len, len(sequences)) + trailing_dims 60 | 61 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 62 | for i, tensor in enumerate(sequences): 63 | length = tensor.size(0) 64 | s_slice = slice(-length, None) if left else slice(None, length) 65 | s_slice = (i, s_slice) if batch_first else (s_slice, i) 66 | out_tensor[s_slice] = tensor 67 | 68 | return out_tensor 69 | 70 | def collate_func(data): 71 | persona_info, h, y, distractors_batch = zip(*data) 72 | 73 | contexts = [] 74 | 75 | if max(map(len, persona_info)) > 0: 76 | persona_info = [torch.tensor(d, dtype=torch.long) for d in persona_info] 77 | contexts.append(persona_info) 78 | 79 | if max(map(len, h)) > 0: 80 | h = [torch.tensor(d, dtype=torch.long) for d in h] 81 | contexts.append(h) 82 | 83 | y_out = [torch.tensor(d, dtype=torch.long) for d in y] 84 | 85 | distractors = [torch.tensor(d, dtype=torch.long) for distractors in distractors_batch for d in distractors] 86 | 87 | # Pad now so we pad correctly when we have only a single input (context concatenated with y) 88 | y_out = pad_sequence(y_out, batch_first=True, padding_value=PADDING_IDX) 89 | distractors = pad_sequence(distractors, batch_first=True, padding_value=PADDING_IDX) 90 | contexts = [pad_sequence(c, batch_first=True, padding_value=PADDING_IDX) for c in contexts] 91 | 92 | return contexts, y_out, distractors 93 | 94 | def _s2s_loss(targets, enc_contexts, model): 95 | hidden_state, padding_mask = None, None 96 | 97 | nexts = targets[:, 1:].contiguous() if targets.dim() == 2 else targets[:, 1:, 0].contiguous() 98 | outputs = model.decode(targets[:, :-1].contiguous(), enc_contexts) 99 | 100 | outputs = outputs.view(-1, outputs.shape[-1]).float() 101 | nexts = nexts.view(-1) 102 | 103 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX) 104 | loss = lm_criterion(outputs, nexts) 105 | return loss, hidden_state, padding_mask 106 | 107 | def _lm_loss(contexts, enc_contexts, model, ignore_idxs, device): 108 | batch_lm_loss = torch.tensor(0, dtype=torch.float, device=device) 109 | 110 | for context in contexts: 111 | enc_context = model.encode(context.clone()) 112 | enc_contexts.append(enc_context) 113 | 114 | context_outputs = model.generate(enc_context[0]) 115 | ignore_mask = torch.stack([context == idx for idx in ignore_idxs], dim=-1).any(dim=-1) 116 | context.masked_fill_(ignore_mask, PADDING_IDX) 117 | prevs = context_outputs[:, :-1, :].contiguous() 118 | nexts = context[:, 1:].contiguous() if context.dim() == 2 else context[:, 1:, 0].contiguous() 119 | lm_criterion = torch.nn.CrossEntropyLoss(ignore_index=PADDING_IDX) 120 | batch_lm_loss += lm_criterion(prevs.view(-1, prevs.shape[-1]).float(), nexts.view(-1)) / len(contexts) 121 | return batch_lm_loss 122 | 123 | 124 | 125 | def main(): 126 | args = InputConfig().args 127 | 128 | trainer_config = get_trainer_config(args) 129 | 130 | set_seed(trainer_config.seed) 131 | device = torch.device(trainer_config.device) 132 | save_path = trainer_config.load_last[:trainer_config.load_last.rfind('/')] 133 | logger = config_logger(os.path.join(save_path, 'inference.log')) 134 | 135 | parsed_valid_data, parsed_test_data = None, None 136 | if args.model_type == 'gpt2': 137 | if args.single_input: 138 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small') 139 | else: 140 | model = GPT2EncoderDecoderModel.from_pretrained('./gpt2-small') 141 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small') 142 | elif args.model_type == 'gpt': 143 | model = OpenAIGPTEncoderDecoderModel.from_pretrained('./openai-gpt') 144 | tokenizer = OpenAIGPTTokenizer.from_pretrained('./openai-gpt') 145 | elif args.model_type == 'seq2seq': 146 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets, 147 | trainer_config.test_datasets, args.vocab_path, data_type=args.data_type) 148 | tokenizer = seq2seq_vocab.vocab 149 | parsed_train_data, parsed_valid_data, parsed_test_data = seq2seq_vocab.all_data[0], seq2seq_vocab.all_data[1], \ 150 | seq2seq_vocab.all_data[2] 151 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size, 152 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger, 153 | multi_input=not args.single_input, attention_fusion_type=args.attention_fusion_type, 154 | is_eval=True) 155 | args.dialog_embeddings = False 156 | 157 | model.shared_attention = (args.shared_attention == 1) 158 | model.shared_module = (args.shared_module == 1) 159 | model.attention_fusion_type = args.attention_fusion_type 160 | if args.model_type in ['gpt', 'dialogpt', 'gpt2', 'gpt2_darts']: 161 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type) 162 | model.embeddings_size = 768 163 | model.n_embeddings = len(tokenizer) 164 | model.shared_attention = (args.shared_attention == 1) 165 | model.shared_module = (args.shared_module == 1) 166 | model.attention_fusion_type = args.attention_fusion_type 167 | model.single_input = args.single_input 168 | if args.model_type == 'gpt': 169 | model_embedding_weight = model.transformer.tokens_embed.weight 170 | model.transformer.tokens_embed = nn.Embedding(model.n_embeddings, 768) 171 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 172 | model.transformer.tokens_embed.weight.data[:-additional_length, :] = model_embedding_weight.data 173 | model.transformer.tokens_embed.weight.data[-additional_length:, :] = 0 174 | model.lm_head.weight = model.transformer.tokens_embed.weight 175 | else: 176 | model_embedding_weight = model.transformer.wte.weight 177 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768) 178 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 179 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data 180 | model.transformer.wte.weight.data[-additional_length:, :] = 0 181 | model.lm_head.weight = model.transformer.wte.weight 182 | 183 | if not args.single_input: 184 | model.reload_module_dict() 185 | model.sent_dialog_id = tokenizer.sent_dialog_id 186 | 187 | model.padding_idx = tokenizer.pad_id 188 | model.n_pos_embeddings = 512 189 | 190 | model.talker1_id = tokenizer.talker1_bos_id 191 | model.talker2_id = tokenizer.talker2_bos_id 192 | model.bos_id = tokenizer.bos_id 193 | model.eos_id = tokenizer.eos_id 194 | model.beam_size = args.beam_size 195 | model.diversity_groups = 1 196 | model.max_seq_len = 32 197 | model.dialog_embeddings = args.dialog_embeddings 198 | model.bs_temperature = args.bs_temperature 199 | model.bs_nucleus_p = args.bs_nucleus_p 200 | model.annealing_topk = args.annealing_topk 201 | model.length_penalty_coef = args.length_penalty 202 | model.vocab = None 203 | model.annealing = args.annealing 204 | model.diversity_coef = args.diversity_coef 205 | model.sample = False 206 | model.inference_mode = args.inference_mode 207 | model.response_k = args.response_k 208 | 209 | logger.info('loading datasets') 210 | valid_dataset = FacebookDataset(trainer_config.valid_datasets, tokenizer, 211 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here 212 | dialog_embeddings=args.dialog_embeddings, 213 | cache=trainer_config.valid_datasets_cache, 214 | use_start_end=args.use_start_end, 215 | negative_samples=0, # Keep all negative samples 216 | augment=False, 217 | aug_syn_proba=0.0, 218 | limit_size=trainer_config.limit_eval_size, 219 | single_input=args.single_input, 220 | data_type=args.data_type, 221 | parsed_data=parsed_valid_data) 222 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer, 223 | max_lengths=(model.n_pos_embeddings - 1) // (3 if args.single_input else 1), # A bit restrictive here 224 | dialog_embeddings=args.dialog_embeddings, 225 | cache=trainer_config.test_datasets_cache, 226 | use_start_end=args.use_start_end, 227 | negative_samples=0, # Keep all negative samples 228 | augment=False, 229 | aug_syn_proba=0.0, 230 | limit_size=trainer_config.limit_eval_size, 231 | single_input=args.single_input, 232 | data_type=args.data_type, 233 | parsed_data=parsed_test_data) 234 | logger.info(f'valid dataset {len(valid_dataset)} test dataset {(len(test_dataset))}') 235 | 236 | model.to(device) 237 | logger.info('Weights loaded from {}'.format(trainer_config.load_last)) 238 | 239 | trainer = Trainer(model, 240 | valid_dataset, 241 | None, 242 | logger=logger, 243 | valid_dataset=valid_dataset, 244 | test_dataset=test_dataset, 245 | train_batch_size=trainer_config.train_batch_size, 246 | batch_split=trainer_config.batch_split, 247 | test_batch_size=trainer_config.test_batch_size, 248 | single_input=args.single_input, 249 | n_jobs=trainer_config.n_jobs, 250 | clip_grad=trainer_config.clip_grad, 251 | device=device, 252 | ignore_idxs=tokenizer.all_special_ids, 253 | local_rank=args.local_rank, 254 | apex_level=None, 255 | apex_loss_scale=trainer_config.apex_loss_scale, 256 | linear_schedule=trainer_config.linear_schedule, 257 | n_epochs=trainer_config.n_epochs, 258 | evaluate_full_sequences=trainer_config.evaluate_full_sequences, 259 | full_input=trainer_config.full_input, 260 | uncertainty_loss=args.uncertainty_loss) 261 | 262 | def external_metrics_func(full_references, full_predictions, epoch, metric=None): 263 | if epoch == -1: 264 | references_file_path = os.path.join(save_path, 'test_references_file.txt') 265 | predictions_file_path = os.path.join(save_path, 'test_predictions_file.txt') 266 | else: 267 | references_file_path = os.path.join(save_path, 'eval_references_file.txt') 268 | predictions_file_path = os.path.join(save_path, 'eval_predictions_file.txt') 269 | with open(references_file_path, 'w', encoding='utf-8') as f: 270 | f.write('\n'.join(full_references)) 271 | with open(predictions_file_path, 'w', encoding='utf-8') as f: 272 | f.write('\n'.join(full_predictions)) 273 | 274 | bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \ 275 | rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path) 276 | 277 | metrics = {'meteor': meteor, 'avg_len': avg_length, 'rouge-l': rouge_l, 'bleu': bleu, 'nist': nist, 278 | 'nist-bleu': nist_bleu, 'f1': f1_score} 279 | for name, metric in ( 280 | ('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy), 281 | ('sentence_div', s_dist), ('corpus_div', c_dist)): 282 | for i, m in enumerate(metric, 1): 283 | metrics['{}_{}'.format(name, i)] = m 284 | 285 | return metrics 286 | 287 | metric_funcs = {'f1_score': f1_score} 288 | # trainer.test(metric_funcs, external_metrics_func, epoch=0, inference=True) 289 | trainer.test(metric_funcs, external_metrics_func, epoch=-1, inference=True) 290 | 291 | 292 | if __name__ == '__main__': 293 | main() 294 | -------------------------------------------------------------------------------- /metrics/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | print "$length \n"; 82 | for(my $n=1;$n<=4;$n++) { 83 | my %REF_NGRAM_N = (); 84 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 85 | my $ngram = "$n"; 86 | for(my $w=0;$w<$n;$w++) { 87 | $ngram .= " ".$WORD[$start+$w]; 88 | } 89 | $REF_NGRAM_N{$ngram}++; 90 | } 91 | foreach my $ngram (keys %REF_NGRAM_N) { 92 | if (!defined($REF_NGRAM{$ngram}) || 93 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 94 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 95 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 96 | } 97 | } 98 | } 99 | } 100 | $length_translation += $length_translation_this_sentence; 101 | $length_reference += $closest_length; 102 | for(my $n=1;$n<=4;$n++) { 103 | my %T_NGRAM = (); 104 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 105 | my $ngram = "$n"; 106 | for(my $w=0;$w<$n;$w++) { 107 | $ngram .= " ".$WORD[$start+$w]; 108 | } 109 | $T_NGRAM{$ngram}++; 110 | } 111 | foreach my $ngram (keys %T_NGRAM) { 112 | $ngram =~ /^(\d+) /; 113 | # my $n = $1; 114 | # my $corr = 0; 115 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 116 | $TOTAL[$n] += $T_NGRAM{$ngram}; 117 | if (defined($REF_NGRAM{$ngram})) { 118 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 119 | $CORRECT[$n] += $T_NGRAM{$ngram}; 120 | # $corr = $T_NGRAM{$ngram}; 121 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 122 | } 123 | else { 124 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 125 | # $corr = $REF_NGRAM{$ngram}; 126 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 127 | } 128 | } 129 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 130 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 131 | } 132 | } 133 | $s++; 134 | } 135 | my $brevity_penalty = 1; 136 | my $bleu = 0; 137 | 138 | my @bleu=(); 139 | 140 | for(my $n=1;$n<=4;$n++) { 141 | if (defined ($TOTAL[$n])){ 142 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 143 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 144 | }else{ 145 | $bleu[$n]=0; 146 | } 147 | } 148 | 149 | if ($length_reference==0){ 150 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 151 | exit(1); 152 | } 153 | 154 | if ($length_translation<$length_reference) { 155 | $brevity_penalty = exp(1-$length_reference/$length_translation); 156 | } 157 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 158 | my_log( $bleu[2] ) + 159 | my_log( $bleu[3] ) + 160 | my_log( $bleu[4] ) ) / 4) ; 161 | printf "BLEU = %.5f, %.5f/%.5f/%.5f/%.5f (BP=%.5f, ratio=%.5f, hyp_len=%d, ref_len=%d)\n", 162 | 100*$bleu, 163 | 100*$bleu[1], 164 | 100*$bleu[2], 165 | 100*$bleu[3], 166 | 100*$bleu[4], 167 | $brevity_penalty, 168 | $length_translation / $length_reference, 169 | $length_translation, 170 | $length_reference; 171 | 172 | 173 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 174 | 175 | sub my_log { 176 | return -9999999999 unless $_[0]; 177 | return log($_[0]); 178 | } -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/Multi-GPT2/3fccdd55a5286427558f7669e22b9fa3d710c2e9/model/__init__.py -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import json 18 | import os 19 | import random 20 | 21 | import torch 22 | from torch.utils.data import Dataset 23 | from tqdm import tqdm 24 | from transformers.tokenization_gpt2 import GPT2Tokenizer 25 | from transformers.tokenization_openai import OpenAIGPTTokenizer 26 | 27 | from model.seq2seq_vocab import Seq2seqTokenizer 28 | from .postprocessing import augment_replica 29 | 30 | SPECIAL_TOKENS = ['.', ',', '?', '!', ':'] 31 | 32 | class FacebookDataset(Dataset): 33 | @staticmethod 34 | def parse_data(path): 35 | with open(path, 'r', encoding='utf-8') as file: 36 | data = [] 37 | for line in file.readlines(): 38 | line = line.strip() 39 | 40 | if len(line) == 0: 41 | continue 42 | 43 | space_idx = line.find(' ') 44 | if space_idx == -1: 45 | dialog_idx = int(line) 46 | else: 47 | dialog_idx = int(line[:space_idx]) 48 | 49 | if int(dialog_idx) == 1: 50 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 51 | 52 | dialog_line = line[space_idx + 1:].split('\t') 53 | dialog_line = [l.strip() for l in dialog_line] 54 | 55 | if dialog_line[0].startswith('your persona:'): 56 | persona_info = dialog_line[0].replace('your persona: ', '') 57 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 58 | persona_info = persona_info[:-1] + ' .' 59 | data[-1]['persona_info'].append(persona_info) 60 | if dialog_line[0].startswith('partner\'s person'): 61 | if not data[-1].__contains__('partner_persona_info'): 62 | data[-1]['partner_persona_info'] = [] 63 | persona_info = dialog_line[0].replace('partner\'s persona: ', '') 64 | if persona_info[-1] == '.' and persona_info[-2] != ' ': 65 | persona_info = persona_info[:-1] + ' .' 66 | data[-1]['partner_persona_info'].append(persona_info) 67 | 68 | elif len(dialog_line) > 1: 69 | data[-1]['dialog'].append(dialog_line[0]) 70 | data[-1]['dialog'].append(dialog_line[1]) 71 | if len(dialog_line) == 4: 72 | data[-1]['candidates'].append(dialog_line[3].split('|')[:-1]) # the last candidate is a duplicate of the good answer (dialog_line[1]) 73 | 74 | return data 75 | 76 | @staticmethod 77 | def parse_data_emoji(path): 78 | with open(path, 'r', encoding='utf-8') as f: 79 | data = [] 80 | for line in f.readlines(): 81 | line = line.strip() 82 | items = line.split('\t') 83 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 84 | data[-1]['persona_info'].append(items[0]) 85 | data[-1]['dialog'].append(items[1]) 86 | data[-1]['dialog'].append(items[2]) 87 | return data 88 | 89 | @staticmethod 90 | def parse_data_daily(path): 91 | with open(path, 'r', encoding='utf-8') as f: 92 | data = [] 93 | for line in f.readlines(): 94 | line = line.strip() 95 | items = line.split('\t') 96 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 97 | data[-1]['persona_info'].append(items[0]) 98 | for i in range(1, len(items)): 99 | data[-1]['dialog'].append(items[i]) 100 | return data 101 | 102 | @staticmethod 103 | def parse_data_weibo(path): 104 | with open(path, 'r', encoding='utf-8') as f: 105 | data = [] 106 | for line in f.readlines(): 107 | line = line.strip() 108 | items = line.split('\t') 109 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 110 | data[-1]['dialog'].append(items[0]) 111 | data[-1]['dialog'].append(items[1]) 112 | return data 113 | 114 | 115 | @staticmethod 116 | def make_dataset(data, vocab, only_final=False): 117 | dataset = [] 118 | if isinstance(vocab, OpenAIGPTTokenizer) or isinstance(vocab, GPT2Tokenizer) or isinstance(vocab, Seq2seqTokenizer): 119 | for chat in tqdm(data): 120 | persona_info = [vocab.encode(vocab.tokenize(s)) for s in chat['persona_info']] 121 | 122 | dialog = [] 123 | if only_final: 124 | for utterance in chat['dialog']: 125 | dialog.append(vocab.encode(vocab.tokenize(utterance))) 126 | dataset.append((persona_info, dialog[:], [])) 127 | else: 128 | for i, replica in enumerate(chat['dialog'], 1): 129 | dialog.append(vocab.encode(vocab.tokenize(replica))) 130 | if not i % 2: 131 | if chat['candidates']: 132 | candidates_ids = [vocab.encode(vocab.tokenize(c)) for c in chat['candidates'][(i - 1) // 2]] 133 | dataset.append((persona_info, dialog[:], candidates_ids)) 134 | else: 135 | dataset.append((persona_info, dialog[:], [])) 136 | if chat.__contains__('partner_persona_info'): 137 | persona_info = [vocab.encode(vocab.tokenize(s)) for s in chat['partner_persona_info']] 138 | dialog = [] 139 | for i, replica in enumerate(chat['dialog'], 1): 140 | dialog.append(vocab.encode(vocab.tokenize(replica))) 141 | if i % 2 and i > 2: 142 | dataset.append((persona_info, dialog[:], [])) 143 | else: 144 | for chat in tqdm(data): 145 | persona_info = [vocab.string2ids(s) for s in chat['persona_info']] 146 | 147 | dialog = [] 148 | for i, replica in enumerate(chat['dialog'], 1): 149 | dialog.append(vocab.string2ids(replica)) 150 | if not i % 2: 151 | if chat['candidates']: 152 | candidates_ids = [vocab.string2ids(c) for c in chat['candidates'][(i-1)//2]] 153 | dataset.append((persona_info, dialog[:], candidates_ids)) 154 | else: 155 | dataset.append((persona_info, dialog[:], [])) 156 | 157 | return dataset 158 | 159 | def __init__(self, paths, vocab, *, max_lengths=512, max_y_length=80, min_infos=2, dialog_embeddings=False, 160 | use_start_end=True, negative_samples=0, limit_size=-1, 161 | cache=None, augment=False, aug_syn_proba=0.1, aug_vary_length=True, max_history_size=-1, 162 | single_input=False, data_type='persona', parsed_data=None, few_shot=False, task_map_path=None): 163 | assert min_infos > 0 164 | 165 | if isinstance(paths, str): 166 | paths = [paths] 167 | 168 | self.augment = augment 169 | self.aug_syn_proba = aug_syn_proba 170 | self.aug_vary_length = aug_vary_length 171 | 172 | self.vocab = vocab 173 | self.max_lengths = max_lengths 174 | self.max_y_length = max_y_length 175 | self.min_infos = min_infos 176 | self.dialog_embeddings = dialog_embeddings 177 | self.use_start_end = use_start_end 178 | self.negative_samples = negative_samples # -1 => include all candidates in data instance 179 | self.max_history_size = max_history_size 180 | self.single_input = single_input 181 | self.data_type = data_type 182 | 183 | if cache and os.path.exists(cache): 184 | self.data = torch.load(cache) 185 | else: 186 | self.data = self._parse_data(paths, vocab, data_type, parsed_data) 187 | if cache: 188 | torch.save(self.data, cache) 189 | if limit_size > 0: 190 | self.data = self.data[:limit_size] 191 | if few_shot and task_map_path is not None: 192 | with open(task_map_path, 'r') as f: 193 | self.task_map = json.load(f) 194 | 195 | def __len__(self): 196 | return len(self.data) 197 | 198 | def _parse_data(self, paths, vocab, data_type, parsed_data): 199 | data = None 200 | if data_type == 'persona': 201 | if not parsed_data: 202 | parsed_data = sum([FacebookDataset.parse_data(path) for path in paths], []) 203 | data = FacebookDataset.make_dataset(parsed_data, vocab) 204 | elif data_type == 'emoji': 205 | if not parsed_data: 206 | parsed_data = sum([FacebookDataset.parse_data_emoji(path) for path in paths], []) 207 | data = FacebookDataset.make_dataset(parsed_data, vocab) 208 | elif data_type == 'daily': 209 | if not parsed_data: 210 | parsed_data = sum([FacebookDataset.parse_data_daily(path) for path in paths], []) 211 | data = FacebookDataset.make_dataset(parsed_data, vocab) 212 | return data 213 | 214 | def _augment(self, sentences, info=False): 215 | 216 | if not self.augment: 217 | return sentences 218 | 219 | if info: 220 | n_info_samples = max(self.min_infos, random.randint(1, len(sentences))) 221 | n_info_samples = min(n_info_samples, len(sentences)) 222 | sentences = random.sample(sentences, n_info_samples) 223 | random.shuffle(sentences) 224 | else: 225 | if self.aug_vary_length: 226 | begin = random.randrange(0, len(sentences) - 1, 2) 227 | end = random.randrange(begin + 2, len(sentences) + 1, 2) 228 | 229 | sentences = sentences[begin:end] 230 | 231 | def _try2augment(sent): 232 | if random.uniform(0, 1) < self.aug_syn_proba: 233 | sent = self.vocab.ids2string(sent) 234 | sent = augment_replica(sent) 235 | sent = self.vocab.string2ids(sent) 236 | return sent 237 | 238 | sentences = list(map(_try2augment, sentences)) if self.aug_syn_proba > 0 else sentences 239 | 240 | return sentences 241 | 242 | def _get_distractors(self, candidates): 243 | if self.negative_samples == 0: 244 | return [] 245 | if self.negative_samples == -1: # => include all candidates in data instance 246 | return candidates 247 | if len(candidates) >= self.negative_samples: 248 | distractors = random.sample(candidates, k=self.negative_samples) 249 | else: # not enought candidates, sample from train dataset instead (we may sample the gold y but quite unlikely) 250 | distractors = random.sample(range(len(self.data)), k=self.negative_samples) 251 | distractors = [self.data[ids][1][-1] for ids in distractors] 252 | return distractors 253 | 254 | def get_tasks_dataset(self): 255 | tasks = [] 256 | for k, v in self.task_map.items(): 257 | tasks.append((k, v['ids'])) 258 | return TaskDataset(tasks) 259 | 260 | def __getitem__(self, idx): 261 | persona_info, dialog, candidates = self.data[idx] 262 | 263 | if len(persona_info): 264 | persona_info = self._augment(persona_info, info=True) 265 | persona_info = sum(persona_info, []) 266 | if self.single_input: 267 | persona_info = [self.vocab.bos_id] + persona_info 268 | if self.dialog_embeddings: 269 | persona_info = [[tok, self.vocab.talker1_bos_id] for tok in persona_info] 270 | elif not self.single_input and not self.dialog_embeddings: 271 | persona_info = [self.vocab.bos_id] + persona_info[:self.max_lengths-2] 272 | else: 273 | persona_info = [self.vocab.info_bos_id] + persona_info[:self.max_lengths-2] + \ 274 | [self.vocab.info_eos_id] if self.use_start_end else persona_info[:self.max_lengths] 275 | if self.dialog_embeddings: 276 | persona_info = [[tok, self.vocab.info_dialog_id] for tok in persona_info] 277 | 278 | dialog = self._augment(dialog) 279 | candidates = self._get_distractors(candidates) 280 | 281 | h = [] 282 | history_start = 0 283 | if self.max_history_size != -1: 284 | history_start = -1 - self.max_history_size 285 | dialog_history = dialog[history_start: -1] 286 | if self.single_input: 287 | for i, ids in enumerate(dialog_history): 288 | if (len(dialog_history) - i) % 2 == 0: 289 | ids = [self.vocab.talker1_bos_id] + ids 290 | else: 291 | ids = [self.vocab.talker2_bos_id] + ids 292 | if self.dialog_embeddings: 293 | ids = [[tok, self.vocab.talker1_bos_id if (len(dialog_history) - i) % 2 == 0 294 | else self.vocab.talker2_bos_id] for tok in ids] 295 | h.extend(ids) 296 | elif not self.single_input and not self.dialog_embeddings: 297 | for i, ids in enumerate(dialog_history): 298 | if (len(dialog_history) - i) % 2 == 0: 299 | ids = [self.vocab.talker1_bos_id] + ids 300 | else: 301 | ids = [self.vocab.talker2_bos_id] + ids 302 | h.extend(ids) 303 | else: 304 | for i, ids in enumerate(dialog_history): 305 | if (len(dialog_history) - i) % 2 == 0 and self.use_start_end: 306 | ids = [self.vocab.talker1_bos_id] + ids + [self.vocab.talker1_eos_id] 307 | elif self.use_start_end: 308 | ids = [self.vocab.talker2_bos_id] + ids + [self.vocab.talker2_eos_id] 309 | if self.dialog_embeddings: 310 | ids = [[tok, self.vocab.talker1_dialog_id if (len(dialog_history) - i) % 2 == 0 311 | else self.vocab.talker2_dialog_id] for tok in ids] 312 | h.extend(ids) 313 | h = h[-self.max_lengths:] 314 | 315 | sentences = [] 316 | for y in (dialog[-1:] + candidates): 317 | if self.single_input: 318 | y = [self.vocab.talker1_bos_id] + y + [self.vocab.eos_id] 319 | if self.dialog_embeddings: 320 | y = [[tok, self.vocab.talker1_bos_id] for tok in y] 321 | sentences.append(y) 322 | elif not self.single_input and not self.dialog_embeddings: 323 | y = [self.vocab.talker1_bos_id] + y + [self.vocab.eos_id] 324 | sentences.append(y) 325 | else: 326 | y = [self.vocab.bos_id] + y + [self.vocab.eos_id] 327 | if self.dialog_embeddings: 328 | y = [[tok, self.vocab.sent_dialog_id] for tok in y] 329 | sentences.append(y) 330 | 331 | return persona_info, h, sentences[0], sentences[1:] 332 | 333 | class TaskDataset(Dataset): 334 | def __init__(self, data_list): 335 | self.data_list = data_list 336 | 337 | def __getitem__(self, idx): 338 | return self.data_list[idx] 339 | 340 | def __len__(self): 341 | return len(self.data_list) 342 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class LabelSmoothingLoss(nn.Module): 23 | def __init__(self, n_labels, smoothing=0.0, ignore_index=-100, size_average=True): 24 | super(LabelSmoothingLoss, self).__init__() 25 | assert 0 <= smoothing <= 1 26 | 27 | self.ignore_index = ignore_index 28 | self.confidence = 1 - smoothing 29 | 30 | if smoothing > 0: 31 | self.criterion = nn.KLDivLoss(size_average=size_average) 32 | n_ignore_idxs = 1 + (ignore_index >= 0) 33 | one_hot = torch.full((1, n_labels), fill_value=(smoothing / (n_labels - n_ignore_idxs))) 34 | if ignore_index >= 0: 35 | one_hot[0, ignore_index] = 0 36 | self.register_buffer('one_hot', one_hot) 37 | else: 38 | self.criterion = nn.NLLLoss(size_average=size_average, ignore_index=ignore_index) 39 | 40 | def forward(self, log_inputs, targets): 41 | if self.confidence < 1: 42 | tdata = targets.data 43 | 44 | tmp = self.one_hot.repeat(targets.shape[0], 1) 45 | tmp.scatter_(1, tdata.unsqueeze(1), self.confidence) 46 | 47 | if self.ignore_index >= 0: 48 | mask = torch.nonzero(tdata.eq(self.ignore_index)).squeeze(-1) 49 | if mask.numel() > 0: 50 | tmp.index_fill_(0, mask, 0) 51 | 52 | targets = tmp 53 | 54 | return self.criterion(log_inputs, targets) 55 | -------------------------------------------------------------------------------- /model/optim.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | 22 | logger = logging.getLogger(__file__) 23 | 24 | 25 | class Adam(torch.optim.Optimizer): 26 | """Implements Adam algorithm. 27 | This implementation is modified from torch.optim.Adam based on: 28 | `Fixed Weight Decay Regularization in Adam` 29 | (see https://arxiv.org/abs/1711.05101) 30 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-3) 35 | betas (Tuple[float, float], optional): coefficients used for computing 36 | running averages of gradient and its square (default: (0.9, 0.999)) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-8) 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 41 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 42 | .. _Adam\: A Method for Stochastic Optimization: 43 | https://arxiv.org/abs/1412.6980 44 | .. _On the Convergence of Adam and Beyond: 45 | https://openreview.net/forum?id=ryQu7f-RZ 46 | """ 47 | 48 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): 49 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 50 | super(Adam, self).__init__(params, defaults) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | Arguments: 55 | closure (callable, optional): A closure that reevaluates the model 56 | and returns the loss. 57 | """ 58 | loss = None 59 | if closure is not None: 60 | loss = closure() 61 | 62 | for group in self.param_groups: 63 | for p in group['params']: 64 | if p.grad is None: 65 | continue 66 | grad = p.grad.data 67 | 68 | amsgrad = group['amsgrad'] 69 | 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['step'] = 0 75 | # Exponential moving average of gradient values 76 | state['exp_avg'] = torch.zeros_like(p.data) 77 | # Exponential moving average of squared gradient values 78 | state['exp_avg_sq'] = torch.zeros_like(p.data) 79 | if amsgrad: 80 | # Maintains max of all exp. moving avg. of sq. grad. values 81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 82 | 83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 84 | if amsgrad: 85 | max_exp_avg_sq = state['max_exp_avg_sq'] 86 | beta1, beta2 = group['betas'] 87 | 88 | state['step'] += 1 89 | 90 | if grad.is_sparse: 91 | grad = grad.coalesce() # the update is non-linear so indices must be unique 92 | grad_indices = grad._indices() 93 | grad_values = grad._values() 94 | size = grad.size() 95 | 96 | def make_sparse(values): 97 | constructor = grad.new 98 | if grad_indices.dim() == 0 or values.dim() == 0: 99 | return constructor().resize_as_(grad) 100 | return constructor(grad_indices, values, size) 101 | 102 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 103 | beta1, beta2 = group['betas'] 104 | 105 | # Decay the first and second moment running average coefficient 106 | # old <- b * old + (1 - b) * new 107 | # <==> old += (1 - b) * (new - old) 108 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 109 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 110 | exp_avg.add_(make_sparse(exp_avg_update_values)) 111 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 112 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 113 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 114 | 115 | # Dense addition again is intended, avoiding another sparse_mask 116 | numer = exp_avg_update_values.add_(old_exp_avg_values) 117 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 118 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 119 | del exp_avg_update_values, exp_avg_sq_update_values 120 | 121 | bias_correction1 = 1 - beta1 ** state['step'] 122 | bias_correction2 = 1 - beta2 ** state['step'] 123 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 124 | 125 | p.data.add_(make_sparse(-step_size * numer.div_(denom))) 126 | else: 127 | # Decay the first and second moment running average coefficient 128 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 129 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 130 | if amsgrad: 131 | # Maintains the maximum of all 2nd moment running avg. till now 132 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 133 | # Use the max. for normalizing running avg. of gradient 134 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 135 | else: 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | 138 | bias_correction1 = 1 - beta1 ** state['step'] 139 | bias_correction2 = 1 - beta2 ** state['step'] 140 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 141 | 142 | if group['weight_decay'] != 0: 143 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 144 | 145 | p.data.addcdiv_(-step_size, exp_avg, denom) 146 | 147 | return loss 148 | 149 | def backward(self, losses): 150 | with torch.autograd.set_detect_anomaly(True): 151 | if not isinstance(losses, (tuple, list)): 152 | losses = [losses] 153 | full_loss = sum(losses, 0) 154 | full_loss.backward() 155 | return full_loss 156 | 157 | class NoamOpt: 158 | def __init__(self, embeddings_size, warmup, optimizer, linear_schedule=False, lr=None, total_steps=None, 159 | apex_level=None, loss_weight=None, extra_module_lr_rate=1.0): 160 | self.embeddings_size = embeddings_size 161 | self.warmup = warmup 162 | self.optimizer = optimizer 163 | self.linear_schedule = linear_schedule 164 | self.apex_level = apex_level 165 | self.lr = lr 166 | self.total_steps = total_steps 167 | self.loss_weight = loss_weight 168 | self.extra_module_lr_rate = extra_module_lr_rate 169 | 170 | self._step = 0 171 | 172 | def state_dict(self): 173 | return {'step': self._step, 174 | 'optimizer': self.optimizer.state_dict()} 175 | 176 | def load_state_dict(self, state_dict): 177 | self._step = state_dict['step'] 178 | try: 179 | self.optimizer.load_state_dict(state_dict['optimizer']) 180 | except ValueError as e: 181 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e)) 182 | except KeyError as e: 183 | logger.info("Optimizer cannot be loaded from checkpoint: {}".format(e)) 184 | 185 | def backward(self, losses): 186 | if not isinstance(losses, (tuple, list)): 187 | losses = [losses] 188 | if self.loss_weight is None: 189 | full_loss = sum(losses, 0) 190 | else: 191 | full_loss = torch.sum(torch.stack(losses, 0) * torch.exp(self.loss_weight[1])) + torch.sum(self.loss_weight[1]) 192 | 193 | if self.apex_level is not None: 194 | try: 195 | from apex.amp import scale_loss 196 | except ImportError: 197 | raise ImportError("Please install apex.") 198 | 199 | for loss_id, loss in enumerate(losses): 200 | with scale_loss(loss, self.optimizer, loss_id=loss_id) as scaled_loss: 201 | scaled_loss.backward() 202 | else: 203 | full_loss.backward() 204 | return full_loss 205 | 206 | def zero_grad(self): 207 | return self.optimizer.zero_grad() 208 | 209 | def get_lr(self): 210 | return self.optimizer.param_groups[0]['lr'] 211 | 212 | @property 213 | def param_groups(self): 214 | return self.optimizer.param_groups 215 | 216 | def step(self): 217 | self._step += 1 218 | rate = self.rate_linear() if self.linear_schedule else self.rate() 219 | for p in self.optimizer.param_groups: 220 | if p.__contains__('extra'): 221 | p['lr'] = rate * self.extra_module_lr_rate 222 | else: 223 | p['lr'] = rate 224 | self.optimizer.step() 225 | 226 | def rate(self, step=None): 227 | if step is None: 228 | step = self._step 229 | 230 | return self.lr * (self.embeddings_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 231 | 232 | @staticmethod 233 | def warmup_linear(x, warmup=0.002): 234 | if x < warmup: 235 | return x/warmup 236 | return 1.0 - x 237 | 238 | def rate_linear(self, step=None): 239 | if step is None: 240 | step = self._step 241 | assert self.lr is not None and self.total_steps is not None 242 | 243 | return self.lr * self.warmup_linear(step/self.total_steps, self.warmup) 244 | -------------------------------------------------------------------------------- /model/postprocessing.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import random 18 | from collections import defaultdict 19 | 20 | import nltk 21 | from nltk.corpus import wordnet 22 | 23 | 24 | def augment_replica(seq): 25 | _exceptions = ['your', 'persona'] 26 | pos2wn = {'NN': wordnet.NOUN, 27 | 'JJ': wordnet.ADJ, 28 | 'VBP': wordnet.VERB, 29 | 'RB': wordnet.ADV} 30 | 31 | synonyms = defaultdict(list) 32 | 33 | tagged_seq = seq.replace('i ', 'I ') 34 | tagged_seq = nltk.pos_tag(nltk.word_tokenize(tagged_seq)) 35 | 36 | for word, pos in tagged_seq: 37 | if pos not in pos2wn or word in _exceptions: 38 | continue 39 | 40 | pos = pos2wn[pos] 41 | synnets = wordnet.synsets(word, pos=pos) 42 | 43 | for synnet in synnets: 44 | for syn in synnet.lemma_names(): 45 | if syn != word: 46 | synonyms[word].append(syn.replace('_', ' ')) 47 | break 48 | if synonyms: 49 | for key, values in synonyms.items(): 50 | seq = seq.replace(key, random.choice(list(values))) 51 | 52 | return seq 53 | -------------------------------------------------------------------------------- /model/seq2seq.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from model.common_layer import _gen_bias_mask 8 | from model.common_layer import _gen_timing_signal 9 | from model.common_layer import _get_attn_subsequent_mask 10 | from model.common_layer import DecoderLayer 11 | from model.common_layer import EncoderLayer 12 | from model.common_layer import LabelSmoothing 13 | from model.common_layer import LayerNorm 14 | from .utils import repeat_along_dim1 15 | 16 | 17 | class Encoder(nn.Module): 18 | """ 19 | A Transformer Encoder module. 20 | Inputs should be in the shape [batch_size, length, hidden_size] 21 | Outputs will have the shape [batch_size, length, hidden_size] 22 | Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf 23 | """ 24 | 25 | def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 26 | filter_size, max_length=512, input_dropout=0.0, layer_dropout=0.0, 27 | attention_dropout=0.0, relu_dropout=0.0, use_mask=False, universal=False): 28 | """ 29 | Parameters: 30 | embedding_size: Size of embeddings 31 | hidden_size: Hidden size 32 | num_layers: Total layers in the Encoder 33 | num_heads: Number of attention heads 34 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 35 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 36 | output_depth: Size last dimension of the final output 37 | filter_size: Hidden size of the middle layer in FFN 38 | max_length: Max sequence length (required for timing signal) 39 | input_dropout: Dropout just after embedding 40 | layer_dropout: Dropout for each layer 41 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 42 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 43 | use_mask: Set to True to turn on future value masking 44 | """ 45 | 46 | super(Encoder, self).__init__() 47 | self.universal = universal 48 | self.num_layers = num_layers 49 | self.timing_signal = _gen_timing_signal(max_length, hidden_size) 50 | 51 | if (self.universal): 52 | ## for t 53 | self.position_signal = _gen_timing_signal(num_layers, hidden_size) 54 | 55 | params = (hidden_size, 56 | total_key_depth or hidden_size, 57 | total_value_depth or hidden_size, 58 | filter_size, 59 | num_heads, 60 | _gen_bias_mask(max_length) if use_mask else None, 61 | layer_dropout, 62 | attention_dropout, 63 | relu_dropout) 64 | 65 | self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) 66 | if (self.universal): 67 | self.enc = EncoderLayer(*params) 68 | else: 69 | self.enc = nn.ModuleList([EncoderLayer(*params) for _ in range(num_layers)]) 70 | 71 | self.layer_norm = LayerNorm(hidden_size) 72 | self.input_dropout = nn.Dropout(input_dropout) 73 | 74 | def forward(self, inputs, mask): 75 | # Add input dropout 76 | x = self.input_dropout(inputs) 77 | 78 | # Project to hidden size 79 | x = self.embedding_proj(x) 80 | 81 | # Add timing signal 82 | x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) 83 | 84 | for i in range(self.num_layers): 85 | x = self.enc[i](x, mask) 86 | 87 | y = self.layer_norm(x) 88 | return y 89 | 90 | 91 | class Decoder(nn.Module): 92 | """ 93 | A Transformer Decoder module. 94 | Inputs should be in the shape [batch_size, length, hidden_size] 95 | Outputs will have the shape [batch_size, length, hidden_size] 96 | Refer Fig.1 in https://arxiv.org/pdf/1706.03762.pdf 97 | """ 98 | 99 | def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, 100 | filter_size, max_length=512, input_dropout=0.0, layer_dropout=0.0, 101 | attention_dropout=0.0, relu_dropout=0.0, universal=False, multi_input=False, context_size=1, 102 | attention_fusion_type='mean'): 103 | """ 104 | Parameters: 105 | embedding_size: Size of embeddings 106 | hidden_size: Hidden size 107 | num_layers: Total layers in the Encoder 108 | num_heads: Number of attention heads 109 | total_key_depth: Size of last dimension of keys. Must be divisible by num_head 110 | total_value_depth: Size of last dimension of values. Must be divisible by num_head 111 | output_depth: Size last dimension of the final output 112 | filter_size: Hidden size of the middle layer in FFN 113 | max_length: Max sequence length (required for timing signal) 114 | input_dropout: Dropout just after embedding 115 | layer_dropout: Dropout for each layer 116 | attention_dropout: Dropout probability after attention (Should be non-zero only during training) 117 | relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training) 118 | multi_input: Whether use multiple attention modules in the decoder 119 | context_size: The number of multiple inputs 120 | """ 121 | 122 | super(Decoder, self).__init__() 123 | self.universal = universal 124 | self.num_layers = num_layers 125 | self.timing_signal = _gen_timing_signal(max_length, hidden_size) 126 | 127 | if (self.universal): 128 | ## for t 129 | self.position_signal = _gen_timing_signal(num_layers, hidden_size) 130 | 131 | self.mask = _get_attn_subsequent_mask(max_length) 132 | 133 | params = (hidden_size, 134 | total_key_depth or hidden_size, 135 | total_value_depth or hidden_size, 136 | filter_size, 137 | num_heads, 138 | _gen_bias_mask(max_length), # mandatory 139 | layer_dropout, 140 | attention_dropout, 141 | relu_dropout, 142 | multi_input, 143 | context_size, 144 | attention_fusion_type) 145 | 146 | self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) 147 | if (self.universal): 148 | self.dec = DecoderLayer(*params) 149 | else: 150 | self.dec = nn.Sequential(*[DecoderLayer(*params) for l in range(num_layers)]) 151 | 152 | self.layer_norm = LayerNorm(hidden_size) 153 | self.input_dropout = nn.Dropout(input_dropout) 154 | self.multi_input = multi_input 155 | self.context_size = context_size 156 | 157 | def forward(self, inputs, encoder_output, mask_src, mask_trg): 158 | dec_mask = torch.gt(mask_trg + self.mask[:, :mask_trg.size(-1), :mask_trg.size(-1)], 0) 159 | # Add input dropout 160 | x = self.input_dropout(inputs) 161 | # Project to hidden size 162 | x = self.embedding_proj(x) 163 | 164 | # Add timing signal 165 | x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) 166 | 167 | # Run decoder 168 | y, _, attn_dist, _ = self.dec((x, encoder_output, [], (mask_src, dec_mask))) 169 | 170 | # Final layer normalization 171 | y = self.layer_norm(y) 172 | return y, attn_dist 173 | 174 | class Generator(nn.Module): 175 | "Define standard linear + softmax generation step." 176 | def __init__(self, hidden_size, vocab, pointer_gen): 177 | super(Generator, self).__init__() 178 | self.proj = nn.Linear(hidden_size, vocab) 179 | self.p_gen_linear = nn.Linear(hidden_size, 1) 180 | self.pointer_gen = pointer_gen 181 | 182 | def forward(self, x, attn_dist=None, enc_batch_extend_vocab=None, temp=1, beam_search=False): 183 | 184 | if self.pointer_gen: 185 | p_gen = self.p_gen_linear(x) 186 | p_gen = torch.sigmoid(p_gen) 187 | 188 | logit = self.proj(x) 189 | 190 | if self.pointer_gen: 191 | vocab_dist = F.softmax(logit/temp, dim=2) 192 | vocab_dist_ = p_gen * vocab_dist 193 | 194 | if isinstance(attn_dist, list): 195 | enc_batch_extend_vocab_ = [torch.cat([sub_vocab.unsqueeze(1)] * x.size(1), 196 | 1) for sub_vocab in enc_batch_extend_vocab] ## extend for all seq 197 | if (beam_search): 198 | enc_batch_extend_vocab_ = [torch.cat([sub_vocab[0].unsqueeze(0)] * x.size(0), 199 | 0) for sub_vocab in enc_batch_extend_vocab_] ## extend for all seq 200 | attn_dist = [F.softmax(a / temp, dim=-1) for a in attn_dist] 201 | attn_dist_ = [(1 - p_gen) * a / 2 for a in attn_dist] 202 | for i in range(len(attn_dist_)): 203 | vocab_dist_.scatter_add(2, enc_batch_extend_vocab_[i], attn_dist_[i]) 204 | logit = torch.log(vocab_dist_ + 1e-40) 205 | else: 206 | enc_batch_extend_vocab_ = torch.cat([enc_batch_extend_vocab.unsqueeze(1)] * x.size(1), 207 | 1) ## extend for all seq 208 | if (beam_search): 209 | enc_batch_extend_vocab_ = torch.cat([enc_batch_extend_vocab_[0].unsqueeze(0)] * x.size(0), 210 | 0) ## extend for all seq 211 | attn_dist = F.softmax(attn_dist / temp, dim=-1) 212 | attn_dist_ = (1 - p_gen) * attn_dist 213 | logit = torch.log(vocab_dist_.scatter_add(2, enc_batch_extend_vocab_, attn_dist_) + 1e-40) 214 | return logit 215 | else: 216 | return F.log_softmax(logit, dim=-1) 217 | 218 | class Embedding: 219 | def __init__(self, tokenizer, emb_size, pretrained_file, logger): 220 | self.emb_size = emb_size 221 | self.embedding = nn.Embedding(tokenizer.n_words, emb_size) 222 | self.logger = logger 223 | self.tokenizer = tokenizer 224 | self.get_pretrained_embedding(pretrained_file) 225 | 226 | def get_pretrained_embedding(self, pretrained_file): 227 | self.logger.info('Loding embedding from %s', pretrained_file) 228 | for line in open(pretrained_file, encoding='utf-8').readlines(): 229 | items = line.split() 230 | if (len(items) == self.emb_size + 1): 231 | if self.tokenizer.word2idx.__contains__(items[0]): 232 | self.embedding.weight.data[self.tokenizer.word2idx[items[0]]] = \ 233 | torch.tensor([float(x) for x in items[1:]]) 234 | self.embedding.weight.data.requires_grad = True 235 | 236 | def get_embedding(self): 237 | return self.embedding 238 | 239 | class EncoderRNN(nn.Module): 240 | def __init__(self, input_size, hidden_size, rnn_type='lstm', num_layers=1): 241 | super(EncoderRNN, self).__init__() 242 | self.hidden_size = hidden_size 243 | self.embedding = nn.Embedding(input_size, hidden_size) 244 | if rnn_type == 'lstm': 245 | self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers) 246 | elif rnn_type == 'gru': 247 | self.rnn = nn.GRU(hidden_size, hidden_size, num_layers=num_layers) 248 | 249 | def forward(self, input_emb, hidden): 250 | output, hidden = self.rnn(input_emb, hidden) 251 | return output, hidden 252 | 253 | class AttnDecoderRNN(nn.Module): 254 | def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=64, rnn_type='lstm', num_layers=1): 255 | super(AttnDecoderRNN, self).__init__() 256 | self.hidden_size = hidden_size 257 | self.output_size = output_size 258 | self.dropout_p = dropout_p 259 | self.max_length = max_length 260 | 261 | self.embedding = nn.Embedding(self.output_size, self.hidden_size) 262 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 263 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 264 | self.dropout = nn.Dropout(self.dropout_p) 265 | if rnn_type == 'lstm': 266 | self.rnn = nn.LSTM(self.hidden_size, self.hidden_size, num_layers=num_layers) 267 | elif rnn_type == 'gru': 268 | self.rnn = nn.GRU(self.hidden_size, self.hidden_size, num_layers=num_layers) 269 | self.out = nn.Linear(self.hidden_size, self.output_size) 270 | 271 | def forward(self, input_emb, hidden_state, encoder_outputs): 272 | embedded = self.dropout(input_emb) 273 | attn_weights = F.softmax( 274 | self.attn(torch.cat((embedded[0], hidden_state[0]), 1)), dim=1) 275 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) 276 | 277 | output = torch.cat((embedded[0], attn_applied[0]), 1) 278 | output = self.attn_combine(output).unsqueeze(0) 279 | 280 | output = F.relu(output) 281 | output, hidden = self.rnn(output, hidden_state) 282 | 283 | output = F.log_softmax(self.out(output[0]), dim=1) 284 | return output, hidden, attn_weights 285 | 286 | class RNNSeq2Seq(nn.Module): 287 | 288 | def __init__(self, emb_dim, hidden_dim, num_layers, tokenizer, pretrained_file, logger, rnn_type='gru'): 289 | super(RNNSeq2Seq, self).__init__() 290 | self.tokenizer = tokenizer 291 | self.vocab_size = tokenizer.n_words 292 | self.embed_obj = Embedding(tokenizer, emb_dim, pretrained_file, logger) 293 | self.embedding = self.embed_obj.get_embedding() 294 | self.encoder = EncoderRNN(emb_dim, hidden_dim, rnn_type=rnn_type, num_layers=num_layers) 295 | # self.decoder = AttnDecoderRNN(hi) 296 | 297 | class TransformerSeq2Seq(nn.Module): 298 | 299 | def __init__(self, emb_dim, hidden_dim, num_layers, heads, depth_size, filter_size, tokenizer, pretrained_file, 300 | pointer_gen, logger, weight_sharing=True, model_file_path=None, is_eval=False, 301 | load_optim=False, label_smoothing=False, multi_input=False, context_size=2, 302 | attention_fusion_type='mean'): 303 | super(TransformerSeq2Seq, self).__init__() 304 | self.tokenizer = tokenizer 305 | self.vocab_size = tokenizer.n_words 306 | 307 | self.embed_obj = Embedding(tokenizer, emb_dim, pretrained_file, logger) 308 | 309 | self.embedding = self.embed_obj.get_embedding() 310 | self.encoder = Encoder(emb_dim, hidden_dim, num_layers=num_layers, num_heads=heads, total_key_depth=depth_size, 311 | total_value_depth=depth_size, filter_size=filter_size) 312 | 313 | self.decoder = Decoder(emb_dim, hidden_dim, num_layers=num_layers, num_heads=heads, total_key_depth=depth_size, 314 | total_value_depth=depth_size, filter_size=filter_size, multi_input=multi_input, 315 | context_size=context_size, attention_fusion_type=attention_fusion_type) 316 | self.generator = Generator(hidden_dim, self.vocab_size, pointer_gen) 317 | self.pad_id = tokenizer.pad_id 318 | self.n_embeddings = tokenizer.n_words 319 | self.embeddings_size = emb_dim 320 | self.multi_input = multi_input 321 | 322 | if weight_sharing: 323 | # Share the weight matrix between target word embedding & the final logit dense layer 324 | self.generator.proj.weight = self.embedding.weight 325 | 326 | self.criterion = nn.NLLLoss(ignore_index=self.pad_id) 327 | if label_smoothing: 328 | self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=self.pad_id, smoothing=0.1) 329 | self.criterion_ppl = nn.NLLLoss(ignore_index=self.pad_id) 330 | if is_eval: 331 | self.encoder = self.encoder.eval() 332 | self.decoder = self.decoder.eval() 333 | self.generator = self.generator.eval() 334 | self.embedding = self.embedding.eval() 335 | 336 | def forward(self, input_ids, labels, train=True, return_encoded=False): 337 | label_embeddings = self.embedding(labels) 338 | mask_target = labels.data.eq(self.pad_id).unsqueeze(1) 339 | if self.multi_input: 340 | input_embeddings = [self.embedding(sub_input) for sub_input in input_ids] 341 | mask_enc = [sub_input.data.eq(self.pad_id).unsqueeze(1) for sub_input in input_ids] 342 | encoder_outputs = [self.encoder(sub_embeddings, sub_mask) for sub_embeddings, sub_mask in 343 | zip(input_embeddings, mask_enc)] 344 | pre_logits, attn_dist = self.decoder(label_embeddings, encoder_outputs, mask_enc, mask_target) 345 | logits = self.generator(pre_logits, attn_dist, enc_batch_extend_vocab=input_ids) 346 | else: 347 | input_embeddings = self.embedding(input_ids) 348 | mask_enc = input_ids.data.eq(self.pad_id).unsqueeze(1) 349 | encoder_outputs = self.encoder(input_embeddings, mask_enc) 350 | pre_logits, attn_dist = self.decoder(label_embeddings, encoder_outputs, mask_enc, mask_target) 351 | logits = self.generator(pre_logits, attn_dist, enc_batch_extend_vocab=input_ids) 352 | if train: 353 | shifted_logits = logits[:, :-1, :].contiguous() 354 | shifted_labels = labels[:, 1:].contiguous() 355 | loss = self.criterion(shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1)) 356 | if return_encoded: 357 | return loss, encoder_outputs 358 | return loss 359 | else: 360 | if return_encoded: 361 | return logits, encoder_outputs 362 | return logits 363 | 364 | def _get_proba_with_temperature(self, logits): 365 | if self.bs_temperature != 1: 366 | logits /= self.bs_temperature 367 | return torch.nn.functional.softmax(logits, dim=-1) 368 | 369 | def _get_beam_scores(self, probas, beam_scores, is_end): 370 | skip_mask = None 371 | 372 | if self.bs_nucleus_p > 0: 373 | assert self.annealing_topk is None 374 | 375 | sorted_probas, idxs = torch.sort(probas, descending=True, dim=-1) 376 | skip_mask = torch.cumsum(sorted_probas.cumsum(dim=-1) > self.bs_nucleus_p, dim=-1) > 1 377 | sorted_probas.masked_fill_(skip_mask, 0.0) 378 | _, idxs = torch.sort(idxs, dim=-1) 379 | probas = torch.gather(sorted_probas, -1, idxs) 380 | skip_mask = torch.gather(skip_mask, -1, idxs) 381 | beam_scores = beam_scores.unsqueeze(-1) + torch.log(probas + 1e-20) * (1 - is_end.float().unsqueeze(-1)) 382 | 383 | if skip_mask is not None: 384 | beam_scores.masked_fill_(skip_mask, float('-inf')) 385 | 386 | return beam_scores 387 | 388 | def _length_penalty(self, sequence_lengths): 389 | """https://arxiv.org/abs/1609.08144""" 390 | return (5 + sequence_lengths) ** self.length_penalty_coef / (5 + 1) ** self.length_penalty_coef 391 | 392 | def _sample(self, beam_scores, num_samples, sample_prob=1.): 393 | if random.random() < sample_prob: 394 | beam_probas = torch.nn.functional.softmax(beam_scores, dim=-1) 395 | if self.annealing_topk is not None: 396 | beam_probas, sample_idxs = beam_probas.topk(self.annealing_topk, dim=-1) 397 | idxs = torch.multinomial(beam_probas, num_samples) 398 | idxs = torch.gather(sample_idxs, 1, idxs) 399 | else: 400 | idxs = torch.multinomial(beam_probas, num_samples) 401 | scores = torch.gather(beam_scores, 1, idxs) 402 | else: 403 | scores, idxs = beam_scores.topk(num_samples, dim=-1) 404 | 405 | return scores, idxs 406 | 407 | def inference(self, input_ids, encoder_outputs=None): 408 | if self.inference_mode == 'beam': 409 | return self.beam_search(input_ids, encoder_outputs) 410 | elif self.inference_mode == 'sampling': 411 | return self.sampling_inference(input_ids, encoder_outputs) 412 | 413 | def sampling_inference(self, input_ids, encoder_outputs): 414 | with torch.no_grad(): 415 | batch_size = input_ids[0].shape[0] if isinstance(input_ids, list) else input_ids.shape[0] 416 | device = next(self.parameters()).device 417 | scores = torch.zeros(batch_size, self.response_k, device=device) 418 | predicts = [] 419 | if self.multi_input: 420 | mask_enc = [sub_input.data.eq(self.pad_id).unsqueeze(1) for sub_input in input_ids] 421 | if encoder_outputs is None: 422 | input_embeddings = [self.embedding(sub_input) for sub_input in input_ids] 423 | encoder_outputs = [self.encoder(sub_embeddings, sub_mask) for sub_embeddings, sub_mask in 424 | zip(input_embeddings, mask_enc)] 425 | else: 426 | mask_enc = input_ids.data.eq(self.pad_id).unsqueeze(1) 427 | if encoder_outputs is None: 428 | input_embeddings = self.embedding(input_ids) 429 | encoder_outputs = self.encoder(input_embeddings, mask_enc) 430 | for k in range(self.response_k): 431 | prevs = torch.full((batch_size, 1), fill_value=self.talker1_id, dtype=torch.long, device=device) 432 | sample_scores = torch.zeros(batch_size, 1, device=device) 433 | lens = torch.ones(batch_size, 1, dtype=torch.long, device=device) 434 | is_end = torch.zeros(batch_size, 1, dtype=torch.uint8, device=device) 435 | for i in range(self.max_seq_len): 436 | label_embeddings = self.embedding(prevs) 437 | mask_target = prevs.data.eq(self.pad_id).unsqueeze(1) 438 | pre_logits, attn_dist = self.decoder(label_embeddings, encoder_outputs, mask_enc, mask_target) 439 | logits = self.generator(pre_logits, attn_dist, enc_batch_extend_vocab=input_ids)[:, -1:, :] 440 | probs = self._get_proba_with_temperature(logits.float()).squeeze(1) 441 | cur_idxs = torch.multinomial(probs, 1) 442 | prevs = torch.cat([prevs, cur_idxs], 1) 443 | is_end[cur_idxs == self.eos_id] = 1 444 | lens[~is_end] += 1 445 | cur_scores = torch.gather(probs, 1, cur_idxs) 446 | sample_scores += torch.log(cur_scores) 447 | sample_scores /= self._length_penalty(lens.float()) 448 | scores[:, k] = sample_scores.squeeze(1) 449 | cur_predict = [] 450 | for i in range(batch_size): 451 | length = lens[i] 452 | cur_predict.append(prevs[i, 1: length].tolist()) 453 | predicts.append(cur_predict) 454 | best_idx = scores.argmax(dim=1) 455 | final_predicts = [] 456 | for i in range(batch_size): 457 | final_predicts.append(predicts[best_idx[i]][i]) 458 | return final_predicts 459 | 460 | def beam_search(self, input_ids, encoder_outputs): 461 | with torch.no_grad(): 462 | batch_size = input_ids[0].shape[0] if isinstance(input_ids, list) else input_ids.shape[0] 463 | device = next(self.parameters()).device 464 | prevs = torch.full((batch_size * self.beam_size, 1), fill_value=self.talker1_id, dtype=torch.long, 465 | device=device) 466 | if self.multi_input: 467 | mask_enc = [sub_input.data.eq(self.pad_id).unsqueeze(1) for sub_input in input_ids] 468 | if encoder_outputs is None: 469 | input_embeddings = [self.embedding(sub_input) for sub_input in input_ids] 470 | encoder_outputs = [self.encoder(sub_embeddings, sub_mask) for sub_embeddings, sub_mask in 471 | zip(input_embeddings, mask_enc)] 472 | encoder_outputs = [repeat_along_dim1(sub_outputs, self.beam_size) for sub_outputs in encoder_outputs] 473 | beam_input_ids = [repeat_along_dim1(sub_input_ids, self.beam_size) for sub_input_ids in input_ids] 474 | beam_mask_enc = [sub_beam_input_ids.data.eq(self.pad_id).unsqueeze(1) for sub_beam_input_ids in 475 | beam_input_ids] 476 | else: 477 | mask_enc = input_ids.data.eq(self.pad_id).unsqueeze(1) 478 | if encoder_outputs is None: 479 | input_embeddings = self.embedding(input_ids) 480 | encoder_outputs = self.encoder(input_embeddings, mask_enc) 481 | encoder_outputs = repeat_along_dim1(encoder_outputs, self.beam_size) 482 | beam_input_ids = repeat_along_dim1(input_ids, self.beam_size) 483 | beam_mask_enc = beam_input_ids.data.eq(self.pad_id).unsqueeze(1) 484 | 485 | beam_scores = torch.zeros(batch_size, self.beam_size, device=device) 486 | beam_lens = torch.ones(batch_size, self.beam_size, dtype=torch.long, device=device) 487 | is_end = torch.zeros(batch_size, self.beam_size, dtype=torch.uint8, device=device) 488 | 489 | current_sample_prob = 1 490 | group_size = self.beam_size // self.diversity_groups 491 | diversity_penalty = torch.zeros((batch_size, self.n_embeddings), device=device) 492 | 493 | for i in range(self.max_seq_len): 494 | label_embeddings = self.embedding(prevs) 495 | mask_target = prevs.data.eq(self.pad_id).unsqueeze(1) 496 | pre_logits, attn_dist = self.decoder(label_embeddings, encoder_outputs, beam_mask_enc, mask_target) 497 | logits = self.generator(pre_logits, attn_dist, enc_batch_extend_vocab=beam_input_ids)[:, -1:, :] 498 | 499 | probs = self._get_proba_with_temperature(logits.float()) 500 | probs = probs.view(batch_size, self.beam_size, -1) 501 | 502 | beam_scores = self._get_beam_scores(probs, beam_scores, is_end) 503 | penalty = self._length_penalty(beam_lens.float() + 1 - is_end.float()).unsqueeze(-1) 504 | beam_scores = beam_scores / penalty 505 | 506 | if i == 0: 507 | penalty = penalty[:, 0, :] 508 | beam_scores = beam_scores[:, 0, :] 509 | 510 | beam_scores, idxs = beam_scores.topk(self.beam_size, dim=-1) 511 | beam_idxs = torch.zeros((batch_size, self.beam_size), dtype=torch.long, device=device) 512 | else: 513 | penalty = penalty.view(batch_size, self.diversity_groups, group_size, -1) 514 | beam_scores = beam_scores.view(batch_size, self.diversity_groups, group_size, -1) 515 | 516 | all_scores, all_idxs = [], [] 517 | for g in range(self.diversity_groups): 518 | g_beam_scores = beam_scores[:, g, :, :] 519 | g_penalty = penalty[:, g, :, :] 520 | g_beam_scores -= self.diversity_coef * diversity_penalty.unsqueeze(1) / g_penalty 521 | g_beam_scores = g_beam_scores.view(batch_size, -1) 522 | 523 | g_scores, g_idxs = self._sample(g_beam_scores, group_size, sample_prob=current_sample_prob) 524 | g_idxs += g * group_size * self.n_embeddings 525 | 526 | all_scores.append(g_scores) 527 | all_idxs.append(g_idxs) 528 | 529 | diversity_penalty.scatter_add_(1, 530 | torch.fmod(g_idxs, self.n_embeddings), 531 | torch.ones((batch_size, group_size), device=device)) 532 | 533 | diversity_penalty.fill_(0) 534 | penalty = penalty.view(batch_size, -1) 535 | beam_scores = torch.cat(all_scores, dim=-1) 536 | idxs = torch.cat(all_idxs, dim=-1) 537 | 538 | beam_idxs = (idxs.float() / self.n_embeddings).long() 539 | 540 | sym_idxs = torch.fmod(idxs, probs.shape[-1]) 541 | is_end = torch.gather(is_end, 1, beam_idxs) 542 | beam_lens = torch.gather(beam_lens, 1, beam_idxs) 543 | 544 | sym_idxs[is_end] = self.padding_idx 545 | beam_lens[~is_end] += 1 546 | is_end[sym_idxs == self.eos_id] = 1 547 | 548 | sym_idxs = sym_idxs.view(batch_size * self.beam_size, 1) 549 | prevs = prevs.view(batch_size, self.beam_size, -1) 550 | prevs = torch.gather(prevs, 1, beam_idxs.unsqueeze(-1).repeat(1, 1, prevs.shape[-1])) 551 | prevs = prevs.view(batch_size * self.beam_size, -1) 552 | prevs = torch.cat([prevs, sym_idxs], dim=1) 553 | 554 | if all(is_end.view(-1)): 555 | break 556 | 557 | beam_scores *= penalty 558 | current_sample_prob *= self.annealing 559 | 560 | predicts = [] 561 | result = prevs.view(batch_size, self.beam_size, -1) 562 | 563 | if self.sample: 564 | probs = torch.nn.functional.softmax(beam_scores, dim=-1) 565 | bests = torch.multinomial(probs, 1).view(-1) 566 | else: 567 | bests = beam_scores.argmax(dim=-1) 568 | 569 | for i in range(batch_size): 570 | best_len = beam_lens[i, bests[i]] 571 | best_seq = result[i, bests[i], 1:best_len - 1] 572 | predicts.append(best_seq.tolist()) 573 | 574 | return predicts 575 | -------------------------------------------------------------------------------- /model/seq2seq_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | class Seq2seqTokenizer: 5 | def __init__(self): 6 | self.word2idx = {"": 0, "": 1, "": 2, "": 3, '': 4, '': 5} 7 | self.idx2word = {0: "", 1: "", 2: "", 3: "", 4: "", 5: ""} 8 | self.n_words = 6 9 | self.all_special_ids = [0, 1, 2, 3, 4, 5] 10 | self.pad_id = 1 11 | self.bos_id = 2 12 | self.eos_id = 3 13 | self.talker1_bos_id = 4 14 | self.talker2_bos_id = 5 15 | 16 | def tokenize(self, str): 17 | res = str.strip().split(' ') 18 | res = [x.lower() for x in res] 19 | return res 20 | 21 | def encode(self, tokenized_str): 22 | res = [] 23 | for token in tokenized_str: 24 | if self.word2idx.__contains__(token): 25 | res.append(self.word2idx[token]) 26 | return res 27 | 28 | def decode(self, ids, skip_special_tokens=True, clean_up_tokenization_spaces=False): 29 | res = [] 30 | for id in ids: 31 | if skip_special_tokens and id in self.all_special_ids: 32 | continue 33 | res.append(self.idx2word[id]) 34 | text = ' '.join(res) 35 | return text 36 | 37 | def index_words(self, sentence): 38 | for word in sentence.split(' '): 39 | self.index_word(word) 40 | 41 | def index_word(self, word): 42 | if not self.word2idx.__contains__(word): 43 | self.word2idx[word] = self.n_words 44 | self.idx2word[self.n_words] = word 45 | self.n_words += 1 46 | 47 | class Seq2seqVocab: 48 | def __init__(self, train_dataset_path, valid_dataset_path, test_dataset_path, vocab_path, data_type='persona'): 49 | if (os.path.exists(vocab_path)): 50 | with open(vocab_path, 'rb') as f: 51 | cached_data = pickle.load(f) 52 | self.vocab = cached_data[0] 53 | self.all_data = cached_data[1] 54 | else: 55 | self.vocab = Seq2seqTokenizer() 56 | if data_type == 'persona': 57 | self.all_data = self.parse_data(train_dataset_path, valid_dataset_path, test_dataset_path) 58 | elif data_type == 'emoji': 59 | self.all_data = self.parse_data_emoji(train_dataset_path, valid_dataset_path, test_dataset_path) 60 | self.parse_vocab(self.all_data, self.vocab) 61 | with open(vocab_path, 'wb') as f: 62 | pickle.dump([self.vocab, self.all_data], f) 63 | 64 | def parse_data(self, train_dataset_path, valid_dataset_path, test_dataset_path): 65 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 66 | all_data = [] 67 | for subset in subsets: 68 | data = [] 69 | with open(subset, 'r') as f: 70 | for line in f.readlines(): 71 | line = line.strip() 72 | space_idx = line.find(' ') 73 | if space_idx == -1: 74 | dialog_idx = int(line) 75 | else: 76 | dialog_idx = int(line[:space_idx]) 77 | 78 | if int(dialog_idx) == 1: 79 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 80 | 81 | dialog_line = line[space_idx + 1:].split('\t') 82 | dialog_line = [l.strip() for l in dialog_line] 83 | 84 | if dialog_line[0].startswith('your persona:'): 85 | persona_info = dialog_line[0].replace('your persona: ', '') 86 | persona_info = persona_info.replace('.', ' .') 87 | data[-1]['persona_info'].append(persona_info) 88 | 89 | elif len(dialog_line) > 1: 90 | data[-1]['dialog'].append(dialog_line[0]) 91 | data[-1]['dialog'].append(dialog_line[1]) 92 | if len(dialog_line) == 4: 93 | data[-1]['candidates'].append(dialog_line[3].split('|')[:-1]) 94 | 95 | all_data.append(data) 96 | return all_data 97 | 98 | def parse_data_emoji(self, train_dataset_path, valid_dataset_path, test_dataset_path): 99 | subsets = [train_dataset_path, valid_dataset_path, test_dataset_path] 100 | all_data = [] 101 | for subset in subsets: 102 | data = [] 103 | with open(subset, 'r') as f: 104 | for line in f.readlines(): 105 | line = line.strip() 106 | items = line.split('\t') 107 | data.append({'persona_info': [], 'dialog': [], 'candidates': []}) 108 | data[-1]['persona_info'].append(items[0]) 109 | data[-1]['dialog'].append(items[1]) 110 | data[-1]['dialog'].append(items[2]) 111 | all_data.append(data) 112 | return all_data 113 | 114 | def parse_vocab(self, all_data, vocab): 115 | for data in all_data: 116 | for p in data: 117 | for s in p['persona_info']: 118 | vocab.index_words(s) 119 | for s in p['dialog']: 120 | vocab.index_words(s) 121 | for c in p['candidates']: 122 | for s in c: 123 | vocab.index_words(s) -------------------------------------------------------------------------------- /model/transformer_model.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import logging 18 | import random 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from .transformer_module import TransformerModule 25 | from .utils import repeat_along_dim1 26 | from copy import deepcopy 27 | 28 | logger = logging.getLogger(__file__) 29 | 30 | 31 | def apex_model(model, *, apex_level=None, optimizer=None, apex_loss_scale=None, num_losses=4): 32 | if apex_level is not None: 33 | assert apex_level == 'O0' or model.sparse_embeddings == False, 'Apex doesn\'t support sparse tensors' 34 | 35 | try: 36 | from apex.amp import initialize 37 | except ImportError: 38 | raise ImportError("Please install apex.") 39 | 40 | return initialize(model, optimizer, opt_level=apex_level, loss_scale=apex_loss_scale, num_losses=num_losses) 41 | 42 | return model if optimizer is None else (model, optimizer) 43 | 44 | 45 | class MultipleChoiceHead(nn.Module): 46 | """ Classifier Head for the transformer """ 47 | 48 | def __init__(self, in_features, dropout): 49 | super(MultipleChoiceHead, self).__init__() 50 | self.dropout = nn.Dropout(dropout) 51 | self.linear = nn.Linear(in_features, 1) 52 | 53 | self._init_weights() 54 | 55 | def _init_weights(self): 56 | nn.init.normal_(self.linear.weight, std=0.02) 57 | nn.init.normal_(self.linear.bias, 0) 58 | 59 | def forward(self, hidden_state, padding_mask): 60 | # Get classification logits as the last logit and apply a Linear layer on them 61 | # hidden_state (bsz, seq_length, hidden_size) 62 | # padding_mask (bsz, seq_length) 63 | last_token_idx = torch.sum(~padding_mask, dim=-1) - 1 # (bsz) 64 | last_token_idx = last_token_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, hidden_state.size(-1)) # (bsz, 1, hidden_size) 65 | multiple_choice_h = hidden_state.gather(dim=-2, index=last_token_idx).squeeze(-2) # (bsz, hidden_size) 66 | multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) # (bsz) 67 | return multiple_choice_logits 68 | 69 | 70 | class TransformerModel(nn.Module): 71 | def __init__(self, n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 72 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, ff_dropout, 73 | bos_id, eos_id, sent_dialog_id, max_seq_len=256, beam_size=5, sample=False, 74 | length_penalty=0.8, annealing_topk=None, annealing=0, normalize_embeddings=True, 75 | diversity_coef=0, diversity_groups=1, n_segments=None, multiple_choice_head=False, 76 | single_input=False, dialog_embeddings=False, vocab=None, constant_embedding=False, 77 | share_models=True, successive_attention=False, sparse_embeddings=False, 78 | shared_attention=True, context_size=2, bs_temperature=1, bs_nucleus_p=0, same_embedding_lm=1, 79 | model_class=TransformerModule): 80 | 81 | super(TransformerModel, self).__init__() 82 | 83 | self.n_embeddings = n_embeddings 84 | self.n_pos_embeddings = n_pos_embeddings 85 | self.embeddings_size = embeddings_size 86 | self.sparse_embeddings = sparse_embeddings 87 | 88 | self.bos_id = bos_id 89 | self.padding_idx = padding_idx 90 | self.eos_id = eos_id 91 | self.sent_dialog_id = sent_dialog_id 92 | 93 | self.max_seq_len = max_seq_len 94 | self.beam_size = beam_size 95 | self.sample = sample 96 | self.length_penalty_coef = length_penalty 97 | self.annealing = annealing 98 | self.annealing_topk = annealing_topk 99 | self.diversity_coef = diversity_coef 100 | self.diversity_groups = diversity_groups 101 | 102 | self.single_input = single_input 103 | self.dialog_embeddings = dialog_embeddings 104 | self.share_models = share_models 105 | 106 | self.bs_temperature = bs_temperature 107 | self.bs_nucleus_p = bs_nucleus_p 108 | 109 | self.vocab = vocab 110 | 111 | self.transformer_module = model_class(n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 112 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, 113 | ff_dropout, normalize_embeddings, n_segments, 114 | constant_embedding=constant_embedding, 115 | successive_attention=successive_attention, 116 | sparse_embeddings=sparse_embeddings, 117 | shared_attention=shared_attention, 118 | context_size=context_size) 119 | if not share_models: 120 | self.encoder_module = model_class(n_layers, n_embeddings, n_pos_embeddings, embeddings_size, 121 | padding_idx, n_heads, dropout, embed_dropout, attn_dropout, 122 | ff_dropout, normalize_embeddings, n_segments, 123 | constant_embedding=constant_embedding, 124 | successive_attention=successive_attention, 125 | sparse_embeddings=sparse_embeddings, 126 | shared_attention=shared_attention) 127 | self.pre_softmax = nn.Linear(embeddings_size, n_embeddings, bias=False) 128 | if same_embedding_lm == 1: 129 | self.pre_softmax.weight = self.transformer_module.embeddings.weight 130 | if same_embedding_lm == -1: 131 | self.pre_softmax.weight = deepcopy(self.transformer_module.embeddings.weight) 132 | self.multiple_choice_head = MultipleChoiceHead(self.embeddings_size, dropout) if multiple_choice_head else None 133 | 134 | # def distribute(self, device, device_ids): 135 | # try: 136 | # from torch.nn.parallel import DistributedDataParallel 137 | # except ImportError: 138 | # raise ImportError("Please install apex.") 139 | # 140 | # def _distribute(model): 141 | # return DistributedDataParallel(model.to(device), device_ids=device_ids, find_unused_parameters=True, broadcast_buffers=False) 142 | # 143 | # self.pre_softmax = _distribute(self.pre_softmax) 144 | # self.transformer_module = _distribute(self.transformer_module) 145 | # if hasattr(self, 'encoder_module'): 146 | # self.encoder_module = _distribute(self.encoder_module) 147 | # self.multiple_choice_head = _distribute(self.multiple_choice_head) \ 148 | # if self.multiple_choice_head is not None else None 149 | 150 | def distribute(self, device): 151 | try: 152 | from apex.parallel import DistributedDataParallel, convert_syncbn_model 153 | except ImportError: 154 | raise ImportError("Please install apex.") 155 | 156 | def _distributed(module): 157 | return DistributedDataParallel(convert_syncbn_model(module)) 158 | 159 | self.transformer_module = _distributed(self.transformer_module.to(device)) 160 | if hasattr(self, 'encoder_module'): 161 | self.encoder_module = _distributed(self.encoder_module.to(device)) 162 | self.pre_softmax = _distributed(self.pre_softmax.to(device)) 163 | self.multiple_choice_head = _distributed(self.multiple_choice_head.to(device)) \ 164 | if self.multiple_choice_head is not None else None 165 | 166 | def state_dict(self): 167 | state_dict = {} 168 | for k in dir(self): 169 | module = getattr(self, k) 170 | if isinstance(module, nn.Module): 171 | if isinstance(module, nn.parallel.DistributedDataParallel): 172 | module = module.module 173 | 174 | state_dict[k] = module.state_dict() 175 | 176 | return state_dict 177 | 178 | def load_state_dict(self, state_dict, strict=True): 179 | for k, v in state_dict.items(): 180 | assert hasattr(self, k), f'Model does not have {k} submodule' 181 | module = getattr(self, k) 182 | if isinstance(module, nn.parallel.DistributedDataParallel): 183 | module = module.module 184 | module.load_state_dict(v, strict) 185 | 186 | def forward(self, persona, dialog): 187 | enc_persona, enc_persona_generated = None, None 188 | if persona is not None: 189 | enc_persona, enc_persona_mask = self.encode(persona) 190 | enc_persona_generated = self.pre_softmax(enc_persona) 191 | enc_dialog, enc_dialog_mask = self.encode(dialog) 192 | enc_dialog_generated = self.pre_softmax(enc_dialog) 193 | return enc_persona, enc_persona_generated, enc_dialog, enc_dialog_generated 194 | 195 | def encode(self, x): 196 | " Returns a tuple(x, padding_mask)" 197 | x, padding_mask, _ = self.transformer_module(x) if self.share_models else self.encoder_module(x) 198 | return x, padding_mask 199 | 200 | def generate(self, enc_x): 201 | return self.pre_softmax(enc_x) 202 | 203 | def classify(self, x, padding_mask): 204 | return self.multiple_choice_head(x, padding_mask) 205 | 206 | def decode_classify(self, x, enc_contexts=[]): 207 | x, padding_mask, _ = self.transformer_module(x, enc_contexts) 208 | return self.classify(x, padding_mask) 209 | 210 | def decode(self, x, enc_contexts=[]): 211 | x, _, _ = self.transformer_module(x, enc_contexts) 212 | return self.generate(x) 213 | 214 | def predict(self, contexts=[]): 215 | if self.single_input: 216 | assert isinstance(contexts, torch.Tensor) 217 | enc_contexts = [] 218 | beam_starts = contexts 219 | else: 220 | enc_contexts = [self.encode(c) for c in contexts] 221 | beam_starts = None 222 | prediction = self.beam_search(enc_contexts=enc_contexts, beam_starts=beam_starts) 223 | 224 | return prediction 225 | 226 | def _length_penalty(self, sequence_lengths): 227 | """https://arxiv.org/abs/1609.08144""" 228 | return (5 + sequence_lengths) ** self.length_penalty_coef / (5 + 1) ** self.length_penalty_coef 229 | 230 | def _get_proba_with_temperature(self, logits): 231 | if self.bs_temperature != 1: 232 | logits /= self.bs_temperature 233 | 234 | return F.softmax(logits, dim=-1) 235 | 236 | def _get_beam_scores(self, probas, beam_scores, is_end): 237 | skip_mask = None 238 | 239 | if self.bs_nucleus_p > 0: 240 | assert self.annealing_topk is None 241 | 242 | sorted_probas, idxs = torch.sort(probas, descending=True, dim=-1) 243 | skip_mask = torch.cumsum(sorted_probas.cumsum(dim=-1) > self.bs_nucleus_p, dim=-1) > 1 244 | sorted_probas.masked_fill_(skip_mask, 0.0) 245 | _, idxs = torch.sort(idxs, dim=-1) 246 | probas = torch.gather(sorted_probas, -1, idxs) 247 | skip_mask = torch.gather(skip_mask, -1, idxs) 248 | 249 | beam_scores = beam_scores.unsqueeze(-1) + torch.log(probas) * (1 - is_end.float().unsqueeze(-1)) 250 | 251 | if skip_mask is not None: 252 | beam_scores.masked_fill_(skip_mask, float('-inf')) 253 | 254 | return beam_scores 255 | 256 | def _sample(self, beam_scores, num_samples, sample_prob=1.): 257 | if random.random() < sample_prob: 258 | beam_probas = F.softmax(beam_scores, dim=-1) 259 | if self.annealing_topk is not None: 260 | beam_probas, sample_idxs = beam_probas.topk(self.annealing_topk, dim=-1) 261 | idxs = torch.multinomial(beam_probas, num_samples) 262 | idxs = torch.gather(sample_idxs, 1, idxs) 263 | else: 264 | idxs = torch.multinomial(beam_probas, num_samples) 265 | scores = torch.gather(beam_scores, 1, idxs) 266 | else: 267 | scores, idxs = beam_scores.topk(num_samples, dim=-1) 268 | 269 | return scores, idxs 270 | 271 | def _fix_past(self, past, beam_idxs): 272 | for layer_output in past: 273 | for context in layer_output: 274 | for v in context: 275 | size_ = v.size() 276 | tile_size = size_[-2] * size_[-1] 277 | new_v = v.contiguous().view(-1, self.beam_size, tile_size) 278 | new_v = new_v.gather(1, beam_idxs.unsqueeze(-1).repeat([1, 1, tile_size])) 279 | v[...] = new_v.view(*size_) 280 | return past 281 | 282 | def beam_search(self, enc_contexts=[], return_beams=False, beam_starts=None): 283 | with torch.no_grad(): 284 | if len(enc_contexts) == 0 and beam_starts is None: 285 | return [] 286 | 287 | batch_size = enc_contexts[0][0].shape[0] if beam_starts is None else beam_starts.shape[0] 288 | device = next(self.parameters()).device 289 | 290 | prevs = torch.full((batch_size * self.beam_size, 1), fill_value=self.bos_id, dtype=torch.long, device=device) 291 | 292 | beam_scores = torch.zeros(batch_size, self.beam_size, device=device) 293 | beam_lens = torch.ones(batch_size, self.beam_size, dtype=torch.long, device=device) 294 | is_end = torch.zeros(batch_size, self.beam_size, dtype=torch.uint8, device=device) 295 | 296 | if beam_starts is not None: 297 | beam_starts = repeat_along_dim1(beam_starts, self.beam_size) 298 | beam_enc_contexts = repeat_along_dim1(enc_contexts, self.beam_size) 299 | 300 | current_sample_prob = 1 301 | group_size = self.beam_size // self.diversity_groups 302 | diversity_penalty = torch.zeros((batch_size, self.n_embeddings), device=device) 303 | past = None 304 | 305 | max_seq_len = min(self.n_pos_embeddings - prevs.shape[1] - (beam_starts.shape[1] if beam_starts is not None else 0), 306 | self.max_seq_len) 307 | 308 | for i in range(max_seq_len): 309 | inputs = prevs[:, -1:, ...] # only use the last token (rest is in past) 310 | if self.dialog_embeddings and inputs.dim() < 3: 311 | inputs = torch.stack((inputs, torch.full_like(inputs, self.sent_dialog_id)), dim=inputs.dim()) 312 | if i == 0 and beam_starts is not None: 313 | inputs = torch.cat((beam_starts, inputs), dim=1) 314 | 315 | outputs, _, past = self.transformer_module(inputs, beam_enc_contexts, past=past) 316 | 317 | logits = self.generate(outputs[:, -1, :]) 318 | 319 | probs = self._get_proba_with_temperature(logits.float()) 320 | probs = probs.view(batch_size, self.beam_size, -1) 321 | 322 | beam_scores = self._get_beam_scores(probs, beam_scores, is_end) 323 | penalty = self._length_penalty(beam_lens.float() + 1 - is_end.float()).unsqueeze(-1) 324 | beam_scores = beam_scores / penalty 325 | 326 | if i == 0: 327 | penalty = penalty[:, 0, :] 328 | beam_scores = beam_scores[:, 0, :] 329 | 330 | beam_scores, idxs = beam_scores.topk(self.beam_size, dim=-1) 331 | beam_idxs = torch.zeros((batch_size, self.beam_size), dtype=torch.long, device=device) 332 | else: 333 | penalty = penalty.view(batch_size, self.diversity_groups, group_size, -1) 334 | beam_scores = beam_scores.view(batch_size, self.diversity_groups, group_size, -1) 335 | 336 | all_scores, all_idxs = [], [] 337 | for g in range(self.diversity_groups): 338 | g_beam_scores = beam_scores[:, g, :, :] 339 | g_penalty = penalty[:, g, :, :] 340 | g_beam_scores -= self.diversity_coef * diversity_penalty.unsqueeze(1) / g_penalty 341 | g_beam_scores = g_beam_scores.view(batch_size, -1) 342 | 343 | g_scores, g_idxs = self._sample(g_beam_scores, group_size, sample_prob=current_sample_prob) 344 | g_idxs += g * group_size * self.n_embeddings 345 | 346 | all_scores.append(g_scores) 347 | all_idxs.append(g_idxs) 348 | 349 | diversity_penalty.scatter_add_(1, 350 | torch.fmod(g_idxs, self.n_embeddings), 351 | torch.ones((batch_size, group_size), device=device)) 352 | 353 | diversity_penalty.fill_(0) 354 | penalty = penalty.view(batch_size, -1) 355 | beam_scores = torch.cat(all_scores, dim=-1) 356 | idxs = torch.cat(all_idxs, dim=-1) 357 | 358 | beam_idxs = (idxs.float() / self.n_embeddings).long() 359 | 360 | sym_idxs = torch.fmod(idxs, probs.shape[-1]) 361 | is_end = torch.gather(is_end, 1, beam_idxs) 362 | beam_lens = torch.gather(beam_lens, 1, beam_idxs) 363 | 364 | if self.vocab is not None: 365 | logger.info('\nbeams:\n' + '\n'.join(self.vocab.ids2string(t.detach().cpu().tolist()) for t in prevs)) 366 | logger.info('\ntop-options:\n' + '\n'.join(self.vocab.ids2string(t.detach().cpu().tolist()) 367 | + str(bi.detach().cpu().tolist()) for t, bi in zip(sym_idxs, beam_idxs))) 368 | 369 | sym_idxs[is_end] = self.padding_idx 370 | beam_lens[~is_end] += 1 371 | is_end[sym_idxs == self.eos_id] = 1 372 | 373 | sym_idxs = sym_idxs.view(batch_size * self.beam_size, 1) 374 | prevs = prevs.view(batch_size, self.beam_size, -1) 375 | prevs = torch.gather(prevs, 1, beam_idxs.unsqueeze(-1).repeat(1, 1, prevs.shape[-1])) 376 | prevs = prevs.view(batch_size * self.beam_size, -1) 377 | prevs = torch.cat([prevs, sym_idxs], dim=1) 378 | 379 | past = self._fix_past(past, beam_idxs) 380 | 381 | if all(is_end.view(-1)): 382 | break 383 | 384 | beam_scores *= penalty 385 | current_sample_prob *= self.annealing 386 | 387 | predicts = [] 388 | result = prevs.view(batch_size, self.beam_size, -1) 389 | 390 | if return_beams: 391 | return result, beam_lens 392 | 393 | if self.sample: 394 | probs = F.softmax(beam_scores, dim=-1) 395 | bests = torch.multinomial(probs, 1).view(-1) 396 | else: 397 | bests = beam_scores.argmax(dim=-1) 398 | 399 | for i in range(batch_size): 400 | best_len = beam_lens[i, bests[i]] 401 | best_seq = result[i, bests[i], 1:best_len-1] 402 | predicts.append(best_seq.tolist()) 403 | 404 | return predicts 405 | -------------------------------------------------------------------------------- /model/transformer_module.py: -------------------------------------------------------------------------------- 1 | # transformer_chatbot 2 | # Copyright (C) 2018 Golovanov, Tselousov 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU Affero General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU Affero General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU Affero General Public License 15 | # along with this program. If not, see . 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from .utils import checkpoint_sequential 25 | 26 | logger = logging.getLogger(__file__) 27 | 28 | try: 29 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 30 | except ImportError: 31 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 32 | from torch.nn import LayerNorm 33 | 34 | 35 | class ConstantPositionalEmbedding(nn.Module): 36 | def __init__(self, embedding_dim, padding_idx): 37 | super(ConstantPositionalEmbedding, self).__init__() 38 | 39 | self.embedding_dim = embedding_dim 40 | self.padding_idx = padding_idx 41 | self.register_buffer('_position_embedding', 42 | ConstantPositionalEmbedding.get_embedding(1024, 43 | self.embedding_dim)) 44 | 45 | @classmethod 46 | def get_embedding(cls, seq_len, embedding_dim, device=None): 47 | seq_len += 1 48 | 49 | half_dim = embedding_dim // 2 50 | 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=device) * -emb) 53 | emb = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(seq_len, -1) 55 | 56 | if embedding_dim % 2: 57 | emb = torch.cat([emb, torch.zeros(seq_len, 1)], dim=1) 58 | 59 | return emb 60 | 61 | def forward(self, positions): 62 | batch_size, seq_len = positions.size() 63 | 64 | cur_seq_len = max(seq_len, torch.max(positions).item()) 65 | 66 | if cur_seq_len >= self._position_embedding.size(0): 67 | self._position_embedding = ConstantPositionalEmbedding.get_embedding(cur_seq_len, 68 | self.embedding_dim, 69 | positions.device) 70 | 71 | return self._position_embedding.index_select(0, positions.view(-1)).view(batch_size, seq_len, -1) 72 | 73 | 74 | class MultiheadAttention(nn.Module): 75 | @classmethod 76 | def _get_future_mask(cls, size, device): 77 | nd, ns = size 78 | max_size = max(nd, ns) 79 | if not hasattr(cls, '_future_mask') or cls._future_mask.device != device or any(s. 16 | 17 | import copy 18 | import io 19 | import json 20 | import logging 21 | import os 22 | import random 23 | import re 24 | import sys 25 | from collections import Counter 26 | 27 | import numpy as np 28 | import torch 29 | import torch.nn as nn 30 | from scipy.interpolate import RectBivariateSpline 31 | from torch.utils.checkpoint import checkpoint 32 | 33 | py_version = sys.version.split('.')[0] 34 | if py_version == '2': 35 | open = io.open 36 | unicode = unicode 37 | else: 38 | unicode = str 39 | open = open 40 | 41 | 42 | def set_seed(seed): 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | random.seed(seed) 46 | 47 | 48 | def repeat_along_dim1(obj, repetitions): 49 | """ repeat (a possibly nested object of) tensors from (batch, ...) to (batch * repetitions, ...) """ 50 | if isinstance(obj, tuple): 51 | return tuple(repeat_along_dim1(o, repetitions) for o in obj) 52 | if isinstance(obj, list): 53 | return list(repeat_along_dim1(o, repetitions) for o in obj) 54 | 55 | obj = obj.unsqueeze(1).repeat([1, repetitions] + [1] * len(obj.size()[1:])) 56 | return obj.view(-1, *obj.size()[2:]) 57 | 58 | 59 | def pad_sequence(sequences, batch_first=False, padding_value=0, left=False): 60 | # assuming trailing dimensions and type of all the Tensors 61 | # in sequences are same and fetching those from sequences[0] 62 | if not len(sequences): 63 | return torch.empty(0) 64 | trailing_dims = sequences[0].size()[1:] 65 | max_len = max([s.size(0) for s in sequences]) 66 | if batch_first: 67 | out_dims = (len(sequences), max_len) + trailing_dims 68 | else: 69 | out_dims = (max_len, len(sequences)) + trailing_dims 70 | 71 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 72 | for i, tensor in enumerate(sequences): 73 | length = tensor.size(0) 74 | s_slice = slice(-length, None) if left else slice(None, length) 75 | s_slice = (i, s_slice) if batch_first else (s_slice, i) 76 | out_tensor[s_slice] = tensor 77 | 78 | return out_tensor 79 | 80 | 81 | def checkpoint_sequential(functions, segments, *inputs): 82 | def run_function(start, end, functions): 83 | def forward(*inputs): 84 | for j in range(start, end + 1): 85 | inputs = functions[j](*inputs) 86 | return inputs 87 | return forward 88 | 89 | if isinstance(functions, torch.nn.Sequential): 90 | functions = list(functions.children()) 91 | 92 | segment_size = len(functions) // segments 93 | # the last chunk has to be non-volatile 94 | end = -1 95 | for start in range(0, segment_size * (segments - 1), segment_size): 96 | end = start + segment_size - 1 97 | inputs = checkpoint(run_function(start, end, functions), *inputs) 98 | if not isinstance(inputs, tuple): 99 | inputs = (inputs,) 100 | return run_function(end + 1, len(functions) - 1, functions)(*inputs) 101 | 102 | 103 | def f1_score(predictions, targets, average=True): 104 | def f1_score_items(pred_items, gold_items): 105 | common = Counter(gold_items) & Counter(pred_items) 106 | num_same = sum(common.values()) 107 | 108 | if num_same == 0: 109 | return 0 110 | 111 | precision = num_same / len(pred_items) 112 | recall = num_same / len(gold_items) 113 | f1 = (2 * precision * recall) / (precision + recall) 114 | 115 | return f1 116 | 117 | scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)] 118 | 119 | if average: 120 | return sum(scores) / len(scores) 121 | 122 | return scores 123 | 124 | 125 | def openai_transformer_config(): 126 | class dotdict(dict): 127 | __getattr__ = dict.get 128 | __setattr__ = dict.__setitem__ 129 | __delattr__ = dict.__delitem__ 130 | 131 | cfg = dotdict({'n_layers': 12, 'n_embeddings': 40477, 'n_pos_embeddings': 512, 132 | 'embeddings_size': 768, 'n_heads': 12, 'dropout': 0.1, 133 | 'embed_dropout': 0.1, 'attn_dropout': 0.1, 'ff_dropout': 0.1}) 134 | 135 | return cfg 136 | 137 | 138 | def load_openai_weights(model, directory, n_special_tokens=0, use_tokenizer=False): 139 | # TODO: add check of shapes 140 | 141 | parameters_names_path = os.path.join(directory, 'parameters_names.json') 142 | parameters_shapes_path = os.path.join(directory, 'parameters_shapes.json') 143 | parameters_weights_paths = [os.path.join(directory, 'params_{}.npy'.format(n)) for n in range(10)] 144 | 145 | with open(parameters_names_path, 'r') as parameters_names_file: 146 | parameters_names = json.load(parameters_names_file) 147 | 148 | with open(parameters_shapes_path, 'r') as parameters_shapes_file: 149 | parameters_shapes = json.load(parameters_shapes_file) 150 | 151 | parameters_weights = [np.load(path) for path in parameters_weights_paths] 152 | parameters_offsets = np.cumsum([np.prod(shape) for shape in parameters_shapes]) 153 | parameters_weights = np.split(np.concatenate(parameters_weights, 0), parameters_offsets)[:-1] 154 | parameters_weights = [p.reshape(s) for p, s in zip(parameters_weights, parameters_shapes)] 155 | 156 | if not use_tokenizer: 157 | parameters_weights[1] = parameters_weights[1][1:] # skip 0 - 158 | 159 | if isinstance(model.pos_embeddings, nn.Embedding): 160 | if model.pos_embeddings.num_embeddings - 1 > parameters_weights[0].shape[0]: 161 | xx = np.linspace(0, parameters_weights[0].shape[0], model.pos_embeddings.num_embeddings - 1) 162 | new_kernel = RectBivariateSpline(np.arange(parameters_weights[0].shape[0]), 163 | np.arange(parameters_weights[0].shape[1]), 164 | parameters_weights[0]) 165 | parameters_weights[0] = new_kernel(xx, np.arange(parameters_weights[0].shape[1])) 166 | 167 | # parameters_weights[0] = parameters_weights[0][:model.pos_embeddings.num_embeddings - 1] 168 | # model.pos_embeddings.weight.data[1:] = torch.from_numpy(parameters_weights[0]) 169 | model.pos_embeddings.weight.data = torch.from_numpy(parameters_weights[0]) 170 | 171 | 172 | if use_tokenizer: 173 | model.embeddings.weight.data[-n_special_tokens + 1:] = 0 174 | model.embeddings.weight.data[: -n_special_tokens + 1] = torch.from_numpy(parameters_weights[1]) 175 | else: 176 | parameters_weights[1] = parameters_weights[1][:model.embeddings.num_embeddings - n_special_tokens] 177 | model.embeddings.weight.data[:n_special_tokens] = 0 178 | model.embeddings.weight.data[n_special_tokens:] = torch.from_numpy(parameters_weights[1]) 179 | 180 | parameters_weights = parameters_weights[2:] 181 | 182 | for name, weights in zip(parameters_names, parameters_weights): 183 | name = name[6:] # skip "model/" 184 | assert name[-2:] == ':0' 185 | name = name[:-2] 186 | name = name.split('/') 187 | 188 | pointer = model 189 | for m_name in name: 190 | if re.fullmatch(r'[A-Za-z]+\d+', m_name): 191 | l = re.split(r'(\d+)', m_name) 192 | else: 193 | l = [m_name] 194 | 195 | pointer = getattr(pointer, l[0]) 196 | 197 | if len(l) >= 2: 198 | num = int(l[1]) 199 | pointer = pointer[num] 200 | 201 | if len(weights.shape) == 3: # conv1d to linear 202 | weights = weights[0].transpose((1, 0)) 203 | 204 | pointer.data[...] = torch.from_numpy(weights) 205 | 206 | # Initialize shared attention layer is necessary 207 | for layer in model.layers: 208 | attn_state = layer.attn.state_dict() 209 | for context_attn in layer.context_attns: 210 | context_attn.load_state_dict(copy.deepcopy(attn_state), strict=False) 211 | 212 | def config_logger(log_path): 213 | logger = logging.getLogger() 214 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 215 | level=logging.INFO) 216 | file_handler = logging.FileHandler(log_path, mode='w') 217 | file_handler.setLevel(logging.INFO) 218 | file_handler.setFormatter( 219 | logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')) 220 | logger.addHandler(file_handler) 221 | return logger 222 | -------------------------------------------------------------------------------- /new_metrics.py: -------------------------------------------------------------------------------- 1 | # author: Xiang Gao @ Microsoft Research, Oct 2018 2 | # compute NLP evaluation metrics 3 | 4 | import io 5 | import re 6 | import subprocess 7 | import sys 8 | import time 9 | import itertools 10 | from collections import defaultdict, Counter 11 | 12 | import numpy as np 13 | 14 | py_version = sys.version.split('.')[0] 15 | if py_version == '2': 16 | open = io.open 17 | else: 18 | unicode = str 19 | 20 | def str2bool(s): 21 | # to avoid issue like this: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 22 | if s.lower() in ['t', 'true', '1', 'y']: 23 | return True 24 | elif s.lower() in ['f', 'false', '0', 'n']: 25 | return False 26 | else: 27 | raise ValueError 28 | 29 | 30 | def calc_nist(path_refs, path_hyp, fld_out='temp', n_lines=None): 31 | return calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines)[0] 32 | 33 | 34 | def calc_bleu(path_refs, path_hyp, fld_out='temp', n_lines=None): 35 | return calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines)[1] 36 | 37 | 38 | def calc_nist_bleu(path_refs, path_hyp, fld_out='temp', n_lines=None): 39 | # call mteval-v14c.pl 40 | # ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl 41 | # you may need to cpan install XML:Twig Sort:Naturally String:Util 42 | 43 | if n_lines is None: 44 | n_lines = len(open(path_hyp, encoding='utf-8').readlines()) 45 | if fld_out is None: 46 | fld_out = 'temp' 47 | _write_xml([''], fld_out + '/src.xml', 'src', n_lines=n_lines) 48 | _write_xml([path_hyp], fld_out + '/hyp.xml', 'hyp', n_lines=n_lines) 49 | _write_xml(path_refs, fld_out + '/ref.xml', 'ref', n_lines=n_lines) 50 | 51 | time.sleep(1) 52 | cmd = [ 53 | 'perl','metrics/mteval-v14c.pl', 54 | '-s', '%s/src.xml'%fld_out, 55 | '-t', '%s/hyp.xml'%fld_out, 56 | '-r', '%s/ref.xml'%fld_out, 57 | ] 58 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE) 59 | output, error = process.communicate() 60 | 61 | lines = output.decode().split('\n') 62 | try: 63 | nist_score = lines[-22].strip('\r').split()[3] 64 | bleu_score = lines[-22].strip('\r').split()[7] 65 | nist = lines[-6].strip('\r').split()[1:5] 66 | bleu = lines[-4].strip('\r').split()[1:5] 67 | return float(nist_score), float(bleu_score), [float(x) for x in nist], [float(x) for x in bleu] 68 | 69 | except Exception: 70 | print('mteval-v14c.pl returns unexpected message') 71 | print('cmd = '+str(cmd)) 72 | print(output.decode()) 73 | print(error.decode()) 74 | return [-1]*4, [-1]*4 75 | 76 | 77 | def calc_cum_bleu(path_refs, path_hyp): 78 | # call multi-bleu.pl 79 | # https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl 80 | # the 4-gram cum BLEU returned by this one should be very close to calc_nist_bleu 81 | # however multi-bleu.pl doesn't return cum BLEU of lower rank, so in nlp_metrics we preferr calc_nist_bleu 82 | # NOTE: this func doesn't support n_lines argument and output is not parsed yet 83 | 84 | # process = subprocess.Popen( 85 | # ['perl', 'metrics/multi-bleu.perl'] + path_refs, 86 | # stdout=subprocess.PIPE, 87 | # stdin=subprocess.PIPE 88 | # ) 89 | process = subprocess.Popen( 90 | ['perl', 'metrics/multi-bleu.perl'] + path_refs, 91 | stdout=subprocess.PIPE, 92 | stdin=open(path_hyp, encoding='utf-8') 93 | ) 94 | # with open(path_hyp, encoding='utf-8') as f: 95 | # lines = f.readlines() 96 | # for i,line in enumerate(lines): 97 | # process.stdin.write(line.encode()) 98 | # print(i) 99 | output, error = process.communicate() 100 | return output.decode() 101 | 102 | 103 | def calc_meteor(path_refs, path_hyp, fld_out='temp', n_lines=None, pretokenized=True): 104 | # Call METEOR code. 105 | # http://www.cs.cmu.edu/~alavie/METEOR/index.html 106 | 107 | path_merged_refs = fld_out + '/refs_merged.txt' 108 | _write_merged_refs(path_refs, path_merged_refs) 109 | 110 | cmd = [ 111 | 'java', '-Xmx1g', # heapsize of 1G to avoid OutOfMemoryError 112 | '-jar', 'metrics/meteor-1.5/meteor-1.5.jar', 113 | path_hyp, path_merged_refs, 114 | '-r', str(len(path_refs)), # refCount 115 | '-l', 'en', '-norm' # also supports language: cz de es fr ar 116 | ] 117 | 118 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 119 | output, error = process.communicate() 120 | for line in output.decode().split('\n'): 121 | if "Final score:" in line: 122 | return float(line.split()[-1]) 123 | 124 | print('meteor-1.5.jar returns unexpected message') 125 | print("cmd = " + " ".join(cmd)) 126 | print(output.decode()) 127 | print(error.decode()) 128 | return -1 129 | 130 | 131 | def calc_entropy(path_hyp, n_lines=None): 132 | # based on Yizhe Zhang's code 133 | etp_score = [0.0,0.0,0.0,0.0] 134 | counter = [defaultdict(int),defaultdict(int),defaultdict(int),defaultdict(int)] 135 | i = 0 136 | for line in open(path_hyp, encoding='utf-8'): 137 | i += 1 138 | words = line.strip('\n').split() 139 | for n in range(4): 140 | for idx in range(len(words)-n): 141 | ngram = ' '.join(words[idx:idx+n+1]) 142 | counter[n][ngram] += 1 143 | if i == n_lines: 144 | break 145 | 146 | for n in range(4): 147 | total = sum(counter[n].values()) 148 | for v in counter[n].values(): 149 | etp_score[n] += - v /total * (np.log(v) - np.log(total)) 150 | 151 | return etp_score 152 | 153 | 154 | def calc_avg_len(path, n_lines=None): 155 | l = [] 156 | for line in open(path, encoding='utf8'): 157 | l.append(len(line.strip('\n').split())) 158 | if len(l) == n_lines: 159 | break 160 | return np.mean(l) 161 | 162 | 163 | def calc_div(path_hyp): 164 | tokens = [0.0, 0.0] 165 | types = [defaultdict(int), defaultdict(int)] 166 | for line in open(path_hyp, encoding='utf-8'): 167 | words = line.strip('\n').split() 168 | for n in range(2): 169 | for idx in range(len(words)-n): 170 | ngram = ' '.join(words[idx:idx+n+1]) 171 | types[n][ngram] = 1 172 | tokens[n] += 1 173 | div1 = len(types[0].keys())/tokens[0] if tokens[0] != 0 else 0 174 | div2 = len(types[1].keys())/tokens[1] if tokens[1] != 0 else 0 175 | return [div1, div2] 176 | 177 | 178 | def nlp_metrics(path_refs, path_hyp, fld_out='temp', n_lines=None): 179 | nist, bleu = calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines) 180 | meteor = calc_meteor(path_refs, path_hyp, fld_out, n_lines) 181 | entropy = calc_entropy(path_hyp, n_lines) 182 | div = calc_div(path_hyp) 183 | avg_len = calc_avg_len(path_hyp, n_lines) 184 | 185 | return nist, bleu, meteor, entropy, div, avg_len 186 | 187 | 188 | def specified_nlp_metric(path_refs, path_hyp, metric): 189 | i = None 190 | 191 | m = re.search('_[\d]\Z', metric) 192 | if m: 193 | metric, i = metric[:m.span()[0]], int(metric[m.span()[0]+1:]) - 1 194 | 195 | try: 196 | res = eval(f'calc_{metric}(path_refs, path_hyp)') 197 | except: 198 | res = eval(f'calc_{metric}(path_hyp)') 199 | 200 | return res if i is None else res[i] 201 | 202 | 203 | def _write_merged_refs(paths_in, path_out, n_lines=None): 204 | # prepare merged ref file for meteor-1.5.jar (calc_meteor) 205 | # lines[i][j] is the ref from i-th ref set for the j-th query 206 | 207 | lines = [] 208 | for path_in in paths_in: 209 | lines.append([line.strip('\n') for line in open(path_in, encoding='utf-8')]) 210 | 211 | with open(path_out, 'w', encoding='utf-8') as f: 212 | for j in range(len(lines[0])): 213 | for i in range(len(paths_in)): 214 | f.write(unicode(lines[i][j]) + "\n") 215 | 216 | 217 | def _write_xml(paths_in, path_out, role, n_lines=None): 218 | # prepare .xml files for mteval-v14c.pl (calc_nist_bleu) 219 | # role = 'src', 'hyp' or 'ref' 220 | 221 | lines = [ 222 | '', 223 | '', 224 | '', 225 | ''%paths_in, 226 | '', 227 | '', 228 | ] 229 | 230 | for i_in, path_in in enumerate(paths_in): 231 | 232 | # header ---- 233 | 234 | if role == 'src': 235 | lines.append('') 236 | set_ending = '' 237 | elif role == 'hyp': 238 | lines.append('') 239 | set_ending = '' 240 | elif role == 'ref': 241 | lines.append(''%i_in) 242 | set_ending = '' 243 | 244 | lines.append('') 245 | 246 | # body ----- 247 | 248 | if role == 'src': 249 | body = [''] * n_lines 250 | else: 251 | with open(path_in, 'r', encoding='utf-8') as f: 252 | body = f.readlines() 253 | if n_lines is not None: 254 | body = body[:n_lines] 255 | for i in range(len(body)): 256 | line = body[i].strip('\n') 257 | line = line.replace('&',' ').replace('<',' ') # remove illegal xml char 258 | if len(line) == 0: 259 | line = '__empty__' 260 | lines.append('

%s

'%(i + 1, line)) 261 | 262 | # ending ----- 263 | 264 | lines.append('
') 265 | if role == 'src': 266 | lines.append('') 267 | elif role == 'hyp': 268 | lines.append('') 269 | elif role == 'ref': 270 | lines.append('') 271 | 272 | lines.append('
') 273 | with open(path_out, 'w', encoding='utf-8') as f: 274 | f.write(unicode('\n'.join(lines))) 275 | 276 | def normalize_answer(s): 277 | re_art = re.compile(r'\b(a|an|the)\b') 278 | re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') 279 | 280 | def remove_articles(text): 281 | return re_art.sub(' ', text) 282 | 283 | def white_space_fix(text): 284 | return ' '.join(text.split()) 285 | 286 | def remove_punc(text): 287 | return re_punc.sub(' ', text) # convert punctuation to spaces 288 | 289 | def lower(text): 290 | return text.lower() 291 | 292 | # return white_space_fix(remove_articles(remove_punc(lower(s)))) 293 | return white_space_fix(remove_articles(lower(s))) 294 | 295 | def _f1_score(ref, pred): 296 | """ 297 | Compute precision, recall and f1 given a set of gold and prediction items. 298 | 299 | :param pred_items: iterable of predicted values 300 | :param gold_items: iterable of gold values 301 | 302 | :return: tuple (p, r, f1) for precision, recall, f1 303 | """ 304 | # ref_items = normalize_answer(ref).split() 305 | # pred_items = normalize_answer(pred).split() 306 | ref_items = ref.split() 307 | pred_items = pred.split() 308 | common = Counter(ref_items) & Counter(pred_items) 309 | num_same = sum(common.values()) 310 | if num_same == 0: 311 | return 0, 0, 0 312 | precision = 1.0 * num_same / len(pred_items) 313 | recall = 1.0 * num_same / len(ref_items) 314 | f1 = (2 * precision * recall) / (precision + recall) 315 | return precision, recall, f1 316 | 317 | def get_f1_score(refs_list, preds_list): 318 | f1 = 0 319 | for i in range(len(refs_list)): 320 | f1 += _f1_score(refs_list[i], preds_list[i])[2] 321 | return f1 / len(refs_list) 322 | 323 | def _split_into_words(sentences): 324 | """Splits multiple sentences into words and flattens the result""" 325 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 326 | 327 | def _get_ngrams(n, text): 328 | """Calcualtes n-grams. 329 | Args: 330 | n: which n-grams to calculate 331 | text: An array of tokens 332 | Returns: 333 | A set of n-grams 334 | """ 335 | ngram_set = set() 336 | text_length = len(text) 337 | max_index_ngram_start = text_length - n 338 | for i in range(max_index_ngram_start + 1): 339 | ngram_set.add(tuple(text[i:i + n])) 340 | return ngram_set 341 | 342 | def _get_word_ngrams(n, sentences): 343 | """Calculates word n-grams for multiple sentences. 344 | """ 345 | assert len(sentences) > 0 346 | assert n > 0 347 | 348 | words = _split_into_words(sentences) 349 | return _get_ngrams(n, words) 350 | 351 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 352 | """ 353 | Computes ROUGE-N of two text collections of sentences. 354 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 355 | papers/rouge-working-note-v1.3.1.pdf 356 | Args: 357 | evaluated_sentences: The sentences that have been picked by the summarizer 358 | reference_sentences: The sentences from the referene set 359 | n: Size of ngram. Defaults to 2. 360 | Returns: 361 | A tuple (f1, precision, recall) for ROUGE-N 362 | Raises: 363 | ValueError: raises exception if a param has len <= 0 364 | """ 365 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 366 | raise ValueError("Collections must contain at least 1 sentence.") 367 | 368 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 369 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 370 | reference_count = len(reference_ngrams) 371 | evaluated_count = len(evaluated_ngrams) 372 | 373 | # Gets the overlapping ngrams between evaluated and reference 374 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 375 | overlapping_count = len(overlapping_ngrams) 376 | 377 | # Handle edge case. This isn't mathematically correct, but it's good enough 378 | if evaluated_count == 0: 379 | precision = 0.0 380 | else: 381 | precision = overlapping_count / evaluated_count 382 | 383 | if reference_count == 0: 384 | recall = 0.0 385 | else: 386 | recall = overlapping_count / reference_count 387 | 388 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 389 | 390 | # return overlapping_count / reference_count 391 | return f1_score, precision, recall 392 | 393 | def _f_p_r_lcs(llcs, m, n): 394 | """ 395 | Computes the LCS-based F-measure score 396 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 397 | rouge-working-note-v1.3.1.pdf 398 | Args: 399 | llcs: Length of LCS 400 | m: number of words in reference summary 401 | n: number of words in candidate summary 402 | Returns: 403 | Float. LCS-based F-measure score 404 | """ 405 | r_lcs = llcs / m 406 | p_lcs = llcs / n 407 | beta = p_lcs / (r_lcs + 1e-12) 408 | num = (1 + (beta**2)) * r_lcs * p_lcs 409 | denom = r_lcs + ((beta**2) * p_lcs) 410 | f_lcs = num / (denom + 1e-12) 411 | return f_lcs, p_lcs, r_lcs 412 | 413 | def _len_lcs(x, y): 414 | """ 415 | Returns the length of the Longest Common Subsequence between sequences x 416 | and y. 417 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 418 | Args: 419 | x: sequence of words 420 | y: sequence of words 421 | Returns 422 | integer: Length of LCS between x and y 423 | """ 424 | table = _lcs(x, y) 425 | n, m = len(x), len(y) 426 | return table[n, m] 427 | 428 | def _lcs(x, y): 429 | """ 430 | Computes the length of the longest common subsequence (lcs) between two 431 | strings. The implementation below uses a DP programming algorithm and runs 432 | in O(nm) time where n = len(x) and m = len(y). 433 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 434 | Args: 435 | x: collection of words 436 | y: collection of words 437 | Returns: 438 | Table of dictionary of coord and len lcs 439 | """ 440 | n, m = len(x), len(y) 441 | table = dict() 442 | for i in range(n + 1): 443 | for j in range(m + 1): 444 | if i == 0 or j == 0: 445 | table[i, j] = 0 446 | elif x[i - 1] == y[j - 1]: 447 | table[i, j] = table[i - 1, j - 1] + 1 448 | else: 449 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 450 | return table 451 | 452 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 453 | """ 454 | Computes ROUGE-L (sentence level) of two text collections of sentences. 455 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 456 | rouge-working-note-v1.3.1.pdf 457 | Calculated according to: 458 | R_lcs = LCS(X,Y)/m 459 | P_lcs = LCS(X,Y)/n 460 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 461 | where: 462 | X = reference summary 463 | Y = Candidate summary 464 | m = length of reference summary 465 | n = length of candidate summary 466 | Args: 467 | evaluated_sentences: The sentences that have been picked by the summarizer 468 | reference_sentences: The sentences from the referene set 469 | Returns: 470 | A float: F_lcs 471 | Raises: 472 | ValueError: raises exception if a param has len <= 0 473 | """ 474 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 475 | raise ValueError("Collections must contain at least 1 sentence.") 476 | reference_words = _split_into_words(reference_sentences) 477 | evaluated_words = _split_into_words(evaluated_sentences) 478 | m = len(reference_words) 479 | n = len(evaluated_words) 480 | lcs = _len_lcs(evaluated_words, reference_words) 481 | return _f_p_r_lcs(lcs, m, n) 482 | 483 | def get_rouge(refs_list, preds_list): 484 | """Calculates average rouge scores for a list of hypotheses and 485 | references""" 486 | 487 | rouge_1 = [ 488 | rouge_n([pred], [ref], 1) for pred, ref in zip(preds_list, refs_list) 489 | ] 490 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 491 | 492 | # Calculate ROUGE-2 F1, precision, recall scores 493 | rouge_2 = [ 494 | rouge_n([pred], [ref], 2) for pred, ref in zip(preds_list, refs_list) 495 | ] 496 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 497 | 498 | # Calculate ROUGE-L F1, precision, recall scores 499 | rouge_l = [ 500 | rouge_l_sentence_level([hyp], [ref]) 501 | for hyp, ref in zip(preds_list, refs_list) 502 | ] 503 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 504 | return np.mean([rouge_1_f,rouge_2_f,rouge_l_f]), rouge_1_f, rouge_2_f, rouge_l_f 505 | 506 | def cal_dist(preds_list): 507 | ngram1, ngram2 = set(), set() 508 | sentence_dist1, sentence_dist2 = 0, 0 509 | total_length = 0 510 | for pred in preds_list: 511 | length = len(pred.split(' ')) 512 | cur_gram_1 = _get_word_ngrams(1, [pred]) 513 | cur_gram_2 = _get_word_ngrams(2, [pred]) 514 | if length > 0: 515 | sentence_dist1 += len(cur_gram_1) / length 516 | if length > 1: 517 | sentence_dist2 += len(cur_gram_2) / (length - 1) 518 | ngram1 = ngram1.union(cur_gram_1) 519 | ngram2 = ngram2.union(cur_gram_2) 520 | total_length += length 521 | if total_length == len(preds_list): 522 | return sentence_dist1 / len(preds_list), sentence_dist2 / len(preds_list), 0, 0 523 | return sentence_dist1 / len(preds_list), sentence_dist2 / len(preds_list), len(ngram1) / total_length, \ 524 | len(ngram2) / (total_length - len(preds_list)) 525 | 526 | def calc_entropy(preds, n_lines=None): 527 | # based on Yizhe Zhang's code 528 | entropy_score = [0.0, 0.0, 0.0, 0.0] 529 | counter = [defaultdict(int),defaultdict(int),defaultdict(int),defaultdict(int)] 530 | i = 0 531 | for line in preds: 532 | i += 1 533 | words = line.strip('\n').split() 534 | for n in range(4): 535 | for idx in range(len(words)-n): 536 | ngram = ' '.join(words[idx:idx+n+1]) 537 | counter[n][ngram] += 1 538 | if i == n_lines: 539 | break 540 | 541 | for n in range(4): 542 | total = sum(counter[n].values()) 543 | for v in counter[n].values(): 544 | entropy_score[n] += - v /total * (np.log(v) - np.log(total)) 545 | 546 | return entropy_score 547 | 548 | def nlp_metrics(ref_file, pred_file, root_path=None): 549 | preds_list = [] 550 | with open(pred_file, 'r', encoding='utf-8') as f: 551 | lines = f.readlines() 552 | for line in lines: 553 | preds_list.append(normalize_answer(line)) 554 | refs_list = [] 555 | with open(ref_file, 'r', encoding='utf-8') as f: 556 | lines = f.readlines() 557 | for line in lines: 558 | refs_list.append(normalize_answer(line)) 559 | 560 | sentence_dist1, sentence_dist2, corpus_dist1, corpus_dist2 = cal_dist(preds_list) 561 | s_dist= [sentence_dist1, sentence_dist2] 562 | c_dist = [corpus_dist1, corpus_dist2] 563 | entropy_scores = calc_entropy(preds_list) 564 | 565 | nist, nist_bleu, nist_list, nist_bleu_list = calc_nist_bleu([ref_file], pred_file, fld_out=root_path) 566 | 567 | # meteor_score = calc_meteor([ref_file], pred_file) 568 | 569 | bleu_output = calc_cum_bleu([ref_file], pred_file) 570 | bleu_text = re.search(r"BLEU = (.+?), (.+?)/(.+?)/(.+?)/(.+?) \(", bleu_output) 571 | 572 | bleu = float(bleu_text.group(1)) 573 | bleu1 = float(bleu_text.group(2)) 574 | bleu2 = float(bleu_text.group(3)) 575 | bleu3 = float(bleu_text.group(4)) 576 | bleu4 = float(bleu_text.group(5)) 577 | bleu_list = [bleu1, bleu2, bleu3, bleu4] 578 | 579 | f1_score = get_f1_score(refs_list, preds_list) 580 | _, _, _, rouge_l = get_rouge(refs_list, preds_list) 581 | avg_pred_length = np.mean([len(x.split(' ')) for x in preds_list]) 582 | 583 | return bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy_scores, 0.0, \ 584 | rouge_l, f1_score, avg_pred_length 585 | -------------------------------------------------------------------------------- /openai-gpt/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "OpenAIGPTLMHeadModel" 4 | ], 5 | "afn": "gelu", 6 | "attn_pdrop": 0.1, 7 | "embd_pdrop": 0.1, 8 | "initializer_range": 0.02, 9 | "layer_norm_epsilon": 1e-05, 10 | "n_ctx": 512, 11 | "n_embd": 768, 12 | "n_head": 12, 13 | "n_layer": 12, 14 | "n_positions": 512, 15 | "n_special": 0, 16 | "resid_pdrop": 0.1, 17 | "vocab_size": 40478, 18 | "shared_module": true, 19 | "shared_attention": false, 20 | "context_size": 2 21 | } 22 | -------------------------------------------------------------------------------- /openai-gpt/special_tokens.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | tqdm 3 | transformers==2.5.1 4 | attrdict 5 | nltk 6 | ftfy 7 | tensorboardX==2.0 8 | future 9 | pillow 10 | git-python -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | 7 | import torch 8 | import torch.nn as nn 9 | from tensorboardX import SummaryWriter 10 | from transformers.tokenization_gpt2 import GPT2Tokenizer 11 | from transformers.tokenization_openai import OpenAIGPTTokenizer 12 | 13 | from config import get_trainer_config 14 | from config import InputConfig 15 | from model.dataset import FacebookDataset 16 | from model.gpt2_model import GPT2DoubleHeadsModel 17 | from model.gpt2_model import GPT2EncoderDecoderModel 18 | from model.openai_model import OpenAIGPTEncoderDecoderModel 19 | from model.openai_model import OpenAIGPTLMHeadModel 20 | from model.trainer import Trainer 21 | from model.utils import config_logger 22 | from model.utils import f1_score 23 | from model.utils import open 24 | from model.utils import set_seed 25 | from model.seq2seq import TransformerSeq2Seq 26 | from model.seq2seq_vocab import Seq2seqVocab 27 | from new_metrics import nlp_metrics 28 | 29 | 30 | class DummyWriter: 31 | """ Used for distributed training (from NVIDIA apex example). 32 | A dummy logger used so that only the main process write and log informations. 33 | """ 34 | def __init__(self, *input, **kwargs): 35 | self.log_dir = "runs/dummy_logs/" 36 | 37 | def add_scalar(self, *input, **kwargs): 38 | pass 39 | 40 | def modify_tokenizer(tokenizer, data_type): 41 | additional_special_tokens = ['', '', '', '', '', 42 | ''] 43 | if data_type == 'emoji': 44 | with open('datasets/emoji_talk/emojis.json', 'r') as f: 45 | emojis = json.load(f)['emojis'] 46 | additional_special_tokens.extend(emojis) 47 | if data_type == 'daily': 48 | with open('datasets/DailyDialog/daily.json', 'r') as f: 49 | topic_tokens = json.load(f) 50 | additional_special_tokens.extend(topic_tokens) 51 | tokenizer.add_special_tokens({'pad_token': '', 'bos_token': '', 'eos_token': '', 52 | 'additional_special_tokens': additional_special_tokens}) 53 | tokenizer.eos_id, tokenizer.bos_id, tokenizer.pad_id = tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id 54 | tokenizer.sent_dialog_id = tokenizer.bos_token_id 55 | tokenizer.info_dialog_id, tokenizer.info_bos_id = tokenizer.added_tokens_encoder[''], \ 56 | tokenizer.added_tokens_encoder[ 57 | ''] 58 | tokenizer.info_eos_id = tokenizer.added_tokens_encoder[''] 59 | tokenizer.talker1_dialog_id, tokenizer.talker1_bos_id = tokenizer.added_tokens_encoder[''], \ 60 | tokenizer.added_tokens_encoder[''] 61 | tokenizer.talker1_eos_id = tokenizer.added_tokens_encoder[''] 62 | tokenizer.talker2_dialog_id, tokenizer.talker2_bos_id = tokenizer.added_tokens_encoder[''], \ 63 | tokenizer.added_tokens_encoder[''] 64 | tokenizer.talker2_eos_id = tokenizer.added_tokens_encoder[''] 65 | return tokenizer, len(additional_special_tokens) + 3 66 | 67 | def get_model_and_tokenizer(args, trainer_config, logger): 68 | if args.model_type == 'gpt': 69 | if args.single_input: 70 | model = OpenAIGPTLMHeadModel.from_pretrained('./openai-gpt') 71 | else: 72 | model = OpenAIGPTEncoderDecoderModel.from_pretrained('./openai-gpt') 73 | tokenizer = OpenAIGPTTokenizer.from_pretrained('./openai-gpt') 74 | elif args.model_type == 'dialogpt': 75 | if args.single_input: 76 | model = GPT2DoubleHeadsModel.from_pretrained('./dialogpt') 77 | else: 78 | model = GPT2EncoderDecoderModel.from_pretrained('./dialogpt') 79 | tokenizer = GPT2Tokenizer.from_pretrained('./dialogpt') 80 | elif args.model_type == 'seq2seq': 81 | seq2seq_vocab = Seq2seqVocab(trainer_config.train_datasets, trainer_config.valid_datasets, 82 | trainer_config.test_datasets, args.vocab_path, data_type=trainer_config.data_type, 83 | extend_exist_vocab=args.extend_exist_vocab) 84 | tokenizer = seq2seq_vocab.vocab 85 | args.dialog_embeddings = False 86 | model = TransformerSeq2Seq(args.emb_dim, args.hidden_dim, args.num_layers, args.heads, args.depth_size, 87 | args.filter_size, tokenizer, args.pretrained_emb_file, args.pointer_gen, logger, 88 | multi_input=not args.single_input, attention_pooling_type=args.attention_pooling_type, 89 | label_smoothing=args.label_smoothing) 90 | else: 91 | if args.single_input: 92 | model = GPT2DoubleHeadsModel.from_pretrained('./gpt2-small') 93 | else: 94 | model = GPT2EncoderDecoderModel.from_pretrained('./gpt2-small') 95 | tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-small') 96 | return model, tokenizer 97 | 98 | '''Modify the model to make it fit the data''' 99 | def modify_model(args, model, tokenizer): 100 | if args.model_type in ['gpt', 'dialogpt', 'gpt2']: 101 | tokenizer, additional_length = modify_tokenizer(tokenizer, args.data_type) 102 | model.embeddings_size = 768 103 | model.n_embeddings = len(tokenizer) 104 | model.shared_attention = (args.shared_attention == 1) 105 | model.shared_module = (args.shared_module == 1) 106 | model.attention_fusion_type = args.attention_fusion_type 107 | model.single_input = args.single_input 108 | if args.model_type == 'gpt': 109 | model_embedding_weight = model.transformer.tokens_embed.weight 110 | model.transformer.tokens_embed = nn.Embedding(model.n_embeddings, 768) 111 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 112 | model.transformer.tokens_embed.weight.data[:-additional_length, :] = model_embedding_weight.data 113 | model.transformer.tokens_embed.weight.data[-additional_length:, :] = 0 114 | model.lm_head.weight = model.transformer.tokens_embed.weight 115 | else: 116 | model_embedding_weight = model.transformer.wte.weight 117 | model.transformer.wte = nn.Embedding(model.n_embeddings, 768) 118 | model.lm_head = nn.Linear(768, model.n_embeddings, bias=False) 119 | model.transformer.wte.weight.data[:-additional_length, :] = model_embedding_weight.data 120 | model.transformer.wte.weight.data[-additional_length:, :] = 0 121 | model.lm_head.weight = model.transformer.wte.weight 122 | 123 | if not args.single_input: 124 | model.reload_module_dict() 125 | model.sent_dialog_id = tokenizer.sent_dialog_id 126 | model.talker1_id = tokenizer.talker1_bos_id 127 | model.talker2_id = tokenizer.talker2_bos_id 128 | 129 | model.padding_idx = tokenizer.pad_id 130 | model.n_pos_embeddings = 512 131 | 132 | model.bos_id = tokenizer.bos_id 133 | model.eos_id = tokenizer.eos_id 134 | model.beam_size = args.beam_size 135 | model.diversity_groups = 1 136 | model.max_seq_len = 32 137 | model.dialog_embeddings = args.dialog_embeddings 138 | model.bs_temperature = args.bs_temperature 139 | model.bs_nucleus_p = args.bs_nucleus_p 140 | model.annealing_topk = args.annealing_topk 141 | model.length_penalty_coef = args.length_penalty 142 | model.vocab = None 143 | model.annealing = args.annealing 144 | model.diversity_coef = args.diversity_coef 145 | model.sample = False 146 | model.inference_mode = args.inference_mode 147 | model.response_k = args.response_k 148 | 149 | def training_procedure(args, trainer_config, model, tokenizer, device, writer, logger, best_checkpoint_path, 150 | last_checkpoint_path, interrupt_checkpoint_path, log_dir, test_data_type=None): 151 | logger.info("trainer config: {}".format(trainer_config)) 152 | logger.info('loading datasets') 153 | train_dataset = FacebookDataset(trainer_config.train_datasets, tokenizer, 154 | max_lengths=model.n_pos_embeddings - 1, # A bit restrictive here 155 | dialog_embeddings=args.dialog_embeddings, 156 | cache=trainer_config.train_datasets_cache, 157 | use_start_end=False, 158 | negative_samples=trainer_config.negative_samples, 159 | augment=trainer_config.persona_augment, 160 | aug_syn_proba=trainer_config.persona_aug_syn_proba, 161 | limit_size=trainer_config.limit_train_size, 162 | max_history_size=trainer_config.max_history_size, 163 | single_input=args.single_input, 164 | data_type=trainer_config.data_type) 165 | valid_dataset = FacebookDataset(trainer_config.valid_datasets, tokenizer, 166 | max_lengths=model.n_pos_embeddings - 1, # A bit restrictive here 167 | dialog_embeddings=args.dialog_embeddings, 168 | cache=trainer_config.valid_datasets_cache, 169 | use_start_end=False, 170 | negative_samples=-1, # Keep all negative samples 171 | augment=False, 172 | aug_syn_proba=0.0, 173 | limit_size=trainer_config.limit_eval_size, 174 | max_history_size=trainer_config.max_history_size, 175 | single_input=args.single_input, 176 | data_type=trainer_config.data_type) 177 | if test_data_type is None: 178 | test_data_type = trainer_config.data_type 179 | test_dataset = FacebookDataset(trainer_config.test_datasets, tokenizer, 180 | max_lengths=model.n_pos_embeddings - 1, # A bit restrictive here 181 | dialog_embeddings=args.dialog_embeddings, 182 | cache=trainer_config.test_datasets_cache, 183 | use_start_end=False, 184 | negative_samples=-1, # Keep all negative samples 185 | augment=False, 186 | aug_syn_proba=0.0, 187 | limit_size=trainer_config.limit_eval_size, 188 | max_history_size=trainer_config.max_history_size, 189 | single_input=args.single_input, 190 | data_type=test_data_type) 191 | logger.info('train dataset {} valid dataset {} test dataset {}' 192 | .format(len(train_dataset), len(valid_dataset), len(test_dataset))) 193 | 194 | '''Normal training will use normal trainer''' 195 | model_trainer = Trainer(model, 196 | train_dataset, 197 | trainer_config, 198 | writer, 199 | logger=logger, 200 | valid_dataset=valid_dataset, 201 | test_dataset=test_dataset, 202 | n_jobs=trainer_config.n_jobs, 203 | device=device, 204 | ignore_idxs=tokenizer.all_special_ids, 205 | local_rank=args.local_rank, 206 | apex_level=None, 207 | apex_loss_scale=trainer_config.apex_loss_scale, 208 | evaluate_full_sequences=trainer_config.evaluate_full_sequences, 209 | full_input=trainer_config.full_input, 210 | uncertainty_loss=args.uncertainty_loss, 211 | best_model_path=best_checkpoint_path, 212 | extra_module_lr_rate=args.extra_module_lr_rate) 213 | 214 | if args.load_last: 215 | state_dict = torch.load(trainer_config.load_last, map_location=device) 216 | model_trainer.load_state_dict(state_dict) 217 | 218 | # helpers ----------------------------------------------------- 219 | def external_metrics_func(full_references, full_predictions, epoch, is_best=False): 220 | if epoch == -1: 221 | if is_best: 222 | references_file_path = os.path.join(writer.logdir, trainer_config.test_references_file) 223 | predictions_file_path = os.path.join(writer.logdir, trainer_config.test_predictions_file_best) 224 | else: 225 | references_file_path = os.path.join(writer.logdir, trainer_config.test_references_file) 226 | predictions_file_path = os.path.join(writer.logdir, trainer_config.test_predictions_file_last) 227 | else: 228 | references_file_path = os.path.join(writer.logdir, trainer_config.eval_references_file) 229 | predictions_file_path = os.path.join(writer.logdir, 230 | trainer_config.eval_predictions_file + "_{}".format(epoch)) 231 | 232 | if not os.path.exists(references_file_path): 233 | with open(references_file_path, 'w', encoding='utf-8') as f: 234 | f.write('\n'.join(full_references)) 235 | # print(len(full_predictions)) 236 | with open(os.path.join(writer.logdir, 'tt.json'), 'w') as f: 237 | json.dump(full_predictions, f) 238 | with open(predictions_file_path, 'w', encoding='utf-8') as f: 239 | if len(full_predictions[-1]) == 0: 240 | full_predictions[-1] = 'a ' 241 | f.write('\n'.join(full_predictions)) 242 | 243 | bleu, bleu_list, nist, nist_list, nist_bleu, nist_bleu_list, s_dist, c_dist, entropy, meteor, \ 244 | rouge_l, f1_score, avg_length = nlp_metrics(references_file_path, predictions_file_path, root_path=log_dir) 245 | 246 | metrics = {'meteor': meteor * 100, 'avg_len': avg_length, 'rouge-l': rouge_l * 100, 'bleu': bleu, 'nist': nist, 247 | 'nist-bleu': nist_bleu, 'f1': f1_score * 100} 248 | for name, metric in ( 249 | ('bleu', bleu_list), ('nist', nist_list), ('nist_bleu', nist_bleu_list), ('entropy', entropy), 250 | ('sentence_div', s_dist), ('corpus_div', c_dist)): 251 | for i, m in enumerate(metric, 1): 252 | if name == 'sentence_div' or name == 'corpus_div': 253 | metrics['{}_{}'.format(name, i)] = m * 100 254 | else: 255 | metrics['{}_{}'.format(name, i)] = m 256 | for k, v in metrics.items(): 257 | metrics[k] = round(v, 6) 258 | 259 | return metrics 260 | 261 | def save_func(epoch): 262 | if epoch != -1: 263 | torch.save(model_trainer.model.state_dict(), last_checkpoint_path) 264 | logger.info('Model on Epoch %d has been saved', epoch) 265 | 266 | def sample_text_func(epoch): 267 | n_samples = 0 268 | model_trainer.model.eval() 269 | samples_idxs = random.sample(range(len(valid_dataset)), n_samples) 270 | samples = [valid_dataset[idx] for idx in samples_idxs] 271 | for persona_info, dialog, target, _ in samples: 272 | contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog] 273 | if len(c) > 0] 274 | prediction = model_trainer.model.predict(contexts)[0] 275 | 276 | persona_info_str = tokenizer.ids2string(persona_info[1:-1]) 277 | dialog_str = tokenizer.ids2string(dialog) 278 | dialog_str = dialog_str.replace(tokenizer.talker1_bos, '\n\t- ').replace(tokenizer.talker2_bos, '\n\t- ') 279 | dialog_str = dialog_str.replace(tokenizer.talker1_eos, '').replace(tokenizer.talker2_eos, '') 280 | target_str = tokenizer.ids2string(target[1:-1]) 281 | prediction_str = tokenizer.ids2string(prediction) 282 | 283 | logger.info('\n') 284 | logger.info('Persona info:\n\t{}'.format(persona_info_str)) 285 | logger.info('Dialog:{}'.format(dialog_str)) 286 | logger.info('Target:\n\t{}'.format(target_str)) 287 | logger.info('Prediction:\n\t{}'.format(prediction_str)) 288 | 289 | def test_func(epoch): 290 | if (epoch + 1) % trainer_config.test_period == 0: 291 | metric_funcs = {'f1_score': f1_score} 292 | model_trainer.test(metric_funcs, external_metrics_func, epoch) 293 | 294 | def f1_risk(predictions, targets): 295 | scores = f1_score(predictions, targets, average=False) 296 | assert all([0 <= s <= 1.0 for s in scores]) 297 | return [1 - s for s in scores] 298 | 299 | def get_risk_metric_func(risk_metric): 300 | """ risk_metric selected in: 301 | f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4} 302 | """ 303 | 304 | def external_metric_risk(predictions, targets): 305 | string_targets = list(tokenizer.ids2string(t) for t in targets) 306 | string_predictions = list(tokenizer.ids2string(t) for t in predictions) 307 | metrics = [external_metrics_func([t], [p], epoch=-1, metric=risk_metric) for p, t in 308 | zip(string_predictions, string_targets)] 309 | 310 | if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]): 311 | return [-m for m in metrics] 312 | 313 | assert all([0 <= s <= 1.0 for s in metrics]), metrics 314 | 315 | return [1 - m for m in metrics] 316 | 317 | if risk_metric == 'f1': 318 | return f1_risk 319 | 320 | return external_metric_risk 321 | 322 | # helpers ----------------------------------------------------- 323 | 324 | try: 325 | model_trainer.train(after_epoch_funcs=[save_func, sample_text_func, test_func], 326 | risk_func=get_risk_metric_func(trainer_config.risk_metric)) 327 | except (KeyboardInterrupt, Exception, RuntimeError) as e: 328 | if args.local_rank in [-1, 0]: 329 | torch.save(model_trainer.state_dict(), interrupt_checkpoint_path) 330 | raise e 331 | 332 | def main(): 333 | args = InputConfig().args 334 | 335 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 336 | datefmt = '%m/%d/%Y %H:%M:%S', 337 | level = logging.INFO if args.local_rank in [-1, 0] else logging.ERROR) 338 | if args.server_ip and args.server_port and args.local_rank in [-1, 0]: 339 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 340 | import ptvsd 341 | print("Waiting for debugger attach") 342 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 343 | ptvsd.wait_for_attach() 344 | 345 | trainer_config = get_trainer_config(args) 346 | 347 | # Log only on main process 348 | if args.local_rank not in [-1, 0]: 349 | sys.stdout = open("./runs/log_distributed_{}".format(args.local_rank), "w") # dump sdtout 350 | writer = DummyWriter() 351 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 352 | datefmt='%m/%d/%Y %H:%M:%S', level=logging.ERROR) 353 | logger = logging.getLogger(__file__) 354 | else: 355 | from datetime import datetime 356 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 357 | if args.single_input: 358 | comment = '_{}_{}_single'.format(args.model_type, args.data_type) 359 | else: 360 | if args.model_type == 'seq2seq': 361 | comment = '_seq2seq_multi_{}_{}'.format(args.data_type, args.attention_fusion_type) 362 | else: 363 | comment = '_{}_{}_{}_{}_{}'.format(args.model_type, args.data_type, args.attention_fusion_type, 364 | ('sm' if args.shared_module == 1 else 'nm'), ('sa' if args.shared_attention == 1 else 'na')) 365 | logdir = os.path.join('runs', current_time + comment) 366 | writer = SummaryWriter(logdir=logdir) 367 | logger = config_logger(os.path.join(logdir, 'train.log')) 368 | 369 | log_dir = writer.logdir 370 | logger.info("Training args: {}".format(args)) 371 | interrupt_checkpoint_path = os.path.join(log_dir, trainer_config.interrupt_checkpoint_path) 372 | last_checkpoint_path = os.path.join(log_dir, trainer_config.last_checkpoint_path) 373 | best_checkpoint_path = os.path.join(log_dir, 'best_model') 374 | logger.info("Logging to {}".format(log_dir)) # Let's save everything on an experiment in the ./runs/XXX/directory 375 | if args.local_rank in [-1, 0]: 376 | with open(os.path.join(log_dir, "trainer_config.json"), "w") as f: 377 | json.dump(trainer_config, f) 378 | 379 | set_seed(trainer_config.seed) 380 | device = torch.device(trainer_config.device) 381 | 382 | model, tokenizer = get_model_and_tokenizer(args, trainer_config, logger) 383 | logger.info('Load tokenizer, vocab size is %d', tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 384 | tokenizer.n_words) 385 | modify_model(args, model, tokenizer) 386 | training_procedure(args, trainer_config, model, tokenizer, device, writer, logger, best_checkpoint_path, 387 | last_checkpoint_path, interrupt_checkpoint_path, log_dir, test_data_type=args.test_data_type) 388 | 389 | if __name__ == '__main__': 390 | main() 391 | --------------------------------------------------------------------------------