├── README.md ├── SeqMix.png ├── active_learn.py ├── data ├── test.txt ├── train.txt └── valid.txt ├── data_load.py ├── model.py ├── requirements.txt └── seqmix.py /README.md: -------------------------------------------------------------------------------- 1 | # SeqMix 2 | The repository of our EMNLP'20 paper 3 | **SeqMix: Augmenting Active Sequence Labeling via Sequence Mixup** 4 | [[paper]](https://rongzhizhang.org/pdf/emnlp20_SeqMix.pdf) [[slides]](https://rongzhizhang.org/slides/EMNLP20_SeqMix_Slides.pdf) 5 | 6 | ![Illustration of the three variants of SeqMix](SeqMix.png) 7 | 8 | # Requirements 9 | - pytorch-transformers==1.2.0 10 | - torch==1.2.0 11 | - seqeval==0.0.5 12 | - tqdm==4.31.1 13 | - nltk==3.4.5 14 | - Flask==1.1.1 15 | - Flask-Cors==3.0.8 16 | - pytorch_pretrained_bert==0.6.2 17 | 18 | Install the required packages: 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | # Key Parameters 24 | - `data_dir`: specify the data file, we provide CoNLL-03 dataset here 25 | - `max_seq_length`: maximum length of each sequence 26 | - `num_train_epochs`: number of training epochs 27 | - `train_batch_size`: batch size during model training 28 | - `active_policy`: query policy of active learning 29 | - `augment_method`: augmenting method 30 | - `augment_rate`: augmenting rate 31 | - `hyper_alpha`: parameter of Beta distribution 32 | 33 | # Run 34 | ## Active learning part 35 | Random Sampling 36 | ``` 37 | python active_learn.py --active_policy=random 38 | ``` 39 | Least Confidence Sampling 40 | ``` 41 | python active_learn.py --active_policy=lc 42 | ``` 43 | Normalized Token Entropy sampling 44 | ``` 45 | python active_learn.py --active_policy=nte 46 | ``` 47 | 48 | ## Seqmix part 49 | Whole sequence mixup 50 | ``` 51 | python active_learn.py --augment_method=soft 52 | ``` 53 | Sub-sequence mixup 54 | ``` 55 | python active_learn.py --augment_method=slack 56 | ``` 57 | Label-constrained sub-sequence mixup 58 | ``` 59 | python active_learn.py --augment_method=lf 60 | ``` 61 | 62 | -------------------------------------------------------------------------------- /SeqMix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rz-zhang/SeqMix/149e02e079bbe798595eb9fabb6a1c03e5cd4f12/SeqMix.png -------------------------------------------------------------------------------- /active_learn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import copy 9 | import math 10 | import time 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from prettytable import PrettyTable 16 | from torch.autograd import Variable 17 | from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig, 18 | BertForTokenClassification, BertTokenizer, 19 | WarmupLinearSchedule) 20 | from torch import nn 21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 22 | TensorDataset) 23 | from torch.utils.data.distributed import DistributedSampler 24 | from tqdm import tqdm, trange 25 | from seqeval.metrics import classification_report 26 | 27 | from model import Ner 28 | from data_load import readfile, NerProcessor, convert_examples_to_features 29 | 30 | 31 | def get_tr_set(size=None, train_examples=None, batch_size=32, soft_labels=[], args=None): 32 | train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, logger) 33 | if size: # return part of features 34 | select_idx = np.random.choice(range(len(train_features)), size=size, replace=False) 35 | train_features = list(np.array(train_features)[select_idx]) 36 | 37 | logger.info(" Num examples = %d", len(train_examples)) 38 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 39 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 40 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 41 | all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long) 42 | all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long) 43 | if len(soft_labels): 44 | all_label_ids = torch.tensor([soft_label for soft_label in soft_labels], dtype=torch.float64) 45 | else: 46 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 47 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_valid_ids, all_lmask_ids) 48 | if args.local_rank == -1: 49 | train_sampler = RandomSampler(train_data) 50 | else: 51 | train_sampler = DistributedSampler(train_data) 52 | 53 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 54 | if size: 55 | return train_dataloader, select_idx 56 | return train_dataloader 57 | 58 | def get_eval_set(eval_on, eval_batch_size=8): 59 | if eval_on == "dev": 60 | eval_examples = processor.get_dev_examples(args.data_dir) 61 | elif eval_on == "test": 62 | eval_examples = processor.get_test_examples(args.data_dir) 63 | else: 64 | raise ValueError("eval on dev or test set only") 65 | eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, logger) 66 | logger.info("***** Running evaluation *****") 67 | logger.info(" Num examples = %d", len(eval_examples)) 68 | logger.info(" Batch size = %d", args.eval_batch_size) 69 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 70 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 71 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 72 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 73 | all_valid_ids = torch.tensor([f.valid_ids for f in eval_features], dtype=torch.long) 74 | all_lmask_ids = torch.tensor([f.label_mask for f in eval_features], dtype=torch.long) 75 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids) 76 | # Run prediction for full data 77 | eval_sampler = SequentialSampler(eval_data) 78 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size) 79 | return eval_dataloader 80 | 81 | '''Evaluation''' 82 | def evaluate(prefix=None, model=None, args=None): 83 | eval_dataloader = get_eval_set(eval_on=args.eval_on, eval_batch_size=args.eval_batch_size) 84 | model.to(device) 85 | model.eval() 86 | eval_loss, eval_accuracy = 0, 0 87 | nb_eval_steps, nb_eval_examples = 0, 0 88 | y_true = [] 89 | y_pred = [] 90 | raw_logits = [] 91 | label_map = {i : label for i, label in enumerate(label_list,1)} 92 | for input_ids, input_mask, segment_ids, label_ids,valid_ids,l_mask in tqdm(eval_dataloader, desc="Evaluating"): 93 | input_ids = input_ids.to(device) 94 | input_mask = input_mask.to(device) 95 | segment_ids = segment_ids.to(device) 96 | valid_ids = valid_ids.to(device) 97 | label_ids = label_ids.to(device) 98 | l_mask = l_mask.to(device) 99 | 100 | with torch.no_grad(): 101 | logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask) 102 | 103 | #raw_logits.append(logits) 104 | logits = torch.argmax(F.log_softmax(logits,dim=2),dim=2) 105 | logits = logits.detach().cpu().numpy() 106 | label_ids = label_ids.to('cpu').numpy() 107 | input_mask = input_mask.to('cpu').numpy() 108 | 109 | for i, label in enumerate(label_ids): 110 | temp_1 = [] 111 | temp_2 = [] 112 | for j,m in enumerate(label): 113 | if j == 0: 114 | continue 115 | elif label_ids[i][j] == len(label_map): 116 | y_true.append(temp_1) 117 | y_pred.append(temp_2) 118 | break 119 | else: 120 | temp_1.append(label_map[label_ids[i][j]]) 121 | try: 122 | temp_2.append(label_map[logits[i][j]]) 123 | except: 124 | temp_2.append('UKN') 125 | 126 | report = classification_report(y_true, y_pred, digits=4) 127 | logger.info("\n%s", report) 128 | return report 129 | 130 | def save_result(prefix='Active', func_paras=None, report=None, table=None, output_dir=None): 131 | result_path = os.path.join(output_dir, prefix+'.txt') 132 | with open(result_path,'a') as f: 133 | if func_paras: 134 | for para in func_paras: 135 | if(type(func_paras[para]))==np.ndarray: 136 | func_paras[para] = func_paras[para].shape 137 | if(type(func_paras[para]))==list: 138 | func_paras[para] = np.array(func_paras[para]).shape 139 | f.write('\nParameters:\n') 140 | for item in func_paras.items(): 141 | f.write(str(item)+'\n') 142 | if report: 143 | f.write(report) 144 | if table: 145 | table = table.get_string() 146 | f.write(table) 147 | 148 | def multi_argmax(values: np.ndarray, n_instances: int = 1) -> np.ndarray: 149 | """ 150 | Selects the indices of the n_instances highest values. 151 | 152 | Input: 153 | values: Contains the values to be selected from. 154 | n_instances: Specifies how many indices to return. 155 | Output: 156 | The indices of the n_instances largest values. 157 | """ 158 | assert n_instances <= values.shape[0], 'n_instances must be less or equal than the size of utility' 159 | 160 | max_idx = np.argpartition(-values, n_instances-1, axis=0)[:n_instances] 161 | return max_idx 162 | 163 | def uncertainty_sampling(model_instance, pool, size): 164 | ''' 165 | Uncertainty sampling policy. 166 | 167 | Input: 168 | model_instance: the model to do the uncertainty measure by give the labels prediction over unobserved data. 169 | pool: the unobserved data. 170 | size: the number of instances to be sampled in each round. 171 | Output: 172 | query_index: the n_instances index of sampled data. 173 | pool[query_index]: the corresponding data. 174 | ''' 175 | active_eval_loader = get_tr_set(train_examples=pool, batch_size=1, args=args) 176 | raw_prediction, turncate_list = active_eval(active_eval_loader, model_instance) # predict, get the softmax output 177 | word_prob = np.max(raw_prediction,axis=2) # select the max probability prediction as the word tag 178 | sentence_uncertainty = [] 179 | for i, sentence in enumerate(word_prob): 180 | sentence_uncertainty.append(np.sum(1-sentence[:turncate_list[i]])) 181 | query_index = multi_argmax(np.array(sentence_uncertainty), size) 182 | return query_index, pool[query_index] 183 | 184 | def nte_sampling(model_instance, pool, size): 185 | active_eval_loader = get_tr_set(train_examples=pool, batch_size=1, args=args) 186 | raw_prediction, turncate_list = active_eval(active_eval_loader, model_instance) # predict, get the softmax output 187 | sentence_nte = cal_nte(raw_prediction, turncate_list) 188 | query_index = multi_argmax(np.array(sentence_nte), size) 189 | return query_index, pool[query_index] 190 | 191 | def cal_nte(logits, turncate): 192 | sentence_nte = [] 193 | for idx, sent in enumerate(logits): 194 | sent_sum = 0 195 | for word in sent[:turncate[idx]]: 196 | tag_sum = 0 197 | for tag in word: 198 | tag_sum += tag*math.log(tag) 199 | sent_sum += tag_sum 200 | sentence_nte.append(-sent_sum/turncate[idx]) 201 | return sentence_nte 202 | 203 | def qbc_sampling(model_com, pool, n_instance): 204 | com_pred = [] 205 | active_eval_loader = get_tr_set(train_examples=pool, batch_size=1, args=args) 206 | for _model in model_com: 207 | raw_prediction, turncate_list = active_eval(active_eval_loader, _model) 208 | tag_prediction = result2tag(raw_prediction, turncate_list) 209 | com_pred.append(tag_prediction) 210 | vote_entropy = cal_vote_entropy(com_pred) 211 | query_index = multi_argmax(vote_entropy, n_instance) 212 | return query_index, pool[query_index] 213 | 214 | def cal_vote_entropy(mc_pred): 215 | ''' 216 | Calculate the vote entropy 217 | 218 | Input: 219 | mc_pred: 3d-shape (num_mc_model * num_sentence * max_len * n_tags) 220 | Output: 221 | vote_entropy: 2d-shape (num_sentence * max_len) 222 | ''' 223 | num_mc_model = len(mc_pred) 224 | num_sentence = mc_pred[0].shape[0] 225 | 226 | print('vote_matrix') 227 | vote_matrix = np.zeros((num_sentence, args.max_seq_length, num_labels)) 228 | for model_idx, pred in enumerate(mc_pred): 229 | for s_idx, sentence in enumerate(pred): 230 | for w_idx, word in enumerate(sentence): 231 | vote_matrix[s_idx][w_idx][word] += 1 232 | print('vote_prob_matrix') 233 | vote_prob_matrix = np.zeros((num_sentence, args.max_seq_length, num_labels)) 234 | for s_idx, sentence in enumerate(vote_matrix): 235 | for w_idx, word in enumerate(sentence): 236 | for tag_idx in range(num_labels): 237 | prob_i = np.sum(word==tag_idx) / num_mc_model 238 | vote_prob_matrix[s_idx][w_idx][tag_idx] = prob_i 239 | print('vote_entropy') 240 | vote_entropy = np.zeros(num_sentence) 241 | for s_idx, sentence in enumerate(vote_prob_matrix): 242 | sentence_entropy = 0 243 | for w_idx, word in enumerate(sentence): 244 | word_entropy = 0 245 | for tag_prob in word: 246 | if tag_prob: 247 | word_entropy -= tag_prob*(math.log(tag_prob)) 248 | sentence_entropy += word_entropy 249 | vote_entropy[s_idx] = sentence_entropy 250 | 251 | return vote_entropy 252 | 253 | def result2tag(result, turncate): 254 | ''' 255 | Convert the result with 3-d shape to the tags with 2-d shape. 256 | ''' 257 | sentences = [] 258 | for idx, sentence in enumerate(result): 259 | valid_len = turncate[idx] 260 | words = [] 261 | for word in sentence[:valid_len]: 262 | word = word.tolist() 263 | tag = word.index(max(word)) 264 | words.append(tag) 265 | sentences.append(words) 266 | return np.array(sentences) 267 | 268 | def random_sampling(model_instance, input_data, n_instances): 269 | ''' 270 | Random sampling policy. 271 | 272 | Input: 273 | model_instance: model 274 | input_data: the unobserved data. 275 | n_instances: the number of instances to be sampled in each round. 276 | Output: 277 | query_index: the n_instances index of sampled data. 278 | input_Data[query_index]: the corresponding data. 279 | ''' 280 | query_index = np.random.choice(range(len(input_data)), size=n_instances, replace=False) 281 | return query_index, input_data[query_index] 282 | 283 | def active_train(data_loader=None, model=None, Epochs=5, soft_loader=None, args=None): 284 | config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name) 285 | if model==None: 286 | model = Ner.from_pretrained(args.bert_model, from_tf = False, config = config) 287 | return_model = Ner.from_pretrained(args.bert_model, from_tf = False, config = config) 288 | model.to(device) 289 | return_model.to(device) 290 | param_optimizer = list(model.named_parameters()) 291 | no_decay = ['bias','LayerNorm.weight'] 292 | optimizer_grouped_parameters = [ 293 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 294 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 295 | ] 296 | num_train_optimization_steps = int(len(data_loader.dataset)/args.train_batch_size/args.gradient_accumulation_steps)*args.num_train_epochs #2190 297 | warmup_steps = int(args.warmup_proportion * num_train_optimization_steps) 298 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 299 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps) 300 | 301 | current_train_size = 0 302 | if soft_loader: 303 | current_train_size = len(data_loader.dataset) + len(soft_loader.dataset) 304 | else: 305 | current_train_size = len(data_loader.dataset) 306 | print('Training on {} data'.format(current_train_size)) 307 | 308 | model.train() 309 | tr_loss = 2020 310 | for epoch_idx in trange(int(Epochs), desc="Epoch"): 311 | current_loss = 0 312 | nb_tr_examples, nb_tr_steps = 0, 0 313 | for step, batch in enumerate(tqdm(data_loader, desc="Iteration")): 314 | batch = tuple(t.to(device) for t in batch) 315 | input_ids, input_mask, segment_ids, label_ids, valid_ids,l_mask = batch 316 | loss = model(input_ids, segment_ids, input_mask, label_ids, valid_ids, l_mask) 317 | if n_gpu > 1: 318 | loss = loss.mean() # mean() to average on multi-gpu. 319 | if args.gradient_accumulation_steps > 1: 320 | loss = loss / args.gradient_accumulation_steps 321 | if args.fp16: 322 | with amp.scale_loss(loss, optimizer) as scaled_loss: 323 | scaled_loss.backward() 324 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 325 | else: 326 | loss.backward() 327 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 328 | current_loss += loss 329 | nb_tr_examples += input_ids.size(0) 330 | nb_tr_steps += 1 331 | if (step + 1) % args.gradient_accumulation_steps == 0: 332 | optimizer.step() 333 | scheduler.step() # Update learning rate schedule 334 | model.zero_grad() 335 | if soft_loader: 336 | for input_ids, input_mask, segment_ids, soft_labels, valid_ids,l_mask in tqdm(soft_loader, desc="Soft Training"): 337 | input_ids = input_ids.to(device) 338 | input_mask = input_mask.to(device) 339 | segment_ids = segment_ids.to(device) 340 | soft_labels = soft_labels.to(device) 341 | l_mask = l_mask.to(device) 342 | #with torch.no_grad(): 343 | logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask) 344 | #logits = F.softmax(logits, dim=2) 345 | logits = logits.detach().cpu().float() 346 | soft_labels = soft_labels.detach().cpu().float() 347 | pos_weight = torch.ones([num_labels]) # All weights are equal to 1 348 | criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) 349 | loss = 0 350 | for i in range(len(logits)): 351 | turncate_len = np.count_nonzero(l_mask[i].detach().cpu().numpy()) 352 | logit = logits[i][:turncate_len] 353 | soft_label = soft_labels[i][:turncate_len] 354 | loss += criterion(logit, soft_label) 355 | loss = Variable(loss, requires_grad=True) 356 | current_loss += loss 357 | loss.backward() 358 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 359 | optimizer.step() 360 | scheduler.step() # Update learning rate schedule 361 | model.zero_grad() 362 | if current_loss <= tr_loss: 363 | return_model.load_state_dict(model.state_dict()) 364 | tr_loss = current_loss 365 | 366 | return return_model 367 | 368 | def active_eval(active_data_loader=None, model=None): 369 | model.to(device) 370 | model.eval() 371 | eval_loss, eval_accuracy = 0, 0 372 | nb_eval_steps, nb_eval_examples = 0, 0 373 | y_true = [] 374 | y_pred = [] 375 | raw_logits = [] 376 | turncate_list = [] 377 | label_map = {i : label for i, label in enumerate(label_list,1)} 378 | for input_ids, input_mask, segment_ids, label_ids,valid_ids,l_mask in tqdm(active_data_loader, desc="Evaluating"): 379 | input_ids = input_ids.to(device) 380 | input_mask = input_mask.to(device) 381 | segment_ids = segment_ids.to(device) 382 | valid_ids = valid_ids.to(device) 383 | label_ids = label_ids.to(device) 384 | l_mask = l_mask.to(device) 385 | 386 | with torch.no_grad(): 387 | logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask) 388 | 389 | logits = F.softmax(logits, dim=2) 390 | assert logits.shape[0] == 1 391 | logits = logits.detach().cpu().numpy().reshape((logits.shape[1], logits.shape[2])) 392 | turncate_len = np.count_nonzero(l_mask.detach().cpu().numpy()) 393 | turncate_list.append(turncate_len) 394 | raw_logits.append(logits) 395 | return raw_logits, turncate_list 396 | 397 | 398 | def active_learn(init_flag=None, train_data=None, num_initial=200, 399 | active_policy=None, num_query=5, num_sample=[100, 100, 100, 100, 100], 400 | dev_data=None, fit_only_new_data=False, Epochs=10, prefix='Active', args=None): 401 | ''' 402 | Implement active learning initializaiton and learning loop 403 | ''' 404 | func_paras = locals() 405 | # Data Initialization 406 | pool = copy.deepcopy(train_data) 407 | train_data = copy.deepcopy(train_data) 408 | original_datasize = len(train_data) 409 | 410 | initial_idx = np.random.choice(range(len(train_data)), size=num_initial, replace=False) 411 | train_data = np.array(train_data)[initial_idx] 412 | 413 | init_data_loader, query_idx = get_tr_set(size=num_initial, train_examples=train_data, args=args) 414 | pool = np.delete(pool, query_idx, axis=0) 415 | print(np.array(pool).shape) 416 | if init_flag: 417 | init_dir = 'init_dir' 418 | model = Ner.from_pretrained(init_dir) 419 | else: 420 | model = active_train(init_data_loader, None, Epochs, args=args) 421 | 422 | report = evaluate('Intialization', model, args) 423 | print_table = PrettyTable(['Model', 'Number of Query', 'Data Usage', 'Test_F1']) 424 | print_table.add_row(['Active Model', 'Model Initialization', len(train_data)/original_datasize, report.split()[-2]]) 425 | print(print_table) 426 | save_result(prefix=args.prefix, report=report, table=print_table, output_dir=args.output_dir) 427 | 428 | print('Learning loop start') 429 | for idx in range(num_query): 430 | print('\n\n-------Query no. %d--------\n' % (idx + 1)) 431 | query_idx, query_instance = active_policy(model, pool, num_sample[idx]) 432 | 433 | if fit_only_new_data: 434 | train_data = pool[query_idx] 435 | else: 436 | train_data = np.concatenate((train_data, pool[query_idx])) 437 | pool = np.delete(pool, query_idx, axis=0) 438 | active_data_loader = get_tr_set(train_examples=train_data, args=args) 439 | model = active_train(active_data_loader, model, Epochs, args=args) 440 | 441 | report = evaluate('Active Learning', model, args) 442 | print_table.add_row(['Active Model', idx+1, len(train_data)/original_datasize, report.split()[-2]]) 443 | print(print_table) 444 | 445 | save_result(prefix=args.prefix, func_paras=func_paras, report=report, table=print_table, output_dir=args.output_dir) 446 | 447 | return model 448 | 449 | ''' 450 | def active_qbc_learn(init_flag=None, train_data=train_examples, num_initial=200, 451 | active_policy=qbc_sampling, num_com=3, num_query=5, num_sample=[100, 100, 100, 100, 100], 452 | dev_data=dev_examples, fit_only_new_data=False, Epochs=10, prefix='Active'): 453 | #Implement active learning initializaiton and learning loop 454 | func_paras = locals() 455 | # Data Initialization 456 | pool = copy.deepcopy(train_data) 457 | train_data = copy.deepcopy(train_data) 458 | original_datasize = len(train_data) 459 | 460 | initial_idx = np.random.choice(range(len(train_data)), size=num_initial, replace=False) 461 | train_data = np.array(train_data)[initial_idx] 462 | 463 | init_data_loader, query_idx = get_tr_set(size=num_initial, train_examples=train_data) 464 | pool = np.delete(pool, query_idx, axis=0) 465 | print(np.array(pool).shape) 466 | if init_flag: 467 | init_dir = 'init_dir' 468 | model = Ner.from_pretrained(init_dir) 469 | print("Initial model loaded from google drive") 470 | else: 471 | model = active_train(init_data_loader, None, Epochs) 472 | 473 | report = evaluate('Intialization', model) 474 | print_table = PrettyTable(['Model', 'Number of Query', 'Data Usage', 'Test_F1']) 475 | print_table.add_row(['Active Model', 'Model Initialization', len(train_data)/original_datasize, report.split()[-2]]) 476 | print(print_table) 477 | 478 | # Construct the committee 479 | model_com = [] 480 | config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name) 481 | for i in range(num_com): 482 | _model = Ner.from_pretrained(args.bert_model, from_tf = False, config = config) 483 | _model.load_state_dict(model.state_dict()) 484 | model_com.append(_model) 485 | 486 | 487 | print('Learning loop start') 488 | for idx in range(num_query): 489 | print('\n-------Query no. %d--------\n' % (idx + 1)) 490 | query_idx, query_instance = active_policy(model_com, pool, num_sample[idx]) 491 | 492 | if fit_only_new_data: 493 | train_data = pool[query_idx] 494 | else: 495 | train_data = np.concatenate((train_data, pool[query_idx])) 496 | pool = np.delete(pool, query_idx, axis=0) 497 | active_data_loader = get_tr_set(train_examples=train_data) 498 | for _idx, _model in enumerate(model_com): 499 | print('\n-------Committee no. %d--------\n' % (_idx + 1)) 500 | _model = active_train(active_data_loader, _model, Epochs) 501 | report = evaluate('Active Learning', _model) 502 | print_table.add_row(['Active Model', idx+1, len(train_data)/original_datasize, report.split()[-2]]) 503 | print(print_table) 504 | 505 | save_result(prefix=prefix, func_paras=func_paras, report=report, table=print_table) 506 | 507 | return model 508 | ''' 509 | 510 | if __name__ == "__main__": 511 | parser = argparse.ArgumentParser() 512 | parser.add_argument("--bert_model", type=str, default='bert-base-cased') 513 | parser.add_argument("--data_dir", type=str, default='data/') 514 | parser.add_argument("--do_eval", type=bool, default=True) 515 | parser.add_argument("--do_train", type=bool, default=True) 516 | parser.add_argument("--max_seq_length", type=int, default=128) 517 | parser.add_argument("--num_train_epochs", type=int, default=10) 518 | parser.add_argument("--task_name", type=str, default='ner') 519 | parser.add_argument("--output_dir", type=str, default='CoNLL/result') 520 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 521 | parser.add_argument("--prefix", type=str, default='file_save_name') 522 | parser.add_argument("--active_policy", type=str, default='nte') 523 | 524 | # keep as default 525 | parser.add_argument("--server_ip", type=str, default='') 526 | parser.add_argument("--server_port", type=str, default='') 527 | parser.add_argument("--local_rank", type=int, default=-1) 528 | parser.add_argument("--no_cuda", type=bool, default=False) 529 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 530 | parser.add_argument("--train_batch_size", type=int, default=32) 531 | parser.add_argument("--seed", type=int, default=2020) 532 | parser.add_argument("--do_lower_case", type=bool, default=False) 533 | parser.add_argument("--weight_decay", type=float, default=0.01) 534 | parser.add_argument("--adam_epsilon", type=float, default=1e-08) 535 | parser.add_argument("--learning_rate", type=float, default=5e-05) 536 | parser.add_argument("--fp16", type=bool, default=False) 537 | parser.add_argument("--fp16_opt_level", type=str, default='O1') 538 | parser.add_argument("--eval_on", type=str, default='dev') 539 | parser.add_argument("--eval_batch_size", type=int, default=8) 540 | parser.add_argument("--max_grad_norm", type=float, default=1.0) 541 | 542 | 543 | args = parser.parse_args() 544 | # parse args 545 | logging.basicConfig(format='%(asctime)s-%(levelname)s-%(name)s-%(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 546 | logger = logging.getLogger(__name__) 547 | if args.server_ip and args.server_port: 548 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 549 | print("Waiting for debugger attach") 550 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 551 | ptvsd.wait_for_attach() 552 | 553 | if args.local_rank == -1 or args.no_cuda: 554 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 555 | #n_gpu = torch.cuda.device_count() 556 | n_gpu = 1 557 | else: 558 | torch.cuda.set_device(args.local_rank) 559 | device = torch.device("cuda", args.local_rank) 560 | n_gpu = 1 561 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 562 | torch.distributed.init_process_group(backend='nccl') 563 | 564 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 565 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 566 | 567 | if args.gradient_accumulation_steps < 1: 568 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(args.gradient_accumulation_steps)) 569 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 570 | 571 | random.seed(args.seed) 572 | np.random.seed(args.seed) 573 | torch.manual_seed(args.seed) 574 | 575 | if not args.do_train and not args.do_eval: 576 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 577 | if not os.path.exists(args.output_dir): 578 | os.makedirs(args.output_dir) 579 | 580 | processor = NerProcessor() 581 | label_list = processor.get_labels() 582 | num_labels = len(label_list) + 1 583 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 584 | 585 | train_examples = None 586 | num_train_optimization_steps = 0 587 | if args.do_train: 588 | train_examples = processor.get_train_examples(args.data_dir) 589 | num_train_optimization_steps = int(len(train_examples)/args.train_batch_size/args.gradient_accumulation_steps)*args.num_train_epochs 590 | if args.local_rank != -1: 591 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 592 | 593 | if args.do_eval: 594 | if args.eval_on == 'dev': 595 | dev_examples = processor.get_dev_examples(args.data_dir) 596 | if args.eval_on == 'test': 597 | dev_examples = processor.get_test_examples(args.data_dir) 598 | 599 | if args.local_rank not in [-1, 0]: 600 | torch.distributed.barrier() 601 | 602 | # prepare model 603 | config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name) 604 | model = Ner.from_pretrained(args.bert_model, from_tf = False, config = config) 605 | 606 | if args.local_rank == 0: 607 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 608 | 609 | model.to(device) 610 | param_optimizer = list(model.named_parameters()) 611 | no_decay = ['bias','LayerNorm.weight'] 612 | optimizer_grouped_parameters = [ 613 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 614 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 615 | ] 616 | 617 | warmup_steps = int(args.warmup_proportion * num_train_optimization_steps) 618 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 619 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps) 620 | 621 | # For our experiment, the following can be ignored 622 | if args.fp16: 623 | try: 624 | from apex import amp 625 | except ImportError: 626 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 627 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 628 | 629 | # multi-gpu training (should be after apex fp16 initialization) 630 | if n_gpu > 1: 631 | model = torch.nn.DataParallel(model) 632 | 633 | if args.local_rank != -1: 634 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 635 | 636 | if args.active_policy=='random': 637 | active_policy = random_sampling 638 | if args.active_policy=='lc': 639 | active_policy = uncertainty_sampling 640 | if args.active_policy=='nte': 641 | active_policy = nte_sampling 642 | 643 | model = active_learn(init_flag=False, train_data=train_examples, dev_data=dev_examples, active_policy=active_policy, prefix=args.prefix, Epochs=args.num_train_epochs, args=args) 644 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class InputExample(object): 4 | """A single training/test example for simple sequence classification.""" 5 | 6 | def __init__(self, guid, text_a, text_b=None, label=None): 7 | """Constructs a InputExample. 8 | Args: 9 | guid: Unique id for the example. 10 | text_a: string. The untokenized text of the first sequence. For single 11 | sequence tasks, only this sequence must be specified. 12 | text_b: (Optional) string. The untokenized text of the second sequence. 13 | Only must be specified for sequence pair tasks. 14 | label: (Optional) string. The label of the example. This should be 15 | specified for train and dev examples, but not for test examples. 16 | """ 17 | self.guid = guid 18 | self.text_a = text_a 19 | self.text_b = text_b 20 | self.label = label 21 | 22 | class InputFeatures(object): 23 | """A single set of features of data.""" 24 | 25 | def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None): 26 | self.input_ids = input_ids 27 | self.input_mask = input_mask 28 | self.segment_ids = segment_ids 29 | self.label_id = label_id 30 | self.valid_ids = valid_ids 31 | self.label_mask = label_mask 32 | 33 | def readfile(filename): 34 | '''read file''' 35 | f = open(filename) 36 | data = [] 37 | sentence = [] 38 | label= [] 39 | for line in f: 40 | if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n": 41 | if len(sentence) > 0: 42 | data.append((sentence,label)) 43 | sentence = [] 44 | label = [] 45 | continue 46 | splits = line.split(' ') 47 | sentence.append(splits[0]) 48 | label.append(splits[-1][:-1]) 49 | 50 | if len(sentence) >0: 51 | data.append((sentence,label)) 52 | sentence = [] 53 | label = [] 54 | return data 55 | 56 | 57 | class DataProcessor(object): 58 | """Base class for data converters for sequence classification data sets.""" 59 | def get_train_examples(self, data_dir): 60 | """Gets a collection of `InputExample`s for the train set.""" 61 | raise NotImplementedError() 62 | def get_dev_examples(self, data_dir): 63 | """Gets a collection of `InputExample`s for the dev set.""" 64 | raise NotImplementedError() 65 | def get_labels(self): 66 | """Gets the list of labels for this data set.""" 67 | raise NotImplementedError() 68 | @classmethod 69 | def _read_tsv(cls, input_file, quotechar=None): 70 | """Reads a tab separated value file.""" 71 | return readfile(input_file) 72 | 73 | 74 | class NerProcessor(DataProcessor): 75 | """Processor for the CoNLL-2003 data set.""" 76 | 77 | def get_train_examples(self, data_dir, size=None): 78 | """See base class.""" 79 | train_file = self._read_tsv(os.path.join(data_dir, "train.txt")) 80 | return_example = self._create_examples(train_file, "train") 81 | if size: 82 | select_idx = np.random.choice(range(len(return_example)), size=size, replace=False) 83 | return_example = list(np.array(return_example)[select_idx]) 84 | return return_example 85 | 86 | def get_dev_examples(self, data_dir): 87 | """See base class.""" 88 | return self._create_examples( 89 | self._read_tsv(os.path.join(data_dir, "valid.txt")), "dev") 90 | def get_test_examples(self, data_dir): 91 | """See base class.""" 92 | return self._create_examples( 93 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 94 | def get_labels(self): 95 | return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"] 96 | def _create_examples(self,lines,set_type): 97 | examples = [] 98 | for i,(sentence,label) in enumerate(lines): 99 | guid = "%s-%s" % (set_type, i) 100 | text_a = ' '.join(sentence) 101 | text_b = None 102 | label = label 103 | examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label)) 104 | return examples 105 | 106 | 107 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, logger): 108 | """Loads a data file into a list of `InputBatch`s.""" 109 | 110 | label_map = {label : i for i, label in enumerate(label_list,1)} 111 | features = [] 112 | for (ex_index,example) in enumerate(examples): 113 | textlist = example.text_a.split(' ') 114 | labellist = example.label 115 | tokens = [] 116 | labels = [] 117 | valid = [] 118 | label_mask = [] 119 | for i, word in enumerate(textlist): 120 | token = tokenizer.tokenize(word) 121 | tokens.extend(token) # maybe ##__ will be added 122 | label_1 = labellist[i] 123 | for m in range(len(token)): 124 | if m == 0: 125 | labels.append(label_1) 126 | valid.append(1) 127 | label_mask.append(1) 128 | else: 129 | valid.append(0) 130 | if len(tokens) >= max_seq_length - 1: 131 | tokens = tokens[0:(max_seq_length - 2)] 132 | labels = labels[0:(max_seq_length - 2)] 133 | valid = valid[0:(max_seq_length - 2)] 134 | label_mask = label_mask[0:(max_seq_length - 2)] 135 | ntokens = [] 136 | segment_ids = [] 137 | label_ids = [] 138 | ntokens.append("[CLS]") 139 | segment_ids.append(0) 140 | valid.insert(0,1) 141 | label_mask.insert(0,1) 142 | label_ids.append(label_map["[CLS]"]) 143 | for i, token in enumerate(tokens): 144 | ntokens.append(token) 145 | segment_ids.append(0) 146 | if len(labels) > i: 147 | label_ids.append(label_map[labels[i]]) 148 | ntokens.append("[SEP]") 149 | segment_ids.append(0) # segment ids always 150 | valid.append(1) 151 | label_mask.append(1) 152 | label_ids.append(label_map["[SEP]"]) 153 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 154 | input_mask = [1] * len(input_ids) 155 | label_mask = [1] * len(label_ids) 156 | # padding 157 | while len(input_ids) < max_seq_length: 158 | input_ids.append(0) 159 | input_mask.append(0) 160 | segment_ids.append(0) 161 | label_ids.append(0) 162 | valid.append(1) 163 | label_mask.append(0) 164 | while len(label_ids) < max_seq_length: 165 | label_ids.append(0) 166 | label_mask.append(0) 167 | assert len(input_ids) == max_seq_length 168 | assert len(input_mask) == max_seq_length 169 | assert len(segment_ids) == max_seq_length 170 | assert len(label_ids) == max_seq_length 171 | assert len(valid) == max_seq_length 172 | assert len(label_mask) == max_seq_length 173 | 174 | if ex_index < 1: 175 | logger.info("*** Example ***") 176 | logger.info("guid: %s" % (example.guid)) 177 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 178 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 179 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 180 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 181 | features.append(InputFeatures(input_ids=input_ids, 182 | input_mask=input_mask, 183 | segment_ids=segment_ids, 184 | label_id=label_ids, 185 | valid_ids=valid, 186 | label_mask=label_mask)) 187 | return features -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig, 4 | BertForTokenClassification, BertTokenizer, 5 | WarmupLinearSchedule) 6 | 7 | 8 | class Ner(BertForTokenClassification): 9 | 10 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None): 11 | sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0] 12 | batch_size,max_len,feat_dim = sequence_output.shape 13 | valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda') 14 | for i in range(batch_size): 15 | jj = -1 16 | for j in range(max_len): 17 | if valid_ids[i][j].item() == 1: 18 | jj += 1 19 | valid_output[i][jj] = sequence_output[i][j] 20 | sequence_output = self.dropout(valid_output) 21 | logits = self.classifier(sequence_output) 22 | 23 | if labels is not None: 24 | loss_fct = nn.CrossEntropyLoss(ignore_index=0) 25 | # Only keep active parts of the loss 26 | #attention_mask_label = None 27 | if attention_mask_label is not None: 28 | active_loss = attention_mask_label.view(-1) == 1 29 | active_logits = logits.view(-1, self.num_labels)[active_loss] 30 | active_labels = labels.view(-1)[active_loss] 31 | loss = loss_fct(active_logits, active_labels) 32 | else: 33 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 34 | return loss 35 | else: 36 | return logits -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-transformers==1.2.0 2 | torch==1.2.0 3 | # metric 4 | seqeval==0.0.5 5 | # training progressbar 6 | tqdm==4.31.1 7 | # tokeniztion 8 | nltk==3.4.5 9 | # for rest api 10 | Flask==1.1.1 11 | Flask-Cors==3.0.8 12 | pytorch_pretrained_bert==0.6.2 -------------------------------------------------------------------------------- /seqmix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import sys 8 | import copy 9 | import time 10 | import math 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from prettytable import PrettyTable 16 | from torch.autograd import Variable 17 | from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig, 18 | BertForTokenClassification, BertTokenizer, 19 | WarmupLinearSchedule) 20 | from torch import nn 21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 22 | TensorDataset) 23 | from torch.utils.data.distributed import DistributedSampler 24 | from tqdm import tqdm, trange 25 | from seqeval.metrics import classification_report 26 | 27 | from model import Ner 28 | from data_load import readfile, NerProcessor, convert_examples_to_features 29 | from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel 30 | from active_learn import nte_sampling, random_sampling, uncertainty_sampling 31 | 32 | '''Get Word Embedding''' 33 | def get_word_embedding(sp_output_dir=None): 34 | model = Ner.from_pretrained(sp_output_dir) 35 | tokenizer = BertTokenizer.from_pretrained(sp_output_dir, do_lower_case=args.do_lower_case) 36 | 37 | for name, parameters in model.named_parameters(): 38 | #print(name,':',parameters.size()) 39 | if name=='bert.embeddings.word_embeddings.weight': 40 | bert_embedding = parameters.detach().cpu().numpy() 41 | 42 | wordidx2ebd = {idx:bert_embedding[idx] for idx in range(bert_embedding.shape[0])} 43 | ebd2wordidx = {} 44 | for k,v in wordidx2ebd.items(): 45 | ebd2wordidx[tuple(v)] = k 46 | 47 | return wordidx2ebd, ebd2wordidx 48 | 49 | def x_reconstruct(sequence): 50 | ''' 51 | Experiment helper function. To reconstruct the sentence from a series of idx of word 52 | ''' 53 | seq_list = [] 54 | seq_str = ' ' 55 | for item in sequence: 56 | if item == n_words-1: 57 | break 58 | seq_list.append(idx2word[item]) 59 | return seq_str.join(seq_list) 60 | 61 | def score_gpt(sentence): 62 | tokenize_input = gpt_tokenizer.tokenize(sentence) 63 | tensor_input = torch.tensor([gpt_tokenizer.convert_tokens_to_ids(tokenize_input)]) 64 | loss=gpt_model(tensor_input, lm_labels=tensor_input) 65 | return math.exp(loss) 66 | 67 | def high_quality(sequence, score_limit_upper=500, score_limit_low=0): 68 | score = score_gpt(sequence) 69 | return (score>=score_limit_low and score= valid_tag_bar: 123 | sub_sequence.append(tuple(sequence[index: index+window_size])) 124 | subseq_start_index.append(index) 125 | return sub_sequence, subseq_start_index 126 | 127 | def soft_pair(candidate): 128 | # valid tag num: 129 | valid_tag_list = [] 130 | for idx, item in enumerate(candidate): 131 | tags = item.label 132 | exclude_label = ['O','[CLS]','[SEP]','[UKN]'] 133 | valid_tag_count = 0 134 | for tag in tags: 135 | if tag not in exclude_label: 136 | valid_tag_count += 1 137 | if valid_tag_count >= valid_tag_bar: 138 | valid_tag_list.append(idx) 139 | # only search in the sentences with enough valid tags 140 | valid_len_list = [] 141 | for item in candidate[valid_tag_list]: 142 | valid_len_list.append(len(item.label)) 143 | # equal length index list (52,) 144 | equal_len_index_list = [] 145 | for length in range(args.max_seq_length): 146 | equal_len_index = np.where(np.array(valid_len_list)==length+1) 147 | equal_len_index_list.append(np.array(valid_tag_list)[list(equal_len_index)]) 148 | return equal_len_index_list 149 | 150 | def soft_pair_index_generator(equal_len_index_list, pair_num, valid_tag_bar): 151 | pair_index_list = [] 152 | len_range = [] 153 | for i in range(len(equal_len_index_list)): 154 | if equal_len_index_list[i].shape[0]>=2: 155 | len_range.append(i+1) 156 | for i in range(pair_num): 157 | temp_len = random.choice(len_range) 158 | pair_index_list.append(random.sample(list(equal_len_index_list[temp_len-1]),2)) 159 | return pair_index_list 160 | 161 | 162 | def lf_mixup(mixdata, sentence1_idx, start_idx1, sentence2_idx, start_idx2, hyper_lambda): 163 | ''' 164 | Label-fixed Mixup. 165 | Note that here the start_idx1 and start_idx2 are only int rather than array. 166 | ''' 167 | new_seq1 = list((mixdata[sentence1_idx].text_a).split()) 168 | new_seq2 = list((mixdata[sentence2_idx].text_a).split()) 169 | 170 | mix_seq = [] 171 | # mixup 172 | for i in range(mix_len): 173 | e1 = wordidx2ebd[tokenizer.convert_tokens_to_ids(new_seq1[start_idx1+i])] 174 | e2 = wordidx2ebd[tokenizer.convert_tokens_to_ids(new_seq2[start_idx2+i])] 175 | e_mix = hyper_lambda*e1 + (1-hyper_lambda)*e2 176 | mix_token = most_similar(e_mix, 7) # return candidate word pool 177 | exclude_pool = [new_seq1[start_idx1+i], new_seq2[start_idx2+i], '[UNK]', '[CLS]', '[SEP]', '[PAD]'] 178 | for token in mix_token: 179 | if token not in exclude_pool and token.find('[unused')==-1 and token.find('##')==-1 : 180 | mix_seq.append(str(token)) 181 | #print(token) 182 | break 183 | 184 | # substitution 185 | for i in range(mix_len): 186 | try: 187 | new_seq1[start_idx1+i] = mix_seq[i] 188 | new_seq2[start_idx2+i] = mix_seq[i] 189 | except: 190 | print('\n---NEW SEQ 1 - LENGTH = {}, START IDX = {}---\n'.format(len(new_seq1), start_idx1)) 191 | print('\n---NEW SEQ 2 - LENGTH = {}, START IDX = {}---\n'.format(len(new_seq2), start_idx2)) 192 | continue 193 | new_seq_1 = ' '.join(new_seq1) 194 | new_seq_2 = ' '.join(new_seq2) 195 | new_seq1_tag = mixdata[sentence1_idx].label 196 | new_seq2_tag = mixdata[sentence2_idx].label 197 | 198 | return new_seq_1, new_seq_2, new_seq1_tag, new_seq2_tag 199 | 200 | 201 | def lf_augment(candidate_data, num_mixup, hyper_alpha, score_limit_upper=500, score_limit_low=0): 202 | ''' 203 | Label_fixed augment. 204 | Given the candidate dataset and number of samples to be generated via mixup method, augment the training dataset by implementing mixup. 205 | ''' 206 | global GUID_COUNT 207 | time_start=time.time() 208 | pair_count = 0 209 | stop_flag = 0 210 | new_sample_count = 0 211 | new_candidate_data = list(copy.deepcopy(candidate_data)) 212 | 213 | for i in range(len(candidate_data)-1): 214 | sub_sequence_i, subseq_i_index = find_sub_seq(candidate_data[i].label, window_size, valid_tag_bar) 215 | for j in range(i+1, len(candidate_data)): 216 | sub_sequence_j, subseq_j_index = find_sub_seq(candidate_data[j].label, window_size, valid_tag_bar) 217 | same_subseq = extract_same_elem(sub_sequence_i, sub_sequence_j) 218 | # If the same subsequence exists: 219 | if same_subseq != []: 220 | for ii in range(len(sub_sequence_i)): 221 | for jj in range(len(sub_sequence_j)): 222 | if sub_sequence_i[ii] == sub_sequence_j[jj]: 223 | hyper_lambda = np.random.beta(hyper_alpha, hyper_alpha) 224 | newseq1, newseq2, newseq1_tag, newseq2_tag = lf_mixup(candidate_data, i, subseq_i_index[ii], j, subseq_j_index[jj], hyper_lambda) 225 | # add newseq1 226 | if score_limit_upper < 0: 227 | high_quality_1 = True 228 | high_quality_2 = True 229 | else: 230 | high_quality_1, score_1 = high_quality(newseq1, score_limit_upper, score_limit_low) 231 | high_quality_2, score_2 = high_quality(newseq2, score_limit_upper, score_limit_low) 232 | if high_quality_1: 233 | GUID_COUNT += 1 234 | new_candidate_data.append(InputExample(guid=GUID_COUNT, text_a=newseq1, text_b=None, label=newseq1_tag)) 235 | new_sample_count += 1 236 | # add newseq2 237 | if high_quality_2: 238 | GUID_COUNT += 1 239 | new_candidate_data.append(InputExample(guid=GUID_COUNT, text_a=newseq2, text_b=None, label=newseq2_tag)) 240 | new_sample_count += 1 241 | if high_quality_1 or high_quality_2: 242 | break 243 | break 244 | if new_sample_count >= num_mixup: 245 | stop_flag = 1 246 | break 247 | if stop_flag: 248 | break 249 | time_end=time.time() 250 | print('{} extra samples are generated, the time cost is {} s'.format(new_sample_count, time_end - time_start)) 251 | return new_candidate_data, new_sample_count 252 | 253 | 254 | def tag2onehot(tag): 255 | label_map = {label : i for i, label in enumerate(label_list,1)} 256 | tagid = label_map[tag] 257 | onehot_tag = np.zeros(len(label_map)+1) 258 | onehot_tag[tagid] = 1 259 | return onehot_tag 260 | 261 | def onehot2tag(tag): 262 | label_map = {label : i for i, label in enumerate(label_list,1)} 263 | reverse_label_map = {i:label for i, label in enumerate(label_list,1)} 264 | return_tag=[] 265 | for word in tag: 266 | if word.any()!=0: 267 | idx = np.where(word!=0) 268 | if len(idx[0]) > 1: 269 | mixtag = '' 270 | for i, item in enumerate(idx[0]): 271 | tag = reverse_label_map[item] 272 | mixtag += ' '+str(word[item])+ tag 273 | else: 274 | mixtag = reverse_label_map[idx[0][0]] 275 | return_tag.append(mixtag) 276 | return return_tag 277 | 278 | 279 | def slack_mixup(mixdata, sentence1_idx, sentence2_idx, start_idx1, start_idx2, hyper_lambda): 280 | ''' 281 | This function implement sentence-level mixup, it will be called by the function augment(). 282 | ''' 283 | new_seq1 = list((mixdata[sentence1_idx].text_a).split()) 284 | new_seq2 = list((mixdata[sentence2_idx].text_a).split()) 285 | labels_1 = copy.deepcopy(mixdata[sentence1_idx].label) 286 | labels_2 = copy.deepcopy(mixdata[sentence2_idx].label) 287 | labels_1 = np.concatenate((['[CLS]'], labels_1, ['[SEP]'])) 288 | labels_2 = np.concatenate((['[CLS]'], labels_2, ['[SEP]'])) 289 | 290 | # Transfer to one-hot form 291 | new_seq1_tag = [] 292 | new_seq2_tag = [] 293 | for i, item in enumerate(labels_1): 294 | new_seq1_tag.append(tag2onehot(item)) 295 | for i, item in enumerate(labels_2): 296 | new_seq2_tag.append(tag2onehot(item)) 297 | # padding 298 | while len(new_seq1_tag) < args.max_seq_length: 299 | new_seq1_tag.append(np.zeros(12)) 300 | while len(new_seq2_tag) < args.max_seq_length: 301 | new_seq2_tag.append(np.zeros(12)) 302 | 303 | mix_seq = [] 304 | mix_seq_tag = [] 305 | 306 | # mixup 307 | for i in range(mix_len): 308 | e1 = wordidx2ebd[tokenizer.convert_tokens_to_ids(new_seq1[start_idx1[0]+i])] 309 | e2 = wordidx2ebd[tokenizer.convert_tokens_to_ids(new_seq2[start_idx2[0]+i])] 310 | e_mix = hyper_lambda*e1 + (1-hyper_lambda)*e2 311 | mix_token = most_similar(e_mix, 7) # return 1 candidate word 312 | exclude_pool = [new_seq1[start_idx1[0]+i], new_seq2[start_idx2[0]+i], '[UNK]', '[CLS]', '[SEP]', '[PAD]'] 313 | 314 | for token in mix_token: 315 | #if token not in exclude_pool and token.find('[unused')==-1 and token.find('##')==-1: 316 | if token not in exclude_pool: 317 | mix_seq.append(token) 318 | break 319 | tag1 = new_seq1_tag[start_idx1[0]+i] 320 | tag2 = new_seq2_tag[start_idx2[0]+i] 321 | 322 | mix_tag = hyper_lambda*tag1 + (1-hyper_lambda)*tag2 323 | mix_seq_tag.append(mix_tag) 324 | 325 | # substitution 326 | for i in range(mix_len): 327 | new_seq1[start_idx1[0]+i] = mix_seq[i] 328 | new_seq2[start_idx2[0]+i] = mix_seq[i] 329 | new_seq1_tag[start_idx1[0]+i] = mix_seq_tag[i] 330 | new_seq2_tag[start_idx2[0]+i] = mix_seq_tag[i] 331 | 332 | new_seq1 = ' '.join(new_seq1) 333 | new_seq2 = ' '.join(new_seq2) 334 | return new_seq1, new_seq2, new_seq1_tag, new_seq2_tag 335 | 336 | 337 | def slack_augment(candidate_data=None, num_mixup=None, hyper_alpha=8, score_limit_upper=500, score_limit_low=0): 338 | ''' 339 | Given the candidate dataset and number of samples to be generated via mixup method, augment the training dataset by implementing mixup. 340 | Implement augmentation via slack-mixup 341 | ''' 342 | global GUID_COUNT 343 | time_start=time.time() 344 | new_sample_count = 0 345 | stop_flag = 0 346 | mixup_data = [] 347 | mixup_label = [] 348 | for i in range(len(candidate_data)-1): 349 | sub_sequence_i, subseq_i_index = find_sub_seq(candidate_data[i].label, window_size, valid_tag_bar) 350 | if len(sub_sequence_i)>0: 351 | for j in range(i+1, len(candidate_data)): 352 | sub_sequence_j, subseq_j_index = find_sub_seq(candidate_data[j].label, window_size, valid_tag_bar) 353 | # If the slack pair exists: 354 | if len(sub_sequence_j)>0: 355 | hyper_lambda = np.random.beta(hyper_alpha, hyper_alpha) # Beta distribution 356 | newseq1, newseq2, newseq1_tag, newseq2_tag = slack_mixup(candidate_data, i, j, subseq_i_index, subseq_j_index, hyper_lambda) 357 | if score_limit_upper < 0: 358 | high_quality_1 = True 359 | high_quality_2 = True 360 | else: 361 | high_quality_1,score_1 = high_quality(newseq1, score_limit_upper, score_limit_low) 362 | high_quality_2,score_2 = high_quality(newseq2, score_limit_upper, score_limit_low) 363 | 364 | if high_quality_1 or high_quality_2: 365 | GUID_COUNT += 1 366 | mixup_data.append(InputExample(guid=GUID_COUNT, text_a=newseq1, text_b=None, label=candidate_data[i].label)) 367 | mixup_label.append(newseq1_tag) 368 | new_sample_count += 1 369 | if new_sample_count >= num_mixup: 370 | stop_flag = 1 371 | break 372 | # add newseq2 373 | if high_quality_2: 374 | GUID_COUNT += 1 375 | mixup_data.append(InputExample(guid=GUID_COUNT, text_a=newseq2, text_b=None, label=candidate_data[j].label)) 376 | mixup_label.append(newseq2_tag) 377 | new_sample_count += 1 378 | if new_sample_count >= num_mixup: 379 | stop_flag = 1 380 | break 381 | if stop_flag: 382 | break 383 | time_end=time.time() 384 | print('{} extra samples are generated, the time cost is {} s'.format(new_sample_count, time_end - time_start)) 385 | return mixup_data, mixup_label, new_sample_count 386 | 387 | def soft_mixup(candidate_1, candidate_2, hyper_lambda): 388 | ''' 389 | This function implement sentence-level mixup, it will be called by the function soft_augment(). 390 | ''' 391 | # sparse sequence and label 392 | seq1 = list((candidate_1.text_a).split()) 393 | seq2 = list((candidate_2.text_a).split()) 394 | 395 | y1 = copy.deepcopy(candidate_1.label) 396 | y2 = copy.deepcopy(candidate_2.label) 397 | 398 | # Transfer to one-hot form 399 | new_seq1_tag = [] 400 | new_seq2_tag = [] 401 | for i, item in enumerate(y1): 402 | new_seq1_tag.append(tag2onehot(item)) 403 | for i, item in enumerate(y2): 404 | new_seq2_tag.append(tag2onehot(item)) 405 | # padding 406 | while len(new_seq1_tag) < args.max_seq_length: 407 | new_seq1_tag.append(np.zeros(12)) 408 | while len(new_seq2_tag) < args.max_seq_length: 409 | new_seq2_tag.append(np.zeros(12)) 410 | 411 | # prepare the generation form 412 | new_seq = copy.deepcopy(seq1) 413 | new_seq_tag = copy.deepcopy(new_seq1_tag) 414 | 415 | assert len(seq1) == len(seq2), 'The two sequences should be in same valid length' 416 | mix_len_sentence = len(seq1) 417 | mix_seq = [] 418 | mix_seq_tag = [] 419 | 420 | # mixup 421 | for i in range(mix_len_sentence): 422 | e1 = wordidx2ebd[tokenizer.convert_tokens_to_ids(seq1[i])] 423 | e2 = wordidx2ebd[tokenizer.convert_tokens_to_ids(seq2[i])] 424 | e_mix = hyper_lambda*e1 + (1-hyper_lambda)*e2 425 | mix_token = most_similar(e_mix, 7) 426 | exclude_pool = [seq1[i], seq2[i], '[UNK]', '[CLS]', '[SEP]', '[PAD]'] 427 | 428 | for token in mix_token: 429 | if token not in exclude_pool: 430 | mix_seq.append(token) 431 | break 432 | tag1 = new_seq1_tag[i] 433 | tag2 = new_seq2_tag[i] 434 | mix_tag = hyper_lambda*tag1 + (1-hyper_lambda)*tag2 435 | mix_seq_tag.append(mix_tag) 436 | 437 | # substitution 438 | for i in range(mix_len_sentence): 439 | new_seq[i] = mix_seq[i] 440 | new_seq_tag[i] = mix_seq_tag[i] 441 | new_seq = ' '.join(new_seq) 442 | return new_seq, new_seq_tag 443 | 444 | def soft_augment(candidate_data=None, num_mixup=None, hyper_alpha=8, score_limit_upper=500, score_limit_low=0): 445 | global GUID_COUNT 446 | print('Implementing soft mixup augmentation, which may take hundreds of seconds') 447 | time_start=time.time() 448 | new_sample_count = 0 449 | mixup_data = [] 450 | mixup_label = [] 451 | candidate_data = copy.deepcopy(candidate_data) 452 | 453 | equal_len_index_list = soft_pair(candidate_data) 454 | pair_index_list = soft_pair_index_generator(equal_len_index_list, 15*num_mixup, valid_tag_bar) 455 | for index in pair_index_list: 456 | hyper_lambda = np.random.beta(hyper_alpha, hyper_alpha) # Beta distribution 457 | i = index[0] 458 | j = index[1] 459 | new_seq, new_seq_tag = soft_mixup(candidate_data[i], candidate_data[j], hyper_lambda) 460 | # add to the training set 461 | if score_limit_upper < 0: 462 | high_quality_flag = True 463 | else: 464 | high_quality_flag, score = high_quality(new_seq, score_limit_upper, score_limit_low) 465 | if high_quality_flag: 466 | GUID_COUNT += 1 467 | mixup_data.append(InputExample(guid=GUID_COUNT, text_a=new_seq, text_b=None, label=candidate_data[i].label)) 468 | mixup_label.append(new_seq_tag) 469 | new_sample_count += 1 470 | case_util(score, candidate_data[i].text_a, candidate_data[j].text_a, candidate_data[i].label, candidate_data[j].label, 471 | new_seq, new_seq_tag, prefix='Soft_case') 472 | if new_sample_count >= num_mixup: 473 | break 474 | 475 | time_end=time.time() 476 | print('{} extra samples are generated, the time cost is {} s'.format(new_sample_count, time_end - time_start)) 477 | return mixup_data, mixup_label, new_sample_count 478 | 479 | def active_augment_learn(init_flag=None, train_data=None, num_initial=200, 480 | active_policy=uncertainty_sampling, augment_method=lf_augment, 481 | num_query=5, num_sample=[100, 100, 100, 100, 100], 482 | augment_rate=0.2, augment_decay=1, 483 | hyper_alpha=8, alpha_decay=1, 484 | Epochs=10, score_limit_low=0, score_limit_upper=500, fit_only_new_data=False, 485 | mixup_flag=True, single_use=False, prefix='SeqMix'): 486 | ''' 487 | Implement active learning initializaiton and learning loop 488 | ''' 489 | func_paras = locals() 490 | # Data Initialization 491 | pool = copy.deepcopy(train_data) 492 | train_data = copy.deepcopy(train_data) 493 | original_datasize = len(train_data) 494 | 495 | initial_idx = np.random.choice(range(len(train_data)), size=num_initial, replace=False) 496 | train_data = np.array(train_data)[initial_idx] 497 | 498 | init_data_loader, query_idx = get_tr_set(size=num_initial, train_examples=train_data) 499 | pool = np.delete(pool, query_idx, axis=0) 500 | print(np.array(pool).shape) 501 | if init_flag: 502 | init_dir = 'init_dir' 503 | model = Ner.from_pretrained(init_dir) 504 | print("Initial model loaded from google drive") 505 | else: 506 | model = active_train(init_data_loader, None, Epochs) 507 | 508 | # report 509 | report = evaluate('Intialization', model) 510 | print_table = PrettyTable(['Model', 'Number of Query', 'Data Usage', 'Data Augmented', 'Test_F1']) 511 | print_table.add_row(['Initial Model', 'Model Initialization', len(train_data)/original_datasize, 0, report.split()[-2]]) 512 | print(print_table) 513 | 514 | # augment on the seed set 515 | test_f1 = [] 516 | dev_f1 = [] 517 | num_augment = int(num_initial*augment_rate) 518 | 519 | if augment_method == slack_augment or soft_augment: 520 | soft_data, soft_labels, new_sample_count = augment_method(train_data, num_augment, hyper_alpha, score_limit_upper, score_limit_low) 521 | soft_loader = get_tr_set(train_examples=soft_data, soft_labels=soft_labels) 522 | else: 523 | mix_data, new_sample_count = augment_method(train_data, num_augment, hyper_alpha, score_limit_upper, score_limit_low) 524 | soft_loader = None 525 | 526 | aug_data_loader = get_tr_set(train_examples=train_data) 527 | model = active_train(data_loader=aug_data_loader, model=model, Epochs=Epochs, soft_loader=soft_loader) 528 | #return model 529 | report = evaluate('SeedSetAug', model) 530 | aug_total_count = new_sample_count 531 | print_table.add_row(['Augment Model', 'Seed Set Augmented', len(train_data)/original_datasize, aug_total_count, report.split()[-2]]) 532 | print(print_table) 533 | save_result(prefix=prefix, func_paras=func_paras, report=report, table=print_table) 534 | 535 | # learning loop 536 | print('Learning loop start') 537 | for idx in range(num_query): 538 | num_augment = int((num_sample[idx]*augment_rate) *(augment_decay**idx)) 539 | hyper_alpha = hyper_alpha * (alpha_decay**idx) 540 | 541 | print('Query no. %d' % (idx + 1)) 542 | query_idx, query_instance = active_policy(model, pool, num_sample[idx]) 543 | mixup_candidate = pool[query_idx] 544 | pool = np.delete(pool, query_idx, axis=0) 545 | 546 | if augment_method == slack_augment or soft_augment: 547 | new_soft_data, new_soft_labels, new_sample_count = augment_method(mixup_candidate, num_augment, hyper_alpha, score_limit_upper, score_limit_low) 548 | soft_data = np.concatenate((soft_data, new_soft_data)) 549 | soft_labels = np.concatenate((soft_labels, new_soft_labels)) 550 | soft_loader = get_tr_set(train_examples=soft_data, soft_labels=soft_labels) 551 | mix_data = mixup_candidate 552 | else: 553 | if mixup_flag: # mixup augment 554 | # mix_data consist of original mixup_candidate and new samples generated by SeqMix 555 | mix_data, new_sample_count = augment_method(mixup_candidate, num_augment, hyper_alpha, score_limit_upper, score_limit_low) 556 | soft_loader = None 557 | else: # duplicate original paring data 558 | mix_data, new_sample_count = duplicate_pair_data(mixup_candidate_X, mixup_candidate_y, num_augment) 559 | 560 | train_data = np.concatenate((train_data, mix_data)) 561 | aug_total_count += new_sample_count 562 | aug_data_loader = get_tr_set(train_examples=train_data) 563 | model = active_train(data_loader=aug_data_loader, model=model, Epochs=Epochs, soft_loader=soft_loader) 564 | if single_use: 565 | train_data = train_data[:-new_sample_count] 566 | aug_total_count = new_sample_count 567 | report = evaluate('SeqMixAug', model) 568 | 569 | data_usage = len(train_data) 570 | if augment_method == lf_mixup: 571 | data_usage -= aug_total_count 572 | print_table.add_row(['Augmented Model', idx+1, data_usage/original_datasize, aug_total_count, report.split()[-2]]) 573 | print(print_table) 574 | save_result(prefix=prefix, func_paras=func_paras, report=report, table=print_table) 575 | 576 | return model 577 | 578 | 579 | 580 | if __name__ == "__main__": 581 | parser = argparse.ArgumentParser() 582 | parser.add_argument("--bert_model", type=str, default='bert-base-cased') 583 | parser.add_argument("--data_dir", type=str, default='data/') 584 | parser.add_argument("--do_eval", type=bool, default=True) 585 | parser.add_argument("--do_train", type=bool, default=True) 586 | parser.add_argument("--max_seq_length", type=int, default=128) 587 | parser.add_argument("--num_train_epochs", type=int, default=10) 588 | parser.add_argument("--task_name", type=str, default='ner') 589 | parser.add_argument("--output_dir", type=str, default='CoNLL/result') 590 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 591 | parser.add_argument("--prefix", type=str, default='file_save_name') 592 | parser.add_argument("--active_policy", type=str, default='nte') 593 | parser.add_argument("--augment_method", type=str, default='soft') 594 | parser.add_argument("--augment_rate", type=float, default=0.2) 595 | parser.add_argument("--hyper_alpha", type=float, default=8) 596 | 597 | 598 | 599 | # keep as default 600 | parser.add_argument("--server_ip", type=str, default='') 601 | parser.add_argument("--server_port", type=str, default='') 602 | parser.add_argument("--local_rank", type=int, default=-1) 603 | parser.add_argument("--no_cuda", type=bool, default=False) 604 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 605 | parser.add_argument("--train_batch_size", type=int, default=32) 606 | parser.add_argument("--seed", type=int, default=2020) 607 | parser.add_argument("--do_lower_case", type=bool, default=False) 608 | parser.add_argument("--weight_decay", type=float, default=0.01) 609 | parser.add_argument("--adam_epsilon", type=float, default=1e-08) 610 | parser.add_argument("--learning_rate", type=float, default=5e-05) 611 | parser.add_argument("--fp16", type=bool, default=False) 612 | parser.add_argument("--fp16_opt_level", type=str, default='O1') 613 | parser.add_argument("--eval_on", type=str, default='dev') 614 | parser.add_argument("--eval_batch_size", type=int, default=8) 615 | parser.add_argument("--max_grad_norm", type=float, default=1.0) 616 | 617 | 618 | args = parser.parse_args() 619 | 620 | sp_output_dir = 'out_conll/' 621 | wordidx2ebd, ebd2wordidx = get_word_embedding(sp_output_dir) 622 | mydict_values = np.array(list(wordidx2ebd.values())) 623 | mydict_keys = np.array(list(wordidx2ebd.keys())) 624 | 625 | '''Scoring''' 626 | # Load pre-trained model (weights) 627 | gpt_model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt') 628 | gpt_model.eval() 629 | # Load pre-trained model tokenizer (vocabulary) 630 | gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 631 | 632 | window_size = 5 633 | valid_tag_bar = 3 634 | mix_len = 5 635 | GUID_COUNT = 14041 636 | 637 | logging.basicConfig(format='%(asctime)s-%(levelname)s-%(name)s-%(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 638 | logger = logging.getLogger(__name__) 639 | if args.server_ip and args.server_port: 640 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 641 | print("Waiting for debugger attach") 642 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 643 | ptvsd.wait_for_attach() 644 | 645 | if args.local_rank == -1 or args.no_cuda: 646 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 647 | #n_gpu = torch.cuda.device_count() 648 | n_gpu = 1 649 | else: 650 | torch.cuda.set_device(args.local_rank) 651 | device = torch.device("cuda", args.local_rank) 652 | n_gpu = 1 653 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 654 | torch.distributed.init_process_group(backend='nccl') 655 | 656 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 657 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 658 | 659 | if args.gradient_accumulation_steps < 1: 660 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(args.gradient_accumulation_steps)) 661 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 662 | 663 | random.seed(args.seed) 664 | np.random.seed(args.seed) 665 | torch.manual_seed(args.seed) 666 | 667 | if not args.do_train and not args.do_eval: 668 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 669 | if not os.path.exists(args.output_dir): 670 | os.makedirs(args.output_dir) 671 | 672 | processor = NerProcessor() 673 | label_list = processor.get_labels() 674 | num_labels = len(label_list) + 1 675 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 676 | 677 | train_examples = None 678 | num_train_optimization_steps = 0 679 | if args.do_train: 680 | train_examples = processor.get_train_examples(args.data_dir) 681 | num_train_optimization_steps = int(len(train_examples)/args.train_batch_size/args.gradient_accumulation_steps)*args.num_train_epochs 682 | if args.local_rank != -1: 683 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 684 | 685 | if args.do_eval: 686 | if args.eval_on == 'dev': 687 | dev_examples = processor.get_dev_examples(args.data_dir) 688 | if args.eval_on == 'test': 689 | dev_examples = processor.get_test_examples(args.data_dir) 690 | 691 | if args.local_rank not in [-1, 0]: 692 | torch.distributed.barrier() 693 | 694 | # prepare model 695 | config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name) 696 | model = Ner.from_pretrained(args.bert_model, from_tf = False, config = config) 697 | 698 | if args.local_rank == 0: 699 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 700 | 701 | model.to(device) 702 | param_optimizer = list(model.named_parameters()) 703 | no_decay = ['bias','LayerNorm.weight'] 704 | optimizer_grouped_parameters = [ 705 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 706 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 707 | ] 708 | 709 | warmup_steps = int(args.warmup_proportion * num_train_optimization_steps) 710 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 711 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps) 712 | 713 | # For our experiment, the following can be ignored 714 | if args.fp16: 715 | try: 716 | from apex import amp 717 | except ImportError: 718 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 719 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 720 | 721 | # multi-gpu training (should be after apex fp16 initialization) 722 | if n_gpu > 1: 723 | model = torch.nn.DataParallel(model) 724 | 725 | if args.local_rank != -1: 726 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 727 | 728 | if args.active_policy=='random': 729 | active_policy = random_sampling 730 | if args.active_policy=='lc': 731 | active_policy = uncertainty_sampling 732 | if args.active_policy=='nte': 733 | active_policy = nte_sampling 734 | 735 | if args.augment_method=='lf': 736 | augment_method = lf_augment 737 | if args.augment_method=='slack': 738 | augment_method = slack_augment 739 | if args.augment_method=='soft': 740 | augment_method = soft_augment 741 | 742 | soft_model = active_augment_learn(init_flag=False, train_data=train_examples, augment_rate=args.augment_rate, hyper_alpha=args.hyper_alpha, active_policy=active_policy, augment_method=augment_method, prefix=args.prefix, Epochs=10) --------------------------------------------------------------------------------