├── .gitignore ├── Pretraining ├── BiGRU.py ├── data.py ├── lr_scheduler.py ├── metric.py ├── model.py ├── preprocess.py ├── train.py └── utils.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Pretraining/BiGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GRU(nn.Module): 5 | 6 | def __init__(self, dim_word, dim_h, num_layers, dropout): 7 | super().__init__() 8 | self.encoder = nn.GRU(input_size=dim_word, 9 | hidden_size=dim_h, 10 | num_layers=num_layers, 11 | dropout=dropout, 12 | batch_first=True, 13 | bidirectional=False) 14 | 15 | def forward_one_step(self, input, last_h): 16 | """ 17 | Args: 18 | - input (bsz, 1, w_dim) 19 | - last_h (num_layers, bsz, h_dim) 20 | """ 21 | hidden, new_h = self.encoder(input, last_h) 22 | return hidden, new_h # (bsz, 1, h_dim), (num_layers, bsz, h_dim) 23 | 24 | 25 | def generate_sequence(self, word_lookup_func, h_0, classifier, vocab, max_step, early_stop=True): 26 | bsz = h_0.size(1) 27 | device = h_0.device 28 | start_id, end_id, pad_id = vocab[''], vocab[''], vocab[''] 29 | 30 | latest = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 31 | results = [latest] 32 | last_h = h_0 33 | finished = torch.zeros((bsz,)).bool().to(device) # record whether is produced 34 | for i in range(max_step-1): # exclude 35 | word_emb = word_lookup_func(latest).unsqueeze(1) # [bsz, 1, dim_w] 36 | word_h, last_h = self.forward_one_step(word_emb, last_h) # [bsz, 1, dim_h] 37 | 38 | logit = classifier(word_h).squeeze(1) # [bsz, num_func] 39 | latest = torch.argmax(logit, dim=1).long() # [bsz, ] 40 | latest[finished] = pad_id # set to after 41 | results.append(latest) 42 | 43 | finished = finished | latest.eq(end_id).bool() 44 | if early_stop and finished.sum().item() == bsz: 45 | # print('finished at step {}'.format(i)) 46 | break 47 | results = torch.stack(results, dim=1) # [bsz, max_len'] 48 | return results 49 | 50 | 51 | def forward(self, input, length, h_0=None): 52 | """ 53 | Args: 54 | - input (bsz, len, w_dim) 55 | - length (bsz, ) 56 | - h_0 (num_layers, bsz, h_dim) 57 | Return: 58 | - hidden (bsz, len, dim) : hidden state of each word 59 | - output (bsz, dim) : sentence embedding 60 | """ 61 | bsz, max_len = input.size(0), input.size(1) 62 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 63 | _, desorted_indices = torch.sort(indices, descending=False) 64 | input = input[indices] 65 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 66 | if h_0 is None: 67 | hidden, h_n = self.encoder(packed_input) 68 | else: 69 | h_0 = h_0[:, indices] 70 | hidden, h_n = self.encoder(packed_input, h_0) 71 | # h_n is (num_layers, bsz, h_dim) 72 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 73 | 74 | output = h_n[-1, :, :] # (bsz, h_dim), take the last layer's state 75 | 76 | # recover order 77 | hidden = hidden[desorted_indices] 78 | output = output[desorted_indices] 79 | h_n = h_n[:, desorted_indices] 80 | return hidden, output, h_n 81 | 82 | 83 | 84 | class BiGRU(nn.Module): 85 | 86 | def __init__(self, dim_word, dim_h, num_layers, dropout): 87 | super().__init__() 88 | self.encoder = nn.GRU(input_size=dim_word, 89 | hidden_size=dim_h//2, 90 | num_layers=num_layers, 91 | dropout=dropout, 92 | batch_first=True, 93 | bidirectional=True) 94 | 95 | def forward(self, input, length): 96 | """ 97 | Args: 98 | - input (bsz, len, w_dim) 99 | - length (bsz, ) 100 | Return: 101 | - hidden (bsz, len, dim) : hidden state of each word 102 | - output (bsz, dim) : sentence embedding 103 | - h_n (num_layers * 2, bsz, dim//2) 104 | """ 105 | bsz, max_len = input.size(0), input.size(1) 106 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 107 | _, desorted_indices = torch.sort(indices, descending=False) 108 | input = input[indices] 109 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 110 | hidden, h_n = self.encoder(packed_input) 111 | # h_n is (num_layers * num_directions, bsz, h_dim//2) 112 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 113 | 114 | output = h_n[-2:, :, :] # (2, bsz, h_dim//2), take the last layer's state 115 | output = output.permute(1, 0, 2).contiguous().view(bsz, -1) # (bsz, h_dim), merge forward and backward h_n 116 | 117 | # recover order 118 | hidden = hidden[desorted_indices] 119 | output = output[desorted_indices] 120 | h_n = h_n[:, desorted_indices] 121 | return hidden, output, h_n -------------------------------------------------------------------------------- /Pretraining/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from Pretraining.utils import invert_dict 5 | 6 | def load_vocab(path): 7 | vocab = json.load(open(path)) 8 | vocab['id2function'] = invert_dict(vocab['function2id']) 9 | return vocab 10 | 11 | def collate(batch): 12 | batch = list(zip(*batch)) 13 | input_ids, token_type_ids, attention_mask, function_ids, relation_pos, relation_id = list(map(torch.stack, batch)) 14 | return input_ids, token_type_ids, attention_mask, function_ids, relation_pos, relation_id 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, inputs): 18 | self.input_ids, self.token_type_ids, self.attention_mask, self.function_ids, self.relation_pos, self.relation_id = inputs 19 | 20 | def __getitem__(self, index): 21 | input_ids = torch.LongTensor(self.input_ids[index]) 22 | token_type_ids = torch.LongTensor(self.token_type_ids[index]) 23 | attention_mask = torch.LongTensor(self.attention_mask[index]) 24 | function_ids = torch.LongTensor(self.function_ids[index]) 25 | relation_pos = torch.LongTensor(self.relation_pos[index]) 26 | relation_id = torch.LongTensor(self.relation_id[index]) 27 | return input_ids, token_type_ids, attention_mask, function_ids, relation_pos, relation_id 28 | 29 | 30 | def __len__(self): 31 | return len(self.input_ids) 32 | 33 | 34 | class DataLoader(torch.utils.data.DataLoader): 35 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 36 | vocab = load_vocab(vocab_json) 37 | inputs = [] 38 | with open(question_pt, 'rb') as f: 39 | for _ in range(6): 40 | inputs.append(pickle.load(f)) 41 | dataset = Dataset(inputs) 42 | 43 | super().__init__( 44 | dataset, 45 | batch_size=batch_size, 46 | shuffle=training, 47 | collate_fn=collate, 48 | ) 49 | self.vocab = vocab -------------------------------------------------------------------------------- /Pretraining/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import warnings 4 | from torch.optim.optimizer import Optimizer 5 | from torch.optim.lr_scheduler import LambdaLR 6 | 7 | def get_constant_schedule(optimizer, last_epoch=-1): 8 | """ Create a schedule with a constant learning rate. 9 | """ 10 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 11 | 12 | 13 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 14 | """ Create a schedule with a constant learning rate preceded by a warmup 15 | period during which the learning rate increases linearly between 0 and 1. 16 | """ 17 | def lr_lambda(current_step): 18 | if current_step < num_warmup_steps: 19 | return float(current_step) / float(max(1.0, num_warmup_steps)) 20 | return 1. 21 | 22 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 23 | 24 | 25 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 26 | """ Create a schedule with a learning rate that decreases linearly after 27 | linearly increasing during a warmup period. 28 | """ 29 | def lr_lambda(current_step): 30 | if current_step < num_warmup_steps: 31 | return float(current_step) / float(max(1, num_warmup_steps)) 32 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 33 | 34 | return LambdaLR(optimizer, lr_lambda, last_epoch) 35 | 36 | 37 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): 38 | """ Create a schedule with a learning rate that decreases following the 39 | values of the cosine function between 0 and `pi * cycles` after a warmup 40 | period during which it increases linearly between 0 and 1. 41 | """ 42 | def lr_lambda(current_step): 43 | if current_step < num_warmup_steps: 44 | return float(current_step) / float(max(1, num_warmup_steps)) 45 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 46 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) 47 | 48 | return LambdaLR(optimizer, lr_lambda, last_epoch) 49 | 50 | 51 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): 52 | """ Create a schedule with a learning rate that decreases following the 53 | values of the cosine function with several hard restarts, after a warmup 54 | period during which it increases linearly between 0 and 1. 55 | """ 56 | def lr_lambda(current_step): 57 | if current_step < num_warmup_steps: 58 | return float(current_step) / float(max(1, num_warmup_steps)) 59 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 60 | if progress >= 1.: 61 | return 0. 62 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) 63 | 64 | return LambdaLR(optimizer, lr_lambda, last_epoch) 65 | 66 | 67 | class CustomDecayLR(object): 68 | ''' 69 | 自定义学习率变化机制 70 | Example: 71 | >>> scheduler = CustomDecayLR(optimizer) 72 | >>> for epoch in range(100): 73 | >>> scheduler.epoch_step() 74 | >>> train(...) 75 | >>> ... 76 | >>> optimizer.zero_grad() 77 | >>> loss.backward() 78 | >>> optimizer.step() 79 | >>> validate(...) 80 | ''' 81 | def __init__(self,optimizer,lr): 82 | self.optimizer = optimizer 83 | self.lr = lr 84 | 85 | def epoch_step(self,epoch): 86 | lr = self.lr 87 | if epoch > 12: 88 | lr = lr / 1000 89 | elif epoch > 8: 90 | lr = lr / 100 91 | elif epoch > 4: 92 | lr = lr / 10 93 | for param_group in self.optimizer.param_groups: 94 | param_group['lr'] = lr 95 | 96 | class BertLR(object): 97 | ''' 98 | Bert模型内定的学习率变化机制 99 | Example: 100 | >>> scheduler = BertLR(optimizer) 101 | >>> for epoch in range(100): 102 | >>> scheduler.step() 103 | >>> train(...) 104 | >>> ... 105 | >>> optimizer.zero_grad() 106 | >>> loss.backward() 107 | >>> optimizer.step() 108 | >>> scheduler.batch_step() 109 | >>> validate(...) 110 | ''' 111 | def __init__(self,optimizer,learning_rate,t_total,warmup): 112 | self.learning_rate = learning_rate 113 | self.optimizer = optimizer 114 | self.t_total = t_total 115 | self.warmup = warmup 116 | 117 | # 线性预热方式 118 | def warmup_linear(self,x, warmup=0.002): 119 | if x < warmup: 120 | return x / warmup 121 | return 1.0 - x 122 | 123 | def batch_step(self,training_step): 124 | lr_this_step = self.learning_rate * self.warmup_linear(training_step / self.t_total,self.warmup) 125 | for param_group in self.optimizer.param_groups: 126 | param_group['lr'] = lr_this_step 127 | 128 | class CyclicLR(object): 129 | ''' 130 | Cyclical learning rates for training neural networks 131 | Example: 132 | >>> scheduler = CyclicLR(optimizer) 133 | >>> for epoch in range(100): 134 | >>> scheduler.step() 135 | >>> train(...) 136 | >>> ... 137 | >>> optimizer.zero_grad() 138 | >>> loss.backward() 139 | >>> optimizer.step() 140 | >>> scheduler.batch_step() 141 | >>> validate(...) 142 | ''' 143 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 144 | step_size=2000, mode='triangular', gamma=1., 145 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 146 | 147 | if not isinstance(optimizer, Optimizer): 148 | raise TypeError('{} is not an Optimizer'.format( 149 | type(optimizer).__name__)) 150 | 151 | self.optimizer = optimizer 152 | 153 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 154 | if len(base_lr) != len(optimizer.param_groups): 155 | raise ValueError("expected {} base_lr, got {}".format( 156 | len(optimizer.param_groups), len(base_lr))) 157 | self.base_lrs = list(base_lr) 158 | else: 159 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 160 | 161 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 162 | if len(max_lr) != len(optimizer.param_groups): 163 | raise ValueError("expected {} max_lr, got {}".format( 164 | len(optimizer.param_groups), len(max_lr))) 165 | self.max_lrs = list(max_lr) 166 | else: 167 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 168 | 169 | self.step_size = step_size 170 | 171 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 172 | and scale_fn is None: 173 | raise ValueError('mode is invalid and scale_fn is None') 174 | 175 | self.mode = mode 176 | self.gamma = gamma 177 | 178 | if scale_fn is None: 179 | if self.mode == 'triangular': 180 | self.scale_fn = self._triangular_scale_fn 181 | self.scale_mode = 'cycle' 182 | elif self.mode == 'triangular2': 183 | self.scale_fn = self._triangular2_scale_fn 184 | self.scale_mode = 'cycle' 185 | elif self.mode == 'exp_range': 186 | self.scale_fn = self._exp_range_scale_fn 187 | self.scale_mode = 'iterations' 188 | else: 189 | self.scale_fn = scale_fn 190 | self.scale_mode = scale_mode 191 | 192 | self.batch_step(last_batch_iteration + 1) 193 | self.last_batch_iteration = last_batch_iteration 194 | 195 | def _triangular_scale_fn(self, x): 196 | return 1. 197 | 198 | def _triangular2_scale_fn(self, x): 199 | return 1 / (2. ** (x - 1)) 200 | 201 | def _exp_range_scale_fn(self, x): 202 | return self.gamma**(x) 203 | 204 | def get_lr(self): 205 | step_size = float(self.step_size) 206 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 207 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 208 | 209 | lrs = [] 210 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 211 | for param_group, base_lr, max_lr in param_lrs: 212 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 213 | if self.scale_mode == 'cycle': 214 | lr = base_lr + base_height * self.scale_fn(cycle) 215 | else: 216 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 217 | lrs.append(lr) 218 | return lrs 219 | 220 | def batch_step(self, batch_iteration=None): 221 | if batch_iteration is None: 222 | batch_iteration = self.last_batch_iteration + 1 223 | self.last_batch_iteration = batch_iteration 224 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 225 | param_group['lr'] = lr 226 | 227 | class ReduceLROnPlateau(object): 228 | """Reduce learning rate when a metric has stopped improving. 229 | Models often benefit from reducing the learning rate by a factor 230 | of 2-10 once learning stagnates. This scheduler reads a metrics 231 | quantity and if no improvement is seen for a 'patience' number 232 | of epochs, the learning rate is reduced. 233 | 234 | Args: 235 | factor: factor by which the learning rate will 236 | be reduced. new_lr = lr * factor 237 | patience: number of epochs with no improvement 238 | after which learning rate will be reduced. 239 | verbose: int. 0: quiet, 1: update messages. 240 | mode: one of {min, max}. In `min` mode, 241 | lr will be reduced when the quantity 242 | monitored has stopped decreasing; in `max` 243 | mode it will be reduced when the quantity 244 | monitored has stopped increasing. 245 | epsilon: threshold for measuring the new optimum, 246 | to only focus on significant changes. 247 | cooldown: number of epochs to wait before resuming 248 | normal operation after lr has been reduced. 249 | min_lr: lower bound on the learning rate. 250 | 251 | 252 | Example: 253 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 254 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 255 | >>> for epoch in range(10): 256 | >>> train(...) 257 | >>> val_acc, val_loss = validate(...) 258 | >>> scheduler.epoch_step(val_loss, epoch) 259 | """ 260 | 261 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 262 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0,eps=1e-8): 263 | 264 | super(ReduceLROnPlateau, self).__init__() 265 | assert isinstance(optimizer, Optimizer) 266 | if factor >= 1.0: 267 | raise ValueError('ReduceLROnPlateau ' 268 | 'does not support a factor >= 1.0.') 269 | self.factor = factor 270 | self.min_lr = min_lr 271 | self.epsilon = epsilon 272 | self.patience = patience 273 | self.verbose = verbose 274 | self.cooldown = cooldown 275 | self.cooldown_counter = 0 # Cooldown counter. 276 | self.monitor_op = None 277 | self.wait = 0 278 | self.best = 0 279 | self.mode = mode 280 | self.optimizer = optimizer 281 | self.eps = eps 282 | self._reset() 283 | 284 | def _reset(self): 285 | """Resets wait counter and cooldown counter. 286 | """ 287 | if self.mode not in ['min', 'max']: 288 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 289 | if self.mode == 'min': 290 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 291 | self.best = np.Inf 292 | else: 293 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 294 | self.best = -np.Inf 295 | self.cooldown_counter = 0 296 | self.wait = 0 297 | 298 | def reset(self): 299 | self._reset() 300 | 301 | def epoch_step(self, metrics, epoch): 302 | current = metrics 303 | if current is None: 304 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 305 | else: 306 | if self.in_cooldown(): 307 | self.cooldown_counter -= 1 308 | self.wait = 0 309 | 310 | if self.monitor_op(current, self.best): 311 | self.best = current 312 | self.wait = 0 313 | elif not self.in_cooldown(): 314 | if self.wait >= self.patience: 315 | for param_group in self.optimizer.param_groups: 316 | old_lr = float(param_group['lr']) 317 | if old_lr > self.min_lr + self.eps: 318 | new_lr = old_lr * self.factor 319 | new_lr = max(new_lr, self.min_lr) 320 | param_group['lr'] = new_lr 321 | if self.verbose > 0: 322 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr)) 323 | self.cooldown_counter = self.cooldown 324 | self.wait = 0 325 | self.wait += 1 326 | 327 | def in_cooldown(self): 328 | return self.cooldown_counter > 0 329 | 330 | class ReduceLRWDOnPlateau(ReduceLROnPlateau): 331 | """Reduce learning rate and weight decay when a metric has stopped 332 | improving. Models often benefit from reducing the learning rate by 333 | a factor of 2-10 once learning stagnates. This scheduler reads a metric 334 | quantity and if no improvement is seen for a 'patience' number 335 | of epochs, the learning rate and weight decay factor is reduced for 336 | optimizers that implement the the weight decay method from the paper 337 | `Fixing Weight Decay Regularization in Adam`_. 338 | 339 | .. _Fixing Weight Decay Regularization in Adam: 340 | https://arxiv.org/abs/1711.05101 341 | for AdamW or SGDW 342 | Example: 343 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 344 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 345 | >>> for epoch in range(10): 346 | >>> train(...) 347 | >>> val_loss = validate(...) 348 | >>> # Note that step should be called after validate() 349 | >>> scheduler.epoch_step(val_loss) 350 | """ 351 | def epoch_step(self, metrics, epoch): 352 | current = metrics 353 | if current is None: 354 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 355 | else: 356 | if self.in_cooldown(): 357 | self.cooldown_counter -= 1 358 | self.wait = 0 359 | 360 | if self.monitor_op(current, self.best): 361 | self.best = current 362 | self.wait = 0 363 | elif not self.in_cooldown(): 364 | if self.wait >= self.patience: 365 | for param_group in self.optimizer.param_groups: 366 | old_lr = float(param_group['lr']) 367 | if old_lr > self.min_lr + self.eps: 368 | new_lr = old_lr * self.factor 369 | new_lr = max(new_lr, self.min_lr) 370 | param_group['lr'] = new_lr 371 | if self.verbose > 0: 372 | print('\nEpoch %d: reducing learning rate to %s.' % (epoch, new_lr)) 373 | if param_group['weight_decay'] != 0: 374 | old_weight_decay = float(param_group['weight_decay']) 375 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr) 376 | if old_weight_decay > new_weight_decay + self.eps: 377 | param_group['weight_decay'] = new_weight_decay 378 | if self.verbose: 379 | print('\nEpoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}.') 380 | self.cooldown_counter = self.cooldown 381 | self.wait = 0 382 | self.wait += 1 383 | 384 | class CosineLRWithRestarts(object): 385 | """Decays learning rate with cosine annealing, normalizes weight decay 386 | hyperparameter value, implements restarts. 387 | https://arxiv.org/abs/1711.05101 388 | 389 | Args: 390 | optimizer (Optimizer): Wrapped optimizer. 391 | batch_size: minibatch size 392 | epoch_size: training samples per epoch 393 | restart_period: epoch count in the first restart period 394 | t_mult: multiplication factor by which the next restart period will extend/shrink 395 | 396 | Example: 397 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 398 | >>> for epoch in range(100): 399 | >>> scheduler.step() 400 | >>> train(...) 401 | >>> ... 402 | >>> optimizer.zero_grad() 403 | >>> loss.backward() 404 | >>> optimizer.step() 405 | >>> scheduler.batch_step() 406 | >>> validate(...) 407 | """ 408 | 409 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 410 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 411 | if not isinstance(optimizer, Optimizer): 412 | raise TypeError('{} is not an Optimizer'.format( 413 | type(optimizer).__name__)) 414 | self.optimizer = optimizer 415 | if last_epoch == -1: 416 | for group in optimizer.param_groups: 417 | group.setdefault('initial_lr', group['lr']) 418 | else: 419 | for i, group in enumerate(optimizer.param_groups): 420 | if 'initial_lr' not in group: 421 | raise KeyError("param 'initial_lr' is not specified " 422 | "in param_groups[{}] when resuming an" 423 | " optimizer".format(i)) 424 | self.base_lrs = list(map(lambda group: group['initial_lr'], 425 | optimizer.param_groups)) 426 | 427 | self.last_epoch = last_epoch 428 | self.batch_size = batch_size 429 | self.iteration = 0 430 | self.epoch_size = epoch_size 431 | self.eta_threshold = eta_threshold 432 | self.t_mult = t_mult 433 | self.verbose = verbose 434 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 435 | optimizer.param_groups)) 436 | self.restart_period = restart_period 437 | self.restarts = 0 438 | self.t_epoch = -1 439 | self.batch_increments = [] 440 | self._set_batch_increment() 441 | 442 | def _schedule_eta(self): 443 | """ 444 | Threshold value could be adjusted to shrink eta_min and eta_max values. 445 | """ 446 | eta_min = 0 447 | eta_max = 1 448 | if self.restarts <= self.eta_threshold: 449 | return eta_min, eta_max 450 | else: 451 | d = self.restarts - self.eta_threshold 452 | k = d * 0.09 453 | return (eta_min + k, eta_max - k) 454 | 455 | def get_lr(self, t_cur): 456 | eta_min, eta_max = self._schedule_eta() 457 | 458 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 459 | * (1. + math.cos(math.pi * 460 | (t_cur / self.restart_period)))) 461 | 462 | weight_decay_norm_multi = math.sqrt(self.batch_size / 463 | (self.epoch_size * 464 | self.restart_period)) 465 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 466 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 467 | for base_weight_decay in self.base_weight_decays] 468 | 469 | if self.t_epoch % self.restart_period < self.t_epoch: 470 | if self.verbose: 471 | print("Restart at epoch {}".format(self.last_epoch)) 472 | self.restart_period *= self.t_mult 473 | self.restarts += 1 474 | self.t_epoch = 0 475 | 476 | return zip(lrs, weight_decays) 477 | 478 | def _set_batch_increment(self): 479 | d, r = divmod(self.epoch_size, self.batch_size) 480 | batches_in_epoch = d + 2 if r > 0 else d + 1 481 | self.iteration = 0 482 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch)) 483 | 484 | def batch_step(self): 485 | self.last_epoch += 1 486 | self.t_epoch += 1 487 | self._set_batch_increment() 488 | try: 489 | t_cur = self.t_epoch + self.batch_increments[self.iteration] 490 | self.iteration += 1 491 | except (IndexError): 492 | raise RuntimeError("Epoch size and batch size used in the " 493 | "training loop and while initializing " 494 | "scheduler should be the same.") 495 | 496 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,self.get_lr(t_cur)): 497 | param_group['lr'] = lr 498 | param_group['weight_decay'] = weight_decay 499 | 500 | 501 | class NoamLR(object): 502 | ''' 503 | 主要参考论文<< Attention Is All You Need>>中的学习更新方式 504 | Example: 505 | >>> scheduler = NoamLR(d_model,factor,warm_up,optimizer) 506 | >>> for epoch in range(100): 507 | >>> scheduler.step() 508 | >>> train(...) 509 | >>> ... 510 | >>> glopab_step += 1 511 | >>> optimizer.zero_grad() 512 | >>> loss.backward() 513 | >>> optimizer.step() 514 | >>> scheduler.batch_step(global_step) 515 | >>> validate(...) 516 | ''' 517 | def __init__(self,d_model,factor,warm_up,optimizer): 518 | self.optimizer = optimizer 519 | self.warm_up = warm_up 520 | self.factor = factor 521 | self.d_model = d_model 522 | self._lr = 0 523 | 524 | def get_lr(self,step): 525 | lr = self.factor * (self.d_model ** (-0.5) * min(step ** (-0.5),step * self.warm_up ** (-1.5))) 526 | return lr 527 | 528 | def batch_step(self,step): 529 | ''' 530 | update parameters and rate 531 | :return: 532 | ''' 533 | lr = self.get_lr(step) 534 | for p in self.optimizer.param_groups: 535 | p['lr'] = lr 536 | self._lr = lr 537 | -------------------------------------------------------------------------------- /Pretraining/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import Counter 3 | 4 | def get_entities(seq,id2label,markup='bios'): 5 | """Gets entities from sequence. 6 | note: BIOS 7 | Args: 8 | seq (list): sequence of labels. 9 | Returns: 10 | list: list of (chunk_type, chunk_start, chunk_end). 11 | Example: 12 | # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] 13 | # >>> get_entity_bios(seq) 14 | [['PER', 0,1], ['LOC', 3, 3]] 15 | """ 16 | chunks = [] 17 | chunk = [-1, -1, -1] 18 | for indx, tag in enumerate(seq): 19 | if not isinstance(tag, str): 20 | tag = id2label[tag] 21 | if tag.startswith("S-"): 22 | if chunk[2] != -1: 23 | chunks.append(chunk) 24 | chunk = [-1, -1, -1] 25 | chunk[1] = indx 26 | chunk[2] = indx 27 | chunk[0] = tag.split('-')[1] 28 | chunks.append(chunk) 29 | chunk = (-1, -1, -1) 30 | if tag.startswith("B-"): 31 | if chunk[2] != -1: 32 | chunks.append(chunk) 33 | chunk = [-1, -1, -1] 34 | chunk[1] = indx 35 | chunk[0] = tag.split('-')[1] 36 | elif tag.startswith('I-') and chunk[1] != -1: 37 | _type = tag.split('-')[1] 38 | if _type == chunk[0]: 39 | chunk[2] = indx 40 | if indx == len(seq) - 1: 41 | chunks.append(chunk) 42 | else: 43 | if chunk[2] != -1: 44 | chunks.append(chunk) 45 | chunk = [-1, -1, -1] 46 | return chunks 47 | 48 | class F1(object): 49 | def __init__(self): 50 | self.origins = [] 51 | self.founds = [] 52 | self.rights = [] 53 | def compute(self, origin, found, right): 54 | recall = 0 if origin == 0 else (right / origin) 55 | precision = 0 if found == 0 else (right / found) 56 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) 57 | return recall, precision, f1 58 | def result(self): 59 | num_examples = len(self.origins) 60 | origins = 0 61 | founds = 0 62 | rights = 0 63 | for i in range(num_examples): 64 | origin = self.origins[i] 65 | found = self.founds[i] 66 | right = self.rights[i] 67 | origins += len(origin) 68 | founds += len(found) 69 | rights += len(right) 70 | # print('origins: {}'.format(origins)) 71 | # print('founds: {}'.format(founds)) 72 | # print('rights: {}'.format(rights)) 73 | recall, precision, f1 = self.compute(origins, founds, rights) 74 | return {'acc': precision, 'recall': recall, 'f1': f1} 75 | 76 | def update(self, pred, label): 77 | self.origins.append(label) 78 | self.founds.append(pred) 79 | self.rights.append(list(set(pred) & set(label))) 80 | 81 | class FunctionAcc(object): 82 | def __init__(self, end_id): 83 | self.correct = 0 84 | self.tot = 0 85 | self.end_id = end_id 86 | def result(self): 87 | acc = self.correct / self.tot 88 | return acc 89 | def update(self, pred, label): 90 | match = True 91 | for i in range(min(len(pred), len(label))): 92 | if label[i] != pred[i]: 93 | match = False 94 | break 95 | if pred[i] == self.end_id and label[i] == self.end_id: 96 | break 97 | if match: 98 | self.correct += 1 99 | self.tot += 1 100 | 101 | class SeqEntityScore(object): 102 | def __init__(self, id2label,markup='bios'): 103 | self.id2label = id2label 104 | self.markup = markup 105 | self.reset() 106 | 107 | def reset(self): 108 | self.origins = [] 109 | self.founds = [] 110 | self.rights = [] 111 | 112 | def compute(self, origin, found, right): 113 | recall = 0 if origin == 0 else (right / origin) 114 | precision = 0 if found == 0 else (right / found) 115 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) 116 | return recall, precision, f1 117 | 118 | def result(self): 119 | class_info = {} 120 | origin_counter = Counter([x[0] for x in self.origins]) 121 | found_counter = Counter([x[0] for x in self.founds]) 122 | right_counter = Counter([x[0] for x in self.rights]) 123 | for type_, count in origin_counter.items(): 124 | origin = count 125 | found = found_counter.get(type_, 0) 126 | right = right_counter.get(type_, 0) 127 | recall, precision, f1 = self.compute(origin, found, right) 128 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} 129 | origin = len(self.origins) 130 | found = len(self.founds) 131 | right = len(self.rights) 132 | recall, precision, f1 = self.compute(origin, found, right) 133 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info 134 | 135 | def update(self, label_paths, pred_paths): 136 | ''' 137 | labels_paths: [[],[],[],....] 138 | pred_paths: [[],[],[],.....] 139 | 140 | :param label_paths: 141 | :param pred_paths: 142 | :return: 143 | Example: 144 | >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 145 | >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 146 | ''' 147 | for label_path, pre_path in zip(label_paths, pred_paths): 148 | label_entities = get_entities(label_path, self.id2label,self.markup) 149 | pre_entities = get_entities(pre_path, self.id2label,self.markup) 150 | self.origins.extend(label_entities) 151 | self.founds.extend(pre_entities) 152 | self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities]) -------------------------------------------------------------------------------- /Pretraining/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import * 5 | from Pretraining.utils import * 6 | from Pretraining.BiGRU import GRU, BiGRU 7 | class RelationPT(BertPreTrainedModel): 8 | def __init__(self, config): 9 | super(RelationPT, self).__init__(config) 10 | self.vocab = config.vocab 11 | self.bert = BertModel(config) 12 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 13 | self.num_functions = len(config.vocab['function2id']) 14 | self.function_embeddings = nn.Embedding(self.num_functions, config.hidden_size) 15 | self.function_decoder = GRU(config.hidden_size, config.hidden_size, num_layers = 1, dropout = 0.2) 16 | self.function_classifier = nn.Sequential( 17 | nn.Linear(config.hidden_size, 1024), 18 | nn.ReLU(), 19 | nn.Linear(1024, self.num_functions), 20 | ) 21 | self.word_dropout = nn.Dropout(0.2) 22 | self.max_program_len = 17 23 | # self.relation_inputs = config.relation_inputs 24 | self.relation_classifier = nn.Sequential( 25 | nn.Linear(config.hidden_size, 1024), 26 | nn.ReLU(), 27 | nn.Linear(1024, config.hidden_size), 28 | ) 29 | 30 | self.concept_classifier = nn.Sequential( 31 | nn.Linear(config.hidden_size, 1024), 32 | nn.ReLU(), 33 | nn.Linear(1024, config.hidden_size), 34 | ) 35 | 36 | self.hidden_size = config.hidden_size 37 | self.init_weights() 38 | 39 | def demo(self, input_ids, token_type_ids, attention_mask): 40 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) 41 | sequence_output = outputs[0] # [bsz, max_seq_length, hidden_size] 42 | pooler_output = outputs[1] # [bsz, hidden_size] 43 | outputs = {} 44 | sequence_output = self.dropout(sequence_output) 45 | bsz = input_ids.size(0) 46 | device = input_ids.device 47 | start_id = self.vocab['function2id'][''] 48 | end_id = self.vocab['function2id'][''] 49 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced 50 | latest_func = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 51 | programs = [latest_func] 52 | last_h = pooler_output.unsqueeze(0) 53 | for i in range(self.max_program_len): 54 | p_word_emb = self.word_dropout(self.function_embeddings(latest_func)).unsqueeze(1) # [bsz, 1, dim_w] 55 | p_word_h, last_h = self.function_decoder.forward_one_step(p_word_emb, last_h) # [bsz, 1, dim_h] 56 | # attention over question words 57 | attn = torch.softmax(torch.bmm(p_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q] 58 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, 1, dim_h] 59 | # sum up 60 | p_word_h = p_word_h + attn_word_h # [bsz, 1, dim_h] 61 | 62 | # predict function 63 | logit_func = self.function_classifier(p_word_h).squeeze(1) # [bsz, num_func] 64 | latest_func = torch.argmax(logit_func, dim=1) # [bsz, ] 65 | programs.append(latest_func) 66 | finished = finished | latest_func.eq(end_id).byte() 67 | if finished.sum().item() == bsz: 68 | # print('finished at step {}'.format(i)) 69 | break 70 | programs = torch.stack(programs, dim=1) # [bsz, max_prog] 71 | outputs['pred_functions'] = programs 72 | return outputs 73 | def forward(self, concept_inputs, relation_inputs, input_ids, token_type_ids, attention_mask, function_ids, relation_info, concept_info): 74 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) 75 | sequence_output = outputs[0] # [bsz, max_seq_length, hidden_size] 76 | pooler_output = outputs[1] # [bsz, hidden_size] 77 | outputs = {} 78 | sequence_output = self.dropout(sequence_output) 79 | bsz = input_ids.size(0) 80 | if relation_info is not None and relation_info[1] is not None: 81 | relation_pos, relation_id = relation_info 82 | func_emb = self.word_dropout(self.function_embeddings(function_ids)) 83 | func_lens = function_ids.size(1) - function_ids.eq(0).long().sum(dim=1) 84 | f_word_h, _, _ = self.function_decoder(func_emb, func_lens.cpu(), h_0=pooler_output.unsqueeze(0)) # [bsz, max_prog, dim_h] 85 | attn = torch.softmax(torch.bmm(f_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q] 86 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, max_prog, dim_h] 87 | f_word_h = f_word_h + attn_word_h # # [bsz, max_prog, dim_h] 88 | function_logits = self.function_classifier(f_word_h) 89 | outputs['function_logits'] = function_logits 90 | outputs['function_loss'] = nn.CrossEntropyLoss()(function_logits.permute(0, 2, 1)[:,:,:-1], function_ids[:,1:]) 91 | # outputs['relation_loss'] = outputs['function_loss'] 92 | dim_h = f_word_h.size(-1) 93 | relation_pos = relation_pos.repeat(1, dim_h).view(bsz, 1, dim_h) 94 | f_word_h = torch.gather(f_word_h, 1, relation_pos).squeeze(1) # [bsz, dim_h] 95 | relation_embeddings = self.bert(input_ids = relation_inputs['input_ids'], \ 96 | attention_mask = relation_inputs['attention_mask'], \ 97 | token_type_ids = relation_inputs['token_type_ids'])[1] # [num_relations, dim_h] 98 | relation_embeddings = self.relation_classifier(relation_embeddings) # [num_relations, dim_h] 99 | relation_logits = f_word_h @ relation_embeddings.t() # [bsz, num_relationis] 100 | outputs['relation_logits'] = relation_logits 101 | relation_id = relation_id.squeeze(-1) 102 | # print('relation_logits', relation_logits.size()) 103 | # print('relation_id', relation_id.size()) 104 | outputs['relation_loss'] = nn.CrossEntropyLoss()(relation_logits, relation_id) 105 | 106 | if concept_info is not None and concept_info[1] is not None: 107 | concept_pos, concept_id = concept_info 108 | func_emb = self.word_dropout(self.function_embeddings(function_ids)) 109 | func_lens = function_ids.size(1) - function_ids.eq(0).long().sum(dim=1) 110 | f_word_h, _, _ = self.function_decoder(func_emb, func_lens.cpu(), h_0=pooler_output.unsqueeze(0)) # [bsz, max_prog, dim_h] 111 | attn = torch.softmax(torch.bmm(f_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q] 112 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, max_prog, dim_h] 113 | f_word_h = f_word_h + attn_word_h # # [bsz, max_prog, dim_h] 114 | function_logits = self.function_classifier(f_word_h) 115 | outputs['function_logits'] = function_logits 116 | outputs['function_loss'] = nn.CrossEntropyLoss()(function_logits.permute(0, 2, 1)[:,:,:-1], function_ids[:,1:]) 117 | # outputs['relation_loss'] = outputs['function_loss'] 118 | dim_h = f_word_h.size(-1) 119 | concept_pos = concept_pos.repeat(1, dim_h).view(bsz, 1, dim_h) 120 | f_word_h = torch.gather(f_word_h, 1, concept_pos).squeeze(1) # [bsz, dim_h] 121 | concept_embeddings = self.bert(input_ids = concept_inputs['input_ids'], \ 122 | attention_mask = concept_inputs['attention_mask'], \ 123 | token_type_ids = concept_inputs['token_type_ids'])[1] # [num_relations, dim_h] 124 | concept_embeddings = self.concept_classifier(concept_embeddings) # [num_relations, dim_h] 125 | concept_logits = f_word_h @ concept_embeddings.t() # [bsz, num_relationis] 126 | outputs['concept_logits'] = concept_logits 127 | concept_id = concept_id.squeeze(-1) 128 | # print('relation_logits', relation_logits.size()) 129 | # print('relation_id', relation_id.size()) 130 | outputs['concept_loss'] = nn.CrossEntropyLoss()(concept_logits, concept_id) 131 | 132 | 133 | if relation_info is not None and relation_info[1] is None: 134 | relation_pos, relation_id = relation_info 135 | bsz = input_ids.size(0) 136 | device = input_ids.device 137 | start_id = self.vocab['function2id'][''] 138 | end_id = self.vocab['function2id'][''] 139 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced 140 | latest_func = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 141 | programs = [latest_func] 142 | last_h = pooler_output.unsqueeze(0) 143 | for i in range(self.max_program_len): 144 | p_word_emb = self.word_dropout(self.function_embeddings(latest_func)).unsqueeze(1) # [bsz, 1, dim_w] 145 | p_word_h, last_h = self.function_decoder.forward_one_step(p_word_emb, last_h) # [bsz, 1, dim_h] 146 | # attention over question words 147 | attn = torch.softmax(torch.bmm(p_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q] 148 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, 1, dim_h] 149 | # sum up 150 | p_word_h = p_word_h + attn_word_h # [bsz, 1, dim_h] 151 | 152 | # predict function 153 | logit_func = self.function_classifier(p_word_h).squeeze(1) # [bsz, num_func] 154 | latest_func = torch.argmax(logit_func, dim=1) # [bsz, ] 155 | programs.append(latest_func) 156 | finished = finished | latest_func.eq(end_id).byte() 157 | if finished.sum().item() == bsz: 158 | # print('finished at step {}'.format(i)) 159 | break 160 | programs = torch.stack(programs, dim=1) # [bsz, max_prog] 161 | outputs['pred_functions'] = programs 162 | 163 | func_emb = self.word_dropout(self.function_embeddings(function_ids)) 164 | func_lens = function_ids.size(1) - function_ids.eq(0).long().sum(dim=1) 165 | f_word_h, _, _ = self.function_decoder(func_emb, func_lens.cpu(), h_0=pooler_output.unsqueeze(0)) # [bsz, max_prog, dim_h] 166 | attn = torch.softmax(torch.bmm(f_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q] 167 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, max_prog, dim_h] 168 | f_word_h = f_word_h + attn_word_h # # [bsz, max_prog, dim_h] 169 | # relation_pos = [relation_pos] * self.hidden_size 170 | # a : [bsz, max_prog, dim_h] 171 | # b : [bsz, 1] 172 | # c = b.repeat(1, dim_h).view(bsz,1,dim_h) 173 | # a.gather(1,c).view((bsz, dim_h)) 174 | dim_h = f_word_h.size(-1) 175 | relation_pos = relation_pos.repeat(1, dim_h).view(bsz, 1, dim_h) 176 | f_word_h = torch.gather(f_word_h, 1, relation_pos).squeeze(1) # [bsz, dim_h] 177 | relation_embeddings = self.bert(input_ids = relation_inputs['input_ids'], \ 178 | attention_mask = relation_inputs['attention_mask'], \ 179 | token_type_ids = relation_inputs['token_type_ids'])[1] # [num_relations, dim_h] 180 | relation_embeddings = self.relation_classifier(relation_embeddings) # [num_relations, dim_h] 181 | relation_logits = f_word_h @ relation_embeddings.t() # [bsz, num_relationis] 182 | outputs['pred_relation'] = torch.argmax(relation_logits, dim = 1) 183 | 184 | 185 | if concept_info is not None and concept_info[1] is None: 186 | concept_pos, concept_id = concept_info 187 | bsz = input_ids.size(0) 188 | device = input_ids.device 189 | start_id = self.vocab['function2id'][''] 190 | end_id = self.vocab['function2id'][''] 191 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced 192 | latest_func = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 193 | programs = [latest_func] 194 | last_h = pooler_output.unsqueeze(0) 195 | for i in range(self.max_program_len): 196 | p_word_emb = self.word_dropout(self.function_embeddings(latest_func)).unsqueeze(1) # [bsz, 1, dim_w] 197 | p_word_h, last_h = self.function_decoder.forward_one_step(p_word_emb, last_h) # [bsz, 1, dim_h] 198 | # attention over question words 199 | attn = torch.softmax(torch.bmm(p_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q] 200 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, 1, dim_h] 201 | # sum up 202 | p_word_h = p_word_h + attn_word_h # [bsz, 1, dim_h] 203 | 204 | # predict function 205 | logit_func = self.function_classifier(p_word_h).squeeze(1) # [bsz, num_func] 206 | latest_func = torch.argmax(logit_func, dim=1) # [bsz, ] 207 | programs.append(latest_func) 208 | finished = finished | latest_func.eq(end_id).byte() 209 | if finished.sum().item() == bsz: 210 | # print('finished at step {}'.format(i)) 211 | break 212 | programs = torch.stack(programs, dim=1) # [bsz, max_prog] 213 | outputs['pred_functions'] = programs 214 | 215 | func_emb = self.word_dropout(self.function_embeddings(function_ids)) 216 | func_lens = function_ids.size(1) - function_ids.eq(0).long().sum(dim=1) 217 | f_word_h, _, _ = self.function_decoder(func_emb, func_lens.cpu(), h_0=pooler_output.unsqueeze(0)) # [bsz, max_prog, dim_h] 218 | attn = torch.softmax(torch.bmm(f_word_h, sequence_output.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q] 219 | attn_word_h = torch.bmm(attn, sequence_output) # [bsz, max_prog, dim_h] 220 | f_word_h = f_word_h + attn_word_h # # [bsz, max_prog, dim_h] 221 | # concept_pos = [concept_pos] * self.hidden_size 222 | # a : [bsz, max_prog, dim_h] 223 | # b : [bsz, 1] 224 | # c = b.repeat(1, dim_h).view(bsz,1,dim_h) 225 | # a.gather(1,c).view((bsz, dim_h)) 226 | dim_h = f_word_h.size(-1) 227 | concept_pos = concept_pos.repeat(1, dim_h).view(bsz, 1, dim_h) 228 | f_word_h = torch.gather(f_word_h, 1, concept_pos).squeeze(1) # [bsz, dim_h] 229 | concept_embeddings = self.bert(input_ids = concept_inputs['input_ids'], \ 230 | attention_mask = concept_inputs['attention_mask'], \ 231 | token_type_ids = concept_inputs['token_type_ids'])[1] # [num_concepts, dim_h] 232 | concept_embeddings = self.concept_classifier(concept_embeddings) # [num_concepts, dim_h] 233 | concept_logits = f_word_h @ concept_embeddings.t() # [bsz, num_conceptis] 234 | outputs['pred_concept'] = torch.argmax(concept_logits, dim = 1) 235 | 236 | return outputs 237 | -------------------------------------------------------------------------------- /Pretraining/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from transformers import * 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | # from fuzzywuzzy import fuzz 8 | import os 9 | import pickle 10 | from Pretraining.utils import * 11 | 12 | tokenizer = BertTokenizer.from_pretrained('/data/csl/resources/Bert/bert-base-cased', do_lower_case = False) 13 | 14 | 15 | 16 | 17 | 18 | def get_vocab(args, vocab): 19 | kb = json.load(open(os.path.join(args.input_dir, 'kb.json'))) 20 | entities = kb['entities'] 21 | for eid in entities: 22 | relations = entities[eid]['relations'] 23 | for relation in relations: 24 | r = relation['predicate'] 25 | if relation['direction'] == 'backward': 26 | r = '[inverse] ' + r 27 | if not r in vocab['relation2id']: 28 | vocab['relation2id'][r] = len(vocab['relation2id']) 29 | vocab['id2relation'] = [relation for relation, id in vocab['relation2id'].items()] 30 | 31 | concepts = kb['concepts'] 32 | for cid in concepts: 33 | concept = concepts[cid]['name'] 34 | if not concept in vocab['concept2id']: 35 | vocab['concept2id'][concept] = len(vocab['concept2id']) 36 | vocab['id2concept'] = [concept for concept, id in vocab['concept2id'].items()] 37 | 38 | train = [json.loads(line.strip()) for line in open(os.path.join(args.input_dir, 'train.json'))] 39 | for item in train: 40 | program = item['program'] 41 | for f in program: 42 | function = f['function'] 43 | if not function in vocab['function2id']: 44 | vocab['function2id'][function] = len(vocab['function2id']) 45 | vocab['id2function'] = [function for function, id in vocab['function2id'].items()] 46 | 47 | 48 | def get_relation_dataset(args, vocab): 49 | # train = json.load(open(os.path.join(args.input_dir, 'train.json'))) 50 | # dev = json.load(open(os.path.join(args.input_dir, 'val.json'))) 51 | train = [json.loads(line.strip()) for line in open(os.path.join(args.input_dir, 'train.json'))] 52 | dev = [json.loads(line.strip()) for line in open(os.path.join(args.input_dir, 'val.json'))] 53 | 54 | for name, raw_data in zip(['train', 'dev'], [train, dev]): 55 | dataset = [] 56 | for item in tqdm(raw_data): 57 | text = item['question'] 58 | program = item['program'] 59 | data = [] 60 | relations = [] 61 | for idx, f in enumerate(program): 62 | function = f['function'] 63 | if function == 'Relate': 64 | inputs = f['inputs'] 65 | r = inputs[0] 66 | if inputs[1] == 'backward': 67 | r = '[inverse] ' + r 68 | if not r in vocab['relation2id']: 69 | continue 70 | r = vocab['relation2id'][r] 71 | relations.append([idx + 1, r]) 72 | function_id = vocab['function2id'][function] 73 | data.append({'function': function_id}) 74 | if len(relations) == 0: 75 | relations.append([0, vocab['relation2id']['']]) 76 | dataset.append({'question': text, 'program': data, 'relations': relations}) 77 | # verbose = True 78 | # if verbose: 79 | # for idx in range(100): 80 | # print('*'*10) 81 | # text = dataset[idx]['question'] 82 | # print(text) 83 | # text = tokenizer.tokenize(text) 84 | # for f in dataset[idx]['program']: 85 | # function_id = f['function'] 86 | # print(vocab['id2function'][function_id]) 87 | # for pos, r in dataset[idx]['relations']: 88 | # print(pos, vocab['id2relation'][r]) 89 | 90 | 91 | with open(os.path.join(args.output_dir, 'relation', '%s.json'%(name)), 'w') as f: 92 | for item in dataset: 93 | f.write(json.dumps(item) + '\n') 94 | 95 | def get_concept_dataset(args, vocab): 96 | train = [json.loads(line.strip()) for line in open(os.path.join(args.input_dir, 'train.json'))] 97 | dev = [json.loads(line.strip()) for line in open(os.path.join(args.input_dir, 'val.json'))] 98 | for name, raw_data in zip(['train', 'dev'], [train, dev]): 99 | dataset = [] 100 | for item in tqdm(raw_data): 101 | text = item['question'] 102 | program = item['program'] 103 | data = [] 104 | concepts = [] 105 | for idx, f in enumerate(program): 106 | function = f['function'] 107 | if function == 'FilterConcept': 108 | inputs = f['inputs'] 109 | c = inputs[0] 110 | if not c in vocab['concept2id']: 111 | continue 112 | c = vocab['concept2id'][c] 113 | concepts.append([idx + 1, c]) 114 | function_id = vocab['function2id'][function] 115 | data.append({'function': function_id}) 116 | if len(concepts) == 0: 117 | concepts.append([0, vocab['concept2id']['']]) 118 | dataset.append({'question': text, 'program': data, 'concepts': concepts}) 119 | # verbose = True 120 | # if verbose: 121 | # for idx in range(100): 122 | # print('*'*10) 123 | # text = dataset[idx]['question'] 124 | # print(text) 125 | # text = tokenizer.tokenize(text) 126 | # for f in dataset[idx]['program']: 127 | # function_id = f['function'] 128 | # print(vocab['id2function'][function_id]) 129 | # for pos, r in dataset[idx]['concepts']: 130 | # print(pos, vocab['id2concept'][r]) 131 | 132 | 133 | with open(os.path.join(args.output_dir, 'concept', '%s.json'%(name)), 'w') as f: 134 | for item in dataset: 135 | f.write(json.dumps(item) + '\n') 136 | 137 | def encode_relation(args, vocab): 138 | encoded_inputs = tokenizer(vocab['id2relation'], padding = True) 139 | print(encoded_inputs.keys()) 140 | print(len(encoded_inputs['input_ids'][0])) 141 | print(len(encoded_inputs['token_type_ids'][0])) 142 | print(len(encoded_inputs['attention_mask'][0])) 143 | print(tokenizer.decode(encoded_inputs['input_ids'][0])) 144 | max_seq_length = len(encoded_inputs['input_ids'][0]) 145 | input_ids_list = encoded_inputs['input_ids'] 146 | token_type_ids_list = encoded_inputs['token_type_ids'] 147 | attention_mask_list = encoded_inputs['attention_mask'] 148 | input_ids_list = np.array(input_ids_list, dtype=np.int32) 149 | token_type_ids_list = np.array(token_type_ids_list, dtype=np.int32) 150 | attention_mask_list = np.array(attention_mask_list, dtype=np.int32) 151 | return input_ids_list, token_type_ids_list, attention_mask_list 152 | 153 | def encode_concept(args, vocab): 154 | encoded_inputs = tokenizer(vocab['id2concept'], padding = True) 155 | print(encoded_inputs.keys()) 156 | print(len(encoded_inputs['input_ids'][0])) 157 | print(len(encoded_inputs['token_type_ids'][0])) 158 | print(len(encoded_inputs['attention_mask'][0])) 159 | print(tokenizer.decode(encoded_inputs['input_ids'][0])) 160 | max_seq_length = len(encoded_inputs['input_ids'][0]) 161 | input_ids_list = encoded_inputs['input_ids'] 162 | token_type_ids_list = encoded_inputs['token_type_ids'] 163 | attention_mask_list = encoded_inputs['attention_mask'] 164 | input_ids_list = np.array(input_ids_list, dtype=np.int32) 165 | token_type_ids_list = np.array(token_type_ids_list, dtype=np.int32) 166 | attention_mask_list = np.array(attention_mask_list, dtype=np.int32) 167 | return input_ids_list, token_type_ids_list, attention_mask_list 168 | 169 | def encode_relation_dataset(args, vocab, dataset): 170 | def get_function_ids(program): 171 | function_ids = [f['function'] for f in program] 172 | return function_ids 173 | 174 | tmp = [] 175 | for item in dataset: 176 | question = item['question'] 177 | program = item['program'] 178 | relations = item['relations'] 179 | for relation in relations: 180 | tmp.append({'question': question, 'program': program, 'relation': relation}) 181 | print('dataset size: {}'.format(len(dataset))) 182 | dataset = tmp 183 | print('new dataset size: {}'.format(len(dataset))) 184 | questions = [] 185 | for item in dataset: 186 | question = item['question'] 187 | questions.append(question) 188 | encoded_inputs = tokenizer(questions, padding = True) 189 | # print(encoded_inputs.keys()) 190 | # print(len(encoded_inputs['input_ids'][0])) 191 | # print(len(encoded_inputs['token_type_ids'][0])) 192 | # print(len(encoded_inputs['attention_mask'][0])) 193 | # print(tokenizer.decode(encoded_inputs['input_ids'][0])) 194 | max_seq_length = len(encoded_inputs['input_ids'][0]) 195 | function_ids_list = [] 196 | for item in tqdm(dataset): 197 | program = item['program'] 198 | program = [{'function': vocab['function2id']['']}] + program + [{'function': vocab['function2id']['']}] 199 | function_ids = get_function_ids(program) 200 | function_ids_list.append(function_ids) 201 | max_func_len = max([len(function_ids) for function_ids in function_ids_list]) 202 | print('max_func_len: {}'.format(max_func_len)) 203 | for function_ids in function_ids_list: 204 | while len(function_ids) < max_func_len: 205 | function_ids.append(vocab['function2id']['']) 206 | assert len(function_ids) == max_func_len 207 | relation_pos_list = [] 208 | relation_id_list = [] 209 | for item in dataset: 210 | relation = item['relation'] 211 | relation_pos_list.append([relation[0]]) 212 | relation_id_list.append([relation[1]]) 213 | 214 | input_ids_list = encoded_inputs['input_ids'] 215 | token_type_ids_list = encoded_inputs['token_type_ids'] 216 | attention_mask_list = encoded_inputs['attention_mask'] 217 | # verbose = False 218 | # if verbose: 219 | # for idx in range(10): 220 | # question = tokenizer.decode(input_ids_list[idx]) 221 | # functions = [vocab['id2function'][id] for id in function_ids_list[idx]] 222 | # relation_pos = relation_pos_list[idx][0] 223 | # relation_id = vocab['id2relation'][relation_id_list[idx][0]] 224 | # print(question, functions, relation_pos, relation_id) 225 | 226 | input_ids_list = np.array(input_ids_list, dtype=np.int32) 227 | token_type_ids_list = np.array(token_type_ids_list, dtype=np.int32) 228 | attention_mask_list = np.array(attention_mask_list, dtype=np.int32) 229 | function_ids_list = np.array(function_ids_list, dtype=np.int32) 230 | relation_pos_list = np.array(relation_pos_list, dtype=np.int32) 231 | relation_id_list = np.array(relation_id_list, dtype=np.int32) 232 | 233 | return input_ids_list, token_type_ids_list, attention_mask_list, function_ids_list, relation_pos_list, relation_id_list 234 | 235 | 236 | def encode_concept_dataset(args, vocab, dataset): 237 | def get_function_ids(program): 238 | function_ids = [f['function'] for f in program] 239 | return function_ids 240 | 241 | tmp = [] 242 | for item in dataset: 243 | question = item['question'] 244 | program = item['program'] 245 | concepts = item['concepts'] 246 | for concept in concepts: 247 | tmp.append({'question': question, 'program': program, 'concept': concept}) 248 | print('dataset size: {}'.format(len(dataset))) 249 | dataset = tmp 250 | print('new dataset size: {}'.format(len(dataset))) 251 | questions = [] 252 | for item in dataset: 253 | question = item['question'] 254 | questions.append(question) 255 | encoded_inputs = tokenizer(questions, padding = True) 256 | # print(encoded_inputs.keys()) 257 | # print(len(encoded_inputs['input_ids'][0])) 258 | # print(len(encoded_inputs['token_type_ids'][0])) 259 | # print(len(encoded_inputs['attention_mask'][0])) 260 | # print(tokenizer.decode(encoded_inputs['input_ids'][0])) 261 | max_seq_length = len(encoded_inputs['input_ids'][0]) 262 | function_ids_list = [] 263 | for item in tqdm(dataset): 264 | program = item['program'] 265 | program = [{'function': vocab['function2id']['']}] + program + [{'function': vocab['function2id']['']}] 266 | function_ids = get_function_ids(program) 267 | function_ids_list.append(function_ids) 268 | max_func_len = max([len(function_ids) for function_ids in function_ids_list]) 269 | print('max_func_len: {}'.format(max_func_len)) 270 | for function_ids in function_ids_list: 271 | while len(function_ids) < max_func_len: 272 | function_ids.append(vocab['function2id']['']) 273 | assert len(function_ids) == max_func_len 274 | relation_pos_list = [] 275 | relation_id_list = [] 276 | for item in dataset: 277 | relation = item['concept'] 278 | relation_pos_list.append([relation[0]]) 279 | relation_id_list.append([relation[1]]) 280 | 281 | input_ids_list = encoded_inputs['input_ids'] 282 | token_type_ids_list = encoded_inputs['token_type_ids'] 283 | attention_mask_list = encoded_inputs['attention_mask'] 284 | verbose = False 285 | if verbose: 286 | for idx in range(10): 287 | question = tokenizer.decode(input_ids_list[idx]) 288 | functions = [vocab['id2function'][id] for id in function_ids_list[idx]] 289 | relation_pos = relation_pos_list[idx][0] 290 | relation_id = vocab['id2concept'][relation_id_list[idx][0]] 291 | print(question, functions, relation_pos, relation_id) 292 | 293 | input_ids_list = np.array(input_ids_list, dtype=np.int32) 294 | token_type_ids_list = np.array(token_type_ids_list, dtype=np.int32) 295 | attention_mask_list = np.array(attention_mask_list, dtype=np.int32) 296 | function_ids_list = np.array(function_ids_list, dtype=np.int32) 297 | relation_pos_list = np.array(relation_pos_list, dtype=np.int32) 298 | relation_id_list = np.array(relation_id_list, dtype=np.int32) 299 | 300 | return input_ids_list, token_type_ids_list, attention_mask_list, function_ids_list, relation_pos_list, relation_id_list 301 | 302 | 303 | def main(): 304 | parser = argparse.ArgumentParser() 305 | parser.add_argument('--input_dir', required = True, type = str) 306 | parser.add_argument('--output_dir', required = True, type = str) 307 | args = parser.parse_args() 308 | print(args) 309 | if not os.path.isdir(args.output_dir): 310 | os.makedirs(args.output_dir) 311 | if not os.path.isdir(os.path.join(args.output_dir, 'relation')): 312 | os.makedirs(os.path.join(args.output_dir, 'relation')) 313 | if not os.path.isdir(os.path.join(args.output_dir, 'concept')): 314 | os.makedirs(os.path.join(args.output_dir, 'concept')) 315 | vocab = { 316 | 'relation2id': { 317 | '': 0 318 | }, 319 | 'concept2id': { 320 | '': 0 321 | }, 322 | 'function2id':{ 323 | '': 0, 324 | '': 1, 325 | '':2 326 | } 327 | } 328 | get_vocab(args, vocab) 329 | 330 | for k in vocab: 331 | print('{}:{}'.format(k, len(vocab[k]))) 332 | fn = os.path.join(args.output_dir, 'vocab.json') 333 | print('Dump vocab to {}'.format(fn)) 334 | with open(fn, 'w') as f: 335 | json.dump(vocab, f, indent=2) 336 | 337 | outputs = encode_relation(args, vocab) 338 | with open(os.path.join(args.output_dir, 'relation', 'relation.pt'), 'wb') as f: 339 | for o in outputs: 340 | print(o.shape) 341 | pickle.dump(o, f) 342 | 343 | outputs = encode_concept(args, vocab) 344 | with open(os.path.join(args.output_dir, 'concept', 'concept.pt'), 'wb') as f: 345 | for o in outputs: 346 | print(o.shape) 347 | pickle.dump(o, f) 348 | 349 | get_relation_dataset(args, vocab) 350 | get_concept_dataset(args, vocab) 351 | # vocab = json.load(open(os.path.join(args.output_dir, 'vocab.json'))) 352 | for name in ['train', 'dev']: 353 | dataset = [] 354 | with open(os.path.join(args.output_dir, 'relation', '%s.json'%(name))) as f: 355 | for line in f: 356 | dataset.append(json.loads(line.strip())) 357 | outputs = encode_relation_dataset(args, vocab, dataset) 358 | assert len(outputs) == 6 359 | print('shape of input_ids, token_type_ids, attention_mask, function_ids, relation_pos, relation_id:') 360 | with open(os.path.join(args.output_dir, 'relation', '{}.pt'.format(name)), 'wb') as f: 361 | for o in outputs: 362 | print(o.shape) 363 | pickle.dump(o, f) 364 | 365 | for name in ['train', 'dev']: 366 | dataset = [] 367 | with open(os.path.join(args.output_dir, 'concept', '%s.json'%(name))) as f: 368 | for line in f: 369 | dataset.append(json.loads(line.strip())) 370 | outputs = encode_concept_dataset(args, vocab, dataset) 371 | assert len(outputs) == 6 372 | print('shape of input_ids, token_type_ids, attention_mask, function_ids, relation_pos, relation_id:') 373 | with open(os.path.join(args.output_dir, 'concept', '{}.pt'.format(name)), 'wb') as f: 374 | for o in outputs: 375 | print(o.shape) 376 | pickle.dump(o, f) 377 | 378 | 379 | 380 | if __name__ == "__main__": 381 | main() 382 | -------------------------------------------------------------------------------- /Pretraining/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | import numpy as np 9 | from Pretraining.utils import * 10 | from Pretraining.model import RelationPT 11 | from Pretraining.data import DataLoader 12 | from transformers import * 13 | from Pretraining.lr_scheduler import get_linear_schedule_with_warmup 14 | from Pretraining.metric import * 15 | import logging 16 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 17 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 18 | rootLogger = logging.getLogger() 19 | import torch.optim as optim 20 | from IPython import embed 21 | 22 | 23 | 24 | 25 | def evaluate(args, concept_inputs, relation_inputs, model, relation_eval_loader, concept_eval_loader, device, prefix = ''): 26 | eval_output_dir = args.output_dir 27 | if not os.path.exists(eval_output_dir): 28 | os.makedirs(eval_output_dir) 29 | # Eval! 30 | nb_eval_steps = 0 31 | func_metric = FunctionAcc(relation_eval_loader.vocab['function2id']['']) 32 | pbar = ProgressBar(n_total=len(relation_eval_loader), desc="Evaluating") 33 | correct = 0 34 | tot = 0 35 | for step, batch in enumerate(relation_eval_loader): 36 | model.eval() 37 | batch = tuple(t.to(device) for t in batch) 38 | # print(batch[4].size()) 39 | with torch.no_grad(): 40 | batch = tuple(t.to(device) for t in batch) 41 | inputs = { 42 | 'concept_inputs': concept_inputs, 43 | 'relation_inputs': relation_inputs, 44 | 'input_ids': batch[0], 45 | 'token_type_ids': batch[1], 46 | 'attention_mask': batch[2], 47 | 'function_ids': batch[3], 48 | 'relation_info': (batch[4], None), 49 | 'concept_info': None 50 | } 51 | outputs = model(**inputs) 52 | pred_functions = outputs['pred_functions'].cpu().tolist() 53 | pred_relation = outputs['pred_relation'] 54 | gt_relation = batch[5] 55 | gt_relation = gt_relation.squeeze(-1) 56 | # print(pred_relation.size(), gt_relation.size(), batch[3].size()) 57 | correct += torch.sum(torch.eq(pred_relation, gt_relation).float()) 58 | # print(correct) 59 | tot += len(pred_relation) 60 | gt_functions = batch[3].cpu().tolist() 61 | for pred, gt in zip(pred_functions, gt_functions): 62 | func_metric.update(pred, gt) 63 | nb_eval_steps += 1 64 | pbar(step) 65 | logging.info('') 66 | acc = func_metric.result() 67 | logging.info('**** function results %s ****', prefix) 68 | info = 'acc: {}'.format(acc) 69 | logging.info(info) 70 | acc = correct.item() / tot 71 | logging.info('**** relation results %s ****', prefix) 72 | logging.info('acc: {}'.format(acc)) 73 | 74 | 75 | 76 | nb_eval_steps = 0 77 | func_metric = FunctionAcc(concept_eval_loader.vocab['function2id']['']) 78 | pbar = ProgressBar(n_total=len(concept_eval_loader), desc="Evaluating") 79 | correct = 0 80 | tot = 0 81 | for step, batch in enumerate(concept_eval_loader): 82 | model.eval() 83 | batch = tuple(t.to(device) for t in batch) 84 | # print(batch[4].size()) 85 | with torch.no_grad(): 86 | batch = tuple(t.to(device) for t in batch) 87 | inputs = { 88 | 'concept_inputs': concept_inputs, 89 | 'relation_inputs': relation_inputs, 90 | 'input_ids': batch[0], 91 | 'token_type_ids': batch[1], 92 | 'attention_mask': batch[2], 93 | 'function_ids': batch[3], 94 | 'relation_info': None, 95 | 'concept_info': (batch[4], None) 96 | } 97 | outputs = model(**inputs) 98 | pred_functions = outputs['pred_functions'].cpu().tolist() 99 | pred_relation = outputs['pred_concept'] 100 | gt_relation = batch[5] 101 | gt_relation = gt_relation.squeeze(-1) 102 | # print(pred_relation.size(), gt_relation.size(), batch[3].size()) 103 | correct += torch.sum(torch.eq(pred_relation, gt_relation).float()) 104 | # print(correct) 105 | tot += len(pred_relation) 106 | gt_functions = batch[3].cpu().tolist() 107 | for pred, gt in zip(pred_functions, gt_functions): 108 | func_metric.update(pred, gt) 109 | nb_eval_steps += 1 110 | pbar(step) 111 | logging.info('') 112 | acc = func_metric.result() 113 | logging.info('**** function results %s ****', prefix) 114 | info = 'acc: {}'.format(acc) 115 | logging.info(info) 116 | acc = correct.item() / tot 117 | logging.info('**** concept results %s ****', prefix) 118 | logging.info('acc: {}'.format(acc)) 119 | 120 | 121 | def train(args): 122 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 123 | 124 | logging.info("Create train_loader and val_loader.........") 125 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 126 | relation_train_pt = os.path.join(args.input_dir, 'relation', 'train.pt') 127 | relation_val_pt = os.path.join(args.input_dir, 'relation', 'dev.pt') 128 | relation_train_loader = DataLoader(vocab_json, relation_train_pt, args.batch_size, training=True) 129 | relation_val_loader = DataLoader(vocab_json, relation_val_pt, args.batch_size) 130 | 131 | concept_train_pt = os.path.join(args.input_dir, 'concept', 'train.pt') 132 | concept_val_pt = os.path.join(args.input_dir, 'concept', 'dev.pt') 133 | concept_train_loader = DataLoader(vocab_json, concept_train_pt, args.batch_size, training=True) 134 | concept_val_loader = DataLoader(vocab_json, concept_val_pt, args.batch_size) 135 | 136 | with open(os.path.join(args.input_dir, 'relation', 'relation.pt'), 'rb') as f: 137 | input_ids = pickle.load(f) 138 | token_type_ids = pickle.load(f) 139 | attention_mask = pickle.load(f) 140 | input_ids = torch.LongTensor(input_ids).to(device) 141 | token_type_ids = torch.LongTensor(token_type_ids).to(device) 142 | attention_mask = torch.LongTensor(attention_mask).to(device) 143 | relation_inputs = { 144 | 'input_ids': input_ids, 145 | 'token_type_ids': token_type_ids, 146 | 'attention_mask': attention_mask 147 | } 148 | 149 | 150 | with open(os.path.join(args.input_dir, 'concept', 'concept.pt'), 'rb') as f: 151 | input_ids = pickle.load(f) 152 | token_type_ids = pickle.load(f) 153 | attention_mask = pickle.load(f) 154 | input_ids = torch.LongTensor(input_ids).to(device) 155 | token_type_ids = torch.LongTensor(token_type_ids).to(device) 156 | attention_mask = torch.LongTensor(attention_mask).to(device) 157 | concept_inputs = { 158 | 'input_ids': input_ids, 159 | 'token_type_ids': token_type_ids, 160 | 'attention_mask': attention_mask 161 | } 162 | 163 | vocab = relation_train_loader.vocab 164 | 165 | logging.info("Create model.........") 166 | config_class, model_class, tokenizer_class = (BertConfig, RelationPT, BertTokenizer) 167 | config = config_class.from_pretrained(args.model_name_or_path, num_labels = len(label_list)) 168 | config.update({'vocab': vocab}) 169 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, do_lower_case = False) 170 | model = model_class.from_pretrained(args.model_name_or_path, config = config) 171 | model = model.to(device) 172 | # logging.info(model) 173 | 174 | 175 | t_total = (len(relation_train_loader) + len(concept_train_loader)) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) 176 | no_decay = ["bias", "LayerNorm.weight"] 177 | bert_param_optimizer = list(model.bert.named_parameters()) 178 | linear_param_optimizer = list(model.function_embeddings.named_parameters()) + list(model.function_classifier.named_parameters()) + list(model.function_decoder.named_parameters()) + list(model.relation_classifier.named_parameters()) + list(model.concept_classifier.named_parameters()) 179 | optimizer_grouped_parameters = [ 180 | {'params': [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 181 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate}, 182 | {'params': [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 183 | 'lr': args.learning_rate}, 184 | {'params': [p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay)], 185 | 'weight_decay': args.weight_decay, 'lr': args.crf_learning_rate}, 186 | {'params': [p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 187 | 'lr': args.crf_learning_rate} 188 | ] 189 | args.warmup_steps = int(t_total * args.warmup_proportion) 190 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 191 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 192 | num_training_steps=t_total) 193 | # Check if saved optimizer or scheduler states exist 194 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 195 | os.path.join(args.model_name_or_path, "scheduler.pt")): 196 | # Load in optimizer and scheduler states 197 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 198 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 199 | 200 | # Train! 201 | logging.info("***** Running training *****") 202 | logging.info(" Num examples = %d", len(relation_train_loader.dataset)) 203 | logging.info(" Num Epochs = %d", args.num_train_epochs) 204 | logging.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 205 | logging.info(" Total optimization steps = %d", t_total) 206 | 207 | global_step = 0 208 | steps_trained_in_current_epoch = 0 209 | # Check if continuing training from a checkpoint 210 | if os.path.exists(args.model_name_or_path) and "checkpoint" in args.model_name_or_path: 211 | # set global_step to gobal_step of last saved checkpoint from model path 212 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 213 | epochs_trained = global_step // (len(relation_train_loader) // args.gradient_accumulation_steps) 214 | steps_trained_in_current_epoch = global_step % (len(relation_train_loader) // args.gradient_accumulation_steps) 215 | logging.info(" Continuing training from checkpoint, will skip to saved global_step") 216 | logging.info(" Continuing training from epoch %d", epochs_trained) 217 | logging.info(" Continuing training from global step %d", global_step) 218 | logging.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 219 | logging.info('Checking...') 220 | logging.info("===================Dev==================") 221 | # evaluate(args, concept_inputs, relation_inputs, model, relation_val_loader, concept_val_loader, device) 222 | tr_loss, logging_loss = 0.0, 0.0 223 | model.zero_grad() 224 | for _ in range(int(args.num_train_epochs)): 225 | logging.info('relation training begins') 226 | pbar = ProgressBar(n_total=len(relation_train_loader), desc='Training') 227 | for step, batch in enumerate(relation_train_loader): 228 | # Skip past any already trained steps if resuming training 229 | if steps_trained_in_current_epoch > 0: 230 | steps_trained_in_current_epoch -= 1 231 | continue 232 | model.train() 233 | batch = tuple(t.to(device) for t in batch) 234 | inputs = { 235 | 'concept_inputs': concept_inputs, 236 | 'relation_inputs': relation_inputs, 237 | 'input_ids': batch[0], 238 | 'token_type_ids': batch[1], 239 | 'attention_mask': batch[2], 240 | 'function_ids': batch[3], 241 | 'relation_info': (batch[4], batch[5]), 242 | 'concept_info': None 243 | } 244 | outputs = model(**inputs) 245 | loss = args.func * outputs['function_loss'] + args.rel * outputs['relation_loss'] 246 | loss.backward() 247 | pbar(step, {'loss': loss.item()}) 248 | tr_loss += loss.item() 249 | if (step + 1) % args.gradient_accumulation_steps == 0: 250 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 251 | optimizer.step() 252 | scheduler.step() # Update learning rate schedule 253 | model.zero_grad() 254 | global_step += 1 255 | 256 | 257 | logging.info('concept training begins') 258 | pbar = ProgressBar(n_total=len(concept_train_loader), desc='Training') 259 | for step, batch in enumerate(concept_train_loader): 260 | # Skip past any already trained steps if resuming training 261 | if steps_trained_in_current_epoch > 0: 262 | steps_trained_in_current_epoch -= 1 263 | continue 264 | model.train() 265 | batch = tuple(t.to(device) for t in batch) 266 | inputs = { 267 | 'concept_inputs': concept_inputs, 268 | 'relation_inputs': relation_inputs, 269 | 'input_ids': batch[0], 270 | 'token_type_ids': batch[1], 271 | 'attention_mask': batch[2], 272 | 'function_ids': batch[3], 273 | 'relation_info': None, 274 | 'concept_info': (batch[4], batch[5]) 275 | } 276 | outputs = model(**inputs) 277 | loss = args.func * outputs['function_loss'] + args.con * outputs['concept_loss'] 278 | loss.backward() 279 | pbar(step, {'loss': loss.item()}) 280 | tr_loss += loss.item() 281 | if (step + 1) % args.gradient_accumulation_steps == 0: 282 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 283 | optimizer.step() 284 | scheduler.step() # Update learning rate schedule 285 | model.zero_grad() 286 | global_step += 1 287 | 288 | # break 289 | # Save model checkpoint 290 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 291 | if not os.path.exists(output_dir): 292 | os.makedirs(output_dir) 293 | model_to_save = ( 294 | model.module if hasattr(model, "module") else model 295 | ) # Take care of distributed/parallel training 296 | model_to_save.save_pretrained(output_dir) 297 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 298 | logging.info("Saving model checkpoint to %s", output_dir) 299 | tokenizer.save_vocabulary(output_dir) 300 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 301 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 302 | logging.info("Saving optimizer and scheduler states to %s", output_dir) 303 | logging.info("\n") 304 | if 'cuda' in str(device): 305 | torch.cuda.empty_cache() 306 | evaluate(args, concept_inputs, relation_inputs, model, relation_val_loader, concept_val_loader, device) 307 | 308 | return global_step, tr_loss / global_step 309 | 310 | 311 | def main(): 312 | parser = argparse.ArgumentParser() 313 | # input and output 314 | parser.add_argument('--input_dir', required=True) 315 | parser.add_argument('--output_dir', required=True) 316 | 317 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 318 | # parser.add_argument('--glove_pt', default='/data/csl/resources/word2vec/glove.840B.300d.py36.pt') 319 | parser.add_argument('--model_name_or_path', default = '/data/csl/resources/Bert/bert-base-cased') 320 | # parser.add_argument('--ckpt') 321 | 322 | # training parameters 323 | parser.add_argument('--weight_decay', default=1e-5, type=float) 324 | parser.add_argument('--batch_size', default=128, type=int) 325 | parser.add_argument('--seed', type=int, default=666, help='random seed') 326 | parser.add_argument('--learning_rate', default=3e-5, type = float) 327 | parser.add_argument('--crf_learning_rate', default=1e-3, type = float) 328 | parser.add_argument('--num_train_epochs', default=25, type = int) 329 | parser.add_argument('--save_steps', default=448, type = int) 330 | parser.add_argument('--logging_steps', default=448, type = int) 331 | parser.add_argument('--warmup_proportion', default=0.1, type = float, 332 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.") 333 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 334 | help="Epsilon for Adam optimizer.") 335 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 336 | help="Number of updates steps to accumulate before performing a backward/update pass.", ) 337 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 338 | help="Max gradient norm.") 339 | # model hyperparameters 340 | parser.add_argument('--dim_word', default=300, type=int) 341 | parser.add_argument('--dim_hidden', default=1024, type=int) 342 | parser.add_argument('--alpha', default = 1, type = float) 343 | parser.add_argument('--beta', default = 1e-1, type = float) 344 | parser.add_argument('--func', default = 1, type = float) 345 | parser.add_argument('--rel', default = 1, type = float) 346 | parser.add_argument('--con', default = 1, type = float) 347 | 348 | 349 | args = parser.parse_args() 350 | 351 | if not os.path.exists(args.save_dir): 352 | os.makedirs(args.save_dir) 353 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 354 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.log'.format(time_))) 355 | fileHandler.setFormatter(logFormatter) 356 | rootLogger.addHandler(fileHandler) 357 | # args display 358 | for k, v in vars(args).items(): 359 | logging.info(k+':'+str(v)) 360 | 361 | seed_everything(666) 362 | 363 | train(args) 364 | 365 | 366 | if __name__ == '__main__': 367 | main() 368 | -------------------------------------------------------------------------------- /Pretraining/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter, deque 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | import random 8 | import os 9 | import time 10 | START_RELATION = 'START_RELATION' 11 | NO_OP_RELATION = 'NO_OP_RELATION' 12 | NO_OP_ENTITY = 'NO_OP_ENTITY' 13 | DUMMY_RELATION = 'DUMMY_RELATION' 14 | DUMMY_ENTITY = 'DUMMY_ENTITY' 15 | 16 | DUMMY_RELATION_ID = 0 17 | START_RELATION_ID = 1 18 | NO_OP_RELATION_ID = 2 19 | DUMMY_ENTITY_ID = 0 20 | NO_OP_ENTITY_ID = 1 21 | 22 | EPSILON = float(np.finfo(float).eps) 23 | HUGE_INT = 1e31 24 | 25 | label_list = ['X', 'B-E', 'I-E', 'S-E', 'O'] 26 | label2id = {label: i for i, label in enumerate(label_list)} 27 | id2label = {i: label for i, label in enumerate(label_list)} 28 | 29 | 30 | def safe_log(x): 31 | return torch.log(x + EPSILON) 32 | 33 | def entropy(p): 34 | return torch.sum(- p * safe_log(p), 1) 35 | 36 | def idx_to_one_hot(idx, size, device): 37 | """ 38 | Args: 39 | idx [bsz, 1] 40 | Return: 41 | one_hot [bsz, size] 42 | """ 43 | one_hot = torch.FloatTensor(len(idx), size).to(device) 44 | one_hot.zero_() 45 | one_hot.scatter_(1, idx, 1) 46 | return one_hot 47 | 48 | def format_path(path_trace, id2entity, id2relation): 49 | def get_most_recent_relation(j): 50 | relation_id = int(path_trace[j][0]) 51 | if relation_id == NO_OP_RELATION_ID: 52 | return '' 53 | else: 54 | return id2relation[relation_id] 55 | 56 | def get_most_recent_entity(j): 57 | return id2entity[int(path_trace[j][1])] 58 | 59 | path_str = get_most_recent_entity(0) 60 | for j in range(1, len(path_trace)): 61 | rel = get_most_recent_relation(j) 62 | if not rel.endswith('_inv'): 63 | path_str += ' -{}-> '.format(rel) 64 | else: 65 | path_str += ' <-{}- '.format(rel[:-4]) 66 | path_str += get_most_recent_entity(j) 67 | return path_str 68 | 69 | def pad_and_cat(a, padding_value, padding_dim=1): 70 | max_dim_size = max([x.size()[padding_dim] for x in a]) 71 | padded_a = [] 72 | for x in a: 73 | if x.size()[padding_dim] < max_dim_size: 74 | res_len = max_dim_size - x.size()[1] 75 | pad = nn.ConstantPad1d((0, res_len), padding_value) 76 | padded_a.append(pad(x)) 77 | else: 78 | padded_a.append(x) 79 | return torch.cat(padded_a, dim=0) 80 | 81 | def safe_log(x): 82 | return torch.log(x + EPSILON) 83 | 84 | def entropy(p): 85 | return torch.sum(- p * safe_log(p), 1) 86 | 87 | def add_item_to_x2id(item, x2id): 88 | if not item in x2id: 89 | x2id[item] = len(x2id) 90 | 91 | def tile_along_beam(v, beam_size, dim=0): 92 | """ 93 | Tile a tensor along a specified dimension for the specified beam size. 94 | :param v: Input tensor. 95 | :param beam_size: Beam size. 96 | """ 97 | if dim == -1: 98 | dim = len(v.size()) - 1 99 | v = v.unsqueeze(dim + 1) 100 | v = torch.cat([v] * beam_size, dim=dim+1) 101 | new_size = [] 102 | for i, d in enumerate(v.size()): 103 | if i == dim + 1: 104 | new_size[-1] *= d 105 | else: 106 | new_size.append(d) 107 | return v.view(new_size) 108 | 109 | def init_vocab(): 110 | return { 111 | '': 0, 112 | '': 1, 113 | '': 2, 114 | '': 3 115 | } 116 | 117 | def invert_dict(d): 118 | return {v: k for k, v in d.items()} 119 | 120 | def load_glove(glove_pt, idx_to_token): 121 | glove = pickle.load(open(glove_pt, 'rb')) 122 | dim = len(glove['the']) 123 | matrix = [] 124 | for i in range(len(idx_to_token)): 125 | token = idx_to_token[i] 126 | tokens = token.split() 127 | if len(tokens) > 1: 128 | v = np.zeros((dim,)) 129 | for token in tokens: 130 | v = v + glove.get(token, glove['the']) 131 | v = v / len(tokens) 132 | else: 133 | v = glove.get(token, glove['the']) 134 | matrix.append(v) 135 | matrix = np.asarray(matrix) 136 | return matrix 137 | 138 | 139 | class SmoothedValue(object): 140 | """Track a series of values and provide access to smoothed values over a 141 | window or the global series average. 142 | """ 143 | 144 | def __init__(self, window_size=20): 145 | self.deque = deque(maxlen=window_size) 146 | self.series = [] 147 | self.total = 0.0 148 | self.count = 0 149 | 150 | def update(self, value): 151 | self.deque.append(value) 152 | self.series.append(value) 153 | self.count += 1 154 | self.total += value 155 | 156 | @property 157 | def median(self): 158 | d = torch.tensor(list(self.deque)) 159 | return d.median().item() 160 | 161 | @property 162 | def avg(self): 163 | d = torch.tensor(list(self.deque)) 164 | return d.mean().item() 165 | 166 | @property 167 | def global_avg(self): 168 | return self.total / self.count 169 | 170 | 171 | class MetricLogger(object): 172 | def __init__(self, delimiter="\t"): 173 | self.meters = defaultdict(SmoothedValue) 174 | self.delimiter = delimiter 175 | 176 | def update(self, **kwargs): 177 | for k, v in kwargs.items(): 178 | if isinstance(v, torch.Tensor): 179 | v = v.item() 180 | assert isinstance(v, (float, int)) 181 | self.meters[k].update(v) 182 | 183 | def __getattr__(self, attr): 184 | if attr in self.meters: 185 | return self.meters[attr] 186 | if attr in self.__dict__: 187 | return self.__dict__[attr] 188 | raise AttributeError("'{}' object has no attribute '{}'".format( 189 | type(self).__name__, attr)) 190 | 191 | def __str__(self): 192 | loss_str = [] 193 | for name, meter in self.meters.items(): 194 | loss_str.append( 195 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 196 | ) 197 | return self.delimiter.join(loss_str) 198 | 199 | def seed_everything(seed=1029): 200 | ''' 201 | 设置整个开发环境的seed 202 | :param seed: 203 | :param device: 204 | :return: 205 | ''' 206 | random.seed(seed) 207 | os.environ['PYTHONHASHSEED'] = str(seed) 208 | np.random.seed(seed) 209 | torch.manual_seed(seed) 210 | torch.cuda.manual_seed(seed) 211 | torch.cuda.manual_seed_all(seed) 212 | # some cudnn methods can be random even after fixing the seed 213 | # unless you tell it to be deterministic 214 | torch.backends.cudnn.deterministic = True 215 | 216 | 217 | class ProgressBar(object): 218 | ''' 219 | custom progress bar 220 | Example: 221 | >>> pbar = ProgressBar(n_total=30,desc='training') 222 | >>> step = 2 223 | >>> pbar(step=step) 224 | ''' 225 | def __init__(self, n_total,width=30,desc = 'Training'): 226 | self.width = width 227 | self.n_total = n_total 228 | self.start_time = time.time() 229 | self.desc = desc 230 | 231 | def __call__(self, step, info={}): 232 | now = time.time() 233 | current = step + 1 234 | recv_per = current / self.n_total 235 | bar = f'[{self.desc}] {current}/{self.n_total} [' 236 | if recv_per >= 1: 237 | recv_per = 1 238 | prog_width = int(self.width * recv_per) 239 | if prog_width > 0: 240 | bar += '=' * (prog_width - 1) 241 | if current< self.n_total: 242 | bar += ">" 243 | else: 244 | bar += '=' 245 | bar += '.' * (self.width - prog_width) 246 | bar += ']' 247 | show_bar = f"\r{bar}" 248 | time_per_unit = (now - self.start_time) / current 249 | if current < self.n_total: 250 | eta = time_per_unit * (self.n_total - current) 251 | if eta > 3600: 252 | eta_format = ('%d:%02d:%02d' % 253 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 254 | elif eta > 60: 255 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 256 | else: 257 | eta_format = '%ds' % eta 258 | time_info = f' - ETA: {eta_format}' 259 | else: 260 | if time_per_unit >= 1: 261 | time_info = f' {time_per_unit:.1f}s/step' 262 | elif time_per_unit >= 1e-3: 263 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 264 | else: 265 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 266 | 267 | show_bar += time_info 268 | if len(info) != 0: 269 | show_info = f'{show_bar} ' + \ 270 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 271 | print(show_info, end='') 272 | else: 273 | print(show_bar, end='') 274 | 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProgramTransfer 2 | Official code and data of the ACL 2022 paper "Program Transfer for Complex Question Answering over Knowledge Bases" 3 | --------------------------------------------------------------------------------