├── ACL_2022_TreeMix.pdf ├── Augmentation.py ├── LICENSE ├── README.md ├── batch_train.py ├── online_augmentation ├── __init__.py └── __pycache__ │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── process_data ├── Load_data.py ├── __init__.py ├── __pycache__ │ ├── Augmentation.cpython-39.pyc │ ├── Load_data.cpython-38.pyc │ ├── Load_data.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── ceshi.cpython-39.pyc │ ├── settings.cpython-38.pyc │ └── settings.cpython-39.pyc ├── get_data.py └── settings.py ├── requirements.txt ├── run.py └── transformers_doc └── pytorch └── task_summary.ipynb /ACL_2022_TreeMix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/ACL_2022_TreeMix.pdf -------------------------------------------------------------------------------- /Augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from nltk import Tree 3 | from tqdm import tqdm 4 | import pandas as pd 5 | import argparse 6 | import numpy as np 7 | import os 8 | from datasets import load_dataset, Dataset, concatenate_datasets 9 | from process_data import settings 10 | from multiprocessing import Pool,cpu_count 11 | 12 | 13 | def modify(commands, debug=0): 14 | commands = commands.split(' ') 15 | verb = ['look', 'jump', 'walk', 'turn', 'run'] 16 | end_sign = ['left', 'right', 'twice', 'thrice'] 17 | add_pos = [] 18 | if debug: 19 | print(commands) 20 | for i in range(len(commands)-1): 21 | if commands[i] in end_sign and commands[i+1] in verb: 22 | add_pos.append(i+1) 23 | # commands.insert(i+1,'and') 24 | if debug: 25 | print(commands) 26 | if commands[i] in verb and commands[i+1] in verb: 27 | add_pos.append(i+1) 28 | # commands.insert(i+1,'and') 29 | if debug: 30 | print(commands) 31 | for i, pos in enumerate(add_pos): 32 | commands.insert(pos+i, 'and') 33 | if debug: 34 | print(commands) 35 | return ' '.join(commands) 36 | 37 | 38 | def c2a(commands, debug=0): 39 | verb = {'look': 'I_LOOK', 'walk': 'I_WALK', 40 | 'run': 'I_RUN', 'jump': 'I_JUMP'} 41 | direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'} 42 | times = {'twice': 2, 'thrice': 3} 43 | conjunction = ['and', 'after'] 44 | 45 | commands = commands.split(' ') 46 | actions = [] 47 | previous_command = [] 48 | pre_actions = [] 49 | flag = 0 50 | i = 0 51 | if debug: 52 | print('raw:', commands) 53 | while len(commands) > 0: 54 | current = commands.pop(0) 55 | if debug: 56 | print('-'*50) 57 | print('step ', i) 58 | i += 1 59 | print('current command:', current, len(commands)) 60 | print('curret waiting commands list:', previous_command) 61 | print('already actions:', actions) 62 | print('previous waiting actions:', pre_actions) 63 | if current in verb.keys() or current == 'turn' or current in conjunction: # add new actions 64 | if current == 'and': 65 | continue 66 | if not previous_command: # initialization 67 | previous_command.append(current) 68 | 69 | else: # one conmands over 70 | if debug: 71 | print('##### one commands over#####') 72 | current_action = translate(previous_command) 73 | previous_command = [] 74 | if debug: 75 | print( 76 | '****got new action from previous commandsa list:{}****'.format(current_action[0])) 77 | 78 | if current == 'after': 79 | pre_actions.extend(current_action) 80 | if debug: 81 | print('****this action into pre_actions****') 82 | elif pre_actions: 83 | if debug: 84 | print( 85 | '****pre_actions and current_actions into action list****') 86 | actions.extend(current_action) 87 | actions.extend(pre_actions) 88 | pre_actions = [] 89 | previous_command.append(current) 90 | else: 91 | # current is a verb 92 | previous_command.append(current) 93 | actions.extend(current_action) 94 | else: 95 | previous_command.append(current) 96 | if previous_command: 97 | current_action = translate(previous_command) 98 | actions.extend(current_action) 99 | if pre_actions: 100 | actions.extend(pre_actions) 101 | if debug: 102 | print('-'*50) 103 | print('over') 104 | print('previous_command', previous_command) 105 | print('pre_actions', pre_actions) 106 | print('current action', current_action) 107 | return actions 108 | 109 | 110 | def translate(previous_command): 111 | verb = {'look': 'I_LOOK', 'walk': 'I_WALK', 112 | 'run': 'I_RUN', 'jump': 'I_JUMP'} 113 | direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'} 114 | times = {'twice': 2, 'thrice': 3} 115 | conjunction = ['and', 'after'] 116 | if previous_command[-1] in times.keys(): 117 | return translate(previous_command[:-1])*times[previous_command[-1]] 118 | if len(previous_command) == 1: 119 | return [verb[previous_command[0]]] 120 | elif len(previous_command) == 2: 121 | if previous_command[0] == 'turn': 122 | return [direction[previous_command[1]]] 123 | elif previous_command[1] in direction: 124 | return [direction[previous_command[1]], verb[previous_command[0]]] 125 | elif len(previous_command) == 3: 126 | if previous_command[0] == 'turn': 127 | if previous_command[1] == 'opposite': 128 | return [direction[previous_command[2]]]*2 129 | else: 130 | return [direction[previous_command[2]]]*4 131 | elif previous_command[0] in verb.keys(): 132 | if previous_command[1] == 'opposite': 133 | return [direction[previous_command[2]], direction[previous_command[2]], verb[previous_command[0]]] 134 | else: 135 | return [direction[previous_command[2]], verb[previous_command[0]]]*4 136 | def set_seed(seed): 137 | random.seed(seed) 138 | np.random.seed(seed) 139 | def subtree_exchange_scan(args,parsing1,parsing2): 140 | new_sentence=None 141 | try: 142 | if args.debug: 143 | print('check5') 144 | t1 = Tree.fromstring(parsing1) 145 | t2 = Tree.fromstring(parsing2) 146 | t1_len=len(t1.leaves()) 147 | t2_len=len(t2.leaves()) 148 | # ----- restrict label-------------- 149 | # candidate_subtree1=list(t1.subtrees(lambda t: t.label() in ['VP','VB'])) 150 | # candidate_subtree2 = list(t2.subtrees(lambda t: t.label() in ['VP', 'VB'])) 151 | # tree_labels1 = [tree.label() for tree in candidate_subtree1] 152 | # tree_labels2 = [tree.label() for tree in candidate_subtree2] 153 | # same_labels = list(set(tree_labels1) & set(tree_labels2)) 154 | # if not same_labels: 155 | # if args.debug: 156 | # print('no same label') 157 | # return None 158 | # select_label=random.choice(same_labels) 159 | # candidate1 = random.choice( 160 | # [t for t in candidate_subtree1 if t.label() == select_label]) 161 | # candidate2 = random.choice( 162 | # [t for t in candidate_subtree2 if t.label() == select_label]) 163 | candidate_subtree1 = list(t1.subtrees()) 164 | candidate_subtree2 = list(t2.subtrees()) 165 | candidate1 = random.choice( 166 | [t for t in candidate_subtree1]) 167 | candidate2 = random.choice( 168 | [t for t in candidate_subtree2]) 169 | exchanged_span = ' '.join(candidate1.leaves()) 170 | exchanging_span = ' '.join(candidate2.leaves()) 171 | original_sentence = ' '.join(t1.leaves()) 172 | new_sentence = original_sentence.replace(exchanged_span, exchanging_span) 173 | debug=0 174 | if args.debug: 175 | print('check6') 176 | print(new_sentence) 177 | debug=1 178 | modified_sentence=modify(new_sentence,debug) 179 | new_label=c2a(modified_sentence,debug) 180 | if args.showinfo: 181 | print('cand1:', ' '.join(candidate1.leaves()), 182 | 'cand2:', ' '.join(candidate2.leaves())) 183 | # print([' '.join(c.leaves()) for c in cand1]) 184 | # print([' '.join(c.leaves()) for c in cand2]) 185 | print('src1:', parsing1) 186 | print('src2:', parsing2) 187 | print('new:',new_sentence) 188 | return modified_sentence,new_label 189 | except Exception as e: 190 | if args.debug: 191 | print('Error!!') 192 | print(e) 193 | return None 194 | def subtree_exchange_single(args,parsing1,label1,parsing2,label2,lam1,lam2): 195 | """ 196 | For a pair sentence, exchange subtree and return a label based on subtree length 197 | 198 | Find the candidate subtree, and extract correspoding span, and exchange span 199 | 200 | """ 201 | if args.debug: 202 | print('check4') 203 | assert lam1>lam2 204 | t1=Tree.fromstring(parsing1) 205 | original_sentence=' '.join(t1.leaves()) 206 | t1_len=len(t1.leaves()) 207 | candidate_subtree1=list(t1.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) 208 | t2=Tree.fromstring(parsing2) 209 | candidate_subtree2=list(t2.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2)) 210 | if args.debug: 211 | print('check5') 212 | # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2)) 213 | if len(candidate_subtree1)==0 or len(candidate_subtree2)==0: 214 | 215 | return None 216 | if args.debug: 217 | print('check6') 218 | if args.phrase_label: 219 | if args.debug: 220 | print('phrase_label') 221 | tree_labels1=[tree.label() for tree in candidate_subtree1] 222 | tree_labels2=[tree.label() for tree in candidate_subtree2] 223 | same_labels=list(set(tree_labels1)&set(tree_labels2)) 224 | if not same_labels: 225 | # print('无相同类型的子树') 226 | return None 227 | if args.phrase_length: 228 | if args.debug: 229 | print('phrase_lable_length') 230 | candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves()) and t1.label()==t2.label()] 231 | candidate1,candidate2= random.choice(candidate) 232 | else: 233 | if args.debug: 234 | print('phrase_lable') 235 | select_label=random.choice(same_labels) 236 | candidate1=random.choice([t for t in candidate_subtree1 if t.label()==select_label]) 237 | candidate2=random.choice([t for t in candidate_subtree2 if t.label()==select_label]) 238 | else: 239 | if args.debug: 240 | print('no phrase_label') 241 | if args.phrase_length: 242 | if args.debug: 243 | print('phrase_length') 244 | candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves())] 245 | candidate1,candidate2= random.choice(candidate) 246 | else: 247 | if args.debug: 248 | print('normal TreeMix') 249 | candidate1=random.choice(candidate_subtree1) 250 | candidate2=random.choice(candidate_subtree2) 251 | 252 | exchanged_span=' '.join(candidate1.leaves()) 253 | exchanged_len=len(candidate1.leaves()) 254 | exchanging_span=' '.join(candidate2.leaves()) 255 | new_sentence=original_sentence.replace(exchanged_span,exchanging_span) 256 | # if args.mixup_cross: 257 | new_label=np.zeros(len(args.label_list)) 258 | 259 | exchanging_len=len(candidate2.leaves()) 260 | new_len=t1_len-exchanged_len+exchanging_len 261 | 262 | new_label[int(label2)]+=exchanging_len/new_len 263 | new_label[int(label1)]+=(new_len-exchanging_len)/new_len 264 | 265 | # else: 266 | # new_label=label1 267 | if args.showinfo: 268 | # print('树1 {}'.format(t1)) 269 | # print('树2 {}'.format(t2)) 270 | print('-'*50) 271 | print('candidate1:{}'.format([' '.join(x.leaves()) for x in candidate_subtree1])) 272 | print('candidate2:{}'.format([' '.join(x.leaves()) for x in candidate_subtree2])) 273 | print('sentence1 ## {} [{}]\nsentence2 ## {} [{}]'.format(original_sentence,label1,' '.join(t2.leaves()),label2)) 274 | print('{} <=========== {}'.format(exchanged_span,exchanging_span)) 275 | print('new sentence: ## {}'.format(new_sentence)) 276 | print('new label:[{}]'.format(new_label)) 277 | return new_sentence,new_label 278 | def subtree_exchange_pair(args,parsing11,parsing12,label1,parsing21,parsing22,label2,lam1,lam2): 279 | """ 280 | For a pair sentence, exchange subtree and return a label based on subtree length 281 | 282 | Find the candidate subtree, and extract correspoding span, and exchange span 283 | 284 | """ 285 | assert lam1>lam2 286 | lam2=lam1-0.2 287 | t11=Tree.fromstring(parsing11) 288 | t12=Tree.fromstring(parsing12) 289 | original_sentence1=' '.join(t11.leaves()) 290 | t11_len=len(t11.leaves()) 291 | original_sentence2=' '.join(t12.leaves()) 292 | t12_len=len(t12.leaves()) 293 | candidate_subtree11=list(t11.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2)) 294 | candidate_subtree12=list(t12.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2)) 295 | t21=Tree.fromstring(parsing21) 296 | t22=Tree.fromstring(parsing22) 297 | t21_len=len(t21.leaves()) 298 | t22_len=len(t22.leaves()) 299 | candidate_subtree21=list(t21.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2)) 300 | candidate_subtree22=list(t22.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2)) 301 | if args.showinfo: 302 | print('\n') 303 | print('*'*50) 304 | print('t11_len:{}\tt12_len:{}\tt21_len:{}\tt22_len:{}\ncandidate_subtree11:{}\ncandidate_subtree12:{}\ncandidate_subtree21:{}\ncandidate_subtree21:{}' 305 | .format(t11_len,t12_len,t21_len,t22_len,candidate_subtree11,candidate_subtree12,candidate_subtree21,candidate_subtree22)) 306 | 307 | # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2)) 308 | if len(candidate_subtree11)==0 or len(candidate_subtree12)==0 or len(candidate_subtree21)==0 or len(candidate_subtree22)==0: 309 | # print("this pair fail",len(candidate_subtree1),len(candidate_subtree2)) 310 | return None 311 | 312 | if args.phrase_label: 313 | tree_labels11=[tree.label() for tree in candidate_subtree11] 314 | tree_labels12=[tree.label() for tree in candidate_subtree12] 315 | tree_labels21=[tree.label() for tree in candidate_subtree21] 316 | tree_labels22=[tree.label() for tree in candidate_subtree22] 317 | same_labels1=list(set(tree_labels11)&set(tree_labels21)) 318 | same_labels2=list(set(tree_labels12)&set(tree_labels22)) 319 | if not (same_labels1 and same_labels2) : 320 | # print('无相同类型的子树') 321 | return None 322 | select_label1=random.choice(same_labels1) 323 | select_label2=random.choice(same_labels2) 324 | displaced1=random.choice([t for t in candidate_subtree11 if t.label()==select_label1]) 325 | displacing1=random.choice([t for t in candidate_subtree21 if t.label()==select_label1]) 326 | displaced2=random.choice([t for t in candidate_subtree12 if t.label()==select_label2]) 327 | displacing2=random.choice([t for t in candidate_subtree22 if t.label()==select_label2]) 328 | else: 329 | displaced1=random.choice(candidate_subtree11) 330 | displacing1=random.choice(candidate_subtree21) 331 | displaced2=random.choice(candidate_subtree12) 332 | displacing2=random.choice(candidate_subtree22) 333 | 334 | 335 | displaced_span1=' '.join(displaced1.leaves()) 336 | displaced_len1=len(displaced1.leaves()) 337 | displacing_span1=' '.join(displacing1.leaves()) 338 | new_sentence1=original_sentence1.replace(displaced_span1,displacing_span1) 339 | 340 | displaced_span2=' '.join(displaced2.leaves()) 341 | displaced_len2=len(displaced2.leaves()) 342 | displacing_span2=' '.join(displacing2.leaves()) 343 | new_sentence2=original_sentence2.replace(displaced_span2,displacing_span2) 344 | 345 | # if args.mixup_cross: 346 | new_label=np.zeros(len(args.label_list)) 347 | displacing_len1=len(displacing1.leaves()) 348 | displacing_len2=len(displacing2.leaves()) 349 | new_len=t11_len+t12_len-displaced_len1-displaced_len2+displacing_len1+displacing_len2 350 | displacing_len=displacing_len1+displacing_len2 351 | new_label[int(label2)]+=displacing_len/new_len 352 | new_label[int(label1)]+=(new_len-displacing_len)/new_len 353 | 354 | 355 | if args.showinfo: 356 | print('Before\nsentence1:{}\nsentence2:{}\nlabel1:{}\tlabel2:{}'.format(original_sentence1,original_sentence2,label1,label2)) 357 | print('replaced1:{} replacing1:{}\nreplaced2:{} replacing2:{}'.format(displaced_span1,displacing_span1,displaced_span2,displacing2)) 358 | print('After\nsentence1:{}\nsentence2:{}\nnew_label:{}'.format(new_sentence1,new_sentence2,new_label)) 359 | print('*'*50) 360 | 361 | # print('被替换的span:{}\n替换的span:{}'.format(exchanged_span,exchanging_span)) 362 | return new_sentence1,new_sentence2,new_label 363 | def augmentation(args,data,seed,dataset,aug_times,lam1=0.1,lam2=0.3): 364 | """ 365 | generate aug_num augmentation dataset 366 | input: 367 | dataset --- pd.dataframe 368 | output: 369 | aug_dataset --- pd.dataframe 370 | """ 371 | generated_list=[] 372 | # print('check2') 373 | if args.debug: 374 | print('check3') 375 | shuffled_dataset=dataset.shuffle() 376 | success=0 377 | total=0 378 | with tqdm(total=int(aug_times)*len(dataset)) as bar: 379 | while success < int(aug_times)*len(dataset): 380 | # for idx in range(len(dataset)): 381 | idx = total % len(dataset) 382 | if args.fraction: 383 | bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | fraction:{} |'.format(data,seed,aug_times,args.fraction)) 384 | else: 385 | bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | '.format(data,seed,aug_times)) 386 | 387 | if args.data_type=='single_cls': 388 | if args.debug: 389 | print('check4') 390 | if 'None' not in [dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing1']]: 391 | aug_sample=subtree_exchange_single( 392 | args,dataset[idx]['parsing1'],dataset[idx][args.label_name], 393 | shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx][args.label_name], 394 | lam1,lam2) 395 | else: 396 | continue 397 | 398 | elif args.data_type=='pair_cls': 399 | # print('check4:pair') 400 | if args.debug: 401 | print('check4') 402 | if 'None' not in [dataset[idx]['parsing1'], dataset[idx]['parsing2'], dataset[idx][args.label_name], 403 | shuffled_dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing2']]: 404 | aug_sample=subtree_exchange_pair( 405 | args,dataset[idx]['parsing1'],dataset[idx]['parsing2'],dataset[idx][args.label_name], 406 | shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx]['parsing2'],shuffled_dataset[idx][args.label_name], 407 | lam1,lam2) 408 | else: 409 | continue 410 | 411 | elif args.data_type=='semantic_parsing': 412 | if args.debug: 413 | print('check4') 414 | aug_sample=subtree_exchange_scan( 415 | args,dataset[idx]['parsing1'], 416 | shuffled_dataset[idx]['parsing1']) 417 | 418 | if args.debug: 419 | print('ok') 420 | print('got one aug_sample : {}'.format(aug_sample)) 421 | if aug_sample: 422 | bar.update(1) 423 | success+=1 424 | generated_list.append(aug_sample) 425 | else: 426 | if args.debug: 427 | print('fail this time') 428 | total+=1 429 | #De-duplication 430 | # generated_list=list(set(generated_list)) 431 | return generated_list 432 | def parse_argument(): 433 | parser=argparse.ArgumentParser() 434 | parser.add_argument('--lam1',type=float,default=0.3) 435 | parser.add_argument('--lam2',type=float,default=0.1) 436 | parser.add_argument('--times',default=[2,5],nargs='+',help='augmentation times list') 437 | parser.add_argument('--min_token',type=int,default=0,help='minimum token numbers of augmentation samples') 438 | parser.add_argument('--label_name',type=str,default='label') 439 | parser.add_argument('--phrase_label',action='store_true',help='subtree lable must be same if set') 440 | parser.add_argument('--phrase_length',action='store_true',help='subtree phrase must be same length if set') 441 | # parser.add_argument('--data_type',type=str,required=True,help='This is a single classification task or pair sentences classification task') 442 | parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list') 443 | parser.add_argument('--showinfo',action='store_true') 444 | parser.add_argument('--mixup_cross',action='store_false',help="NO mix across different classes if set") 445 | parser.add_argument('--low_resource',action='store_true',help="create low source raw and aug datasets if set") 446 | parser.add_argument('--debug',action='store_true',help="display debug information") 447 | parser.add_argument('--data',nargs='+',required=True,help='data list') 448 | parser.add_argument('--proc',type=int,help='processing number for multiprocessing') 449 | args=parser.parse_args() 450 | if not args.proc: 451 | args.proc=cpu_count() 452 | return args 453 | def create_aug_data(args,dataset,data,seed,times,test_dataset=None): 454 | 455 | if args.phrase_label and not args.phrase_length: 456 | prefix_save_path=os.path.join(args.output_dir,'samephraselabel_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) 457 | elif args.phrase_length and not args.phrase_label: 458 | prefix_save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) 459 | elif args.phrase_length and args.phrase_label: 460 | prefix_save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) 461 | elif not args.mixup_cross: 462 | prefix_save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) 463 | elif args.data_type == 'semantic_parsing': 464 | prefix_save_path = os.path.join(args.output_dir, 'scan_times{}_seed{}'.format( 465 | times, seed)) 466 | else: 467 | prefix_save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2)) 468 | if args.debug: 469 | print('check1') 470 | if not [file_name for file_name in os.listdir(args.output_dir) if file_name.startswith(prefix_save_path)]: 471 | if args.min_token: 472 | dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][0]].split(' '))>args.min_token) 473 | if tasksettings.task_to_keys[data][1]: 474 | dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][1]].split(' '))>args.min_token) 475 | if args.data_type=='single_cls': 476 | if args.debug: 477 | print('check2') 478 | if args.mixup_cross: 479 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name]) 480 | else: 481 | if args.debug: 482 | print('label_list',args.label_list) 483 | new_pd=None 484 | for i in args.label_list: 485 | samples=dataset.filter(lambda sample:sample[args.label_name]==i) 486 | dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name]) 487 | new_pd=pd.concat([new_pd,dataframe],axis=0) 488 | elif args.data_type=='pair_cls': 489 | if args.debug: 490 | print('check2') 491 | if args.mixup_cross: 492 | # print('check1') 493 | # print(args, seed, dataset, times,tasksettings.task_to_keys[data][0], tasksettings.task_to_keys[data][1], args.label_name) 494 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name]) 495 | else: 496 | new_pd=None 497 | if args.debug: 498 | print('label_list',args.label_list) 499 | for i in args.label_list: 500 | samples=dataset.filter(lambda sample:sample[args.label_name]==i) 501 | dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name]) 502 | new_pd=pd.concat([new_pd,dataframe],axis=0) 503 | elif args.data_type=='semantic_parsing': 504 | if args.debug: 505 | print('check2') 506 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times),columns=[tasksettings.task_to_keys[data][0],args.label_name]) 507 | 508 | 509 | new_pd=new_pd.sample(frac=1) 510 | 511 | 512 | 513 | if args.data_type == 'semantic_parsing': 514 | 515 | train_pd=pd.read_csv('DATA/ADDPRIM_JUMP/data/train.csv') 516 | frames = [train_pd,new_pd] 517 | aug_dataset=pd.concat(frames,ignore_index=True) 518 | else: 519 | aug_dataset = Dataset.from_pandas(new_pd) 520 | aug_dataset = aug_dataset.remove_columns("__index_level_0__") 521 | 522 | if args.phrase_label: 523 | save_path = os.path.join(args.output_dir, 'samephraselabel_times{}_min{}_seed{}_{}_{}_{}k'.format( 524 | times, args.min_token, seed, args.lam1, args.lam2, round(len(new_pd)//1000,-1))) 525 | elif args.phrase_length and not args.phrase_label: 526 | save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) 527 | elif args.phrase_length and args.phrase_label: 528 | save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) 529 | elif not args.mixup_cross: 530 | save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) 531 | elif args.data_type=='semantic_parsing': 532 | save_path_train = os.path.join(prefix_save_path, 'train.csv') 533 | save_path_test = os.path.join(prefix_save_path, 'test.csv') 534 | else: 535 | save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1))) 536 | if args.data_type == 'semantic_parsing': 537 | 538 | 539 | if not os.path.exists(prefix_save_path): 540 | os.makedirs(prefix_save_path) 541 | 542 | aug_dataset.to_csv(save_path_train,index=0) 543 | test_dataset.to_csv(save_path_test,index=0) 544 | else: 545 | aug_dataset.save_to_disk(save_path) 546 | else: 547 | print('file {} already exsits!'.format(prefix_save_path)) 548 | 549 | 550 | def main(): 551 | p=Pool(args.proc) 552 | for data in args.data: 553 | path_dir=os.path.join('DATA',data.upper()) 554 | if data in tasksettings.pair_datasets: 555 | args.data_type='pair_cls' 556 | elif data in tasksettings.SCAN: 557 | args.label_name='actions' 558 | args.data_type='semantic_parsing' 559 | testset_path=os.path.join(path_dir,'data','test.csv') 560 | else: 561 | args.data_type='single_cls' 562 | if data=='trec': 563 | try: 564 | assert args.label_name in ['label-fine', 'label-coarse'] 565 | except AssertionError: 566 | raise(AssertionError( 567 | "If you want to train on trec dataset with augmentation, you have to name the label of split in ['label-fine', 'label-coarse']")) 568 | 569 | print(args.label_name,data) 570 | args.output_dir=os.path.join(path_dir,'generated/{}'.format(args.label_name)) 571 | else: 572 | args.output_dir=os.path.join(path_dir,'generated') 573 | args.data_path=os.path.join(path_dir,'data','train_parsing.csv') 574 | 575 | dataset=load_dataset('csv',data_files=[args.data_path],split='train') 576 | if args.data_type=='semantic_parsing': 577 | testset=load_dataset('csv',data_files=[testset_path],split='train') 578 | if args.data_type in ['single_cls','pair_cls']: 579 | args.label_list=list(set(dataset[args.label_name])) #根据data做一个表查找所有的label 580 | for seed in args.seeds: 581 | seed=int(seed) 582 | set_seed(seed) 583 | dataset=dataset.shuffle() 584 | if args.low_resource: 585 | for fraction in tasksettings.low_resource[data]: 586 | args.fraction=fraction 587 | train_dataset=dataset.select(random.sample(range(len(dataset)),int(fraction*len(dataset)))) 588 | low_resource_dir=os.path.join(path_dir,'low_resource','low_resource_{}'.format(fraction),'seed_{}'.format(seed)) 589 | if not os.path.exists(low_resource_dir): 590 | os.makedirs(low_resource_dir) 591 | args.output_dir=low_resource_dir 592 | train_path=os.path.join(args.output_dir,'partial_train') 593 | if not os.path.exists(train_path): 594 | train_dataset.save_to_disk(train_path) 595 | for times in args.times: 596 | times=int(times) 597 | p.apply_async(create_aug_data, args=( 598 | args, train_dataset, data, seed, times)) 599 | else: 600 | args.fraction=None 601 | for times in args.times: 602 | times=int(times) 603 | p.apply_async(create_aug_data, args=( 604 | args, dataset, data, seed, times,testset)) 605 | print('='*20,'Start generating augmentation datsets !',"="*20) 606 | # p.close() 607 | # p.join() 608 | 609 | p.close() 610 | p.join() 611 | print('='*20, 'Augmenatation done !', "="*20) 612 | if __name__=='__main__': 613 | 614 | tasksettings=settings.TaskSettings() 615 | args=parse_argument() 616 | print(args) 617 | main() 618 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Le Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeMix: Compositional Constituency-based Data Augmentation for Natural Language Understanding (NAACL 2022) 2 | 3 | Pytorch Implementation of [TreeMix](https://arxiv.org/abs/2205.06153) 4 | 5 | ![treemix_example](https://i.loli.net/2021/11/30/EhX9iZax3s6jpTC.jpg) 6 | 7 | ## Abstract 8 | 9 | Data augmentation is an effective approach to tackle over-fitting. Many previous works have proposed different data augmentations strategies for NLP, such as noise injection, word replacement, back-translation etc. Though effective, they missed one important characteristic of language–compositionality, meaning of a complex expression is built from its subparts. Motivated by this, we propose a compositional data augmentation approach for natural language understanding called TreeMix. Specifically, TreeMix leverages constituency parsing tree to decompose sentences into constituent sub-structures and the Mixup data augmentation technique to recombine them to generate new sentences. Compared with previous approaches, TreeMix introduces greater diversity to the samples generated and encourages models to learn compositionality of NLP data. Extensive experiments on text classification and semantic parsing benchmarks demonstrate that TreeMix outperforms current stateof-the-art data augmentation methods. 10 | 11 | ## Code Structure 12 | 13 | ``` 14 | |__ DATA 15 | |__ SST2 16 | |__ data 17 | |__ train.csv --> raw train dataset 18 | |__ test.csv --> raw test dataset 19 | |__ train_parsing --> consituency parsing results 20 | |__ generated 21 | |__ times2_min0_seed0_0.3_0.1_7k --> augmentation dataset(hugging face dataset format) with 2 times bigger, seed 0, lambda_L=0.1, lambda_U=0.3, total size=7k 22 | |__ logs --> best results log 23 | |__ runs --> tensorboard results 24 | |__ aug --> augmentation only baseline 25 | |__ raw --> standard baseline 26 | |__ raw_aug --> TreeMix results 27 | |__ checkpoints 28 | |__ best.pt --> Best model checkpoints 29 | |__ process_data / --> Download & Semantic parsing using Stanfordcorenlp tooktiks 30 | |__ Load_data.py --> Loading raw dataset and augmentation dataset 31 | |__ get_data.py --> Download all dataset from huggingface dataset and perform constituency parsing to obtain processed dataset 32 | |__ settings.py --> Hyperparameter settings & Task settings 33 | |__ online_augmentation 34 | |__ __init__.py --> Random Mixup 35 | |__ Augmentation.py --> Subtree Exchange augmentation method based on consituency parsing results for all dataset(single sentence classification, sentence relation classification, SCAN dataset) 36 | |__ run.py --> Train for one dataset 37 | |__ batch_train.py -> Train with different datasets and different settings, specified by giving specific arguments 38 | ``` 39 | 40 | ### Getting Started 41 | 42 | ``` 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | Note that to successfully run TreeMix, you must install `stanfordcorenlp`. Please refer to this 47 | 48 | [corenlp]( https://stanfordnlp.github.io/CoreNLP/download.html "stanfordcorenlp") for more information. 49 | 50 | ### Download & Constituency Parsing 51 | 52 | ``` 53 | cd process_data 54 | python get_data.py --data {DATA} --corenlp_dir {CORENLP} 55 | ``` 56 | 57 | `DATA` indicates the dataset name, `CORENLP` indicates the directory of `stanfordcorenlp` . After this process, you could get corresponding data folder in `DATA/` and `train_parsing.csv`. 58 | 59 | ### TreeMix Augmentation 60 | 61 | ``` 62 | python Augmentation.py --data {DATASET} --times {TIMES} --lam1 {LAM1} --lam2 {LAM2} --seeds {SEEDS} 63 | ``` 64 | 65 | Augmentation with different arguments, will generate #(TIMES)×#(SEEDS) extra dataset. `DATASET` could be **list of data name** such as 'sst2 rte', since 'trec' has two versions, you need to input `--label_name {}` to specify whether the trec-fine or trec-coarse set. Besides, by typing `--low_resource` , it will generated partial augmentation dataset as well as partial train set. You can modify the hyperparameter `lambda_U` and `lambda_L` by changing `LAM1` and `LAM2` . `TIMES` could be **a list of intergers** such as 2,5 to assign the size of the augmentation dataset. 66 | 67 | ### Model Training 68 | 69 | ``` 70 | python batch_train.py --mode {MODE} --data {DATASET} 71 | ``` 72 | 73 | Evaluation one dataset for all its augmenation set with specific mode. `MODE` can be 'raw', 'aug', 'raw_aug', which indicates train the model with raw dataset only, augmentation dataset only and combination of raw and augmentation set respectively. `DATASET` should be **one specific dataset** name. This will report all results of a specifc dataset in `/log`. If not specified, the hyperparameters will be set as in `/process_data/settings.py`, please look into this file for more arguments information. 74 | -------------------------------------------------------------------------------- /batch_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from process_data.settings import TaskSettings 4 | def parse_argument(): 5 | parser = argparse.ArgumentParser(description='download and parsing datasets') 6 | parser.add_argument('--data',type=str,required=True,help='data list') 7 | parser.add_argument('--aug_dir',help='Augmentation file directory') 8 | parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list') 9 | parser.add_argument('--modes',nargs='+',required=True,help='seed list') 10 | parser.add_argument('--label_name',type=str,default='label') 11 | # parser.add_argument('--batch_size',default=128,type=int,help='train examples in each batch') 12 | # parser.add_argument('--aug_batch_size',default=128,type=int,help='train examples in each batch') 13 | parser.add_argument('--random_mix',type=str,choices=['zero_one','zero','one','all'],help="random mixup ") 14 | parser.add_argument('--prefix',type=str,help="only choosing the datasets with the prefix,for ablation study") 15 | parser.add_argument('--GPU',type=int,default=0,help="available GPU number") 16 | parser.add_argument('--low_resource', action='store_true', 17 | help='whther to train low resource dataset') 18 | 19 | args=parser.parse_args() 20 | if args.data=='trec': 21 | try: 22 | assert args.label_name in ['label-fine','label-coarse'] 23 | except AssertionError: 24 | raise( AssertionError("If you want to train on TREC dataset with augmentation, you have to name the label of split either 'label-fine' or 'label-coarse'")) 25 | args.aug_dir = os.path.join('DATA', args.data.upper(), 'generated',args.label_name) 26 | if args.aug_dir is None : 27 | args.aug_dir=os.path.join('DATA',args.data.upper(),'generated') 28 | 29 | if 'aug' in args.modes: 30 | try: 31 | assert [file for file in os.listdir(args.aug_dir) if 'times' in file] 32 | except AssertionError: 33 | raise( AssertionError( "{}".format('This directory has no augmentation file, please input correct aug_dir!') ) ) 34 | if args.low_resource: 35 | try: 36 | args.low_resource = os.path.join('DATA', args.data.upper(),'low_resource') 37 | assert os.path.exists(args.low_resource) 38 | except AssertionError: 39 | raise( AssertionError("There is no any low resource datasets in this data")) 40 | 41 | return args 42 | def batch_train(args): 43 | for seed in args.seeds: 44 | # for aug_file in os.listdir(args.aug_dir): 45 | for mode in args.modes: 46 | if mode=='raw': 47 | # data_path=os.path.join(args.aug_dir,aug_file) 48 | if args.random_mix: 49 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,args.random_mix,**settings[args.data])) 50 | else: 51 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,**settings[args.data])) 52 | else: 53 | for aug_file in os.listdir(args.aug_dir): 54 | if args.prefix: 55 | # only train on file with prefix 56 | if aug_file.startswith(args.prefix): 57 | aug_file_path = os.path.join( 58 | args.aug_dir, aug_file) 59 | assert os.path.exists(aug_file_path) 60 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( 61 | args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data])) 62 | else: 63 | # train on every file in dir 64 | aug_file_path = os.path.join( 65 | args.aug_dir, aug_file) 66 | assert os.path.exists(aug_file_path) 67 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( 68 | args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data])) 69 | def low_resource_train(args): 70 | for partial_split in os.listdir(args.low_resource): 71 | partial_split_path=os.path.join(args.low_resource,partial_split) 72 | args.output_dir = os.path.join( 73 | args.low_resource_dir, partial_split) 74 | if not os.path.exists(args.output_dir): 75 | os.makedirs(args.output_dir) 76 | for seed_num in os.listdir(partial_split_path): 77 | partial_split_seed_path=os.path.join(partial_split_path,seed_num) 78 | for mode in args.modes: 79 | if mode=='raw': 80 | if args.random_mix: 81 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} ' 82 | .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]), args.output_dir, args.label_name, mode, args.data, args.random_mix, **settings[args.data])) 83 | else: 84 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} ' 85 | .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]),args.output_dir, args.label_name, mode, args.data, **settings[args.data])) 86 | elif mode=='raw_aug': 87 | for aug_file in [file for file in os.listdir(partial_split_seed_path) if file.startswith('times')]: 88 | aug_file_path=os.path.join(partial_split_seed_path,aug_file) 89 | assert os.path.exists(aug_file_path) 90 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format( 91 | args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]) , args.output_dir, args.label_name, mode, args.data, aug_file_path, **settings[args.data])) 92 | if __name__=='__main__': 93 | args=parse_argument() 94 | tasksettings=TaskSettings() 95 | settings=tasksettings.train_settings 96 | if args.low_resource: 97 | args.low_resource_dir=os.path.join('DATA',args.data.upper(),'runs','low_resource') 98 | if not os.path.exists(args.low_resource_dir): 99 | os.makedirs(args.low_resource_dir) 100 | low_resource_train(args) 101 | else: 102 | batch_train(args) 103 | -------------------------------------------------------------------------------- /online_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | def random_mixup_process(args,ids1,lam): 5 | 6 | rand_index=torch.randperm(ids1.shape[0]) 7 | lenlist=[] 8 | # rand_index=torch.randperm(len(ids1)) 9 | for x in ids1: 10 | mask=((x!=101)&(x!=0)&(x!=102)) 11 | lenlist.append(int(mask.sum())) 12 | lenlist2=torch.tensor(lenlist)[rand_index] 13 | spanlen=torch.tensor([int(x*lam) for x in lenlist]) 14 | 15 | beginlist=[1+random.randint(0,x-int(x*lam)) for x in lenlist] 16 | beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] 17 | if args.difflen: 18 | 19 | spanlen2=torch.tensor([int(x*lam) for x in lenlist2]) 20 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)] 21 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)] 22 | else: 23 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] 24 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)] 25 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] 26 | 27 | ids2=ids1.clone() 28 | if args.difflen: 29 | for idx in range(len(ids1)): 30 | tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]] 31 | ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp)))) 32 | else: 33 | for idx in range(len(ids1)): 34 | ids1[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]] 35 | assert ids1.shape==ids2.shape 36 | return ids1,rand_index 37 | def mixup_01(args,input_ids,lam,idx1,idx2): 38 | ''' 39 | 01交换 40 | ''' 41 | difflen=False 42 | random_index=torch.zeros(len(idx1)+len(idx2)).long() 43 | random_index[idx1]=torch.tensor(np.random.choice(idx2,size=len(idx1))) 44 | random_index[idx2]=torch.tensor(np.random.choice(idx1,size=len(idx2))) 45 | 46 | len_list1=[] 47 | len_list2=[] 48 | for input_id1 in input_ids: 49 | #计算各个句子的具体token数 50 | mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102)) 51 | len_list1.append(int(mask.sum())) 52 | # print(len_list1) 53 | len_list2=torch.tensor(len_list1)[random_index] 54 | 55 | spanlen=torch.tensor([int(x*lam) for x in len_list1]) 56 | beginlist=[1+random.randint(0,x-int(x*lam)) for x in len_list1] 57 | beginlist2=[1+random.randint(0,x-y) for x,y in zip(len_list2,spanlen)] 58 | if difflen: 59 | spanlen2=torch.tensor([int(x*lam) for x in len_list2]) 60 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)] 61 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)] 62 | else: 63 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)] 64 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)] 65 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] 66 | new_ids=input_ids.clone() 67 | 68 | # print(random_index) 69 | if difflen: 70 | for idx in range(len(ids1)): 71 | tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]] 72 | ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp)))) 73 | else: 74 | for idx in range(len(input_ids)): 75 | new_ids[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=input_ids[random_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]] 76 | # for i in range(len(input_ids)): 77 | # print('{}:交换的是{}与{},其中第1句选取的是从{}开始到{}的句子,第2句选取的是从{}开始到{}结束的句子'.format( 78 | # i,i,random_index[i],spanlist[i][0],spanlist[i][0]+spanlist[i][1],spanlist2[i][0],spanlist2[i][0]+spanlist2[i][1])) 79 | return new_ids,random_index 80 | def mixup(args,input_ids,lam,idx1,idx2=None): 81 | ''' 82 | 只针对idx1索引对应的sample内部进行交换,如果idx2也给的话就是idx1 idx2进行交换 83 | ''' 84 | select_input_ids=torch.index_select(input_ids,0,idx1) 85 | rand_index=torch.randperm(select_input_ids.shape[0]) 86 | new_idx=torch.tensor(list(range(input_ids.shape[0]))) 87 | len_list1=[] 88 | len_list2=[] 89 | for input_id1 in select_input_ids: 90 | #calculte length of tokens in each sentence 91 | mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102)) 92 | len_list1.append(int(mask.sum())) 93 | len_list2=torch.tensor(len_list1)[rand_index] 94 | 95 | spanlen=torch.tensor([int(x*lam) for x in len_list1]) 96 | beginlist=[1+random.randint(0,x-y) for x,y in zip(len_list1,spanlen)] 97 | beginlist2=[1+random.randint(0,max(0,x-y)) for x,y in zip(len_list2,spanlen)] 98 | 99 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)] 100 | spanlist2 = [(x, min(int(y),z)) for x, y, z in zip(beginlist2, spanlen, len_list2)] 101 | new_ids=input_ids.clone() 102 | new_idx[idx1]=idx1[rand_index] 103 | for i,idx in enumerate(idx1): 104 | new_ids[idx][spanlist[i][0]:spanlist[i][0]+spanlist[i][1]]=input_ids[idx1[rand_index[i]]][spanlist2[i][0]:spanlist2[i][0]+spanlist2[i][1]] 105 | 106 | return new_ids,new_idx 107 | def random_mixup(args,ids1,lab1,lam): 108 | """ 109 | function: random select span to exchange based on lam to decide span length and rand_index decide selected candidate exchange sentece 110 | input: 111 | ids1 -- tensors of tensors input_ids 112 | lab1 -- tensors of tensors labels 113 | lam -- span length rate 114 | output: 115 | ids1 -- tensors of tensors , exchanged span 116 | rand_index -- tensors , permutation index 117 | 118 | """ 119 | if args.random_mix=='all': 120 | return mixup(args,ids1,lam,torch.tensor(range(ids1.shape[0]))) 121 | else: 122 | pos_idx=(lab1==1).nonzero().squeeze() 123 | neg_idx=(lab1==0).nonzero().squeeze() 124 | pos_samples=torch.index_select(ids1,0,pos_idx) 125 | neg_samples=torch.index_select(ids1,0,neg_idx) 126 | if args.random_mix=='zero': 127 | return mixup(args,ids1,lam,neg_idx) 128 | if args.random_mix=='one': 129 | return mixup(args,ids1,lam,pos_idx) 130 | if args.random_mix=='zero_one': 131 | return mixup_01(args,ids1,lam,pos_idx,neg_idx) 132 | -------------------------------------------------------------------------------- /online_augmentation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/online_augmentation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /online_augmentation/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/online_augmentation/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/Load_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, load_from_disk 2 | from transformers import BertTokenizer 3 | import torch 4 | import numpy as np 5 | import os 6 | from . import settings 7 | class DATA_process(object): 8 | def __init__(self, args=None): 9 | if args: 10 | print('Initializing with args') 11 | self.data = args.data if args.data else None 12 | self.task = args.task if args.task else None 13 | self.tokenizer = BertTokenizer.from_pretrained( 14 | args.model, do_lower_case=True) if args.model else None 15 | self.tasksettings = settings.TaskSettings() 16 | self.max_length = args.max_length if args.max_length else None 17 | self.label_name = args.label_name if args.label_name else None 18 | self.batch_size = args.batch_size if args.batch_size else None 19 | self.aug_batch_size=args.aug_batch_size if args.aug_batch_size else None 20 | self.min_train_token = args.min_train_token if args.min_train_token else None 21 | self.max_train_token = args.max_train_token if args.max_train_token else None 22 | self.num_proc = args.num_proc if args.num_proc else None 23 | self.low_resource_dir = args.low_resource_dir if args.low_resource_dir else None 24 | self.data_path = args.data_path if args.data_path else None 25 | self.random_mix = args.random_mix if args.random_mix else None 26 | 27 | def validation_data(self): 28 | validation_set = self.validationset( 29 | data=self.data) 30 | print('='*20,'multiprocess processing test dataset','='*20) 31 | # Process dataset to make dataloader 32 | if self.task == 'single': 33 | validation_set = validation_set.map( 34 | self.encode, batched=True, num_proc=self.num_proc) 35 | else: 36 | validation_set = validation_set.map( 37 | self.encode_pair, batched=True, num_proc=self.num_proc) 38 | # validation_set = validation_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True) 39 | validation_set = validation_set.rename_column( 40 | self.label_name, "labels") 41 | validation_set.set_format(type='torch', columns=[ 42 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) 43 | 44 | val_dataloader = torch.utils.data.DataLoader( 45 | validation_set, batch_size=self.batch_size, shuffle=True) 46 | return val_dataloader 47 | def encode(self, examples): 48 | return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], max_length=self.max_length, truncation=True, padding='max_length') 49 | def encode_pair(self, examples): 50 | return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], examples[self.tasksettings.task_to_keys[self.data][1]], max_length=self.max_length, truncation=True, padding='max_length') 51 | 52 | def train_data(self, count_label=False): 53 | train_set, label_num = self.traindataset( 54 | data=self.data, low_resource_dir=self.low_resource_dir, label_num=count_label) 55 | print('='*20,'multiprocess processing train dataset','='*20) 56 | if self.task == 'single': 57 | train_set = train_set.map( 58 | self.encode, batched=True, num_proc=self.num_proc) 59 | else: 60 | train_set = train_set.map( 61 | self.encode_pair, batched=True, num_proc=self.num_proc) 62 | if self.random_mix: 63 | # sort the train dataset 64 | print('-'*20, 'random_mixup', '-'*20) 65 | train_set = train_set.map( 66 | lambda examples: {'token_num': np.sum(np.array(examples['attention_mask']))}) 67 | train_set = train_set.sort('token_num', reverse=True) 68 | # train_set = train_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True) 69 | train_set = train_set.rename_column(self.label_name, "labels") 70 | if self.min_train_token: 71 | print( 72 | '-'*20, 'filter sample whose sentence shorter than {}'.format(self.min_train_token), '-'*20) 73 | train_set = train_set.filter(lambda example: sum( 74 | example['attention_mask']) > self.min_train_token+2) 75 | if self.max_train_token: 76 | print( 77 | '-'*20, 'filter sample whose sentence longer than {}'.format(self.max_train_token), '-'*20) 78 | train_set = train_set.filter(lambda example: sum( 79 | example['attention_mask']) < self.max_train_token+2) 80 | train_set.set_format(type='torch', columns=[ 81 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) 82 | 83 | train_dataloader = torch.utils.data.DataLoader( 84 | train_set, batch_size=self.batch_size, shuffle=True) 85 | if count_label: 86 | return train_dataloader, label_num 87 | else: 88 | return train_dataloader 89 | def augmentation_data(self): 90 | try: 91 | aug_dataset = load_dataset( 92 | 'csv', data_files=[self.data_path])['train'] 93 | except Exception as e: 94 | aug_dataset = load_from_disk(self.data_path) 95 | print('='*20, 'multiprocess processing aug dataset', '='*20) 96 | if self.task == 'single': 97 | aug_dataset = aug_dataset.map( 98 | self.encode, batched=True, num_proc=self.num_proc) 99 | else: 100 | aug_dataset = aug_dataset.map( 101 | self.encode_pair, batched=True, num_proc=self.num_proc) 102 | # if self.mix: 103 | # # label has more than one dimension 104 | # # aug_dataset = aug_dataset.map(lambda examples: {'labels':examples[self.label_name]},batched=True) 105 | # else: 106 | # # aug_dataset = aug_dataset.map(lambda examples: {'labels':int(examples[self.label_name])}) 107 | aug_dataset = aug_dataset.rename_column(self.label_name, 'labels') 108 | 109 | aug_dataset.set_format(type='torch', columns=[ 110 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels']) 111 | aug_dataloader = torch.utils.data.DataLoader( 112 | aug_dataset, batch_size=self.aug_batch_size, shuffle=True) 113 | return aug_dataloader 114 | 115 | def validationset(self,data): 116 | if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: 117 | if data == 'mnli': 118 | validation_set = load_dataset( 119 | 'glue', data, split='validation_mismatched') 120 | else: 121 | validation_set = load_dataset('glue', data, split='validation') 122 | print('-'*20, 'Test on glue@{}'.format(data), '-'*20) 123 | elif data in ['imdb', 'ag_news', 'trec']: 124 | validation_set = load_dataset(data, split='test') 125 | print('-'*20, 'Test on {}'.format(data), '-'*20) 126 | elif data == 'sst': 127 | validation_set = load_dataset(data, 'default', split='test') 128 | validation_set = validation_set.map(lambda example: {'label': int( 129 | example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4) 130 | print('-'*20, 'Test on {}'.format(data), '-'*20) 131 | else: 132 | validation_set = load_dataset(data, split='validation') 133 | print('-'*20, 'Test on {}'.format(data), '-'*20) 134 | 135 | return validation_set 136 | 137 | def traindataset(self, data, low_resource_dir=None, split='train', label_num=False): 138 | if low_resource_dir: 139 | train_set = load_from_disk(os.path.join( 140 | low_resource_dir, 'partial_train')) 141 | else: 142 | if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: 143 | train_set = load_dataset('glue', data, split=split) 144 | elif data == 'sst': 145 | train_set = load_dataset(data, 'default', split=split) 146 | train_set = train_set.map(lambda example: {'label': int( 147 | example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4) 148 | else: 149 | train_set = load_dataset(data, split=split) 150 | if label_num: 151 | return train_set, len(set(train_set[self.label_name])) 152 | else: 153 | return train_set 154 | if __name__=="__main__": 155 | data_processor=DATA_process() 156 | valset=data_processor.validationset(data='ag_news') 157 | print(valset) 158 | -------------------------------------------------------------------------------- /process_data/__init__.py: -------------------------------------------------------------------------------- 1 | if __name__=='__main__': 2 | print('Using process_data package') -------------------------------------------------------------------------------- /process_data/__pycache__/Augmentation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Augmentation.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/Load_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Load_data.cpython-38.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/Load_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Load_data.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/ceshi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/ceshi.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/settings.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/settings.cpython-38.pyc -------------------------------------------------------------------------------- /process_data/__pycache__/settings.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/settings.cpython-39.pyc -------------------------------------------------------------------------------- /process_data/get_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import numpy as np 3 | import argparse 4 | import os 5 | import pandas as pd 6 | from stanfordcorenlp import StanfordCoreNLP 7 | from tqdm import tqdm 8 | import time 9 | import settings 10 | from multiprocessing import cpu_count 11 | from pandarallel import pandarallel 12 | def parse_argument(): 13 | parser = argparse.ArgumentParser(description='download and parsing datasets') 14 | parser.add_argument('--data',nargs='+',required=True,help='data list') 15 | parser.add_argument('--corenlp_dir',type=str,default='/remote-home/lzhang/stanford-corenlp-full-2018-10-05/') 16 | parser.add_argument('--proc',type=int,help='multiprocessing num') 17 | args=parser.parse_args() 18 | return args 19 | 20 | 21 | def parsing_stanfordnlp(raw_text): 22 | try: 23 | parsing = snlp.parse(raw_text) 24 | return parsing 25 | except Exception as e: 26 | return 'None' 27 | 28 | def constituency_parsing(args): 29 | if not args.proc: 30 | args.proc = cpu_count() 31 | pandarallel.initialize(nb_workers=args.proc, progress_bar=True) 32 | for dataset in args.data: 33 | DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA') 34 | path_dir=os.path.join(DATA_dir,dataset.upper()) 35 | output_path=os.path.join(path_dir,'data','train_parsing.csv') 36 | if os.path.exists(output_path): 37 | print('The data {} has already parsed!'.format(dataset.upper())) 38 | continue 39 | train=pd.read_csv(os.path.join(path_dir,'data','train.csv'),encoding="utf-8") 40 | for dataset in args.data: 41 | DATA_dir = os.path.join(os.path.abspath( 42 | os.path.join(os.getcwd(), "..")), 'DATA') 43 | path_dir = os.path.join(DATA_dir, dataset.upper()) 44 | output_path = os.path.join(path_dir, 'data', 'train_parsing.csv') 45 | if os.path.exists(output_path): 46 | print('The data {} has already parsed!'.format(dataset.upper())) 47 | continue 48 | train = pd.read_csv(os.path.join( 49 | path_dir, 'data', 'train.csv'), encoding="utf-8") 50 | for i,text_name in enumerate(task_to_keys[dataset]): 51 | parsing_name = 'parsing{}'.format(i+1) 52 | train[parsing_name] = train[text_name].parallel_apply( 53 | parsing_stanfordnlp) 54 | 55 | for i,text_name in enumerate(task_to_keys[dataset]): 56 | parsing_name='parsing{}'.format(i+1) 57 | train=train.drop(train[train[parsing_name]=='None'].index) 58 | train.to_csv(output_path, index=0) 59 | def download_data(args): 60 | 61 | for dataset in args.data: 62 | DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA') 63 | path_dir=os.path.join(DATA_dir,dataset.upper()) 64 | if dataset.upper() in os.listdir(DATA_dir): 65 | print('{} directory already exists !'.format(dataset.upper())) 66 | continue 67 | try: 68 | if dataset ==['addprim_jump','addprim_turn_left','simple']: 69 | downloaded_data_list = [load_dataset('scan', dataset)] 70 | if dataset in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']: 71 | downloaded_data_list=[load_dataset('glue',dataset)] 72 | elif dataset =='sst': 73 | downloaded_data_list = [load_dataset("sst", "default")] 74 | else: 75 | downloaded_data_list=[load_dataset(dataset)] 76 | 77 | if not os.path.exists(path_dir): 78 | if dataset=='trec': 79 | os.makedirs(os.path.join(path_dir,'generated/fine')) 80 | os.makedirs(os.path.join(path_dir,'generated/coarse')) 81 | os.makedirs(os.path.join( 82 | path_dir, 'runs/label-coarse/raw')) 83 | os.makedirs(os.path.join( 84 | path_dir, 'runs/label-coarse/aug')) 85 | os.makedirs(os.path.join( 86 | path_dir, 'runs/label-coarse/raw_aug')) 87 | os.makedirs(os.path.join( 88 | path_dir, 'runs/label-fine/raw')) 89 | os.makedirs(os.path.join( 90 | path_dir, 'runs/label-fine/aug')) 91 | os.makedirs(os.path.join( 92 | path_dir, 'runs/label-fine/raw_aug')) 93 | else: 94 | os.makedirs(os.path.join(path_dir,'generated')) 95 | os.makedirs(os.path.join(path_dir,'runs/raw')) 96 | os.makedirs(os.path.join(path_dir,'runs/aug')) 97 | os.makedirs(os.path.join(path_dir,'runs/raw_aug')) 98 | os.makedirs(os.path.join(path_dir,'logs')) 99 | os.makedirs(os.path.join(path_dir,'data')) 100 | for downloaded_data in downloaded_data_list: 101 | for data_split in downloaded_data: 102 | dataset_split=downloaded_data[data_split] 103 | dataset_split.to_csv(os.path.join(path_dir,'data',data_split+'.csv'),index=0) 104 | except Exception as e: 105 | print('Downloading failed on {}, due to error {}'.format(dataset,e)) 106 | if __name__=='__main__': 107 | args = parse_argument() 108 | tasksettings=settings.TaskSettings() 109 | task_to_keys=tasksettings.task_to_keys 110 | print('='*20,'Start Downloading Datasets','='*20) 111 | download_data(args) 112 | print('='*20,'Start Parsing Datasets','='*20) 113 | snlp = StanfordCoreNLP(args.corenlp_dir) 114 | constituency_parsing(args) 115 | -------------------------------------------------------------------------------- /process_data/settings.py: -------------------------------------------------------------------------------- 1 | class TaskSettings(object): 2 | def __init__(self): 3 | self.train_settings={ 4 | "mnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2}, 5 | "mrpc":{'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':0.2}, 6 | "qnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2}, 7 | "qqp": {'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':300,'max_length':128,'augweight':0.2}, 8 | "rte": {'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':-0.2}, 9 | "sst2":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5}, 10 | "trec":{'epoch':20,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5}, 11 | "imdb":{'epoch':5,'batch_size':8,'aug_batch_size':8,'val_steps':500,'max_length':512,'augweight':0.5}, 12 | "ag_news": {'epoch': 5, 'batch_size': 96, 'aug_batch_size': 96, 'val_steps': 500, 'max_length': 128, 'augweight': 0.5}, 13 | 14 | } 15 | self.task_to_keys = { 16 | "mnli": ["premise", "hypothesis"], 17 | "mrpc": ["sentence1", "sentence2"], 18 | "qnli": ["question", "sentence"], 19 | "qqp": ["question1", "question2"], 20 | "rte": ["sentence1", "sentence2"], 21 | "sst2": ["sentence"], 22 | "trec": ["text"], 23 | "anli": ["premise", "hypothesis"], 24 | "imdb": ["text"], 25 | "ag_news":["text"], 26 | "sst":["sentence"], 27 | "addprim_jump":["commands"], 28 | "addprim_turn_left":["commands"] 29 | } 30 | self.pair_datasets=['qqp','rte','qnli','mrpc','mnli'] 31 | self.SCAN = ['addprim_turn_left', 'addprim_jump','simple'] 32 | self.low_resource={ 33 | "ag_news":[0.01,0.02,0.05,0.1,0.2], 34 | "sst":[0.01,0.02,0.05,0.1,0.2], 35 | "sst2":[0.01,0.02,0.05,0.1,0.2] 36 | } 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==1.11.0 2 | nltk==3.6.2 3 | pandarallel==1.5.4 4 | stanfordcorenlp==3.9.1.1 5 | torch==1.9.0+cu111 6 | transformers==4.12.5 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.utils.data.distributed 4 | from torch.utils.data.distributed import DistributedSampler 5 | import torch.nn.parallel 6 | from transformers import BertForSequenceClassification, AdamW 7 | from transformers import get_linear_schedule_with_warmup 8 | import numpy as np 9 | import torch.nn as nn 10 | from sklearn.metrics import f1_score, accuracy_score 11 | from tqdm import tqdm 12 | import os 13 | import re 14 | import torch.nn.functional as F 15 | from torch.utils.tensorboard import SummaryWriter 16 | from itertools import cycle 17 | import argparse 18 | import torch.distributed as dist 19 | import time 20 | import online_augmentation 21 | import logging 22 | from process_data.Load_data import DATA_process 23 | 24 | 25 | 26 | def set_seed(seed): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def cross_entropy(logits, target): 35 | p = F.softmax(logits, dim=1) 36 | log_p = -torch.log(p) 37 | loss = target*log_p 38 | # print(target,p,log_p,loss) 39 | batch_num = logits.shape[0] 40 | return loss.sum()/batch_num 41 | 42 | 43 | 44 | def flat_accuracy(preds, labels): 45 | pred_flat = np.argmax(preds, axis=1).flatten() 46 | labels_flat = labels.flatten() 47 | return accuracy_score(labels_flat, pred_flat) 48 | 49 | 50 | def reduce_tensor(tensor, args): 51 | rt = tensor.clone() 52 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 53 | rt /= args.world_size 54 | return rt 55 | 56 | 57 | def tensorboard_settings(args): 58 | if 'raw' in args.mode: 59 | if args.data_path: 60 | # raw_aug 61 | log_dir = os.path.join(args.output_dir, 'Raw_Aug_{}_{}_{}_{}_{}'.format(args.data_path.split( 62 | '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size)) 63 | if os.path.exists(log_dir): 64 | raise IOError( 65 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) 66 | writer = SummaryWriter(log_dir=log_dir) 67 | else: 68 | # raw 69 | if args.random_mix: 70 | log_dir = os.path.join(args.output_dir, 'Raw_random_mixup_{}_{}_{}'.format( 71 | args.random_mix, args.alpha, args.seed)) 72 | if os.path.exists(log_dir): 73 | raise IOError( 74 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) 75 | writer = SummaryWriter(log_dir=log_dir) 76 | else: 77 | log_dir = os.path.join( 78 | args.output_dir, 'Raw_{}'.format(args.seed)) 79 | if os.path.exists(log_dir): 80 | raise IOError( 81 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) 82 | writer = SummaryWriter(log_dir=log_dir) 83 | elif args.mode == 'aug': 84 | # aug 85 | log_dir = os.path.join(args.output_dir, 'Aug_{}_{}_{}_{}_{}'.format(args.data_path.split( 86 | '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size)) 87 | if os.path.exists(log_dir): 88 | raise IOError( 89 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir)) 90 | writer = SummaryWriter(log_dir=log_dir) 91 | return writer 92 | 93 | 94 | def logging_settings(args): 95 | logger = logging.getLogger('result') 96 | logger.setLevel(logging.INFO) 97 | fmt = logging.Formatter( 98 | fmt='%(asctime)s - %(filename)s - %(levelname)s: %(message)s') 99 | if not os.path.exists(os.path.join('DATA', args.data.upper(), 'logs')): 100 | os.makedirs(os.path.join( 101 | 'DATA', args.data.upper(), 'logs')) 102 | if args.low_resource_dir: 103 | log_path = os.path.join('DATA', args.data.upper(),'logs', 'lowresourcebest_result.log') 104 | else: 105 | log_path = os.path.join('DATA', args.data.upper(),'logs', 'best_result.log') 106 | 107 | fh = logging.FileHandler(log_path, mode='a+', encoding='utf-8') 108 | ft=logging.Filter(name='result.a') 109 | fh.setFormatter(fmt) 110 | fh.setLevel(logging.INFO) 111 | fh.addFilter(ft) 112 | logger.addHandler(fh) 113 | result_logger=logging.getLogger('result.a') 114 | return result_logger 115 | def loading_model(args,label_num): 116 | t1 = time.time() 117 | if args.local_rank == -1: 118 | device = torch.device( 119 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 120 | args.n_gpu = torch.cuda.device_count() 121 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 122 | torch.cuda.set_device(args.local_rank) 123 | device = torch.device("cuda", args.local_rank) 124 | torch.distributed.init_process_group(backend='nccl') 125 | args.n_gpu = 1 # the number of gpu on each proc 126 | args.device = device 127 | if args.local_rank != -1: 128 | args.world_size = torch.cuda.device_count() 129 | else: 130 | args.world_size = 1 131 | print('*'*40, '\nSettings:{}'.format(args)) 132 | print('*'*40) 133 | print('='*20, 'Loading models', '='*20) 134 | model = BertForSequenceClassification.from_pretrained( 135 | args.model, num_labels=label_num) 136 | model.to(device) 137 | t2 = time.time() 138 | print( 139 | '='*20, 'Loading models complete!, cost {:.2f}s'.format(t2-t1), '='*20) 140 | # model parrallel 141 | if args.local_rank != -1: 142 | model = torch.nn.parallel.DistributedDataParallel( 143 | model, device_ids=[args.local_rank]) 144 | elif args.n_gpu > 1: 145 | model = nn.DataParallel(model) 146 | if args.load_model_path is not None: 147 | print("="*20, "Load model from %s", args.load_model_path,) 148 | model.load_state_dict(torch.load(args.load_model_path)) 149 | return model 150 | 151 | def parse_argument(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--local_rank', default=-1, type=int, 154 | help='node rank for distributed training') 155 | parser.add_argument("--no_cuda", action='store_true', 156 | help="Avoid using CUDA when available") 157 | 158 | parser.add_argument( 159 | '--mode', type=str, choices=['raw', 'aug', 'raw_aug', 'visualize'], required=True) 160 | parser.add_argument('--save_model', action='store_true') 161 | parser.add_argument('--load_model_path', type=str) 162 | parser.add_argument('--data', type=str, required=True) 163 | parser.add_argument('--num_proc', type=int, default=8, 164 | help='multi process number used in dataloader process') 165 | 166 | # training settings 167 | parser.add_argument('--output_dir', type=str, help="tensorboard fileoutput directory") 168 | parser.add_argument('--epoch', type=int, default=5, help='train epochs') 169 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 170 | parser.add_argument('--seed', default=42, type=int, help='seed ') 171 | parser.add_argument('--batch_size', default=128, type=int, 172 | help='train examples in each batch') 173 | parser.add_argument('--val_steps', default=100, type=int, 174 | help='evaluate on dev datasets every steps') 175 | parser.add_argument('--max_length', default=128, 176 | type=int, help='encode max length') 177 | parser.add_argument('--label_name', type=str, default='label') 178 | parser.add_argument('--model', type=str, default='bert-base-uncased') 179 | parser.add_argument('--low_resource_dir', type=str, 180 | help='Low resource data dir') 181 | 182 | # train on augmentation dataset parameters 183 | parser.add_argument('--aug_batch_size', default=128, 184 | type=int, help='train examples in each batch') 185 | parser.add_argument('--augweight', default=0.2, type=float) 186 | parser.add_argument('--data_path', type=str, help="augmentation file path") 187 | parser.add_argument('--min_train_token', type=int, default=0, 188 | help="minimum token num restriction for train dataset") 189 | parser.add_argument('--max_train_token', type=int, default=0, 190 | help="maximum token num restriction for train dataset") 191 | parser.add_argument('--mix', action='store_false', help='train on 01mixup') 192 | 193 | # random mixup 194 | parser.add_argument('--alpha', type=float, default=0.1, 195 | help="online augmentation alpha") 196 | parser.add_argument('--onlyaug', action='store_true', 197 | help="train only on online aug batch") 198 | parser.add_argument('--difflen', action='store_true', 199 | help="train only on online aug batch") 200 | parser.add_argument('--random_mix', type=str, help="random mixup ") 201 | 202 | # visualize dataset 203 | 204 | args = parser.parse_args() 205 | if args.data == 'trec': 206 | try: 207 | assert args.label_name in ['label-fine', 'label-coarse'] 208 | except AssertionError: 209 | raise(AssertionError( 210 | "If you want to train on trec dataset with augmentation, you have to name the label of split")) 211 | if not args.output_dir: 212 | args.output_dir = os.path.join( 213 | 'DATA', args.data.upper(), 'runs', args.label_name, args.mode) 214 | if args.mode == 'raw': 215 | args.batch_size = 128 216 | if 'aug' in args.mode: 217 | assert args.data_path 218 | if args.mode == 'aug': 219 | args.seed = 42 220 | if not args.output_dir: 221 | args.output_dir = os.path.join( 222 | 'DATA', args.data.upper(), 'runs', args.mode) 223 | if not os.path.exists(args.output_dir): 224 | os.makedirs(args.output_dir) 225 | if args.data in ['rte', 'mrpc', 'qqp', 'mnli', 'qnli']: 226 | args.task = 'pair' 227 | else: 228 | args.task = 'single' 229 | 230 | return args 231 | 232 | 233 | def train(args): 234 | # ======================================== 235 | # Tensorboard &Logging 236 | # ======================================== 237 | writer = tensorboard_settings(args) 238 | result_logger = logging_settings(args) 239 | data_process = DATA_process(args) 240 | # ======================================== 241 | # Loading datasets 242 | # ======================================== 243 | print('='*20, 'Start processing dataset', '='*20) 244 | t1 = time.time() 245 | 246 | val_dataloader = data_process.validation_data() 247 | 248 | if args.mode != 'aug': 249 | train_dataloader, label_num = data_process.train_data(count_label=True) 250 | # print('Label_num',label_num) 251 | if args.data_path: 252 | print('='*20, 'Train Augmentation dataset path: {}'.format(args.data_path), '='*20) 253 | aug_dataloader = data_process.augmentation_data() 254 | if args.mode == 'aug': 255 | train_dataloader = aug_dataloader 256 | else: 257 | aug_dataloader = cycle(aug_dataloader) 258 | 259 | t2 = time.time() 260 | print('='*20, 'Dataset process done! cost {:.2f}s'.format(t2-t1), '='*20) 261 | 262 | # ======================================== 263 | # Model 264 | # ======================================== 265 | model=loading_model(args,label_num) 266 | # ======================================== 267 | # Optimizer Settings 268 | # ======================================== 269 | optimizer = AdamW(model.parameters(), lr=args.lr) 270 | all_steps = args.epoch*len(train_dataloader) 271 | scheduler = get_linear_schedule_with_warmup( 272 | optimizer, num_warmup_steps=20, num_training_steps=all_steps) 273 | criterion = nn.CrossEntropyLoss() 274 | model.train() 275 | 276 | # ======================================== 277 | # Train 278 | # ======================================== 279 | print('='*20, 'Start training', '='*20) 280 | best_acc = 0 281 | args.val_steps = min(len(train_dataloader), args.val_steps) 282 | 283 | for epoch in range(args.epoch): 284 | bar = tqdm(enumerate(train_dataloader), total=len( 285 | train_dataloader)//args.world_size) 286 | fail = 0 287 | loss = 0 288 | for step, batch in bar: 289 | model.zero_grad() 290 | 291 | # ---------------------------------------------- 292 | # Train_dataloader 293 | # ---------------------------------------------- 294 | if args.random_mix: 295 | try: 296 | 297 | input_ids, target_a = batch['input_ids'], batch['labels'] 298 | lam = np.random.choice([0, 0.1, 0.2, 0.3]) 299 | exchanged_ids, new_index = online_augmentation.random_mixup( 300 | args, input_ids, target_a, lam) 301 | target_b = target_a[new_index] 302 | outputs = model(exchanged_ids.to(args.device), token_type_ids=None, attention_mask=( 303 | exchanged_ids > 0).to(args.device)) 304 | logits = outputs.logits 305 | loss = criterion(logits.to(args.device), target_a.to( 306 | args.device))*(1-lam)+criterion(logits.to(args.device), target_b.to(args.device))*lam 307 | 308 | 309 | except Exception as e: 310 | fail += 1 311 | batch = {k: v.to(args.device) for k, v in batch.items()} 312 | outputs = model(**batch) 313 | loss = outputs.loss 314 | elif args.model == 'aug': 315 | # train only on augmentation dataset 316 | batch = {k: v.to(args.device) for k, v in batch.items()} 317 | if args.mix: 318 | # train on 01 tree mixup augmentation dataset 319 | mix_label = batch['labels'] 320 | del batch['labels'] 321 | 322 | outputs = model(**batch) 323 | logits = outputs.logits 324 | 325 | loss = cross_entropy(logits, mix_label) 326 | else: 327 | # train on 00&11 tree mixup augmentation dataset 328 | outputs = model(**batch) 329 | loss = outputs.loss 330 | else: 331 | # normal train 332 | 333 | batch = {k: v.to(args.device) for k, v in batch.items()} 334 | 335 | outputs = model(**batch) 336 | loss = outputs.loss 337 | # ---------------------------------------------- 338 | # Aug_dataloader 339 | # ---------------------------------------------- 340 | if args.mode == 'raw_aug': 341 | aug_batch = next(aug_dataloader) 342 | aug_batch = {k: v.to(args.device) for k, v in aug_batch.items()} 343 | 344 | if args.mix: 345 | mix_label = aug_batch['labels'] 346 | del aug_batch['labels'] 347 | aug_outputs = model(**aug_batch) 348 | aug_logits = aug_outputs.logits 349 | 350 | aug_loss = cross_entropy(aug_logits, mix_label) 351 | else: 352 | aug_outputs = model(**aug_batch) 353 | aug_loss = aug_outputs.loss 354 | loss += aug_loss*args.augweight # for sst2,rte reaches best performance 355 | 356 | # Backward propagation 357 | if args.n_gpu > 1: 358 | loss = loss.mean() 359 | loss.backward() 360 | optimizer.step() 361 | scheduler.step() 362 | optimizer.zero_grad() 363 | if args.local_rank == 0 or args.local_rank == -1: 364 | writer.add_scalar("Loss/loss", loss, step + 365 | epoch*len(train_dataloader)) 366 | writer.flush() 367 | if args.random_mix: 368 | bar.set_description( 369 | '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}| Fail:{}|'.format(epoch, args.epoch, best_acc*100, fail)) 370 | else: 371 | bar.set_description( 372 | '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}|'.format(epoch, args.epoch, best_acc*100)) 373 | 374 | # ================================================= 375 | # Validation 376 | # ================================================= 377 | if (epoch*len(train_dataloader)+step+1) % args.val_steps == 0: 378 | total_eval_accuracy = 0 379 | total_val_loss = 0 380 | model.eval() # evaluation after each epoch 381 | for i, batch in enumerate(val_dataloader): 382 | with torch.no_grad(): 383 | batch = {k: v.to(args.device) 384 | for k, v in batch.items()} 385 | outputs = model(**batch) 386 | logits = outputs.logits 387 | loss = outputs.loss 388 | 389 | if args.n_gpu > 1: 390 | loss = loss.mean() 391 | logits = logits.detach().cpu().numpy() 392 | label_ids = batch['labels'].to('cpu').numpy() 393 | 394 | accuracy = flat_accuracy(logits, label_ids) 395 | if args.local_rank != -1: 396 | torch.distributed.barrier() 397 | reduced_loss = reduce_tensor(loss, args) 398 | accuracy = torch.tensor(accuracy).to(args.device) 399 | reduced_acc = reduce_tensor(accuracy, args) 400 | total_val_loss += reduced_loss 401 | total_eval_accuracy += reduced_acc 402 | else: 403 | total_eval_accuracy += accuracy.item() 404 | total_val_loss += loss.item() 405 | avg_val_loss = total_val_loss/len(val_dataloader) 406 | avg_val_accuracy = total_eval_accuracy/len(val_dataloader) 407 | if avg_val_accuracy > best_acc: 408 | best_acc = avg_val_accuracy 409 | bset_steps = (epoch*len(train_dataloader) + 410 | step)*args.batch_size 411 | if args.save_model: 412 | torch.save(model.state_dict(), 'best_model.pt') 413 | if args.local_rank == 0 or args.local_rank == -1: 414 | writer.add_scalar("Test/Loss", avg_val_loss, 415 | epoch*len(train_dataloader)+step) 416 | writer.add_scalar( 417 | "Test/Accuracy", avg_val_accuracy, epoch*len(train_dataloader)+step) 418 | writer.flush() 419 | # print(f'Validation loss: {avg_val_loss}') 420 | # print(f'Accuracy: {avg_val_accuracy:.5f}') 421 | # print('Best Accuracy:{:.5f} Steps:{}\n'.format(best_acc, bset_steps)) 422 | 423 | if args.data_path: 424 | aug_num=args.data_path.split('_')[-1] 425 | 426 | if args.low_resource_dir: 427 | # low resource raw_aug 428 | partial = re.findall(r'low_resource_(0.\d+)', 429 | args.low_resource_dir)[0] 430 | aug_num_seed = aug_num+'_'+str(args.seed) 431 | result_logger.info('-'*160) 432 | result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format( 433 | args.data+'_'+partial, args.mode, aug_num_seed, round(best_acc*100, 3), bset_steps, args.augweight, args.data_path)) 434 | else: 435 | # raw_aug 436 | aug_data_seed=re.findall(r'seed(\d)',args.data_path)[0] 437 | aug_num_seed = aug_num+'_'+aug_data_seed 438 | result_logger.info('-'*160) 439 | result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format( 440 | args.data, args.mode, aug_num_seed ,round(best_acc*100,3), bset_steps, args.augweight,args.data_path)) 441 | else: 442 | if args.low_resource_dir: 443 | # low resource raw 444 | partial=re.findall(r'low_resource_(0.\d+)',args.low_resource_dir)[0] 445 | result_logger.info('-'*160) 446 | result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format( 447 | args.data+'-'+partial, args.mode, args.seed, round(best_acc*100,3), bset_steps,bool(args.random_mix) ,args.data_path)) 448 | else: 449 | # raw 450 | result_logger.info('-'*160) 451 | result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format( 452 | args.data, args.mode, args.seed, round(best_acc*100,3), bset_steps, bool(args.random_mix),args.data_path)) 453 | 454 | 455 | 456 | 457 | 458 | def main(args): 459 | set_seed(args.seed) 460 | if args.mode in ['raw', 'raw_aug', 'aug']: 461 | if args.low_resource_dir: 462 | print("="*20, ' Lowresource ', '='*20) 463 | train(args) 464 | if __name__ == '__main__': 465 | args = parse_argument() 466 | main(args) 467 | -------------------------------------------------------------------------------- /transformers_doc/pytorch/task_summary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "task_summary.ipynb", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "language_info": { 11 | "name": "python" 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "7c08442589064421b77b6aa52ce5b947": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "model_module_version": "1.5.0", 23 | "state": { 24 | "_view_name": "HBoxView", 25 | "_dom_classes": [], 26 | "_model_name": "HBoxModel", 27 | "_view_module": "@jupyter-widgets/controls", 28 | "_model_module_version": "1.5.0", 29 | "_view_count": null, 30 | "_view_module_version": "1.5.0", 31 | "box_style": "", 32 | "layout": "IPY_MODEL_23f3f545c30d4226987e7909e8365413", 33 | "_model_module": "@jupyter-widgets/controls", 34 | "children": [ 35 | "IPY_MODEL_3d29469f501c4a0b96c211241459f34a", 36 | "IPY_MODEL_9578e983215a4bc497b5882a4244b8a9", 37 | "IPY_MODEL_47b6aefd82c04f0fb3a3de2c7f37c58f" 38 | ] 39 | } 40 | }, 41 | "23f3f545c30d4226987e7909e8365413": { 42 | "model_module": "@jupyter-widgets/base", 43 | "model_name": "LayoutModel", 44 | "model_module_version": "1.2.0", 45 | "state": { 46 | "_view_name": "LayoutView", 47 | "grid_template_rows": null, 48 | "right": null, 49 | "justify_content": null, 50 | "_view_module": "@jupyter-widgets/base", 51 | "overflow": null, 52 | "_model_module_version": "1.2.0", 53 | "_view_count": null, 54 | "flex_flow": null, 55 | "width": null, 56 | "min_width": null, 57 | "border": null, 58 | "align_items": null, 59 | "bottom": null, 60 | "_model_module": "@jupyter-widgets/base", 61 | "top": null, 62 | "grid_column": null, 63 | "overflow_y": null, 64 | "overflow_x": null, 65 | "grid_auto_flow": null, 66 | "grid_area": null, 67 | "grid_template_columns": null, 68 | "flex": null, 69 | "_model_name": "LayoutModel", 70 | "justify_items": null, 71 | "grid_row": null, 72 | "max_height": null, 73 | "align_content": null, 74 | "visibility": null, 75 | "align_self": null, 76 | "height": null, 77 | "min_height": null, 78 | "padding": null, 79 | "grid_auto_rows": null, 80 | "grid_gap": null, 81 | "max_width": null, 82 | "order": null, 83 | "_view_module_version": "1.2.0", 84 | "grid_template_areas": null, 85 | "object_position": null, 86 | "object_fit": null, 87 | "grid_auto_columns": null, 88 | "margin": null, 89 | "display": null, 90 | "left": null 91 | } 92 | }, 93 | "3d29469f501c4a0b96c211241459f34a": { 94 | "model_module": "@jupyter-widgets/controls", 95 | "model_name": "HTMLModel", 96 | "model_module_version": "1.5.0", 97 | "state": { 98 | "_view_name": "HTMLView", 99 | "style": "IPY_MODEL_b1e066e72ca04508bf50a60961e66bea", 100 | "_dom_classes": [], 101 | "description": "", 102 | "_model_name": "HTMLModel", 103 | "placeholder": "​", 104 | "_view_module": "@jupyter-widgets/controls", 105 | "_model_module_version": "1.5.0", 106 | "value": "Downloading: 100%", 107 | "_view_count": null, 108 | "_view_module_version": "1.5.0", 109 | "description_tooltip": null, 110 | "_model_module": "@jupyter-widgets/controls", 111 | "layout": "IPY_MODEL_96fe7d4c88e541eea33dde3889e3e381" 112 | } 113 | }, 114 | "9578e983215a4bc497b5882a4244b8a9": { 115 | "model_module": "@jupyter-widgets/controls", 116 | "model_name": "FloatProgressModel", 117 | "model_module_version": "1.5.0", 118 | "state": { 119 | "_view_name": "ProgressView", 120 | "style": "IPY_MODEL_65b763eef5184606b28d4673e6ad857d", 121 | "_dom_classes": [], 122 | "description": "", 123 | "_model_name": "FloatProgressModel", 124 | "bar_style": "success", 125 | "max": 629, 126 | "_view_module": "@jupyter-widgets/controls", 127 | "_model_module_version": "1.5.0", 128 | "value": 629, 129 | "_view_count": null, 130 | "_view_module_version": "1.5.0", 131 | "orientation": "horizontal", 132 | "min": 0, 133 | "description_tooltip": null, 134 | "_model_module": "@jupyter-widgets/controls", 135 | "layout": "IPY_MODEL_40984bf56a8b448e82451fa60fc15f7c" 136 | } 137 | }, 138 | "47b6aefd82c04f0fb3a3de2c7f37c58f": { 139 | "model_module": "@jupyter-widgets/controls", 140 | "model_name": "HTMLModel", 141 | "model_module_version": "1.5.0", 142 | "state": { 143 | "_view_name": "HTMLView", 144 | "style": "IPY_MODEL_21bf499e3fb34c02b63934ea6e4f162f", 145 | "_dom_classes": [], 146 | "description": "", 147 | "_model_name": "HTMLModel", 148 | "placeholder": "​", 149 | "_view_module": "@jupyter-widgets/controls", 150 | "_model_module_version": "1.5.0", 151 | "value": " 629/629 [00:00<00:00, 10.2kB/s]", 152 | "_view_count": null, 153 | "_view_module_version": "1.5.0", 154 | "description_tooltip": null, 155 | "_model_module": "@jupyter-widgets/controls", 156 | "layout": "IPY_MODEL_9566729c64304007953c365eb0341df2" 157 | } 158 | }, 159 | "b1e066e72ca04508bf50a60961e66bea": { 160 | "model_module": "@jupyter-widgets/controls", 161 | "model_name": "DescriptionStyleModel", 162 | "model_module_version": "1.5.0", 163 | "state": { 164 | "_view_name": "StyleView", 165 | "_model_name": "DescriptionStyleModel", 166 | "description_width": "", 167 | "_view_module": "@jupyter-widgets/base", 168 | "_model_module_version": "1.5.0", 169 | "_view_count": null, 170 | "_view_module_version": "1.2.0", 171 | "_model_module": "@jupyter-widgets/controls" 172 | } 173 | }, 174 | "96fe7d4c88e541eea33dde3889e3e381": { 175 | "model_module": "@jupyter-widgets/base", 176 | "model_name": "LayoutModel", 177 | "model_module_version": "1.2.0", 178 | "state": { 179 | "_view_name": "LayoutView", 180 | "grid_template_rows": null, 181 | "right": null, 182 | "justify_content": null, 183 | "_view_module": "@jupyter-widgets/base", 184 | "overflow": null, 185 | "_model_module_version": "1.2.0", 186 | "_view_count": null, 187 | "flex_flow": null, 188 | "width": null, 189 | "min_width": null, 190 | "border": null, 191 | "align_items": null, 192 | "bottom": null, 193 | "_model_module": "@jupyter-widgets/base", 194 | "top": null, 195 | "grid_column": null, 196 | "overflow_y": null, 197 | "overflow_x": null, 198 | "grid_auto_flow": null, 199 | "grid_area": null, 200 | "grid_template_columns": null, 201 | "flex": null, 202 | "_model_name": "LayoutModel", 203 | "justify_items": null, 204 | "grid_row": null, 205 | "max_height": null, 206 | "align_content": null, 207 | "visibility": null, 208 | "align_self": null, 209 | "height": null, 210 | "min_height": null, 211 | "padding": null, 212 | "grid_auto_rows": null, 213 | "grid_gap": null, 214 | "max_width": null, 215 | "order": null, 216 | "_view_module_version": "1.2.0", 217 | "grid_template_areas": null, 218 | "object_position": null, 219 | "object_fit": null, 220 | "grid_auto_columns": null, 221 | "margin": null, 222 | "display": null, 223 | "left": null 224 | } 225 | }, 226 | "65b763eef5184606b28d4673e6ad857d": { 227 | "model_module": "@jupyter-widgets/controls", 228 | "model_name": "ProgressStyleModel", 229 | "model_module_version": "1.5.0", 230 | "state": { 231 | "_view_name": "StyleView", 232 | "_model_name": "ProgressStyleModel", 233 | "description_width": "", 234 | "_view_module": "@jupyter-widgets/base", 235 | "_model_module_version": "1.5.0", 236 | "_view_count": null, 237 | "_view_module_version": "1.2.0", 238 | "bar_color": null, 239 | "_model_module": "@jupyter-widgets/controls" 240 | } 241 | }, 242 | "40984bf56a8b448e82451fa60fc15f7c": { 243 | "model_module": "@jupyter-widgets/base", 244 | "model_name": "LayoutModel", 245 | "model_module_version": "1.2.0", 246 | "state": { 247 | "_view_name": "LayoutView", 248 | "grid_template_rows": null, 249 | "right": null, 250 | "justify_content": null, 251 | "_view_module": "@jupyter-widgets/base", 252 | "overflow": null, 253 | "_model_module_version": "1.2.0", 254 | "_view_count": null, 255 | "flex_flow": null, 256 | "width": null, 257 | "min_width": null, 258 | "border": null, 259 | "align_items": null, 260 | "bottom": null, 261 | "_model_module": "@jupyter-widgets/base", 262 | "top": null, 263 | "grid_column": null, 264 | "overflow_y": null, 265 | "overflow_x": null, 266 | "grid_auto_flow": null, 267 | "grid_area": null, 268 | "grid_template_columns": null, 269 | "flex": null, 270 | "_model_name": "LayoutModel", 271 | "justify_items": null, 272 | "grid_row": null, 273 | "max_height": null, 274 | "align_content": null, 275 | "visibility": null, 276 | "align_self": null, 277 | "height": null, 278 | "min_height": null, 279 | "padding": null, 280 | "grid_auto_rows": null, 281 | "grid_gap": null, 282 | "max_width": null, 283 | "order": null, 284 | "_view_module_version": "1.2.0", 285 | "grid_template_areas": null, 286 | "object_position": null, 287 | "object_fit": null, 288 | "grid_auto_columns": null, 289 | "margin": null, 290 | "display": null, 291 | "left": null 292 | } 293 | }, 294 | "21bf499e3fb34c02b63934ea6e4f162f": { 295 | "model_module": "@jupyter-widgets/controls", 296 | "model_name": "DescriptionStyleModel", 297 | "model_module_version": "1.5.0", 298 | "state": { 299 | "_view_name": "StyleView", 300 | "_model_name": "DescriptionStyleModel", 301 | "description_width": "", 302 | "_view_module": "@jupyter-widgets/base", 303 | "_model_module_version": "1.5.0", 304 | "_view_count": null, 305 | "_view_module_version": "1.2.0", 306 | "_model_module": "@jupyter-widgets/controls" 307 | } 308 | }, 309 | "9566729c64304007953c365eb0341df2": { 310 | "model_module": "@jupyter-widgets/base", 311 | "model_name": "LayoutModel", 312 | "model_module_version": "1.2.0", 313 | "state": { 314 | "_view_name": "LayoutView", 315 | "grid_template_rows": null, 316 | "right": null, 317 | "justify_content": null, 318 | "_view_module": "@jupyter-widgets/base", 319 | "overflow": null, 320 | "_model_module_version": "1.2.0", 321 | "_view_count": null, 322 | "flex_flow": null, 323 | "width": null, 324 | "min_width": null, 325 | "border": null, 326 | "align_items": null, 327 | "bottom": null, 328 | "_model_module": "@jupyter-widgets/base", 329 | "top": null, 330 | "grid_column": null, 331 | "overflow_y": null, 332 | "overflow_x": null, 333 | "grid_auto_flow": null, 334 | "grid_area": null, 335 | "grid_template_columns": null, 336 | "flex": null, 337 | "_model_name": "LayoutModel", 338 | "justify_items": null, 339 | "grid_row": null, 340 | "max_height": null, 341 | "align_content": null, 342 | "visibility": null, 343 | "align_self": null, 344 | "height": null, 345 | "min_height": null, 346 | "padding": null, 347 | "grid_auto_rows": null, 348 | "grid_gap": null, 349 | "max_width": null, 350 | "order": null, 351 | "_view_module_version": "1.2.0", 352 | "grid_template_areas": null, 353 | "object_position": null, 354 | "object_fit": null, 355 | "grid_auto_columns": null, 356 | "margin": null, 357 | "display": null, 358 | "left": null 359 | } 360 | }, 361 | "6bf72948e9ce47d296cf9673d5c11ced": { 362 | "model_module": "@jupyter-widgets/controls", 363 | "model_name": "HBoxModel", 364 | "model_module_version": "1.5.0", 365 | "state": { 366 | "_view_name": "HBoxView", 367 | "_dom_classes": [], 368 | "_model_name": "HBoxModel", 369 | "_view_module": "@jupyter-widgets/controls", 370 | "_model_module_version": "1.5.0", 371 | "_view_count": null, 372 | "_view_module_version": "1.5.0", 373 | "box_style": "", 374 | "layout": "IPY_MODEL_e32995501a0243c1ab56ee42ecddd80e", 375 | "_model_module": "@jupyter-widgets/controls", 376 | "children": [ 377 | "IPY_MODEL_e1c3b3a1b0ce49cc9db2bd16be25c929", 378 | "IPY_MODEL_a1f0479b17f8451daaf69b0af81b9adb", 379 | "IPY_MODEL_5f7a3c947c564436bd2dfcc04d4da5ee" 380 | ] 381 | } 382 | }, 383 | "e32995501a0243c1ab56ee42ecddd80e": { 384 | "model_module": "@jupyter-widgets/base", 385 | "model_name": "LayoutModel", 386 | "model_module_version": "1.2.0", 387 | "state": { 388 | "_view_name": "LayoutView", 389 | "grid_template_rows": null, 390 | "right": null, 391 | "justify_content": null, 392 | "_view_module": "@jupyter-widgets/base", 393 | "overflow": null, 394 | "_model_module_version": "1.2.0", 395 | "_view_count": null, 396 | "flex_flow": null, 397 | "width": null, 398 | "min_width": null, 399 | "border": null, 400 | "align_items": null, 401 | "bottom": null, 402 | "_model_module": "@jupyter-widgets/base", 403 | "top": null, 404 | "grid_column": null, 405 | "overflow_y": null, 406 | "overflow_x": null, 407 | "grid_auto_flow": null, 408 | "grid_area": null, 409 | "grid_template_columns": null, 410 | "flex": null, 411 | "_model_name": "LayoutModel", 412 | "justify_items": null, 413 | "grid_row": null, 414 | "max_height": null, 415 | "align_content": null, 416 | "visibility": null, 417 | "align_self": null, 418 | "height": null, 419 | "min_height": null, 420 | "padding": null, 421 | "grid_auto_rows": null, 422 | "grid_gap": null, 423 | "max_width": null, 424 | "order": null, 425 | "_view_module_version": "1.2.0", 426 | "grid_template_areas": null, 427 | "object_position": null, 428 | "object_fit": null, 429 | "grid_auto_columns": null, 430 | "margin": null, 431 | "display": null, 432 | "left": null 433 | } 434 | }, 435 | "e1c3b3a1b0ce49cc9db2bd16be25c929": { 436 | "model_module": "@jupyter-widgets/controls", 437 | "model_name": "HTMLModel", 438 | "model_module_version": "1.5.0", 439 | "state": { 440 | "_view_name": "HTMLView", 441 | "style": "IPY_MODEL_d0d537e5dc744a5794187c5a498358fc", 442 | "_dom_classes": [], 443 | "description": "", 444 | "_model_name": "HTMLModel", 445 | "placeholder": "​", 446 | "_view_module": "@jupyter-widgets/controls", 447 | "_model_module_version": "1.5.0", 448 | "value": "Downloading: 100%", 449 | "_view_count": null, 450 | "_view_module_version": "1.5.0", 451 | "description_tooltip": null, 452 | "_model_module": "@jupyter-widgets/controls", 453 | "layout": "IPY_MODEL_bdbbb4175a6c4bd389755dc0f8d186cb" 454 | } 455 | }, 456 | "a1f0479b17f8451daaf69b0af81b9adb": { 457 | "model_module": "@jupyter-widgets/controls", 458 | "model_name": "FloatProgressModel", 459 | "model_module_version": "1.5.0", 460 | "state": { 461 | "_view_name": "ProgressView", 462 | "style": "IPY_MODEL_902c7115997647b9a5140ba707d74390", 463 | "_dom_classes": [], 464 | "description": "", 465 | "_model_name": "FloatProgressModel", 466 | "bar_style": "success", 467 | "max": 267844284, 468 | "_view_module": "@jupyter-widgets/controls", 469 | "_model_module_version": "1.5.0", 470 | "value": 267844284, 471 | "_view_count": null, 472 | "_view_module_version": "1.5.0", 473 | "orientation": "horizontal", 474 | "min": 0, 475 | "description_tooltip": null, 476 | "_model_module": "@jupyter-widgets/controls", 477 | "layout": "IPY_MODEL_cbff745d3ebb4ba0b9bfd8b1a210550f" 478 | } 479 | }, 480 | "5f7a3c947c564436bd2dfcc04d4da5ee": { 481 | "model_module": "@jupyter-widgets/controls", 482 | "model_name": "HTMLModel", 483 | "model_module_version": "1.5.0", 484 | "state": { 485 | "_view_name": "HTMLView", 486 | "style": "IPY_MODEL_db0e50cf79174b99b7943fe9fa51b738", 487 | "_dom_classes": [], 488 | "description": "", 489 | "_model_name": "HTMLModel", 490 | "placeholder": "​", 491 | "_view_module": "@jupyter-widgets/controls", 492 | "_model_module_version": "1.5.0", 493 | "value": " 255M/255M [00:06<00:00, 35.0MB/s]", 494 | "_view_count": null, 495 | "_view_module_version": "1.5.0", 496 | "description_tooltip": null, 497 | "_model_module": "@jupyter-widgets/controls", 498 | "layout": "IPY_MODEL_91e992390cf94d4bb23164ccbfb2a4d0" 499 | } 500 | }, 501 | "d0d537e5dc744a5794187c5a498358fc": { 502 | "model_module": "@jupyter-widgets/controls", 503 | "model_name": "DescriptionStyleModel", 504 | "model_module_version": "1.5.0", 505 | "state": { 506 | "_view_name": "StyleView", 507 | "_model_name": "DescriptionStyleModel", 508 | "description_width": "", 509 | "_view_module": "@jupyter-widgets/base", 510 | "_model_module_version": "1.5.0", 511 | "_view_count": null, 512 | "_view_module_version": "1.2.0", 513 | "_model_module": "@jupyter-widgets/controls" 514 | } 515 | }, 516 | "bdbbb4175a6c4bd389755dc0f8d186cb": { 517 | "model_module": "@jupyter-widgets/base", 518 | "model_name": "LayoutModel", 519 | "model_module_version": "1.2.0", 520 | "state": { 521 | "_view_name": "LayoutView", 522 | "grid_template_rows": null, 523 | "right": null, 524 | "justify_content": null, 525 | "_view_module": "@jupyter-widgets/base", 526 | "overflow": null, 527 | "_model_module_version": "1.2.0", 528 | "_view_count": null, 529 | "flex_flow": null, 530 | "width": null, 531 | "min_width": null, 532 | "border": null, 533 | "align_items": null, 534 | "bottom": null, 535 | "_model_module": "@jupyter-widgets/base", 536 | "top": null, 537 | "grid_column": null, 538 | "overflow_y": null, 539 | "overflow_x": null, 540 | "grid_auto_flow": null, 541 | "grid_area": null, 542 | "grid_template_columns": null, 543 | "flex": null, 544 | "_model_name": "LayoutModel", 545 | "justify_items": null, 546 | "grid_row": null, 547 | "max_height": null, 548 | "align_content": null, 549 | "visibility": null, 550 | "align_self": null, 551 | "height": null, 552 | "min_height": null, 553 | "padding": null, 554 | "grid_auto_rows": null, 555 | "grid_gap": null, 556 | "max_width": null, 557 | "order": null, 558 | "_view_module_version": "1.2.0", 559 | "grid_template_areas": null, 560 | "object_position": null, 561 | "object_fit": null, 562 | "grid_auto_columns": null, 563 | "margin": null, 564 | "display": null, 565 | "left": null 566 | } 567 | }, 568 | "902c7115997647b9a5140ba707d74390": { 569 | "model_module": "@jupyter-widgets/controls", 570 | "model_name": "ProgressStyleModel", 571 | "model_module_version": "1.5.0", 572 | "state": { 573 | "_view_name": "StyleView", 574 | "_model_name": "ProgressStyleModel", 575 | "description_width": "", 576 | "_view_module": "@jupyter-widgets/base", 577 | "_model_module_version": "1.5.0", 578 | "_view_count": null, 579 | "_view_module_version": "1.2.0", 580 | "bar_color": null, 581 | "_model_module": "@jupyter-widgets/controls" 582 | } 583 | }, 584 | "cbff745d3ebb4ba0b9bfd8b1a210550f": { 585 | "model_module": "@jupyter-widgets/base", 586 | "model_name": "LayoutModel", 587 | "model_module_version": "1.2.0", 588 | "state": { 589 | "_view_name": "LayoutView", 590 | "grid_template_rows": null, 591 | "right": null, 592 | "justify_content": null, 593 | "_view_module": "@jupyter-widgets/base", 594 | "overflow": null, 595 | "_model_module_version": "1.2.0", 596 | "_view_count": null, 597 | "flex_flow": null, 598 | "width": null, 599 | "min_width": null, 600 | "border": null, 601 | "align_items": null, 602 | "bottom": null, 603 | "_model_module": "@jupyter-widgets/base", 604 | "top": null, 605 | "grid_column": null, 606 | "overflow_y": null, 607 | "overflow_x": null, 608 | "grid_auto_flow": null, 609 | "grid_area": null, 610 | "grid_template_columns": null, 611 | "flex": null, 612 | "_model_name": "LayoutModel", 613 | "justify_items": null, 614 | "grid_row": null, 615 | "max_height": null, 616 | "align_content": null, 617 | "visibility": null, 618 | "align_self": null, 619 | "height": null, 620 | "min_height": null, 621 | "padding": null, 622 | "grid_auto_rows": null, 623 | "grid_gap": null, 624 | "max_width": null, 625 | "order": null, 626 | "_view_module_version": "1.2.0", 627 | "grid_template_areas": null, 628 | "object_position": null, 629 | "object_fit": null, 630 | "grid_auto_columns": null, 631 | "margin": null, 632 | "display": null, 633 | "left": null 634 | } 635 | }, 636 | "db0e50cf79174b99b7943fe9fa51b738": { 637 | "model_module": "@jupyter-widgets/controls", 638 | "model_name": "DescriptionStyleModel", 639 | "model_module_version": "1.5.0", 640 | "state": { 641 | "_view_name": "StyleView", 642 | "_model_name": "DescriptionStyleModel", 643 | "description_width": "", 644 | "_view_module": "@jupyter-widgets/base", 645 | "_model_module_version": "1.5.0", 646 | "_view_count": null, 647 | "_view_module_version": "1.2.0", 648 | "_model_module": "@jupyter-widgets/controls" 649 | } 650 | }, 651 | "91e992390cf94d4bb23164ccbfb2a4d0": { 652 | "model_module": "@jupyter-widgets/base", 653 | "model_name": "LayoutModel", 654 | "model_module_version": "1.2.0", 655 | "state": { 656 | "_view_name": "LayoutView", 657 | "grid_template_rows": null, 658 | "right": null, 659 | "justify_content": null, 660 | "_view_module": "@jupyter-widgets/base", 661 | "overflow": null, 662 | "_model_module_version": "1.2.0", 663 | "_view_count": null, 664 | "flex_flow": null, 665 | "width": null, 666 | "min_width": null, 667 | "border": null, 668 | "align_items": null, 669 | "bottom": null, 670 | "_model_module": "@jupyter-widgets/base", 671 | "top": null, 672 | "grid_column": null, 673 | "overflow_y": null, 674 | "overflow_x": null, 675 | "grid_auto_flow": null, 676 | "grid_area": null, 677 | "grid_template_columns": null, 678 | "flex": null, 679 | "_model_name": "LayoutModel", 680 | "justify_items": null, 681 | "grid_row": null, 682 | "max_height": null, 683 | "align_content": null, 684 | "visibility": null, 685 | "align_self": null, 686 | "height": null, 687 | "min_height": null, 688 | "padding": null, 689 | "grid_auto_rows": null, 690 | "grid_gap": null, 691 | "max_width": null, 692 | "order": null, 693 | "_view_module_version": "1.2.0", 694 | "grid_template_areas": null, 695 | "object_position": null, 696 | "object_fit": null, 697 | "grid_auto_columns": null, 698 | "margin": null, 699 | "display": null, 700 | "left": null 701 | } 702 | }, 703 | "a66fb6293209491caf0f05bbab72c1f3": { 704 | "model_module": "@jupyter-widgets/controls", 705 | "model_name": "HBoxModel", 706 | "model_module_version": "1.5.0", 707 | "state": { 708 | "_view_name": "HBoxView", 709 | "_dom_classes": [], 710 | "_model_name": "HBoxModel", 711 | "_view_module": "@jupyter-widgets/controls", 712 | "_model_module_version": "1.5.0", 713 | "_view_count": null, 714 | "_view_module_version": "1.5.0", 715 | "box_style": "", 716 | "layout": "IPY_MODEL_b226e034649d4a27968a24f0215acfa2", 717 | "_model_module": "@jupyter-widgets/controls", 718 | "children": [ 719 | "IPY_MODEL_143a846ef031430a94ca03ce798e1c38", 720 | "IPY_MODEL_02cd0740d8394ee9978e2205bd7c0885", 721 | "IPY_MODEL_fe2d58558248438f89639669b20d044a" 722 | ] 723 | } 724 | }, 725 | "b226e034649d4a27968a24f0215acfa2": { 726 | "model_module": "@jupyter-widgets/base", 727 | "model_name": "LayoutModel", 728 | "model_module_version": "1.2.0", 729 | "state": { 730 | "_view_name": "LayoutView", 731 | "grid_template_rows": null, 732 | "right": null, 733 | "justify_content": null, 734 | "_view_module": "@jupyter-widgets/base", 735 | "overflow": null, 736 | "_model_module_version": "1.2.0", 737 | "_view_count": null, 738 | "flex_flow": null, 739 | "width": null, 740 | "min_width": null, 741 | "border": null, 742 | "align_items": null, 743 | "bottom": null, 744 | "_model_module": "@jupyter-widgets/base", 745 | "top": null, 746 | "grid_column": null, 747 | "overflow_y": null, 748 | "overflow_x": null, 749 | "grid_auto_flow": null, 750 | "grid_area": null, 751 | "grid_template_columns": null, 752 | "flex": null, 753 | "_model_name": "LayoutModel", 754 | "justify_items": null, 755 | "grid_row": null, 756 | "max_height": null, 757 | "align_content": null, 758 | "visibility": null, 759 | "align_self": null, 760 | "height": null, 761 | "min_height": null, 762 | "padding": null, 763 | "grid_auto_rows": null, 764 | "grid_gap": null, 765 | "max_width": null, 766 | "order": null, 767 | "_view_module_version": "1.2.0", 768 | "grid_template_areas": null, 769 | "object_position": null, 770 | "object_fit": null, 771 | "grid_auto_columns": null, 772 | "margin": null, 773 | "display": null, 774 | "left": null 775 | } 776 | }, 777 | "143a846ef031430a94ca03ce798e1c38": { 778 | "model_module": "@jupyter-widgets/controls", 779 | "model_name": "HTMLModel", 780 | "model_module_version": "1.5.0", 781 | "state": { 782 | "_view_name": "HTMLView", 783 | "style": "IPY_MODEL_63e7051bccb64b19a9eb416d6a2aa63d", 784 | "_dom_classes": [], 785 | "description": "", 786 | "_model_name": "HTMLModel", 787 | "placeholder": "​", 788 | "_view_module": "@jupyter-widgets/controls", 789 | "_model_module_version": "1.5.0", 790 | "value": "Downloading: 100%", 791 | "_view_count": null, 792 | "_view_module_version": "1.5.0", 793 | "description_tooltip": null, 794 | "_model_module": "@jupyter-widgets/controls", 795 | "layout": "IPY_MODEL_5cb83ac370b54d9bb51443cfb4b1d324" 796 | } 797 | }, 798 | "02cd0740d8394ee9978e2205bd7c0885": { 799 | "model_module": "@jupyter-widgets/controls", 800 | "model_name": "FloatProgressModel", 801 | "model_module_version": "1.5.0", 802 | "state": { 803 | "_view_name": "ProgressView", 804 | "style": "IPY_MODEL_33843555ad2e49789f8704d6abda7c00", 805 | "_dom_classes": [], 806 | "description": "", 807 | "_model_name": "FloatProgressModel", 808 | "bar_style": "success", 809 | "max": 48, 810 | "_view_module": "@jupyter-widgets/controls", 811 | "_model_module_version": "1.5.0", 812 | "value": 48, 813 | "_view_count": null, 814 | "_view_module_version": "1.5.0", 815 | "orientation": "horizontal", 816 | "min": 0, 817 | "description_tooltip": null, 818 | "_model_module": "@jupyter-widgets/controls", 819 | "layout": "IPY_MODEL_0ce9c5b524074ee9996cfb7ba2c35c83" 820 | } 821 | }, 822 | "fe2d58558248438f89639669b20d044a": { 823 | "model_module": "@jupyter-widgets/controls", 824 | "model_name": "HTMLModel", 825 | "model_module_version": "1.5.0", 826 | "state": { 827 | "_view_name": "HTMLView", 828 | "style": "IPY_MODEL_04653b5fc47244c7abe96622f28b13f7", 829 | "_dom_classes": [], 830 | "description": "", 831 | "_model_name": "HTMLModel", 832 | "placeholder": "​", 833 | "_view_module": "@jupyter-widgets/controls", 834 | "_model_module_version": "1.5.0", 835 | "value": " 48.0/48.0 [00:00<00:00, 493B/s]", 836 | "_view_count": null, 837 | "_view_module_version": "1.5.0", 838 | "description_tooltip": null, 839 | "_model_module": "@jupyter-widgets/controls", 840 | "layout": "IPY_MODEL_40325b42a719450eb4ad85b7c4ff0423" 841 | } 842 | }, 843 | "63e7051bccb64b19a9eb416d6a2aa63d": { 844 | "model_module": "@jupyter-widgets/controls", 845 | "model_name": "DescriptionStyleModel", 846 | "model_module_version": "1.5.0", 847 | "state": { 848 | "_view_name": "StyleView", 849 | "_model_name": "DescriptionStyleModel", 850 | "description_width": "", 851 | "_view_module": "@jupyter-widgets/base", 852 | "_model_module_version": "1.5.0", 853 | "_view_count": null, 854 | "_view_module_version": "1.2.0", 855 | "_model_module": "@jupyter-widgets/controls" 856 | } 857 | }, 858 | "5cb83ac370b54d9bb51443cfb4b1d324": { 859 | "model_module": "@jupyter-widgets/base", 860 | "model_name": "LayoutModel", 861 | "model_module_version": "1.2.0", 862 | "state": { 863 | "_view_name": "LayoutView", 864 | "grid_template_rows": null, 865 | "right": null, 866 | "justify_content": null, 867 | "_view_module": "@jupyter-widgets/base", 868 | "overflow": null, 869 | "_model_module_version": "1.2.0", 870 | "_view_count": null, 871 | "flex_flow": null, 872 | "width": null, 873 | "min_width": null, 874 | "border": null, 875 | "align_items": null, 876 | "bottom": null, 877 | "_model_module": "@jupyter-widgets/base", 878 | "top": null, 879 | "grid_column": null, 880 | "overflow_y": null, 881 | "overflow_x": null, 882 | "grid_auto_flow": null, 883 | "grid_area": null, 884 | "grid_template_columns": null, 885 | "flex": null, 886 | "_model_name": "LayoutModel", 887 | "justify_items": null, 888 | "grid_row": null, 889 | "max_height": null, 890 | "align_content": null, 891 | "visibility": null, 892 | "align_self": null, 893 | "height": null, 894 | "min_height": null, 895 | "padding": null, 896 | "grid_auto_rows": null, 897 | "grid_gap": null, 898 | "max_width": null, 899 | "order": null, 900 | "_view_module_version": "1.2.0", 901 | "grid_template_areas": null, 902 | "object_position": null, 903 | "object_fit": null, 904 | "grid_auto_columns": null, 905 | "margin": null, 906 | "display": null, 907 | "left": null 908 | } 909 | }, 910 | "33843555ad2e49789f8704d6abda7c00": { 911 | "model_module": "@jupyter-widgets/controls", 912 | "model_name": "ProgressStyleModel", 913 | "model_module_version": "1.5.0", 914 | "state": { 915 | "_view_name": "StyleView", 916 | "_model_name": "ProgressStyleModel", 917 | "description_width": "", 918 | "_view_module": "@jupyter-widgets/base", 919 | "_model_module_version": "1.5.0", 920 | "_view_count": null, 921 | "_view_module_version": "1.2.0", 922 | "bar_color": null, 923 | "_model_module": "@jupyter-widgets/controls" 924 | } 925 | }, 926 | "0ce9c5b524074ee9996cfb7ba2c35c83": { 927 | "model_module": "@jupyter-widgets/base", 928 | "model_name": "LayoutModel", 929 | "model_module_version": "1.2.0", 930 | "state": { 931 | "_view_name": "LayoutView", 932 | "grid_template_rows": null, 933 | "right": null, 934 | "justify_content": null, 935 | "_view_module": "@jupyter-widgets/base", 936 | "overflow": null, 937 | "_model_module_version": "1.2.0", 938 | "_view_count": null, 939 | "flex_flow": null, 940 | "width": null, 941 | "min_width": null, 942 | "border": null, 943 | "align_items": null, 944 | "bottom": null, 945 | "_model_module": "@jupyter-widgets/base", 946 | "top": null, 947 | "grid_column": null, 948 | "overflow_y": null, 949 | "overflow_x": null, 950 | "grid_auto_flow": null, 951 | "grid_area": null, 952 | "grid_template_columns": null, 953 | "flex": null, 954 | "_model_name": "LayoutModel", 955 | "justify_items": null, 956 | "grid_row": null, 957 | "max_height": null, 958 | "align_content": null, 959 | "visibility": null, 960 | "align_self": null, 961 | "height": null, 962 | "min_height": null, 963 | "padding": null, 964 | "grid_auto_rows": null, 965 | "grid_gap": null, 966 | "max_width": null, 967 | "order": null, 968 | "_view_module_version": "1.2.0", 969 | "grid_template_areas": null, 970 | "object_position": null, 971 | "object_fit": null, 972 | "grid_auto_columns": null, 973 | "margin": null, 974 | "display": null, 975 | "left": null 976 | } 977 | }, 978 | "04653b5fc47244c7abe96622f28b13f7": { 979 | "model_module": "@jupyter-widgets/controls", 980 | "model_name": "DescriptionStyleModel", 981 | "model_module_version": "1.5.0", 982 | "state": { 983 | "_view_name": "StyleView", 984 | "_model_name": "DescriptionStyleModel", 985 | "description_width": "", 986 | "_view_module": "@jupyter-widgets/base", 987 | "_model_module_version": "1.5.0", 988 | "_view_count": null, 989 | "_view_module_version": "1.2.0", 990 | "_model_module": "@jupyter-widgets/controls" 991 | } 992 | }, 993 | "40325b42a719450eb4ad85b7c4ff0423": { 994 | "model_module": "@jupyter-widgets/base", 995 | "model_name": "LayoutModel", 996 | "model_module_version": "1.2.0", 997 | "state": { 998 | "_view_name": "LayoutView", 999 | "grid_template_rows": null, 1000 | "right": null, 1001 | "justify_content": null, 1002 | "_view_module": "@jupyter-widgets/base", 1003 | "overflow": null, 1004 | "_model_module_version": "1.2.0", 1005 | "_view_count": null, 1006 | "flex_flow": null, 1007 | "width": null, 1008 | "min_width": null, 1009 | "border": null, 1010 | "align_items": null, 1011 | "bottom": null, 1012 | "_model_module": "@jupyter-widgets/base", 1013 | "top": null, 1014 | "grid_column": null, 1015 | "overflow_y": null, 1016 | "overflow_x": null, 1017 | "grid_auto_flow": null, 1018 | "grid_area": null, 1019 | "grid_template_columns": null, 1020 | "flex": null, 1021 | "_model_name": "LayoutModel", 1022 | "justify_items": null, 1023 | "grid_row": null, 1024 | "max_height": null, 1025 | "align_content": null, 1026 | "visibility": null, 1027 | "align_self": null, 1028 | "height": null, 1029 | "min_height": null, 1030 | "padding": null, 1031 | "grid_auto_rows": null, 1032 | "grid_gap": null, 1033 | "max_width": null, 1034 | "order": null, 1035 | "_view_module_version": "1.2.0", 1036 | "grid_template_areas": null, 1037 | "object_position": null, 1038 | "object_fit": null, 1039 | "grid_auto_columns": null, 1040 | "margin": null, 1041 | "display": null, 1042 | "left": null 1043 | } 1044 | }, 1045 | "a56872e4b7fd495f8ba3a1e2a63788bc": { 1046 | "model_module": "@jupyter-widgets/controls", 1047 | "model_name": "HBoxModel", 1048 | "model_module_version": "1.5.0", 1049 | "state": { 1050 | "_view_name": "HBoxView", 1051 | "_dom_classes": [], 1052 | "_model_name": "HBoxModel", 1053 | "_view_module": "@jupyter-widgets/controls", 1054 | "_model_module_version": "1.5.0", 1055 | "_view_count": null, 1056 | "_view_module_version": "1.5.0", 1057 | "box_style": "", 1058 | "layout": "IPY_MODEL_edd0e5ac02764983bcf87308769e20af", 1059 | "_model_module": "@jupyter-widgets/controls", 1060 | "children": [ 1061 | "IPY_MODEL_c847204177474e33a61f49c4db5d8ac1", 1062 | "IPY_MODEL_0d1bdbc672ee4f83b836b363863e223f", 1063 | "IPY_MODEL_ac7b1ed676894875971eadbfc22bfd07" 1064 | ] 1065 | } 1066 | }, 1067 | "edd0e5ac02764983bcf87308769e20af": { 1068 | "model_module": "@jupyter-widgets/base", 1069 | "model_name": "LayoutModel", 1070 | "model_module_version": "1.2.0", 1071 | "state": { 1072 | "_view_name": "LayoutView", 1073 | "grid_template_rows": null, 1074 | "right": null, 1075 | "justify_content": null, 1076 | "_view_module": "@jupyter-widgets/base", 1077 | "overflow": null, 1078 | "_model_module_version": "1.2.0", 1079 | "_view_count": null, 1080 | "flex_flow": null, 1081 | "width": null, 1082 | "min_width": null, 1083 | "border": null, 1084 | "align_items": null, 1085 | "bottom": null, 1086 | "_model_module": "@jupyter-widgets/base", 1087 | "top": null, 1088 | "grid_column": null, 1089 | "overflow_y": null, 1090 | "overflow_x": null, 1091 | "grid_auto_flow": null, 1092 | "grid_area": null, 1093 | "grid_template_columns": null, 1094 | "flex": null, 1095 | "_model_name": "LayoutModel", 1096 | "justify_items": null, 1097 | "grid_row": null, 1098 | "max_height": null, 1099 | "align_content": null, 1100 | "visibility": null, 1101 | "align_self": null, 1102 | "height": null, 1103 | "min_height": null, 1104 | "padding": null, 1105 | "grid_auto_rows": null, 1106 | "grid_gap": null, 1107 | "max_width": null, 1108 | "order": null, 1109 | "_view_module_version": "1.2.0", 1110 | "grid_template_areas": null, 1111 | "object_position": null, 1112 | "object_fit": null, 1113 | "grid_auto_columns": null, 1114 | "margin": null, 1115 | "display": null, 1116 | "left": null 1117 | } 1118 | }, 1119 | "c847204177474e33a61f49c4db5d8ac1": { 1120 | "model_module": "@jupyter-widgets/controls", 1121 | "model_name": "HTMLModel", 1122 | "model_module_version": "1.5.0", 1123 | "state": { 1124 | "_view_name": "HTMLView", 1125 | "style": "IPY_MODEL_9f56b91e83444c8b9272a7a469156ca9", 1126 | "_dom_classes": [], 1127 | "description": "", 1128 | "_model_name": "HTMLModel", 1129 | "placeholder": "​", 1130 | "_view_module": "@jupyter-widgets/controls", 1131 | "_model_module_version": "1.5.0", 1132 | "value": "Downloading: 100%", 1133 | "_view_count": null, 1134 | "_view_module_version": "1.5.0", 1135 | "description_tooltip": null, 1136 | "_model_module": "@jupyter-widgets/controls", 1137 | "layout": "IPY_MODEL_31f205c738f34949a55a20f282e7ce79" 1138 | } 1139 | }, 1140 | "0d1bdbc672ee4f83b836b363863e223f": { 1141 | "model_module": "@jupyter-widgets/controls", 1142 | "model_name": "FloatProgressModel", 1143 | "model_module_version": "1.5.0", 1144 | "state": { 1145 | "_view_name": "ProgressView", 1146 | "style": "IPY_MODEL_a618850b03e949c495390e4c249ae29d", 1147 | "_dom_classes": [], 1148 | "description": "", 1149 | "_model_name": "FloatProgressModel", 1150 | "bar_style": "success", 1151 | "max": 231508, 1152 | "_view_module": "@jupyter-widgets/controls", 1153 | "_model_module_version": "1.5.0", 1154 | "value": 231508, 1155 | "_view_count": null, 1156 | "_view_module_version": "1.5.0", 1157 | "orientation": "horizontal", 1158 | "min": 0, 1159 | "description_tooltip": null, 1160 | "_model_module": "@jupyter-widgets/controls", 1161 | "layout": "IPY_MODEL_4846e7b941924fd8882c82def4094005" 1162 | } 1163 | }, 1164 | "ac7b1ed676894875971eadbfc22bfd07": { 1165 | "model_module": "@jupyter-widgets/controls", 1166 | "model_name": "HTMLModel", 1167 | "model_module_version": "1.5.0", 1168 | "state": { 1169 | "_view_name": "HTMLView", 1170 | "style": "IPY_MODEL_2d2e0dc44d154729adcff6ced9f29042", 1171 | "_dom_classes": [], 1172 | "description": "", 1173 | "_model_name": "HTMLModel", 1174 | "placeholder": "​", 1175 | "_view_module": "@jupyter-widgets/controls", 1176 | "_model_module_version": "1.5.0", 1177 | "value": " 226k/226k [00:00<00:00, 1.04MB/s]", 1178 | "_view_count": null, 1179 | "_view_module_version": "1.5.0", 1180 | "description_tooltip": null, 1181 | "_model_module": "@jupyter-widgets/controls", 1182 | "layout": "IPY_MODEL_25b532f5020a4790b3ec5e4a71e1b91b" 1183 | } 1184 | }, 1185 | "9f56b91e83444c8b9272a7a469156ca9": { 1186 | "model_module": "@jupyter-widgets/controls", 1187 | "model_name": "DescriptionStyleModel", 1188 | "model_module_version": "1.5.0", 1189 | "state": { 1190 | "_view_name": "StyleView", 1191 | "_model_name": "DescriptionStyleModel", 1192 | "description_width": "", 1193 | "_view_module": "@jupyter-widgets/base", 1194 | "_model_module_version": "1.5.0", 1195 | "_view_count": null, 1196 | "_view_module_version": "1.2.0", 1197 | "_model_module": "@jupyter-widgets/controls" 1198 | } 1199 | }, 1200 | "31f205c738f34949a55a20f282e7ce79": { 1201 | "model_module": "@jupyter-widgets/base", 1202 | "model_name": "LayoutModel", 1203 | "model_module_version": "1.2.0", 1204 | "state": { 1205 | "_view_name": "LayoutView", 1206 | "grid_template_rows": null, 1207 | "right": null, 1208 | "justify_content": null, 1209 | "_view_module": "@jupyter-widgets/base", 1210 | "overflow": null, 1211 | "_model_module_version": "1.2.0", 1212 | "_view_count": null, 1213 | "flex_flow": null, 1214 | "width": null, 1215 | "min_width": null, 1216 | "border": null, 1217 | "align_items": null, 1218 | "bottom": null, 1219 | "_model_module": "@jupyter-widgets/base", 1220 | "top": null, 1221 | "grid_column": null, 1222 | "overflow_y": null, 1223 | "overflow_x": null, 1224 | "grid_auto_flow": null, 1225 | "grid_area": null, 1226 | "grid_template_columns": null, 1227 | "flex": null, 1228 | "_model_name": "LayoutModel", 1229 | "justify_items": null, 1230 | "grid_row": null, 1231 | "max_height": null, 1232 | "align_content": null, 1233 | "visibility": null, 1234 | "align_self": null, 1235 | "height": null, 1236 | "min_height": null, 1237 | "padding": null, 1238 | "grid_auto_rows": null, 1239 | "grid_gap": null, 1240 | "max_width": null, 1241 | "order": null, 1242 | "_view_module_version": "1.2.0", 1243 | "grid_template_areas": null, 1244 | "object_position": null, 1245 | "object_fit": null, 1246 | "grid_auto_columns": null, 1247 | "margin": null, 1248 | "display": null, 1249 | "left": null 1250 | } 1251 | }, 1252 | "a618850b03e949c495390e4c249ae29d": { 1253 | "model_module": "@jupyter-widgets/controls", 1254 | "model_name": "ProgressStyleModel", 1255 | "model_module_version": "1.5.0", 1256 | "state": { 1257 | "_view_name": "StyleView", 1258 | "_model_name": "ProgressStyleModel", 1259 | "description_width": "", 1260 | "_view_module": "@jupyter-widgets/base", 1261 | "_model_module_version": "1.5.0", 1262 | "_view_count": null, 1263 | "_view_module_version": "1.2.0", 1264 | "bar_color": null, 1265 | "_model_module": "@jupyter-widgets/controls" 1266 | } 1267 | }, 1268 | "4846e7b941924fd8882c82def4094005": { 1269 | "model_module": "@jupyter-widgets/base", 1270 | "model_name": "LayoutModel", 1271 | "model_module_version": "1.2.0", 1272 | "state": { 1273 | "_view_name": "LayoutView", 1274 | "grid_template_rows": null, 1275 | "right": null, 1276 | "justify_content": null, 1277 | "_view_module": "@jupyter-widgets/base", 1278 | "overflow": null, 1279 | "_model_module_version": "1.2.0", 1280 | "_view_count": null, 1281 | "flex_flow": null, 1282 | "width": null, 1283 | "min_width": null, 1284 | "border": null, 1285 | "align_items": null, 1286 | "bottom": null, 1287 | "_model_module": "@jupyter-widgets/base", 1288 | "top": null, 1289 | "grid_column": null, 1290 | "overflow_y": null, 1291 | "overflow_x": null, 1292 | "grid_auto_flow": null, 1293 | "grid_area": null, 1294 | "grid_template_columns": null, 1295 | "flex": null, 1296 | "_model_name": "LayoutModel", 1297 | "justify_items": null, 1298 | "grid_row": null, 1299 | "max_height": null, 1300 | "align_content": null, 1301 | "visibility": null, 1302 | "align_self": null, 1303 | "height": null, 1304 | "min_height": null, 1305 | "padding": null, 1306 | "grid_auto_rows": null, 1307 | "grid_gap": null, 1308 | "max_width": null, 1309 | "order": null, 1310 | "_view_module_version": "1.2.0", 1311 | "grid_template_areas": null, 1312 | "object_position": null, 1313 | "object_fit": null, 1314 | "grid_auto_columns": null, 1315 | "margin": null, 1316 | "display": null, 1317 | "left": null 1318 | } 1319 | }, 1320 | "2d2e0dc44d154729adcff6ced9f29042": { 1321 | "model_module": "@jupyter-widgets/controls", 1322 | "model_name": "DescriptionStyleModel", 1323 | "model_module_version": "1.5.0", 1324 | "state": { 1325 | "_view_name": "StyleView", 1326 | "_model_name": "DescriptionStyleModel", 1327 | "description_width": "", 1328 | "_view_module": "@jupyter-widgets/base", 1329 | "_model_module_version": "1.5.0", 1330 | "_view_count": null, 1331 | "_view_module_version": "1.2.0", 1332 | "_model_module": "@jupyter-widgets/controls" 1333 | } 1334 | }, 1335 | "25b532f5020a4790b3ec5e4a71e1b91b": { 1336 | "model_module": "@jupyter-widgets/base", 1337 | "model_name": "LayoutModel", 1338 | "model_module_version": "1.2.0", 1339 | "state": { 1340 | "_view_name": "LayoutView", 1341 | "grid_template_rows": null, 1342 | "right": null, 1343 | "justify_content": null, 1344 | "_view_module": "@jupyter-widgets/base", 1345 | "overflow": null, 1346 | "_model_module_version": "1.2.0", 1347 | "_view_count": null, 1348 | "flex_flow": null, 1349 | "width": null, 1350 | "min_width": null, 1351 | "border": null, 1352 | "align_items": null, 1353 | "bottom": null, 1354 | "_model_module": "@jupyter-widgets/base", 1355 | "top": null, 1356 | "grid_column": null, 1357 | "overflow_y": null, 1358 | "overflow_x": null, 1359 | "grid_auto_flow": null, 1360 | "grid_area": null, 1361 | "grid_template_columns": null, 1362 | "flex": null, 1363 | "_model_name": "LayoutModel", 1364 | "justify_items": null, 1365 | "grid_row": null, 1366 | "max_height": null, 1367 | "align_content": null, 1368 | "visibility": null, 1369 | "align_self": null, 1370 | "height": null, 1371 | "min_height": null, 1372 | "padding": null, 1373 | "grid_auto_rows": null, 1374 | "grid_gap": null, 1375 | "max_width": null, 1376 | "order": null, 1377 | "_view_module_version": "1.2.0", 1378 | "grid_template_areas": null, 1379 | "object_position": null, 1380 | "object_fit": null, 1381 | "grid_auto_columns": null, 1382 | "margin": null, 1383 | "display": null, 1384 | "left": null 1385 | } 1386 | } 1387 | } 1388 | } 1389 | }, 1390 | "cells": [ 1391 | { 1392 | "cell_type": "markdown", 1393 | "metadata": { 1394 | "id": "view-in-github", 1395 | "colab_type": "text" 1396 | }, 1397 | "source": [ 1398 | "\"Open" 1399 | ] 1400 | }, 1401 | { 1402 | "cell_type": "markdown", 1403 | "metadata": { 1404 | "id": "4IYErMLarJTr" 1405 | }, 1406 | "source": [ 1407 | "## Example of TreeMix" 1408 | ] 1409 | }, 1410 | { 1411 | "cell_type": "markdown", 1412 | "metadata": { 1413 | "id": "XJEVX6F9rQdI" 1414 | }, 1415 | "source": [ 1416 | "NOTE : This page was originally used by HuggingFace to illustrate the summary of various tasks ([original page](https://colab.research.google.com/github/huggingface/notebooks/blob/master/transformers_doc/pytorch/task_summary.ipynb#scrollTo=XJEVX6F9rQdI)), we use it to show the examples we illustrate in our paper. We follow the orginal settings and just change the sentence in to predict. This is a sequence classification model trained on full SST2 datasets." 1417 | ] 1418 | }, 1419 | { 1420 | "cell_type": "code", 1421 | "metadata": { 1422 | "id": "D4qYNi9oqDXN", 1423 | "colab": { 1424 | "base_uri": "https://localhost:8080/" 1425 | }, 1426 | "outputId": "0aa9ff6c-9b18-4e98-90d6-24057dbbb320" 1427 | }, 1428 | "source": [ 1429 | "# Transformers installation\n", 1430 | "! pip install transformers datasets\n", 1431 | "# To install from source instead of the last release, comment the command above and uncomment the following one.\n", 1432 | "# ! pip install git+https://github.com/huggingface/transformers.git\n" 1433 | ], 1434 | "execution_count": null, 1435 | "outputs": [ 1436 | { 1437 | "output_type": "stream", 1438 | "name": "stdout", 1439 | "text": [ 1440 | "Collecting transformers\n", 1441 | " Downloading transformers-4.12.2-py3-none-any.whl (3.1 MB)\n", 1442 | "\u001b[K |████████████████████████████████| 3.1 MB 4.9 MB/s \n", 1443 | "\u001b[?25hCollecting datasets\n", 1444 | " Downloading datasets-1.14.0-py3-none-any.whl (290 kB)\n", 1445 | "\u001b[K |████████████████████████████████| 290 kB 42.2 MB/s \n", 1446 | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.3.0)\n", 1447 | "Collecting huggingface-hub>=0.0.17\n", 1448 | " Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)\n", 1449 | "\u001b[K |████████████████████████████████| 56 kB 4.1 MB/s \n", 1450 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", 1451 | "Collecting sacremoses\n", 1452 | " Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)\n", 1453 | "\u001b[K |████████████████████████████████| 895 kB 56.3 MB/s \n", 1454 | "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n", 1455 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1)\n", 1456 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n", 1457 | "Collecting pyyaml>=5.1\n", 1458 | " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n", 1459 | "\u001b[K |████████████████████████████████| 596 kB 57.1 MB/s \n", 1460 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n", 1461 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n", 1462 | "Collecting tokenizers<0.11,>=0.10.1\n", 1463 | " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n", 1464 | "\u001b[K |████████████████████████████████| 3.3 MB 30.8 MB/s \n", 1465 | "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3)\n", 1466 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7)\n", 1467 | "Collecting xxhash\n", 1468 | " Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)\n", 1469 | "\u001b[K |████████████████████████████████| 243 kB 57.5 MB/s \n", 1470 | "\u001b[?25hRequirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)\n", 1471 | "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)\n", 1472 | "Collecting aiohttp\n", 1473 | " Downloading aiohttp-3.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n", 1474 | "\u001b[K |████████████████████████████████| 1.1 MB 57.7 MB/s \n", 1475 | "\u001b[?25hCollecting fsspec[http]>=2021.05.0\n", 1476 | " Downloading fsspec-2021.10.1-py3-none-any.whl (125 kB)\n", 1477 | "\u001b[K |████████████████████████████████| 125 kB 60.6 MB/s \n", 1478 | "\u001b[?25hRequirement already satisfied: pyarrow!=4.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n", 1479 | "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n", 1480 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", 1481 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n", 1482 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", 1483 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n", 1484 | "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.7)\n", 1485 | "Collecting aiosignal>=1.1.2\n", 1486 | " Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n", 1487 | "Collecting yarl<2.0,>=1.0\n", 1488 | " Downloading yarl-1.7.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)\n", 1489 | "\u001b[K |████████████████████████████████| 271 kB 58.0 MB/s \n", 1490 | "\u001b[?25hCollecting asynctest==0.13.0\n", 1491 | " Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)\n", 1492 | "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.2.0)\n", 1493 | "Collecting frozenlist>=1.1.1\n", 1494 | " Downloading frozenlist-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (192 kB)\n", 1495 | "\u001b[K |████████████████████████████████| 192 kB 57.0 MB/s \n", 1496 | "\u001b[?25hCollecting multidict<7.0,>=4.5\n", 1497 | " Downloading multidict-5.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (160 kB)\n", 1498 | "\u001b[K |████████████████████████████████| 160 kB 57.1 MB/s \n", 1499 | "\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3\n", 1500 | " Downloading async_timeout-4.0.0a3-py3-none-any.whl (9.5 kB)\n", 1501 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n", 1502 | "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n", 1503 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n", 1504 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", 1505 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n", 1506 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n", 1507 | "Installing collected packages: multidict, frozenlist, yarl, asynctest, async-timeout, aiosignal, pyyaml, fsspec, aiohttp, xxhash, tokenizers, sacremoses, huggingface-hub, transformers, datasets\n", 1508 | " Attempting uninstall: pyyaml\n", 1509 | " Found existing installation: PyYAML 3.13\n", 1510 | " Uninstalling PyYAML-3.13:\n", 1511 | " Successfully uninstalled PyYAML-3.13\n", 1512 | "Successfully installed aiohttp-3.8.0 aiosignal-1.2.0 async-timeout-4.0.0a3 asynctest-0.13.0 datasets-1.14.0 frozenlist-1.2.0 fsspec-2021.10.1 huggingface-hub-0.0.19 multidict-5.2.0 pyyaml-6.0 sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.12.2 xxhash-2.0.2 yarl-1.7.0\n" 1513 | ] 1514 | } 1515 | ] 1516 | }, 1517 | { 1518 | "cell_type": "markdown", 1519 | "metadata": { 1520 | "id": "1FpvawfuqDXS" 1521 | }, 1522 | "source": [ 1523 | "## Sequence Classification" 1524 | ] 1525 | }, 1526 | { 1527 | "cell_type": "markdown", 1528 | "metadata": { 1529 | "id": "HGQEKVGZqDXS" 1530 | }, 1531 | "source": [ 1532 | "Sequence classification is the task of classifying sequences according to a given number of classes. An example of\n", 1533 | "sequence classification is the GLUE dataset, which is entirely based on that task. If you would like to fine-tune a\n", 1534 | "model on a GLUE sequence classification task, you may leverage the :prefix_link:*run_glue.py\n", 1535 | "*, :prefix_link:*run_tf_glue.py\n", 1536 | "*, :prefix_link:*run_tf_text_classification.py\n", 1537 | "* or :prefix_link:*run_xnli.py\n", 1538 | "* scripts.\n", 1539 | "\n", 1540 | "Here is an example of using pipelines to do sentiment analysis: identifying if a sequence is positive or negative. It\n", 1541 | "leverages a fine-tuned model on sst2, which is a GLUE task.\n", 1542 | "\n", 1543 | "This returns a label (\"POSITIVE\" or \"NEGATIVE\") alongside a score, as follows:" 1544 | ] 1545 | }, 1546 | { 1547 | "cell_type": "code", 1548 | "metadata": { 1549 | "id": "dcBfWoIaqDXT", 1550 | "colab": { 1551 | "base_uri": "https://localhost:8080/", 1552 | "height": 212, 1553 | "referenced_widgets": [ 1554 | "7c08442589064421b77b6aa52ce5b947", 1555 | "23f3f545c30d4226987e7909e8365413", 1556 | "3d29469f501c4a0b96c211241459f34a", 1557 | "9578e983215a4bc497b5882a4244b8a9", 1558 | "47b6aefd82c04f0fb3a3de2c7f37c58f", 1559 | "b1e066e72ca04508bf50a60961e66bea", 1560 | "96fe7d4c88e541eea33dde3889e3e381", 1561 | "65b763eef5184606b28d4673e6ad857d", 1562 | "40984bf56a8b448e82451fa60fc15f7c", 1563 | "21bf499e3fb34c02b63934ea6e4f162f", 1564 | "9566729c64304007953c365eb0341df2", 1565 | "6bf72948e9ce47d296cf9673d5c11ced", 1566 | "e32995501a0243c1ab56ee42ecddd80e", 1567 | "e1c3b3a1b0ce49cc9db2bd16be25c929", 1568 | "a1f0479b17f8451daaf69b0af81b9adb", 1569 | "5f7a3c947c564436bd2dfcc04d4da5ee", 1570 | "d0d537e5dc744a5794187c5a498358fc", 1571 | "bdbbb4175a6c4bd389755dc0f8d186cb", 1572 | "902c7115997647b9a5140ba707d74390", 1573 | "cbff745d3ebb4ba0b9bfd8b1a210550f", 1574 | "db0e50cf79174b99b7943fe9fa51b738", 1575 | "91e992390cf94d4bb23164ccbfb2a4d0", 1576 | "a66fb6293209491caf0f05bbab72c1f3", 1577 | "b226e034649d4a27968a24f0215acfa2", 1578 | "143a846ef031430a94ca03ce798e1c38", 1579 | "02cd0740d8394ee9978e2205bd7c0885", 1580 | "fe2d58558248438f89639669b20d044a", 1581 | "63e7051bccb64b19a9eb416d6a2aa63d", 1582 | "5cb83ac370b54d9bb51443cfb4b1d324", 1583 | "33843555ad2e49789f8704d6abda7c00", 1584 | "0ce9c5b524074ee9996cfb7ba2c35c83", 1585 | "04653b5fc47244c7abe96622f28b13f7", 1586 | "40325b42a719450eb4ad85b7c4ff0423", 1587 | "a56872e4b7fd495f8ba3a1e2a63788bc", 1588 | "edd0e5ac02764983bcf87308769e20af", 1589 | "c847204177474e33a61f49c4db5d8ac1", 1590 | "0d1bdbc672ee4f83b836b363863e223f", 1591 | "ac7b1ed676894875971eadbfc22bfd07", 1592 | "9f56b91e83444c8b9272a7a469156ca9", 1593 | "31f205c738f34949a55a20f282e7ce79", 1594 | "a618850b03e949c495390e4c249ae29d", 1595 | "4846e7b941924fd8882c82def4094005", 1596 | "2d2e0dc44d154729adcff6ced9f29042", 1597 | "25b532f5020a4790b3ec5e4a71e1b91b" 1598 | ] 1599 | }, 1600 | "outputId": "3ae13dd7-c7e5-45ac-eb01-48ecfa8620f3" 1601 | }, 1602 | "source": [ 1603 | "from transformers import pipeline\n", 1604 | "classifier = pipeline(\"sentiment-analysis\")\n", 1605 | "result = classifier(\"This film is good and every one loves it\")[0]\n", 1606 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n", 1607 | "result = classifier(\"The film is poor and I do not like it\")[0]\n", 1608 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n", 1609 | "result = classifier(\"This film is good but I do not like it\")[0]\n", 1610 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n" 1611 | ], 1612 | "execution_count": null, 1613 | "outputs": [ 1614 | { 1615 | "output_type": "stream", 1616 | "name": "stderr", 1617 | "text": [ 1618 | "No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)\n" 1619 | ] 1620 | }, 1621 | { 1622 | "output_type": "display_data", 1623 | "data": { 1624 | "application/vnd.jupyter.widget-view+json": { 1625 | "model_id": "7c08442589064421b77b6aa52ce5b947", 1626 | "version_minor": 0, 1627 | "version_major": 2 1628 | }, 1629 | "text/plain": [ 1630 | "Downloading: 0%| | 0.00/629 [00:00