├── semantic-unit-based ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── metrics.cpython-35.pyc │ │ ├── misc_utils.cpython-35.pyc │ │ ├── data_helper.cpython-35.pyc │ │ └── dict_helper.cpython-35.pyc │ ├── metrics.py │ ├── misc_utils.py │ ├── data_helper.py │ └── dict_helper.py ├── models │ ├── __pycache__ │ │ ├── beam.cpython-35.pyc │ │ ├── loss.cpython-35.pyc │ │ ├── rnn.cpython-35.pyc │ │ ├── optims.cpython-35.pyc │ │ ├── seq2seq.cpython-35.pyc │ │ ├── __init__.cpython-35.pyc │ │ └── attention.cpython-35.pyc │ ├── __init__.py │ ├── loss.py │ ├── optims.py │ ├── attention.py │ ├── beam.py │ ├── rnn.py │ └── seq2seq.py ├── config.yaml ├── opts.py ├── preprocess.py ├── train.py └── lr_scheduler.py └── README.md /semantic-unit-based/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_helper import * 2 | from .dict_helper import * 3 | from .misc_utils import * 4 | from .metrics import * 5 | -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/beam.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/beam.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/loss.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/rnn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/rnn.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/optims.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/optims.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/seq2seq.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/seq2seq.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/utils/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/utils/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/models/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/utils/__pycache__/misc_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/utils/__pycache__/misc_utils.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/utils/__pycache__/data_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/utils/__pycache__/data_helper.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/utils/__pycache__/dict_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/SU4MLC/HEAD/semantic-unit-based/utils/__pycache__/dict_helper.cpython-35.pyc -------------------------------------------------------------------------------- /semantic-unit-based/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .loss import * 3 | from .optims import * 4 | from .rnn import * 5 | from .seq2seq import * 6 | from .beam import * 7 | -------------------------------------------------------------------------------- /semantic-unit-based/models/loss.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | def cross_entropy_loss(scores, targets, criterion, config): 4 | loss = criterion(scores, targets.view(-1)) 5 | pred = scores.max(1)[1] 6 | # num_correct = pred.data.eq(targets.data).masked_select(targets.ne(utils.PAD).data).sum() 7 | num_correct = pred.eq(targets).masked_select(targets.ne(utils.PAD)).sum() 8 | # num_total = targets.ne(utils.PAD).data.sum() 9 | num_total = targets.ne(utils.PAD).sum() 10 | 11 | loss = loss / num_total 12 | 13 | return loss, num_total, num_correct 14 | -------------------------------------------------------------------------------- /semantic-unit-based/config.yaml: -------------------------------------------------------------------------------- 1 | data: '/home/linjunyang/multilabel_rcv/data/' 2 | logF: 'experiments/rcv/' 3 | epoch: 20 4 | batch_size: 64 5 | optim: 'adam' 6 | cell: 'lstm' 7 | attention: 'luong_gate' 8 | learning_rate: 0.0003 9 | max_grad_norm: 10 10 | learning_rate_decay: 0.5 11 | start_decay_at: 2 12 | emb_size: 512 13 | hidden_size: 512 14 | dec_num_layers: 2 15 | enc_num_layers: 2 16 | bidirectional: True 17 | dropout: 0.2 18 | max_time_step: 30 19 | eval_interval: 500 20 | save_interval: 1000 21 | unk: False 22 | schedule: False 23 | schesamp: False 24 | length_norm: True 25 | metrics: ['hamming_loss', 'macro_f1', 'micro_f1'] 26 | shared_vocab: False 27 | beam_size: 5 28 | dilated: True -------------------------------------------------------------------------------- /semantic-unit-based/models/optims.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm_ 3 | 4 | 5 | class Optim(object): 6 | 7 | def set_parameters(self, params): 8 | self.params = list(params) # careful: params may be a generator 9 | if self.method == 'sgd': 10 | self.optimizer = optim.SGD(self.params, lr=self.lr) 11 | elif self.method == 'adagrad': 12 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 13 | elif self.method == 'adadelta': 14 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 15 | elif self.method == 'adam': 16 | self.optimizer = optim.Adam(self.params, lr=self.lr) 17 | else: 18 | raise RuntimeError("Invalid optim method: " + self.method) 19 | 20 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None, max_decay_times=2): 21 | self.last_score = None 22 | self.decay_times = 0 23 | self.max_decay_times = max_decay_times 24 | self.lr = lr 25 | self.max_grad_norm = max_grad_norm 26 | self.method = method 27 | self.lr_decay = lr_decay 28 | self.start_decay_at = start_decay_at 29 | self.start_decay = False 30 | 31 | def step(self): 32 | # Compute gradients norm. 33 | if self.max_grad_norm: 34 | clip_grad_norm_(self.params, self.max_grad_norm) 35 | self.optimizer.step() 36 | 37 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 38 | def updateLearningRate(self, score, epoch): 39 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 40 | self.start_decay = True 41 | 42 | if self.start_decay: 43 | self.lr = self.lr * self.lr_decay 44 | print("Decaying learning rate to %g" % self.lr) 45 | 46 | self.last_score = score 47 | self.optimizer.param_groups[0]['lr'] = self.lr 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-Unit-for-Multi-label-Text-Classification 2 | Code for the article "Semantic-Unit-Based Dilated Convolution for Multi-Label Text Classification" (EMNLP 2018). 3 | 4 | *********************************************************** 5 | 6 | ## Requirements 7 | * Ubuntu 16.0.4 8 | * Python 3.5 9 | * Pytorch 0.4.1 (updated) 10 | 11 | ************************************************************** 12 | 13 | ## Data 14 | Our preprocessed RCV1-V2 dataset can be retrieved through [this link](https://drive.google.com/open?id=1oQ5_gPoRwAl7UGWTDNu4qATNtJ1l1kXd). (The json file of label set for evaluation is added for convenience.) 15 | 16 | *************************************************************** 17 | 18 | ## Preprocessing 19 | ``` 20 | python3 preprocess.py -load_data path_to_data -save_data path_to_store_data (-src_filter 500) 21 | ``` 22 | Remember to put the data (plain text file) into a folder and name them *train.src*, *train.tgt*, *valid.src*, *valid.tgt*, *test.src* and *test.tgt*, and make a new folder inside called *data*. 23 | 24 | *************************************************************** 25 | 26 | ## Training 27 | ``` 28 | python3 train.py -log log_name -config config_yaml -gpus id (-label_dict_file path to your label set) 29 | ``` 30 | Create your own yaml file for hyperparameter setting. 31 | 32 | **************************************************************** 33 | 34 | ## Evaluation 35 | ``` 36 | python3 train.py -log log_name -config config_yaml -gpus id -restore checkpoint -mode eval 37 | ``` 38 | 39 | ******************************************************************* 40 | 41 | # Citation 42 | If you use this code for your research, please kindly cite our paper: 43 | ``` 44 | @inproceedings{DBLP:conf/emnlp/LinSYM018, 45 | author = {Junyang Lin and 46 | Qi Su and 47 | Pengcheng Yang and 48 | Shuming Ma and 49 | Xu Sun}, 50 | title = {Semantic-Unit-Based Dilated Convolution for Multi-Label Text Classification}, 51 | booktitle = {Proceedings of the 2018 Conference on Empirical Methods in Natural 52 | Language Processing, Brussels, Belgium, October 31 - November 4, 2018}, 53 | pages = {4554--4564}, 54 | year = {2018} 55 | } 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /semantic-unit-based/opts.py: -------------------------------------------------------------------------------- 1 | def model_opts(parser): 2 | 3 | parser.add_argument('-config', default='default.yaml', type=str, 4 | help="config file") 5 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 6 | help="Use CUDA on the listed devices.") 7 | parser.add_argument('-restore', default='', type=str, 8 | help="restore checkpoint") 9 | parser.add_argument('-seed', type=int, default=1234, 10 | help="Random seed") 11 | parser.add_argument('-model', default='seq2seq', type=str, 12 | help="Model selection") 13 | parser.add_argument('-mode', default='train', type=str, 14 | help="Mode selection") 15 | parser.add_argument('-module', default='seq2seq', type=str, 16 | help="Module selection") 17 | parser.add_argument('-log', default='', type=str, 18 | help="log directory") 19 | parser.add_argument('-num_processes', type=int, default=4, 20 | help="number of processes") 21 | parser.add_argument('-refF', default='', type=str, 22 | help="reference file") 23 | parser.add_argument('-unk', action='store_true', help='replace unk') 24 | parser.add_argument('-char', action='store_true', help='char level decoding') 25 | parser.add_argument('-length_norm', action='store_true', help='replace unk') 26 | parser.add_argument('-pool_size', type=int, default=0, help="pool size of maxout layer") 27 | parser.add_argument('-scale', type=float, default=1, help="proportion of the training set") 28 | parser.add_argument('-max_split', type=int, default=0, help="max generator time steps for memory efficiency") 29 | parser.add_argument('-split_num', type=int, default=0, help="split number for splitres") 30 | parser.add_argument('-pretrain', default='', type=str, help="load pretrain encoder") 31 | parser.add_argument('-label_dict_file', default='~/multilabel_rcv/topic_sorted.json', type=str, 32 | help="label_dict") 33 | 34 | 35 | def convert_to_config(opt, config): 36 | opt = vars(opt) 37 | for key in opt: 38 | if key not in config: 39 | config[key] = opt[key] 40 | -------------------------------------------------------------------------------- /semantic-unit-based/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import codecs 4 | import yaml 5 | import time 6 | import numpy as np 7 | 8 | from sklearn import metrics 9 | 10 | 11 | def eval_metrics(reference, candidate, label_dict, log_path): 12 | ref_dir = log_path + 'reference/' 13 | cand_dir = log_path + 'candidate/' 14 | if not os.path.exists(ref_dir): 15 | os.mkdir(ref_dir) 16 | if not os.path.exists(cand_dir): 17 | os.mkdir(cand_dir) 18 | ref_file = ref_dir+'reference' 19 | cand_file = cand_dir+'candidate' 20 | 21 | for i in range(len(reference)): 22 | with codecs.open(ref_file+str(i), 'w', 'utf-8') as f: 23 | f.write("".join(reference[i])+'\n') 24 | with codecs.open(cand_file+str(i), 'w', 'utf-8') as f: 25 | f.write("".join(candidate[i])+'\n') 26 | 27 | def make_label(l, label_dict): 28 | length = len(label_dict) 29 | result = np.zeros(length) 30 | indices = [label_dict.get(label.strip().upper(), 0) for label in l] 31 | result[indices] = 1 32 | return result 33 | 34 | def prepare_label(y_list, y_pre_list, label_dict): 35 | reference = np.array([make_label(y, label_dict) for y in y_list]) 36 | candidate = np.array([make_label(y_pre, label_dict) for y_pre in y_pre_list]) 37 | return reference, candidate 38 | 39 | def get_one_error(y, candidate, label_dict): 40 | idx = [label_dict.get(c[0].strip().upper(), 0) for c in candidate] 41 | result = [(y[i, idx[i]] == 1) for i in range(len(idx))] 42 | return (1 - np.array(result).mean()) 43 | 44 | def get_metrics(y, y_pre): 45 | hamming_loss = metrics.hamming_loss(y, y_pre) 46 | macro_f1 = metrics.f1_score(y, y_pre, average='macro') 47 | macro_precision = metrics.precision_score(y, y_pre, average='macro') 48 | macro_recall = metrics.recall_score(y, y_pre, average='macro') 49 | micro_f1 = metrics.f1_score(y, y_pre, average='micro') 50 | micro_precision = metrics.precision_score(y, y_pre, average='micro') 51 | micro_recall = metrics.recall_score(y, y_pre, average='micro') 52 | return hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall 53 | 54 | y, y_pre = prepare_label(reference, candidate, label_dict) 55 | # one_error = get_one_error(y, candidate, label_dict) 56 | hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall = get_metrics(y, y_pre) 57 | return {'hamming_loss': hamming_loss, 58 | 'macro_f1': macro_f1, 59 | 'macro_precision': macro_precision, 60 | 'macro_recall': macro_recall, 61 | 'micro_f1': micro_f1, 62 | 'micro_precision': micro_precision, 63 | 'micro_recall': micro_recall} -------------------------------------------------------------------------------- /semantic-unit-based/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import time 4 | import sys 5 | 6 | class AttrDict(dict): 7 | def __init__(self, *args, **kwargs): 8 | super(AttrDict, self).__init__(*args, **kwargs) 9 | self.__dict__ = self 10 | 11 | 12 | def read_config(path): 13 | return AttrDict(yaml.load(open(path, 'r'))) 14 | 15 | 16 | def print_log(file): 17 | def write_log(s): 18 | print(s, end='') 19 | with open(file, 'a') as f: 20 | f.write(s) 21 | return write_log 22 | 23 | 24 | 25 | _, term_width = os.popen('stty size', 'r').read().split() 26 | term_width = int(term_width) 27 | 28 | TOTAL_BAR_LENGTH = 86. 29 | last_time = time.time() 30 | begin_time = last_time 31 | def progress_bar(current, total, msg=None): 32 | global last_time, begin_time 33 | current = current % total 34 | if current == 0: 35 | begin_time = time.time() # Reset for new bar. 36 | 37 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 38 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 39 | 40 | sys.stdout.write(' [') 41 | for i in range(cur_len): 42 | sys.stdout.write('=') 43 | sys.stdout.write('>') 44 | for i in range(rest_len): 45 | sys.stdout.write('.') 46 | sys.stdout.write(']') 47 | 48 | cur_time = time.time() 49 | step_time = cur_time - last_time 50 | last_time = cur_time 51 | tot_time = cur_time - begin_time 52 | 53 | L = [] 54 | L.append(' Step: %s' % format_time(step_time)) 55 | L.append(' | Tot: %s' % format_time(tot_time)) 56 | if msg: 57 | L.append(' | ' + msg) 58 | 59 | msg = ''.join(L) 60 | sys.stdout.write(msg) 61 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 62 | sys.stdout.write(' ') 63 | 64 | # Go back to the center of the bar. 65 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 66 | sys.stdout.write('\b') 67 | sys.stdout.write(' %d/%d ' % (current+1, total)) 68 | 69 | if current < total-1: 70 | sys.stdout.write('\r') 71 | else: 72 | sys.stdout.write('\n') 73 | sys.stdout.flush() 74 | 75 | def format_time(seconds): 76 | days = int(seconds / 3600/24) 77 | seconds = seconds - days*3600*24 78 | hours = int(seconds / 3600) 79 | seconds = seconds - hours*3600 80 | minutes = int(seconds / 60) 81 | seconds = seconds - minutes*60 82 | secondsf = int(seconds) 83 | seconds = seconds - secondsf 84 | millis = int(seconds*1000) 85 | 86 | f = '' 87 | i = 1 88 | if days > 0: 89 | f += str(days) + 'D' 90 | i += 1 91 | if hours > 0 and i <= 2: 92 | f += str(hours) + 'h' 93 | i += 1 94 | if minutes > 0 and i <= 2: 95 | f += str(minutes) + 'm' 96 | i += 1 97 | if secondsf > 0 and i <= 2: 98 | f += str(secondsf) + 's' 99 | i += 1 100 | if millis > 0 and i <= 2: 101 | f += str(millis) + 'ms' 102 | i += 1 103 | if f == '': 104 | f = '0ms' 105 | return f -------------------------------------------------------------------------------- /semantic-unit-based/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class luong_attention(nn.Module): 7 | 8 | def __init__(self, hidden_size, emb_size, pool_size=0): 9 | super(luong_attention, self).__init__() 10 | self.hidden_size, self.emb_size, self.pool_size = hidden_size, emb_size, pool_size 11 | self.linear_in = nn.Linear(hidden_size, hidden_size) 12 | if pool_size > 0: 13 | self.linear_out = maxout(2*hidden_size + emb_size, hidden_size, pool_size) 14 | else: 15 | self.linear_out = nn.Sequential(nn.Linear(2*hidden_size + emb_size, hidden_size), nn.SELU(), nn.Linear(hidden_size, hidden_size), nn.Tanh()) 16 | self.softmax = nn.Softmax(dim=1) 17 | 18 | def init_context(self, context): 19 | self.context = context.transpose(0, 1) 20 | 21 | def forward(self, h, x): 22 | gamma_h = self.linear_in(h).unsqueeze(2) # batch * size * 1 23 | weights = torch.bmm(self.context, gamma_h).squeeze(2) # batch * time 24 | weights = self.softmax(weights) # batch * time 25 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) # batch * size 26 | output = self.linear_out(torch.cat([c_t, h, x], 1)) 27 | 28 | return output, weights 29 | 30 | 31 | class luong_gate_attention(nn.Module): 32 | 33 | def __init__(self, hidden_size, emb_size, prob=0.2): 34 | super(luong_gate_attention, self).__init__() 35 | self.hidden_size, self.emb_size = hidden_size, emb_size 36 | self.linear_in_conv = nn.Linear(hidden_size, hidden_size) 37 | self.linear_out_conv = nn.Linear(2*hidden_size, hidden_size) 38 | self.selu_out_conv = nn.Sequential(nn.SELU()) 39 | self.linear_in = nn.Linear(hidden_size, hidden_size) 40 | self.linear_out = nn.Linear(2*hidden_size, hidden_size) 41 | self.selu_out = nn.Sequential(nn.SELU()) 42 | self.softmax = nn.Softmax(dim=1) 43 | 44 | def init_context(self, context): 45 | self.context = context.transpose(0, 1) 46 | 47 | def forward(self, h, conv): 48 | gamma_h = self.linear_in_conv(h).unsqueeze(2) 49 | weights = torch.bmm(conv, gamma_h).squeeze(2) 50 | weights = self.softmax(weights) 51 | c_t_conv = torch.bmm(weights.unsqueeze(1), conv).squeeze(1) 52 | output_conv = self.selu_out_conv(self.linear_out_conv(torch.cat((h, c_t_conv), -1))) 53 | 54 | gamma_h = self.linear_in(output_conv).unsqueeze(2) 55 | weights = torch.bmm(self.context, gamma_h).squeeze(2) 56 | weights = self.softmax(weights) 57 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) 58 | output = self.selu_out(self.linear_out(torch.cat([output_conv, c_t], -1))) 59 | 60 | output = output_conv + output 61 | 62 | return output, weights 63 | 64 | 65 | class bahdanau_attention(nn.Module): 66 | 67 | def __init__(self, hidden_size, emb_size, pool_size=0): 68 | super(bahdanau_attention, self).__init__() 69 | self.linear_encoder = nn.Linear(hidden_size, hidden_size) 70 | self.linear_decoder = nn.Linear(hidden_size, hidden_size) 71 | self.linear_v = nn.Linear(hidden_size, 1) 72 | self.linear_r = nn.Linear(hidden_size*2+emb_size, hidden_size*2) 73 | self.hidden_size = hidden_size 74 | self.emb_size = emb_size 75 | self.softmax = nn.Softmax(dim=1) 76 | self.tanh = nn.Tanh() 77 | 78 | def init_context(self, context): 79 | self.context = context.transpose(0, 1) 80 | 81 | def forward(self, h, x): 82 | gamma_encoder = self.linear_encoder(self.context) # batch * time * size 83 | gamma_decoder = self.linear_decoder(h).unsqueeze(1) # batch * 1 * size 84 | weights = self.linear_v(self.tanh(gamma_encoder+gamma_decoder)).squeeze(2) # batch * time 85 | weights = self.softmax(weights) # batch * time 86 | c_t = torch.bmm(weights.unsqueeze(1), self.context).squeeze(1) # batch * size 87 | r_t = self.linear_r(torch.cat([c_t, h, x], dim=1)) 88 | output = r_t.view(-1, self.hidden_size, 2).max(2)[0] 89 | 90 | return output, weights 91 | 92 | 93 | class maxout(nn.Module): 94 | 95 | def __init__(self, in_feature, out_feature, pool_size): 96 | super(maxout, self).__init__() 97 | self.in_feature = in_feature 98 | self.out_feature = out_feature 99 | self.pool_size = pool_size 100 | self.linear = nn.Linear(in_feature, out_feature*pool_size) 101 | 102 | def forward(self, x): 103 | output = self.linear(x) 104 | output = output.view(-1, self.out_feature, self.pool_size) 105 | output = output.max(2)[0] 106 | 107 | return output 108 | -------------------------------------------------------------------------------- /semantic-unit-based/utils/data_helper.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import torch 3 | import torch.utils.data as torch_data 4 | from random import Random 5 | import utils 6 | 7 | num_samples = 1 8 | 9 | 10 | class MonoDataset(torch_data.Dataset): 11 | 12 | def __init__(self, infos, indexes=None): 13 | 14 | self.srcF = infos['srcF'] 15 | self.original_srcF = infos['original_srcF'] 16 | self.length = infos['length'] 17 | self.infos = infos 18 | if indexes is None: 19 | self.indexes = list(range(self.length)) 20 | else: 21 | self.indexes = indexes 22 | 23 | def __getitem__(self, index): 24 | index = self.indexes[index] 25 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 26 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 27 | 28 | return src, original_src 29 | 30 | def __len__(self): 31 | return len(self.indexes) 32 | 33 | 34 | class BiDataset(torch_data.Dataset): 35 | 36 | def __init__(self, infos, indexes=None, char=False): 37 | 38 | self.srcF = infos['srcF'] 39 | self.tgtF = infos['tgtF'] 40 | self.original_srcF = infos['original_srcF'] 41 | self.original_tgtF = infos['original_tgtF'] 42 | self.length = infos['length'] 43 | self.infos = infos 44 | self.char = char 45 | if indexes is None: 46 | self.indexes = list(range(self.length)) 47 | else: 48 | self.indexes = indexes 49 | 50 | def __getitem__(self, index): 51 | index = self.indexes[index] 52 | src = list(map(int, linecache.getline(self.srcF, index+1).strip().split())) 53 | tgt = list(map(int, linecache.getline(self.tgtF, index+1).strip().split())) 54 | original_src = linecache.getline(self.original_srcF, index+1).strip().split() 55 | original_tgt = linecache.getline(self.original_tgtF, index+1).strip().split() if not self.char else \ 56 | list(linecache.getline(self.original_tgtF, index + 1).strip()) 57 | 58 | return src, tgt, original_src, original_tgt 59 | 60 | def __len__(self): 61 | return len(self.indexes) 62 | 63 | 64 | def splitDataset(data_set, sizes): 65 | length = len(data_set) 66 | indexes = list(range(length)) 67 | rng = Random() 68 | rng.seed(1234) 69 | rng.shuffle(indexes) 70 | 71 | data_sets = [] 72 | part_len = int(length / sizes) 73 | for i in range(sizes-1): 74 | data_sets.append(BiDataset(data_set.infos, indexes[0:part_len])) 75 | indexes = indexes[part_len:] 76 | data_sets.append(BiDataset(data_set.infos, indexes)) 77 | return data_sets 78 | 79 | 80 | def padding(data): 81 | src, tgt, original_src, original_tgt = zip(*data) 82 | 83 | src_len = [len(s) for s in src] 84 | src_pad = torch.zeros(len(src), max(src_len)).long() 85 | for i, s in enumerate(src): 86 | end = src_len[i] 87 | src_pad[i, :end] = torch.LongTensor(s[end-1::-1]) 88 | 89 | tgt_len = [len(s) for s in tgt] 90 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 91 | for i, s in enumerate(tgt): 92 | end = tgt_len[i] 93 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 94 | 95 | return src_pad, tgt_pad, \ 96 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), \ 97 | original_src, original_tgt 98 | 99 | 100 | def ae_padding(data): 101 | src, tgt, original_src, original_tgt = zip(*data) 102 | 103 | src_len = [len(s) for s in src] 104 | src_pad = torch.zeros(len(src), max(src_len)).long() 105 | for i, s in enumerate(src): 106 | end = src_len[i] 107 | src_pad[i, :end] = torch.LongTensor(s)[:end] 108 | 109 | tgt_len = [len(s) for s in tgt] 110 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 111 | for i, s in enumerate(tgt): 112 | end = tgt_len[i] 113 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 114 | 115 | ae_len = [len(s)+2 for s in src] 116 | ae_pad = torch.zeros(len(src), max(ae_len)).long() 117 | for i, s in enumerate(src): 118 | end = ae_len[i] 119 | ae_pad[i, 0] = utils.BOS 120 | ae_pad[i, 1:end-1] = torch.LongTensor(s)[:end-2] 121 | ae_pad[i, end-1] = utils.EOS 122 | 123 | return src_pad, tgt_pad, ae_pad, \ 124 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), torch.LongTensor(ae_len), \ 125 | original_src, original_tgt 126 | 127 | 128 | def split_padding(data): 129 | src, tgt, original_src, original_tgt = zip(*data) 130 | 131 | split_samples = [] 132 | num_per_sample = int(len(src) / utils.num_samples) 133 | 134 | for i in range(utils.num_samples): 135 | split_src = src[i*num_per_sample:(i+1)*num_per_sample] 136 | split_tgt = tgt[i*num_per_sample:(i+1)*num_per_sample] 137 | split_original_src = original_src[i * num_per_sample:(i + 1) * num_per_sample] 138 | split_original_tgt = original_tgt[i * num_per_sample:(i + 1) * num_per_sample] 139 | 140 | src_len = [len(s) for s in split_src] 141 | src_pad = torch.zeros(len(split_src), max(src_len)).long() 142 | for i, s in enumerate(split_src): 143 | end = src_len[i] 144 | src_pad[i, :end] = torch.LongTensor(s)[:end] 145 | 146 | tgt_len = [len(s) for s in split_tgt] 147 | tgt_pad = torch.zeros(len(split_tgt), max(tgt_len)).long() 148 | for i, s in enumerate(split_tgt): 149 | end = tgt_len[i] 150 | tgt_pad[i, :end] = torch.LongTensor(s)[:end] 151 | 152 | split_samples.append([src_pad, tgt_pad, 153 | torch.LongTensor(src_len), torch.LongTensor(tgt_len), 154 | split_original_src, split_original_tgt]) 155 | 156 | return split_samples -------------------------------------------------------------------------------- /semantic-unit-based/models/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | class Beam(object): 5 | def __init__(self, size, n_best=1, cuda=True, length_norm=False, minimum_length=0): 6 | 7 | self.size = size 8 | self.tt = torch.cuda if cuda else torch 9 | 10 | # The score for each translation on the beam. 11 | self.scores = self.tt.FloatTensor(size).zero_() 12 | self.allScores = [] 13 | 14 | # The backpointers at each time-step. 15 | self.prevKs = [] 16 | 17 | # The outputs at each time-step. 18 | self.nextYs = [self.tt.LongTensor(size) 19 | .fill_(utils.EOS)] 20 | self.nextYs[0][0] = utils.BOS 21 | 22 | # Has EOS topped the beam yet. 23 | self._eos = utils.EOS 24 | self.eosTop = False 25 | 26 | # The attentions (matrix) for each time. 27 | self.attn = [] 28 | 29 | # Time and k pair for finished. 30 | self.finished = [] 31 | self.n_best = n_best 32 | 33 | self.length_norm = length_norm 34 | self.minimum_length = minimum_length 35 | 36 | 37 | def getCurrentState(self): 38 | "Get the outputs for the current timestep." 39 | return self.nextYs[-1] 40 | 41 | def getCurrentOrigin(self): 42 | "Get the backpointers for the current timestep." 43 | return self.prevKs[-1] 44 | 45 | def advance(self, wordLk, attnOut): 46 | """ 47 | Given prob over words for every last beam `wordLk` and attention 48 | `attnOut`: Compute and update the beam search. 49 | Parameters: 50 | * `wordLk`- probs of advancing from the last step (K x words) 51 | * `attnOut`- attention at the last step 52 | Returns: True if beam search is complete. 53 | """ 54 | numWords = wordLk.size(1) 55 | 56 | # Sum the previous scores. 57 | if len(self.prevKs) > 0: 58 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 59 | 60 | # Don't let EOS have children. 61 | for i in range(self.nextYs[-1].size(0)): 62 | if self.nextYs[-1][i] == self._eos: 63 | beamLk[i] = -1e20 64 | ngrams = [] 65 | le = len(self.nextYs) 66 | for j in range(self.nextYs[-1].size(0)): 67 | hyp, _ = self.getHyp(le-1, j) 68 | ngrams = set() 69 | fail = False 70 | gram = [] 71 | for i in range(le-1): 72 | # last n tokens, n = block_ngram_repeat 73 | gram = (gram + [hyp[i]])[-3:] 74 | # skip the blocking if it is in the exclusion list 75 | #if set(gram) & self.exclusion_tokens: 76 | #continue 77 | if tuple(gram) in ngrams: 78 | fail = True 79 | ngrams.add(tuple(gram)) 80 | if fail: 81 | beamLk[j] = -1e20 82 | else: 83 | beamLk = wordLk[0] 84 | flatBeamLk = beamLk.view(-1) 85 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 86 | 87 | self.allScores.append(self.scores) 88 | self.scores = bestScores 89 | 90 | # bestScoresId is flattened beam x word array, so calculate which 91 | # word and beam each score came from 92 | prevK = bestScoresId / numWords 93 | self.prevKs.append(prevK) 94 | self.nextYs.append((bestScoresId - prevK * numWords)) 95 | self.attn.append(attnOut.index_select(0, prevK)) 96 | 97 | for i in range(self.nextYs[-1].size(0)): 98 | if self.nextYs[-1][i] == self._eos: 99 | s = self.scores[i] 100 | if self.length_norm: 101 | s /= len(self.nextYs) 102 | if len(self.nextYs) - 1 >= self.minimum_length: 103 | self.finished.append((s, len(self.nextYs) - 1, i)) 104 | 105 | # End condition is when top-of-beam is EOS and no global score. 106 | if self.nextYs[-1][0] == utils.EOS: 107 | self.allScores.append(self.scores) 108 | self.eosTop = True 109 | 110 | def done(self): 111 | return self.eosTop and len(self.finished) >= self.n_best 112 | 113 | def beam_update(self, state, idx): 114 | positions = self.getCurrentOrigin() 115 | for e in state: 116 | a, br, d = e.size() 117 | e = e.view(a, self.size, br // self.size, d) 118 | sentStates = e[:, :, idx] 119 | sentStates.copy_(sentStates.index_select(1, positions)) 120 | 121 | def beam_update_gru(self, state, idx): 122 | positions = self.getCurrentOrigin() 123 | for e in state: 124 | br, d = e.size() 125 | e = e.view(self.size, br // self.size, d) 126 | sentStates = e[:, idx] 127 | sentStates.copy_(sentStates.index_select(0, positions)) 128 | 129 | def beam_update_memory(self, state, idx): 130 | positions = self.getCurrentOrigin() 131 | e = state 132 | br, d = e.size() 133 | e = e.view(self.size, br // self.size, d) 134 | sentStates = e[:, idx] 135 | sentStates.copy_(sentStates.index_select(0, positions)) 136 | 137 | def sortFinished(self, minimum=None): 138 | if minimum is not None: 139 | i = 0 140 | # Add from beam until we have minimum outputs. 141 | while len(self.finished) < minimum: 142 | s = self.scores[i].item() 143 | self.finished.append((s, len(self.nextYs) - 1, i)) 144 | i += 1 145 | 146 | self.finished.sort(key=lambda a: -a[0]) 147 | scores = [sc for sc, _, _ in self.finished] 148 | ks = [(t, k) for _, t, k in self.finished] 149 | return scores, ks 150 | 151 | def getHyp(self, timestep, k): 152 | """ 153 | Walk back to construct the full hypothesis. 154 | """ 155 | hyp, attn = [], [] 156 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 157 | hyp.append(self.nextYs[j+1][k].item()) 158 | attn.append(self.attn[j][k]) 159 | k = self.prevKs[j][k].item() 160 | return hyp[::-1], torch.stack(attn[::-1]) 161 | -------------------------------------------------------------------------------- /semantic-unit-based/utils/dict_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | PAD = 0 5 | UNK = 1 6 | BOS = 2 7 | EOS = 3 8 | 9 | PAD_WORD = '' 10 | UNK_WORD = ' ' 11 | BOS_WORD = '' 12 | EOS_WORD = '' 13 | 14 | 15 | class Dict(object): 16 | def __init__(self, data=None, lower=True): 17 | self.idxToLabel = {} 18 | self.labelToIdx = {} 19 | self.frequencies = {} 20 | self.lower = lower 21 | # Special entries will not be pruned. 22 | self.special = [] 23 | 24 | if data is not None: 25 | if type(data) == str: 26 | self.loadFile(data) 27 | else: 28 | self.addSpecials(data) 29 | 30 | def size(self): 31 | return len(self.idxToLabel) 32 | 33 | # Load entries from a file. 34 | def loadFile(self, filename): 35 | for line in open(filename): 36 | fields = line.split() 37 | label = fields[0] 38 | idx = int(fields[1]) 39 | self.add(label, idx) 40 | 41 | # Write entries to a file. 42 | def writeFile(self, filename): 43 | with open(filename, 'w') as file: 44 | for i in range(self.size()): 45 | label = self.idxToLabel[i] 46 | file.write('%s %d\n' % (label, i)) 47 | 48 | file.close() 49 | 50 | def loadDict(self, idxToLabel): 51 | for i in range(len(idxToLabel)): 52 | label = idxToLabel[i] 53 | self.add(label, i) 54 | 55 | def lookup(self, key, default=None): 56 | key = key.lower() if self.lower else key 57 | try: 58 | return self.labelToIdx[key] 59 | except KeyError: 60 | return default 61 | 62 | def getLabel(self, idx, default=None): 63 | try: 64 | return self.idxToLabel[idx] 65 | except KeyError: 66 | return default 67 | 68 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 69 | def addSpecial(self, label, idx=None): 70 | idx = self.add(label, idx) 71 | self.special += [idx] 72 | 73 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 74 | def addSpecials(self, labels): 75 | for label in labels: 76 | self.addSpecial(label) 77 | 78 | # Add `label` in the dictionary. Use `idx` as its index if given. 79 | def add(self, label, idx=None): 80 | label = label.lower() if self.lower else label 81 | if idx is not None: 82 | self.idxToLabel[idx] = label 83 | self.labelToIdx[label] = idx 84 | else: 85 | if label in self.labelToIdx: 86 | idx = self.labelToIdx[label] 87 | else: 88 | idx = len(self.idxToLabel) 89 | self.idxToLabel[idx] = label 90 | self.labelToIdx[label] = idx 91 | 92 | if idx not in self.frequencies: 93 | self.frequencies[idx] = 1 94 | else: 95 | self.frequencies[idx] += 1 96 | 97 | return idx 98 | 99 | # Return a new dictionary with the `size` most frequent entries. 100 | def prune(self, size): 101 | if size > self.size(): 102 | return self 103 | 104 | # Only keep the `size` most frequent entries. 105 | freq = torch.tensor( 106 | [self.frequencies[i] for i in range(len(self.frequencies))]) 107 | _, idx = torch.sort(freq, 0, True) 108 | idx = idx.tolist() 109 | 110 | newDict = Dict() 111 | newDict.lower = self.lower 112 | 113 | # Add special entries in all cases. 114 | for i in self.special: 115 | newDict.addSpecial(self.idxToLabel[i]) 116 | 117 | for i in idx[:size]: 118 | newDict.add(self.idxToLabel[i]) 119 | 120 | return newDict 121 | 122 | # Convert `labels` to indices. Use `unkWord` if not found. 123 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 124 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 125 | vec = [] 126 | 127 | if bosWord is not None: 128 | vec += [self.lookup(bosWord)] 129 | 130 | unk = self.lookup(unkWord) 131 | vec += [self.lookup(label, default=unk) for label in labels] 132 | 133 | if eosWord is not None: 134 | vec += [self.lookup(eosWord)] 135 | 136 | return vec 137 | 138 | 139 | def convertToIdxandOOVs(self, labels, unkWord, bosWord=None, eosWord=None): 140 | vec = [] 141 | oovs = OrderedDict() 142 | 143 | if bosWord is not None: 144 | vec += [self.lookup(bosWord)] 145 | 146 | unk = self.lookup(unkWord) 147 | for label in labels: 148 | id = self.lookup(label, default=unk) 149 | if id != unk: 150 | vec += [id] 151 | else: 152 | if label not in oovs: 153 | oovs[label] = len(oovs)+self.size() 154 | oov_num = oovs[label] 155 | vec += [oov_num] 156 | 157 | if eosWord is not None: 158 | vec += [self.lookup(eosWord)] 159 | 160 | return torch.LongTensor(vec), oovs 161 | 162 | def convertToIdxwithOOVs(self, labels, unkWord, bosWord=None, eosWord=None, oovs=None): 163 | vec = [] 164 | 165 | if bosWord is not None: 166 | vec += [self.lookup(bosWord)] 167 | 168 | unk = self.lookup(unkWord) 169 | for label in labels: 170 | id = self.lookup(label, default=unk) 171 | if id == unk and label in oovs: 172 | vec += [oovs[label]] 173 | else: 174 | vec += [id] 175 | 176 | if eosWord is not None: 177 | vec += [self.lookup(eosWord)] 178 | 179 | return torch.LongTensor(vec) 180 | 181 | 182 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 183 | def convertToLabels(self, idx, stop, oovs=None): 184 | labels = [] 185 | 186 | for i in idx: 187 | if i == stop: 188 | break 189 | if i < self.size(): 190 | labels += [self.getLabel(i)] 191 | else: 192 | labels += [oovs[i-self.size()]] 193 | 194 | return labels 195 | -------------------------------------------------------------------------------- /semantic-unit-based/models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch.nn.utils.rnn import pack_padded_sequence as pack 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | import models 7 | 8 | 9 | class rnn_encoder(nn.Module): 10 | 11 | def __init__(self, config, embedding=None): 12 | super(rnn_encoder, self).__init__() 13 | 14 | self.embedding = embedding if embedding is not None else nn.Embedding(config.src_vocab_size, config.emb_size) 15 | self.hidden_size = config.hidden_size 16 | self.config = config 17 | if config.cell == 'gru': 18 | self.rnn = nn.GRU(input_size=config.emb_size, hidden_size=config.hidden_size, 19 | num_layers=config.enc_num_layers, dropout=config.dropout, 20 | bidirectional=config.bidirectional) 21 | else: 22 | self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.hidden_size, 23 | num_layers=config.enc_num_layers, dropout=config.dropout, 24 | bidirectional=config.bidirectional) 25 | 26 | self.dconv = nn.Sequential(nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, padding=1, dilation=1), 27 | nn.SELU(), nn.AlphaDropout(p=0.05), 28 | nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, padding=1, dilation=2), 29 | nn.SELU(), nn.AlphaDropout(p=0.05), 30 | nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, padding=1, dilation=3), 31 | nn.SELU(), nn.AlphaDropout(p=0.05)) 32 | self.linear = nn.Linear(2*config.hidden_size, 2*config.hidden_size) 33 | self.glu = nn.GLU() 34 | 35 | def forward(self, inputs, lengths): 36 | embs = pack(self.embedding(inputs), lengths) 37 | embeds = unpack(embs)[0] 38 | outputs, state = self.rnn(embeds) 39 | 40 | if self.config.bidirectional: 41 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 42 | 43 | conv = outputs.transpose(0,1).transpose(1,2) 44 | conv = self.dconv(conv) 45 | conv = conv.transpose(1,2) 46 | 47 | if self.config.cell == 'gru': 48 | state = state[:self.config.dec_num_layers] 49 | else: 50 | state = (state[0][::2], state[1][::2]) 51 | 52 | return outputs, state, conv 53 | 54 | 55 | class rnn_decoder(nn.Module): 56 | 57 | def __init__(self, config, embedding=None, use_attention=True): 58 | super(rnn_decoder, self).__init__() 59 | self.embedding = embedding if embedding is not None else nn.Embedding(config.tgt_vocab_size, config.emb_size) 60 | 61 | input_size = config.emb_size 62 | 63 | if config.cell == 'gru': 64 | self.rnn = StackedGRU(input_size=input_size, hidden_size=config.hidden_size, 65 | num_layers=config.dec_num_layers, dropout=config.dropout) 66 | else: 67 | self.rnn = StackedLSTM(input_size=input_size, hidden_size=config.hidden_size, 68 | num_layers=config.dec_num_layers, dropout=config.dropout) 69 | 70 | self.linear = nn.Linear(config.hidden_size, config.tgt_vocab_size) 71 | 72 | if not use_attention or config.attention == 'None': 73 | self.attention = None 74 | elif config.attention == 'bahdanau': 75 | self.attention = models.bahdanau_attention(config.hidden_size, config.emb_size, config.pool_size) 76 | elif config.attention == 'luong': 77 | self.attention = models.luong_attention(config.hidden_size, config.emb_size, config.pool_size) 78 | elif config.attention == 'luong_gate': 79 | self.attention = models.luong_gate_attention(config.hidden_size, config.emb_size) 80 | 81 | self.hidden_size = config.hidden_size 82 | self.dropout = nn.Dropout(config.dropout) 83 | self.config = config 84 | 85 | def forward(self, input, state, conv): 86 | embs = self.embedding(input) 87 | output, state = self.rnn(embs, state) 88 | if self.attention is not None: 89 | if self.config.attention == 'luong_gate': 90 | # print(output.size(), conv.size()) 91 | output, attn_weights = self.attention(output, conv) 92 | else: 93 | output, attn_weights = self.attention(output, embs) 94 | else: 95 | attn_weights = None 96 | output = self.compute_score(output) 97 | 98 | return output, state, attn_weights 99 | 100 | def compute_score(self, hiddens): 101 | scores = self.linear(hiddens) 102 | return scores 103 | 104 | 105 | class StackedLSTM(nn.Module): 106 | def __init__(self, num_layers, input_size, hidden_size, dropout): 107 | super(StackedLSTM, self).__init__() 108 | self.dropout = nn.Dropout(dropout) 109 | self.num_layers = num_layers 110 | self.layers = nn.ModuleList() 111 | 112 | for i in range(num_layers): 113 | lstm = nn.LSTMCell(input_size, hidden_size) 114 | self.layers.append(lstm) 115 | input_size = hidden_size 116 | 117 | def forward(self, input, hidden): 118 | h_0, c_0 = hidden 119 | h_1, c_1 = [], [] 120 | for i, layer in enumerate(self.layers): 121 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 122 | input = h_1_i 123 | if i + 1 != self.num_layers: 124 | input = self.dropout(input) 125 | h_1 += [h_1_i] 126 | c_1 += [c_1_i] 127 | 128 | h_1 = torch.stack(h_1) 129 | c_1 = torch.stack(c_1) 130 | 131 | return input, (h_1, c_1) 132 | 133 | 134 | class StackedGRU(nn.Module): 135 | def __init__(self, num_layers, input_size, hidden_size, dropout): 136 | super(StackedGRU, self).__init__() 137 | self.dropout = nn.Dropout(dropout) 138 | self.num_layers = num_layers 139 | self.layers = nn.ModuleList() 140 | 141 | for i in range(num_layers): 142 | self.layers.append(nn.GRUCell(input_size, hidden_size)) 143 | input_size = hidden_size 144 | 145 | def forward(self, input, hidden): 146 | h_0 = hidden 147 | h_1 = [] 148 | for i, layer in enumerate(self.layers): 149 | h_1_i = layer(input, h_0[i]) 150 | input = h_1_i 151 | if i + 1 != self.num_layers: 152 | input = self.dropout(input) 153 | h_1 += [h_1_i] 154 | 155 | h_1 = torch.stack(h_1) 156 | 157 | return input, h_1 158 | -------------------------------------------------------------------------------- /semantic-unit-based/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from torch.autograd import Variable 4 | import utils 5 | import models 6 | import random 7 | 8 | 9 | class seq2seq(nn.Module): 10 | 11 | def __init__(self, config, use_attention=True, encoder=None, decoder=None): 12 | super(seq2seq, self).__init__() 13 | 14 | if encoder is not None: 15 | self.encoder = encoder 16 | else: 17 | self.encoder = models.rnn_encoder(config) 18 | tgt_embedding = self.encoder.embedding if config.shared_vocab else None 19 | if decoder is not None: 20 | self.decoder = decoder 21 | else: 22 | self.decoder = models.rnn_decoder(config, embedding=tgt_embedding, use_attention=use_attention) 23 | self.log_softmax = nn.LogSoftmax(dim=-1) 24 | self.use_cuda = config.use_cuda 25 | self.config = config 26 | self.criterion = nn.CrossEntropyLoss(ignore_index=utils.PAD, reduction='none') 27 | if config.use_cuda: 28 | self.criterion.cuda() 29 | 30 | def compute_loss(self, scores, targets): 31 | scores = scores.view(-1, scores.size(2)) 32 | loss = self.criterion(scores, targets.contiguous().view(-1)) 33 | return loss 34 | 35 | def forward(self, src, src_len, dec, targets, teacher_ratio=1.0): 36 | src = src.t() 37 | dec = dec.t() 38 | targets = targets.t() 39 | teacher = random.random() < teacher_ratio 40 | 41 | contexts, state, conv = self.encoder(src, src_len.tolist()) 42 | 43 | if self.decoder.attention is not None: 44 | self.decoder.attention.init_context(context=contexts) 45 | outputs = [] 46 | if teacher: 47 | for input in dec.split(1): 48 | output, state, attn_weights = self.decoder(input.squeeze(0), state, conv) 49 | outputs.append(output) 50 | outputs = torch.stack(outputs) 51 | else: 52 | inputs = [dec.split(1)[0].squeeze(0)] 53 | for i, _ in enumerate(dec.split(1)): 54 | output, state, attn_weights = self.decoder(inputs[i], state) 55 | predicted = output.max(1)[1] 56 | inputs += [predicted] 57 | outputs.append(output) 58 | outputs = torch.stack(outputs) 59 | 60 | loss = self.compute_loss(outputs, targets) 61 | return loss, outputs 62 | 63 | def sample(self, src, src_len): 64 | 65 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 66 | _, reverse_indices = torch.sort(indices) 67 | src = torch.index_select(src, dim=0, index=indices) 68 | bos = torch.ones(src.size(0)).long().fill_(utils.BOS) 69 | src = src.t() 70 | 71 | if self.use_cuda: 72 | bos = bos.cuda() 73 | 74 | contexts, state, conv = self.encoder(src, lengths.tolist()) 75 | 76 | if self.decoder.attention is not None: 77 | self.decoder.attention.init_context(context=contexts) 78 | inputs, outputs, attn_matrix = [bos], [], [] 79 | for i in range(self.config.max_time_step): 80 | output, state, attn_weights = self.decoder(inputs[i], state, conv) 81 | predicted = output.max(1)[1] 82 | inputs += [predicted] 83 | outputs += [predicted] 84 | attn_matrix += [attn_weights] 85 | 86 | outputs = torch.stack(outputs) 87 | sample_ids = torch.index_select(outputs, dim=1, index=reverse_indices).t() 88 | 89 | if self.decoder.attention is not None: 90 | attn_matrix = torch.stack(attn_matrix) 91 | alignments = attn_matrix.max(2)[1] 92 | alignments = torch.index_select(alignments, dim=1, index=reverse_indices).t() 93 | else: 94 | alignments = None 95 | 96 | return sample_ids, alignments 97 | 98 | def beam_sample(self, src, src_len, beam_size=1, eval_=False): 99 | 100 | # (1) Run the encoder on the src. 101 | 102 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 103 | _, ind = torch.sort(indices) 104 | src = torch.index_select(src, dim=0, index=indices) 105 | src = src.t() 106 | batch_size = src.size(1) 107 | contexts, encState, conv = self.encoder(src, lengths.tolist()) 108 | 109 | # (1b) Initialize for the decoder. 110 | def var(a): 111 | return torch.tensor(a, requires_grad=False) 112 | 113 | def rvar(a): 114 | return var(a.repeat(1, beam_size, 1)) 115 | 116 | def bottle(m): 117 | return m.view(batch_size * beam_size, -1) 118 | 119 | def unbottle(m): 120 | return m.view(beam_size, batch_size, -1) 121 | 122 | # Repeat everything beam_size times. 123 | contexts = rvar(contexts) 124 | conv = rvar(conv.transpose(0,1)).transpose(0,1) 125 | 126 | if self.config.cell == 'lstm': 127 | decState = (rvar(encState[0]), rvar(encState[1])) 128 | else: 129 | decState = rvar(encState) 130 | 131 | beam = [models.Beam(beam_size, n_best=1, 132 | cuda=self.use_cuda, length_norm=self.config.length_norm) 133 | for __ in range(batch_size)] 134 | if self.decoder.attention is not None: 135 | self.decoder.attention.init_context(contexts) 136 | 137 | # (2) run the decoder to generate sentences, using beam search. 138 | 139 | for i in range(self.config.max_time_step): 140 | 141 | if all((b.done() for b in beam)): 142 | break 143 | 144 | # Construct batch x beam_size nxt words. 145 | # Get all the pending current beam words and arrange for forward. 146 | inp = var(torch.stack([b.getCurrentState() for b in beam]) 147 | .t().contiguous().view(-1)) 148 | 149 | # Run one step. 150 | output, decState, attn = self.decoder(inp, decState, conv) 151 | # decOut: beam x rnn_size 152 | 153 | # (b) Compute a vector of batch*beam word scores. 154 | output = unbottle(self.log_softmax(output)) 155 | attn = unbottle(attn) 156 | # beam x tgt_vocab 157 | 158 | # (c) Advance each beam. 159 | # update state 160 | for j, b in enumerate(beam): 161 | b.advance(output[:, j], attn[:, j]) 162 | if self.config.cell == 'lstm': 163 | b.beam_update(decState, j) 164 | else: 165 | b.beam_update_gru(decState, j) 166 | 167 | # (3) Package everything up. 168 | allHyps, allScores, allAttn = [], [], [] 169 | if eval_: 170 | allWeight = [] 171 | 172 | for j in ind: 173 | b = beam[j] 174 | n_best = 1 175 | scores, ks = b.sortFinished(minimum=n_best) 176 | hyps, attn = [], [] 177 | if eval_: 178 | weight = [] 179 | for i, (times, k) in enumerate(ks[:n_best]): 180 | hyp, att = b.getHyp(times, k) 181 | hyps.append(hyp) 182 | attn.append(att.max(1)[1]) 183 | if eval_: 184 | weight.append(att) 185 | allHyps.append(hyps[0]) 186 | allScores.append(scores[0]) 187 | allAttn.append(attn[0]) 188 | if eval_: 189 | allWeight.append(weight[0]) 190 | 191 | if eval_: 192 | return allHyps, allAttn, allWeight 193 | 194 | return allHyps, allAttn 195 | -------------------------------------------------------------------------------- /semantic-unit-based/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | import pickle 4 | 5 | parser = argparse.ArgumentParser(description='preprocess.py') 6 | 7 | parser.add_argument('-load_data', required=True, 8 | help="input file for the data") 9 | 10 | parser.add_argument('-save_data', required=True, 11 | help="Output file for the prepared data") 12 | 13 | parser.add_argument('-src_vocab_size', type=int, default=50000, 14 | help="Size of the source vocabulary") 15 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 16 | help="Size of the target vocabulary") 17 | parser.add_argument('-src_filter', type=int, default=0, 18 | help="Maximum source sequence length") 19 | parser.add_argument('-tgt_filter', type=int, default=0, 20 | help="Maximum target sequence length") 21 | parser.add_argument('-src_trun', type=int, default=0, 22 | help="Truncate source sequence length") 23 | parser.add_argument('-tgt_trun', type=int, default=0, 24 | help="Truncate target sequence length") 25 | parser.add_argument('-src_char', action='store_true', help='character based encoding') 26 | parser.add_argument('-tgt_char', action='store_true', help='character based decoding') 27 | parser.add_argument('-src_suf', default='src', 28 | help="the suffix of the source filename") 29 | parser.add_argument('-tgt_suf', default='tgt', 30 | help="the suffix of the target filename") 31 | 32 | parser.add_argument('-share', action='store_true', help='share the vocabulary between source and target') 33 | 34 | parser.add_argument('-report_every', type=int, default=100000, 35 | help="Report status every this many sentences") 36 | 37 | opt = parser.parse_args() 38 | 39 | 40 | def makeVocabulary(filename, trun_length, filter_length, char, vocab, size): 41 | 42 | print("%s: length limit = %d, truncate length = %d" % (filename, filter_length, trun_length)) 43 | max_length = 0 44 | with open(filename, encoding='utf8') as f: 45 | for sent in f.readlines(): 46 | if char: 47 | tokens = list(sent.strip()) 48 | else: 49 | tokens = sent.strip().split() 50 | if 0 < filter_length < len(sent.strip().split()): 51 | continue 52 | max_length = max(max_length, len(tokens)) 53 | if trun_length > 0: 54 | tokens = tokens[:trun_length] 55 | for word in tokens: 56 | vocab.add(word) 57 | 58 | print('Max length of %s = %d' % (filename, max_length)) 59 | 60 | if size > 0: 61 | originalSize = vocab.size() 62 | vocab = vocab.prune(size) 63 | print('Created dictionary of size %d (pruned from %d)' % 64 | (vocab.size(), originalSize)) 65 | 66 | return vocab 67 | 68 | 69 | def saveVocabulary(name, vocab, file): 70 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...') 71 | vocab.writeFile(file) 72 | 73 | 74 | def makeData(srcFile, tgtFile, srcDicts, tgtDicts, save_srcFile, save_tgtFile, lim=0): 75 | sizes = 0 76 | count, empty_ignored, limit_ignored = 0, 0, 0 77 | 78 | print('Processing %s & %s ...' % (srcFile, tgtFile)) 79 | srcF = open(srcFile, encoding='utf8') 80 | tgtF = open(tgtFile, encoding='utf8') 81 | 82 | srcIdF = open(save_srcFile + '.id', 'w') 83 | tgtIdF = open(save_tgtFile + '.id', 'w') 84 | srcStrF = open(save_srcFile + '.str', 'w', encoding='utf8') 85 | tgtStrF = open(save_tgtFile + '.str', 'w', encoding='utf8') 86 | 87 | while True: 88 | sline = srcF.readline() 89 | tline = tgtF.readline() 90 | 91 | # normal end of file 92 | if sline == "" and tline == "": 93 | break 94 | 95 | # source or target does not have same number of lines 96 | if sline == "" or tline == "": 97 | print('WARNING: source and target do not have the same number of sentences') 98 | break 99 | 100 | sline = sline.strip() 101 | tline = tline.strip() 102 | 103 | # source and/or target are empty 104 | if sline == "" or tline == "": 105 | print('WARNING: ignoring an empty line ('+str(count+1)+')') 106 | empty_ignored += 1 107 | continue 108 | 109 | sline = sline.lower() 110 | tline = tline.lower() 111 | 112 | srcWords = sline.split() if not opt.src_char else list(sline) 113 | tgtWords = tline.split() if not opt.tgt_char else list(tline) 114 | 115 | 116 | if (opt.src_filter == 0 or len(sline.split()) <= opt.src_filter) and \ 117 | (opt.tgt_filter == 0 or len(tline.split()) <= opt.tgt_filter): 118 | 119 | if opt.src_trun > 0: 120 | srcWords = srcWords[:opt.src_trun] 121 | if opt.tgt_trun > 0: 122 | tgtWords = tgtWords[:opt.tgt_trun] 123 | 124 | srcIds = srcDicts.convertToIdx(srcWords, utils.UNK_WORD) 125 | tgtIds = tgtDicts.convertToIdx(tgtWords, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD) 126 | 127 | srcIdF.write(" ".join(list(map(str, srcIds)))+'\n') 128 | tgtIdF.write(" ".join(list(map(str, tgtIds)))+'\n') 129 | if not opt.src_char: 130 | srcStrF.write(" ".join(srcWords)+'\n') 131 | else: 132 | srcStrF.write("".join(srcWords) + '\n') 133 | if not opt.tgt_char: 134 | tgtStrF.write(" ".join(tgtWords)+'\n') 135 | else: 136 | tgtStrF.write("".join(tgtWords) + '\n') 137 | 138 | sizes += 1 139 | else: 140 | limit_ignored += 1 141 | 142 | count += 1 143 | 144 | if count % opt.report_every == 0: 145 | print('... %d sentences prepared' % count) 146 | 147 | srcF.close() 148 | tgtF.close() 149 | srcStrF.close() 150 | tgtStrF.close() 151 | srcIdF.close() 152 | tgtIdF.close() 153 | 154 | print('Prepared %d sentences (%d and %d ignored due to length == 0 or > )' % 155 | (sizes, empty_ignored, limit_ignored)) 156 | 157 | return {'srcF': save_srcFile + '.id', 'tgtF': save_tgtFile + '.id', 158 | 'original_srcF': save_srcFile + '.str', 'original_tgtF': save_tgtFile + '.str', 159 | 'length': sizes} 160 | 161 | 162 | def main(): 163 | 164 | dicts = {} 165 | 166 | train_src, train_tgt = opt.load_data + 'train.' + opt.src_suf, opt.load_data + 'train.' + opt.tgt_suf 167 | valid_src, valid_tgt = opt.load_data + 'valid.' + opt.src_suf, opt.load_data + 'valid.' + opt.tgt_suf 168 | test_src, test_tgt = opt.load_data + 'test.' + opt.src_suf, opt.load_data + 'test.' + opt.tgt_suf 169 | 170 | save_train_src, save_train_tgt = opt.save_data + 'train.' + opt.src_suf, opt.save_data + 'train.' + opt.tgt_suf 171 | save_valid_src, save_valid_tgt = opt.save_data + 'valid.' + opt.src_suf, opt.save_data + 'valid.' + opt.tgt_suf 172 | save_test_src, save_test_tgt = opt.save_data + 'test.' + opt.src_suf, opt.save_data + 'test.' + opt.tgt_suf 173 | 174 | src_dict, tgt_dict = opt.save_data + 'src.dict', opt.save_data + 'tgt.dict' 175 | 176 | if opt.share: 177 | assert opt.src_vocab_size == opt.tgt_vocab_size 178 | print('Building source and target vocabulary...') 179 | dicts['src'] = dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 180 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 181 | dicts['src'] = dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 182 | else: 183 | print('Building source vocabulary...') 184 | dicts['src'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 185 | dicts['src'] = makeVocabulary(train_src, opt.src_trun, opt.src_filter, opt.src_char, dicts['src'], opt.src_vocab_size) 186 | print('Building target vocabulary...') 187 | dicts['tgt'] = utils.Dict([utils.PAD_WORD, utils.UNK_WORD, utils.BOS_WORD, utils.EOS_WORD]) 188 | dicts['tgt'] = makeVocabulary(train_tgt, opt.tgt_trun, opt.tgt_filter, opt.tgt_char, dicts['tgt'], opt.tgt_vocab_size) 189 | 190 | print('Preparing training ...') 191 | train = makeData(train_src, train_tgt, dicts['src'], dicts['tgt'], save_train_src, save_train_tgt) 192 | 193 | print('Preparing validation ...') 194 | valid = makeData(valid_src, valid_tgt, dicts['src'], dicts['tgt'], save_valid_src, save_valid_tgt) 195 | 196 | print('Preparing test ...') 197 | test = makeData(test_src, test_tgt, dicts['src'], dicts['tgt'], save_test_src, save_test_tgt) 198 | 199 | print('Saving source vocabulary to \'' + src_dict + '\'...') 200 | dicts['src'].writeFile(src_dict) 201 | 202 | print('Saving source vocabulary to \'' + tgt_dict + '\'...') 203 | dicts['tgt'].writeFile(tgt_dict) 204 | 205 | data = {'train': train, 'valid': valid, 206 | 'test': test, 'dict': dicts} 207 | pickle.dump(data, open(opt.save_data+'data.pkl', 'wb')) 208 | 209 | 210 | if __name__ == "__main__": 211 | main() 212 | -------------------------------------------------------------------------------- /semantic-unit-based/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | # from torch.autograd import Variable 4 | import lr_scheduler as L 5 | 6 | import os 7 | import argparse 8 | import pickle 9 | import time 10 | from collections import OrderedDict 11 | 12 | import opts 13 | import models 14 | import utils 15 | import codecs 16 | import json 17 | 18 | import numpy as np 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | import matplotlib.ticker as ticker 23 | import random 24 | 25 | parser = argparse.ArgumentParser(description='train.py') 26 | opts.model_opts(parser) 27 | 28 | opt = parser.parse_args() 29 | config = utils.read_config(opt.config) 30 | torch.manual_seed(opt.seed) 31 | random.seed(opt.seed) 32 | np.random.seed(opt.seed) 33 | opts.convert_to_config(opt, config) 34 | 35 | # cuda 36 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 37 | config.use_cuda = use_cuda 38 | if use_cuda: 39 | torch.cuda.set_device(opt.gpus[0]) 40 | torch.cuda.manual_seed(opt.seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | with open(opt.label_dict_file, 'r') as f: 44 | label_dict = json.load(f) 45 | 46 | 47 | def load_data(): 48 | print('loading data...\n') 49 | data = pickle.load(open(config.data+'data.pkl', 'rb')) 50 | data['train']['length'] = int(data['train']['length'] * opt.scale) 51 | 52 | trainset = utils.BiDataset(data['train'], char=config.char) 53 | validset = utils.BiDataset(data['valid'], char=config.char) 54 | 55 | src_vocab = data['dict']['src'] 56 | tgt_vocab = data['dict']['tgt'] 57 | config.src_vocab_size = src_vocab.size() 58 | config.tgt_vocab_size = tgt_vocab.size() 59 | 60 | trainloader = torch.utils.data.DataLoader(dataset=trainset, 61 | batch_size=config.batch_size, 62 | shuffle=True, 63 | num_workers=0, 64 | collate_fn=utils.padding) 65 | if hasattr(config, 'valid_batch_size'): 66 | valid_batch_size = config.valid_batch_size 67 | else: 68 | valid_batch_size = config.batch_size 69 | validloader = torch.utils.data.DataLoader(dataset=validset, 70 | batch_size=valid_batch_size, 71 | shuffle=False, 72 | num_workers=0, 73 | collate_fn=utils.padding) 74 | 75 | return {'trainset': trainset, 'validset': validset, 76 | 'trainloader': trainloader, 'validloader': validloader, 77 | 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab} 78 | 79 | 80 | 81 | def build_model(checkpoints, print_log): 82 | for k, v in config.items(): 83 | print_log("%s:\t%s\n" % (str(k), str(v))) 84 | 85 | # model 86 | print('building model...\n') 87 | model = getattr(models, opt.model)(config) 88 | if checkpoints is not None: 89 | model.load_state_dict(checkpoints['model']) 90 | if opt.pretrain: 91 | print('loading checkpoint from %s' % opt.pretrain) 92 | pre_ckpt = torch.load(opt.pretrain)['model'] 93 | pre_ckpt = OrderedDict({key[8:]: pre_ckpt[key] for key in pre_ckpt if key.startswith('encoder')}) 94 | print(model.encoder.state_dict().keys()) 95 | print(pre_ckpt.keys()) 96 | model.encoder.load_state_dict(pre_ckpt) 97 | if use_cuda: 98 | model.cuda() 99 | 100 | # optimizer 101 | if checkpoints is not None: 102 | optim = checkpoints['optim'] 103 | else: 104 | optim = models.Optim(config.optim, config.learning_rate, config.max_grad_norm, 105 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at) 106 | optim.set_parameters(model.parameters()) 107 | 108 | # print log 109 | param_count = 0 110 | for param in model.parameters(): 111 | param_count += param.view(-1).size()[0] 112 | for k, v in config.items(): 113 | print_log("%s:\t%s\n" % (str(k), str(v))) 114 | print_log("\n") 115 | print_log(repr(model) + "\n\n") 116 | print_log('total number of parameters: %d\n\n' % param_count) 117 | 118 | return model, optim, print_log 119 | 120 | 121 | def train_model(model, data, optim, epoch, params): 122 | 123 | model.train() 124 | trainloader = data['trainloader'] 125 | 126 | for src, tgt, src_len, tgt_len, original_src, original_tgt in trainloader: 127 | 128 | model.zero_grad() 129 | 130 | if config.use_cuda: 131 | src = src.cuda() 132 | tgt = tgt.cuda() 133 | src_len = src_len.cuda() 134 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 135 | src = torch.index_select(src, dim=0, index=indices) 136 | tgt = torch.index_select(tgt, dim=0, index=indices) 137 | dec = tgt[:, :-1] 138 | targets = tgt[:, 1:] 139 | 140 | try: 141 | if config.schesamp: 142 | if epoch > 8: 143 | e = epoch - 8 144 | loss, outputs = model(src, lengths, dec, targets, teacher_ratio=0.9**e) 145 | else: 146 | loss, outputs = model(src, lengths, dec, targets) 147 | else: 148 | loss, outputs = model(src, lengths, dec, targets) 149 | pred = outputs.max(2)[1] 150 | targets = targets.t() 151 | num_correct = pred.eq(targets).masked_select(targets.ne(utils.PAD)).sum().item() 152 | num_total = targets.ne(utils.PAD).sum().item() 153 | if config.max_split == 0: 154 | loss = torch.sum(loss) / num_total 155 | loss.backward() 156 | optim.step() 157 | 158 | params['report_loss'] += loss.item() 159 | params['report_correct'] += num_correct 160 | params['report_total'] += num_total 161 | 162 | except RuntimeError as e: 163 | if 'out of memory' in str(e): 164 | print('| WARNING: ran out of memory') 165 | if hasattr(torch.cuda, 'empty_cache'): 166 | torch.cuda.empty_cache() 167 | else: 168 | raise e 169 | 170 | utils.progress_bar(params['updates'], config.eval_interval) 171 | params['updates'] += 1 172 | 173 | if params['updates'] % config.eval_interval == 0: 174 | params['log']("epoch: %3d, loss: %6.3f, time: %6.3f, updates: %8d, accuracy: %2.2f\n" 175 | % (epoch, params['report_loss'], time.time()-params['report_time'], 176 | params['updates'], params['report_correct'] * 100.0 / params['report_total'])) 177 | print('evaluating after %d updates...\r' % params['updates']) 178 | score = eval_model(model, data, params) 179 | for metric in config.metrics: 180 | params[metric].append(score[metric]) 181 | if score[metric] >= max(params[metric]): 182 | with codecs.open(params['log_path']+'best_'+metric+'_prediction.txt','w','utf-8') as f: 183 | f.write(codecs.open(params['log_path']+'candidate.txt','r','utf-8').read()) 184 | save_model(params['log_path']+'best_'+metric+'_checkpoint.pt', model, optim, params['updates']) 185 | model.train() 186 | params['report_loss'], params['report_time'] = 0, time.time() 187 | params['report_correct'], params['report_total'] = 0, 0 188 | 189 | if params['updates'] % config.save_interval == 0: 190 | save_model(params['log_path']+'checkpoint.pt', model, optim, params['updates']) 191 | 192 | optim.updateLearningRate(score=0, epoch=epoch) 193 | 194 | 195 | def eval_model(model, data, params): 196 | 197 | model.eval() 198 | reference, candidate, source, alignments = [], [], [], [] 199 | count, total_count = 0, len(data['validset']) 200 | validloader = data['validloader'] 201 | tgt_vocab = data['tgt_vocab'] 202 | 203 | 204 | for src, tgt, src_len, tgt_len, original_src, original_tgt in validloader: 205 | 206 | if config.use_cuda: 207 | src = src.cuda() 208 | src_len = src_len.cuda() 209 | 210 | with torch.no_grad(): 211 | if config.beam_size > 1: 212 | samples, alignment, weight = model.beam_sample(src, src_len, beam_size=config.beam_size, eval_=True) 213 | else: 214 | samples, alignment = model.sample(src, src_len) 215 | 216 | candidate += [tgt_vocab.convertToLabels(s, utils.EOS) for s in samples] 217 | source += original_src 218 | reference += original_tgt 219 | if alignment is not None: 220 | alignments += [align for align in alignment] 221 | 222 | count += len(original_src) 223 | utils.progress_bar(count, total_count) 224 | 225 | if config.unk and config.attention != 'None': 226 | cands = [] 227 | for s, c, align in zip(source, candidate, alignments): 228 | cand = [] 229 | for word, idx in zip(c, align): 230 | if word == utils.UNK_WORD and idx < len(s): 231 | try: 232 | cand.append(s[idx]) 233 | except: 234 | cand.append(word) 235 | print("%d %d\n" % (len(s), idx)) 236 | else: 237 | cand.append(word) 238 | cands.append(cand) 239 | if len(cand) == 0: 240 | print('Error!') 241 | candidate = cands 242 | 243 | with codecs.open(params['log_path']+'candidate.txt','w+','utf-8') as f: 244 | for i in range(len(candidate)): 245 | f.write(" ".join(candidate[i])+'\n') 246 | 247 | results = utils.eval_metrics(reference, candidate, label_dict, params['log_path']) 248 | score = {} 249 | result_line = "" 250 | for metric in config.metrics: 251 | score[metric] = results[metric] 252 | result_line += metric + ": %s " % str(score[metric]) 253 | result_line += '\n' 254 | 255 | params['log'](result_line) 256 | 257 | return score 258 | 259 | 260 | def save_model(path, model, optim, updates): 261 | model_state_dict = model.state_dict() 262 | checkpoints = { 263 | 'model': model_state_dict, 264 | 'config': config, 265 | 'optim': optim, 266 | 'updates': updates} 267 | torch.save(checkpoints, path) 268 | 269 | 270 | def build_log(): 271 | # log 272 | if not os.path.exists(config.logF): 273 | os.mkdir(config.logF) 274 | if opt.log == '': 275 | log_path = config.logF + str(int(time.time() * 1000)) + '/' 276 | else: 277 | log_path = config.logF + opt.log + '/' 278 | if not os.path.exists(log_path): 279 | os.mkdir(log_path) 280 | print_log = utils.print_log(log_path + 'log.txt') 281 | return print_log, log_path 282 | 283 | 284 | def showAttention(path, s, c, attentions, index): 285 | # Set up figure with colorbar 286 | fig = plt.figure() 287 | ax = fig.add_subplot(111) 288 | cax = ax.matshow(attentions.numpy(), cmap='bone') 289 | fig.colorbar(cax) 290 | # Set up axes 291 | ax.set_xticklabels([''] + s, rotation=90) 292 | ax.set_yticklabels([''] + c) 293 | # Show label at every tick 294 | ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 295 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 296 | plt.show() 297 | plt.savefig(path + str(index) + '.jpg') 298 | 299 | 300 | def main(): 301 | # checkpoint 302 | if opt.restore: 303 | print('loading checkpoint...\n') 304 | checkpoints = torch.load(opt.restore) 305 | else: 306 | checkpoints = None 307 | 308 | data = load_data() 309 | print_log, log_path = build_log() 310 | model, optim, print_log = build_model(checkpoints, print_log) 311 | # scheduler 312 | if config.schedule: 313 | scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch) 314 | params = {'updates': 0, 'report_loss': 0, 'report_total': 0, 315 | 'report_correct': 0, 'report_time': time.time(), 316 | 'log': print_log, 'log_path': log_path} 317 | for metric in config.metrics: 318 | params[metric] = [] 319 | if opt.restore: 320 | params['updates'] = checkpoints['updates'] 321 | 322 | if opt.mode == "train": 323 | for i in range(1, config.epoch + 1): 324 | if config.schedule: 325 | scheduler.step() 326 | print("Decaying learning rate to %g" % scheduler.get_lr()[0]) 327 | train_model(model, data, optim, i, params) 328 | for metric in config.metrics: 329 | print_log("Best %s score: %.2f\n" % (metric, max(params[metric]))) 330 | else: 331 | score = eval_model(model, data, params) 332 | 333 | 334 | if __name__ == '__main__': 335 | main() 336 | -------------------------------------------------------------------------------- /semantic-unit-based/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class _LRScheduler(object): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} is not an Optimizer'.format( 10 | type(optimizer).__name__)) 11 | self.optimizer = optimizer 12 | if last_epoch == -1: 13 | for group in optimizer.param_groups: 14 | group.setdefault('initial_lr', group['lr']) 15 | else: 16 | for i, group in enumerate(optimizer.param_groups): 17 | if 'initial_lr' not in group: 18 | raise KeyError("param 'initial_lr' is not specified " 19 | "in param_groups[{}] when resuming an optimizer".format(i)) 20 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 21 | self.step(last_epoch + 1) 22 | self.last_epoch = last_epoch 23 | 24 | def get_lr(self): 25 | raise NotImplementedError 26 | 27 | def step(self, epoch=None): 28 | if epoch is None: 29 | epoch = self.last_epoch + 1 30 | self.last_epoch = epoch 31 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 32 | param_group['lr'] = lr 33 | 34 | 35 | class LambdaLR(_LRScheduler): 36 | """Sets the learning rate of each parameter group to the initial lr 37 | times a given function. When last_epoch=-1, sets initial lr as lr. 38 | Args: 39 | optimizer (Optimizer): Wrapped optimizer. 40 | lr_lambda (function or list): A function which computes a multiplicative 41 | factor given an integer parameter epoch, or a list of such 42 | functions, one for each group in optimizer.param_groups. 43 | last_epoch (int): The index of last epoch. Default: -1. 44 | Example: 45 | >>> # Assuming optimizer has two groups. 46 | >>> lambda1 = lambda epoch: epoch // 30 47 | >>> lambda2 = lambda epoch: 0.95 ** epoch 48 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 49 | >>> for epoch in range(100): 50 | >>> scheduler.step() 51 | >>> train(...) 52 | >>> validate(...) 53 | """ 54 | def __init__(self, optimizer, lr_lambda, last_epoch=-1): 55 | self.optimizer = optimizer 56 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 57 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 58 | else: 59 | if len(lr_lambda) != len(optimizer.param_groups): 60 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 61 | len(optimizer.param_groups), len(lr_lambda))) 62 | self.lr_lambdas = list(lr_lambda) 63 | self.last_epoch = last_epoch 64 | super(LambdaLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | return [base_lr * lmbda(self.last_epoch) 68 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 69 | 70 | 71 | class StepLR(_LRScheduler): 72 | """Sets the learning rate of each parameter group to the initial lr 73 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 74 | initial lr as lr. 75 | Args: 76 | optimizer (Optimizer): Wrapped optimizer. 77 | step_size (int): Period of learning rate decay. 78 | gamma (float): Multiplicative factor of learning rate decay. 79 | Default: 0.1. 80 | last_epoch (int): The index of last epoch. Default: -1. 81 | Example: 82 | >>> # Assuming optimizer uses lr = 0.5 for all groups 83 | >>> # lr = 0.05 if epoch < 30 84 | >>> # lr = 0.005 if 30 <= epoch < 60 85 | >>> # lr = 0.0005 if 60 <= epoch < 90 86 | >>> # ... 87 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 88 | >>> for epoch in range(100): 89 | >>> scheduler.step() 90 | >>> train(...) 91 | >>> validate(...) 92 | """ 93 | 94 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): 95 | self.step_size = step_size 96 | self.gamma = gamma 97 | super(StepLR, self).__init__(optimizer, last_epoch) 98 | 99 | def get_lr(self): 100 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size) 101 | for base_lr in self.base_lrs] 102 | 103 | 104 | class MultiStepLR(_LRScheduler): 105 | """Set the learning rate of each parameter group to the initial lr decayed 106 | by gamma once the number of epoch reaches one of the milestones. When 107 | last_epoch=-1, sets initial lr as lr. 108 | Args: 109 | optimizer (Optimizer): Wrapped optimizer. 110 | milestones (list): List of epoch indices. Must be increasing. 111 | gamma (float): Multiplicative factor of learning rate decay. 112 | Default: 0.1. 113 | last_epoch (int): The index of last epoch. Default: -1. 114 | Example: 115 | >>> # Assuming optimizer uses lr = 0.5 for all groups 116 | >>> # lr = 0.05 if epoch < 30 117 | >>> # lr = 0.005 if 30 <= epoch < 80 118 | >>> # lr = 0.0005 if epoch >= 80 119 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 120 | >>> for epoch in range(100): 121 | >>> scheduler.step() 122 | >>> train(...) 123 | >>> validate(...) 124 | """ 125 | 126 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 127 | if not list(milestones) == sorted(milestones): 128 | raise ValueError('Milestones should be a list of' 129 | ' increasing integers. Got {}', milestones) 130 | self.milestones = milestones 131 | self.gamma = gamma 132 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 133 | 134 | def get_lr(self): 135 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 136 | for base_lr in self.base_lrs] 137 | 138 | 139 | class ExponentialLR(_LRScheduler): 140 | """Set the learning rate of each parameter group to the initial lr decayed 141 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 142 | Args: 143 | optimizer (Optimizer): Wrapped optimizer. 144 | gamma (float): Multiplicative factor of learning rate decay. 145 | last_epoch (int): The index of last epoch. Default: -1. 146 | """ 147 | 148 | def __init__(self, optimizer, gamma, last_epoch=-1): 149 | self.gamma = gamma 150 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 151 | 152 | def get_lr(self): 153 | return [base_lr * self.gamma ** self.last_epoch 154 | for base_lr in self.base_lrs] 155 | 156 | 157 | class CosineAnnealingLR(_LRScheduler): 158 | """Set the learning rate of each parameter group using a cosine annealing 159 | schedule, where :math:`\eta_{max}` is set to the initial lr and 160 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 161 | .. math:: 162 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 163 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 164 | When last_epoch=-1, sets initial lr as lr. 165 | It has been proposed in 166 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 167 | implements the cosine annealing part of SGDR, and not the restarts. 168 | Args: 169 | optimizer (Optimizer): Wrapped optimizer. 170 | T_max (int): Maximum number of iterations. 171 | eta_min (float): Minimum learning rate. Default: 0. 172 | last_epoch (int): The index of last epoch. Default: -1. 173 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 174 | https://arxiv.org/abs/1608.03983 175 | """ 176 | 177 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): 178 | self.T_max = T_max 179 | self.eta_min = eta_min 180 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) 181 | 182 | def get_lr(self): 183 | return [self.eta_min + (base_lr - self.eta_min) * 184 | (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2 185 | for base_lr in self.base_lrs] 186 | 187 | 188 | class ReduceLROnPlateau(object): 189 | """Reduce learning rate when a metric has stopped improving. 190 | Models often benefit from reducing the learning rate by a factor 191 | of 2-10 once learning stagnates. This scheduler reads a metrics 192 | quantity and if no improvement is seen for a 'patience' number 193 | of epochs, the learning rate is reduced. 194 | Args: 195 | optimizer (Optimizer): Wrapped optimizer. 196 | mode (str): One of `min`, `max`. In `min` mode, lr will 197 | be reduced when the quantity monitored has stopped 198 | decreasing; in `max` mode it will be reduced when the 199 | quantity monitored has stopped increasing. Default: 'min'. 200 | factor (float): Factor by which the learning rate will be 201 | reduced. new_lr = lr * factor. Default: 0.1. 202 | patience (int): Number of epochs with no improvement after 203 | which learning rate will be reduced. Default: 10. 204 | verbose (bool): If True, prints a message to stdout for 205 | each update. Default: False. 206 | threshold (float): Threshold for measuring the new optimum, 207 | to only focus on significant changes. Default: 1e-4. 208 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 209 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 210 | mode or best * ( 1 - threshold ) in `min` mode. 211 | In `abs` mode, dynamic_threshold = best + threshold in 212 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 213 | cooldown (int): Number of epochs to wait before resuming 214 | normal operation after lr has been reduced. Default: 0. 215 | min_lr (float or list): A scalar or a list of scalars. A 216 | lower bound on the learning rate of all param groups 217 | or each group respectively. Default: 0. 218 | eps (float): Minimal decay applied to lr. If the difference 219 | between new and old lr is smaller than eps, the update is 220 | ignored. Default: 1e-8. 221 | Example: 222 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 223 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 224 | >>> for epoch in range(10): 225 | >>> train(...) 226 | >>> val_loss = validate(...) 227 | >>> # Note that step should be called after validate() 228 | >>> scheduler.step(val_loss) 229 | """ 230 | 231 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 232 | verbose=False, threshold=1e-4, threshold_mode='rel', 233 | cooldown=0, min_lr=0, eps=1e-8): 234 | 235 | if factor >= 1.0: 236 | raise ValueError('Factor should be < 1.0.') 237 | self.factor = factor 238 | 239 | if not isinstance(optimizer, Optimizer): 240 | raise TypeError('{} is not an Optimizer'.format( 241 | type(optimizer).__name__)) 242 | self.optimizer = optimizer 243 | 244 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 245 | if len(min_lr) != len(optimizer.param_groups): 246 | raise ValueError("expected {} min_lrs, got {}".format( 247 | len(optimizer.param_groups), len(min_lr))) 248 | self.min_lrs = list(min_lr) 249 | else: 250 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 251 | 252 | self.patience = patience 253 | self.verbose = verbose 254 | self.cooldown = cooldown 255 | self.cooldown_counter = 0 256 | self.mode = mode 257 | self.threshold = threshold 258 | self.threshold_mode = threshold_mode 259 | self.best = None 260 | self.num_bad_epochs = None 261 | self.mode_worse = None # the worse value for the chosen mode 262 | self.is_better = None 263 | self.eps = eps 264 | self.last_epoch = -1 265 | self._init_is_better(mode=mode, threshold=threshold, 266 | threshold_mode=threshold_mode) 267 | self._reset() 268 | 269 | def _reset(self): 270 | """Resets num_bad_epochs counter and cooldown counter.""" 271 | self.best = self.mode_worse 272 | self.cooldown_counter = 0 273 | self.num_bad_epochs = 0 274 | 275 | def step(self, metrics, epoch=None): 276 | current = metrics 277 | if epoch is None: 278 | epoch = self.last_epoch = self.last_epoch + 1 279 | self.last_epoch = epoch 280 | 281 | if self.is_better(current, self.best): 282 | self.best = current 283 | self.num_bad_epochs = 0 284 | else: 285 | self.num_bad_epochs += 1 286 | 287 | if self.in_cooldown: 288 | self.cooldown_counter -= 1 289 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 290 | 291 | if self.num_bad_epochs > self.patience: 292 | self._reduce_lr(epoch) 293 | self.cooldown_counter = self.cooldown 294 | self.num_bad_epochs = 0 295 | 296 | def _reduce_lr(self, epoch): 297 | for i, param_group in enumerate(self.optimizer.param_groups): 298 | old_lr = float(param_group['lr']) 299 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 300 | if old_lr - new_lr > self.eps: 301 | param_group['lr'] = new_lr 302 | if self.verbose: 303 | print('Epoch {:5d}: reducing learning rate' 304 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 305 | 306 | @property 307 | def in_cooldown(self): 308 | return self.cooldown_counter > 0 309 | 310 | def _init_is_better(self, mode, threshold, threshold_mode): 311 | if mode not in {'min', 'max'}: 312 | raise ValueError('mode ' + mode + ' is unknown!') 313 | if threshold_mode not in {'rel', 'abs'}: 314 | raise ValueError('threshold mode ' + mode + ' is unknown!') 315 | if mode == 'min' and threshold_mode == 'rel': 316 | rel_epsilon = 1. - threshold 317 | self.is_better = lambda a, best: a < best * rel_epsilon 318 | self.mode_worse = float('Inf') 319 | elif mode == 'min' and threshold_mode == 'abs': 320 | self.is_better = lambda a, best: a < best - threshold 321 | self.mode_worse = float('Inf') 322 | elif mode == 'max' and threshold_mode == 'rel': 323 | rel_epsilon = threshold + 1. 324 | self.is_better = lambda a, best: a > best * rel_epsilon 325 | self.mode_worse = -float('Inf') 326 | else: # mode == 'max' and epsilon_mode == 'abs': 327 | self.is_better = lambda a, best: a > best + threshold 328 | self.mode_worse = -float('Inf') 329 | --------------------------------------------------------------------------------