├── Appendix.pdf ├── LICENSE ├── README.md ├── bert └── .gitkeep ├── bert_data └── .gitkeep ├── json_data └── .gitkeep ├── logs └── .gitkeep ├── models └── .gitkeep ├── pretrain_emb └── .gitkeep ├── results └── .gitkeep └── src ├── distributed.py ├── models ├── __init__.py ├── adam.py ├── data_loader.py ├── decoder_rnn.py ├── decoder_tf.py ├── encoder.py ├── generator.py ├── hier_model.py ├── hier_model_trainer.py ├── hier_predictor.py ├── loss.py ├── neural.py ├── optimizers.py ├── reporter.py ├── rl_model.py ├── rl_model_trainer.py ├── rl_predictor.py ├── seq2seq.py ├── seq2seq_predictor.py ├── seq2seq_trainer.py ├── topic.py ├── topic_model.py └── topic_model_trainer.py ├── others ├── __init__.py ├── logging.py ├── tokenization.py ├── utils.py └── vocab_wrapper.py ├── prepro ├── __init__.py └── data_builder.py ├── preprocess.py ├── train.py ├── train_abstractive.py ├── train_emb.py └── translate ├── __init__.py ├── beam.py └── penalties.py /Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/Appendix.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yicheng Zou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Topic-Dialog-Summ 2 | 3 | Pytorch implementation of the AAAI-2021 paper: [Topic-Oriented Spoken Dialogue Summarization for Customer Service with Saliency-Aware Topic Modeling](https://ojs.aaai.org/index.php/AAAI/article/view/17723/17530). 4 | 5 | The code is partially referred to https://github.com/nlpyang/PreSumm. 6 | 7 | ## Requirements 8 | 9 | * Python 3.6 or higher 10 | * torch==1.1.0 11 | * pytorch-transformers==1.1.0 12 | * torchtext==0.4.0 13 | * rouge==0.3.2 14 | * tensorboardX==2.1 15 | * nltk==3.5 16 | * gensim==3.8.3 17 | 18 | ## Environment 19 | 20 | * Tesla V100 16GB GPU 21 | * CUDA 10.2 22 | 23 | ## Data Format 24 | 25 | Each json file is a data list that includes dialogue samples. The format of a dialogue sample is shown as follows: 26 | 27 | ``` 28 | {"session": [ 29 | // Utterance 30 | { 31 | // Chinese characters 32 | "content": ["请", "问", "有", "什", "么", "可", "以", "帮", "您"], 33 | // Chinese Words 34 | "word": ["请问", "有", "什么", "可以", "帮", "您"], 35 | // Role info (Agent) 36 | "type": "客服" 37 | }, 38 | 39 | {"content": ["我", "想", "退", "货"], 40 | "word": ["我", "想", "退货"], 41 | // Role info (Customer) 42 | "type": "客户"}, 43 | 44 | ... 45 | ], 46 | "summary": ["客", "户", "来", "电", "要", "求", "退", "货", "。", ...] 47 | } 48 | ``` 49 | 50 | ## Usage 51 | 52 | 1. Download BERT checkpoints. 53 | 54 | The pretrained BERT checkpoints can be found at: 55 | 56 | * Chinese BERT: https://github.com/ymcui/Chinese-BERT-wwm 57 | * English BERT: https://github.com/google-research/bert 58 | 59 | Put BERT checkpoints into the directory **bert** like this: 60 | 61 | ``` 62 | --- bert 63 | | 64 | |--- chinese_bert 65 | | 66 | |--- config.json 67 | | 68 | |--- pytorch_model.bin 69 | | 70 | |--- vocab.txt 71 | ``` 72 | 73 | 2. Pre-train word2vec embeddings 74 | 75 | ``` 76 | PYTHONPATH=. python ./src/train_emb.py -data_path json_data -emb_size 100 -emb_path pretrain_emb/word2vec 77 | ``` 78 | 79 | 3. Data Processing 80 | 81 | ``` 82 | PYTHONPATH=. python ./src/preprocess.py -raw_path json_data -save_path bert_data -bert_dir bert/chinese_bert -log_file logs/preprocess.log -emb_path pretrain_emb/word2vec -tokenize -truncated -add_ex_label 83 | ``` 84 | 85 | 4. Pre-train the pipeline model (Ext + Abs) 86 | 87 | ``` 88 | PYTHONPATH=. python ./src/train.py -data_path bert_data/ali -bert_dir bert/chinese_bert -log_file logs/pipeline.topic.train.log -sep_optim -topic_model -split_noise -pretrain -model_path models/pipeline_topic 89 | ``` 90 | 91 | 5. Train the whole model with RL 92 | 93 | ``` 94 | PYTHONPATH=. python ./src/train.py -data_path bert_data/ali -bert_dir bert/chinese_bert -log_file logs/rl.topic.train.log -model_path models/rl_topic -topic_model -split_noise -train_from models/pipeline_topic/model_step_80000.pt -train_from_ignore_optim -lr 0.00001 -save_checkpoint_steps 500 -train_steps 30000 95 | ``` 96 | 97 | 6. Validate 98 | 99 | ``` 100 | PYTHONPATH=. python ./src/train.py -mode validate -data_path bert_data/ali -bert_dir bert/chinese_bert -log_file logs/rl.topic.val.log -alpha 0.95 -model_path models/rl_topic -topic_model -split_noise -result_path results/val 101 | ``` 102 | 103 | 7. Test 104 | 105 | ``` 106 | PYTHONPATH=. python ./src/train.py -mode test -data_path bert_data/ali -bert_dir bert/chinese_bert -test_from models/rl_topic/model_step_30000.pt -log_file logs/rl.topic.test.log -alpha 0.95 -topic_model -split_noise -result_path results/test 107 | ``` 108 | 109 | ## Data 110 | 111 | Our dialogue summarization dataset is collected from [Alibaba customer service center](https://114.1688.com/kf/contact.html). All dialogues are incoming calls in Mandarin Chinese that take place between a customer and a service agent. For the security of private information from customers, we performed the data desensitization and converted words to IDs. As a result, the data cannot be directly used in our released codes and other pre-trained models like BERT, but the dataset still provides some statistical information. 112 | 113 | The desensitized data is available at 114 | [Google Drive](https://drive.google.com/file/d/1X3-C9vTYfk43T5NIEvRsdRIJkN1RuG7b/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1AvkGnerKpQHUNbwkz9kO7A) (extract code: t6nx). 115 | 116 | ## Citation 117 | @article{Zou_Zhao_Kang_Lin_Peng_Jiang_Sun_Zhang_Huang_Liu_2021, 118 | title={Topic-Oriented Spoken Dialogue Summarization for Customer Service with Saliency-Aware Topic Modeling}, 119 | volume={35}, 120 | url={https://ojs.aaai.org/index.php/AAAI/article/view/17723}, 121 | number={16}, 122 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 123 | author={Zou, Yicheng and Zhao, Lujun and Kang, Yangyang and Lin, Jun and Peng, Minlong and Jiang, Zhuoren and Sun, Changlong and Zhang, Qi and Huang, Xuanjing and Liu, Xiaozhong}, 124 | year={2021}, 125 | month={May}, 126 | pages={14665-14673} 127 | } 128 | 129 | -------------------------------------------------------------------------------- /bert/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/bert/.gitkeep -------------------------------------------------------------------------------- /bert_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/bert_data/.gitkeep -------------------------------------------------------------------------------- /json_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/json_data/.gitkeep -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/logs/.gitkeep -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/models/.gitkeep -------------------------------------------------------------------------------- /pretrain_emb/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/pretrain_emb/.gitkeep -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/results/.gitkeep -------------------------------------------------------------------------------- /src/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | 12 | import torch.distributed 13 | 14 | from others.logging import logger 15 | 16 | 17 | def is_master(gpu_ranks, device_id): 18 | return gpu_ranks[device_id] == 0 19 | 20 | 21 | def multi_init(device_id, world_size, gpu_ranks): 22 | print(gpu_ranks) 23 | dist_init_method = 'tcp://localhost:10000' 24 | dist_world_size = world_size 25 | torch.distributed.init_process_group( 26 | backend='nccl', init_method=dist_init_method, 27 | world_size=dist_world_size, rank=gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(gpu_ranks, device_id): 30 | logger.disabled = True 31 | 32 | return gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/adam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class Adam(Optimizer): 6 | r"""Implements Adam algorithm. 7 | 8 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 1e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square (default: (0.9, 0.999)) 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 20 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 21 | (default: False) 22 | 23 | .. _Adam\: A Method for Stochastic Optimization: 24 | https://arxiv.org/abs/1412.6980 25 | .. _On the Convergence of Adam and Beyond: 26 | https://openreview.net/forum?id=ryQu7f-RZ 27 | """ 28 | 29 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 30 | weight_decay=0, amsgrad=False): 31 | if not 0.0 <= lr: 32 | raise ValueError("Invalid learning rate: {}".format(lr)) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {}".format(eps)) 35 | if not 0.0 <= betas[0] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 37 | if not 0.0 <= betas[1] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 39 | defaults = dict(lr=lr, betas=betas, eps=eps, 40 | weight_decay=weight_decay, amsgrad=amsgrad) 41 | super(Adam, self).__init__(params, defaults) 42 | 43 | def __setstate__(self, state): 44 | super(Adam, self).__setstate__(state) 45 | for group in self.param_groups: 46 | group.setdefault('amsgrad', False) 47 | 48 | def step(self, closure=None): 49 | """Performs a single optimization step. 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 65 | 66 | state = self.state[p] 67 | 68 | # State initialization 69 | if len(state) == 0: 70 | state['step'] = 0 71 | # Exponential moving average of gradient values 72 | state['next_m'] = torch.zeros_like(p.data) 73 | # Exponential moving average of squared gradient values 74 | state['next_v'] = torch.zeros_like(p.data) 75 | 76 | next_m, next_v = state['next_m'], state['next_v'] 77 | beta1, beta2 = group['betas'] 78 | 79 | # Decay the first and second moment running average coefficient 80 | # In-place operations to update the averages at the same time 81 | next_m.mul_(beta1).add_(1 - beta1, grad) 82 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | update = next_m / (next_v.sqrt() + group['eps']) 84 | 85 | # Just adding the square of the weights to the loss function is *not* 86 | # the correct way of using L2 regularization/weight decay with Adam, 87 | # since that will interact with the m and v parameters in strange ways. 88 | # 89 | # Instead we want to decay the weights in a manner that doesn't interact 90 | # with the m/v parameters. This is equivalent to adding the square 91 | # of the weights to the loss with plain (non-momentum) SGD. 92 | if group['weight_decay'] > 0.0: 93 | update += group['weight_decay'] * p.data 94 | 95 | lr_scheduled = group['lr'] 96 | 97 | update_with_lr = lr_scheduled * update 98 | p.data.add_(-update_with_lr) 99 | 100 | state['step'] += 1 101 | 102 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 103 | # No bias correction 104 | # bias_correction1 = 1 - beta1 ** state['step'] 105 | # bias_correction2 = 1 - beta2 ** state['step'] 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /src/models/decoder_rnn.py: -------------------------------------------------------------------------------- 1 | """ Base Class and function for Decoders """ 2 | 3 | from __future__ import division 4 | import torch 5 | import torch.nn as nn 6 | 7 | from models.neural import aeq, DecoderState, GlobalAttention 8 | 9 | 10 | class RNNDecoder(nn.Module): 11 | """ 12 | Base recurrent attention-based decoder class. 13 | Specifies the interface used by different decoder types 14 | and required by :obj:`models.NMTModel`. 15 | .. mermaid:: 16 | graph BT 17 | A[Input] 18 | subgraph RNN 19 | C[Pos 1] 20 | D[Pos 2] 21 | E[Pos N] 22 | end 23 | G[Decoder State] 24 | H[Decoder State] 25 | I[Outputs] 26 | F[Memory_Bank] 27 | A--emb-->C 28 | A--emb-->D 29 | A--emb-->E 30 | H-->C 31 | C-- attn --- F 32 | D-- attn --- F 33 | E-- attn --- F 34 | C-->I 35 | D-->I 36 | E-->I 37 | E-->G 38 | F---I 39 | Args: 40 | rnn_type (:obj:`str`): 41 | style of recurrent unit to use, one of [LSTM, GRU] 42 | bidirectional_encoder (bool) : use with a bidirectional encoder 43 | num_layers (int) : number of stacked layers 44 | hidden_size (int) : hidden size of each layer 45 | attn_type (str) : see :obj:`onmt.modules.GlobalAttention` 46 | coverage_attn (str): see :obj:`onmt.modules.GlobalAttention` 47 | copy_attn (bool): setup a separate copy attention mechanism 48 | dropout (float) : dropout value for :obj:`nn.Dropout` 49 | embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use 50 | """ 51 | 52 | def __init__(self, rnn_type, bidirectional_encoder, num_layers, 53 | hidden_size, attn_type="general", 54 | coverage_attn=False, copy_attn=False, 55 | dropout=0.0, embeddings=None, 56 | reuse_copy_attn=False): 57 | super(RNNDecoder, self).__init__() 58 | assert embeddings is not None 59 | # Basic attributes. 60 | self.decoder_type = 'rnn' 61 | self.bidirectional_encoder = bidirectional_encoder 62 | self.num_layers = num_layers 63 | self.hidden_size = hidden_size 64 | self.embeddings = embeddings 65 | self.dropout = nn.Dropout(dropout) 66 | input_size = self.embeddings.embedding_dim + self.hidden_size 67 | 68 | # Build the RNN. 69 | self.rnn = self._build_rnn(rnn_type, 70 | input_size=input_size, 71 | hidden_size=hidden_size, 72 | num_layers=num_layers, 73 | dropout=dropout) 74 | 75 | # Set up the standard attention. 76 | self._coverage = coverage_attn 77 | self.attn = GlobalAttention( 78 | hidden_size, coverage=coverage_attn, 79 | attn_type=attn_type 80 | ) 81 | 82 | # Set up a separated copy attention layer, if needed. 83 | self._copy = False 84 | if copy_attn and not reuse_copy_attn: 85 | self.copy_attn = GlobalAttention( 86 | hidden_size, attn_type=attn_type 87 | ) 88 | if copy_attn: 89 | self._copy = True 90 | self._reuse_copy_attn = reuse_copy_attn 91 | 92 | def _build_rnn(self, rnn_type, input_size, 93 | hidden_size, num_layers, dropout): 94 | if rnn_type == "LSTM": 95 | stacked_cell = StackedLSTM 96 | else: 97 | stacked_cell = StackedGRU 98 | return stacked_cell(num_layers, input_size, 99 | hidden_size, dropout) 100 | 101 | def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): 102 | """ 103 | See StdRNNDecoder._run_forward_pass() for description 104 | of arguments and return values. 105 | """ 106 | # Additional args check. 107 | input_feed = state.input_feed.squeeze(0) 108 | input_feed_batch, _ = input_feed.size() 109 | tgt_batch, _, _ = tgt.size() 110 | aeq(tgt_batch, input_feed_batch) 111 | # END Additional args check. 112 | 113 | # Initialize local and return variables. 114 | decoder_outputs = [] 115 | attns = {"std": []} 116 | if self._copy: 117 | attns["copy"] = [] 118 | if self.training and self._coverage: 119 | attns["coverage"] = [] 120 | 121 | hidden = state.hidden 122 | coverage = state.coverage.squeeze(0) \ 123 | if state.coverage is not None else None 124 | 125 | # Input feed concatenates hidden state with 126 | # input at every time step. 127 | for _, emb_t in enumerate(tgt.transpose(0, 1).split(1)): 128 | emb_t = emb_t.squeeze(0) 129 | decoder_input = torch.cat([emb_t, input_feed], 1) 130 | 131 | rnn_output, hidden = self.rnn(decoder_input, hidden) 132 | decoder_output, p_attn = self.attn( 133 | rnn_output, 134 | memory_bank, 135 | memory_lengths=memory_lengths) 136 | 137 | decoder_output = self.dropout(decoder_output) 138 | input_feed = decoder_output 139 | 140 | decoder_outputs += [decoder_output] 141 | attns["std"] += [p_attn] 142 | 143 | # Update the coverage attention. 144 | if self.training and self._coverage: 145 | coverage = coverage + p_attn \ 146 | if coverage is not None else p_attn 147 | attns["coverage"] += [coverage] 148 | 149 | # Run the forward pass of the copy attention layer. 150 | if self._copy and not self._reuse_copy_attn: 151 | _, copy_attn = self.copy_attn(decoder_output, 152 | memory_bank) 153 | attns["copy"] += [copy_attn] 154 | elif self._copy: 155 | attns["copy"] = attns["std"] 156 | # Return result. 157 | return hidden, decoder_outputs, attns 158 | 159 | def forward(self, tgt, memory_bank, state, memory_masks=None, 160 | step=None): 161 | # Check 162 | assert isinstance(state, RNNDecoderState) 163 | # tgt.size() returns tgt length and batch 164 | tgt_batch, _ = tgt.size() 165 | memory_batch, _, _ = memory_bank.size() 166 | aeq(tgt_batch, memory_batch) 167 | # END 168 | memory_lengths = memory_masks.sum(dim=1) 169 | emb = self.embeddings(tgt) 170 | # Run the forward pass of the RNN. 171 | decoder_final, decoder_outputs, attns = self._run_forward_pass( 172 | emb, memory_bank, state, memory_lengths=memory_lengths) 173 | 174 | # Update the state with the result. 175 | final_output = decoder_outputs[-1] 176 | coverage = None 177 | if "coverage" in attns: 178 | coverage = attns["coverage"][-1].unsqueeze(0) 179 | state.update_state(decoder_final, final_output.unsqueeze(0), coverage) 180 | 181 | # Concatenates sequence of tensors along a new dimension. 182 | # NOTE: v0.3 to 0.4: decoder_outputs / attns[*] may not be list 183 | # (in particular in case of SRU) it was not raising error in 0.3 184 | # since stack(Variable) was allowed. 185 | # In 0.4, SRU returns a tensor that shouldn't be stacke 186 | if type(decoder_outputs) == list: 187 | decoder_outputs = torch.stack(decoder_outputs).transpose(0, 1) 188 | 189 | for k in attns: 190 | if type(attns[k]) == list: 191 | attns[k] = torch.stack(attns[k]).transpose(0, 1) 192 | 193 | return decoder_outputs, state, attns 194 | 195 | def init_decoder_state(self, src, memory_bank, encoder_final): 196 | """ Init decoder state with last state of the encoder """ 197 | def _fix_enc_hidden(hidden): 198 | # The encoder hidden is (layers*directions) x batch x dim. 199 | # We need to convert it to layers x batch x (directions*dim). 200 | if self.bidirectional_encoder: 201 | hidden = torch.cat([hidden[0:hidden.size(0):2], 202 | hidden[1:hidden.size(0):2]], 2) 203 | return hidden 204 | 205 | if isinstance(encoder_final, tuple): # LSTM 206 | return RNNDecoderState(self.hidden_size, 207 | tuple([_fix_enc_hidden(enc_hid) 208 | for enc_hid in encoder_final])) 209 | else: # GRU 210 | return RNNDecoderState(self.hidden_size, 211 | _fix_enc_hidden(encoder_final)) 212 | 213 | 214 | class RNNDecoderState(DecoderState): 215 | """ Base class for RNN decoder state """ 216 | 217 | def __init__(self, hidden_size, rnnstate): 218 | """ 219 | Args: 220 | hidden_size (int): the size of hidden layer of the decoder. 221 | rnnstate: final hidden state from the encoder. 222 | transformed to shape: layers x batch x (directions*dim). 223 | """ 224 | if not isinstance(rnnstate, tuple): 225 | self.hidden = (rnnstate,) 226 | else: 227 | self.hidden = rnnstate 228 | self.coverage = None 229 | 230 | # Init the input feed. 231 | batch_size = self.hidden[0].size(1) 232 | h_size = (batch_size, hidden_size) 233 | self.input_feed = self.hidden[0].data.new(*h_size).zero_() \ 234 | .unsqueeze(0) 235 | 236 | @property 237 | def _all(self): 238 | return self.hidden + (self.input_feed,) 239 | 240 | def update_state(self, rnnstate, input_feed, coverage): 241 | """ Update decoder state """ 242 | if not isinstance(rnnstate, tuple): 243 | self.hidden = (rnnstate,) 244 | else: 245 | self.hidden = rnnstate 246 | self.input_feed = input_feed 247 | self.coverage = coverage 248 | 249 | def repeat_beam_size_times(self, beam_size): 250 | """ Repeat beam_size times along batch dimension. """ 251 | vars = [e.data.repeat(1, beam_size, 1) 252 | for e in self._all] 253 | self.hidden = tuple(vars[:-1]) 254 | self.input_feed = vars[-1] 255 | 256 | def map_batch_fn(self, fn): 257 | self.hidden = tuple(map(lambda x: fn(x, 1), self.hidden)) 258 | self.input_feed = fn(self.input_feed, 1) 259 | 260 | 261 | class StackedLSTM(nn.Module): 262 | """ 263 | Our own implementation of stacked LSTM. 264 | Needed for the decoder, because we do input feeding. 265 | """ 266 | 267 | def __init__(self, num_layers, input_size, rnn_size, dropout): 268 | super(StackedLSTM, self).__init__() 269 | self.dropout = nn.Dropout(dropout) 270 | self.num_layers = num_layers 271 | self.layers = nn.ModuleList() 272 | 273 | for _ in range(num_layers): 274 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 275 | input_size = rnn_size 276 | 277 | def forward(self, input_feed, hidden): 278 | h_0, c_0 = hidden 279 | h_1, c_1 = [], [] 280 | for i, layer in enumerate(self.layers): 281 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 282 | input_feed = h_1_i 283 | if i + 1 != self.num_layers: 284 | input_feed = self.dropout(input_feed) 285 | h_1 += [h_1_i] 286 | c_1 += [c_1_i] 287 | 288 | h_1 = torch.stack(h_1) 289 | c_1 = torch.stack(c_1) 290 | 291 | return input_feed, (h_1, c_1) 292 | 293 | 294 | class StackedGRU(nn.Module): 295 | """ 296 | Our own implementation of stacked GRU. 297 | Needed for the decoder, because we do input feeding. 298 | """ 299 | 300 | def __init__(self, num_layers, input_size, rnn_size, dropout): 301 | super(StackedGRU, self).__init__() 302 | self.dropout = nn.Dropout(dropout) 303 | self.num_layers = num_layers 304 | self.layers = nn.ModuleList() 305 | 306 | for _ in range(num_layers): 307 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 308 | input_size = rnn_size 309 | 310 | def forward(self, input_feed, hidden): 311 | h_1 = [] 312 | for i, layer in enumerate(self.layers): 313 | h_1_i = layer(input_feed, hidden[0][i]) 314 | input_feed = h_1_i 315 | if i + 1 != self.num_layers: 316 | input_feed = self.dropout(input_feed) 317 | h_1 += [h_1_i] 318 | 319 | h_1 = torch.stack(h_1) 320 | return input_feed, (h_1,) 321 | -------------------------------------------------------------------------------- /src/models/decoder_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from models.encoder import PositionalEncoding 10 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState 11 | 12 | MAX_SIZE = 5000 13 | 14 | 15 | class TransformerDecoderLayer(nn.Module): 16 | """ 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | MultiHeadedAttention, also the input size of 20 | the first-layer of the PositionwiseFeedForward. 21 | heads (int): the number of heads for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the PositionwiseFeedForward. 23 | dropout (float): dropout probability(0-1.0). 24 | self_attn_type (string): type of self-attention scaled-dot, average 25 | """ 26 | 27 | def __init__(self, d_model, heads, d_ff, dropout, topic=False, topic_dim=300, split_noise=False): 28 | super(TransformerDecoderLayer, self).__init__() 29 | 30 | self.self_attn = MultiHeadedAttention( 31 | heads, d_model, dropout=dropout) 32 | 33 | self.context_attn = MultiHeadedAttention( 34 | heads, d_model, dropout=dropout, topic=topic, topic_dim=topic_dim, split_noise=split_noise) 35 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 36 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) 37 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 38 | self.drop = nn.Dropout(dropout) 39 | mask = self._get_attn_subsequent_mask(MAX_SIZE) 40 | # Register self.mask as a buffer in TransformerDecoderLayer, so 41 | # it gets TransformerDecoderLayer's cuda behavior automatically. 42 | self.register_buffer('mask', mask) 43 | 44 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, 45 | layer_cache=None, topic_vec=None, requires_att=False): 46 | """ 47 | Args: 48 | inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` 49 | memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` 50 | src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` 51 | tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` 52 | 53 | Returns: 54 | (`FloatTensor`, `FloatTensor`, `FloatTensor`): 55 | 56 | * output `[batch_size x 1 x model_dim]` 57 | * attn `[batch_size x 1 x src_len]` 58 | * all_input `[batch_size x current_step x model_dim]` 59 | 60 | """ 61 | dec_mask = torch.gt(tgt_pad_mask + 62 | self.mask[:, :tgt_pad_mask.size(1), 63 | :tgt_pad_mask.size(1)], 0) 64 | input_norm = self.layer_norm_1(inputs) 65 | all_input = input_norm 66 | if previous_input is not None: 67 | all_input = torch.cat((previous_input, input_norm), dim=1) 68 | dec_mask = None 69 | 70 | query, _ = self.self_attn(all_input, all_input, input_norm, 71 | mask=dec_mask, 72 | layer_cache=layer_cache, 73 | type="self") 74 | 75 | query = self.drop(query) + inputs 76 | 77 | query_norm = self.layer_norm_2(query) 78 | mid, att = self.context_attn(memory_bank, memory_bank, query_norm, 79 | mask=src_pad_mask, 80 | layer_cache=layer_cache, 81 | type="context", 82 | topic_vec=topic_vec, 83 | requires_att=requires_att) 84 | mid = self.drop(mid) + query 85 | 86 | output = self.feed_forward(mid) 87 | 88 | return output, all_input, att 89 | # return output 90 | 91 | def _get_attn_subsequent_mask(self, size): 92 | """ 93 | Get an attention mask to avoid using the subsequent info. 94 | 95 | Args: 96 | size: int 97 | 98 | Returns: 99 | (`LongTensor`): 100 | 101 | * subsequent_mask `[1 x size x size]` 102 | """ 103 | attn_shape = (1, size, size) 104 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 105 | subsequent_mask = torch.from_numpy(subsequent_mask) 106 | return subsequent_mask 107 | 108 | 109 | class TransformerDecoder(nn.Module): 110 | """ 111 | The Transformer decoder from "Attention is All You Need". 112 | 113 | 114 | .. mermaid:: 115 | 116 | graph BT 117 | A[input] 118 | B[multi-head self-attn] 119 | BB[multi-head src-attn] 120 | C[feed forward] 121 | O[output] 122 | A --> B 123 | B --> BB 124 | BB --> C 125 | C --> O 126 | 127 | 128 | Args: 129 | num_layers (int): number of encoder layers. 130 | d_model (int): size of the model 131 | heads (int): number of heads 132 | d_ff (int): size of the inner FF layer 133 | dropout (float): dropout parameters 134 | embeddings (:obj:`onmt.modules.Embeddings`): 135 | embeddings to use, should have positional encodings 136 | attn_type (str): if using a seperate copy attention 137 | """ 138 | 139 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings=None, 140 | topic=False, topic_dim=300, split_noise=False): 141 | super(TransformerDecoder, self).__init__() 142 | 143 | # Basic attributes. 144 | self.decoder_type = 'transformer' 145 | self.num_layers = num_layers 146 | 147 | if embeddings is not None: 148 | self.embeddings = embeddings 149 | self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) 150 | 151 | # Build TransformerDecoder. 152 | self.transformer_layers = nn.ModuleList( 153 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout, 154 | topic=topic, topic_dim=topic_dim, split_noise=split_noise) 155 | for _ in range(num_layers)]) 156 | 157 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 158 | 159 | def forward(self, tgt, memory_bank, state, init_tokens=None, 160 | step=None, cache=None, memory_masks=None, tgt_masks=None, 161 | requires_att=False, topic_vec=None): 162 | 163 | if tgt.dim() == 2: 164 | tgt_batch, tgt_len = tgt.size() 165 | 166 | # Run the forward pass of the TransformerDecoder. 167 | emb = self.embeddings(tgt) 168 | if init_tokens is not None: 169 | emb = torch.cat([init_tokens.unsqueeze(1), emb[:, 1:, :]], 1) 170 | assert emb.dim() == 3 # len x batch x embedding_dim 171 | 172 | output = self.pos_emb(emb, step) 173 | else: 174 | tgt_batch, tgt_len, _ = tgt.size() 175 | output = tgt 176 | 177 | if tgt_masks is not None: 178 | tgt_pad_mask = tgt_masks.unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len) 179 | else: 180 | assert tgt.dim() == 2 181 | padding_idx = self.embeddings.padding_idx 182 | tgt_pad_mask = tgt.data.eq(padding_idx).unsqueeze(1) \ 183 | .expand(tgt_batch, tgt_len, tgt_len) 184 | 185 | src_memory_bank = memory_bank 186 | if memory_masks is not None: 187 | src_batch = memory_masks.size(0) 188 | src_len = memory_masks.size(-1) 189 | src_pad_mask = memory_masks.unsqueeze(1).expand(src_batch, tgt_len, src_len) 190 | else: 191 | src_batch = memory_bank.size(0) 192 | src_len = memory_bank.size(1) 193 | src_pad_mask = tgt_pad_mask.new_zeros([src_batch, tgt_len, src_len]) 194 | 195 | if state.cache is None: 196 | saved_inputs = [] 197 | 198 | for i in range(self.num_layers): 199 | prev_layer_input = None 200 | if state.cache is None: 201 | if state.previous_input is not None: 202 | prev_layer_input = state.previous_layer_inputs[i] 203 | output, all_input, last_layer_att \ 204 | = self.transformer_layers[i]( 205 | output, src_memory_bank, 206 | src_pad_mask, tgt_pad_mask, 207 | previous_input=prev_layer_input, 208 | layer_cache=state.cache["layer_{}".format(i)] 209 | if state.cache is not None else None, 210 | topic_vec=topic_vec, 211 | requires_att=False if i < self.num_layers-1 else requires_att) 212 | if state.cache is None: 213 | saved_inputs.append(all_input) 214 | 215 | if state.cache is None: 216 | saved_inputs = torch.stack(saved_inputs) 217 | 218 | output = self.layer_norm(output) 219 | 220 | # Process the result and update the attentions. 221 | 222 | if state.cache is None: 223 | state = state.update_state(tgt, saved_inputs) 224 | 225 | if requires_att and last_layer_att is not None: 226 | return output, state, {"copy": last_layer_att} 227 | else: 228 | return output, state, None 229 | 230 | def init_decoder_state(self, src, memory_bank, enc_hidden=None, 231 | with_cache=False): 232 | """ Init decoder state """ 233 | state = TransformerDecoderState(src) 234 | if with_cache: 235 | state._init_cache(memory_bank, self.num_layers) 236 | return state 237 | 238 | 239 | class TransformerDecoderState(DecoderState): 240 | """ Transformer Decoder state base class """ 241 | 242 | def __init__(self, src): 243 | """ 244 | Args: 245 | src (FloatTensor): a sequence of source words tensors 246 | with optional feature tensors, of size (len x batch). 247 | """ 248 | self.src = src 249 | self.previous_input = None 250 | self.previous_layer_inputs = None 251 | self.cache = None 252 | 253 | @property 254 | def _all(self): 255 | """ 256 | Contains attributes that need to be updated in self.beam_update(). 257 | """ 258 | if (self.previous_input is not None 259 | and self.previous_layer_inputs is not None): 260 | return (self.previous_input, 261 | self.previous_layer_inputs, 262 | self.src) 263 | else: 264 | return (self.src,) 265 | 266 | def detach(self): 267 | if self.previous_input is not None: 268 | self.previous_input = self.previous_input.detach() 269 | if self.previous_layer_inputs is not None: 270 | self.previous_layer_inputs = self.previous_layer_inputs.detach() 271 | self.src = self.src.detach() 272 | 273 | def update_state(self, new_input, previous_layer_inputs): 274 | state = TransformerDecoderState(self.src) 275 | state.previous_input = new_input 276 | state.previous_layer_inputs = previous_layer_inputs 277 | return state 278 | 279 | def _init_cache(self, memory_bank, num_layers): 280 | self.cache = {} 281 | 282 | for l in range(num_layers): 283 | layer_cache = { 284 | "memory_keys": None, 285 | "memory_values": None 286 | } 287 | layer_cache["self_keys"] = None 288 | layer_cache["self_values"] = None 289 | self.cache["layer_{}".format(l)] = layer_cache 290 | 291 | def repeat_beam_size_times(self, beam_size): 292 | """ Repeat beam_size times along batch dimension. """ 293 | self.src = self.src.data.repeat(1, beam_size, 1) 294 | 295 | def map_batch_fn(self, fn): 296 | def _recursive_map(struct, batch_dim=0): 297 | for k, v in struct.items(): 298 | if v is not None: 299 | if isinstance(v, dict): 300 | _recursive_map(v) 301 | else: 302 | struct[k] = fn(v, batch_dim) 303 | 304 | self.src = fn(self.src, 0) 305 | if self.cache is not None: 306 | _recursive_map(self.cache) 307 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | 8 | from pytorch_transformers import BertModel 9 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward, rnn_factory 10 | 11 | 12 | class Classifier(nn.Module): 13 | def __init__(self, hidden_size): 14 | super(Classifier, self).__init__() 15 | self.linear1 = nn.Linear(hidden_size, 1) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def forward(self, x, mask_cls): 19 | h = self.linear1(x).squeeze(-1) 20 | sent_scores = self.sigmoid(h) * mask_cls.float() 21 | return sent_scores 22 | 23 | 24 | class Bert(nn.Module): 25 | def __init__(self, temp_dir, finetune=False): 26 | super(Bert, self).__init__() 27 | self.model = BertModel.from_pretrained(temp_dir) 28 | 29 | self.finetune = finetune 30 | 31 | def forward(self, x, segs, mask): 32 | if(self.finetune): 33 | top_vec, _ = self.model(x, segs, attention_mask=mask) 34 | else: 35 | self.eval() 36 | with torch.no_grad(): 37 | top_vec, _ = self.model(x, segs, attention_mask=mask) 38 | return top_vec 39 | 40 | 41 | class PositionalEncoding(nn.Module): 42 | 43 | def __init__(self, dropout, dim, max_len=5000): 44 | 45 | pe = torch.zeros(max_len, dim) 46 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 47 | -(math.log(10000.0) / dim))) 48 | position = torch.arange(0, max_len).unsqueeze(1) 49 | pe[:, 0::2] = torch.sin(position.float() * div_term) 50 | pe[:, 1::2] = torch.cos(position.float() * div_term) 51 | pe = pe.unsqueeze(0) 52 | 53 | super(PositionalEncoding, self).__init__() 54 | self.register_buffer('pe', pe) 55 | self.dropout = nn.Dropout(p=dropout) 56 | self.dim = dim 57 | 58 | def forward(self, emb, step=None, add_emb=None): 59 | emb = emb * math.sqrt(self.dim) 60 | if add_emb is not None: 61 | emb = emb + add_emb 62 | if (step): 63 | pos = self.pe[:, step][:, None, :] 64 | emb = emb + pos 65 | else: 66 | pos = self.pe[:, :emb.size(1)] 67 | emb = emb + pos 68 | emb = self.dropout(emb) 69 | return emb 70 | 71 | 72 | class DistancePositionalEncoding(nn.Module): 73 | def __init__(self, dim, max_len=5000): 74 | mid_pos = max_len // 2 75 | # absolute position embedding 76 | ape = torch.zeros(max_len, dim // 2) 77 | # distance position embedding 78 | dpe = torch.zeros(max_len, dim // 2) 79 | 80 | ap = torch.arange(0, max_len).unsqueeze(1) 81 | dp = torch.abs(torch.arange(0, max_len).unsqueeze(1) - mid_pos) 82 | 83 | div_term = torch.exp((torch.arange(0, dim//2, 2, dtype=torch.float) * 84 | -(math.log(10000.0) / dim * 2))) 85 | ape[:, 0::2] = torch.sin(ap.float() * div_term) 86 | ape[:, 1::2] = torch.cos(ap.float() * div_term) 87 | dpe[:, 0::2] = torch.sin(dp.float() * div_term) 88 | dpe[:, 1::2] = torch.cos(dp.float() * div_term) 89 | 90 | ape = ape.unsqueeze(0) 91 | super(DistancePositionalEncoding, self).__init__() 92 | self.register_buffer('ape', ape) 93 | self.register_buffer('dpe', dpe) 94 | self.dim = dim 95 | self.mid_pos = mid_pos 96 | 97 | def forward(self, emb, shift): 98 | device = emb.device 99 | _, length, _ = emb.size() 100 | pe_seg = [len(ex) for ex in shift] 101 | medium_pos = [torch.cat([torch.tensor([0], device=device), 102 | (ex[1:] + ex[:-1]) // 2 + 1, 103 | torch.tensor([length], device=device)], 0) 104 | for ex in shift] 105 | shift = torch.cat(shift, 0) 106 | index = torch.arange(self.mid_pos, self.mid_pos + length, device=device).\ 107 | unsqueeze(0).expand(len(shift), length) - shift.unsqueeze(1) 108 | index = torch.split(index, pe_seg) 109 | dp_index = [] 110 | for i in range(len(index)): 111 | dpi = torch.zeros([length], device=device) 112 | for j in range(len(index[i])): 113 | dpi[medium_pos[i][j]:medium_pos[i][j+1]] = index[i][j][medium_pos[i][j]:medium_pos[i][j+1]] 114 | dp_index.append(dpi.unsqueeze(0)) 115 | dp_index = torch.cat(dp_index, 0).long() 116 | 117 | dpe = self.dpe[dp_index] 118 | ape = self.ape[:, :emb.size(1)].expand(emb.size(0), emb.size(1), -1) 119 | pe = torch.cat([dpe, ape], -1) 120 | emb = emb + pe 121 | return emb 122 | 123 | 124 | class RelativePositionalEncoding(nn.Module): 125 | def __init__(self, dim, max_len=5000): 126 | mid_pos = max_len // 2 127 | # relative position embedding 128 | pe = torch.zeros(max_len, dim) 129 | 130 | position = torch.arange(0, max_len).unsqueeze(1) - mid_pos 131 | 132 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 133 | -(math.log(10000.0) / dim))) 134 | pe[:, 0::2] = torch.sin(position.float() * div_term) 135 | pe[:, 1::2] = torch.cos(position.float() * div_term) 136 | 137 | super(RelativePositionalEncoding, self).__init__() 138 | self.register_buffer('pe', pe) 139 | self.dim = dim 140 | self.mid_pos = mid_pos 141 | 142 | def forward(self, emb, shift): 143 | device = emb.device 144 | bsz, length, _ = emb.size() 145 | index = torch.arange(self.mid_pos, self.mid_pos + emb.size(1), device=device).\ 146 | unsqueeze(0).expand(bsz, length) - shift.unsqueeze(1) 147 | pe = self.pe[index] 148 | emb = emb + pe 149 | return emb 150 | 151 | def get_emb(self, emb, shift): 152 | device = emb.device 153 | index = torch.arange(self.mid_pos, self.mid_pos + emb.size(1), device=device).\ 154 | unsqueeze(0).expand(emb.size(0), emb.size(1)) - shift.unsqueeze(1) 155 | return self.pe[index] 156 | 157 | 158 | class RNNEncoder(nn.Module): 159 | """ A generic recurrent neural network encoder. 160 | Args: 161 | rnn_type (str): 162 | style of recurrent unit to use, one of [RNN, LSTM, GRU] 163 | bidirectional (bool) : use a bidirectional RNN 164 | num_layers (int) : number of stacked layers 165 | hidden_size (int) : hidden size of each layer 166 | dropout (float) : dropout value for :class:`torch.nn.Dropout` 167 | """ 168 | 169 | def __init__(self, rnn_type, bidirectional, num_layers, 170 | hidden_size, dropout=0.0, embeddings=None): 171 | super(RNNEncoder, self).__init__() 172 | assert embeddings is not None 173 | 174 | num_directions = 2 if bidirectional else 1 175 | assert hidden_size % num_directions == 0 176 | hidden_size = hidden_size // num_directions 177 | self.embeddings = embeddings 178 | 179 | self.rnn = rnn_factory(rnn_type, 180 | input_size=embeddings.embedding_dim, 181 | hidden_size=hidden_size, 182 | num_layers=num_layers, 183 | dropout=dropout, 184 | bidirectional=bidirectional, 185 | batch_first=True) 186 | 187 | def forward(self, src, mask): 188 | 189 | emb = self.embeddings(src) 190 | # s_len, batch, emb_dim = emb.size() 191 | lengths = mask.sum(dim=1) 192 | 193 | # Lengths data is wrapped inside a Tensor. 194 | lengths_list = lengths.view(-1).tolist() 195 | packed_emb = pack(emb, lengths_list, batch_first=True, enforce_sorted=False) 196 | 197 | memory_bank, encoder_final = self.rnn(packed_emb) 198 | 199 | memory_bank = unpack(memory_bank, batch_first=True)[0] 200 | 201 | return memory_bank, encoder_final 202 | 203 | 204 | class TransformerEncoderLayer(nn.Module): 205 | def __init__(self, d_model, heads, d_ff, dropout): 206 | super(TransformerEncoderLayer, self).__init__() 207 | 208 | self.self_attn = MultiHeadedAttention( 209 | heads, d_model, dropout=dropout) 210 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 211 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 212 | self.dropout = nn.Dropout(dropout) 213 | 214 | def forward(self, iter, inputs, mask): 215 | if (iter != 0): 216 | input_norm = self.layer_norm(inputs) 217 | else: 218 | input_norm = inputs 219 | 220 | mask = mask.unsqueeze(1) 221 | context, _ = self.self_attn(input_norm, input_norm, input_norm, 222 | mask=mask, type='self') 223 | out = self.dropout(context) + inputs 224 | return self.feed_forward(out) 225 | 226 | 227 | class TransformerEncoder(nn.Module): 228 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 229 | super(TransformerEncoder, self).__init__() 230 | self.num_inter_layers = num_inter_layers 231 | self.pos_emb = PositionalEncoding(dropout, d_model) 232 | self.transformer = nn.ModuleList( 233 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 234 | for _ in range(num_inter_layers)]) 235 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 236 | 237 | def forward(self, top_vecs, mask): 238 | """ See :obj:`EncoderBase.forward()`""" 239 | 240 | x = self.pos_emb(top_vecs) 241 | 242 | for i in range(self.num_inter_layers): 243 | x = self.transformer[i](i, x, mask) # all_sents * max_tokens * dim 244 | 245 | output = self.layer_norm(x) 246 | 247 | return output 248 | -------------------------------------------------------------------------------- /src/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.neural import aeq 4 | from models.neural import gumbel_softmax 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, vocab_size, dec_hidden_size, pad_idx): 9 | super(Generator, self).__init__() 10 | self.linear = nn.Linear(dec_hidden_size, vocab_size) 11 | self.softmax = nn.LogSoftmax(dim=-1) 12 | self.pad_idx = pad_idx 13 | 14 | def forward(self, x, use_gumbel_softmax=False): 15 | output = self.linear(x) 16 | output[:, self.pad_idx] = -float('inf') 17 | if use_gumbel_softmax: 18 | output = gumbel_softmax(output, log_mode=True, dim=-1) 19 | else: 20 | output = self.softmax(output) 21 | return output 22 | 23 | 24 | class PointerNetGenerator(nn.Module): 25 | def __init__(self, mem_hidden_size, dec_hidden_size, hidden_size): 26 | super(PointerNetGenerator, self).__init__() 27 | self.terminate_state = nn.Parameter(torch.empty(1, mem_hidden_size)) 28 | self.linear_dec = nn.Linear(dec_hidden_size, hidden_size) 29 | self.linear_mem = nn.Linear(mem_hidden_size, hidden_size) 30 | self.score_linear = nn.Linear(hidden_size, 1) 31 | self.tanh = nn.Tanh() 32 | self.softmax = nn.LogSoftmax(dim=-1) 33 | 34 | def forward(self, mem, dec_hid, mem_mask, dec_mask, dup_mask): 35 | 36 | batch_size = mem.size(0) 37 | 38 | # Add terminate state 39 | mem = torch.cat([self.terminate_state.unsqueeze(0).expand(batch_size, 1, -1), mem], 1) 40 | mem_mask = torch.cat([torch.zeros([batch_size, 1], dtype=mem_mask.dtype, device=mem_mask.device), mem_mask], 1) 41 | 42 | mem_len = mem.size(1) 43 | dec_len = dec_hid.size(1) 44 | 45 | # batch * dec_len * mem_len * hid_size 46 | mem_expand = mem.unsqueeze(1).expand(batch_size, dec_len, mem_len, -1) 47 | dec_expand = dec_hid.unsqueeze(2).expand(batch_size, dec_len, mem_len, -1) 48 | mask_expand = mem_mask.unsqueeze(1).expand(batch_size, dec_len, mem_len) 49 | score = self.score_linear(self.tanh(self.linear_mem(mem_expand) + self.linear_dec(dec_expand))).squeeze_(-1) 50 | score[mask_expand] = -float('inf') 51 | 52 | # Avoid duplicate extraction. 53 | dup_mask[dec_mask, :] = 0 54 | if score.requires_grad: 55 | dup_mask = dup_mask.float() 56 | dup_mask[dup_mask == 1] = -float('inf') 57 | score = dup_mask + score 58 | else: 59 | score[dup_mask.byte()] = -float('inf') 60 | 61 | output = self.softmax(score) 62 | return output 63 | 64 | 65 | class CopyGenerator(nn.Module): 66 | """Generator module that additionally considers copying 67 | words directly from the source. 68 | The main idea is that we have an extended "dynamic dictionary". 69 | It contains `|tgt_dict|` words plus an arbitrary number of 70 | additional words introduced by the source sentence. 71 | For each source sentence we have a `src_map` that maps 72 | each source word to an index in `tgt_dict` if it known, or 73 | else to an extra word. 74 | The copy generator is an extended version of the standard 75 | generator that computes three values. 76 | * :math:`p_{softmax}` the standard softmax over `tgt_dict` 77 | * :math:`p(z)` the probability of copying a word from 78 | the source 79 | * :math:`p_{copy}` the probility of copying a particular word. 80 | taken from the attention distribution directly. 81 | The model returns a distribution over the extend dictionary, 82 | computed as 83 | :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` 84 | .. mermaid:: 85 | graph BT 86 | A[input] 87 | S[src_map] 88 | B[softmax] 89 | BB[switch] 90 | C[attn] 91 | D[copy] 92 | O[output] 93 | A --> B 94 | A --> BB 95 | S --> D 96 | C --> D 97 | D --> O 98 | B --> O 99 | BB --> O 100 | Args: 101 | input_size (int): size of input representation 102 | output_size (int): size of output representation 103 | """ 104 | 105 | def __init__(self, output_size, input_size, pad_idx): 106 | super(CopyGenerator, self).__init__() 107 | self.linear = nn.Linear(input_size, output_size) 108 | self.linear_copy = nn.Linear(input_size, 1) 109 | self.softmax = nn.Softmax(dim=1) 110 | self.sigmoid = nn.Sigmoid() 111 | self.padding_idx = pad_idx 112 | 113 | def forward(self, hidden, attn, src_map): 114 | """ 115 | Compute a distribution over the target dictionary 116 | extended by the dynamic dictionary implied by compying 117 | source words. 118 | Args: 119 | hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]` 120 | attn (`FloatTensor`): attn for each `[batch*tlen, input_size]` 121 | src_map (`FloatTensor`): 122 | A sparse indicator matrix mapping each source word to 123 | its index in the "extended" vocab containing. 124 | `[src_len, batch, extra_words]` 125 | """ 126 | # CHECKS 127 | batch_by_tlen, _ = hidden.size() 128 | batch_by_tlen_, slen = attn.size() 129 | batch, slen_, cvocab = src_map.size() 130 | aeq(batch_by_tlen, batch_by_tlen_) 131 | aeq(slen, slen_) 132 | 133 | # Original probabilities. 134 | logits = self.linear(hidden) 135 | logits[:, self.padding_idx] = -float('inf') 136 | prob = self.softmax(logits) 137 | 138 | # Probability of copying p(z=1) batch. 139 | p_copy = self.sigmoid(self.linear_copy(hidden)) 140 | # Probibility of not copying: p_{word}(w) * (1 - p(z)) 141 | out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) 142 | mul_attn = torch.mul(attn, p_copy.expand_as(attn)) 143 | copy_prob = torch.bmm(mul_attn.view(batch, -1, slen), src_map) 144 | copy_prob = copy_prob.view(-1, cvocab) 145 | return torch.cat([out_prob, copy_prob], 1) 146 | 147 | 148 | def collapse_copy_scores(scores, batch, tgt_vocab, batch_index=None): 149 | """ 150 | Given scores from an expanded dictionary 151 | corresponeding to a batch, sums together copies, 152 | with a dictionary word when it is ambigious. 153 | """ 154 | offset = len(tgt_vocab) 155 | for b in range(scores.size(0)): 156 | blank = [] 157 | fill = [] 158 | 159 | if batch_index is not None: 160 | src_vocab = batch.src_vocabs[batch_index[b]] 161 | else: 162 | src_vocab = batch.src_vocabs[b] 163 | 164 | for i in range(1, len(src_vocab)): 165 | ti = src_vocab.itos[i] 166 | if ti != 0: 167 | blank.append(offset + i) 168 | fill.append(ti) 169 | if blank: 170 | blank = torch.tensor(blank, device=scores.device) 171 | fill = torch.tensor(fill, device=scores.device) 172 | scores[b, :].index_add_(1, fill, 173 | scores[b, :].index_select(1, blank)) 174 | scores[b, :].index_fill_(1, blank, 1e-10) 175 | return scores 176 | -------------------------------------------------------------------------------- /src/models/hier_model_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | import distributed 7 | from models.reporter import ReportMgr, Statistics 8 | from models.loss import abs_loss 9 | from others.logging import logger 10 | 11 | 12 | def _tally_parameters(model): 13 | n_params = sum([p.nelement() for p in model.parameters()]) 14 | return n_params 15 | 16 | 17 | def build_trainer(args, device_id, model, optims, tokenizer): 18 | """ 19 | Simplify `Trainer` creation based on user `opt`s* 20 | Args: 21 | opt (:obj:`Namespace`): user options (usually from argument parsing) 22 | model (:obj:`onmt.models.NMTModel`): the model to train 23 | fields (dict): dict of fields 24 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 25 | data_type (str): string describing the type of data 26 | e.g. "text", "img", "audio" 27 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 28 | used to save the model 29 | """ 30 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 31 | 32 | grad_accum_count = args.accum_count 33 | n_gpu = args.world_size 34 | 35 | if device_id >= 0: 36 | gpu_rank = int(args.gpu_ranks[device_id]) 37 | else: 38 | gpu_rank = 0 39 | n_gpu = 0 40 | 41 | print('gpu_rank %d' % gpu_rank) 42 | 43 | tensorboard_log_dir = args.model_path 44 | 45 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 46 | 47 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 48 | 49 | symbols = {'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 50 | 'PAD': tokenizer.vocab['[PAD]'], 'SEG': tokenizer.vocab['[unused3]'], 51 | 'UNK': tokenizer.vocab['[UNK]']} 52 | 53 | gen_loss = abs_loss(args, model.generator, symbols, tokenizer.vocab, device, train=True) 54 | 55 | trainer = Trainer(args, model, optims, tokenizer, gen_loss, 56 | grad_accum_count, n_gpu, gpu_rank, report_manager) 57 | 58 | # print(tr) 59 | if (model): 60 | n_params = _tally_parameters(model) 61 | logger.info('* number of parameters: %d' % n_params) 62 | 63 | return trainer 64 | 65 | 66 | class Trainer(object): 67 | """ 68 | Class that controls the training process. 69 | 70 | Args: 71 | model(:py:class:`onmt.models.model.NMTModel`): translation model 72 | to train 73 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 74 | training loss computation 75 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 76 | training loss computation 77 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 78 | the optimizer responsible for update 79 | trunc_size(int): length of truncated back propagation through time 80 | shard_size(int): compute loss in shards of this size for efficiency 81 | data_type(string): type of the source input: [text|img|audio] 82 | norm_method(string): normalization methods: [sents|tokens] 83 | grad_accum_count(int): accumulate gradients this many times. 84 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 85 | the object that creates reports, or None 86 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 87 | used to save a checkpoint. 88 | Thus nothing will be saved if this parameter is None 89 | """ 90 | 91 | def __init__(self, args, model, optims, tokenizer, abs_loss, 92 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 93 | report_manager=None): 94 | # Basic attributes. 95 | self.args = args 96 | self.save_checkpoint_steps = args.save_checkpoint_steps 97 | self.model = model 98 | self.optims = optims 99 | self.tokenizer = tokenizer 100 | self.grad_accum_count = grad_accum_count 101 | self.n_gpu = n_gpu 102 | self.gpu_rank = gpu_rank 103 | self.report_manager = report_manager 104 | 105 | self.abs_loss = abs_loss 106 | 107 | assert grad_accum_count > 0 108 | # Set model in training mode. 109 | if (model): 110 | self.model.train() 111 | 112 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 113 | """ 114 | The main training loops. 115 | by iterating over training data (i.e. `train_iter_fct`) 116 | and running validation (i.e. iterating over `valid_iter_fct` 117 | 118 | Args: 119 | train_iter_fct(function): a function that returns the train 120 | iterator. e.g. something like 121 | train_iter_fct = lambda: generator(*args, **kwargs) 122 | valid_iter_fct(function): same as train_iter_fct, for valid data 123 | train_steps(int): 124 | valid_steps(int): 125 | save_checkpoint_steps(int): 126 | 127 | Return: 128 | None 129 | """ 130 | logger.info('Start training...') 131 | 132 | step = self.optims[0]._step + 1 133 | true_batchs = [] 134 | accum = 0 135 | tgt_tokens = 0 136 | src_tokens = 0 137 | sents = 0 138 | examples = 0 139 | 140 | train_iter = train_iter_fct() 141 | total_stats = Statistics() 142 | report_stats = Statistics() 143 | self._start_report_manager(start_time=total_stats.start_time) 144 | 145 | while step <= train_steps: 146 | 147 | for i, batch in enumerate(train_iter): 148 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 149 | 150 | true_batchs.append(batch) 151 | tgt_tokens += batch.tgt[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 152 | src_tokens += batch.src[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 153 | sents += batch.src.size(0) 154 | examples += batch.tgt.size(0) 155 | accum += 1 156 | if accum == self.grad_accum_count: 157 | if self.n_gpu > 1: 158 | tgt_tokens = sum(distributed.all_gather_list(tgt_tokens)) 159 | src_tokens = sum(distributed.all_gather_list(src_tokens)) 160 | sents = sum(distributed.all_gather_list(sents)) 161 | examples = sum(distributed.all_gather_list(examples)) 162 | 163 | normalization = (tgt_tokens, src_tokens, sents, examples) 164 | self._gradient_calculation( 165 | true_batchs, normalization, total_stats, 166 | report_stats, step) 167 | 168 | report_stats = self._maybe_report_training( 169 | step, train_steps, 170 | self.optims[0].learning_rate, 171 | report_stats) 172 | 173 | true_batchs = [] 174 | accum = 0 175 | src_tokens = 0 176 | tgt_tokens = 0 177 | sents = 0 178 | examples = 0 179 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 180 | self._save(step) 181 | step += 1 182 | if step > train_steps: 183 | break 184 | train_iter = train_iter_fct() 185 | 186 | return total_stats 187 | 188 | def _gradient_calculation(self, true_batchs, normalization, total_stats, 189 | report_stats, step): 190 | self.model.zero_grad() 191 | 192 | for batch in true_batchs: 193 | outputs, _, topic_loss = self.model(batch) 194 | 195 | tgt_tokens, src_tokens, sents, examples = normalization 196 | 197 | if self.args.topic_model: 198 | # Topic Model loss 199 | topic_stats = Statistics(topic_loss=topic_loss.clone().item() / float(examples)) 200 | topic_loss.div(float(examples)).backward(retain_graph=True) 201 | total_stats.update(topic_stats) 202 | report_stats.update(topic_stats) 203 | 204 | # Auto-encoder loss 205 | abs_stats = self.abs_loss(batch, outputs, self.args.generator_shard_size, 206 | tgt_tokens, retain_graph=False) 207 | abs_stats.n_docs = len(batch) 208 | total_stats.update(abs_stats) 209 | report_stats.update(abs_stats) 210 | 211 | # in case of multi step gradient accumulation, 212 | # update only after accum batches 213 | if self.n_gpu > 1: 214 | grads = [p.grad.data for p in self.model.parameters() 215 | if p.requires_grad 216 | and p.grad is not None] 217 | distributed.all_reduce_and_rescale_tensors( 218 | grads, float(1)) 219 | for o in self.optims: 220 | o.step() 221 | 222 | def _save(self, step): 223 | real_model = self.model 224 | 225 | model_state_dict = real_model.state_dict() 226 | # generator_state_dict = real_generator.state_dict() 227 | checkpoint = { 228 | 'model': model_state_dict, 229 | # 'generator': generator_state_dict, 230 | 'opt': self.args, 231 | 'optims': self.optims, 232 | } 233 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 234 | logger.info("Saving checkpoint %s" % checkpoint_path) 235 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 236 | torch.save(checkpoint, checkpoint_path) 237 | return checkpoint, checkpoint_path 238 | 239 | def _start_report_manager(self, start_time=None): 240 | """ 241 | Simple function to start report manager (if any) 242 | """ 243 | if self.report_manager is not None: 244 | if start_time is None: 245 | self.report_manager.start() 246 | else: 247 | self.report_manager.start_time = start_time 248 | 249 | def _maybe_gather_stats(self, stat): 250 | """ 251 | Gather statistics in multi-processes cases 252 | 253 | Args: 254 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 255 | or None (it returns None in this case) 256 | 257 | Returns: 258 | stat: the updated (or unchanged) stat object 259 | """ 260 | if stat is not None and self.n_gpu > 1: 261 | return Statistics.all_gather_stats(stat) 262 | return stat 263 | 264 | def _maybe_report_training(self, step, num_steps, learning_rate, 265 | report_stats): 266 | """ 267 | Simple function to report training stats (if report_manager is set) 268 | see `onmt.utils.ReportManagerBase.report_training` for doc 269 | """ 270 | if self.report_manager is not None: 271 | return self.report_manager.report_training( 272 | step, num_steps, learning_rate, report_stats, 273 | multigpu=self.n_gpu > 1) 274 | 275 | def _report_step(self, learning_rate, step, train_stats=None, 276 | valid_stats=None): 277 | """ 278 | Simple function to report stats (if report_manager is set) 279 | see `onmt.utils.ReportManagerBase.report_step` for doc 280 | """ 281 | if self.report_manager is not None: 282 | return self.report_manager.report_step( 283 | learning_rate, step, train_stats=train_stats, 284 | valid_stats=valid_stats) 285 | 286 | def _maybe_save(self, step): 287 | """ 288 | Save the model if a model saver is set 289 | """ 290 | if self.model_saver is not None: 291 | self.model_saver.maybe_save(step) 292 | -------------------------------------------------------------------------------- /src/models/hier_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf8-*- 3 | """ Translator Class and builder """ 4 | from __future__ import print_function 5 | import codecs 6 | import torch 7 | 8 | from tensorboardX import SummaryWriter 9 | from others.utils import rouge_results_to_str, test_bleu, test_length 10 | from translate.beam import GNMTGlobalScorer 11 | from rouge import Rouge, FilesRouge 12 | from nltk.translate.bleu_score import sentence_bleu 13 | from nltk.translate.bleu_score import SmoothingFunction 14 | 15 | 16 | def build_predictor(args, tokenizer, model, logger=None): 17 | scorer = GNMTGlobalScorer(args.alpha, length_penalty='wu') 18 | 19 | translator = Translator(args, model, tokenizer, 20 | global_scorer=scorer, logger=logger) 21 | return translator 22 | 23 | 24 | class Translator(object): 25 | """ 26 | Uses a model to translate a batch of sentences. 27 | 28 | 29 | Args: 30 | model (:obj:`onmt.modules.NMTModel`): 31 | NMT model to use for translation 32 | fields (dict of Fields): data fields 33 | beam_size (int): size of beam to use 34 | n_best (int): number of translations produced 35 | max_length (int): maximum length output to produce 36 | global_scores (:obj:`GlobalScorer`): 37 | object to rescore final translations 38 | copy_attn (bool): use copy attention during translation 39 | cuda (bool): use cuda 40 | beam_trace (bool): trace beam search for debugging 41 | logger(logging.Logger): logger. 42 | """ 43 | 44 | def __init__(self, 45 | args, 46 | model, 47 | tokenizer, 48 | global_scorer=None, 49 | logger=None, 50 | dump_beam=""): 51 | self.logger = logger 52 | self.cuda = args.visible_gpus != '-1' 53 | 54 | self.args = args 55 | self.model = model 56 | self.generator = self.model.generator 57 | self.tokenizer = tokenizer 58 | self.vocab = tokenizer.vocab 59 | self.start_token = self.vocab['[unused1]'] 60 | self.end_token = self.vocab['[unused2]'] 61 | self.seg_token = self.vocab['[unused3]'] 62 | 63 | self.global_scorer = global_scorer 64 | self.beam_size = args.beam_size 65 | self.min_length = args.min_length 66 | self.max_length = args.max_length 67 | 68 | self.dump_beam = dump_beam 69 | 70 | # for debugging 71 | self.beam_trace = self.dump_beam != "" 72 | self.beam_accum = None 73 | 74 | tensorboard_log_dir = args.model_path 75 | 76 | self.tensorboard_writer = SummaryWriter( 77 | tensorboard_log_dir, comment="Unmt") 78 | 79 | if self.beam_trace: 80 | self.beam_accum = { 81 | "predicted_ids": [], 82 | "beam_parent_ids": [], 83 | "scores": [], 84 | "log_probs": []} 85 | 86 | def _build_target_tokens(self, pred): 87 | # vocab = self.fields["tgt"].vocab 88 | tokens = [] 89 | for tok in pred: 90 | tok = int(tok) 91 | tokens.append(tok) 92 | if tokens[-1] == self.end_token: 93 | tokens = tokens[:-1] 94 | break 95 | tokens = [t for t in tokens if t < len(self.vocab)] 96 | tokens = self.vocab.DecodeIds(tokens).split(' ') 97 | return tokens 98 | 99 | def from_batch_dev(self, doc_batch, tgt_data): 100 | 101 | translations = [] 102 | 103 | batch_size = len(doc_batch) 104 | 105 | for b in range(batch_size): 106 | 107 | # generated text 108 | pred_summ = self.tokenizer.convert_ids_to_tokens( 109 | [int(n) for n in doc_batch[b]]) 110 | pred_summ = ' '.join(pred_summ) 111 | pred_summ = pred_summ.replace('[unused0]', '').replace('[unused1]', '').\ 112 | replace('[unused2]', '').replace('[unused5]', '#').replace('[UNK]', '').strip() 113 | pred_summ = ' '.join(pred_summ.split()) 114 | 115 | gold_data = ' '.join(tgt_data[b]) 116 | gold_data = gold_data.replace('[PAD]', '').replace('[unused1]', '').\ 117 | replace('[unused2]', '').replace('[unused5]', '#').replace("[UNK]", '').strip() 118 | gold_data = ' '.join(gold_data.split()) 119 | 120 | translations.append((pred_summ, gold_data)) 121 | 122 | return translations 123 | 124 | def from_batch_test(self, batch, output_batch, tgt_data): 125 | 126 | translations = [] 127 | 128 | batch_size = len(batch) 129 | 130 | origin_txt, ex_segs = batch.original_str, batch.ex_segs 131 | 132 | ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)] 133 | 134 | for b in range(batch_size): 135 | # original text 136 | original_sent = ' '.join(origin_txt[ex_segs[b]:ex_segs[b+1]]) 137 | 138 | # long doc context text 139 | pred_summ = self.tokenizer.convert_ids_to_tokens( 140 | [int(n) for n in output_batch[b]]) 141 | pred_summ = ' '.join(pred_summ) 142 | 143 | pred_summ = pred_summ.replace('[unused0]', '').replace('[unused1]', '').\ 144 | replace('[unused2]', '').replace('[unused5]', '#').replace('[UNK]', '').strip() 145 | pred_summ = ' '.join(pred_summ.split()) 146 | 147 | gold_data = ' '.join(tgt_data[b]) 148 | gold_data = gold_data.replace('[PAD]', '').replace('[unused1]', '').replace('[unused2]', '').\ 149 | replace('[unused5]', '#').replace('[UNK]', '').strip() 150 | gold_data = ' '.join(gold_data.split()) 151 | 152 | translation = (original_sent, pred_summ, gold_data) 153 | translations.append(translation) 154 | 155 | return translations 156 | 157 | def validate(self, data_iter, step, attn_debug=False): 158 | 159 | self.model.eval() 160 | gold_path = self.args.result_path + 'step.%d.gold_temp' % step 161 | pred_path = self.args.result_path + 'step.%d.pred_temp' % step 162 | gold_out_file = codecs.open(gold_path, 'w', 'utf-8') 163 | pred_out_file = codecs.open(pred_path, 'w', 'utf-8') 164 | 165 | # pred_results, gold_results = [], [] 166 | ct = 0 167 | with torch.no_grad(): 168 | for batch in data_iter: 169 | output_data, tgt_data = self.translate_batch(batch) 170 | translations = self.from_batch_dev(output_data, tgt_data) 171 | 172 | for idx in range(len(translations)): 173 | if ct % 100 == 0: 174 | print("Processing %d" % ct) 175 | pred_summ, gold_data = translations[idx] 176 | pred_out_file.write(pred_summ + '\n') 177 | gold_out_file.write(gold_data + '\n') 178 | ct += 1 179 | pred_out_file.flush() 180 | gold_out_file.flush() 181 | 182 | pred_out_file.close() 183 | gold_out_file.close() 184 | 185 | if (step != -1): 186 | pred_bleu = test_bleu(pred_path, gold_path) 187 | file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) 188 | pred_rouges = file_rouge.get_scores(avg=True) 189 | self.logger.info('Gold Length at step %d: %.2f' % 190 | (step, test_length(gold_path, gold_path, ratio=False))) 191 | self.logger.info('Prediction Length ratio at step %d: %.2f' % 192 | (step, test_length(pred_path, gold_path))) 193 | self.logger.info('Prediction Bleu at step %d: %.2f' % 194 | (step, pred_bleu*100)) 195 | self.logger.info('Prediction Rouges at step %d: \n%s\n' % 196 | (step, rouge_results_to_str(pred_rouges))) 197 | rouge_results = (pred_rouges["rouge-1"]['f'], 198 | pred_rouges["rouge-l"]['f']) 199 | return rouge_results 200 | 201 | def translate(self, 202 | data_iter, step, 203 | attn_debug=False): 204 | 205 | self.model.eval() 206 | output_path = self.args.result_path + '.%d.output' % step 207 | output_file = codecs.open(output_path, 'w', 'utf-8') 208 | gold_path = self.args.result_path + '.%d.gold_test' % step 209 | pred_path = self.args.result_path + '.%d.pred_test' % step 210 | gold_out_file = codecs.open(gold_path, 'w', 'utf-8') 211 | pred_out_file = codecs.open(pred_path, 'w', 'utf-8') 212 | # pred_results, gold_results = [], [] 213 | ct = 0 214 | 215 | with torch.no_grad(): 216 | rouge = Rouge() 217 | for batch in data_iter: 218 | output_data, tgt_data = self.translate_batch(batch) 219 | translations = self.from_batch_test(batch, output_data, tgt_data) 220 | 221 | for idx in range(len(translations)): 222 | origin_sent, pred_summ, gold_data = translations[idx] 223 | if ct % 100 == 0: 224 | print("Processing %d" % ct) 225 | output_file.write("ID : %d\n" % ct) 226 | output_file.write("ORIGIN : \n " + origin_sent.replace('', '\n ') + "\n") 227 | output_file.write("GOLD : " + gold_data.strip() + "\n") 228 | output_file.write("DOC_GEN : " + pred_summ.strip() + "\n") 229 | rouge_score = rouge.get_scores(pred_summ, gold_data) 230 | bleu_score = sentence_bleu([gold_data.split()], pred_summ.split(), 231 | smoothing_function=SmoothingFunction().method1) 232 | output_file.write("DOC_GEN bleu & rouge-f 1/l: %.4f & %.4f/%.4f\n\n" % 233 | (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-l"]["f"])) 234 | pred_out_file.write(pred_summ.strip() + '\n') 235 | gold_out_file.write(gold_data.strip() + '\n') 236 | ct += 1 237 | pred_out_file.flush() 238 | gold_out_file.flush() 239 | output_file.flush() 240 | 241 | pred_out_file.close() 242 | gold_out_file.close() 243 | output_file.close() 244 | 245 | if (step != -1): 246 | pred_bleu = test_bleu(pred_path, gold_path) 247 | file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) 248 | pred_rouges = file_rouge.get_scores(avg=True) 249 | self.logger.info('Gold Length at step %d: %.2f\n' % 250 | (step, test_length(gold_path, gold_path, ratio=False))) 251 | self.logger.info('Prediction Length ratio at step %d: %.2f' % 252 | (step, test_length(pred_path, gold_path))) 253 | self.logger.info('Prediction Bleu at step %d: %.2f' % 254 | (step, pred_bleu*100)) 255 | self.logger.info('Prediction Rouges at step %d: \n%s' % 256 | (step, rouge_results_to_str(pred_rouges))) 257 | 258 | def translate_batch(self, batch): 259 | """ 260 | Translate a batch of sentences. 261 | 262 | Mostly a wrapper around :obj:`Beam`. 263 | 264 | Args: 265 | batch (:obj:`Batch`): a batch from a dataset object 266 | data (:obj:`Dataset`): the dataset object 267 | fast (bool): enables fast beam search (may not support all features) 268 | 269 | Todo: 270 | Shouldn't need the original dataset. 271 | """ 272 | _, output_data, _ = self.model(batch) 273 | tgt_txt = batch.tgt_txt 274 | return output_data, tgt_txt 275 | 276 | 277 | class Translation(object): 278 | """ 279 | Container for a translated sentence. 280 | 281 | Attributes: 282 | src (`LongTensor`): src word ids 283 | src_raw ([str]): raw src words 284 | 285 | pred_sents ([[str]]): words from the n-best translations 286 | pred_scores ([[float]]): log-probs of n-best translations 287 | attns ([`FloatTensor`]) : attention dist for each translation 288 | gold_sent ([str]): words from gold translation 289 | gold_score ([float]): log-prob of gold translation 290 | 291 | """ 292 | 293 | def __init__(self, fname, src, src_raw, pred_sents, 294 | attn, pred_scores, tgt_sent, gold_score): 295 | self.fname = fname 296 | self.src = src 297 | self.src_raw = src_raw 298 | self.pred_sents = pred_sents 299 | self.attns = attn 300 | self.pred_scores = pred_scores 301 | self.gold_sent = tgt_sent 302 | self.gold_score = gold_score 303 | 304 | def log(self, sent_number): 305 | """ 306 | Log translation. 307 | """ 308 | 309 | output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) 310 | 311 | best_pred = self.pred_sents[0] 312 | best_score = self.pred_scores[0] 313 | pred_sent = ' '.join(best_pred) 314 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 315 | output += "PRED SCORE: {:.4f}\n".format(best_score) 316 | 317 | if self.gold_sent is not None: 318 | tgt_sent = ' '.join(self.gold_sent) 319 | output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) 320 | output += ("GOLD SCORE: {:.4f}\n".format(self.gold_score)) 321 | if len(self.pred_sents) > 1: 322 | output += '\nBEST HYP:\n' 323 | for score, sent in zip(self.pred_scores, self.pred_sents): 324 | output += "[{:.4f}] {}\n".format(score, sent) 325 | 326 | return output 327 | -------------------------------------------------------------------------------- /src/models/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers class """ 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | 7 | def use_gpu(opt): 8 | """ 9 | Creates a boolean if gpu used 10 | """ 11 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 12 | (hasattr(opt, 'gpu') and opt.gpu > -1) 13 | 14 | 15 | def build_optim(args, model, checkpoint, generation=False): 16 | """ Build optimizer """ 17 | 18 | if checkpoint is not None: 19 | optim = checkpoint['optims'][0] 20 | saved_optimizer_state_dict = optim.optimizer.state_dict() 21 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 22 | if args.visible_gpus != '-1': 23 | for state in optim.optimizer.state.values(): 24 | for k, v in state.items(): 25 | if torch.is_tensor(v): 26 | state[k] = v.cuda() 27 | 28 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 29 | raise RuntimeError( 30 | "Error: loaded Adam optimizer from existing model" + 31 | " but optimizer state is empty") 32 | 33 | else: 34 | if generation: 35 | optim = Optimizer( 36 | args.optim, args.lr, args.max_grad_norm, 37 | beta1=args.beta1, beta2=args.beta2, 38 | decay_method='noam', 39 | warmup_steps=args.warmup_steps) 40 | else: 41 | optim = Optimizer( 42 | args.optim, args.lr, args.max_grad_norm, 43 | beta1=args.beta1, beta2=args.beta2, 44 | start_decay_steps=1, 45 | decay_steps=10, 46 | lr_decay=0.9999) 47 | 48 | optim.set_parameters(list(model.named_parameters())) 49 | 50 | return optim 51 | 52 | 53 | def build_optim_bert(args, model, checkpoint): 54 | """ Build optimizer """ 55 | 56 | if checkpoint is not None: 57 | optim = checkpoint['optims'][0] 58 | saved_optimizer_state_dict = optim.optimizer.state_dict() 59 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 60 | if args.visible_gpus != '-1': 61 | for state in optim.optimizer.state.values(): 62 | for k, v in state.items(): 63 | if torch.is_tensor(v): 64 | state[k] = v.cuda() 65 | 66 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 67 | raise RuntimeError( 68 | "Error: loaded Adam optimizer from existing model" + 69 | " but optimizer state is empty") 70 | 71 | else: 72 | optim = Optimizer( 73 | args.optim, args.lr_bert, args.max_grad_norm, 74 | beta1=args.beta1, beta2=args.beta2, 75 | decay_method='noam', 76 | warmup_steps=args.warmup_steps_bert) 77 | 78 | params = [(n, p) for n, p in list(model.named_parameters()) if n.startswith('encoder.model')] 79 | optim.set_parameters(params) 80 | 81 | return optim 82 | 83 | 84 | def build_optim_other(args, model, checkpoint): 85 | """ Build optimizer """ 86 | 87 | if checkpoint is not None: 88 | optim = checkpoint['optims'][1] 89 | saved_optimizer_state_dict = optim.optimizer.state_dict() 90 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 91 | if args.visible_gpus != '-1': 92 | for state in optim.optimizer.state.values(): 93 | for k, v in state.items(): 94 | if torch.is_tensor(v): 95 | state[k] = v.cuda() 96 | 97 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 98 | raise RuntimeError( 99 | "Error: loaded Adam optimizer from existing model" + 100 | " but optimizer state is empty") 101 | 102 | else: 103 | optim = Optimizer( 104 | args.optim, args.lr_other, args.max_grad_norm, 105 | beta1=args.beta1, beta2=args.beta2, 106 | decay_method='noam', 107 | warmup_steps=args.warmup_steps_other) 108 | 109 | if args.encoder == 'bert': 110 | params = [(n, p) for n, p in list(model.named_parameters()) 111 | if not n.startswith('encoder.model') and not n.startswith('topic_model')] 112 | else: 113 | params = [(n, p) for n, p in list(model.named_parameters()) 114 | if not n.startswith('topic_model')] 115 | optim.set_parameters(params) 116 | 117 | return optim 118 | 119 | 120 | def build_optim_topic(args, model, checkpoint): 121 | """ Build optimizer """ 122 | 123 | if checkpoint is not None: 124 | optim = checkpoint['optims'][2] 125 | saved_optimizer_state_dict = optim.optimizer.state_dict() 126 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 127 | if args.visible_gpus != '-1': 128 | for state in optim.optimizer.state.values(): 129 | for k, v in state.items(): 130 | if torch.is_tensor(v): 131 | state[k] = v.cuda() 132 | 133 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 134 | raise RuntimeError( 135 | "Error: loaded Adam optimizer from existing model" + 136 | " but optimizer state is empty") 137 | 138 | else: 139 | optim = Optimizer( 140 | args.optim, args.lr_topic, args.max_grad_norm, 141 | beta1=args.beta1, beta2=args.beta2, 142 | start_decay_steps=1, 143 | decay_steps=20, 144 | lr_decay=0.999) 145 | 146 | params = [(n, p) for n, p in list(model.named_parameters()) if n.startswith('topic_model')] 147 | optim.set_parameters(params) 148 | 149 | return optim 150 | 151 | 152 | class Optimizer(object): 153 | """ 154 | Controller class for optimization. Mostly a thin 155 | wrapper for `optim`, but also useful for implementing 156 | rate scheduling beyond what is currently available. 157 | Also implements necessary methods for training RNNs such 158 | as grad manipulations. 159 | 160 | Args: 161 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 162 | lr (float): learning rate 163 | lr_decay (float, optional): learning rate decay multiplier 164 | start_decay_steps (int, optional): step to start learning rate decay 165 | beta1, beta2 (float, optional): parameters for adam 166 | adagrad_accum (float, optional): initialization parameter for adagrad 167 | decay_method (str, option): custom decay options 168 | warmup_steps (int, option): parameter for `noam` decay 169 | model_size (int, option): parameter for `noam` decay 170 | 171 | We use the default parameters for Adam that are suggested by 172 | the original paper https://arxiv.org/pdf/1412.6980.pdf 173 | These values are also used by other established implementations, 174 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 175 | https://keras.io/optimizers/ 176 | Recently there are slightly different values used in the paper 177 | "Attention is all you need" 178 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 179 | was used there however, beta2=0.999 is still arguably the more 180 | established value, so we use that here as well 181 | """ 182 | 183 | def __init__(self, method, learning_rate, max_grad_norm, 184 | lr_decay=1, start_decay_steps=None, decay_steps=None, 185 | beta1=0.9, beta2=0.999, 186 | adagrad_accum=0.0, 187 | decay_method=None, 188 | warmup_steps=4000, weight_decay=0): 189 | self.last_ppl = None 190 | self.learning_rate = learning_rate 191 | self.original_lr = learning_rate 192 | self.max_grad_norm = max_grad_norm 193 | self.method = method 194 | self.lr_decay = lr_decay 195 | self.start_decay_steps = start_decay_steps 196 | self.decay_steps = decay_steps 197 | self.start_decay = False 198 | self._step = 0 199 | self.betas = [beta1, beta2] 200 | self.adagrad_accum = adagrad_accum 201 | self.decay_method = decay_method 202 | self.warmup_steps = warmup_steps 203 | self.weight_decay = weight_decay 204 | 205 | def set_parameters(self, params): 206 | """ ? """ 207 | self.params = [] 208 | self.sparse_params = [] 209 | for k, p in params: 210 | if p.requires_grad: 211 | if self.method != 'sparseadam' or "embed" not in k: 212 | self.params.append(p) 213 | else: 214 | self.sparse_params.append(p) 215 | if self.method == 'sgd': 216 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 217 | elif self.method == 'adagrad': 218 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 219 | for group in self.optimizer.param_groups: 220 | for p in group['params']: 221 | self.optimizer.state[p]['sum'] = self.optimizer\ 222 | .state[p]['sum'].fill_(self.adagrad_accum) 223 | elif self.method == 'adadelta': 224 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 225 | elif self.method == 'adam': 226 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 227 | betas=self.betas, eps=1e-9) 228 | else: 229 | raise RuntimeError("Invalid optim method: " + self.method) 230 | 231 | def _set_rate(self, learning_rate): 232 | self.learning_rate = learning_rate 233 | if self.method != 'sparseadam': 234 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 235 | else: 236 | for op in self.optimizer.optimizers: 237 | op.param_groups[0]['lr'] = self.learning_rate 238 | 239 | def step(self): 240 | """Update the model parameters based on current gradients. 241 | 242 | Optionally, will employ gradient modification or update learning 243 | rate. 244 | """ 245 | self._step += 1 246 | 247 | # Decay method used in tensor2tensor. 248 | if self.decay_method == "noam": 249 | self._set_rate(self.original_lr * 250 | min(self._step ** (-0.5), 251 | self._step * self.warmup_steps**(-1.5))) 252 | 253 | else: 254 | if ((self.start_decay_steps is not None) and ( 255 | self._step >= self.start_decay_steps)): 256 | self.start_decay = True 257 | if self.start_decay: 258 | if ((self._step - self.start_decay_steps) 259 | % self.decay_steps == 0): 260 | self.learning_rate = self.learning_rate * self.lr_decay 261 | 262 | if self.method != 'sparseadam': 263 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 264 | 265 | if self.max_grad_norm: 266 | clip_grad_norm_(self.params, self.max_grad_norm) 267 | self.optimizer.step() 268 | -------------------------------------------------------------------------------- /src/models/reporter.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | from datetime import datetime 4 | 5 | import time 6 | import math 7 | import sys 8 | 9 | from distributed import all_gather_list 10 | from others.logging import logger 11 | 12 | 13 | def build_report_manager(opt): 14 | if opt.tensorboard: 15 | from tensorboardX import SummaryWriter 16 | writer = SummaryWriter(opt.tensorboard_log_dir 17 | + datetime.now().strftime("/%b-%d_%H-%M-%S"), 18 | comment="Unmt") 19 | else: 20 | writer = None 21 | 22 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 23 | tensorboard_writer=writer) 24 | return report_mgr 25 | 26 | 27 | class ReportMgrBase(object): 28 | """ 29 | Report Manager Base class 30 | Inherited classes should override: 31 | * `_report_training` 32 | * `_report_step` 33 | """ 34 | 35 | def __init__(self, report_every, start_time=-1.): 36 | """ 37 | Args: 38 | report_every(int): Report status every this many sentences 39 | start_time(float): manually set report start time. Negative values 40 | means that you will need to set it later or use `start()` 41 | """ 42 | self.report_every = report_every 43 | self.progress_step = 0 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def log(self, *args, **kwargs): 50 | logger.info(*args, **kwargs) 51 | 52 | def report_training(self, step, num_steps, learning_rate, 53 | report_stats, multigpu=False): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | step(int): current step count. 60 | num_steps(int): total number of batches. 61 | learning_rate(float): current learning rate. 62 | report_stats(Statistics): old Statistics instance. 63 | Returns: 64 | report_stats(Statistics): updated Statistics instance. 65 | """ 66 | if self.start_time < 0: 67 | raise ValueError("""ReportMgr needs to be started 68 | (set 'start_time' or use 'start()'""") 69 | 70 | if multigpu: 71 | report_stats = Statistics.all_gather_stats(report_stats) 72 | 73 | if step % self.report_every == 0: 74 | self._report_training( 75 | step, num_steps, learning_rate, report_stats) 76 | self.progress_step += 1 77 | return Statistics() 78 | 79 | def _report_training(self, *args, **kwargs): 80 | """ To be overridden """ 81 | raise NotImplementedError() 82 | 83 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 84 | """ 85 | Report stats of a step 86 | 87 | Args: 88 | train_stats(Statistics): training stats 89 | valid_stats(Statistics): validation stats 90 | lr(float): current learning rate 91 | """ 92 | self._report_step( 93 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 94 | 95 | def _report_step(self, *args, **kwargs): 96 | raise NotImplementedError() 97 | 98 | 99 | class ReportMgr(ReportMgrBase): 100 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 101 | """ 102 | A report manager that writes statistics on standard output as well as 103 | (optionally) TensorBoard 104 | 105 | Args: 106 | report_every(int): Report status every this many sentences 107 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 108 | The TensorBoard Summary writer to use or None 109 | """ 110 | super(ReportMgr, self).__init__(report_every, start_time) 111 | self.tensorboard_writer = tensorboard_writer 112 | 113 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 114 | if self.tensorboard_writer is not None: 115 | stats.log_tensorboard( 116 | prefix, self.tensorboard_writer, learning_rate, step) 117 | 118 | def _report_training(self, step, num_steps, learning_rate, 119 | report_stats): 120 | """ 121 | See base class method `ReportMgrBase.report_training`. 122 | """ 123 | report_stats.output(step, num_steps, 124 | learning_rate, self.start_time) 125 | 126 | # Log the progress using the number of batches on the x-axis. 127 | self.maybe_log_tensorboard(report_stats, 128 | "progress", 129 | learning_rate, 130 | step) 131 | report_stats = Statistics() 132 | 133 | return report_stats 134 | 135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 136 | """ 137 | See base class method `ReportMgrBase.report_step`. 138 | """ 139 | if train_stats is not None: 140 | self.log('Train perplexity: %g' % train_stats.ppl()) 141 | self.log('Train accuracy: %g' % train_stats.accuracy()) 142 | 143 | self.maybe_log_tensorboard(train_stats, 144 | "train", 145 | lr, 146 | step) 147 | 148 | if valid_stats is not None: 149 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 150 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 151 | 152 | self.maybe_log_tensorboard(valid_stats, 153 | "valid", 154 | lr, 155 | step) 156 | 157 | 158 | class Statistics(object): 159 | """ 160 | Accumulator for loss statistics. 161 | Currently calculates: 162 | 163 | * accuracy 164 | * perplexity 165 | * elapsed time 166 | """ 167 | 168 | def __init__(self, loss=0, n_words=0, n_correct=0, rl_loss=0, topic_loss=0): 169 | self.loss = loss 170 | self.n_words = n_words 171 | self.n_docs = 0 172 | self.n_correct = n_correct 173 | self.n_src_words = 0 174 | self.start_time = time.time() 175 | self.rl_loss = rl_loss 176 | self.topic_loss = topic_loss 177 | 178 | @staticmethod 179 | def all_gather_stats(stat, max_size=4096): 180 | """ 181 | Gather a `Statistics` object accross multiple process/nodes 182 | 183 | Args: 184 | stat(:obj:Statistics): the statistics object to gather 185 | accross all processes/nodes 186 | max_size(int): max buffer size to use 187 | 188 | Returns: 189 | `Statistics`, the update stats object 190 | """ 191 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 192 | return stats[0] 193 | 194 | @staticmethod 195 | def all_gather_stats_list(stat_list, max_size=4096): 196 | from torch.distributed import get_rank 197 | 198 | """ 199 | Gather a `Statistics` list accross all processes/nodes 200 | 201 | Args: 202 | stat_list(list([`Statistics`])): list of statistics objects to 203 | gather accross all processes/nodes 204 | max_size(int): max buffer size to use 205 | 206 | Returns: 207 | our_stats(list([`Statistics`])): list of updated stats 208 | """ 209 | # Get a list of world_size lists with len(stat_list) Statistics objects 210 | all_stats = all_gather_list(stat_list, max_size=max_size) 211 | 212 | our_rank = get_rank() 213 | our_stats = all_stats[our_rank] 214 | for other_rank, stats in enumerate(all_stats): 215 | if other_rank == our_rank: 216 | continue 217 | for i, stat in enumerate(stats): 218 | our_stats[i].update(stat, update_n_src_words=True) 219 | return our_stats 220 | 221 | def update(self, stat, update_n_src_words=False): 222 | """ 223 | Update statistics by suming values with another `Statistics` object 224 | 225 | Args: 226 | stat: another statistic object 227 | update_n_src_words(bool): whether to update (sum) `n_src_words` 228 | or not 229 | 230 | """ 231 | self.loss += stat.loss 232 | self.n_words += stat.n_words 233 | self.n_correct += stat.n_correct 234 | self.n_docs += stat.n_docs 235 | self.rl_loss += stat.rl_loss 236 | self.topic_loss += stat.topic_loss 237 | 238 | if update_n_src_words: 239 | self.n_src_words += stat.n_src_words 240 | 241 | def accuracy(self): 242 | """ compute accuracy """ 243 | if self.n_words != 0: 244 | return 100 * (self.n_correct / self.n_words) 245 | else: 246 | return -1 247 | 248 | def xent(self): 249 | """ compute cross entropy """ 250 | if self.n_words != 0: 251 | return self.loss / self.n_words 252 | else: 253 | return -1 254 | 255 | def ppl(self): 256 | """ compute perplexity """ 257 | if self.n_words != 0: 258 | return math.exp(min(self.loss / self.n_words, 100)) 259 | else: 260 | return -1 261 | 262 | def rlloss(self): 263 | return self.rl_loss 264 | 265 | def tploss(self): 266 | return self.topic_loss 267 | 268 | def elapsed_time(self): 269 | """ compute elapsed time """ 270 | return time.time() - self.start_time 271 | 272 | def output(self, step, num_steps, learning_rate, start): 273 | """Write out statistics to stdout. 274 | 275 | Args: 276 | step (int): current step 277 | n_batch (int): total batches 278 | start (int): start time of step. 279 | """ 280 | t = self.elapsed_time() 281 | logger.info( 282 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; rlloss: %4.2f; " + 283 | "tploss: %4.2f; lr: %7.8f; %3.0f/%3.0f tok/s; %6.0f sec") 284 | % (step, num_steps, 285 | self.accuracy(), 286 | self.ppl(), 287 | self.xent(), 288 | self.rlloss(), 289 | self.tploss(), 290 | learning_rate, 291 | self.n_src_words / (t + 1e-5), 292 | self.n_words / (t + 1e-5), 293 | time.time() - start)) 294 | sys.stdout.flush() 295 | 296 | def log_tensorboard(self, prefix, writer, learning_rate, step): 297 | """ display statistics to tensorboard """ 298 | t = self.elapsed_time() 299 | writer.add_scalar(prefix + "/xent", self.xent(), step) 300 | writer.add_scalar(prefix + "/tploss", self.tploss(), step) 301 | writer.add_scalar(prefix + "/rlloss", self.rlloss(), step) 302 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 303 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 304 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 305 | writer.add_scalar(prefix + "/lr", learning_rate, step) 306 | -------------------------------------------------------------------------------- /src/models/rl_model_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | import distributed 7 | from models.reporter import ReportMgr, Statistics 8 | from models.loss import abs_loss, CrossEntropyLossCompute 9 | from others.logging import logger 10 | 11 | 12 | def _tally_parameters(model): 13 | n_params = sum([p.nelement() for p in model.parameters()]) 14 | return n_params 15 | 16 | 17 | def build_trainer(args, device_id, model, optims, tokenizer): 18 | """ 19 | Simplify `Trainer` creation based on user `opt`s* 20 | Args: 21 | opt (:obj:`Namespace`): user options (usually from argument parsing) 22 | model (:obj:`onmt.models.NMTModel`): the model to train 23 | fields (dict): dict of fields 24 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 25 | data_type (str): string describing the type of data 26 | e.g. "text", "img", "audio" 27 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 28 | used to save the model 29 | """ 30 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 31 | 32 | grad_accum_count = args.accum_count 33 | n_gpu = args.world_size 34 | 35 | if device_id >= 0: 36 | gpu_rank = int(args.gpu_ranks[device_id]) 37 | else: 38 | gpu_rank = 0 39 | n_gpu = 0 40 | 41 | print('gpu_rank %d' % gpu_rank) 42 | 43 | tensorboard_log_dir = args.model_path 44 | 45 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 46 | 47 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 48 | 49 | symbols = {'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 50 | 'PAD': tokenizer.vocab['[PAD]'], 'SEG': tokenizer.vocab['[unused3]'], 51 | 'UNK': tokenizer.vocab['[UNK]']} 52 | 53 | gen_loss = abs_loss(args, model.generator, symbols, tokenizer.vocab, device, train=True) 54 | 55 | pn_loss = CrossEntropyLossCompute().to(device) 56 | 57 | trainer = Trainer(args, model, optims, tokenizer, gen_loss, pn_loss, 58 | grad_accum_count, n_gpu, gpu_rank, report_manager) 59 | 60 | # print(tr) 61 | if (model): 62 | n_params = _tally_parameters(model) 63 | logger.info('* number of parameters: %d' % n_params) 64 | 65 | return trainer 66 | 67 | 68 | class Trainer(object): 69 | """ 70 | Class that controls the training process. 71 | 72 | Args: 73 | model(:py:class:`onmt.models.model.NMTModel`): translation model 74 | to train 75 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 76 | training loss computation 77 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 78 | training loss computation 79 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 80 | the optimizer responsible for update 81 | trunc_size(int): length of truncated back propagation through time 82 | shard_size(int): compute loss in shards of this size for efficiency 83 | data_type(string): type of the source input: [text|img|audio] 84 | norm_method(string): normalization methods: [sents|tokens] 85 | grad_accum_count(int): accumulate gradients this many times. 86 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 87 | the object that creates reports, or None 88 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 89 | used to save a checkpoint. 90 | Thus nothing will be saved if this parameter is None 91 | """ 92 | 93 | def __init__(self, args, model, optims, tokenizer, abs_loss, pn_loss, 94 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 95 | report_manager=None): 96 | # Basic attributes. 97 | self.args = args 98 | self.save_checkpoint_steps = args.save_checkpoint_steps 99 | self.model = model 100 | self.optims = optims 101 | self.tokenizer = tokenizer 102 | self.grad_accum_count = grad_accum_count 103 | self.n_gpu = n_gpu 104 | self.gpu_rank = gpu_rank 105 | self.report_manager = report_manager 106 | 107 | self.abs_loss = abs_loss 108 | self.pn_loss = pn_loss 109 | 110 | assert grad_accum_count > 0 111 | # Set model in training mode. 112 | if (model): 113 | self.model.train() 114 | 115 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 116 | """ 117 | The main training loops. 118 | by iterating over training data (i.e. `train_iter_fct`) 119 | and running validation (i.e. iterating over `valid_iter_fct` 120 | 121 | Args: 122 | train_iter_fct(function): a function that returns the train 123 | iterator. e.g. something like 124 | train_iter_fct = lambda: generator(*args, **kwargs) 125 | valid_iter_fct(function): same as train_iter_fct, for valid data 126 | train_steps(int): 127 | valid_steps(int): 128 | save_checkpoint_steps(int): 129 | 130 | Return: 131 | None 132 | """ 133 | logger.info('Start training...') 134 | 135 | step = self.optims[0]._step + 1 136 | true_batchs = [] 137 | accum = 0 138 | tgt_tokens = 0 139 | src_tokens = 0 140 | tgt_labels = 0 141 | sents = 0 142 | examples = 0 143 | 144 | train_iter = train_iter_fct() 145 | total_stats = Statistics() 146 | report_stats = Statistics() 147 | self._start_report_manager(start_time=total_stats.start_time) 148 | 149 | while step <= train_steps: 150 | 151 | for i, batch in enumerate(train_iter): 152 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 153 | 154 | true_batchs.append(batch) 155 | tgt_tokens += batch.tgt[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 156 | src_tokens += batch.src[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 157 | tgt_labels += sum([len(l)+1 for l in batch.tgt_labels]) 158 | sents += batch.src.size(0) 159 | examples += batch.tgt.size(0) 160 | accum += 1 161 | if accum == self.grad_accum_count: 162 | if self.n_gpu > 1: 163 | tgt_tokens = sum(distributed.all_gather_list(tgt_tokens)) 164 | src_tokens = sum(distributed.all_gather_list(src_tokens)) 165 | tgt_labels = sum(distributed.all_gather_list(tgt_labels)) 166 | sents = sum(distributed.all_gather_list(sents)) 167 | examples = sum(distributed.all_gather_list(examples)) 168 | 169 | normalization = (tgt_tokens, src_tokens, tgt_labels, sents, examples) 170 | self._gradient_calculation( 171 | true_batchs, normalization, total_stats, 172 | report_stats, step) 173 | 174 | report_stats = self._maybe_report_training( 175 | step, train_steps, 176 | self.optims[0].learning_rate, 177 | report_stats) 178 | 179 | true_batchs = [] 180 | accum = 0 181 | src_tokens = 0 182 | tgt_tokens = 0 183 | tgt_labels = 0 184 | sents = 0 185 | examples = 0 186 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 187 | self._save(step) 188 | step += 1 189 | if step > train_steps: 190 | break 191 | train_iter = train_iter_fct() 192 | 193 | return total_stats 194 | 195 | def _gradient_calculation(self, true_batchs, normalization, total_stats, 196 | report_stats, step): 197 | self.model.zero_grad() 198 | 199 | for batch in true_batchs: 200 | if self.args.pretrain: 201 | pn_output, decode_output, topic_loss, _ = self.model.pretrain(batch) 202 | else: 203 | rl_loss, decode_output, topic_loss, _, _ = self.model(batch) 204 | 205 | tgt_tokens, src_tokens, tgt_labels, sents, examples = normalization 206 | 207 | if self.args.pretrain: 208 | if self.args.topic_model: 209 | # Topic Model loss 210 | topic_stats = Statistics(topic_loss=topic_loss.clone().item() / float(examples)) 211 | topic_loss.div(float(examples)).backward(retain_graph=True) 212 | total_stats.update(topic_stats) 213 | report_stats.update(topic_stats) 214 | 215 | # Extractiton Loss 216 | pn_stats = self.pn_loss(batch.pn_tgt, pn_output, self.args.generator_shard_size, 217 | tgt_labels, retain_graph=True) 218 | total_stats.update(pn_stats) 219 | report_stats.update(pn_stats) 220 | 221 | # Generation loss 222 | abs_stats = self.abs_loss(batch, decode_output, self.args.generator_shard_size, 223 | tgt_tokens, retain_graph=False) 224 | abs_stats.n_docs = len(batch) 225 | total_stats.update(abs_stats) 226 | report_stats.update(abs_stats) 227 | 228 | else: 229 | if self.args.topic_model: 230 | # Topic Model loss 231 | topic_stats = Statistics(topic_loss=topic_loss.clone().item() / float(examples)) 232 | topic_loss.div(float(examples)).backward(retain_graph=True) 233 | total_stats.update(topic_stats) 234 | report_stats.update(topic_stats) 235 | 236 | # RL loss 237 | rl_stats = Statistics(rl_loss=rl_loss.clone().item() / float(examples)) 238 | # critic_stats = Statistics(ct_loss=critic_loss.clone().item() / float(examples)) 239 | rl_loss.div(float(examples)).backward(retain_graph=True) 240 | total_stats.update(rl_stats) 241 | # total_stats.update(critic_stats) 242 | report_stats.update(rl_stats) 243 | # report_stats.update(critic_stats) 244 | 245 | # Generation loss 246 | abs_stats = self.abs_loss(batch, decode_output, self.args.generator_shard_size, 247 | tgt_tokens, retain_graph=False) 248 | abs_stats.n_docs = len(batch) 249 | total_stats.update(abs_stats) 250 | report_stats.update(abs_stats) 251 | 252 | # in case of multi step gradient accumulation, 253 | # update only after accum batches 254 | if self.n_gpu > 1: 255 | grads = [p.grad.data for p in self.model.parameters() 256 | if p.requires_grad 257 | and p.grad is not None] 258 | distributed.all_reduce_and_rescale_tensors( 259 | grads, float(1)) 260 | for o in self.optims: 261 | o.step() 262 | 263 | def _save(self, step): 264 | real_model = self.model 265 | 266 | model_state_dict = real_model.state_dict() 267 | # generator_state_dict = real_generator.state_dict() 268 | checkpoint = { 269 | 'model': model_state_dict, 270 | # 'generator': generator_state_dict, 271 | 'opt': self.args, 272 | 'optims': self.optims, 273 | } 274 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 275 | logger.info("Saving checkpoint %s" % checkpoint_path) 276 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 277 | torch.save(checkpoint, checkpoint_path) 278 | return checkpoint, checkpoint_path 279 | 280 | def _start_report_manager(self, start_time=None): 281 | """ 282 | Simple function to start report manager (if any) 283 | """ 284 | if self.report_manager is not None: 285 | if start_time is None: 286 | self.report_manager.start() 287 | else: 288 | self.report_manager.start_time = start_time 289 | 290 | def _maybe_gather_stats(self, stat): 291 | """ 292 | Gather statistics in multi-processes cases 293 | 294 | Args: 295 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 296 | or None (it returns None in this case) 297 | 298 | Returns: 299 | stat: the updated (or unchanged) stat object 300 | """ 301 | if stat is not None and self.n_gpu > 1: 302 | return Statistics.all_gather_stats(stat) 303 | return stat 304 | 305 | def _maybe_report_training(self, step, num_steps, learning_rate, 306 | report_stats): 307 | """ 308 | Simple function to report training stats (if report_manager is set) 309 | see `onmt.utils.ReportManagerBase.report_training` for doc 310 | """ 311 | if self.report_manager is not None: 312 | return self.report_manager.report_training( 313 | step, num_steps, learning_rate, report_stats, 314 | multigpu=self.n_gpu > 1) 315 | 316 | def _report_step(self, learning_rate, step, train_stats=None, 317 | valid_stats=None): 318 | """ 319 | Simple function to report stats (if report_manager is set) 320 | see `onmt.utils.ReportManagerBase.report_step` for doc 321 | """ 322 | if self.report_manager is not None: 323 | return self.report_manager.report_step( 324 | learning_rate, step, train_stats=train_stats, 325 | valid_stats=valid_stats) 326 | 327 | def _maybe_save(self, step): 328 | """ 329 | Save the model if a model saver is set 330 | """ 331 | if self.model_saver is not None: 332 | self.model_saver.maybe_save(step) 333 | -------------------------------------------------------------------------------- /src/models/seq2seq_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*-coding:utf8-*- 3 | """ Translator Class and builder """ 4 | from __future__ import print_function 5 | import codecs 6 | import torch 7 | 8 | from tensorboardX import SummaryWriter 9 | from others.utils import rouge_results_to_str, test_bleu, test_length 10 | from translate.beam import GNMTGlobalScorer 11 | from rouge import Rouge, FilesRouge 12 | from nltk.translate.bleu_score import sentence_bleu 13 | from nltk.translate.bleu_score import SmoothingFunction 14 | 15 | 16 | def build_predictor(args, tokenizer, model, logger=None): 17 | scorer = GNMTGlobalScorer(args.alpha, length_penalty='wu') 18 | 19 | translator = Translator(args, model, tokenizer, 20 | global_scorer=scorer, logger=logger) 21 | return translator 22 | 23 | 24 | class Translator(object): 25 | """ 26 | Uses a model to translate a batch of sentences. 27 | 28 | 29 | Args: 30 | model (:obj:`onmt.modules.NMTModel`): 31 | NMT model to use for translation 32 | fields (dict of Fields): data fields 33 | beam_size (int): size of beam to use 34 | n_best (int): number of translations produced 35 | max_length (int): maximum length output to produce 36 | global_scores (:obj:`GlobalScorer`): 37 | object to rescore final translations 38 | copy_attn (bool): use copy attention during translation 39 | cuda (bool): use cuda 40 | beam_trace (bool): trace beam search for debugging 41 | logger(logging.Logger): logger. 42 | """ 43 | 44 | def __init__(self, 45 | args, 46 | model, 47 | tokenizer, 48 | global_scorer=None, 49 | logger=None, 50 | dump_beam=""): 51 | self.logger = logger 52 | self.cuda = args.visible_gpus != '-1' 53 | 54 | self.args = args 55 | self.model = model 56 | self.generator = self.model.generator 57 | self.tokenizer = tokenizer 58 | self.vocab = tokenizer.vocab 59 | self.start_token = self.vocab['[unused1]'] 60 | self.end_token = self.vocab['[unused2]'] 61 | self.seg_token = self.vocab['[unused3]'] 62 | 63 | self.global_scorer = global_scorer 64 | self.beam_size = args.beam_size 65 | self.min_length = args.min_length 66 | self.max_length = args.max_length 67 | 68 | self.dump_beam = dump_beam 69 | 70 | # for debugging 71 | self.beam_trace = self.dump_beam != "" 72 | self.beam_accum = None 73 | 74 | tensorboard_log_dir = args.model_path 75 | 76 | self.tensorboard_writer = SummaryWriter( 77 | tensorboard_log_dir, comment="Unmt") 78 | 79 | if self.beam_trace: 80 | self.beam_accum = { 81 | "predicted_ids": [], 82 | "beam_parent_ids": [], 83 | "scores": [], 84 | "log_probs": []} 85 | 86 | def _build_target_tokens(self, pred): 87 | # vocab = self.fields["tgt"].vocab 88 | tokens = [] 89 | for tok in pred: 90 | tok = int(tok) 91 | tokens.append(tok) 92 | if tokens[-1] == self.end_token: 93 | tokens = tokens[:-1] 94 | break 95 | tokens = [t for t in tokens if t < len(self.vocab)] 96 | tokens = self.vocab.DecodeIds(tokens).split(' ') 97 | return tokens 98 | 99 | def from_batch_dev(self, doc_batch, tgt_data): 100 | 101 | translations = [] 102 | 103 | batch_size = len(doc_batch) 104 | 105 | for b in range(batch_size): 106 | 107 | # generated text 108 | pred_summ = self.tokenizer.convert_ids_to_tokens( 109 | [int(n) for n in doc_batch[b]]) 110 | pred_summ = ' '.join(pred_summ) 111 | pred_summ = pred_summ.replace('[unused0]', '').replace('[unused1]', '').\ 112 | replace('[unused2]', '').replace('[unused5]', '#').replace('[UNK]', '').strip() 113 | pred_summ = ' '.join(pred_summ.split()) 114 | 115 | gold_data = ' '.join(tgt_data[b]) 116 | gold_data = gold_data.replace('[PAD]', '').replace('[unused1]', '').\ 117 | replace('[unused2]', '').replace('[unused5]', '#').replace("[UNK]", '').strip() 118 | gold_data = ' '.join(gold_data.split()) 119 | 120 | translations.append((pred_summ, gold_data)) 121 | 122 | return translations 123 | 124 | def from_batch_test(self, batch, output_batch, tgt_data): 125 | 126 | translations = [] 127 | 128 | batch_size = len(batch) 129 | 130 | origin_txt, ex_segs = batch.original_str, batch.ex_segs 131 | 132 | ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)] 133 | 134 | for b in range(batch_size): 135 | # original text 136 | original_sent = ' '.join(origin_txt[ex_segs[b]:ex_segs[b+1]]) 137 | 138 | # long doc context text 139 | pred_summ = self.tokenizer.convert_ids_to_tokens( 140 | [int(n) for n in output_batch[b]]) 141 | pred_summ = ' '.join(pred_summ) 142 | 143 | pred_summ = pred_summ.replace('[unused0]', '').replace('[unused1]', '').\ 144 | replace('[unused2]', '').replace('[unused5]', '#').replace('[UNK]', '').strip() 145 | pred_summ = ' '.join(pred_summ.split()) 146 | 147 | gold_data = ' '.join(tgt_data[b]) 148 | gold_data = gold_data.replace('[PAD]', '').replace('[unused1]', '').replace('[unused2]', '').\ 149 | replace('[unused5]', '#').replace('[UNK]', '').strip() 150 | gold_data = ' '.join(gold_data.split()) 151 | 152 | translation = (original_sent, pred_summ, gold_data) 153 | translations.append(translation) 154 | 155 | return translations 156 | 157 | def validate(self, data_iter, step, attn_debug=False): 158 | 159 | self.model.eval() 160 | gold_path = self.args.result_path + 'step.%d.gold_temp' % step 161 | pred_path = self.args.result_path + 'step.%d.pred_temp' % step 162 | gold_out_file = codecs.open(gold_path, 'w', 'utf-8') 163 | pred_out_file = codecs.open(pred_path, 'w', 'utf-8') 164 | 165 | # pred_results, gold_results = [], [] 166 | ct = 0 167 | with torch.no_grad(): 168 | for batch in data_iter: 169 | output_data, tgt_data = self.translate_batch(batch) 170 | translations = self.from_batch_dev(output_data, tgt_data) 171 | 172 | for idx in range(len(translations)): 173 | if ct % 100 == 0: 174 | print("Processing %d" % ct) 175 | pred_summ, gold_data = translations[idx] 176 | pred_out_file.write(pred_summ + '\n') 177 | gold_out_file.write(gold_data + '\n') 178 | ct += 1 179 | pred_out_file.flush() 180 | gold_out_file.flush() 181 | 182 | pred_out_file.close() 183 | gold_out_file.close() 184 | 185 | if (step != -1): 186 | pred_bleu = test_bleu(pred_path, gold_path) 187 | file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) 188 | pred_rouges = file_rouge.get_scores(avg=True) 189 | self.logger.info('Gold Length at step %d: %.2f' % 190 | (step, test_length(gold_path, gold_path, ratio=False))) 191 | self.logger.info('Prediction Length ratio at step %d: %.2f' % 192 | (step, test_length(pred_path, gold_path))) 193 | self.logger.info('Prediction Bleu at step %d: %.2f' % 194 | (step, pred_bleu*100)) 195 | self.logger.info('Prediction Rouges at step %d: \n%s\n' % 196 | (step, rouge_results_to_str(pred_rouges))) 197 | rouge_results = (pred_rouges["rouge-1"]['f'], 198 | pred_rouges["rouge-l"]['f']) 199 | return rouge_results 200 | 201 | def translate(self, 202 | data_iter, step, 203 | attn_debug=False): 204 | 205 | self.model.eval() 206 | output_path = self.args.result_path + '.%d.output' % step 207 | output_file = codecs.open(output_path, 'w', 'utf-8') 208 | gold_path = self.args.result_path + '.%d.gold_test' % step 209 | pred_path = self.args.result_path + '.%d.pred_test' % step 210 | gold_out_file = codecs.open(gold_path, 'w', 'utf-8') 211 | pred_out_file = codecs.open(pred_path, 'w', 'utf-8') 212 | # pred_results, gold_results = [], [] 213 | ct = 0 214 | 215 | with torch.no_grad(): 216 | rouge = Rouge() 217 | for batch in data_iter: 218 | output_data, tgt_data = self.translate_batch(batch) 219 | translations = self.from_batch_test(batch, output_data, tgt_data) 220 | 221 | for idx in range(len(translations)): 222 | origin_sent, pred_summ, gold_data = translations[idx] 223 | if ct % 100 == 0: 224 | print("Processing %d" % ct) 225 | output_file.write("ID : %d\n" % ct) 226 | output_file.write("ORIGIN : \n " + origin_sent.replace('', '\n ') + "\n") 227 | output_file.write("GOLD : " + gold_data.strip() + "\n") 228 | output_file.write("DOC_GEN : " + pred_summ.strip() + "\n") 229 | rouge_score = rouge.get_scores(pred_summ, gold_data) 230 | bleu_score = sentence_bleu([gold_data.split()], pred_summ.split(), 231 | smoothing_function=SmoothingFunction().method1) 232 | output_file.write("DOC_GEN bleu & rouge-f 1/l: %.4f & %.4f/%.4f\n\n" % 233 | (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-l"]["f"])) 234 | pred_out_file.write(pred_summ.strip() + '\n') 235 | gold_out_file.write(gold_data.strip() + '\n') 236 | ct += 1 237 | pred_out_file.flush() 238 | gold_out_file.flush() 239 | output_file.flush() 240 | 241 | pred_out_file.close() 242 | gold_out_file.close() 243 | output_file.close() 244 | 245 | if (step != -1): 246 | pred_bleu = test_bleu(pred_path, gold_path) 247 | file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) 248 | pred_rouges = file_rouge.get_scores(avg=True) 249 | self.logger.info('Gold Length at step %d: %.2f\n' % 250 | (step, test_length(gold_path, gold_path, ratio=False))) 251 | self.logger.info('Prediction Length ratio at step %d: %.2f' % 252 | (step, test_length(pred_path, gold_path))) 253 | self.logger.info('Prediction Bleu at step %d: %.2f' % 254 | (step, pred_bleu*100)) 255 | self.logger.info('Prediction Rouges at step %d: \n%s' % 256 | (step, rouge_results_to_str(pred_rouges))) 257 | 258 | def translate_batch(self, batch): 259 | """ 260 | Translate a batch of sentences. 261 | 262 | Mostly a wrapper around :obj:`Beam`. 263 | 264 | Args: 265 | batch (:obj:`Batch`): a batch from a dataset object 266 | data (:obj:`Dataset`): the dataset object 267 | fast (bool): enables fast beam search (may not support all features) 268 | 269 | Todo: 270 | Shouldn't need the original dataset. 271 | """ 272 | _, output_data, _ = self.model(batch) 273 | tgt_txt = batch.tgt_txt 274 | return output_data, tgt_txt 275 | 276 | 277 | class Translation(object): 278 | """ 279 | Container for a translated sentence. 280 | 281 | Attributes: 282 | src (`LongTensor`): src word ids 283 | src_raw ([str]): raw src words 284 | 285 | pred_sents ([[str]]): words from the n-best translations 286 | pred_scores ([[float]]): log-probs of n-best translations 287 | attns ([`FloatTensor`]) : attention dist for each translation 288 | gold_sent ([str]): words from gold translation 289 | gold_score ([float]): log-prob of gold translation 290 | 291 | """ 292 | 293 | def __init__(self, fname, src, src_raw, pred_sents, 294 | attn, pred_scores, tgt_sent, gold_score): 295 | self.fname = fname 296 | self.src = src 297 | self.src_raw = src_raw 298 | self.pred_sents = pred_sents 299 | self.attns = attn 300 | self.pred_scores = pred_scores 301 | self.gold_sent = tgt_sent 302 | self.gold_score = gold_score 303 | 304 | def log(self, sent_number): 305 | """ 306 | Log translation. 307 | """ 308 | 309 | output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) 310 | 311 | best_pred = self.pred_sents[0] 312 | best_score = self.pred_scores[0] 313 | pred_sent = ' '.join(best_pred) 314 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 315 | output += "PRED SCORE: {:.4f}\n".format(best_score) 316 | 317 | if self.gold_sent is not None: 318 | tgt_sent = ' '.join(self.gold_sent) 319 | output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) 320 | output += ("GOLD SCORE: {:.4f}\n".format(self.gold_score)) 321 | if len(self.pred_sents) > 1: 322 | output += '\nBEST HYP:\n' 323 | for score, sent in zip(self.pred_scores, self.pred_sents): 324 | output += "[{:.4f}] {}\n".format(score, sent) 325 | 326 | return output 327 | -------------------------------------------------------------------------------- /src/models/seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | import distributed 7 | from models.reporter import ReportMgr, Statistics 8 | from models.loss import abs_loss 9 | from others.logging import logger 10 | 11 | 12 | def _tally_parameters(model): 13 | n_params = sum([p.nelement() for p in model.parameters()]) 14 | return n_params 15 | 16 | 17 | def build_trainer(args, device_id, model, optims, tokenizer): 18 | """ 19 | Simplify `Trainer` creation based on user `opt`s* 20 | Args: 21 | opt (:obj:`Namespace`): user options (usually from argument parsing) 22 | model (:obj:`onmt.models.NMTModel`): the model to train 23 | fields (dict): dict of fields 24 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 25 | data_type (str): string describing the type of data 26 | e.g. "text", "img", "audio" 27 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 28 | used to save the model 29 | """ 30 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 31 | 32 | grad_accum_count = args.accum_count 33 | n_gpu = args.world_size 34 | 35 | if device_id >= 0: 36 | gpu_rank = int(args.gpu_ranks[device_id]) 37 | else: 38 | gpu_rank = 0 39 | n_gpu = 0 40 | 41 | print('gpu_rank %d' % gpu_rank) 42 | 43 | tensorboard_log_dir = args.model_path 44 | 45 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 46 | 47 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 48 | 49 | symbols = {'BOS': tokenizer.vocab['[unused1]'], 'EOS': tokenizer.vocab['[unused2]'], 50 | 'PAD': tokenizer.vocab['[PAD]'], 'SEG': tokenizer.vocab['[unused3]'], 51 | 'UNK': tokenizer.vocab['[UNK]']} 52 | 53 | gen_loss = abs_loss(args, model.generator, symbols, tokenizer.vocab, device, train=True) 54 | 55 | trainer = Trainer(args, model, optims, tokenizer, gen_loss, 56 | grad_accum_count, n_gpu, gpu_rank, report_manager) 57 | 58 | # print(tr) 59 | if (model): 60 | n_params = _tally_parameters(model) 61 | logger.info('* number of parameters: %d' % n_params) 62 | 63 | return trainer 64 | 65 | 66 | class Trainer(object): 67 | """ 68 | Class that controls the training process. 69 | 70 | Args: 71 | model(:py:class:`onmt.models.model.NMTModel`): translation model 72 | to train 73 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 74 | training loss computation 75 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 76 | training loss computation 77 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 78 | the optimizer responsible for update 79 | trunc_size(int): length of truncated back propagation through time 80 | shard_size(int): compute loss in shards of this size for efficiency 81 | data_type(string): type of the source input: [text|img|audio] 82 | norm_method(string): normalization methods: [sents|tokens] 83 | grad_accum_count(int): accumulate gradients this many times. 84 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 85 | the object that creates reports, or None 86 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 87 | used to save a checkpoint. 88 | Thus nothing will be saved if this parameter is None 89 | """ 90 | 91 | def __init__(self, args, model, optims, tokenizer, abs_loss, 92 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 93 | report_manager=None): 94 | # Basic attributes. 95 | self.args = args 96 | self.save_checkpoint_steps = args.save_checkpoint_steps 97 | self.model = model 98 | self.optims = optims 99 | self.tokenizer = tokenizer 100 | self.grad_accum_count = grad_accum_count 101 | self.n_gpu = n_gpu 102 | self.gpu_rank = gpu_rank 103 | self.report_manager = report_manager 104 | 105 | self.abs_loss = abs_loss 106 | 107 | assert grad_accum_count > 0 108 | # Set model in training mode. 109 | if (model): 110 | self.model.train() 111 | 112 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 113 | """ 114 | The main training loops. 115 | by iterating over training data (i.e. `train_iter_fct`) 116 | and running validation (i.e. iterating over `valid_iter_fct` 117 | 118 | Args: 119 | train_iter_fct(function): a function that returns the train 120 | iterator. e.g. something like 121 | train_iter_fct = lambda: generator(*args, **kwargs) 122 | valid_iter_fct(function): same as train_iter_fct, for valid data 123 | train_steps(int): 124 | valid_steps(int): 125 | save_checkpoint_steps(int): 126 | 127 | Return: 128 | None 129 | """ 130 | logger.info('Start training...') 131 | 132 | step = self.optims[0]._step + 1 133 | true_batchs = [] 134 | accum = 0 135 | tgt_tokens = 0 136 | src_tokens = 0 137 | tgt_labels = 0 138 | sents = 0 139 | examples = 0 140 | 141 | train_iter = train_iter_fct() 142 | total_stats = Statistics() 143 | report_stats = Statistics() 144 | self._start_report_manager(start_time=total_stats.start_time) 145 | 146 | while step <= train_steps: 147 | 148 | for i, batch in enumerate(train_iter): 149 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 150 | 151 | true_batchs.append(batch) 152 | tgt_tokens += batch.tgt[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 153 | src_tokens += batch.src[:, 1:].ne(self.abs_loss.padding_idx).sum().item() 154 | tgt_labels += sum([len(l)+1 for l in batch.tgt_labels]) 155 | sents += batch.src.size(0) 156 | examples += batch.tgt.size(0) 157 | accum += 1 158 | if accum == self.grad_accum_count: 159 | if self.n_gpu > 1: 160 | tgt_tokens = sum(distributed.all_gather_list(tgt_tokens)) 161 | src_tokens = sum(distributed.all_gather_list(src_tokens)) 162 | tgt_labels = sum(distributed.all_gather_list(tgt_labels)) 163 | sents = sum(distributed.all_gather_list(sents)) 164 | examples = sum(distributed.all_gather_list(examples)) 165 | 166 | normalization = (tgt_tokens, src_tokens, tgt_labels, sents, examples) 167 | self._gradient_calculation( 168 | true_batchs, normalization, total_stats, 169 | report_stats, step) 170 | 171 | report_stats = self._maybe_report_training( 172 | step, train_steps, 173 | self.optims[0].learning_rate, 174 | report_stats) 175 | 176 | true_batchs = [] 177 | accum = 0 178 | src_tokens = 0 179 | tgt_tokens = 0 180 | tgt_labels = 0 181 | sents = 0 182 | examples = 0 183 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 184 | self._save(step) 185 | step += 1 186 | if step > train_steps: 187 | break 188 | train_iter = train_iter_fct() 189 | 190 | return total_stats 191 | 192 | def _gradient_calculation(self, true_batchs, normalization, total_stats, 193 | report_stats, step): 194 | self.model.zero_grad() 195 | 196 | for batch in true_batchs: 197 | decode_output, _, attn = self.model(batch) 198 | 199 | tgt_tokens, src_tokens, tgt_labels, sents, examples = normalization 200 | 201 | # Generation loss 202 | abs_stats = self.abs_loss(batch, decode_output, self.args.generator_shard_size, tgt_tokens, attns=attn) 203 | abs_stats.n_docs = len(batch) 204 | total_stats.update(abs_stats) 205 | report_stats.update(abs_stats) 206 | 207 | # in case of multi step gradient accumulation, 208 | # update only after accum batches 209 | if self.n_gpu > 1: 210 | grads = [p.grad.data for p in self.model.parameters() 211 | if p.requires_grad 212 | and p.grad is not None] 213 | distributed.all_reduce_and_rescale_tensors( 214 | grads, float(1)) 215 | for o in self.optims: 216 | o.step() 217 | 218 | def _save(self, step): 219 | real_model = self.model 220 | 221 | model_state_dict = real_model.state_dict() 222 | # generator_state_dict = real_generator.state_dict() 223 | checkpoint = { 224 | 'model': model_state_dict, 225 | # 'generator': generator_state_dict, 226 | 'opt': self.args, 227 | 'optims': self.optims, 228 | } 229 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 230 | logger.info("Saving checkpoint %s" % checkpoint_path) 231 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 232 | torch.save(checkpoint, checkpoint_path) 233 | return checkpoint, checkpoint_path 234 | 235 | def _start_report_manager(self, start_time=None): 236 | """ 237 | Simple function to start report manager (if any) 238 | """ 239 | if self.report_manager is not None: 240 | if start_time is None: 241 | self.report_manager.start() 242 | else: 243 | self.report_manager.start_time = start_time 244 | 245 | def _maybe_gather_stats(self, stat): 246 | """ 247 | Gather statistics in multi-processes cases 248 | 249 | Args: 250 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 251 | or None (it returns None in this case) 252 | 253 | Returns: 254 | stat: the updated (or unchanged) stat object 255 | """ 256 | if stat is not None and self.n_gpu > 1: 257 | return Statistics.all_gather_stats(stat) 258 | return stat 259 | 260 | def _maybe_report_training(self, step, num_steps, learning_rate, 261 | report_stats): 262 | """ 263 | Simple function to report training stats (if report_manager is set) 264 | see `onmt.utils.ReportManagerBase.report_training` for doc 265 | """ 266 | if self.report_manager is not None: 267 | return self.report_manager.report_training( 268 | step, num_steps, learning_rate, report_stats, 269 | multigpu=self.n_gpu > 1) 270 | 271 | def _report_step(self, learning_rate, step, train_stats=None, 272 | valid_stats=None): 273 | """ 274 | Simple function to report stats (if report_manager is set) 275 | see `onmt.utils.ReportManagerBase.report_step` for doc 276 | """ 277 | if self.report_manager is not None: 278 | return self.report_manager.report_step( 279 | learning_rate, step, train_stats=train_stats, 280 | valid_stats=valid_stats) 281 | 282 | def _maybe_save(self, step): 283 | """ 284 | Save the model if a model saver is set 285 | """ 286 | if self.model_saver is not None: 287 | self.model_saver.maybe_save(step) 288 | -------------------------------------------------------------------------------- /src/models/topic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TopicModel(nn.Module): 8 | 9 | def __init__(self, vocab_size, hidden_dim, topic_num, noise_rate=0.5): 10 | 11 | super(TopicModel, self).__init__() 12 | self.hidden_dim = hidden_dim 13 | self.topic_num = topic_num 14 | self.noise_rate = noise_rate 15 | self.mlp = nn.Sequential( 16 | nn.Linear(vocab_size, 2*hidden_dim), 17 | nn.Tanh() 18 | ) 19 | self.mu_linear = nn.Linear(2*hidden_dim, hidden_dim) 20 | self.sigma_linear = nn.Linear(2*hidden_dim, hidden_dim) 21 | self.theta_linear = nn.Linear(hidden_dim, topic_num) 22 | self.topic_emb = nn.Parameter(torch.empty(topic_num, hidden_dim)) 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | else: 27 | p.data.zero_() 28 | 29 | def forward(self, bow_repre, voc_emb, summ_target=None): 30 | 31 | id_mask = bow_repre.gt(0).float() 32 | # bow_valid = bow_repre.gt(0) 33 | # stopnum = (bow_valid.sum(dim=-1).float() * (1-self.stop_word_rate)).long() 34 | # threshold = bow_repre.sort(dim=-1, descending=True)[0].index_select(-1, stopnum).diagonal(0) 35 | # id_mask = bow_repre.gt(threshold.unsqueeze(-1)).float() 36 | # id_mask = bow_repre.gt(torch.mean(bow_repre[bow_repre.gt(0)])).float() 37 | # id_mask = bow_repre.gt(torch.medium(bow_repre, dim=-1)) 38 | 39 | # Inference Stage 40 | linear_output = self.mlp(bow_repre) 41 | mu = self.mu_linear(linear_output) 42 | log_sigma_sq = self.sigma_linear(linear_output) 43 | 44 | eps = torch.empty_like(mu).float().normal_() 45 | sigma = torch.sqrt(torch.exp(log_sigma_sq)) 46 | 47 | if self.training: 48 | h = mu + sigma * eps 49 | else: 50 | h = mu 51 | 52 | theta_logits = self.theta_linear(h) 53 | 54 | e_loss = -0.5 * torch.sum(1 + log_sigma_sq - mu.pow(2) - torch.exp(log_sigma_sq)) 55 | 56 | # Generation Stage 57 | self.beta = beta = F.softmax(torch.matmul(self.topic_emb, voc_emb.transpose(0, 1)) / math.sqrt(self.hidden_dim), dim=-1) 58 | 59 | if summ_target is not None: 60 | 61 | summ_topic_num = int(self.topic_num * (1-self.noise_rate)) 62 | # build noise target 63 | noise_target = (id_mask != summ_target).float() 64 | 65 | summ_mask = torch.zeros_like(theta_logits) 66 | summ_mask[:, summ_topic_num:] = -float('inf') 67 | 68 | noise_mask = torch.zeros_like(theta_logits) 69 | noise_mask[:, :summ_topic_num] = -float('inf') 70 | 71 | theta_summ = F.softmax(theta_logits + summ_mask, dim=-1) 72 | theta_noise = F.softmax(theta_logits + noise_mask, dim=-1) 73 | 74 | logits_summ = torch.log(torch.matmul(theta_summ, beta) + 1e-40) 75 | logits_noise = torch.log(torch.matmul(theta_noise, beta) + 1e-40) 76 | 77 | g_loss = - torch.sum(logits_summ * summ_target) - torch.sum(logits_noise * noise_target) 78 | # topic_emb = torch.cat([torch.matmul(theta_summ, self.topic_emb), 79 | # torch.matmul(theta_noise, self.topic_emb)], -1) 80 | topic_emb = (torch.matmul(theta_summ, self.topic_emb), torch.matmul(theta_noise, self.topic_emb)) 81 | else: 82 | theta = F.softmax(theta_logits, dim=-1) 83 | logits = torch.log(torch.matmul(theta, beta) + 1e-40) 84 | g_loss = - torch.sum(logits * id_mask) 85 | topic_emb = torch.matmul(theta, self.topic_emb) 86 | 87 | return e_loss + g_loss, topic_emb 88 | 89 | 90 | class MultiTopicModel(nn.Module): 91 | 92 | def __init__(self, vocab_size, hidden_dim, topic_num, noise_rate, embeddings, agent=False, cust=False): 93 | 94 | super(MultiTopicModel, self).__init__() 95 | self.embeddings = nn.Parameter(embeddings) 96 | self.agent = agent 97 | self.cust = cust 98 | self.tm1 = TopicModel(vocab_size, hidden_dim, topic_num, noise_rate) 99 | if cust: 100 | self.tm2 = TopicModel(vocab_size, hidden_dim, topic_num, noise_rate) 101 | if agent: 102 | self.tm3 = TopicModel(vocab_size, hidden_dim, topic_num, noise_rate) 103 | 104 | def forward(self, all_bow, customer_bow, agent_bow, 105 | summ_all_target=None, summ_customer_target=None, summ_agent_target=None): 106 | 107 | loss_all, emb_all = self.tm1(all_bow, self.embeddings, summ_all_target) 108 | if self.cust: 109 | loss_customer, emb_customer = self.tm2(customer_bow, self.embeddings, summ_customer_target) 110 | else: 111 | loss_customer, emb_customer = 0, None 112 | if self.agent: 113 | loss_agent, emb_agent = self.tm3(agent_bow, self.embeddings, summ_agent_target) 114 | else: 115 | loss_agent, emb_agent = 0, None 116 | loss = loss_all + loss_agent + loss_customer 117 | 118 | return loss, (emb_all, emb_customer, emb_agent) 119 | -------------------------------------------------------------------------------- /src/models/topic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn.init import xavier_uniform_ 5 | 6 | from models.topic import MultiTopicModel 7 | from others.vocab_wrapper import VocabWrapper 8 | from others.id_wrapper import VocIDWrapper 9 | 10 | 11 | class Model(nn.Module): 12 | def __init__(self, args, device, vocab, checkpoint=None): 13 | super(Model, self).__init__() 14 | self.args = args 15 | self.device = device 16 | self.voc_id_wrapper = VocIDWrapper('pretrain_emb/id_word2vec.voc.txt') 17 | # Topic Model 18 | # Using Golve or Word2vec embedding 19 | if self.args.tokenize: 20 | self.voc_wrapper = VocabWrapper(self.args.word_emb_mode) 21 | self.voc_wrapper.load_emb(self.args.word_emb_path) 22 | self.voc_emb = torch.tensor(self.voc_wrapper.get_emb()) 23 | else: 24 | self.voc_emb = torch.empty(self.vocab_size, self.args.word_emb_size) 25 | xavier_uniform_(self.voc_emb) 26 | # self.voc_emb.weight = copy.deepcopy(self.encoder.model.embeddings.word_embeddings.weight) 27 | self.topic_model = MultiTopicModel(self.voc_emb.size(0), self.voc_emb.size(-1), 28 | args.topic_num, self.voc_emb, agent=True, cust=True) 29 | 30 | if checkpoint is not None: 31 | self.load_state_dict(checkpoint['model'], strict=True) 32 | 33 | self.to(device) 34 | 35 | def forward(self, batch): 36 | 37 | all_bow, customer_bow, agent_bow = \ 38 | batch.all_bow, batch.customer_bow, batch.agent_bow 39 | topic_loss, _ = self.topic_model(all_bow, customer_bow, agent_bow) 40 | 41 | return topic_loss 42 | -------------------------------------------------------------------------------- /src/models/topic_model_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | 6 | import distributed 7 | from models.reporter import ReportMgr, Statistics 8 | from models.loss import abs_loss, LogisticLossCompute 9 | from others.logging import logger 10 | 11 | 12 | def _tally_parameters(model): 13 | n_params = sum([p.nelement() for p in model.parameters()]) 14 | return n_params 15 | 16 | 17 | def build_trainer(args, device_id, model, optims, tokenizer): 18 | """ 19 | Simplify `Trainer` creation based on user `opt`s* 20 | Args: 21 | opt (:obj:`Namespace`): user options (usually from argument parsing) 22 | model (:obj:`onmt.models.NMTModel`): the model to train 23 | fields (dict): dict of fields 24 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 25 | data_type (str): string describing the type of data 26 | e.g. "text", "img", "audio" 27 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 28 | used to save the model 29 | """ 30 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 31 | 32 | grad_accum_count = args.accum_count 33 | n_gpu = args.world_size 34 | 35 | if device_id >= 0: 36 | gpu_rank = int(args.gpu_ranks[device_id]) 37 | else: 38 | gpu_rank = 0 39 | n_gpu = 0 40 | 41 | print('gpu_rank %d' % gpu_rank) 42 | 43 | tensorboard_log_dir = args.model_path 44 | 45 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 46 | 47 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 48 | 49 | trainer = Trainer(args, model, optims, tokenizer, grad_accum_count, n_gpu, gpu_rank, report_manager) 50 | 51 | # print(tr) 52 | if (model): 53 | n_params = _tally_parameters(model) 54 | logger.info('* number of parameters: %d' % n_params) 55 | 56 | return trainer 57 | 58 | 59 | class Trainer(object): 60 | """ 61 | Class that controls the training process. 62 | 63 | Args: 64 | model(:py:class:`onmt.models.model.NMTModel`): translation model 65 | to train 66 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 67 | training loss computation 68 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 69 | training loss computation 70 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 71 | the optimizer responsible for update 72 | trunc_size(int): length of truncated back propagation through time 73 | shard_size(int): compute loss in shards of this size for efficiency 74 | data_type(string): type of the source input: [text|img|audio] 75 | norm_method(string): normalization methods: [sents|tokens] 76 | grad_accum_count(int): accumulate gradients this many times. 77 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 78 | the object that creates reports, or None 79 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 80 | used to save a checkpoint. 81 | Thus nothing will be saved if this parameter is None 82 | """ 83 | 84 | def __init__(self, args, model, optims, tokenizer, 85 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 86 | report_manager=None): 87 | # Basic attributes. 88 | self.args = args 89 | self.save_checkpoint_steps = args.save_checkpoint_steps 90 | self.model = model 91 | self.optims = optims 92 | self.tokenizer = tokenizer 93 | self.grad_accum_count = grad_accum_count 94 | self.n_gpu = n_gpu 95 | self.gpu_rank = gpu_rank 96 | self.report_manager = report_manager 97 | 98 | assert grad_accum_count > 0 99 | # Set model in training mode. 100 | if (model): 101 | self.model.train() 102 | 103 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 104 | """ 105 | The main training loops. 106 | by iterating over training data (i.e. `train_iter_fct`) 107 | and running validation (i.e. iterating over `valid_iter_fct` 108 | 109 | Args: 110 | train_iter_fct(function): a function that returns the train 111 | iterator. e.g. something like 112 | train_iter_fct = lambda: generator(*args, **kwargs) 113 | valid_iter_fct(function): same as train_iter_fct, for valid data 114 | train_steps(int): 115 | valid_steps(int): 116 | save_checkpoint_steps(int): 117 | 118 | Return: 119 | None 120 | """ 121 | logger.info('Start training...') 122 | 123 | step = self.optims[0]._step + 1 124 | true_batchs = [] 125 | accum = 0 126 | examples = 0 127 | 128 | train_iter = train_iter_fct() 129 | total_stats = Statistics() 130 | report_stats = Statistics() 131 | self._start_report_manager(start_time=total_stats.start_time) 132 | 133 | while step <= train_steps: 134 | 135 | for i, batch in enumerate(train_iter): 136 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 137 | 138 | true_batchs.append(batch) 139 | examples += batch.tgt.size(0) 140 | accum += 1 141 | if accum == self.grad_accum_count: 142 | if self.n_gpu > 1: 143 | examples = sum(distributed.all_gather_list(examples)) 144 | 145 | self._gradient_calculation( 146 | true_batchs, examples, total_stats, 147 | report_stats, step) 148 | 149 | report_stats = self._maybe_report_training( 150 | step, train_steps, 151 | self.optims[0].learning_rate, 152 | report_stats) 153 | 154 | true_batchs = [] 155 | accum = 0 156 | examples = 0 157 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 158 | self._save(step) 159 | step += 1 160 | if step > train_steps: 161 | break 162 | train_iter = train_iter_fct() 163 | 164 | return total_stats 165 | 166 | def _gradient_calculation(self, true_batchs, examples, total_stats, 167 | report_stats, step): 168 | self.model.zero_grad() 169 | 170 | for batch in true_batchs: 171 | loss = self.model(batch) 172 | 173 | # Topic Model loss 174 | topic_stats = Statistics(topic_loss=loss.clone().item() / float(examples)) 175 | loss.div(float(examples)).backward(retain_graph=False) 176 | total_stats.update(topic_stats) 177 | report_stats.update(topic_stats) 178 | 179 | if step % 1000 == 0: 180 | for k in range(self.args.topic_num): 181 | logger.info(','.join([self.model.voc_id_wrapper.i2w(i) for i in 182 | self.model.topic_model.tm1.beta.topk(20, dim=-1)[1][k].tolist()])) 183 | # in case of multi step gradient accumulation, 184 | # update only after accum batches 185 | if self.n_gpu > 1: 186 | grads = [p.grad.data for p in self.model.parameters() 187 | if p.requires_grad 188 | and p.grad is not None] 189 | distributed.all_reduce_and_rescale_tensors( 190 | grads, float(1)) 191 | for o in self.optims: 192 | o.step() 193 | 194 | def _save(self, step): 195 | real_model = self.model 196 | 197 | model_state_dict = real_model.state_dict() 198 | # generator_state_dict = real_generator.state_dict() 199 | checkpoint = { 200 | 'model': model_state_dict, 201 | # 'generator': generator_state_dict, 202 | 'opt': self.args, 203 | 'optims': self.optims, 204 | } 205 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 206 | logger.info("Saving checkpoint %s" % checkpoint_path) 207 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 208 | torch.save(checkpoint, checkpoint_path) 209 | return checkpoint, checkpoint_path 210 | 211 | def _start_report_manager(self, start_time=None): 212 | """ 213 | Simple function to start report manager (if any) 214 | """ 215 | if self.report_manager is not None: 216 | if start_time is None: 217 | self.report_manager.start() 218 | else: 219 | self.report_manager.start_time = start_time 220 | 221 | def _maybe_gather_stats(self, stat): 222 | """ 223 | Gather statistics in multi-processes cases 224 | 225 | Args: 226 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 227 | or None (it returns None in this case) 228 | 229 | Returns: 230 | stat: the updated (or unchanged) stat object 231 | """ 232 | if stat is not None and self.n_gpu > 1: 233 | return Statistics.all_gather_stats(stat) 234 | return stat 235 | 236 | def _maybe_report_training(self, step, num_steps, learning_rate, 237 | report_stats): 238 | """ 239 | Simple function to report training stats (if report_manager is set) 240 | see `onmt.utils.ReportManagerBase.report_training` for doc 241 | """ 242 | if self.report_manager is not None: 243 | return self.report_manager.report_training( 244 | step, num_steps, learning_rate, report_stats, 245 | multigpu=self.n_gpu > 1) 246 | 247 | def _report_step(self, learning_rate, step, train_stats=None, 248 | valid_stats=None): 249 | """ 250 | Simple function to report stats (if report_manager is set) 251 | see `onmt.utils.ReportManagerBase.report_step` for doc 252 | """ 253 | if self.report_manager is not None: 254 | return self.report_manager.report_step( 255 | learning_rate, step, train_stats=train_stats, 256 | valid_stats=valid_stats) 257 | 258 | def _maybe_save(self, step): 259 | """ 260 | Save the model if a model saver is set 261 | """ 262 | if self.model_saver is not None: 263 | self.model_saver.maybe_save(step) 264 | -------------------------------------------------------------------------------- /src/others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/src/others/__init__.py -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /src/others/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk.translate.bleu_score import sentence_bleu 3 | from nltk.translate.bleu_score import SmoothingFunction 4 | 5 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 6 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 7 | 8 | 9 | def clean(x): 10 | return re.sub( 11 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 12 | lambda m: REMAP.get(m.group()), x) 13 | 14 | 15 | def test_bleu(cand, ref): 16 | candidate = [line.strip() for line in open(cand, encoding='utf-8')] 17 | reference = [line.strip() for line in open(ref, encoding='utf-8')] 18 | if len(reference) != len(candidate): 19 | raise ValueError('The number of sentences in both files do not match.') 20 | if len(reference) == 0: 21 | return 0 22 | score = 0. 23 | for i in range(len(reference)): 24 | gold_list = reference[i].split() 25 | cand_list = candidate[i].split() 26 | score += sentence_bleu([gold_list], cand_list, smoothing_function=SmoothingFunction().method1) 27 | score /= len(reference) 28 | return score 29 | 30 | 31 | def test_length(cand, ref, ratio=True): 32 | candidate = [len(line.split()) for line in open(cand, encoding='utf-8')] 33 | if len(candidate) == 0: 34 | return 0 35 | if ratio: 36 | reference = [len(line.split()) for line in open(ref, encoding='utf-8')] 37 | score = sum([candidate[i] / reference[i] for i in range(len(candidate))]) / len(candidate) 38 | else: 39 | score = sum(candidate) / len(candidate) 40 | return score 41 | 42 | 43 | def tile(x, count, dim=0): 44 | """ 45 | Tiles x on dimension dim count times. 46 | """ 47 | if x is None: 48 | return None 49 | perm = list(range(len(x.size()))) 50 | if dim != 0: 51 | perm[0], perm[dim] = perm[dim], perm[0] 52 | x = x.permute(perm).contiguous() 53 | out_size = list(x.size()) 54 | out_size[0] *= count 55 | batch = x.size(0) 56 | x = x.contiguous()\ 57 | .view(batch, -1) \ 58 | .transpose(0, 1) \ 59 | .repeat(count, 1) \ 60 | .transpose(0, 1) \ 61 | .contiguous() \ 62 | .view(*out_size) 63 | if dim != 0: 64 | x = x.permute(perm).contiguous() 65 | return x 66 | 67 | 68 | def test_f1(acc_num, pred_num, gold_num): 69 | p = acc_num / pred_num * 1. 70 | r = acc_num / gold_num * 1. 71 | if p == 0. and r == 0.: 72 | f1 = -1 73 | else: 74 | f1 = (2 * p * r) / (p + r) 75 | return f1, p, r 76 | 77 | 78 | """ 79 | def rouge_results_to_str(results_dict): 80 | if results_dict is None: 81 | return "No Results.\n" 82 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-P(1/2/l): {:.2f}/{:.2f}/{:.2f}\n".format( 83 | results_dict["rouge_1_f_score"] * 100, 84 | results_dict["rouge_2_f_score"] * 100, 85 | results_dict["rouge_l_f_score"] * 100, 86 | results_dict["rouge_1_recall"] * 100, 87 | results_dict["rouge_2_recall"] * 100, 88 | results_dict["rouge_l_recall"] * 100, 89 | results_dict["rouge_1_precision"] * 100, 90 | results_dict["rouge_2_precision"] * 100, 91 | results_dict["rouge_l_precision"] * 100 92 | ) 93 | """ 94 | 95 | 96 | def rouge_results_to_str(results_dict): 97 | if results_dict is None: 98 | return "No Results.\n" 99 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-P(1/2/l): {:.2f}/{:.2f}/{:.2f}\n".format( 100 | results_dict["rouge-1"]['f'] * 100, 101 | results_dict["rouge-2"]['f'] * 100, 102 | results_dict["rouge-l"]['f'] * 100, 103 | results_dict["rouge-1"]['r'] * 100, 104 | results_dict["rouge-2"]['r'] * 100, 105 | results_dict["rouge-l"]['r'] * 100, 106 | results_dict["rouge-1"]['p'] * 100, 107 | results_dict["rouge-2"]['p'] * 100, 108 | results_dict["rouge-l"]['p'] * 100 109 | ) 110 | -------------------------------------------------------------------------------- /src/others/vocab_wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | from gensim.models import Word2Vec 3 | from gensim.models import KeyedVectors 4 | # from glove import Corpus 5 | # from glove import Glove 6 | 7 | 8 | class VocabWrapper(object): 9 | # Glove has not been implemented. 10 | def __init__(self, mode="word2vec", emb_size=100): 11 | self.mode = mode 12 | self.emb_size = emb_size 13 | self.model = None 14 | self.emb = None 15 | 16 | def _glove_init(self): 17 | pass 18 | 19 | def _word2vec_init(self): 20 | self.model = Word2Vec(size=self.emb_size, window=5, min_count=30, workers=4) 21 | 22 | def _glove_train(self, ex): 23 | pass 24 | 25 | def _word2vec_train(self, ex): 26 | if self.model.wv.vectors.size == 0: 27 | self.model.build_vocab(ex, update=False) 28 | else: 29 | self.model.build_vocab(ex, update=True) 30 | self.model.train(ex, total_examples=self.model.corpus_count, epochs=1) 31 | 32 | def _glove_report(self): 33 | pass 34 | 35 | def _word2vec_report(self): 36 | if self.model is not None: 37 | print("Total examples: %d" % self.model.corpus_count) 38 | print("Vocab Size: %d" % len(self.model.wv.vocab)) 39 | else: 40 | print("Vocab Size: %d" % len(self.emb.vocab)) 41 | 42 | def _glove_save_model(self, path): 43 | pass 44 | 45 | def _word2vec_save_model(self, path): 46 | self._word2vec_report() 47 | self.model.save(path) 48 | 49 | def _glove_load_model(self, path): 50 | pass 51 | 52 | def _word2vec_load_model(self, path): 53 | self.model = Word2Vec.load(path) 54 | 55 | def _glove_save_emb(self, path): 56 | pass 57 | 58 | def _word2vec_save_emb(self, path): 59 | self._word2vec_report() 60 | if self.model is not None: 61 | self.model.wv.save(path) 62 | else: 63 | self.emb.save(path) 64 | 65 | def _glove_load_emb(self, path): 66 | pass 67 | 68 | def _word2vec_load_emb(self, path): 69 | self.emb = KeyedVectors.load(path) 70 | self.emb_size = self.emb.vector_size 71 | 72 | def _w2i_glove(self, w): 73 | return None 74 | 75 | def _w2i_word2vec(self, w): 76 | if self.emb is not None: 77 | if w in self.emb.vocab.keys(): 78 | return self.emb.vocab[w].index 79 | if self.model is not None: 80 | if w in self.model.wv.vocab.keys(): 81 | return self.model.wv.vocab[w].index 82 | return None 83 | 84 | def _i2w_glove(self, idx): 85 | return None 86 | 87 | def _i2w_word2vec(self, idx): 88 | if self.emb is not None: 89 | if idx < len(self.emb.vocab): 90 | return self.emb.index2word[idx] 91 | if self.model is not None: 92 | if idx < len(self.model.wv.vocab): 93 | return self.model.wv.index2word[idx] 94 | return None 95 | 96 | def _i2e_glove(self, idx): 97 | return None 98 | 99 | def _i2e_word2vec(self, idx): 100 | if self.emb is not None: 101 | if idx < len(self.emb.vocab): 102 | return self.emb.vectors[idx] 103 | if self.model is not None: 104 | if idx < len(self.model.wv.vocab): 105 | return self.model.wv.vectors[idx] 106 | return None 107 | 108 | def _w2e_glove(self, w): 109 | return None 110 | 111 | def _w2e_word2vec(self, w): 112 | if self.emb is not None: 113 | if w in self.emb.vocab.keys(): 114 | return self.emb[w] 115 | if self.model is not None: 116 | if w in self.model.wv.vocab.keys(): 117 | return self.model.wv[w] 118 | return None 119 | 120 | def _voc_size_glove(self): 121 | return -1 122 | 123 | def _voc_size_word2vec(self): 124 | if self.emb is not None: 125 | return len(self.emb.vocab) 126 | if self.model is not None: 127 | return len(self.model.wv.vocab) 128 | return -1 129 | 130 | def _get_emb_glove(self): 131 | return None 132 | 133 | def _get_emb_word2vec(self): 134 | if self.emb is not None: 135 | return self.emb.vectors 136 | if self.model is not None: 137 | return self.model.wv.vectors 138 | return None 139 | 140 | def init_model(self): 141 | if self.mode == "glove": 142 | self._glove_init() 143 | else: 144 | self._word2vec_init() 145 | 146 | def train(self, ex): 147 | """ 148 | ex: training examples. 149 | [['我', '爱', '中国', '。'], 150 | ['这', '是', '一个', '句子', '。']] 151 | """ 152 | if self.mode == "glove": 153 | self._glove_train(ex) 154 | else: 155 | self._word2vec_train(ex) 156 | 157 | def report(self): 158 | if self.mode == "glove": 159 | self._glove_report() 160 | else: 161 | self._word2vec_report() 162 | 163 | def save_model(self, path): 164 | if self.mode == "glove": 165 | self._glove_save_model(path) 166 | else: 167 | self._word2vec_save_model(path) 168 | 169 | def load_model(self, path): 170 | if self.mode == "glove": 171 | self._glove_load_model(path) 172 | else: 173 | self._word2vec_load_model(path) 174 | 175 | def save_emb(self, path): 176 | if self.mode == "glove": 177 | self._glove_save_emb(path) 178 | else: 179 | self._word2vec_save_emb(path) 180 | 181 | def load_emb(self, path): 182 | if self.mode == "glove": 183 | self._glove_load_emb(path) 184 | else: 185 | self._word2vec_load_emb(path) 186 | 187 | def w2i(self, w): 188 | if self.mode == "glove": 189 | return self._w2i_glove(w) 190 | else: 191 | return self._w2i_word2vec(w) 192 | 193 | def i2w(self, idx): 194 | if self.mode == "glove": 195 | return self._i2w_glove(idx) 196 | else: 197 | return self._i2w_word2vec(idx) 198 | 199 | def w2e(self, w): 200 | if self.mode == "glove": 201 | return self._w2e_glove(w) 202 | else: 203 | return self._w2e_word2vec(w) 204 | 205 | def i2e(self, idx): 206 | if self.mode == "glove": 207 | return self._i2e_glove(idx) 208 | else: 209 | return self._i2e_word2vec(idx) 210 | 211 | def voc_size(self): 212 | if self.mode == "glove": 213 | return self._voc_size_glove() 214 | else: 215 | return self._voc_size_word2vec() 216 | 217 | def get_emb(self): 218 | if self.mode == "glove": 219 | return self._get_emb_glove() 220 | else: 221 | return self._get_emb_word2vec() 222 | -------------------------------------------------------------------------------- /src/prepro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/src/prepro/__init__.py -------------------------------------------------------------------------------- /src/prepro/data_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import gc 4 | import glob 5 | import json 6 | import os 7 | import random 8 | import torch 9 | from os.path import join as pjoin 10 | 11 | from collections import Counter 12 | from rouge import Rouge 13 | from others.logging import logger 14 | from others.tokenization import BertTokenizer 15 | from others.vocab_wrapper import VocabWrapper 16 | 17 | 18 | def greedy_selection(doc, summ, summary_size): 19 | 20 | doc_sents = list(map(lambda x: x["original_txt"], doc)) 21 | max_rouge = 0.0 22 | 23 | rouge = Rouge() 24 | selected = [] 25 | while True: 26 | cur_max_rouge = max_rouge 27 | cur_id = -1 28 | for i in range(len(doc_sents)): 29 | if (i in selected): 30 | continue 31 | c = selected + [i] 32 | temp_txt = " ".join([doc_sents[j] for j in c]) 33 | if len(temp_txt.split()) > summary_size: 34 | continue 35 | rouge_score = rouge.get_scores(temp_txt, summ) 36 | rouge_1 = rouge_score[0]["rouge-1"]["r"] 37 | rouge_l = rouge_score[0]["rouge-l"]["r"] 38 | rouge_score = rouge_1 + rouge_l 39 | if rouge_score > cur_max_rouge: 40 | cur_max_rouge = rouge_score 41 | cur_id = i 42 | if (cur_id == -1): 43 | return selected 44 | selected.append(cur_id) 45 | max_rouge = cur_max_rouge 46 | 47 | return selected 48 | 49 | 50 | class BertData(): 51 | def __init__(self, args): 52 | self.args = args 53 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_dir) 54 | 55 | self.sep_token = '[SEP]' 56 | self.cls_token = '[CLS]' 57 | self.pad_token = '[PAD]' 58 | self.unk_token = '[UNK]' 59 | self.tgt_bos = '[unused1]' 60 | self.tgt_eos = '[unused2]' 61 | self.role_1 = '[unused3]' 62 | self.role_2 = '[unused4]' 63 | self.sep_vid = self.tokenizer.vocab[self.sep_token] 64 | self.cls_vid = self.tokenizer.vocab[self.cls_token] 65 | self.pad_vid = self.tokenizer.vocab[self.pad_token] 66 | self.unk_vid = self.tokenizer.vocab[self.unk_token] 67 | 68 | def preprocess_src(self, content, info=None): 69 | if_exceed_length = False 70 | 71 | if not (info == "客服" or info == '客户'): 72 | return None 73 | if len(content) < self.args.min_src_ntokens_per_sent: 74 | return None 75 | if len(content) > self.args.max_src_ntokens_per_sent: 76 | if_exceed_length = True 77 | 78 | original_txt = ' '.join(content) 79 | 80 | if self.args.truncated: 81 | content = content[:self.args.max_src_ntokens_per_sent] 82 | content_text = ' '.join(content).lower() 83 | content_subtokens = self.tokenizer.tokenize(content_text) 84 | 85 | # [CLS] + T0 + T1 + ... + Tn 86 | if info == '客服': 87 | src_subtokens = [self.cls_token, self.role_1] + content_subtokens 88 | else: 89 | src_subtokens = [self.cls_token, self.role_2] + content_subtokens 90 | src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens) 91 | segments_ids = len(src_subtoken_idxs) * [0] 92 | 93 | return src_subtoken_idxs, segments_ids, original_txt, \ 94 | src_subtokens, if_exceed_length 95 | 96 | def preprocess_summary(self, content): 97 | 98 | original_txt = ' '.join(content) 99 | 100 | content_text = ' '.join(content).lower() 101 | content_subtokens = self.tokenizer.tokenize(content_text) 102 | 103 | content_subtokens = [self.tgt_bos] + content_subtokens + [self.tgt_eos] 104 | subtoken_idxs = self.tokenizer.convert_tokens_to_ids(content_subtokens) 105 | 106 | return subtoken_idxs, original_txt, content_subtokens 107 | 108 | def integrate_dialogue(self, dialogue): 109 | src_tokens = [self.cls_token] 110 | segments_ids = [0] 111 | segment_id = 0 112 | for sent in dialogue: 113 | tokens = sent["src_tokens"][1:] + [self.sep_token] 114 | src_tokens.extend(tokens) 115 | segments_ids.extend([segment_id] * len(tokens)) 116 | segment_id = 1 - segment_id 117 | src_ids = self.tokenizer.convert_tokens_to_ids(src_tokens) 118 | return {"src_id": src_ids, "segs": segments_ids} 119 | 120 | 121 | def topic_info_generate(dialogue, file_counter): 122 | all_counter = Counter() 123 | customer_counter = Counter() 124 | agent_counter = Counter() 125 | 126 | for sent in dialogue: 127 | role = sent["role"] 128 | token_ids = sent["tokenized_id"] 129 | all_counter.update(token_ids) 130 | if role == "客服": 131 | agent_counter.update(token_ids) 132 | else: 133 | customer_counter.update(token_ids) 134 | file_counter['all'].update(all_counter.keys()) 135 | file_counter['customer'].update(customer_counter.keys()) 136 | file_counter['agent'].update(agent_counter.keys()) 137 | file_counter['num'] += 1 138 | return {"all": all_counter, "customer": customer_counter, "agent": agent_counter} 139 | 140 | 141 | def topic_summ_info_generate(dialogue, ex_labels): 142 | all_counter = Counter() 143 | customer_counter = Counter() 144 | agent_counter = Counter() 145 | 146 | for i, sent in enumerate(dialogue): 147 | if i in ex_labels: 148 | role = sent["role"] 149 | token_ids = sent["tokenized_id"] 150 | all_counter.update(token_ids) 151 | if role == "客服": 152 | agent_counter.update(token_ids) 153 | else: 154 | customer_counter.update(token_ids) 155 | return {"all": all_counter, "customer": customer_counter, "agent": agent_counter} 156 | 157 | 158 | def format_to_bert(args, corpus_type=None): 159 | 160 | a_lst = [] 161 | file_counter = {"all": Counter(), "customer": Counter(), "agent": Counter(), "num": 0, "voc_size": 0} 162 | if corpus_type is not None: 163 | for json_f in glob.glob(pjoin(args.raw_path, '*' + corpus_type + '*.json')): 164 | real_name = json_f.split('/')[-1] 165 | a_lst.append((corpus_type, json_f, args, file_counter, pjoin(args.save_path, real_name.replace('json', 'bert.pt')))) 166 | else: 167 | for json_f in glob.glob(pjoin(args.raw_path, '*.json')): 168 | real_name = json_f.split('/')[-1] 169 | corpus_type = real_name.split('.')[1] 170 | a_lst.append((corpus_type, json_f, args, file_counter, pjoin(args.save_path, real_name.replace('json', 'bert.pt')))) 171 | 172 | total_statistic = { 173 | "instances": 0, 174 | "total_turns": 0., 175 | "processed_turns": 0., 176 | "max_turns": -1, 177 | "turns_num": [0] * 11, 178 | "exceed_length_num": 0, 179 | "exceed_turns_num": 0, 180 | "total_src_length": 0., 181 | "src_sent_length_num": [0] * 11, 182 | "src_token_length_num": [0] * 11, 183 | "total_tgt_length": 0 184 | } 185 | for d in a_lst: 186 | statistic = _format_to_bert(d) 187 | if statistic is None: 188 | continue 189 | total_statistic["instances"] += statistic["instances"] 190 | total_statistic["total_turns"] += statistic["total_turns"] 191 | total_statistic["processed_turns"] += statistic["processed_turns"] 192 | total_statistic["max_turns"] = max(total_statistic["max_turns"], statistic["max_turns"]) 193 | total_statistic["exceed_length_num"] += statistic["exceed_length_num"] 194 | total_statistic["exceed_turns_num"] += statistic["exceed_turns_num"] 195 | total_statistic["total_src_length"] += statistic["total_src_length"] 196 | total_statistic["total_tgt_length"] += statistic["total_tgt_length"] 197 | for idx in range(len(total_statistic["turns_num"])): 198 | total_statistic["turns_num"][idx] += statistic["turns_num"][idx] 199 | for idx in range(len(total_statistic["src_sent_length_num"])): 200 | total_statistic["src_sent_length_num"][idx] += statistic["src_sent_length_num"][idx] 201 | for idx in range(len(total_statistic["src_token_length_num"])): 202 | total_statistic["src_token_length_num"][idx] += statistic["src_token_length_num"][idx] 203 | 204 | # save file counter 205 | save_file = pjoin(args.save_path, 'idf_info.pt') 206 | logger.info('Saving file counter to %s' % save_file) 207 | torch.save(file_counter, save_file) 208 | 209 | if total_statistic["instances"] > 0: 210 | logger.info("Total examples: %d" % total_statistic["instances"]) 211 | logger.info("Average sentence number per dialogue: %f" % (total_statistic["total_turns"] / total_statistic["instances"])) 212 | logger.info("Processed average sentence number per dialogue: %f" % (total_statistic["processed_turns"] / total_statistic["instances"])) 213 | logger.info("Total sentences: %d" % total_statistic["total_turns"]) 214 | logger.info("Processed sentences: %d" % total_statistic["processed_turns"]) 215 | logger.info("Exceeded max sentence number dialogues: %d" % total_statistic["exceed_turns_num"]) 216 | logger.info("Max dialogue sentences: %d" % total_statistic["max_turns"]) 217 | for idx, num in enumerate(total_statistic["turns_num"]): 218 | logger.info("Dialogue sentences %d ~ %d: %d, %.2f%%" % (idx * 20, (idx+1) * 20, num, (num / total_statistic["instances"]))) 219 | logger.info("Exceed length sentences number: %d" % total_statistic["exceed_length_num"]) 220 | logger.info("Average src sentence length: %f" % (total_statistic["total_src_length"] / total_statistic["total_turns"])) 221 | for idx, num in enumerate(total_statistic["src_sent_length_num"]): 222 | logger.info("Sent length %d ~ %d: %d, %.2f%%" % (idx * 10, (idx+1) * 10, num, (num / total_statistic["total_turns"]))) 223 | logger.info("Average src token length: %f" % (total_statistic["total_src_length"] / total_statistic["instances"])) 224 | for idx, num in enumerate(total_statistic["src_token_length_num"]): 225 | logger.info("token num %d ~ %d: %d, %.2f%%" % (idx * 300, (idx+1) * 300, num, (num / total_statistic["instances"]))) 226 | logger.info("Average tgt length: %f" % (total_statistic["total_tgt_length"] / total_statistic["instances"])) 227 | 228 | 229 | def _format_to_bert(params): 230 | _, json_file, args, file_counter, save_file = params 231 | if (os.path.exists(save_file)): 232 | logger.info('Ignore %s' % save_file) 233 | return 234 | 235 | bert = BertData(args) 236 | 237 | logger.info('Processing %s' % json_file) 238 | jobs = json.load(open(json_file)) 239 | 240 | if args.tokenize: 241 | voc_wrapper = VocabWrapper(args.emb_mode) 242 | voc_wrapper.load_emb(args.emb_path) 243 | file_counter['voc_size'] = voc_wrapper.voc_size() 244 | 245 | datasets = [] 246 | exceed_length_num = 0 247 | exceed_turns_num = 0 248 | total_src_length = 0. 249 | total_tgt_length = 0. 250 | src_length_sent_num = [0] * 11 251 | src_length_token_num = [0] * 11 252 | max_turns = 0 253 | turns_num = [0] * 11 254 | dialogue_turns = 0. 255 | processed_turns = 0. 256 | 257 | count = 0 258 | 259 | for dialogue in jobs: 260 | dialogue_b_data = [] 261 | dialogue_token_num = 0 262 | for index, sent in enumerate(dialogue['session']): 263 | content = sent['content'] 264 | role = sent['type'] 265 | b_data = bert.preprocess_src(content, role) 266 | if (b_data is None): 267 | continue 268 | src_subtoken_idxs, segments_ids, original_txt, \ 269 | src_subtokens, exceed_length = b_data 270 | b_data_dict = {"index": index, "src_id": src_subtoken_idxs, 271 | "segs": segments_ids, "original_txt": original_txt, 272 | "src_tokens": src_subtokens, "role": role} 273 | if args.tokenize: 274 | ids = map(lambda x: voc_wrapper.w2i(x), sent['word']) 275 | tokenized_id = [x for x in ids if x is not None] 276 | b_data_dict["tokenized_id"] = tokenized_id 277 | else: 278 | b_data_dict["tokenized_id"] = src_subtoken_idxs[2:] 279 | src_length_sent_num[min(len(src_subtoken_idxs) // 10, 10)] += 1 280 | dialogue_token_num += len(src_subtoken_idxs) 281 | total_src_length += len(src_subtoken_idxs) 282 | dialogue_b_data.append(b_data_dict) 283 | if exceed_length: 284 | exceed_length_num += 1 285 | if len(dialogue_b_data) >= args.max_turns: 286 | exceed_turns_num += 1 287 | if args.truncated: 288 | break 289 | dialogue_example = {"session": dialogue_b_data} 290 | dialogue_integrated = bert.integrate_dialogue(dialogue_b_data) 291 | topic_info = topic_info_generate(dialogue_b_data, file_counter) 292 | dialogue_example["dialogue"] = dialogue_integrated 293 | dialogue_example["topic_info"] = topic_info 294 | # test & dev data process 295 | if "summary" in dialogue.keys(): 296 | content = dialogue["summary"] 297 | summ_b_data = bert.preprocess_summary(content) 298 | subtoken_idxs, original_txt, content_subtokens = summ_b_data 299 | total_tgt_length += len(subtoken_idxs) 300 | b_data_dict = {"id": subtoken_idxs, 301 | "original_txt": original_txt, 302 | "content_tokens": content_subtokens} 303 | if args.add_ex_label: 304 | ex_labels = greedy_selection(dialogue_b_data, original_txt, args.ex_max_token_num) 305 | topic_summ_info = topic_summ_info_generate(dialogue_b_data, ex_labels) 306 | b_data_dict["ex_labels"] = ex_labels 307 | b_data_dict["topic_summ_info"] = topic_summ_info 308 | dialogue_example["summary"] = b_data_dict 309 | 310 | if len(dialogue_b_data) >= args.min_turns: 311 | datasets.append(dialogue_example) 312 | turns_num[min(len(dialogue_b_data) // 20, 10)] += 1 313 | src_length_token_num[min(dialogue_token_num // 300, 10)] += 1 314 | max_turns = max(max_turns, len(dialogue_b_data)) 315 | dialogue_turns += len(dialogue['session']) 316 | processed_turns += len(dialogue_b_data) 317 | 318 | count += 1 319 | if count % 50 == 0: 320 | print(count) 321 | 322 | statistic = { 323 | "instances": len(datasets), 324 | "total_turns": dialogue_turns, 325 | "processed_turns": processed_turns, 326 | "max_turns": max_turns, 327 | "turns_num": turns_num, 328 | "exceed_length_num": exceed_length_num, 329 | "exceed_turns_num": exceed_turns_num, 330 | "total_src_length": total_src_length, 331 | "src_sent_length_num": src_length_sent_num, 332 | "src_token_length_num": src_length_token_num, 333 | "total_tgt_length": total_tgt_length 334 | } 335 | 336 | logger.info('Processed instances %d' % len(datasets)) 337 | logger.info('Saving to %s' % save_file) 338 | torch.save(datasets, save_file) 339 | datasets = [] 340 | gc.collect() 341 | return statistic 342 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import argparse 4 | from others.logging import init_logger 5 | from prepro import data_builder as data_builder 6 | 7 | 8 | def str2bool(v): 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("-pretrained_model", default='bert', type=str) 20 | 21 | parser.add_argument("-type", default='train', type=str) 22 | parser.add_argument("-raw_path", default='json_data') 23 | parser.add_argument("-save_path", default='bert_data') 24 | parser.add_argument("-shard_size", default=2000, type=int) 25 | parser.add_argument("-idlize", nargs='?', const=True, default=False) 26 | 27 | parser.add_argument("-bert_dir", default='bert/chinese_rbt3') 28 | parser.add_argument('-min_src_ntokens', default=1, type=int) 29 | parser.add_argument('-max_src_ntokens', default=3000, type=int) 30 | parser.add_argument('-min_src_ntokens_per_sent', default=1, type=int) 31 | parser.add_argument('-max_src_ntokens_per_sent', default=50, type=int) 32 | parser.add_argument('-min_tgt_ntokens', default=1, type=int) 33 | parser.add_argument('-max_tgt_ntokens', default=500, type=int) 34 | parser.add_argument('-min_turns', default=1, type=int) 35 | parser.add_argument('-max_turns', default=100, type=int) 36 | parser.add_argument("-lower", type=str2bool, nargs='?', const=True, default=True) 37 | parser.add_argument("-tokenize", type=str2bool, nargs='?', const=True, default=False) 38 | parser.add_argument("-emb_mode", default="word2vec", type=str, choices=["glove", "word2vec"]) 39 | parser.add_argument("-emb_path", default="", type=str) 40 | parser.add_argument("-ex_max_token_num", default=500, type=int) 41 | parser.add_argument("-truncated", nargs='?', const=True, default=False) 42 | parser.add_argument("-add_ex_label", nargs='?', const=True, default=False) 43 | 44 | parser.add_argument('-log_file', default='logs/preprocess.log') 45 | parser.add_argument('-dataset', default='') 46 | 47 | args = parser.parse_args() 48 | if args.type not in ["train", "dev", "test"]: 49 | print("Invalid data type! Data type should be 'train', 'dev', or 'test'.") 50 | exit(0) 51 | init_logger(args.log_file) 52 | data_builder.format_to_bert(args) 53 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import os 9 | from others.logging import init_logger 10 | from train_abstractive import validate, train, test_text, baseline 11 | 12 | 13 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size', 14 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval'] 15 | 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | # Basic args 29 | parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test', 'lead', 'oracle']) 30 | parser.add_argument("-test_mode", default='abs', type=str, choices=['ext', 'abs']) 31 | parser.add_argument("-src_data_mode", default='utt', type=str, choices=['utt', 'word']) 32 | parser.add_argument("-data_path", default='bert_data/ali') 33 | parser.add_argument("-model_path", default='models') 34 | parser.add_argument("-result_path", default='results/ali') 35 | parser.add_argument("-bert_dir", default='bert/chinese_bert') 36 | parser.add_argument('-log_file', default='logs/temp.log') 37 | parser.add_argument('-visible_gpus', default='0', type=str) 38 | parser.add_argument('-gpu_ranks', default='0', type=str) 39 | parser.add_argument('-seed', default=666, type=int) 40 | 41 | # Batch sizes 42 | parser.add_argument("-batch_size", default=2000, type=int) 43 | parser.add_argument("-batch_ex_size", default=4, type=int) 44 | parser.add_argument("-test_batch_size", default=20000, type=int) 45 | parser.add_argument("-test_batch_ex_size", default=50, type=int) 46 | 47 | # Model args 48 | parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'transformer', 'rnn']) 49 | parser.add_argument("-decoder", default='transformer', type=str, choices=['transformer', 'rnn']) 50 | parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=True) 51 | parser.add_argument("-max_pos", default=512, type=int) 52 | parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True) 53 | parser.add_argument("-dec_dropout", default=0.2, type=float) 54 | parser.add_argument("-dec_layers", default=3, type=int) 55 | parser.add_argument("-dec_hidden_size", default=768, type=int) 56 | parser.add_argument("-dec_heads", default=8, type=int) 57 | parser.add_argument("-dec_ff_size", default=2048, type=int) 58 | parser.add_argument("-enc_hidden_size", default=768, type=int) 59 | parser.add_argument("-enc_ff_size", default=2048, type=int) 60 | parser.add_argument("-enc_heads", default=8, type=int) 61 | parser.add_argument("-enc_dropout", default=0.2, type=float) 62 | parser.add_argument("-enc_layers", default=3, type=int) 63 | 64 | # args for copy mechanism and coverage 65 | parser.add_argument("-coverage", type=str2bool, nargs='?', const=True, default=False) 66 | parser.add_argument("-copy_attn", type=str2bool, nargs='?', const=True, default=False) 67 | parser.add_argument("-copy_attn_force", type=str2bool, nargs='?', const=True, default=False) 68 | parser.add_argument("-copy_loss_by_seqlength", type=str2bool, nargs='?', const=True, default=False) 69 | 70 | # args for sent-level encoder 71 | parser.add_argument("-hier_dropout", default=0.2, type=float) 72 | parser.add_argument("-hier_layers", default=2, type=int) 73 | parser.add_argument("-hier_hidden_size", default=768, type=int) 74 | parser.add_argument("-hier_heads", default=8, type=int) 75 | parser.add_argument("-hier_ff_size", default=2048, type=int) 76 | 77 | # args for topic model 78 | parser.add_argument("-topic_model", type=str2bool, nargs='?', const=True, default=False) 79 | parser.add_argument("-loss_lambda", default=0.001, type=float) 80 | parser.add_argument("-tokenize", type=str2bool, nargs='?', const=True, default=True) 81 | parser.add_argument("-idf_info_path", default="bert_data/idf_info.pt") 82 | parser.add_argument("-topic_num", default=50, type=int) 83 | parser.add_argument("-word_emb_size", default=100, type=int) 84 | parser.add_argument("-word_emb_mode", default="word2vec", type=str, choices=["glove", "word2vec"]) 85 | parser.add_argument("-word_emb_path", default="pretrain_emb/word2vec", type=str) 86 | parser.add_argument("-use_idf", type=str2bool, nargs='?', const=True, default=False) 87 | parser.add_argument("-split_noise", type=str2bool, nargs='?', const=True, default=False) 88 | parser.add_argument("-max_word_count", default=6000, type=int) 89 | parser.add_argument("-min_word_count", default=5, type=int) 90 | parser.add_argument("-agent", type=str2bool, nargs='?', const=True, default=True) 91 | parser.add_argument("-cust", type=str2bool, nargs='?', const=True, default=True) 92 | parser.add_argument("-noise_rate", type=float, default=0.5) 93 | 94 | # Training process args 95 | parser.add_argument("-save_checkpoint_steps", default=2000, type=int) 96 | parser.add_argument("-accum_count", default=2, type=int) 97 | parser.add_argument("-report_every", default=5, type=int) 98 | parser.add_argument("-train_steps", default=80000, type=int) 99 | parser.add_argument("-label_smoothing", default=0.1, type=float) 100 | parser.add_argument("-generator_shard_size", default=32, type=int) 101 | parser.add_argument("-max_tgt_len", default=100, type=int) 102 | 103 | # Beam search decoding args 104 | parser.add_argument("-alpha", default=0.6, type=float) 105 | parser.add_argument("-beam_size", default=3, type=int) 106 | parser.add_argument("-min_length", default=10, type=int) 107 | parser.add_argument("-max_length", default=100, type=int) 108 | parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True) 109 | 110 | # Optim args 111 | parser.add_argument("-optim", default='adam', type=str) 112 | parser.add_argument("-sep_optim", type=str2bool, nargs='?', const=True, default=False) 113 | parser.add_argument("-lr_bert", default=0.001, type=float) 114 | parser.add_argument("-lr_other", default=0.01, type=float) 115 | parser.add_argument("-lr_topic", default=0.0001, type=float) 116 | parser.add_argument("-lr", default=0.001, type=float) 117 | parser.add_argument("-beta1", default=0.9, type=float) 118 | parser.add_argument("-beta2", default=0.999, type=float) 119 | parser.add_argument("-warmup", type=str2bool, nargs='?', const=True, default=True) 120 | parser.add_argument("-warmup_steps", default=5000, type=int) 121 | parser.add_argument("-warmup_steps_bert", default=5000, type=int) 122 | parser.add_argument("-warmup_steps_other", default=5000, type=int) 123 | parser.add_argument("-max_grad_norm", default=0, type=float) 124 | 125 | # Pretrain args 126 | parser.add_argument("-pretrain", type=str2bool, nargs='?', const=True, default=False) 127 | # Baseline model pretrain args 128 | parser.add_argument("-pretrain_steps", default=80000, type=int) 129 | 130 | # Utility args 131 | parser.add_argument("-test_all", type=str2bool, nargs='?', const=True, default=False) 132 | parser.add_argument("-test_from", default='') 133 | parser.add_argument("-test_start_from", default=-1, type=int) 134 | parser.add_argument("-train_from", default='') 135 | parser.add_argument("-train_from_ignore_optim", type=str2bool, nargs='?', const=True, default=False) 136 | 137 | # args for RL 138 | parser.add_argument("-freeze_step", default=500, type=int) 139 | parser.add_argument("-ex_max_token_num", default=500, type=int) 140 | parser.add_argument("-sent_hidden_size", default=768, type=int) 141 | parser.add_argument("-sent_ff_size", default=2048, type=int) 142 | parser.add_argument("-sent_heads", default=8, type=int) 143 | parser.add_argument("-sent_dropout", default=0.2, type=float) 144 | parser.add_argument("-sent_layers", default=3, type=int) 145 | parser.add_argument("-pn_hidden_size", default=768, type=int) 146 | parser.add_argument("-pn_ff_size", default=2048, type=int) 147 | parser.add_argument("-pn_heads", default=8, type=int) 148 | parser.add_argument("-pn_dropout", default=0.2, type=float) 149 | parser.add_argument("-pn_layers", default=2, type=int) 150 | parser.add_argument("-mask_token_prob", default=0.15, type=float) 151 | parser.add_argument("-select_sent_prob", default=0.90, type=float) 152 | 153 | args = parser.parse_args() 154 | args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))] 155 | args.world_size = len(args.gpu_ranks) 156 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 157 | 158 | init_logger(args.log_file) 159 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 160 | device_id = 0 if device == "cuda" else -1 161 | 162 | if (args.mode == 'train'): 163 | train(args, device_id) 164 | elif (args.mode == 'lead'): 165 | baseline(args, cal_lead=True) 166 | elif (args.mode == 'oracle'): 167 | baseline(args, cal_oracle=True) 168 | elif (args.mode == 'validate'): 169 | validate(args, device_id) 170 | elif (args.mode == 'test'): 171 | cp = args.test_from 172 | try: 173 | step = int(cp.split('.')[-2].split('_')[-1]) 174 | except RuntimeWarning: 175 | print("Unrecognized cp step.") 176 | step = 0 177 | test_text(args, device_id, cp, step) 178 | else: 179 | print("Undefined mode! Please check input.") 180 | -------------------------------------------------------------------------------- /src/train_emb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import glob 5 | from os.path import join as pjoin 6 | from others.vocab_wrapper import VocabWrapper 7 | 8 | 9 | def train_emb(args): 10 | data_dir = os.path.abspath(args.data_path) 11 | print("Preparing to process %s ..." % data_dir) 12 | raw_files = glob.glob(pjoin(data_dir, '*.json')) 13 | 14 | ex_num = 0 15 | vocab_wrapper = VocabWrapper(args.mode, args.emb_size) 16 | vocab_wrapper.init_model() 17 | 18 | file_ex = [] 19 | for s in raw_files: 20 | exs = json.load(open(s)) 21 | print("Processing File " + s) 22 | for ex in exs: 23 | example = list(map(lambda x: x['word'], ex['session'])) 24 | file_ex.extend(example) 25 | ex_num += 1 26 | vocab_wrapper.train(file_ex) 27 | vocab_wrapper.report() 28 | print("Datasets size: %d" % ex_num) 29 | vocab_wrapper.save_emb(args.emb_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("-mode", default='word2vec', type=str, choices=['glove', 'word2vec']) 35 | parser.add_argument("-data_path", default="", type=str) 36 | parser.add_argument("-emb_size", default=100, type=int) 37 | parser.add_argument("-emb_path", default="", type=str) 38 | 39 | args = parser.parse_args() 40 | 41 | train_emb(args) 42 | -------------------------------------------------------------------------------- /src/translate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RowitZou/topic-dialog-summ/0de31d97b07be4004e08f9755ee66bea47aa7b10/src/translate/__init__.py -------------------------------------------------------------------------------- /src/translate/beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from translate import penalties 4 | 5 | 6 | class Beam(object): 7 | """ 8 | Class for managing the internals of the beam search process. 9 | 10 | Takes care of beams, back pointers, and scores. 11 | 12 | Args: 13 | size (int): beam size 14 | pad, bos, eos (int): indices of padding, beginning, and ending. 15 | n_best (int): nbest size to use 16 | cuda (bool): use gpu 17 | global_scorer (:obj:`GlobalScorer`) 18 | """ 19 | 20 | def __init__(self, size, pad, bos, eos, 21 | n_best=1, cuda=False, 22 | global_scorer=None, 23 | min_length=0, 24 | stepwise_penalty=False, 25 | block_ngram_repeat=0, 26 | exclusion_tokens=set()): 27 | 28 | self.size = size 29 | self.tt = torch.cuda if cuda else torch 30 | 31 | # The score for each translation on the beam. 32 | self.scores = self.tt.FloatTensor(size).zero_() 33 | self.all_scores = [] 34 | 35 | # The backpointers at each time-step. 36 | self.prev_ks = [] 37 | 38 | # The outputs at each time-step. 39 | self.next_ys = [self.tt.LongTensor(size) 40 | .fill_(pad)] 41 | self.next_ys[0][0] = bos 42 | 43 | # Has EOS topped the beam yet. 44 | self._eos = eos 45 | self.eos_top = False 46 | 47 | # The attentions (matrix) for each time. 48 | self.attn = [] 49 | 50 | # Time and k pair for finished. 51 | self.finished = [] 52 | self.n_best = n_best 53 | 54 | # Information for global scoring. 55 | self.global_scorer = global_scorer 56 | self.global_state = {} 57 | 58 | # Minimum prediction length 59 | self.min_length = min_length 60 | 61 | # Apply Penalty at every step 62 | self.stepwise_penalty = stepwise_penalty 63 | self.block_ngram_repeat = block_ngram_repeat 64 | self.exclusion_tokens = exclusion_tokens 65 | 66 | def get_current_state(self): 67 | "Get the outputs for the current timestep." 68 | return self.next_ys[-1] 69 | 70 | def get_current_origin(self): 71 | "Get the backpointers for the current timestep." 72 | return self.prev_ks[-1] 73 | 74 | def advance(self, word_probs, attn_out): 75 | """ 76 | Given prob over words for every last beam `wordLk` and attention 77 | `attn_out`: Compute and update the beam search. 78 | 79 | Parameters: 80 | 81 | * `word_probs`- probs of advancing from the last step (K x words) 82 | * `attn_out`- attention at the last step 83 | 84 | Returns: True if beam search is complete. 85 | """ 86 | num_words = word_probs.size(1) 87 | if self.stepwise_penalty: 88 | self.global_scorer.update_score(self, attn_out) 89 | # force the output to be longer than self.min_length 90 | cur_len = len(self.next_ys) 91 | if cur_len < self.min_length: 92 | for k in range(len(word_probs)): 93 | word_probs[k][self._eos] = -1e20 94 | # Sum the previous scores. 95 | if len(self.prev_ks) > 0: 96 | beam_scores = word_probs + \ 97 | self.scores.unsqueeze(1).expand_as(word_probs) 98 | # Don't let EOS have children. 99 | for i in range(self.next_ys[-1].size(0)): 100 | if self.next_ys[-1][i] == self._eos: 101 | beam_scores[i] = -1e20 102 | 103 | # Block ngram repeats 104 | if self.block_ngram_repeat > 0: 105 | ngrams = [] 106 | le = len(self.next_ys) 107 | for j in range(self.next_ys[-1].size(0)): 108 | hyp, _ = self.get_hyp(le - 1, j) 109 | ngrams = set() 110 | fail = False 111 | gram = [] 112 | for i in range(le - 1): 113 | # Last n tokens, n = block_ngram_repeat 114 | gram = (gram + 115 | [hyp[i].item()])[-self.block_ngram_repeat:] 116 | # Skip the blocking if it is in the exclusion list 117 | if set(gram) & self.exclusion_tokens: 118 | continue 119 | if tuple(gram) in ngrams: 120 | fail = True 121 | ngrams.add(tuple(gram)) 122 | if fail: 123 | beam_scores[j] = -10e20 124 | else: 125 | beam_scores = word_probs[0] 126 | flat_beam_scores = beam_scores.view(-1) 127 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 128 | True, True) 129 | 130 | self.all_scores.append(self.scores) 131 | self.scores = best_scores 132 | 133 | # best_scores_id is flattened beam x word array, so calculate which 134 | # word and beam each score came from 135 | prev_k = best_scores_id / num_words 136 | self.prev_ks.append(prev_k) 137 | self.next_ys.append((best_scores_id - prev_k * num_words)) 138 | self.attn.append(attn_out.index_select(0, prev_k)) 139 | self.global_scorer.update_global_state(self) 140 | 141 | for i in range(self.next_ys[-1].size(0)): 142 | if self.next_ys[-1][i] == self._eos: 143 | global_scores = self.global_scorer.score(self, self.scores) 144 | s = global_scores[i] 145 | self.finished.append((s, len(self.next_ys) - 1, i)) 146 | 147 | # End condition is when top-of-beam is EOS and no global score. 148 | if self.next_ys[-1][0] == self._eos: 149 | self.all_scores.append(self.scores) 150 | self.eos_top = True 151 | 152 | def done(self): 153 | return self.eos_top and len(self.finished) >= self.n_best 154 | 155 | def sort_finished(self, minimum=None): 156 | if minimum is not None: 157 | i = 0 158 | # Add from beam until we have minimum outputs. 159 | while len(self.finished) < minimum: 160 | global_scores = self.global_scorer.score(self, self.scores) 161 | s = global_scores[i] 162 | self.finished.append((s, len(self.next_ys) - 1, i)) 163 | i += 1 164 | 165 | self.finished.sort(key=lambda a: -a[0]) 166 | scores = [sc for sc, _, _ in self.finished] 167 | ks = [(t, k) for _, t, k in self.finished] 168 | return scores, ks 169 | 170 | def get_hyp(self, timestep, k): 171 | """ 172 | Walk back to construct the full hypothesis. 173 | """ 174 | hyp, attn = [], [] 175 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 176 | hyp.append(self.next_ys[j + 1][k]) 177 | attn.append(self.attn[j][k]) 178 | k = self.prev_ks[j][k] 179 | return hyp[::-1], torch.stack(attn[::-1]) 180 | 181 | 182 | class GNMTGlobalScorer(object): 183 | """ 184 | NMT re-ranking score from 185 | "Google's Neural Machine Translation System" :cite:`wu2016google` 186 | 187 | Args: 188 | alpha (float): length parameter 189 | beta (float): coverage parameter 190 | """ 191 | 192 | def __init__(self, alpha, length_penalty): 193 | self.alpha = alpha 194 | penalty_builder = penalties.PenaltyBuilder(length_penalty) 195 | # Term will be subtracted from probability 196 | # Probability will be divided by this 197 | self.length_penalty = penalty_builder.length_penalty() 198 | 199 | def score(self, beam, logprobs): 200 | """ 201 | Rescores a prediction based on penalty functions 202 | """ 203 | normalized_probs = self.length_penalty(beam, 204 | logprobs, 205 | self.alpha) 206 | 207 | return normalized_probs 208 | -------------------------------------------------------------------------------- /src/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | 4 | class PenaltyBuilder(object): 5 | """ 6 | Returns the Length and Coverage Penalty function for Beam Search. 7 | 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | """ 12 | 13 | def __init__(self, length_pen): 14 | self.length_pen = length_pen 15 | 16 | def length_penalty(self): 17 | if self.length_pen == "wu": 18 | return self.length_wu 19 | elif self.length_pen == "avg": 20 | return self.length_average 21 | else: 22 | return self.length_none 23 | 24 | """ 25 | Below are all the different penalty terms implemented so far 26 | """ 27 | 28 | def length_wu(self, beam, logprobs, alpha=0.): 29 | """ 30 | NMT length re-ranking score from 31 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 32 | """ 33 | 34 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 35 | ((5 + 1) ** alpha)) 36 | return (logprobs / modifier) 37 | 38 | def length_average(self, beam, logprobs, alpha=0.): 39 | """ 40 | Returns the average probability of tokens in a sequence. 41 | """ 42 | return logprobs / len(beam.next_ys) 43 | 44 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 45 | """ 46 | Returns unmodified scores. 47 | """ 48 | return logprobs 49 | --------------------------------------------------------------------------------