├── README.md ├── config.yaml ├── lr_scheduler.py ├── models ├── __init__.py ├── attention.py ├── beam.py ├── optims.py ├── rnn.py └── seq2seq.py ├── opts.py ├── predict.py ├── preprocess.py ├── train.py └── utils ├── __init__.py ├── data_helper.py ├── dict_helper.py ├── metrics.py └── misc_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Sequence Generation Model for Multi-label Classification 2 | This is the code for our paper *SGM: Sequence Generation Model for Multi-label Classification* [[pdf]](https://arxiv.org/abs/1806.04822) 3 | 4 | *********************************************************** 5 | 6 | ## Note 7 | In general, this code is more suitable for the following application scenarios: 8 | * **The dataset is relatively large:** 9 | * The performance of the seq2seq model depends on the size of the dataset. 10 | * **There exist some orders or dependencies between labels:** 11 | * A reasonable prior order of labels tends to be helpful. 12 | 13 | *********************************************************** 14 | 15 | ## Requirements 16 | * Ubuntu 16.0.4 17 | * Python version >= 3.5 18 | * [PyTorch](http://pytorch.org/) version >= 1.0.0 19 | 20 | 21 | *************************************************************** 22 | 23 | ## Dataset 24 | Our used RCV1-V2 dataset can be downloaded from google drive with [this link](https://drive.google.com/drive/folders/1lBt2MjEoh4CG2jjt4nMgHro2h3k2gwUH?usp=sharing). The structure of the folders on drive is: 25 | ``` 26 | Google Drive Root # The compressed zip file 27 | |-- data # The unprocessed raw data files 28 | | |-- train.src 29 | | |-- train.tgt 30 | | |-- valid.src 31 | | |-- valid.tgt 32 | | |-- test.src 33 | | |-- test.tgt 34 | | |-- topic_sorted.json # The json file of label set for evaluation 35 | |-- checkpoints # The pre-trained model checkpoints 36 | | |-- sgm.pt 37 | | |-- sgmge.pt 38 | ``` 39 | 40 | We found that the valid-set in the previous version is so small that the model tends to overfit the valid-set, resulting in unstable performance. Therefore, we have expanded the valid-set. In addition, we also filtered out samples that contain more than 500 words in the original RCV1-V2 dataset. 41 | 42 | *************************************************************** 43 | 44 | ## Reproducibility 45 | We provide the pretrained checkpoints of the SGM model and the SGM+GE model on the RCV1-V2 dataset to help you to reproduce our reported experimental results. The detailed reproduction steps are as follows: 46 | 47 | * Please download the RCV1-V2 dataset and checkpoints first by clicking on the [link](https://drive.google.com/drive/folders/1lBt2MjEoh4CG2jjt4nMgHro2h3k2gwUH?usp=sharing), then put them in the same directory as these codes. The correct structure of the folders should be: 48 | ``` 49 | Root 50 | |-- data 51 | | |-- ... 52 | |-- checkpoints 53 | | |-- ... 54 | |-- models 55 | | |-- ... 56 | |-- utils 57 | | |-- ... 58 | |-- preprocess.py 59 | |-- train.py 60 | |-- ... 61 | ``` 62 | 63 | * Preprocess the downloaded data: 64 | ```bash 65 | python3 preprocess.py -load_data ./data/ -save_data ./data/save_data/ -src_vocab_size 50000 66 | ``` 67 | All the preprocessed data will be stored in the folder `./data/save_data/` 68 | 69 | * Perform prediction and evaluation: 70 | ```bash 71 | python3 predict.py -gpus gpu_id -data ./data/save_data/ -batch_size 64 -restore ./checkpoints/sgm.pt -log results/ 72 | ``` 73 | The predicted labels and evaluation scores will be stored in the folder `results` 74 | 75 | *************************************************************** 76 | 77 | ## Training from scratch 78 | 79 | ### Preprocessing 80 | You can preprocess the dataset with the following command: 81 | ``` 82 | python3 preprocess.py \ 83 | -load_data load_data_path \ # input file dir for the data 84 | -save_data save_data_path \ # output file dir for the processed data 85 | -src_vocab_size 50000 # size of the source vocabulary 86 | ``` 87 | Note that all data path must end with `/`. Other parameter descriptions can be found in `preprocess.py` 88 | 89 | *************************************************************** 90 | 91 | ### Training 92 | You can perform model training with the following command: 93 | ``` 94 | python3 train.py -gpus gpu_id -config model_config -log save_path 95 | ``` 96 | 97 | All log files and checkpoints during training will be saved in `save_path`. The detailed parameter descriptions can be found in `train.py` 98 | 99 | **************************************************************** 100 | 101 | ### Testing 102 | You can perform testing with the following command: 103 | ``` 104 | python3 predict.py -gpus gpu_id -data save_data_path -batch_size batch_size -log log_path 105 | ``` 106 | 107 | The predicted labels and evaluation scores will be stored in the folder `log_path`. The detailed parameter descriptions can be found in `predict.py` 108 | 109 | ******************************************************************* 110 | 111 | ## Citation 112 | If you use the above code for your research, please cite the paper: 113 | 114 | ``` 115 | @inproceedings{YangCOLING2018, 116 | author = {Pengcheng Yang and 117 | Xu Sun and 118 | Wei Li and 119 | Shuming Ma and 120 | Wei Wu and 121 | Houfeng Wang}, 122 | title = {{SGM:} Sequence Generation Model for Multi-label Classification}, 123 | booktitle = {Proceedings of the 27th International Conference on Computational 124 | Linguistics, {COLING} 2018, Santa Fe, New Mexico, USA, August 20-26, 125 | 2018}, 126 | pages = {3915--3926}, 127 | year = {2018} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data: './data/save_data/' 2 | logF: 'experiments/' 3 | epoch: 20 4 | batch_size: 64 5 | optim: 'adam' 6 | cell: 'lstm' 7 | attention: 'luong_gate' 8 | learning_rate: 0.0003 9 | max_grad_norm: 10 10 | learning_rate_decay: 0.5 11 | start_decay_at: 2 12 | emb_size: 512 13 | hidden_size: 512 14 | dec_num_layers: 3 15 | enc_num_layers: 3 16 | bidirectional: True 17 | dropout: 0.1 18 | max_time_step: 30 19 | eval_interval: 100 20 | save_interval: 5000 21 | unk: False 22 | schedule: False 23 | schesamp: False 24 | length_norm: True 25 | metrics: ['micro_f1'] 26 | shared_vocab: False 27 | beam_size: 1 28 | eval_time: 10 29 | mask: True 30 | global_emb: False 31 | tau: 0.1 -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class _LRScheduler(object): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} is not an Optimizer'.format( 10 | type(optimizer).__name__)) 11 | self.optimizer = optimizer 12 | if last_epoch == -1: 13 | for group in optimizer.param_groups: 14 | group.setdefault('initial_lr', group['lr']) 15 | else: 16 | for i, group in enumerate(optimizer.param_groups): 17 | if 'initial_lr' not in group: 18 | raise KeyError("param 'initial_lr' is not specified " 19 | "in param_groups[{}] when resuming an optimizer".format(i)) 20 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 21 | self.step(last_epoch + 1) 22 | self.last_epoch = last_epoch 23 | 24 | def get_lr(self): 25 | raise NotImplementedError 26 | 27 | def step(self, epoch=None): 28 | if epoch is None: 29 | epoch = self.last_epoch + 1 30 | self.last_epoch = epoch 31 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 32 | param_group['lr'] = lr 33 | 34 | 35 | class LambdaLR(_LRScheduler): 36 | """Sets the learning rate of each parameter group to the initial lr 37 | times a given function. When last_epoch=-1, sets initial lr as lr. 38 | Args: 39 | optimizer (Optimizer): Wrapped optimizer. 40 | lr_lambda (function or list): A function which computes a multiplicative 41 | factor given an integer parameter epoch, or a list of such 42 | functions, one for each group in optimizer.param_groups. 43 | last_epoch (int): The index of last epoch. Default: -1. 44 | Example: 45 | >>> # Assuming optimizer has two groups. 46 | >>> lambda1 = lambda epoch: epoch // 30 47 | >>> lambda2 = lambda epoch: 0.95 ** epoch 48 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 49 | >>> for epoch in range(100): 50 | >>> scheduler.step() 51 | >>> train(...) 52 | >>> validate(...) 53 | """ 54 | def __init__(self, optimizer, lr_lambda, last_epoch=-1): 55 | self.optimizer = optimizer 56 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 57 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 58 | else: 59 | if len(lr_lambda) != len(optimizer.param_groups): 60 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 61 | len(optimizer.param_groups), len(lr_lambda))) 62 | self.lr_lambdas = list(lr_lambda) 63 | self.last_epoch = last_epoch 64 | super(LambdaLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | return [base_lr * lmbda(self.last_epoch) 68 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 69 | 70 | 71 | class StepLR(_LRScheduler): 72 | """Sets the learning rate of each parameter group to the initial lr 73 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 74 | initial lr as lr. 75 | Args: 76 | optimizer (Optimizer): Wrapped optimizer. 77 | step_size (int): Period of learning rate decay. 78 | gamma (float): Multiplicative factor of learning rate decay. 79 | Default: 0.1. 80 | last_epoch (int): The index of last epoch. Default: -1. 81 | Example: 82 | >>> # Assuming optimizer uses lr = 0.5 for all groups 83 | >>> # lr = 0.05 if epoch < 30 84 | >>> # lr = 0.005 if 30 <= epoch < 60 85 | >>> # lr = 0.0005 if 60 <= epoch < 90 86 | >>> # ... 87 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 88 | >>> for epoch in range(100): 89 | >>> scheduler.step() 90 | >>> train(...) 91 | >>> validate(...) 92 | """ 93 | 94 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): 95 | self.step_size = step_size 96 | self.gamma = gamma 97 | super(StepLR, self).__init__(optimizer, last_epoch) 98 | 99 | def get_lr(self): 100 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size) 101 | for base_lr in self.base_lrs] 102 | 103 | 104 | class MultiStepLR(_LRScheduler): 105 | """Set the learning rate of each parameter group to the initial lr decayed 106 | by gamma once the number of epoch reaches one of the milestones. When 107 | last_epoch=-1, sets initial lr as lr. 108 | Args: 109 | optimizer (Optimizer): Wrapped optimizer. 110 | milestones (list): List of epoch indices. Must be increasing. 111 | gamma (float): Multiplicative factor of learning rate decay. 112 | Default: 0.1. 113 | last_epoch (int): The index of last epoch. Default: -1. 114 | Example: 115 | >>> # Assuming optimizer uses lr = 0.5 for all groups 116 | >>> # lr = 0.05 if epoch < 30 117 | >>> # lr = 0.005 if 30 <= epoch < 80 118 | >>> # lr = 0.0005 if epoch >= 80 119 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 120 | >>> for epoch in range(100): 121 | >>> scheduler.step() 122 | >>> train(...) 123 | >>> validate(...) 124 | """ 125 | 126 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 127 | if not list(milestones) == sorted(milestones): 128 | raise ValueError('Milestones should be a list of' 129 | ' increasing integers. Got {}', milestones) 130 | self.milestones = milestones 131 | self.gamma = gamma 132 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 133 | 134 | def get_lr(self): 135 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 136 | for base_lr in self.base_lrs] 137 | 138 | 139 | class ExponentialLR(_LRScheduler): 140 | """Set the learning rate of each parameter group to the initial lr decayed 141 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 142 | Args: 143 | optimizer (Optimizer): Wrapped optimizer. 144 | gamma (float): Multiplicative factor of learning rate decay. 145 | last_epoch (int): The index of last epoch. Default: -1. 146 | """ 147 | 148 | def __init__(self, optimizer, gamma, last_epoch=-1): 149 | self.gamma = gamma 150 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 151 | 152 | def get_lr(self): 153 | return [base_lr * self.gamma ** self.last_epoch 154 | for base_lr in self.base_lrs] 155 | 156 | 157 | class CosineAnnealingLR(_LRScheduler): 158 | """Set the learning rate of each parameter group using a cosine annealing 159 | schedule, where :math:`\eta_{max}` is set to the initial lr and 160 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 161 | .. math:: 162 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 163 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 164 | When last_epoch=-1, sets initial lr as lr. 165 | It has been proposed in 166 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 167 | implements the cosine annealing part of SGDR, and not the restarts. 168 | Args: 169 | optimizer (Optimizer): Wrapped optimizer. 170 | T_max (int): Maximum number of iterations. 171 | eta_min (float): Minimum learning rate. Default: 0. 172 | last_epoch (int): The index of last epoch. Default: -1. 173 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 174 | https://arxiv.org/abs/1608.03983 175 | """ 176 | 177 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): 178 | self.T_max = T_max 179 | self.eta_min = eta_min 180 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) 181 | 182 | def get_lr(self): 183 | return [self.eta_min + (base_lr - self.eta_min) * 184 | (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2 185 | for base_lr in self.base_lrs] 186 | 187 | 188 | class ReduceLROnPlateau(object): 189 | """Reduce learning rate when a metric has stopped improving. 190 | Models often benefit from reducing the learning rate by a factor 191 | of 2-10 once learning stagnates. This scheduler reads a metrics 192 | quantity and if no improvement is seen for a 'patience' number 193 | of epochs, the learning rate is reduced. 194 | Args: 195 | optimizer (Optimizer): Wrapped optimizer. 196 | mode (str): One of `min`, `max`. In `min` mode, lr will 197 | be reduced when the quantity monitored has stopped 198 | decreasing; in `max` mode it will be reduced when the 199 | quantity monitored has stopped increasing. Default: 'min'. 200 | factor (float): Factor by which the learning rate will be 201 | reduced. new_lr = lr * factor. Default: 0.1. 202 | patience (int): Number of epochs with no improvement after 203 | which learning rate will be reduced. Default: 10. 204 | verbose (bool): If True, prints a message to stdout for 205 | each update. Default: False. 206 | threshold (float): Threshold for measuring the new optimum, 207 | to only focus on significant changes. Default: 1e-4. 208 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 209 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 210 | mode or best * ( 1 - threshold ) in `min` mode. 211 | In `abs` mode, dynamic_threshold = best + threshold in 212 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 213 | cooldown (int): Number of epochs to wait before resuming 214 | normal operation after lr has been reduced. Default: 0. 215 | min_lr (float or list): A scalar or a list of scalars. A 216 | lower bound on the learning rate of all param groups 217 | or each group respectively. Default: 0. 218 | eps (float): Minimal decay applied to lr. If the difference 219 | between new and old lr is smaller than eps, the update is 220 | ignored. Default: 1e-8. 221 | Example: 222 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 223 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 224 | >>> for epoch in range(10): 225 | >>> train(...) 226 | >>> val_loss = validate(...) 227 | >>> # Note that step should be called after validate() 228 | >>> scheduler.step(val_loss) 229 | """ 230 | 231 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 232 | verbose=False, threshold=1e-4, threshold_mode='rel', 233 | cooldown=0, min_lr=0, eps=1e-8): 234 | 235 | if factor >= 1.0: 236 | raise ValueError('Factor should be < 1.0.') 237 | self.factor = factor 238 | 239 | if not isinstance(optimizer, Optimizer): 240 | raise TypeError('{} is not an Optimizer'.format( 241 | type(optimizer).__name__)) 242 | self.optimizer = optimizer 243 | 244 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 245 | if len(min_lr) != len(optimizer.param_groups): 246 | raise ValueError("expected {} min_lrs, got {}".format( 247 | len(optimizer.param_groups), len(min_lr))) 248 | self.min_lrs = list(min_lr) 249 | else: 250 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 251 | 252 | self.patience = patience 253 | self.verbose = verbose 254 | self.cooldown = cooldown 255 | self.cooldown_counter = 0 256 | self.mode = mode 257 | self.threshold = threshold 258 | self.threshold_mode = threshold_mode 259 | self.best = None 260 | self.num_bad_epochs = None 261 | self.mode_worse = None # the worse value for the chosen mode 262 | self.is_better = None 263 | self.eps = eps 264 | self.last_epoch = -1 265 | self._init_is_better(mode=mode, threshold=threshold, 266 | threshold_mode=threshold_mode) 267 | self._reset() 268 | 269 | def _reset(self): 270 | """Resets num_bad_epochs counter and cooldown counter.""" 271 | self.best = self.mode_worse 272 | self.cooldown_counter = 0 273 | self.num_bad_epochs = 0 274 | 275 | def step(self, metrics, epoch=None): 276 | current = metrics 277 | if epoch is None: 278 | epoch = self.last_epoch = self.last_epoch + 1 279 | self.last_epoch = epoch 280 | 281 | if self.is_better(current, self.best): 282 | self.best = current 283 | self.num_bad_epochs = 0 284 | else: 285 | self.num_bad_epochs += 1 286 | 287 | if self.in_cooldown: 288 | self.cooldown_counter -= 1 289 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 290 | 291 | if self.num_bad_epochs > self.patience: 292 | self._reduce_lr(epoch) 293 | self.cooldown_counter = self.cooldown 294 | self.num_bad_epochs = 0 295 | 296 | def _reduce_lr(self, epoch): 297 | for i, param_group in enumerate(self.optimizer.param_groups): 298 | old_lr = float(param_group['lr']) 299 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 300 | if old_lr - new_lr > self.eps: 301 | param_group['lr'] = new_lr 302 | if self.verbose: 303 | print('Epoch {:5d}: reducing learning rate' 304 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 305 | 306 | @property 307 | def in_cooldown(self): 308 | return self.cooldown_counter > 0 309 | 310 | def _init_is_better(self, mode, threshold, threshold_mode): 311 | if mode not in {'min', 'max'}: 312 | raise ValueError('mode ' + mode + ' is unknown!') 313 | if threshold_mode not in {'rel', 'abs'}: 314 | raise ValueError('threshold mode ' + mode + ' is unknown!') 315 | if mode == 'min' and threshold_mode == 'rel': 316 | rel_epsilon = 1. - threshold 317 | self.is_better = lambda a, best: a < best * rel_epsilon 318 | self.mode_worse = float('Inf') 319 | elif mode == 'min' and threshold_mode == 'abs': 320 | self.is_better = lambda a, best: a < best - threshold 321 | self.mode_worse = float('Inf') 322 | elif mode == 'max' and threshold_mode == 'rel': 323 | rel_epsilon = threshold + 1. 324 | self.is_better = lambda a, best: a > best * rel_epsilon 325 | self.mode_worse = -float('Inf') 326 | else: # mode == 'max' and epsilon_mode == 'abs': 327 | self.is_better = lambda a, best: a > best + threshold 328 | self.mode_worse = -float('Inf') 329 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .optims import * 3 | from .rnn import * 4 | from .seq2seq import * 5 | from .beam import * -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class luong_attention(nn.Module): 7 | 8 | def __init__(self, hidden_size, emb_size, pool_size=0): 9 | super(luong_attention, self).__init__() 10 | self.hidden_size, self.emb_size, self.pool_size = hidden_size, emb_size, pool_size 11 | self.linear_in = nn.Linear(hidden_size, hidden_size) 12 | if pool_size > 0: 13 | self.linear_out = maxout(2*hidden_size + emb_size, hidden_size, pool_size) 14 | else: 15 | self.linear_out = nn.Sequential(nn.Linear(2*hidden_size + emb_size, hidden_size), nn.SELU(), 16 | nn.Linear(hidden_size, hidden_size), nn.Tanh()) 17 | self.softmax = nn.Softmax(dim=1) 18 | 19 | def init_context(self, context): 20 | self.context = context.transpose(0, 1) 21 | 22 | def forward(self, h, x): 23 | gamma_h = self.linear_in(h).unsqueeze(2) # batch * size * 1 24 | weights = torch.bmm(self.context, gamma_h).squeeze(2) # batch * time 25 | weights = self.softmax(weights) # batch * time 26 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) # batch * size 27 | output = self.linear_out(torch.cat([c_t, h, x], 1)) 28 | 29 | return output, weights 30 | 31 | 32 | class luong_gate_attention(nn.Module): 33 | 34 | def __init__(self, hidden_size, emb_size, prob=0.1): 35 | super(luong_gate_attention, self).__init__() 36 | self.linear_in = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Dropout(p=prob), 37 | nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Dropout(p=prob)) 38 | self.linear_out = nn.Sequential(nn.Linear(2*hidden_size, hidden_size), nn.SELU(), nn.Dropout(p=prob), 39 | nn.Linear(hidden_size, hidden_size), nn.SELU(), nn.Dropout(p=prob)) 40 | self.softmax = nn.Softmax(dim=-1) 41 | self.dropout = nn.Dropout(p=prob) 42 | 43 | def init_context(self, context): 44 | self.context = context.transpose(0, 1) 45 | 46 | def forward(self, h): 47 | gamma_h = self.linear_in(h).unsqueeze(2) 48 | weights = self.dropout(torch.bmm(self.context, gamma_h).squeeze(2)) 49 | weights = self.softmax(weights) 50 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) 51 | output = self.linear_out(torch.cat([h, c_t], 1)) 52 | 53 | return output, weights 54 | 55 | 56 | class bahdanau_attention(nn.Module): 57 | 58 | def __init__(self, hidden_size, emb_size): 59 | super(bahdanau_attention, self).__init__() 60 | self.linear_encoder = nn.Linear(hidden_size, hidden_size) 61 | self.linear_decoder = nn.Linear(hidden_size, hidden_size) 62 | self.linear_v = nn.Linear(hidden_size, 1) 63 | self.linear_r = nn.Linear(hidden_size*2+emb_size, hidden_size*2) 64 | self.hidden_size = hidden_size 65 | self.emb_size = emb_size 66 | self.softmax = nn.Softmax(dim=1) 67 | self.tanh = nn.Tanh() 68 | 69 | def init_context(self, context): 70 | self.context = context.transpose(0, 1) 71 | 72 | def forward(self, h, x): 73 | gamma_encoder = self.linear_encoder(self.context) # batch * time * size 74 | gamma_decoder = self.linear_decoder(h).unsqueeze(1) # batch * 1 * size 75 | weights = self.linear_v(self.tanh(gamma_encoder+gamma_decoder)).squeeze(2) # batch * time 76 | weights = self.softmax(weights) # batch * time 77 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) # batch * size 78 | r_t = self.linear_r(torch.cat([c_t, h, x], dim=1)) 79 | output = r_t.view(-1, self.hidden_size, 2).max(2)[0] 80 | 81 | return output, weights 82 | 83 | 84 | class maxout(nn.Module): 85 | 86 | def __init__(self, in_feature, out_feature, pool_size): 87 | super(maxout, self).__init__() 88 | self.in_feature = in_feature 89 | self.out_feature = out_feature 90 | self.pool_size = pool_size 91 | self.linear = nn.Linear(in_feature, out_feature*pool_size) 92 | 93 | def forward(self, x): 94 | output = self.linear(x) 95 | output = output.view(-1, self.out_feature, self.pool_size) 96 | output = output.max(2)[0] 97 | 98 | return output 99 | -------------------------------------------------------------------------------- /models/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | class Beam(object): 5 | 6 | def __init__(self, size, n_best=1, cuda=True, length_norm=False, minimum_length=0): 7 | 8 | self.size = size 9 | self.tt = torch.cuda if cuda else torch 10 | 11 | # The score for each translation on the beam. 12 | self.scores = self.tt.FloatTensor(size).zero_() 13 | self.allScores = [] 14 | 15 | # The backpointers at each time-step. 16 | self.prevKs = [] 17 | 18 | # The outputs at each time-step. 19 | self.nextYs = [self.tt.LongTensor(size) 20 | .fill_(utils.EOS)] 21 | self.nextYs[0][0] = utils.BOS 22 | 23 | # Has EOS topped the beam yet. 24 | self._eos = utils.EOS 25 | self.eosTop = False 26 | 27 | # The attentions (matrix) for each time. 28 | self.attn = [] 29 | 30 | # Time and k pair for finished. 31 | self.finished = [] 32 | self.n_best = n_best 33 | 34 | self.length_norm = length_norm 35 | self.minimum_length = minimum_length 36 | 37 | def getCurrentState(self): 38 | "Get the outputs for the current timestep." 39 | return self.nextYs[-1] 40 | 41 | def getCurrentOrigin(self): 42 | "Get the backpointers for the current timestep." 43 | return self.prevKs[-1] 44 | 45 | def advance(self, wordLk, attnOut): 46 | """ 47 | Given prob over words for every last beam `wordLk` and attention 48 | `attnOut`: Compute and update the beam search. 49 | Parameters: 50 | * `wordLk`- probs of advancing from the last step (K x words) 51 | * `attnOut`- attention at the last step 52 | Returns: True if beam search is complete. 53 | """ 54 | numWords = wordLk.size(1) 55 | 56 | # Sum the previous scores. 57 | if len(self.prevKs) > 0: 58 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 59 | 60 | # Don't let EOS have children. 61 | for i in range(self.nextYs[-1].size(0)): 62 | if self.nextYs[-1][i] == self._eos: 63 | beamLk[i] = -1e20 64 | ngrams = [] 65 | le = len(self.nextYs) 66 | for j in range(self.nextYs[-1].size(0)): 67 | hyp, _ = self.getHyp(le-1, j) 68 | ngrams = set() 69 | fail = False 70 | gram = [] 71 | for i in range(le-1): 72 | # last n tokens, n = block_ngram_repeat 73 | gram = (gram + [hyp[i]])[-3:] 74 | # skip the blocking if it is in the exclusion list 75 | #if set(gram) & self.exclusion_tokens: 76 | #continue 77 | if tuple(gram) in ngrams: 78 | fail = True 79 | ngrams.add(tuple(gram)) 80 | if fail: 81 | beamLk[j] = -1e20 82 | else: 83 | beamLk = wordLk[0] 84 | flatBeamLk = beamLk.view(-1) 85 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 86 | 87 | self.allScores.append(self.scores) 88 | self.scores = bestScores 89 | 90 | # bestScoresId is flattened beam x word array, so calculate which 91 | # word and beam each score came from 92 | prevK = bestScoresId / numWords 93 | self.prevKs.append(prevK) 94 | self.nextYs.append((bestScoresId - prevK * numWords)) 95 | self.attn.append(attnOut.index_select(0, prevK)) 96 | 97 | for i in range(self.nextYs[-1].size(0)): 98 | if self.nextYs[-1][i] == self._eos: 99 | s = self.scores[i] 100 | if self.length_norm: 101 | s /= len(self.nextYs) 102 | if len(self.nextYs) - 1 >= self.minimum_length: 103 | self.finished.append((s, len(self.nextYs) - 1, i)) 104 | 105 | # End condition is when top-of-beam is EOS and no global score. 106 | if self.nextYs[-1][0] == utils.EOS: 107 | self.allScores.append(self.scores) 108 | self.eosTop = True 109 | 110 | def done(self): 111 | return self.eosTop and len(self.finished) >= self.n_best 112 | 113 | def beam_update(self, state, idx): 114 | positions = self.getCurrentOrigin() 115 | for e in state: 116 | a, br, d = e.size() 117 | e = e.view(a, self.size, br // self.size, d) 118 | sentStates = e[:, :, idx] 119 | sentStates.copy_(sentStates.index_select(1, positions)) 120 | 121 | def beam_update_gru(self, state, idx): 122 | positions = self.getCurrentOrigin() 123 | for e in state: 124 | br, d = e.size() 125 | e = e.view(self.size, br // self.size, d) 126 | sentStates = e[:, idx] 127 | sentStates.copy_(sentStates.index_select(0, positions)) 128 | 129 | def beam_update_memory(self, state, idx): 130 | positions = self.getCurrentOrigin() 131 | e = state 132 | br, d = e.size() 133 | e = e.view(self.size, br // self.size, d) 134 | sentStates = e[:, idx] 135 | sentStates.copy_(sentStates.index_select(0, positions)) 136 | 137 | def sortFinished(self, minimum=None): 138 | if minimum is not None: 139 | i = 0 140 | # Add from beam until we have minimum outputs. 141 | while len(self.finished) < minimum: 142 | s = self.scores[i].item() 143 | self.finished.append((s, len(self.nextYs) - 1, i)) 144 | i += 1 145 | 146 | self.finished.sort(key=lambda a: -a[0]) 147 | scores = [sc for sc, _, _ in self.finished] 148 | ks = [(t, k) for _, t, k in self.finished] 149 | return scores, ks 150 | 151 | def getHyp(self, timestep, k): 152 | """ 153 | Walk back to construct the full hypothesis. 154 | """ 155 | hyp, attn = [], [] 156 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 157 | hyp.append(self.nextYs[j+1][k].item()) 158 | attn.append(self.attn[j][k]) 159 | k = self.prevKs[j][k].item() 160 | return hyp[::-1], torch.stack(attn[::-1]) 161 | -------------------------------------------------------------------------------- /models/optims.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm_ 3 | 4 | 5 | class Optim(object): 6 | 7 | def set_parameters(self, params): 8 | self.params = list(params) # careful: params may be a generator 9 | if self.method == 'sgd': 10 | self.optimizer = optim.SGD(self.params, lr=self.lr) 11 | elif self.method == 'adagrad': 12 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 13 | elif self.method == 'adadelta': 14 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 15 | elif self.method == 'adam': 16 | self.optimizer = optim.Adam(self.params, lr=self.lr) 17 | else: 18 | raise RuntimeError("Invalid optim method: " + self.method) 19 | 20 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None, max_decay_times=2): 21 | self.last_score = None 22 | self.decay_times = 0 23 | self.max_decay_times = max_decay_times 24 | self.lr = lr 25 | self.max_grad_norm = max_grad_norm 26 | self.method = method 27 | self.lr_decay = lr_decay 28 | self.start_decay_at = start_decay_at 29 | self.start_decay = False 30 | 31 | def step(self): 32 | # Compute gradients norm. 33 | if self.max_grad_norm: 34 | clip_grad_norm_(self.params, self.max_grad_norm) 35 | self.optimizer.step() 36 | 37 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 38 | def updateLearningRate(self, score, epoch): 39 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 40 | self.start_decay = True 41 | 42 | if self.start_decay: 43 | self.lr = self.lr * self.lr_decay 44 | print("Decaying learning rate to %g" % self.lr) 45 | 46 | self.last_score = score 47 | self.optimizer.param_groups[0]['lr'] = self.lr 48 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch.nn.utils.rnn import pack_padded_sequence as pack 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | import models 7 | 8 | 9 | class rnn_encoder(nn.Module): 10 | 11 | def __init__(self, config, embedding=None): 12 | super(rnn_encoder, self).__init__() 13 | 14 | self.config = config 15 | self.hidden_size = config.hidden_size 16 | self.embedding = embedding if embedding is not None else nn.Embedding(config.src_vocab_size, config.emb_size) 17 | if config.cell == 'gru': 18 | self.rnn = nn.GRU(input_size=config.emb_size, hidden_size=config.hidden_size, 19 | num_layers=config.enc_num_layers, dropout=config.dropout, 20 | bidirectional=config.bidirectional) 21 | else: 22 | self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.hidden_size, 23 | num_layers=config.enc_num_layers, dropout=config.dropout, 24 | bidirectional=config.bidirectional) 25 | 26 | def forward(self, inputs, lengths): 27 | embs = pack(self.embedding(inputs), lengths) 28 | outputs, state = self.rnn(embs) 29 | outputs = unpack(outputs)[0] 30 | 31 | if self.config.bidirectional: 32 | # outputs: [max_src_len, batch_size, hidden_size] 33 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 34 | if self.config.cell == 'gru': 35 | state = state[:self.config.dec_num_layers] 36 | else: 37 | state = (state[0][::2], state[1][::2]) 38 | 39 | return outputs, state 40 | 41 | 42 | class rnn_decoder(nn.Module): 43 | 44 | def __init__(self, config, embedding=None, use_attention=True): 45 | super(rnn_decoder, self).__init__() 46 | 47 | self.config = config 48 | self.hidden_size = config.hidden_size 49 | self.embedding = embedding if embedding is not None else nn.Embedding(config.tgt_vocab_size, config.emb_size) 50 | input_size = 2 * config.emb_size if config.global_emb else config.emb_size 51 | 52 | if config.cell == 'gru': 53 | self.rnn = StackedGRU(input_size=input_size, hidden_size=config.hidden_size, 54 | num_layers=config.dec_num_layers, dropout=config.dropout) 55 | else: 56 | self.rnn = StackedLSTM(input_size=input_size, hidden_size=config.hidden_size, 57 | num_layers=config.dec_num_layers, dropout=config.dropout) 58 | 59 | self.linear = nn.Linear(config.hidden_size, config.tgt_vocab_size) 60 | 61 | if not use_attention or config.attention == 'None': 62 | self.attention = None 63 | elif config.attention == 'bahdanau': 64 | self.attention = models.bahdanau_attention(config.hidden_size, input_size) 65 | elif config.attention == 'luong': 66 | self.attention = models.luong_attention(config.hidden_size, input_size, config.pool_size) 67 | elif config.attention == 'luong_gate': 68 | self.attention = models.luong_gate_attention(config.hidden_size, input_size) 69 | 70 | self.dropout = nn.Dropout(config.dropout) 71 | 72 | if config.global_emb: 73 | self.ge_proj1 = nn.Linear(config.emb_size, config.emb_size) 74 | self.ge_proj2 = nn.Linear(config.emb_size, config.emb_size) 75 | self.softmax = nn.Softmax(dim=1) 76 | 77 | def forward(self, input, state, output=None, mask=None): 78 | embs = self.embedding(input) 79 | 80 | if self.config.global_emb: 81 | if output is None: 82 | output = embs.new_zeros(embs.size(0), self.config.tgt_vocab_size) 83 | probs = self.softmax(output / self.config.tau) 84 | emb_avg = torch.matmul(probs, self.embedding.weight) 85 | H = torch.sigmoid(self.ge_proj1(embs) + self.ge_proj2(emb_avg)) 86 | emb_glb = H * embs + (1 - H) * emb_avg 87 | embs = torch.cat((embs, emb_glb), dim=-1) 88 | 89 | output, state = self.rnn(embs, state) 90 | if self.attention is not None: 91 | if self.config.attention == 'luong_gate': 92 | output, attn_weights = self.attention(output) 93 | else: 94 | output, attn_weights = self.attention(output, embs) 95 | else: 96 | attn_weights = None 97 | output = self.compute_score(output) 98 | 99 | if self.config.mask and mask: 100 | mask = torch.stack(mask, dim=1).long() 101 | output.scatter_(dim=1, index=mask, value=-1e7) 102 | 103 | return output, state, attn_weights 104 | 105 | def compute_score(self, hiddens): 106 | scores = self.linear(hiddens) 107 | return scores 108 | 109 | 110 | class StackedLSTM(nn.Module): 111 | def __init__(self, num_layers, input_size, hidden_size, dropout): 112 | super(StackedLSTM, self).__init__() 113 | self.dropout = nn.Dropout(dropout) 114 | self.num_layers = num_layers 115 | self.layers = nn.ModuleList() 116 | 117 | for _ in range(num_layers): 118 | lstm = nn.LSTMCell(input_size, hidden_size) 119 | self.layers.append(lstm) 120 | input_size = hidden_size 121 | 122 | def forward(self, input, hidden): 123 | h_0, c_0 = hidden 124 | h_1, c_1 = [], [] 125 | for i, layer in enumerate(self.layers): 126 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 127 | input = h_1_i 128 | if i + 1 != self.num_layers: 129 | input = self.dropout(input) 130 | h_1 += [h_1_i] 131 | c_1 += [c_1_i] 132 | 133 | h_1 = torch.stack(h_1) 134 | c_1 = torch.stack(c_1) 135 | 136 | return input, (h_1, c_1) 137 | 138 | 139 | class StackedGRU(nn.Module): 140 | def __init__(self, num_layers, input_size, hidden_size, dropout): 141 | super(StackedGRU, self).__init__() 142 | self.dropout = nn.Dropout(dropout) 143 | self.num_layers = num_layers 144 | self.layers = nn.ModuleList() 145 | 146 | for _ in range(num_layers): 147 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 148 | input_size = hidden_size 149 | 150 | def forward(self, input, hidden): 151 | h_0 = hidden 152 | h_1 = [] 153 | for i, layer in enumerate(self.layers): 154 | h_1_i = layer(input, h_0[i]) 155 | input = h_1_i 156 | if i + 1 != self.num_layers: 157 | input = self.dropout(input) 158 | h_1 += [h_1_i] 159 | 160 | h_1 = torch.stack(h_1) 161 | 162 | return input, h_1 163 | -------------------------------------------------------------------------------- /models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import utils 4 | import models 5 | 6 | 7 | class seq2seq(nn.Module): 8 | 9 | def __init__(self, config, use_attention=True, encoder=None, decoder=None): 10 | super(seq2seq, self).__init__() 11 | 12 | if encoder is not None: 13 | self.encoder = encoder 14 | else: 15 | self.encoder = models.rnn_encoder(config) 16 | tgt_embedding = self.encoder.embedding if config.shared_vocab else None 17 | if decoder is not None: 18 | self.decoder = decoder 19 | else: 20 | self.decoder = models.rnn_decoder(config, embedding=tgt_embedding, use_attention=use_attention) 21 | self.log_softmax = nn.LogSoftmax(dim=-1) 22 | self.use_cuda = config.use_cuda 23 | self.config = config 24 | self.criterion = nn.CrossEntropyLoss(ignore_index=utils.PAD, reduction='none') 25 | if config.use_cuda: 26 | self.criterion.cuda() 27 | 28 | def compute_loss(self, scores, targets): 29 | scores = scores.view(-1, scores.size(2)) 30 | loss = self.criterion(scores, targets.contiguous().view(-1)) 31 | return loss 32 | 33 | def forward(self, src, src_len, dec, targets): 34 | """ 35 | Args: 36 | src: [bs, src_len] 37 | src_len: [bs] 38 | dec: [bs, tgt_len] (bos, x1, ..., xn) 39 | targets: [bs, tgt_len] (x1, ..., xn, eos) 40 | """ 41 | 42 | src = src.t() 43 | dec = dec.t() 44 | targets = targets.t() 45 | 46 | contexts, state = self.encoder(src, src_len.tolist()) 47 | 48 | if self.decoder.attention is not None: 49 | self.decoder.attention.init_context(context=contexts) 50 | 51 | outputs = [] 52 | output = None 53 | 54 | for input in dec.split(1): 55 | output, state, _ = self.decoder(input.squeeze(0), state, output) 56 | outputs.append(output) 57 | outputs = torch.stack(outputs) 58 | 59 | loss = self.compute_loss(outputs, targets) 60 | return loss, outputs 61 | 62 | def sample(self, src, src_len): 63 | 64 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 65 | _, reverse_indices = torch.sort(indices) 66 | src = torch.index_select(src, dim=0, index=indices) 67 | bos = torch.ones(src.size(0)).long().fill_(utils.BOS) 68 | src = src.t() 69 | 70 | if self.use_cuda: 71 | bos = bos.cuda() 72 | 73 | contexts, state = self.encoder(src, lengths.tolist()) 74 | 75 | if self.decoder.attention is not None: 76 | self.decoder.attention.init_context(context=contexts) 77 | 78 | inputs, outputs, attn_matrix = [bos], [], [] 79 | output = None 80 | 81 | for i in range(self.config.max_time_step): 82 | output, state, attn_weights = self.decoder(inputs[i], state, output, outputs) 83 | predicted = output.max(1)[1] 84 | inputs += [predicted] 85 | outputs += [predicted] 86 | attn_matrix += [attn_weights] 87 | 88 | outputs = torch.stack(outputs) 89 | sample_ids = torch.index_select(outputs, dim=1, index=reverse_indices).t() 90 | 91 | if self.decoder.attention is not None: 92 | attn_matrix = torch.stack(attn_matrix) 93 | alignments = attn_matrix.max(2)[1] 94 | alignments = torch.index_select(alignments, dim=1, index=reverse_indices).t() 95 | else: 96 | alignments = None 97 | 98 | return sample_ids, alignments 99 | 100 | def beam_sample(self, src, src_len, beam_size=1, eval_=False): 101 | 102 | # (1) Run the encoder on the src. 103 | 104 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 105 | _, ind = torch.sort(indices) 106 | src = torch.index_select(src, dim=0, index=indices) 107 | src = src.t() 108 | batch_size = src.size(1) 109 | contexts, encState = self.encoder(src, lengths.tolist()) 110 | 111 | # (1b) Initialize for the decoder. 112 | def var(a): 113 | return torch.tensor(a, requires_grad=False) 114 | 115 | def rvar(a): 116 | return var(a.repeat(1, beam_size, 1)) 117 | 118 | def bottle(m): 119 | return m.view(batch_size * beam_size, -1) 120 | 121 | def unbottle(m): 122 | return m.view(beam_size, batch_size, -1) 123 | 124 | # Repeat everything beam_size times. 125 | contexts = rvar(contexts) 126 | 127 | if self.config.cell == 'lstm': 128 | decState = (rvar(encState[0]), rvar(encState[1])) 129 | else: 130 | decState = rvar(encState) 131 | 132 | beam = [models.Beam(beam_size, n_best=1, 133 | cuda=self.use_cuda, length_norm=self.config.length_norm) 134 | for __ in range(batch_size)] 135 | if self.decoder.attention is not None: 136 | self.decoder.attention.init_context(contexts) 137 | 138 | # (2) run the decoder to generate sentences, using beam search. 139 | 140 | for i in range(self.config.max_time_step): 141 | 142 | if all((b.done() for b in beam)): 143 | break 144 | 145 | # Construct batch x beam_size nxt words. 146 | # Get all the pending current beam words and arrange for forward. 147 | inp = var(torch.stack([b.getCurrentState() for b in beam]) 148 | .t().contiguous().view(-1)) 149 | 150 | # Run one step. 151 | output, decState, attn = self.decoder(inp, decState) 152 | # decOut: beam x rnn_size 153 | 154 | # (b) Compute a vector of batch*beam word scores. 155 | output = unbottle(self.log_softmax(output)) 156 | attn = unbottle(attn) 157 | # beam x tgt_vocab 158 | 159 | # (c) Advance each beam. 160 | # update state 161 | for j, b in enumerate(beam): 162 | b.advance(output[:, j], attn[:, j]) 163 | if self.config.cell == 'lstm': 164 | b.beam_update(decState, j) 165 | else: 166 | b.beam_update_gru(decState, j) 167 | 168 | # (3) Package everything up. 169 | allHyps, allScores, allAttn = [], [], [] 170 | if eval_: 171 | allWeight = [] 172 | 173 | for j in ind: 174 | b = beam[j] 175 | n_best = 1 176 | scores, ks = b.sortFinished(minimum=n_best) 177 | hyps, attn = [], [] 178 | if eval_: 179 | weight = [] 180 | for i, (times, k) in enumerate(ks[:n_best]): 181 | hyp, att = b.getHyp(times, k) 182 | hyps.append(hyp) 183 | attn.append(att.max(1)[1]) 184 | if eval_: 185 | weight.append(att) 186 | allHyps.append(hyps[0]) 187 | allScores.append(scores[0]) 188 | allAttn.append(attn[0]) 189 | if eval_: 190 | allWeight.append(weight[0]) 191 | 192 | if eval_: 193 | return allHyps, allAttn, allWeight 194 | 195 | return allHyps, allAttn -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | def model_opts(parser): 2 | 3 | parser.add_argument('-config', default='config.yaml', type=str, 4 | help="config file") 5 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 6 | help="use CUDA on the listed devices.") 7 | parser.add_argument('-restore', default='./checkpoints/sgmge.pt', type=str, 8 | help="restore checkpoint") 9 | parser.add_argument('-seed', default=1234, type=int, 10 | help="random seed") 11 | parser.add_argument('-model', default='seq2seq', type=str, 12 | help="model selection") 13 | parser.add_argument('-mode', default='train', type=str, 14 | help="mode selection") 15 | parser.add_argument('-module', default='seq2seq', type=str, 16 | help="module selection") 17 | parser.add_argument('-log', default='', type=str, 18 | help="log directory") 19 | parser.add_argument('-num_processes', type=int, default=4, 20 | help="number of processes") 21 | parser.add_argument('-refF', default='', type=str, 22 | help="reference file") 23 | parser.add_argument('-unk', action='store_true', 24 | help='replace unk') 25 | parser.add_argument('-char', action='store_true', 26 | help='char level decoding') 27 | parser.add_argument('-length_norm', action='store_true', 28 | help='replace unk') 29 | parser.add_argument('-pool_size', type=int, default=0, 30 | help="pool size of maxout layer") 31 | parser.add_argument('-scale', type=float, default=1, 32 | help="proportion of the training set") 33 | parser.add_argument('-max_split', type=int, default=0, 34 | help="max generator time steps for memory efficiency") 35 | parser.add_argument('-split_num', type=int, default=0, 36 | help="split number for splitres") 37 | parser.add_argument('-pretrain', default='', type=str, 38 | help="load pretrain encoder") 39 | parser.add_argument('-label_dict_file', default='./data/topic_sorted.json', type=str, 40 | help="label_dict") 41 | 42 | 43 | def convert_to_config(opt, config): 44 | opt = vars(opt) 45 | for key in opt: 46 | if key not in config: 47 | config[key] = opt[key] 48 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | 4 | import os 5 | import argparse 6 | import pickle 7 | import codecs 8 | import json 9 | import random 10 | import numpy as np 11 | 12 | import opts 13 | import models 14 | import utils 15 | 16 | 17 | parser = argparse.ArgumentParser(description='predict.py') 18 | 19 | opts.model_opts(parser) 20 | parser.add_argument('-data', type=str, default='./data/save_data/', 21 | help="the processed data dir") 22 | parser.add_argument('-batch_size', type=int, default=64, 23 | help="the batch size for testing") 24 | 25 | opt = parser.parse_args() 26 | 27 | 28 | if not os.path.exists(opt.log): 29 | os.makedirs(opt.log) 30 | 31 | # load checkpoint 32 | assert opt.restore 33 | print('loading checkpoint...\n') 34 | checkpoints = torch.load(opt.restore) 35 | config = checkpoints['config'] 36 | 37 | # set seed 38 | torch.manual_seed(opt.seed) 39 | random.seed(opt.seed) 40 | np.random.seed(opt.seed) 41 | 42 | # set cuda 43 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 44 | config.use_cuda = use_cuda 45 | if use_cuda: 46 | torch.cuda.set_device(opt.gpus[0]) 47 | torch.cuda.manual_seed(opt.seed) 48 | torch.backends.cudnn.deterministic = True 49 | 50 | # load label_dict 51 | with codecs.open(opt.label_dict_file, 'r', 'utf-8') as f: 52 | label_dict = json.load(f) 53 | 54 | 55 | def load_data(): 56 | print('loading data...\n') 57 | data = pickle.load(open(opt.data+'data.pkl', 'rb')) 58 | src_vocab = data['dict']['src'] 59 | tgt_vocab = data['dict']['tgt'] 60 | config.src_vocab_size = src_vocab.size() 61 | config.tgt_vocab_size = tgt_vocab.size() 62 | testset = utils.BiDataset(data['test'], char=config.char) 63 | testloader = torch.utils.data.DataLoader(dataset=testset, 64 | batch_size=opt.batch_size, 65 | shuffle=False, 66 | num_workers=0, 67 | collate_fn=utils.padding) 68 | return {'testset':testset, 'testloader': testloader, 69 | 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab} 70 | 71 | 72 | # load data 73 | data = load_data() 74 | 75 | # build model 76 | print('building model...\n') 77 | model = getattr(models, opt.model)(config) 78 | model.load_state_dict(checkpoints['model']) 79 | if use_cuda: 80 | model.cuda() 81 | 82 | 83 | def eval_model(model, data): 84 | 85 | model.eval() 86 | reference, candidate, source, alignments = [], [], [], [] 87 | tgt_vocab = data['tgt_vocab'] 88 | count, total_count = 0, len(data['testset']) 89 | dataloader = data['testloader'] 90 | 91 | for src, tgt, src_len, tgt_len, original_src, original_tgt in dataloader: 92 | 93 | if config.use_cuda: 94 | src = src.cuda() 95 | src_len = src_len.cuda() 96 | 97 | with torch.no_grad(): 98 | if config.beam_size > 1 and (not config.global_emb): 99 | samples, alignment, _ = model.beam_sample(src, src_len, beam_size=config.beam_size, eval_=True) 100 | else: 101 | samples, alignment = model.sample(src, src_len) 102 | 103 | candidate += [tgt_vocab.convertToLabels(s.tolist(), utils.EOS) for s in samples] 104 | source += original_src 105 | reference += original_tgt 106 | if alignment is not None: 107 | alignments += [align for align in alignment] 108 | 109 | count += len(original_src) 110 | utils.progress_bar(count, total_count) 111 | 112 | if config.unk and config.attention != 'None': 113 | cands = [] 114 | for s, c, align in zip(source, candidate, alignments): 115 | cand = [] 116 | for word, idx in zip(c, align): 117 | if word == utils.UNK_WORD and idx < len(s): 118 | try: 119 | cand.append(s[idx]) 120 | except: 121 | cand.append(word) 122 | print("%d %d\n" % (len(s), idx)) 123 | else: 124 | cand.append(word) 125 | cands.append(cand) 126 | if len(cand) == 0: 127 | print('Error!') 128 | candidate = cands 129 | 130 | results = utils.eval_metrics(reference, candidate, label_dict, opt.log) 131 | results = [('%s: %.5f' % item + '\n') for item in results.items()] 132 | with codecs.open(opt.log+'results.txt', 'w', 'utf-8') as f: 133 | f.writelines(results) 134 | 135 | 136 | if __name__ == '__main__': 137 | eval_model(model, data) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import codecs 4 | import utils 5 | import pickle 6 | 7 | 8 | parser = argparse.ArgumentParser(description='preprocess.py') 9 | 10 | parser.add_argument('-load_data', type=str, required=True, 11 | help="input file dir for the data") 12 | parser.add_argument('-save_data', type=str, required=True, 13 | help="output file dir for the processed data") 14 | 15 | parser.add_argument('-src_vocab_size', type=int, default=50000, 16 | help="size of the source vocabulary") 17 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 18 | help="size of the target vocabulary") 19 | parser.add_argument('-src_filter', type=int, default=0, 20 | help="maximum source sequence length") 21 | parser.add_argument('-tgt_filter', type=int, default=0, 22 | help="maximum target sequence length") 23 | parser.add_argument('-src_trun', type=int, default=0, 24 | help="truncate source sequence length") 25 | parser.add_argument('-tgt_trun', type=int, default=0, 26 | help="truncate target sequence length") 27 | parser.add_argument('-src_char', action='store_true', 28 | help='character based encoding') 29 | parser.add_argument('-tgt_char', action='store_true', 30 | help='character based decoding') 31 | parser.add_argument('-src_suf', default='src', 32 | help="the suffix of the source filename") 33 | parser.add_argument('-tgt_suf', default='tgt', 34 | help="the suffix of the target filename") 35 | 36 | parser.add_argument('-share', action='store_true', 37 | help='share the vocabulary between source and target') 38 | 39 | parser.add_argument('-report_every', type=int, default=100000, 40 | help="report status every this many sentences") 41 | 42 | opt = parser.parse_args() 43 | 44 | 45 | def makeVocabulary(filename, trun_length, filter_length, char, vocab, size): 46 | 47 | print("%s: length limit = %d, truncate length = %d" % (filename, filter_length, trun_length)) 48 | max_length = 0 49 | with codecs.open(filename, 'r', 'utf-8') as f: 50 | for sent in f.readlines(): 51 | if char: 52 | tokens = list(sent.strip()) 53 | else: 54 | tokens = sent.strip().split() 55 | if 0 < filter_length < len(sent.strip().split()): 56 | continue 57 | max_length = max(max_length, len(tokens)) 58 | if trun_length > 0: 59 | tokens = tokens[:trun_length] 60 | for word in tokens: 61 | vocab.add(word) 62 | 63 | print('Max length of %s = %d' % (filename, max_length)) 64 | 65 | if size > 0: 66 | originalSize = vocab.size() 67 | vocab = vocab.prune(size) 68 | print('Created dictionary of size %d (pruned from %d)' % 69 | (vocab.size(), originalSize)) 70 | 71 | return vocab 72 | 73 | 74 | def saveVocabulary(name, vocab, file): 75 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...') 76 | vocab.writeFile(file) 77 | 78 | 79 | def makeData(srcFile, tgtFile, srcDicts, tgtDicts, save_srcFile, save_tgtFile, lim=0): 80 | sizes = 0 81 | count, empty_ignored, limit_ignored = 0, 0, 0 82 | 83 | print('Processing %s & %s ...' % (srcFile, tgtFile)) 84 | srcF = open(srcFile, encoding='utf8') 85 | tgtF = open(tgtFile, encoding='utf8') 86 | 87 | srcIdF = open(save_srcFile + '.id', 'w') 88 | tgtIdF = open(save_tgtFile + '.id', 'w') 89 | srcStrF = open(save_srcFile + '.str', 'w', encoding='utf8') 90 | tgtStrF = open(save_tgtFile + '.str', 'w', encoding='utf8') 91 | 92 | while True: 93 | sline = srcF.readline() 94 | tline = tgtF.readline() 95 | 96 | # normal end of file 97 | if sline == "" and tline == "": 98 | break 99 | 100 | # source or target does not have same number of lines 101 | if sline == "" or tline == "": 102 | print('WARNING: source and target do not have the same number of sentences') 103 | break 104 | 105 | sline = sline.strip() 106 | tline = tline.strip() 107 | 108 | # source and/or target are empty 109 | if sline == "" or tline == "": 110 | print('WARNING: ignoring an empty line ('+str(count+1)+')') 111 | empty_ignored += 1 112 | continue 113 | 114 | sline = sline.lower() 115 | tline = tline.lower() 116 | 117 | srcWords = sline.split() if not opt.src_char else list(sline) 118 | tgtWords = tline.split() if not opt.tgt_char else list(tline) 119 | 120 | 121 | if (opt.src_filter == 0 or len(sline.split()) <= opt.src_filter) and \ 122 | (opt.tgt_filter == 0 or len(tline.split()) <= opt.tgt_filter): 123 | 124 | if opt.src_trun > 0: 125 | srcWords = srcWords[:opt.src_trun] 126 | if opt.tgt_trun > 0: 127 | tgtWords = tgtWords[:opt.tgt_trun] 128 | 129 | srcIds = srcDicts.convertToIdx(srcWords, utils.UNK_WORD) 130 | tgtIds = tgtDicts.convertToIdx(tgtWords, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD) 131 | 132 | srcIdF.write(" ".join(list(map(str, srcIds)))+'\n') 133 | tgtIdF.write(" ".join(list(map(str, tgtIds)))+'\n') 134 | if not opt.src_char: 135 | srcStrF.write(" ".join(srcWords)+'\n') 136 | else: 137 | srcStrF.write("".join(srcWords) + '\n') 138 | if not opt.tgt_char: 139 | tgtStrF.write(" ".join(tgtWords)+'\n') 140 | else: 141 | tgtStrF.write("".join(tgtWords) + '\n') 142 | 143 | sizes += 1 144 | else: 145 | limit_ignored += 1 146 | 147 | count += 1 148 | 149 | if count % opt.report_every == 0: 150 | print('... %d sentences prepared' % count) 151 | 152 | srcF.close() 153 | tgtF.close() 154 | srcStrF.close() 155 | tgtStrF.close() 156 | srcIdF.close() 157 | tgtIdF.close() 158 | 159 | print('Prepared %d sentences (%d and %d ignored due to length == 0 or > )' % 160 | (sizes, empty_ignored, limit_ignored)) 161 | 162 | return {'srcF': save_srcFile + '.id', 'tgtF': save_tgtFile + '.id', 163 | 'original_srcF': save_srcFile + '.str', 'original_tgtF': save_tgtFile + '.str', 164 | 'length': sizes} 165 | 166 | 167 | def main(): 168 | 169 | if not os.path.exists(opt.save_data): 170 | os.makedirs(opt.save_data) 171 | 172 | dicts = {} 173 | 174 | train_src, train_tgt = opt.load_data + 'train.' + opt.src_suf, opt.load_data + 'train.' + opt.tgt_suf 175 | valid_src, valid_tgt = opt.load_data + 'valid.' + opt.src_suf, opt.load_data + 'valid.' + opt.tgt_suf 176 | test_src, test_tgt = opt.load_data + 'test.' + opt.src_suf, opt.load_data + 'test.' + opt.tgt_suf 177 | 178 | save_train_src, save_train_tgt = opt.save_data + 'train.' + opt.src_suf, opt.save_data + 'train.' + opt.tgt_suf 179 | save_valid_src, save_valid_tgt = opt.save_data + 'valid.' + opt.src_suf, opt.save_data + 'valid.' + opt.tgt_suf 180 | save_test_src, save_test_tgt = opt.save_data + 'test.' + opt.src_suf, opt.save_data + 'test.' + opt.tgt_suf 181 | 182 | src_dict, tgt_dict = opt.save_data + 'src.dict', opt.save_data + 'tgt.dict' 183 | 184 | if opt.share: 185 | assert opt.src_vocab_size == opt.tgt_vocab_size 186 | print('Building source and target vocabulary...') 187 | dicts['src'] = dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 188 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 189 | dicts['src'] = dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 190 | else: 191 | print('Building source vocabulary...') 192 | dicts['src'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 193 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 194 | print('Building target vocabulary...') 195 | dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 196 | dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 197 | 198 | print('Preparing training ...') 199 | train = makeData(train_src, train_tgt, dicts['src'], dicts['tgt'], save_train_src, save_train_tgt) 200 | 201 | print('Preparing validation ...') 202 | valid = makeData(valid_src, valid_tgt, dicts['src'], dicts['tgt'], save_valid_src, save_valid_tgt) 203 | 204 | print('Preparing test ...') 205 | test = makeData(test_src, test_tgt, dicts['src'], dicts['tgt'], save_test_src, save_test_tgt) 206 | 207 | print('Saving source vocabulary to \'' + src_dict + '\'...') 208 | dicts['src'].writeFile(src_dict) 209 | 210 | print('Saving source vocabulary to \'' + tgt_dict + '\'...') 211 | dicts['tgt'].writeFile(tgt_dict) 212 | 213 | data = {'train': train, 'valid': valid, 214 | 'test': test, 'dict': dicts} 215 | pickle.dump(data, open(opt.save_data+'data.pkl', 'wb')) 216 | 217 | 218 | if __name__ == "__main__": 219 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import lr_scheduler as L 4 | 5 | import os 6 | import argparse 7 | import pickle 8 | import time 9 | import random 10 | import numpy as np 11 | from collections import OrderedDict 12 | 13 | import opts 14 | import models 15 | import utils 16 | import codecs 17 | import json 18 | 19 | 20 | parser = argparse.ArgumentParser(description='train.py') 21 | opts.model_opts(parser) 22 | opt = parser.parse_args() 23 | 24 | config = utils.read_config(opt.config) 25 | torch.manual_seed(opt.seed) 26 | random.seed(opt.seed) 27 | np.random.seed(opt.seed) 28 | opts.convert_to_config(opt, config) 29 | 30 | # cuda 31 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 32 | config.use_cuda = use_cuda 33 | if use_cuda: 34 | torch.cuda.set_device(opt.gpus[0]) 35 | torch.cuda.manual_seed(opt.seed) 36 | torch.backends.cudnn.deterministic = True 37 | 38 | with codecs.open(opt.label_dict_file, 'r', 'utf-8') as f: 39 | label_dict = json.load(f) 40 | 41 | 42 | def load_data(): 43 | print('loading data...\n') 44 | data = pickle.load(open(config.data+'data.pkl', 'rb')) 45 | data['train']['length'] = int(data['train']['length'] * opt.scale) 46 | 47 | trainset = utils.BiDataset(data['train'], char=config.char) 48 | validset = utils.BiDataset(data['valid'], char=config.char) 49 | 50 | src_vocab = data['dict']['src'] 51 | tgt_vocab = data['dict']['tgt'] 52 | config.src_vocab_size = src_vocab.size() 53 | config.tgt_vocab_size = tgt_vocab.size() 54 | 55 | trainloader = torch.utils.data.DataLoader(dataset=trainset, 56 | batch_size=config.batch_size, 57 | shuffle=True, 58 | num_workers=0, 59 | collate_fn=utils.padding) 60 | if hasattr(config, 'valid_batch_size'): 61 | valid_batch_size = config.valid_batch_size 62 | else: 63 | valid_batch_size = config.batch_size 64 | validloader = torch.utils.data.DataLoader(dataset=validset, 65 | batch_size=valid_batch_size, 66 | shuffle=False, 67 | num_workers=0, 68 | collate_fn=utils.padding) 69 | 70 | return {'trainset': trainset, 'validset': validset, 71 | 'trainloader': trainloader, 'validloader': validloader, 72 | 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab} 73 | 74 | 75 | def build_model(checkpoints, print_log): 76 | for k, v in config.items(): 77 | print_log("%s:\t%s\n" % (str(k), str(v))) 78 | 79 | # model 80 | print('building model...\n') 81 | model = getattr(models, opt.model)(config) 82 | if checkpoints is not None: 83 | model.load_state_dict(checkpoints['model']) 84 | if opt.pretrain: 85 | print('loading checkpoint from %s' % opt.pretrain) 86 | pre_ckpt = torch.load(opt.pretrain)['model'] 87 | pre_ckpt = OrderedDict({key[8:]: pre_ckpt[key] for key in pre_ckpt if key.startswith('encoder')}) 88 | print(model.encoder.state_dict().keys()) 89 | print(pre_ckpt.keys()) 90 | model.encoder.load_state_dict(pre_ckpt) 91 | if use_cuda: 92 | model.cuda() 93 | 94 | # optimizer 95 | if checkpoints is not None: 96 | optim = checkpoints['optim'] 97 | else: 98 | optim = models.Optim(config.optim, config.learning_rate, config.max_grad_norm, 99 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at) 100 | optim.set_parameters(model.parameters()) 101 | 102 | # print log 103 | param_count = 0 104 | for param in model.parameters(): 105 | param_count += param.view(-1).size()[0] 106 | for k, v in config.items(): 107 | print_log("%s:\t%s\n" % (str(k), str(v))) 108 | print_log("\n") 109 | print_log(repr(model) + "\n\n") 110 | print_log('total number of parameters: %d\n\n' % param_count) 111 | 112 | return model, optim, print_log 113 | 114 | 115 | def train_model(model, data, optim, epoch, params): 116 | 117 | model.train() 118 | trainloader = data['trainloader'] 119 | 120 | for src, tgt, src_len, tgt_len, original_src, original_tgt in trainloader: 121 | 122 | model.zero_grad() 123 | 124 | if config.use_cuda: 125 | src = src.cuda() 126 | tgt = tgt.cuda() 127 | src_len = src_len.cuda() 128 | 129 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 130 | src = torch.index_select(src, dim=0, index=indices) 131 | tgt = torch.index_select(tgt, dim=0, index=indices) 132 | dec = tgt[:, :-1] 133 | targets = tgt[:, 1:] 134 | 135 | try: 136 | if config.schesamp: 137 | if epoch > 8: 138 | e = epoch - 8 139 | loss, outputs = model(src, lengths, dec, targets, teacher_ratio=0.9**e) 140 | else: 141 | loss, outputs = model(src, lengths, dec, targets) 142 | else: 143 | loss, outputs = model(src, lengths, dec, targets) 144 | 145 | pred = outputs.max(2)[1] 146 | targets = targets.t() 147 | num_correct = pred.eq(targets).masked_select(targets.ne(utils.PAD)).sum().item() 148 | num_total = targets.ne(utils.PAD).sum().item() 149 | if config.max_split == 0: 150 | loss = torch.sum(loss) / num_total 151 | loss.backward() 152 | optim.step() 153 | 154 | params['report_loss'] += loss.item() 155 | params['report_correct'] += num_correct 156 | params['report_total'] += num_total 157 | 158 | except RuntimeError as e: 159 | if 'out of memory' in str(e): 160 | print('| WARNING: ran out of memory') 161 | if hasattr(torch.cuda, 'empty_cache'): 162 | torch.cuda.empty_cache() 163 | else: 164 | raise e 165 | 166 | utils.progress_bar(params['updates'], config.eval_interval) 167 | params['updates'] += 1 168 | 169 | if params['updates'] % config.eval_interval == 0: 170 | params['log']("epoch: %3d, loss: %6.3f, time: %6.3f, updates: %8d, accuracy: %2.2f\n" 171 | % (epoch, params['report_loss'], time.time()-params['report_time'], 172 | params['updates'], params['report_correct'] * 100.0 / params['report_total'])) 173 | print('evaluating after %d updates...\r' % params['updates']) 174 | score = eval_model(model, data, params) 175 | for metric in config.metrics: 176 | params[metric].append(score[metric]) 177 | if score[metric] >= max(params[metric]): 178 | with codecs.open(params['log_path']+'best_'+metric+'_prediction.txt','w','utf-8') as f: 179 | f.write(codecs.open(params['log_path']+'candidate.txt','r','utf-8').read()) 180 | save_model(params['log_path']+'best_'+metric+'_checkpoint.pt', model, optim, params['updates']) 181 | model.train() 182 | params['report_loss'], params['report_time'] = 0, time.time() 183 | params['report_correct'], params['report_total'] = 0, 0 184 | 185 | if params['updates'] % config.save_interval == 0: 186 | save_model(params['log_path']+'checkpoint.pt', model, optim, params['updates']) 187 | 188 | optim.updateLearningRate(score=0, epoch=epoch) 189 | 190 | 191 | def eval_model(model, data, params): 192 | 193 | model.eval() 194 | reference, candidate, source, alignments = [], [], [], [] 195 | count, total_count = 0, len(data['validset']) 196 | validloader = data['validloader'] 197 | tgt_vocab = data['tgt_vocab'] 198 | 199 | for src, tgt, src_len, tgt_len, original_src, original_tgt in validloader: 200 | 201 | if config.use_cuda: 202 | src = src.cuda() 203 | src_len = src_len.cuda() 204 | 205 | with torch.no_grad(): 206 | if config.beam_size > 1 and (not config.global_emb): 207 | samples, alignment, _ = model.beam_sample(src, src_len, beam_size=config.beam_size, eval_=True) 208 | else: 209 | samples, alignment = model.sample(src, src_len) 210 | 211 | candidate += [tgt_vocab.convertToLabels(s.tolist(), utils.EOS) for s in samples] 212 | source += original_src 213 | reference += original_tgt 214 | if alignment is not None: 215 | alignments += [align for align in alignment] 216 | 217 | count += len(original_src) 218 | utils.progress_bar(count, total_count) 219 | 220 | if config.unk and config.attention != 'None': 221 | cands = [] 222 | for s, c, align in zip(source, candidate, alignments): 223 | cand = [] 224 | for word, idx in zip(c, align): 225 | if word == utils.UNK_WORD and idx < len(s): 226 | try: 227 | cand.append(s[idx]) 228 | except: 229 | cand.append(word) 230 | print("%d %d\n" % (len(s), idx)) 231 | else: 232 | cand.append(word) 233 | cands.append(cand) 234 | if len(cand) == 0: 235 | print('Error!') 236 | candidate = cands 237 | 238 | with codecs.open(params['log_path']+'candidate.txt','w+','utf-8') as f: 239 | for i in range(len(candidate)): 240 | f.write(" ".join(candidate[i])+'\n') 241 | 242 | results = utils.eval_metrics(reference, candidate, label_dict, params['log_path']) 243 | score = {} 244 | result_line = "" 245 | for metric in config.metrics: 246 | score[metric] = results[metric] 247 | result_line += metric + ": %s " % str(score[metric]) 248 | result_line += '\n' 249 | 250 | params['log'](result_line) 251 | 252 | return score 253 | 254 | 255 | def save_model(path, model, optim, updates): 256 | model_state_dict = model.state_dict() 257 | checkpoints = { 258 | 'model': model_state_dict, 259 | 'config': config, 260 | 'optim': optim, 261 | 'updates': updates} 262 | torch.save(checkpoints, path) 263 | 264 | 265 | def build_log(): 266 | # log 267 | if not os.path.exists(config.logF): 268 | os.makedirs(config.logF) 269 | if opt.log == '': 270 | log_path = config.logF + str(int(time.time() * 1000)) + '/' 271 | else: 272 | log_path = config.logF + opt.log + '/' 273 | if not os.path.exists(log_path): 274 | os.makedirs(log_path) 275 | print_log = utils.print_log(log_path + 'log.txt') 276 | return print_log, log_path 277 | 278 | 279 | def main(): 280 | # checkpoint 281 | if opt.restore: 282 | print('loading checkpoint...\n') 283 | checkpoints = torch.load(opt.restore) 284 | else: 285 | checkpoints = None 286 | 287 | data = load_data() 288 | print_log, log_path = build_log() 289 | model, optim, print_log = build_model(checkpoints, print_log) 290 | # scheduler 291 | if config.schedule: 292 | scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch) 293 | params = {'updates': 0, 'report_loss': 0, 'report_total': 0, 294 | 'report_correct': 0, 'report_time': time.time(), 295 | 'log': print_log, 'log_path': log_path} 296 | 297 | for metric in config.metrics: 298 | params[metric] = [] 299 | if opt.restore: 300 | params['updates'] = checkpoints['updates'] 301 | 302 | if opt.mode == "train": 303 | for i in range(1, config.epoch + 1): 304 | if config.schedule: 305 | scheduler.step() 306 | print("Decaying learning rate to %g" % scheduler.get_lr()[0]) 307 | train_model(model, data, optim, i, params) 308 | for metric in config.metrics: 309 | print_log("Best %s score: %.2f\n" % (metric, max(params[metric]))) 310 | else: 311 | score = eval_model(model, data, params) 312 | 313 | 314 | if __name__ == '__main__': 315 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_helper import * 2 | from .dict_helper import * 3 | from .misc_utils import * 4 | from .metrics import * 5 | -------------------------------------------------------------------------------- /utils/data_helper.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import torch 3 | import torch.utils.data as torch_data 4 | from random import Random 5 | import utils 6 | 7 | num_samples = 1 8 | 9 | 10 | class MonoDataset(torch_data.Dataset): 11 | 12 | def __init__(self, infos, indexes=None): 13 | 14 | self.srcF = infos['srcF'] 15 | self.original_srcF = infos['original_srcF'] 16 | self.length = infos['length'] 17 | self.infos = infos 18 | if indexes is None: 19 | self.indexes = list(range(self.length)) 20 | else: 21 | self.indexes = indexes 22 | 23 | def __getitem__(self, index): 24 | index = self.indexes[index] 25 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 26 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 27 | 28 | return src, original_src 29 | 30 | def __len__(self): 31 | return len(self.indexes) 32 | 33 | 34 | class BiDataset(torch_data.Dataset): 35 | 36 | def __init__(self, infos, indexes=None, char=False): 37 | 38 | self.srcF = infos['srcF'] 39 | self.tgtF = infos['tgtF'] 40 | self.original_srcF = infos['original_srcF'] 41 | self.original_tgtF = infos['original_tgtF'] 42 | self.length = infos['length'] 43 | self.infos = infos 44 | self.char = char 45 | if indexes is None: 46 | self.indexes = list(range(self.length)) 47 | else: 48 | self.indexes = indexes 49 | 50 | def __getitem__(self, index): 51 | index = self.indexes[index] 52 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 53 | tgt = list(map(int, linecache.getline(self.tgtF, index+1).strip().split())) 54 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 55 | original_tgt = linecache.getline(self.original_tgtF, index+1).strip().split() if not self.char else \ 56 | list(linecache.getline(self.original_tgtF, index + 1).strip()) 57 | 58 | return src, tgt, original_src, original_tgt 59 | 60 | def __len__(self): 61 | return len(self.indexes) 62 | 63 | 64 | def splitDataset(data_set, sizes): 65 | length = len(data_set) 66 | indexes = list(range(length)) 67 | rng = Random() 68 | rng.seed(1234) 69 | rng.shuffle(indexes) 70 | 71 | data_sets = [] 72 | part_len = int(length / sizes) 73 | for i in range(sizes-1): 74 | data_sets.append(BiDataset(data_set.infos, indexes[0:part_len])) 75 | indexes = indexes[part_len:] 76 | data_sets.append(BiDataset(data_set.infos, indexes)) 77 | return data_sets 78 | 79 | 80 | def padding(data): 81 | src, tgt, original_src, original_tgt = zip(*data) 82 | 83 | src_len = [len(s) for s in src] 84 | src_pad = torch.zeros(len(src), max(src_len)).long() 85 | for i, s in enumerate(src): 86 | end = src_len[i] 87 | src_pad[i, :end] = torch.LongTensor(s[end-1::-1]) 88 | 89 | tgt_len = [len(s) for s in tgt] 90 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 91 | for i, s in enumerate(tgt): 92 | end = tgt_len[i] 93 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 94 | 95 | return src_pad, tgt_pad, \ 96 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), \ 97 | original_src, original_tgt 98 | 99 | 100 | def ae_padding(data): 101 | src, tgt, original_src, original_tgt = zip(*data) 102 | 103 | src_len = [len(s) for s in src] 104 | src_pad = torch.zeros(len(src), max(src_len)).long() 105 | for i, s in enumerate(src): 106 | end = src_len[i] 107 | src_pad[i, :end] = torch.LongTensor(s)[:end] 108 | 109 | tgt_len = [len(s) for s in tgt] 110 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 111 | for i, s in enumerate(tgt): 112 | end = tgt_len[i] 113 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 114 | 115 | ae_len = [len(s)+2 for s in src] 116 | ae_pad = torch.zeros(len(src), max(ae_len)).long() 117 | for i, s in enumerate(src): 118 | end = ae_len[i] 119 | ae_pad[i, 0] = utils.BOS 120 | ae_pad[i, 1:end-1] = torch.LongTensor(s)[:end-2] 121 | ae_pad[i, end-1] = utils.EOS 122 | 123 | return src_pad, tgt_pad, ae_pad, \ 124 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), torch.LongTensor(ae_len), \ 125 | original_src, original_tgt 126 | 127 | 128 | def split_padding(data): 129 | src, tgt, original_src, original_tgt = zip(*data) 130 | 131 | split_samples = [] 132 | num_per_sample = int(len(src) / utils.num_samples) 133 | 134 | for i in range(utils.num_samples): 135 | split_src = src[i*num_per_sample:(i+1)*num_per_sample] 136 | split_tgt = tgt[i*num_per_sample:(i+1)*num_per_sample] 137 | split_original_src = original_src[i * num_per_sample:(i + 1) * num_per_sample] 138 | split_original_tgt = original_tgt[i * num_per_sample:(i + 1) * num_per_sample] 139 | 140 | src_len = [len(s) for s in split_src] 141 | src_pad = torch.zeros(len(split_src), max(src_len)).long() 142 | for i, s in enumerate(split_src): 143 | end = src_len[i] 144 | src_pad[i, :end] = torch.LongTensor(s)[:end] 145 | 146 | tgt_len = [len(s) for s in split_tgt] 147 | tgt_pad = torch.zeros(len(split_tgt), max(tgt_len)).long() 148 | for i, s in enumerate(split_tgt): 149 | end = tgt_len[i] 150 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 151 | 152 | split_samples.append([src_pad, tgt_pad, 153 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), 154 | split_original_src, split_original_tgt]) 155 | 156 | return split_samples -------------------------------------------------------------------------------- /utils/dict_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | PAD = 0 5 | UNK = 1 6 | BOS = 2 7 | EOS = 3 8 | 9 | PAD_WORD = '' 10 | UNK_WORD = ' ' 11 | BOS_WORD = '' 12 | EOS_WORD = '' 13 | 14 | 15 | class Dict(object): 16 | def __init__(self, data=None, lower=True): 17 | self.idxToLabel = {} 18 | self.labelToIdx = {} 19 | self.frequencies = {} 20 | self.lower = lower 21 | # Special entries will not be pruned. 22 | self.special = [] 23 | 24 | if data is not None: 25 | if type(data) == str: 26 | self.loadFile(data) 27 | else: 28 | self.addSpecials(data) 29 | 30 | def size(self): 31 | return len(self.idxToLabel) 32 | 33 | # Load entries from a file. 34 | def loadFile(self, filename): 35 | for line in open(filename): 36 | fields = line.split() 37 | label = fields[0] 38 | idx = int(fields[1]) 39 | self.add(label, idx) 40 | 41 | # Write entries to a file. 42 | def writeFile(self, filename): 43 | with open(filename, 'w') as file: 44 | for i in range(self.size()): 45 | label = self.idxToLabel[i] 46 | file.write('%s %d\n' % (label, i)) 47 | 48 | file.close() 49 | 50 | def loadDict(self, idxToLabel): 51 | for i in range(len(idxToLabel)): 52 | label = idxToLabel[i] 53 | self.add(label, i) 54 | 55 | def lookup(self, key, default=None): 56 | key = key.lower() if self.lower else key 57 | try: 58 | return self.labelToIdx[key] 59 | except KeyError: 60 | return default 61 | 62 | def getLabel(self, idx, default=None): 63 | try: 64 | return self.idxToLabel[idx] 65 | except KeyError: 66 | return default 67 | 68 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 69 | def addSpecial(self, label, idx=None): 70 | idx = self.add(label, idx) 71 | self.special += [idx] 72 | 73 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 74 | def addSpecials(self, labels): 75 | for label in labels: 76 | self.addSpecial(label) 77 | 78 | # Add `label` in the dictionary. Use `idx` as its index if given. 79 | def add(self, label, idx=None): 80 | label = label.lower() if self.lower else label 81 | if idx is not None: 82 | self.idxToLabel[idx] = label 83 | self.labelToIdx[label] = idx 84 | else: 85 | if label in self.labelToIdx: 86 | idx = self.labelToIdx[label] 87 | else: 88 | idx = len(self.idxToLabel) 89 | self.idxToLabel[idx] = label 90 | self.labelToIdx[label] = idx 91 | 92 | if idx not in self.frequencies: 93 | self.frequencies[idx] = 1 94 | else: 95 | self.frequencies[idx] += 1 96 | 97 | return idx 98 | 99 | # Return a new dictionary with the `size` most frequent entries. 100 | def prune(self, size): 101 | if size > self.size(): 102 | return self 103 | 104 | # Only keep the `size` most frequent entries. 105 | freq = torch.tensor( 106 | [self.frequencies[i] for i in range(len(self.frequencies))]) 107 | _, idx = torch.sort(freq, 0, True) 108 | idx = idx.tolist() 109 | 110 | newDict = Dict() 111 | newDict.lower = self.lower 112 | 113 | # Add special entries in all cases. 114 | for i in self.special: 115 | newDict.addSpecial(self.idxToLabel[i]) 116 | 117 | for i in idx[:size]: 118 | newDict.add(self.idxToLabel[i]) 119 | 120 | return newDict 121 | 122 | # Convert `labels` to indices. Use `unkWord` if not found. 123 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 124 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 125 | vec = [] 126 | 127 | if bosWord is not None: 128 | vec += [self.lookup(bosWord)] 129 | 130 | unk = self.lookup(unkWord) 131 | vec += [self.lookup(label, default=unk) for label in labels] 132 | 133 | if eosWord is not None: 134 | vec += [self.lookup(eosWord)] 135 | 136 | return vec 137 | 138 | def convertToIdxandOOVs(self, labels, unkWord, bosWord=None, eosWord=None): 139 | vec = [] 140 | oovs = OrderedDict() 141 | 142 | if bosWord is not None: 143 | vec += [self.lookup(bosWord)] 144 | 145 | unk = self.lookup(unkWord) 146 | for label in labels: 147 | id = self.lookup(label, default=unk) 148 | if id != unk: 149 | vec += [id] 150 | else: 151 | if label not in oovs: 152 | oovs[label] = len(oovs)+self.size() 153 | oov_num = oovs[label] 154 | vec += [oov_num] 155 | 156 | if eosWord is not None: 157 | vec += [self.lookup(eosWord)] 158 | 159 | return torch.LongTensor(vec), oovs 160 | 161 | def convertToIdxwithOOVs(self, labels, unkWord, bosWord=None, eosWord=None, oovs=None): 162 | vec = [] 163 | 164 | if bosWord is not None: 165 | vec += [self.lookup(bosWord)] 166 | 167 | unk = self.lookup(unkWord) 168 | for label in labels: 169 | id = self.lookup(label, default=unk) 170 | if id == unk and label in oovs: 171 | vec += [oovs[label]] 172 | else: 173 | vec += [id] 174 | 175 | if eosWord is not None: 176 | vec += [self.lookup(eosWord)] 177 | 178 | return torch.LongTensor(vec) 179 | 180 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 181 | def convertToLabels(self, idx, stop, oovs=None): 182 | labels = [] 183 | 184 | for i in idx: 185 | if i == stop: 186 | break 187 | if i < self.size(): 188 | labels += [self.getLabel(i)] 189 | else: 190 | labels += [oovs[i-self.size()]] 191 | 192 | return labels 193 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | import numpy as np 4 | from sklearn import metrics 5 | 6 | 7 | def eval_metrics(reference, candidate, label_dict, log_path): 8 | 9 | ref_dir = log_path + 'reference/' 10 | cand_dir = log_path + 'candidate/' 11 | if not os.path.exists(ref_dir): 12 | os.makedirs(ref_dir) 13 | if not os.path.exists(cand_dir): 14 | os.makedirs(cand_dir) 15 | ref_file = ref_dir + 'reference' 16 | cand_file = cand_dir + 'candidate' 17 | 18 | for i in range(len(reference)): 19 | with codecs.open(ref_file+str(i), 'w', 'utf-8') as f: 20 | f.write(" ".join(reference[i]) + '\n') 21 | with codecs.open(cand_file+str(i), 'w', 'utf-8') as f: 22 | f.write(" ".join(candidate[i]) + '\n') 23 | 24 | def make_label(l, label_dict): 25 | length = len(label_dict) 26 | result = np.zeros(length) 27 | indices = [label_dict.get(label.strip().lower(), 0) for label in l] 28 | result[indices] = 1 29 | return result 30 | 31 | def prepare_label(y_list, y_pre_list, label_dict): 32 | reference = np.array([make_label(y, label_dict) for y in y_list]) 33 | candidate = np.array([make_label(y_pre, label_dict) for y_pre in y_pre_list]) 34 | return reference, candidate 35 | 36 | def get_metrics(y, y_pre): 37 | hamming_loss = metrics.hamming_loss(y, y_pre) 38 | micro_f1 = metrics.f1_score(y, y_pre, average='micro') 39 | micro_precision = metrics.precision_score(y, y_pre, average='micro') 40 | micro_recall = metrics.recall_score(y, y_pre, average='micro') 41 | instance_f1 = metrics.f1_score(y, y_pre, average='samples') 42 | instance_precision = metrics.precision_score(y, y_pre, average='samples') 43 | instance_recall = metrics.recall_score(y, y_pre, average='samples') 44 | return hamming_loss, \ 45 | micro_f1, micro_precision, micro_recall, \ 46 | instance_f1, instance_precision, instance_recall 47 | 48 | y, y_pre = prepare_label(reference, candidate, label_dict) 49 | hamming_loss, micro_f1, micro_precision, micro_recall, instance_f1, instance_precision, instance_recall = get_metrics(y, y_pre) 50 | 51 | return {'hamming_loss': hamming_loss, 52 | 'micro_f1': micro_f1, 53 | 'micro_precision': micro_precision, 54 | 'micro_recall': micro_recall, 55 | 'instance_f1': instance_f1, 56 | 'instance_precision': instance_precision, 57 | 'instance_recall': instance_recall} -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import time 4 | import sys 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super(AttrDict, self).__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def read_config(path): 14 | return AttrDict(yaml.load(open(path, 'r'))) 15 | 16 | 17 | def print_log(file): 18 | def write_log(s): 19 | print(s, end='') 20 | with open(file, 'a') as f: 21 | f.write(s) 22 | return write_log 23 | 24 | 25 | _, term_width = os.popen('stty size', 'r').read().split() 26 | term_width = int(term_width) 27 | 28 | TOTAL_BAR_LENGTH = 86. 29 | last_time = time.time() 30 | begin_time = last_time 31 | 32 | 33 | def progress_bar(current, total, msg=None): 34 | global last_time, begin_time 35 | current = current % total 36 | if current == 0: 37 | begin_time = time.time() # Reset for new bar. 38 | 39 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 40 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 41 | 42 | sys.stdout.write(' [') 43 | for i in range(cur_len): 44 | sys.stdout.write('=') 45 | sys.stdout.write('>') 46 | for i in range(rest_len): 47 | sys.stdout.write('.') 48 | sys.stdout.write(']') 49 | 50 | cur_time = time.time() 51 | step_time = cur_time - last_time 52 | last_time = cur_time 53 | tot_time = cur_time - begin_time 54 | 55 | L = [] 56 | L.append(' Step: %s' % format_time(step_time)) 57 | L.append(' | Tot: %s' % format_time(tot_time)) 58 | if msg: 59 | L.append(' | ' + msg) 60 | 61 | msg = ''.join(L) 62 | sys.stdout.write(msg) 63 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 64 | sys.stdout.write(' ') 65 | 66 | # Go back to the center of the bar. 67 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 68 | sys.stdout.write('\b') 69 | sys.stdout.write(' %d/%d ' % (current+1, total)) 70 | 71 | if current < total-1: 72 | sys.stdout.write('\r') 73 | else: 74 | sys.stdout.write('\n') 75 | sys.stdout.flush() 76 | 77 | 78 | def format_time(seconds): 79 | days = int(seconds / 3600/24) 80 | seconds = seconds - days*3600*24 81 | hours = int(seconds / 3600) 82 | seconds = seconds - hours*3600 83 | minutes = int(seconds / 60) 84 | seconds = seconds - minutes*60 85 | secondsf = int(seconds) 86 | seconds = seconds - secondsf 87 | millis = int(seconds*1000) 88 | 89 | f = '' 90 | i = 1 91 | if days > 0: 92 | f += str(days) + 'D' 93 | i += 1 94 | if hours > 0 and i <= 2: 95 | f += str(hours) + 'h' 96 | i += 1 97 | if minutes > 0 and i <= 2: 98 | f += str(minutes) + 'm' 99 | i += 1 100 | if secondsf > 0 and i <= 2: 101 | f += str(secondsf) + 's' 102 | i += 1 103 | if millis > 0 and i <= 2: 104 | f += str(millis) + 'ms' 105 | i += 1 106 | if f == '': 107 | f = '0ms' 108 | return f --------------------------------------------------------------------------------