├── ACL_2022_TreeMix.pdf
├── Augmentation.py
├── LICENSE
├── README.md
├── batch_train.py
├── online_augmentation
├── __init__.py
└── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── __init__.cpython-39.pyc
├── process_data
├── Load_data.py
├── __init__.py
├── __pycache__
│ ├── Augmentation.cpython-39.pyc
│ ├── Load_data.cpython-38.pyc
│ ├── Load_data.cpython-39.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── ceshi.cpython-39.pyc
│ ├── settings.cpython-38.pyc
│ └── settings.cpython-39.pyc
├── get_data.py
└── settings.py
├── requirements.txt
├── run.py
└── transformers_doc
└── pytorch
└── task_summary.ipynb
/ACL_2022_TreeMix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/ACL_2022_TreeMix.pdf
--------------------------------------------------------------------------------
/Augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | from nltk import Tree
3 | from tqdm import tqdm
4 | import pandas as pd
5 | import argparse
6 | import numpy as np
7 | import os
8 | from datasets import load_dataset, Dataset, concatenate_datasets
9 | from process_data import settings
10 | from multiprocessing import Pool,cpu_count
11 |
12 |
13 | def modify(commands, debug=0):
14 | commands = commands.split(' ')
15 | verb = ['look', 'jump', 'walk', 'turn', 'run']
16 | end_sign = ['left', 'right', 'twice', 'thrice']
17 | add_pos = []
18 | if debug:
19 | print(commands)
20 | for i in range(len(commands)-1):
21 | if commands[i] in end_sign and commands[i+1] in verb:
22 | add_pos.append(i+1)
23 | # commands.insert(i+1,'and')
24 | if debug:
25 | print(commands)
26 | if commands[i] in verb and commands[i+1] in verb:
27 | add_pos.append(i+1)
28 | # commands.insert(i+1,'and')
29 | if debug:
30 | print(commands)
31 | for i, pos in enumerate(add_pos):
32 | commands.insert(pos+i, 'and')
33 | if debug:
34 | print(commands)
35 | return ' '.join(commands)
36 |
37 |
38 | def c2a(commands, debug=0):
39 | verb = {'look': 'I_LOOK', 'walk': 'I_WALK',
40 | 'run': 'I_RUN', 'jump': 'I_JUMP'}
41 | direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'}
42 | times = {'twice': 2, 'thrice': 3}
43 | conjunction = ['and', 'after']
44 |
45 | commands = commands.split(' ')
46 | actions = []
47 | previous_command = []
48 | pre_actions = []
49 | flag = 0
50 | i = 0
51 | if debug:
52 | print('raw:', commands)
53 | while len(commands) > 0:
54 | current = commands.pop(0)
55 | if debug:
56 | print('-'*50)
57 | print('step ', i)
58 | i += 1
59 | print('current command:', current, len(commands))
60 | print('curret waiting commands list:', previous_command)
61 | print('already actions:', actions)
62 | print('previous waiting actions:', pre_actions)
63 | if current in verb.keys() or current == 'turn' or current in conjunction: # add new actions
64 | if current == 'and':
65 | continue
66 | if not previous_command: # initialization
67 | previous_command.append(current)
68 |
69 | else: # one conmands over
70 | if debug:
71 | print('##### one commands over#####')
72 | current_action = translate(previous_command)
73 | previous_command = []
74 | if debug:
75 | print(
76 | '****got new action from previous commandsa list:{}****'.format(current_action[0]))
77 |
78 | if current == 'after':
79 | pre_actions.extend(current_action)
80 | if debug:
81 | print('****this action into pre_actions****')
82 | elif pre_actions:
83 | if debug:
84 | print(
85 | '****pre_actions and current_actions into action list****')
86 | actions.extend(current_action)
87 | actions.extend(pre_actions)
88 | pre_actions = []
89 | previous_command.append(current)
90 | else:
91 | # current is a verb
92 | previous_command.append(current)
93 | actions.extend(current_action)
94 | else:
95 | previous_command.append(current)
96 | if previous_command:
97 | current_action = translate(previous_command)
98 | actions.extend(current_action)
99 | if pre_actions:
100 | actions.extend(pre_actions)
101 | if debug:
102 | print('-'*50)
103 | print('over')
104 | print('previous_command', previous_command)
105 | print('pre_actions', pre_actions)
106 | print('current action', current_action)
107 | return actions
108 |
109 |
110 | def translate(previous_command):
111 | verb = {'look': 'I_LOOK', 'walk': 'I_WALK',
112 | 'run': 'I_RUN', 'jump': 'I_JUMP'}
113 | direction = {'left': 'I_TURN_LEFT', 'right': 'I_TURN_RIGHT'}
114 | times = {'twice': 2, 'thrice': 3}
115 | conjunction = ['and', 'after']
116 | if previous_command[-1] in times.keys():
117 | return translate(previous_command[:-1])*times[previous_command[-1]]
118 | if len(previous_command) == 1:
119 | return [verb[previous_command[0]]]
120 | elif len(previous_command) == 2:
121 | if previous_command[0] == 'turn':
122 | return [direction[previous_command[1]]]
123 | elif previous_command[1] in direction:
124 | return [direction[previous_command[1]], verb[previous_command[0]]]
125 | elif len(previous_command) == 3:
126 | if previous_command[0] == 'turn':
127 | if previous_command[1] == 'opposite':
128 | return [direction[previous_command[2]]]*2
129 | else:
130 | return [direction[previous_command[2]]]*4
131 | elif previous_command[0] in verb.keys():
132 | if previous_command[1] == 'opposite':
133 | return [direction[previous_command[2]], direction[previous_command[2]], verb[previous_command[0]]]
134 | else:
135 | return [direction[previous_command[2]], verb[previous_command[0]]]*4
136 | def set_seed(seed):
137 | random.seed(seed)
138 | np.random.seed(seed)
139 | def subtree_exchange_scan(args,parsing1,parsing2):
140 | new_sentence=None
141 | try:
142 | if args.debug:
143 | print('check5')
144 | t1 = Tree.fromstring(parsing1)
145 | t2 = Tree.fromstring(parsing2)
146 | t1_len=len(t1.leaves())
147 | t2_len=len(t2.leaves())
148 | # ----- restrict label--------------
149 | # candidate_subtree1=list(t1.subtrees(lambda t: t.label() in ['VP','VB']))
150 | # candidate_subtree2 = list(t2.subtrees(lambda t: t.label() in ['VP', 'VB']))
151 | # tree_labels1 = [tree.label() for tree in candidate_subtree1]
152 | # tree_labels2 = [tree.label() for tree in candidate_subtree2]
153 | # same_labels = list(set(tree_labels1) & set(tree_labels2))
154 | # if not same_labels:
155 | # if args.debug:
156 | # print('no same label')
157 | # return None
158 | # select_label=random.choice(same_labels)
159 | # candidate1 = random.choice(
160 | # [t for t in candidate_subtree1 if t.label() == select_label])
161 | # candidate2 = random.choice(
162 | # [t for t in candidate_subtree2 if t.label() == select_label])
163 | candidate_subtree1 = list(t1.subtrees())
164 | candidate_subtree2 = list(t2.subtrees())
165 | candidate1 = random.choice(
166 | [t for t in candidate_subtree1])
167 | candidate2 = random.choice(
168 | [t for t in candidate_subtree2])
169 | exchanged_span = ' '.join(candidate1.leaves())
170 | exchanging_span = ' '.join(candidate2.leaves())
171 | original_sentence = ' '.join(t1.leaves())
172 | new_sentence = original_sentence.replace(exchanged_span, exchanging_span)
173 | debug=0
174 | if args.debug:
175 | print('check6')
176 | print(new_sentence)
177 | debug=1
178 | modified_sentence=modify(new_sentence,debug)
179 | new_label=c2a(modified_sentence,debug)
180 | if args.showinfo:
181 | print('cand1:', ' '.join(candidate1.leaves()),
182 | 'cand2:', ' '.join(candidate2.leaves()))
183 | # print([' '.join(c.leaves()) for c in cand1])
184 | # print([' '.join(c.leaves()) for c in cand2])
185 | print('src1:', parsing1)
186 | print('src2:', parsing2)
187 | print('new:',new_sentence)
188 | return modified_sentence,new_label
189 | except Exception as e:
190 | if args.debug:
191 | print('Error!!')
192 | print(e)
193 | return None
194 | def subtree_exchange_single(args,parsing1,label1,parsing2,label2,lam1,lam2):
195 | """
196 | For a pair sentence, exchange subtree and return a label based on subtree length
197 |
198 | Find the candidate subtree, and extract correspoding span, and exchange span
199 |
200 | """
201 | if args.debug:
202 | print('check4')
203 | assert lam1>lam2
204 | t1=Tree.fromstring(parsing1)
205 | original_sentence=' '.join(t1.leaves())
206 | t1_len=len(t1.leaves())
207 | candidate_subtree1=list(t1.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2))
208 | t2=Tree.fromstring(parsing2)
209 | candidate_subtree2=list(t2.subtrees(lambda t: lam1>len(t.leaves())/t1_len>lam2))
210 | if args.debug:
211 | print('check5')
212 | # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2))
213 | if len(candidate_subtree1)==0 or len(candidate_subtree2)==0:
214 |
215 | return None
216 | if args.debug:
217 | print('check6')
218 | if args.phrase_label:
219 | if args.debug:
220 | print('phrase_label')
221 | tree_labels1=[tree.label() for tree in candidate_subtree1]
222 | tree_labels2=[tree.label() for tree in candidate_subtree2]
223 | same_labels=list(set(tree_labels1)&set(tree_labels2))
224 | if not same_labels:
225 | # print('无相同类型的子树')
226 | return None
227 | if args.phrase_length:
228 | if args.debug:
229 | print('phrase_lable_length')
230 | candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves()) and t1.label()==t2.label()]
231 | candidate1,candidate2= random.choice(candidate)
232 | else:
233 | if args.debug:
234 | print('phrase_lable')
235 | select_label=random.choice(same_labels)
236 | candidate1=random.choice([t for t in candidate_subtree1 if t.label()==select_label])
237 | candidate2=random.choice([t for t in candidate_subtree2 if t.label()==select_label])
238 | else:
239 | if args.debug:
240 | print('no phrase_label')
241 | if args.phrase_length:
242 | if args.debug:
243 | print('phrase_length')
244 | candidate=[(t1,t2) for t1 in candidate_subtree1 for t2 in candidate_subtree2 if len(t1.leaves())==len(t2.leaves())]
245 | candidate1,candidate2= random.choice(candidate)
246 | else:
247 | if args.debug:
248 | print('normal TreeMix')
249 | candidate1=random.choice(candidate_subtree1)
250 | candidate2=random.choice(candidate_subtree2)
251 |
252 | exchanged_span=' '.join(candidate1.leaves())
253 | exchanged_len=len(candidate1.leaves())
254 | exchanging_span=' '.join(candidate2.leaves())
255 | new_sentence=original_sentence.replace(exchanged_span,exchanging_span)
256 | # if args.mixup_cross:
257 | new_label=np.zeros(len(args.label_list))
258 |
259 | exchanging_len=len(candidate2.leaves())
260 | new_len=t1_len-exchanged_len+exchanging_len
261 |
262 | new_label[int(label2)]+=exchanging_len/new_len
263 | new_label[int(label1)]+=(new_len-exchanging_len)/new_len
264 |
265 | # else:
266 | # new_label=label1
267 | if args.showinfo:
268 | # print('树1 {}'.format(t1))
269 | # print('树2 {}'.format(t2))
270 | print('-'*50)
271 | print('candidate1:{}'.format([' '.join(x.leaves()) for x in candidate_subtree1]))
272 | print('candidate2:{}'.format([' '.join(x.leaves()) for x in candidate_subtree2]))
273 | print('sentence1 ## {} [{}]\nsentence2 ## {} [{}]'.format(original_sentence,label1,' '.join(t2.leaves()),label2))
274 | print('{} <=========== {}'.format(exchanged_span,exchanging_span))
275 | print('new sentence: ## {}'.format(new_sentence))
276 | print('new label:[{}]'.format(new_label))
277 | return new_sentence,new_label
278 | def subtree_exchange_pair(args,parsing11,parsing12,label1,parsing21,parsing22,label2,lam1,lam2):
279 | """
280 | For a pair sentence, exchange subtree and return a label based on subtree length
281 |
282 | Find the candidate subtree, and extract correspoding span, and exchange span
283 |
284 | """
285 | assert lam1>lam2
286 | lam2=lam1-0.2
287 | t11=Tree.fromstring(parsing11)
288 | t12=Tree.fromstring(parsing12)
289 | original_sentence1=' '.join(t11.leaves())
290 | t11_len=len(t11.leaves())
291 | original_sentence2=' '.join(t12.leaves())
292 | t12_len=len(t12.leaves())
293 | candidate_subtree11=list(t11.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2))
294 | candidate_subtree12=list(t12.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2))
295 | t21=Tree.fromstring(parsing21)
296 | t22=Tree.fromstring(parsing22)
297 | t21_len=len(t21.leaves())
298 | t22_len=len(t22.leaves())
299 | candidate_subtree21=list(t21.subtrees(lambda t: lam1>len(t.leaves())/t11_len>lam2))
300 | candidate_subtree22=list(t22.subtrees(lambda t: lam1>len(t.leaves())/t12_len>lam2))
301 | if args.showinfo:
302 | print('\n')
303 | print('*'*50)
304 | print('t11_len:{}\tt12_len:{}\tt21_len:{}\tt22_len:{}\ncandidate_subtree11:{}\ncandidate_subtree12:{}\ncandidate_subtree21:{}\ncandidate_subtree21:{}'
305 | .format(t11_len,t12_len,t21_len,t22_len,candidate_subtree11,candidate_subtree12,candidate_subtree21,candidate_subtree22))
306 |
307 | # print('subtree1:',len(candidate_subtree1),'\nsubtree2:',len(candidate_subtree2))
308 | if len(candidate_subtree11)==0 or len(candidate_subtree12)==0 or len(candidate_subtree21)==0 or len(candidate_subtree22)==0:
309 | # print("this pair fail",len(candidate_subtree1),len(candidate_subtree2))
310 | return None
311 |
312 | if args.phrase_label:
313 | tree_labels11=[tree.label() for tree in candidate_subtree11]
314 | tree_labels12=[tree.label() for tree in candidate_subtree12]
315 | tree_labels21=[tree.label() for tree in candidate_subtree21]
316 | tree_labels22=[tree.label() for tree in candidate_subtree22]
317 | same_labels1=list(set(tree_labels11)&set(tree_labels21))
318 | same_labels2=list(set(tree_labels12)&set(tree_labels22))
319 | if not (same_labels1 and same_labels2) :
320 | # print('无相同类型的子树')
321 | return None
322 | select_label1=random.choice(same_labels1)
323 | select_label2=random.choice(same_labels2)
324 | displaced1=random.choice([t for t in candidate_subtree11 if t.label()==select_label1])
325 | displacing1=random.choice([t for t in candidate_subtree21 if t.label()==select_label1])
326 | displaced2=random.choice([t for t in candidate_subtree12 if t.label()==select_label2])
327 | displacing2=random.choice([t for t in candidate_subtree22 if t.label()==select_label2])
328 | else:
329 | displaced1=random.choice(candidate_subtree11)
330 | displacing1=random.choice(candidate_subtree21)
331 | displaced2=random.choice(candidate_subtree12)
332 | displacing2=random.choice(candidate_subtree22)
333 |
334 |
335 | displaced_span1=' '.join(displaced1.leaves())
336 | displaced_len1=len(displaced1.leaves())
337 | displacing_span1=' '.join(displacing1.leaves())
338 | new_sentence1=original_sentence1.replace(displaced_span1,displacing_span1)
339 |
340 | displaced_span2=' '.join(displaced2.leaves())
341 | displaced_len2=len(displaced2.leaves())
342 | displacing_span2=' '.join(displacing2.leaves())
343 | new_sentence2=original_sentence2.replace(displaced_span2,displacing_span2)
344 |
345 | # if args.mixup_cross:
346 | new_label=np.zeros(len(args.label_list))
347 | displacing_len1=len(displacing1.leaves())
348 | displacing_len2=len(displacing2.leaves())
349 | new_len=t11_len+t12_len-displaced_len1-displaced_len2+displacing_len1+displacing_len2
350 | displacing_len=displacing_len1+displacing_len2
351 | new_label[int(label2)]+=displacing_len/new_len
352 | new_label[int(label1)]+=(new_len-displacing_len)/new_len
353 |
354 |
355 | if args.showinfo:
356 | print('Before\nsentence1:{}\nsentence2:{}\nlabel1:{}\tlabel2:{}'.format(original_sentence1,original_sentence2,label1,label2))
357 | print('replaced1:{} replacing1:{}\nreplaced2:{} replacing2:{}'.format(displaced_span1,displacing_span1,displaced_span2,displacing2))
358 | print('After\nsentence1:{}\nsentence2:{}\nnew_label:{}'.format(new_sentence1,new_sentence2,new_label))
359 | print('*'*50)
360 |
361 | # print('被替换的span:{}\n替换的span:{}'.format(exchanged_span,exchanging_span))
362 | return new_sentence1,new_sentence2,new_label
363 | def augmentation(args,data,seed,dataset,aug_times,lam1=0.1,lam2=0.3):
364 | """
365 | generate aug_num augmentation dataset
366 | input:
367 | dataset --- pd.dataframe
368 | output:
369 | aug_dataset --- pd.dataframe
370 | """
371 | generated_list=[]
372 | # print('check2')
373 | if args.debug:
374 | print('check3')
375 | shuffled_dataset=dataset.shuffle()
376 | success=0
377 | total=0
378 | with tqdm(total=int(aug_times)*len(dataset)) as bar:
379 | while success < int(aug_times)*len(dataset):
380 | # for idx in range(len(dataset)):
381 | idx = total % len(dataset)
382 | if args.fraction:
383 | bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | fraction:{} |'.format(data,seed,aug_times,args.fraction))
384 | else:
385 | bar.set_description('| Dataset:{:<5} | seed:{} | times:{} | '.format(data,seed,aug_times))
386 |
387 | if args.data_type=='single_cls':
388 | if args.debug:
389 | print('check4')
390 | if 'None' not in [dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing1']]:
391 | aug_sample=subtree_exchange_single(
392 | args,dataset[idx]['parsing1'],dataset[idx][args.label_name],
393 | shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx][args.label_name],
394 | lam1,lam2)
395 | else:
396 | continue
397 |
398 | elif args.data_type=='pair_cls':
399 | # print('check4:pair')
400 | if args.debug:
401 | print('check4')
402 | if 'None' not in [dataset[idx]['parsing1'], dataset[idx]['parsing2'], dataset[idx][args.label_name],
403 | shuffled_dataset[idx]['parsing1'], shuffled_dataset[idx]['parsing2']]:
404 | aug_sample=subtree_exchange_pair(
405 | args,dataset[idx]['parsing1'],dataset[idx]['parsing2'],dataset[idx][args.label_name],
406 | shuffled_dataset[idx]['parsing1'],shuffled_dataset[idx]['parsing2'],shuffled_dataset[idx][args.label_name],
407 | lam1,lam2)
408 | else:
409 | continue
410 |
411 | elif args.data_type=='semantic_parsing':
412 | if args.debug:
413 | print('check4')
414 | aug_sample=subtree_exchange_scan(
415 | args,dataset[idx]['parsing1'],
416 | shuffled_dataset[idx]['parsing1'])
417 |
418 | if args.debug:
419 | print('ok')
420 | print('got one aug_sample : {}'.format(aug_sample))
421 | if aug_sample:
422 | bar.update(1)
423 | success+=1
424 | generated_list.append(aug_sample)
425 | else:
426 | if args.debug:
427 | print('fail this time')
428 | total+=1
429 | #De-duplication
430 | # generated_list=list(set(generated_list))
431 | return generated_list
432 | def parse_argument():
433 | parser=argparse.ArgumentParser()
434 | parser.add_argument('--lam1',type=float,default=0.3)
435 | parser.add_argument('--lam2',type=float,default=0.1)
436 | parser.add_argument('--times',default=[2,5],nargs='+',help='augmentation times list')
437 | parser.add_argument('--min_token',type=int,default=0,help='minimum token numbers of augmentation samples')
438 | parser.add_argument('--label_name',type=str,default='label')
439 | parser.add_argument('--phrase_label',action='store_true',help='subtree lable must be same if set')
440 | parser.add_argument('--phrase_length',action='store_true',help='subtree phrase must be same length if set')
441 | # parser.add_argument('--data_type',type=str,required=True,help='This is a single classification task or pair sentences classification task')
442 | parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list')
443 | parser.add_argument('--showinfo',action='store_true')
444 | parser.add_argument('--mixup_cross',action='store_false',help="NO mix across different classes if set")
445 | parser.add_argument('--low_resource',action='store_true',help="create low source raw and aug datasets if set")
446 | parser.add_argument('--debug',action='store_true',help="display debug information")
447 | parser.add_argument('--data',nargs='+',required=True,help='data list')
448 | parser.add_argument('--proc',type=int,help='processing number for multiprocessing')
449 | args=parser.parse_args()
450 | if not args.proc:
451 | args.proc=cpu_count()
452 | return args
453 | def create_aug_data(args,dataset,data,seed,times,test_dataset=None):
454 |
455 | if args.phrase_label and not args.phrase_length:
456 | prefix_save_path=os.path.join(args.output_dir,'samephraselabel_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2))
457 | elif args.phrase_length and not args.phrase_label:
458 | prefix_save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2))
459 | elif args.phrase_length and args.phrase_label:
460 | prefix_save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2))
461 | elif not args.mixup_cross:
462 | prefix_save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2))
463 | elif args.data_type == 'semantic_parsing':
464 | prefix_save_path = os.path.join(args.output_dir, 'scan_times{}_seed{}'.format(
465 | times, seed))
466 | else:
467 | prefix_save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}'.format(times,args.min_token,seed,args.lam1,args.lam2))
468 | if args.debug:
469 | print('check1')
470 | if not [file_name for file_name in os.listdir(args.output_dir) if file_name.startswith(prefix_save_path)]:
471 | if args.min_token:
472 | dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][0]].split(' '))>args.min_token)
473 | if tasksettings.task_to_keys[data][1]:
474 | dataset=dataset.filter(lambda sample: len(sample[tasksettings.task_to_keys[data][1]].split(' '))>args.min_token)
475 | if args.data_type=='single_cls':
476 | if args.debug:
477 | print('check2')
478 | if args.mixup_cross:
479 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name])
480 | else:
481 | if args.debug:
482 | print('label_list',args.label_list)
483 | new_pd=None
484 | for i in args.label_list:
485 | samples=dataset.filter(lambda sample:sample[args.label_name]==i)
486 | dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],args.label_name])
487 | new_pd=pd.concat([new_pd,dataframe],axis=0)
488 | elif args.data_type=='pair_cls':
489 | if args.debug:
490 | print('check2')
491 | if args.mixup_cross:
492 | # print('check1')
493 | # print(args, seed, dataset, times,tasksettings.task_to_keys[data][0], tasksettings.task_to_keys[data][1], args.label_name)
494 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name])
495 | else:
496 | new_pd=None
497 | if args.debug:
498 | print('label_list',args.label_list)
499 | for i in args.label_list:
500 | samples=dataset.filter(lambda sample:sample[args.label_name]==i)
501 | dataframe=pd.DataFrame(augmentation(args,data,seed,samples,times,args.lam1,args.lam2),columns=[tasksettings.task_to_keys[data][0],tasksettings.task_to_keys[data][1],args.label_name])
502 | new_pd=pd.concat([new_pd,dataframe],axis=0)
503 | elif args.data_type=='semantic_parsing':
504 | if args.debug:
505 | print('check2')
506 | new_pd=pd.DataFrame(augmentation(args,data,seed,dataset,times),columns=[tasksettings.task_to_keys[data][0],args.label_name])
507 |
508 |
509 | new_pd=new_pd.sample(frac=1)
510 |
511 |
512 |
513 | if args.data_type == 'semantic_parsing':
514 |
515 | train_pd=pd.read_csv('DATA/ADDPRIM_JUMP/data/train.csv')
516 | frames = [train_pd,new_pd]
517 | aug_dataset=pd.concat(frames,ignore_index=True)
518 | else:
519 | aug_dataset = Dataset.from_pandas(new_pd)
520 | aug_dataset = aug_dataset.remove_columns("__index_level_0__")
521 |
522 | if args.phrase_label:
523 | save_path = os.path.join(args.output_dir, 'samephraselabel_times{}_min{}_seed{}_{}_{}_{}k'.format(
524 | times, args.min_token, seed, args.lam1, args.lam2, round(len(new_pd)//1000,-1)))
525 | elif args.phrase_length and not args.phrase_label:
526 | save_path=os.path.join(args.output_dir,'samephraselength_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1)))
527 | elif args.phrase_length and args.phrase_label:
528 | save_path=os.path.join(args.output_dir,'samephraselabel_length_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1)))
529 | elif not args.mixup_cross:
530 | save_path=os.path.join(args.output_dir,'sameclass_times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1)))
531 | elif args.data_type=='semantic_parsing':
532 | save_path_train = os.path.join(prefix_save_path, 'train.csv')
533 | save_path_test = os.path.join(prefix_save_path, 'test.csv')
534 | else:
535 | save_path=os.path.join(args.output_dir,'times{}_min{}_seed{}_{}_{}_{}k'.format(times,args.min_token,seed,args.lam1,args.lam2,round(len(new_pd)//1000,-1)))
536 | if args.data_type == 'semantic_parsing':
537 |
538 |
539 | if not os.path.exists(prefix_save_path):
540 | os.makedirs(prefix_save_path)
541 |
542 | aug_dataset.to_csv(save_path_train,index=0)
543 | test_dataset.to_csv(save_path_test,index=0)
544 | else:
545 | aug_dataset.save_to_disk(save_path)
546 | else:
547 | print('file {} already exsits!'.format(prefix_save_path))
548 |
549 |
550 | def main():
551 | p=Pool(args.proc)
552 | for data in args.data:
553 | path_dir=os.path.join('DATA',data.upper())
554 | if data in tasksettings.pair_datasets:
555 | args.data_type='pair_cls'
556 | elif data in tasksettings.SCAN:
557 | args.label_name='actions'
558 | args.data_type='semantic_parsing'
559 | testset_path=os.path.join(path_dir,'data','test.csv')
560 | else:
561 | args.data_type='single_cls'
562 | if data=='trec':
563 | try:
564 | assert args.label_name in ['label-fine', 'label-coarse']
565 | except AssertionError:
566 | raise(AssertionError(
567 | "If you want to train on trec dataset with augmentation, you have to name the label of split in ['label-fine', 'label-coarse']"))
568 |
569 | print(args.label_name,data)
570 | args.output_dir=os.path.join(path_dir,'generated/{}'.format(args.label_name))
571 | else:
572 | args.output_dir=os.path.join(path_dir,'generated')
573 | args.data_path=os.path.join(path_dir,'data','train_parsing.csv')
574 |
575 | dataset=load_dataset('csv',data_files=[args.data_path],split='train')
576 | if args.data_type=='semantic_parsing':
577 | testset=load_dataset('csv',data_files=[testset_path],split='train')
578 | if args.data_type in ['single_cls','pair_cls']:
579 | args.label_list=list(set(dataset[args.label_name])) #根据data做一个表查找所有的label
580 | for seed in args.seeds:
581 | seed=int(seed)
582 | set_seed(seed)
583 | dataset=dataset.shuffle()
584 | if args.low_resource:
585 | for fraction in tasksettings.low_resource[data]:
586 | args.fraction=fraction
587 | train_dataset=dataset.select(random.sample(range(len(dataset)),int(fraction*len(dataset))))
588 | low_resource_dir=os.path.join(path_dir,'low_resource','low_resource_{}'.format(fraction),'seed_{}'.format(seed))
589 | if not os.path.exists(low_resource_dir):
590 | os.makedirs(low_resource_dir)
591 | args.output_dir=low_resource_dir
592 | train_path=os.path.join(args.output_dir,'partial_train')
593 | if not os.path.exists(train_path):
594 | train_dataset.save_to_disk(train_path)
595 | for times in args.times:
596 | times=int(times)
597 | p.apply_async(create_aug_data, args=(
598 | args, train_dataset, data, seed, times))
599 | else:
600 | args.fraction=None
601 | for times in args.times:
602 | times=int(times)
603 | p.apply_async(create_aug_data, args=(
604 | args, dataset, data, seed, times,testset))
605 | print('='*20,'Start generating augmentation datsets !',"="*20)
606 | # p.close()
607 | # p.join()
608 |
609 | p.close()
610 | p.join()
611 | print('='*20, 'Augmenatation done !', "="*20)
612 | if __name__=='__main__':
613 |
614 | tasksettings=settings.TaskSettings()
615 | args=parse_argument()
616 | print(args)
617 | main()
618 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Le Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TreeMix: Compositional Constituency-based Data Augmentation for Natural Language Understanding (NAACL 2022)
2 |
3 | Pytorch Implementation of [TreeMix](https://arxiv.org/abs/2205.06153)
4 |
5 | 
6 |
7 | ## Abstract
8 |
9 | Data augmentation is an effective approach to tackle over-fitting. Many previous works have proposed different data augmentations strategies for NLP, such as noise injection, word replacement, back-translation etc. Though effective, they missed one important characteristic of language–compositionality, meaning of a complex expression is built from its subparts. Motivated by this, we propose a compositional data augmentation approach for natural language understanding called TreeMix. Specifically, TreeMix leverages constituency parsing tree to decompose sentences into constituent sub-structures and the Mixup data augmentation technique to recombine them to generate new sentences. Compared with previous approaches, TreeMix introduces greater diversity to the samples generated and encourages models to learn compositionality of NLP data. Extensive experiments on text classification and semantic parsing benchmarks demonstrate that TreeMix outperforms current stateof-the-art data augmentation methods.
10 |
11 | ## Code Structure
12 |
13 | ```
14 | |__ DATA
15 | |__ SST2
16 | |__ data
17 | |__ train.csv --> raw train dataset
18 | |__ test.csv --> raw test dataset
19 | |__ train_parsing --> consituency parsing results
20 | |__ generated
21 | |__ times2_min0_seed0_0.3_0.1_7k --> augmentation dataset(hugging face dataset format) with 2 times bigger, seed 0, lambda_L=0.1, lambda_U=0.3, total size=7k
22 | |__ logs --> best results log
23 | |__ runs --> tensorboard results
24 | |__ aug --> augmentation only baseline
25 | |__ raw --> standard baseline
26 | |__ raw_aug --> TreeMix results
27 | |__ checkpoints
28 | |__ best.pt --> Best model checkpoints
29 | |__ process_data / --> Download & Semantic parsing using Stanfordcorenlp tooktiks
30 | |__ Load_data.py --> Loading raw dataset and augmentation dataset
31 | |__ get_data.py --> Download all dataset from huggingface dataset and perform constituency parsing to obtain processed dataset
32 | |__ settings.py --> Hyperparameter settings & Task settings
33 | |__ online_augmentation
34 | |__ __init__.py --> Random Mixup
35 | |__ Augmentation.py --> Subtree Exchange augmentation method based on consituency parsing results for all dataset(single sentence classification, sentence relation classification, SCAN dataset)
36 | |__ run.py --> Train for one dataset
37 | |__ batch_train.py -> Train with different datasets and different settings, specified by giving specific arguments
38 | ```
39 |
40 | ### Getting Started
41 |
42 | ```
43 | pip install -r requirements.txt
44 | ```
45 |
46 | Note that to successfully run TreeMix, you must install `stanfordcorenlp`. Please refer to this
47 |
48 | [corenlp]( https://stanfordnlp.github.io/CoreNLP/download.html "stanfordcorenlp") for more information.
49 |
50 | ### Download & Constituency Parsing
51 |
52 | ```
53 | cd process_data
54 | python get_data.py --data {DATA} --corenlp_dir {CORENLP}
55 | ```
56 |
57 | `DATA` indicates the dataset name, `CORENLP` indicates the directory of `stanfordcorenlp` . After this process, you could get corresponding data folder in `DATA/` and `train_parsing.csv`.
58 |
59 | ### TreeMix Augmentation
60 |
61 | ```
62 | python Augmentation.py --data {DATASET} --times {TIMES} --lam1 {LAM1} --lam2 {LAM2} --seeds {SEEDS}
63 | ```
64 |
65 | Augmentation with different arguments, will generate #(TIMES)×#(SEEDS) extra dataset. `DATASET` could be **list of data name** such as 'sst2 rte', since 'trec' has two versions, you need to input `--label_name {}` to specify whether the trec-fine or trec-coarse set. Besides, by typing `--low_resource` , it will generated partial augmentation dataset as well as partial train set. You can modify the hyperparameter `lambda_U` and `lambda_L` by changing `LAM1` and `LAM2` . `TIMES` could be **a list of intergers** such as 2,5 to assign the size of the augmentation dataset.
66 |
67 | ### Model Training
68 |
69 | ```
70 | python batch_train.py --mode {MODE} --data {DATASET}
71 | ```
72 |
73 | Evaluation one dataset for all its augmenation set with specific mode. `MODE` can be 'raw', 'aug', 'raw_aug', which indicates train the model with raw dataset only, augmentation dataset only and combination of raw and augmentation set respectively. `DATASET` should be **one specific dataset** name. This will report all results of a specifc dataset in `/log`. If not specified, the hyperparameters will be set as in `/process_data/settings.py`, please look into this file for more arguments information.
74 |
--------------------------------------------------------------------------------
/batch_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from process_data.settings import TaskSettings
4 | def parse_argument():
5 | parser = argparse.ArgumentParser(description='download and parsing datasets')
6 | parser.add_argument('--data',type=str,required=True,help='data list')
7 | parser.add_argument('--aug_dir',help='Augmentation file directory')
8 | parser.add_argument('--seeds',default=[0,1,2,3,4],nargs='+',help='seed list')
9 | parser.add_argument('--modes',nargs='+',required=True,help='seed list')
10 | parser.add_argument('--label_name',type=str,default='label')
11 | # parser.add_argument('--batch_size',default=128,type=int,help='train examples in each batch')
12 | # parser.add_argument('--aug_batch_size',default=128,type=int,help='train examples in each batch')
13 | parser.add_argument('--random_mix',type=str,choices=['zero_one','zero','one','all'],help="random mixup ")
14 | parser.add_argument('--prefix',type=str,help="only choosing the datasets with the prefix,for ablation study")
15 | parser.add_argument('--GPU',type=int,default=0,help="available GPU number")
16 | parser.add_argument('--low_resource', action='store_true',
17 | help='whther to train low resource dataset')
18 |
19 | args=parser.parse_args()
20 | if args.data=='trec':
21 | try:
22 | assert args.label_name in ['label-fine','label-coarse']
23 | except AssertionError:
24 | raise( AssertionError("If you want to train on TREC dataset with augmentation, you have to name the label of split either 'label-fine' or 'label-coarse'"))
25 | args.aug_dir = os.path.join('DATA', args.data.upper(), 'generated',args.label_name)
26 | if args.aug_dir is None :
27 | args.aug_dir=os.path.join('DATA',args.data.upper(),'generated')
28 |
29 | if 'aug' in args.modes:
30 | try:
31 | assert [file for file in os.listdir(args.aug_dir) if 'times' in file]
32 | except AssertionError:
33 | raise( AssertionError( "{}".format('This directory has no augmentation file, please input correct aug_dir!') ) )
34 | if args.low_resource:
35 | try:
36 | args.low_resource = os.path.join('DATA', args.data.upper(),'low_resource')
37 | assert os.path.exists(args.low_resource)
38 | except AssertionError:
39 | raise( AssertionError("There is no any low resource datasets in this data"))
40 |
41 | return args
42 | def batch_train(args):
43 | for seed in args.seeds:
44 | # for aug_file in os.listdir(args.aug_dir):
45 | for mode in args.modes:
46 | if mode=='raw':
47 | # data_path=os.path.join(args.aug_dir,aug_file)
48 | if args.random_mix:
49 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,args.random_mix,**settings[args.data]))
50 | else:
51 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(args.GPU,args.label_name,mode,int(seed),args.data,**settings[args.data]))
52 | else:
53 | for aug_file in os.listdir(args.aug_dir):
54 | if args.prefix:
55 | # only train on file with prefix
56 | if aug_file.startswith(args.prefix):
57 | aug_file_path = os.path.join(
58 | args.aug_dir, aug_file)
59 | assert os.path.exists(aug_file_path)
60 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
61 | args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data]))
62 | else:
63 | # train on every file in dir
64 | aug_file_path = os.path.join(
65 | args.aug_dir, aug_file)
66 | assert os.path.exists(aug_file_path)
67 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --label_name {} --mode {} --seed {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
68 | args.GPU, args.label_name, mode, int(seed), args.data, aug_file_path, **settings[args.data]))
69 | def low_resource_train(args):
70 | for partial_split in os.listdir(args.low_resource):
71 | partial_split_path=os.path.join(args.low_resource,partial_split)
72 | args.output_dir = os.path.join(
73 | args.low_resource_dir, partial_split)
74 | if not os.path.exists(args.output_dir):
75 | os.makedirs(args.output_dir)
76 | for seed_num in os.listdir(partial_split_path):
77 | partial_split_seed_path=os.path.join(partial_split_path,seed_num)
78 | for mode in args.modes:
79 | if mode=='raw':
80 | if args.random_mix:
81 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --random_mix {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '
82 | .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]), args.output_dir, args.label_name, mode, args.data, args.random_mix, **settings[args.data]))
83 | else:
84 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '
85 | .format(args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]),args.output_dir, args.label_name, mode, args.data, **settings[args.data]))
86 | elif mode=='raw_aug':
87 | for aug_file in [file for file in os.listdir(partial_split_seed_path) if file.startswith('times')]:
88 | aug_file_path=os.path.join(partial_split_seed_path,aug_file)
89 | assert os.path.exists(aug_file_path)
90 | os.system('CUDA_VISIBLE_DEVICES={} python run.py --low_resource_dir {} --seed {} --output_dir {} --label_name {} --mode {} --data {} --data_path {} --epoch {epoch} --batch_size {batch_size} --aug_batch_size {aug_batch_size} --val_steps {val_steps} --max_length {max_length} --augweight {augweight} '.format(
91 | args.GPU, partial_split_seed_path, int(seed_num.split('_')[1]) , args.output_dir, args.label_name, mode, args.data, aug_file_path, **settings[args.data]))
92 | if __name__=='__main__':
93 | args=parse_argument()
94 | tasksettings=TaskSettings()
95 | settings=tasksettings.train_settings
96 | if args.low_resource:
97 | args.low_resource_dir=os.path.join('DATA',args.data.upper(),'runs','low_resource')
98 | if not os.path.exists(args.low_resource_dir):
99 | os.makedirs(args.low_resource_dir)
100 | low_resource_train(args)
101 | else:
102 | batch_train(args)
103 |
--------------------------------------------------------------------------------
/online_augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | def random_mixup_process(args,ids1,lam):
5 |
6 | rand_index=torch.randperm(ids1.shape[0])
7 | lenlist=[]
8 | # rand_index=torch.randperm(len(ids1))
9 | for x in ids1:
10 | mask=((x!=101)&(x!=0)&(x!=102))
11 | lenlist.append(int(mask.sum()))
12 | lenlist2=torch.tensor(lenlist)[rand_index]
13 | spanlen=torch.tensor([int(x*lam) for x in lenlist])
14 |
15 | beginlist=[1+random.randint(0,x-int(x*lam)) for x in lenlist]
16 | beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
17 | if args.difflen:
18 |
19 | spanlen2=torch.tensor([int(x*lam) for x in lenlist2])
20 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)]
21 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)]
22 | else:
23 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
24 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)]
25 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]
26 |
27 | ids2=ids1.clone()
28 | if args.difflen:
29 | for idx in range(len(ids1)):
30 | tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]]
31 | ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp))))
32 | else:
33 | for idx in range(len(ids1)):
34 | ids1[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]]
35 | assert ids1.shape==ids2.shape
36 | return ids1,rand_index
37 | def mixup_01(args,input_ids,lam,idx1,idx2):
38 | '''
39 | 01交换
40 | '''
41 | difflen=False
42 | random_index=torch.zeros(len(idx1)+len(idx2)).long()
43 | random_index[idx1]=torch.tensor(np.random.choice(idx2,size=len(idx1)))
44 | random_index[idx2]=torch.tensor(np.random.choice(idx1,size=len(idx2)))
45 |
46 | len_list1=[]
47 | len_list2=[]
48 | for input_id1 in input_ids:
49 | #计算各个句子的具体token数
50 | mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102))
51 | len_list1.append(int(mask.sum()))
52 | # print(len_list1)
53 | len_list2=torch.tensor(len_list1)[random_index]
54 |
55 | spanlen=torch.tensor([int(x*lam) for x in len_list1])
56 | beginlist=[1+random.randint(0,x-int(x*lam)) for x in len_list1]
57 | beginlist2=[1+random.randint(0,x-y) for x,y in zip(len_list2,spanlen)]
58 | if difflen:
59 | spanlen2=torch.tensor([int(x*lam) for x in len_list2])
60 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen2)]
61 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen2)]
62 | else:
63 | # beginlist2=[1+random.randint(0,x-y) for x,y in zip(lenlist2,spanlen)]
64 | spanlist2=[(x,int(y)) for x,y in zip(beginlist2,spanlen)]
65 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]
66 | new_ids=input_ids.clone()
67 |
68 | # print(random_index)
69 | if difflen:
70 | for idx in range(len(ids1)):
71 | tmp=torch.cat((ids1[idx][:spanlist[idx][0]],ids2[rand_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]],ids1[idx][spanlist[idx][0]+spanlist[idx][1]:]),dim=0)[:ids1.shape[1]]
72 | ids1[idx]=torch.cat((tmp,torch.zeros(ids1.shape[1]-len(tmp))))
73 | else:
74 | for idx in range(len(input_ids)):
75 | new_ids[idx][spanlist[idx][0]:spanlist[idx][0]+spanlist[idx][1]]=input_ids[random_index[idx]][spanlist2[idx][0]:spanlist2[idx][0]+spanlist2[idx][1]]
76 | # for i in range(len(input_ids)):
77 | # print('{}:交换的是{}与{},其中第1句选取的是从{}开始到{}的句子,第2句选取的是从{}开始到{}结束的句子'.format(
78 | # i,i,random_index[i],spanlist[i][0],spanlist[i][0]+spanlist[i][1],spanlist2[i][0],spanlist2[i][0]+spanlist2[i][1]))
79 | return new_ids,random_index
80 | def mixup(args,input_ids,lam,idx1,idx2=None):
81 | '''
82 | 只针对idx1索引对应的sample内部进行交换,如果idx2也给的话就是idx1 idx2进行交换
83 | '''
84 | select_input_ids=torch.index_select(input_ids,0,idx1)
85 | rand_index=torch.randperm(select_input_ids.shape[0])
86 | new_idx=torch.tensor(list(range(input_ids.shape[0])))
87 | len_list1=[]
88 | len_list2=[]
89 | for input_id1 in select_input_ids:
90 | #calculte length of tokens in each sentence
91 | mask=((input_id1!=101)&(input_id1!=0)&(input_id1!=102))
92 | len_list1.append(int(mask.sum()))
93 | len_list2=torch.tensor(len_list1)[rand_index]
94 |
95 | spanlen=torch.tensor([int(x*lam) for x in len_list1])
96 | beginlist=[1+random.randint(0,x-y) for x,y in zip(len_list1,spanlen)]
97 | beginlist2=[1+random.randint(0,max(0,x-y)) for x,y in zip(len_list2,spanlen)]
98 |
99 | spanlist=[(x,int(y)) for x,y in zip(beginlist,spanlen)]
100 | spanlist2 = [(x, min(int(y),z)) for x, y, z in zip(beginlist2, spanlen, len_list2)]
101 | new_ids=input_ids.clone()
102 | new_idx[idx1]=idx1[rand_index]
103 | for i,idx in enumerate(idx1):
104 | new_ids[idx][spanlist[i][0]:spanlist[i][0]+spanlist[i][1]]=input_ids[idx1[rand_index[i]]][spanlist2[i][0]:spanlist2[i][0]+spanlist2[i][1]]
105 |
106 | return new_ids,new_idx
107 | def random_mixup(args,ids1,lab1,lam):
108 | """
109 | function: random select span to exchange based on lam to decide span length and rand_index decide selected candidate exchange sentece
110 | input:
111 | ids1 -- tensors of tensors input_ids
112 | lab1 -- tensors of tensors labels
113 | lam -- span length rate
114 | output:
115 | ids1 -- tensors of tensors , exchanged span
116 | rand_index -- tensors , permutation index
117 |
118 | """
119 | if args.random_mix=='all':
120 | return mixup(args,ids1,lam,torch.tensor(range(ids1.shape[0])))
121 | else:
122 | pos_idx=(lab1==1).nonzero().squeeze()
123 | neg_idx=(lab1==0).nonzero().squeeze()
124 | pos_samples=torch.index_select(ids1,0,pos_idx)
125 | neg_samples=torch.index_select(ids1,0,neg_idx)
126 | if args.random_mix=='zero':
127 | return mixup(args,ids1,lam,neg_idx)
128 | if args.random_mix=='one':
129 | return mixup(args,ids1,lam,pos_idx)
130 | if args.random_mix=='zero_one':
131 | return mixup_01(args,ids1,lam,pos_idx,neg_idx)
132 |
--------------------------------------------------------------------------------
/online_augmentation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/online_augmentation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/online_augmentation/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/online_augmentation/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/Load_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, load_from_disk
2 | from transformers import BertTokenizer
3 | import torch
4 | import numpy as np
5 | import os
6 | from . import settings
7 | class DATA_process(object):
8 | def __init__(self, args=None):
9 | if args:
10 | print('Initializing with args')
11 | self.data = args.data if args.data else None
12 | self.task = args.task if args.task else None
13 | self.tokenizer = BertTokenizer.from_pretrained(
14 | args.model, do_lower_case=True) if args.model else None
15 | self.tasksettings = settings.TaskSettings()
16 | self.max_length = args.max_length if args.max_length else None
17 | self.label_name = args.label_name if args.label_name else None
18 | self.batch_size = args.batch_size if args.batch_size else None
19 | self.aug_batch_size=args.aug_batch_size if args.aug_batch_size else None
20 | self.min_train_token = args.min_train_token if args.min_train_token else None
21 | self.max_train_token = args.max_train_token if args.max_train_token else None
22 | self.num_proc = args.num_proc if args.num_proc else None
23 | self.low_resource_dir = args.low_resource_dir if args.low_resource_dir else None
24 | self.data_path = args.data_path if args.data_path else None
25 | self.random_mix = args.random_mix if args.random_mix else None
26 |
27 | def validation_data(self):
28 | validation_set = self.validationset(
29 | data=self.data)
30 | print('='*20,'multiprocess processing test dataset','='*20)
31 | # Process dataset to make dataloader
32 | if self.task == 'single':
33 | validation_set = validation_set.map(
34 | self.encode, batched=True, num_proc=self.num_proc)
35 | else:
36 | validation_set = validation_set.map(
37 | self.encode_pair, batched=True, num_proc=self.num_proc)
38 | # validation_set = validation_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True)
39 | validation_set = validation_set.rename_column(
40 | self.label_name, "labels")
41 | validation_set.set_format(type='torch', columns=[
42 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
43 |
44 | val_dataloader = torch.utils.data.DataLoader(
45 | validation_set, batch_size=self.batch_size, shuffle=True)
46 | return val_dataloader
47 | def encode(self, examples):
48 | return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], max_length=self.max_length, truncation=True, padding='max_length')
49 | def encode_pair(self, examples):
50 | return self.tokenizer(examples[self.tasksettings.task_to_keys[self.data][0]], examples[self.tasksettings.task_to_keys[self.data][1]], max_length=self.max_length, truncation=True, padding='max_length')
51 |
52 | def train_data(self, count_label=False):
53 | train_set, label_num = self.traindataset(
54 | data=self.data, low_resource_dir=self.low_resource_dir, label_num=count_label)
55 | print('='*20,'multiprocess processing train dataset','='*20)
56 | if self.task == 'single':
57 | train_set = train_set.map(
58 | self.encode, batched=True, num_proc=self.num_proc)
59 | else:
60 | train_set = train_set.map(
61 | self.encode_pair, batched=True, num_proc=self.num_proc)
62 | if self.random_mix:
63 | # sort the train dataset
64 | print('-'*20, 'random_mixup', '-'*20)
65 | train_set = train_set.map(
66 | lambda examples: {'token_num': np.sum(np.array(examples['attention_mask']))})
67 | train_set = train_set.sort('token_num', reverse=True)
68 | # train_set = train_set.map(lambda examples: {'labels': examples[args.label_name]}, batched=True)
69 | train_set = train_set.rename_column(self.label_name, "labels")
70 | if self.min_train_token:
71 | print(
72 | '-'*20, 'filter sample whose sentence shorter than {}'.format(self.min_train_token), '-'*20)
73 | train_set = train_set.filter(lambda example: sum(
74 | example['attention_mask']) > self.min_train_token+2)
75 | if self.max_train_token:
76 | print(
77 | '-'*20, 'filter sample whose sentence longer than {}'.format(self.max_train_token), '-'*20)
78 | train_set = train_set.filter(lambda example: sum(
79 | example['attention_mask']) < self.max_train_token+2)
80 | train_set.set_format(type='torch', columns=[
81 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
82 |
83 | train_dataloader = torch.utils.data.DataLoader(
84 | train_set, batch_size=self.batch_size, shuffle=True)
85 | if count_label:
86 | return train_dataloader, label_num
87 | else:
88 | return train_dataloader
89 | def augmentation_data(self):
90 | try:
91 | aug_dataset = load_dataset(
92 | 'csv', data_files=[self.data_path])['train']
93 | except Exception as e:
94 | aug_dataset = load_from_disk(self.data_path)
95 | print('='*20, 'multiprocess processing aug dataset', '='*20)
96 | if self.task == 'single':
97 | aug_dataset = aug_dataset.map(
98 | self.encode, batched=True, num_proc=self.num_proc)
99 | else:
100 | aug_dataset = aug_dataset.map(
101 | self.encode_pair, batched=True, num_proc=self.num_proc)
102 | # if self.mix:
103 | # # label has more than one dimension
104 | # # aug_dataset = aug_dataset.map(lambda examples: {'labels':examples[self.label_name]},batched=True)
105 | # else:
106 | # # aug_dataset = aug_dataset.map(lambda examples: {'labels':int(examples[self.label_name])})
107 | aug_dataset = aug_dataset.rename_column(self.label_name, 'labels')
108 |
109 | aug_dataset.set_format(type='torch', columns=[
110 | 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
111 | aug_dataloader = torch.utils.data.DataLoader(
112 | aug_dataset, batch_size=self.aug_batch_size, shuffle=True)
113 | return aug_dataloader
114 |
115 | def validationset(self,data):
116 | if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']:
117 | if data == 'mnli':
118 | validation_set = load_dataset(
119 | 'glue', data, split='validation_mismatched')
120 | else:
121 | validation_set = load_dataset('glue', data, split='validation')
122 | print('-'*20, 'Test on glue@{}'.format(data), '-'*20)
123 | elif data in ['imdb', 'ag_news', 'trec']:
124 | validation_set = load_dataset(data, split='test')
125 | print('-'*20, 'Test on {}'.format(data), '-'*20)
126 | elif data == 'sst':
127 | validation_set = load_dataset(data, 'default', split='test')
128 | validation_set = validation_set.map(lambda example: {'label': int(
129 | example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4)
130 | print('-'*20, 'Test on {}'.format(data), '-'*20)
131 | else:
132 | validation_set = load_dataset(data, split='validation')
133 | print('-'*20, 'Test on {}'.format(data), '-'*20)
134 |
135 | return validation_set
136 |
137 | def traindataset(self, data, low_resource_dir=None, split='train', label_num=False):
138 | if low_resource_dir:
139 | train_set = load_from_disk(os.path.join(
140 | low_resource_dir, 'partial_train'))
141 | else:
142 | if data in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']:
143 | train_set = load_dataset('glue', data, split=split)
144 | elif data == 'sst':
145 | train_set = load_dataset(data, 'default', split=split)
146 | train_set = train_set.map(lambda example: {'label': int(
147 | example['label']*10//2)}, remove_columns=['tokens', 'tree'], num_proc=4)
148 | else:
149 | train_set = load_dataset(data, split=split)
150 | if label_num:
151 | return train_set, len(set(train_set[self.label_name]))
152 | else:
153 | return train_set
154 | if __name__=="__main__":
155 | data_processor=DATA_process()
156 | valset=data_processor.validationset(data='ag_news')
157 | print(valset)
158 |
--------------------------------------------------------------------------------
/process_data/__init__.py:
--------------------------------------------------------------------------------
1 | if __name__=='__main__':
2 | print('Using process_data package')
--------------------------------------------------------------------------------
/process_data/__pycache__/Augmentation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Augmentation.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/Load_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Load_data.cpython-38.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/Load_data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/Load_data.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/ceshi.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/ceshi.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/settings.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/settings.cpython-38.pyc
--------------------------------------------------------------------------------
/process_data/__pycache__/settings.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lezhang7/TreeMix/016b47df8f028d384aedeaf1605ce24aac48b9b7/process_data/__pycache__/settings.cpython-39.pyc
--------------------------------------------------------------------------------
/process_data/get_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import numpy as np
3 | import argparse
4 | import os
5 | import pandas as pd
6 | from stanfordcorenlp import StanfordCoreNLP
7 | from tqdm import tqdm
8 | import time
9 | import settings
10 | from multiprocessing import cpu_count
11 | from pandarallel import pandarallel
12 | def parse_argument():
13 | parser = argparse.ArgumentParser(description='download and parsing datasets')
14 | parser.add_argument('--data',nargs='+',required=True,help='data list')
15 | parser.add_argument('--corenlp_dir',type=str,default='/remote-home/lzhang/stanford-corenlp-full-2018-10-05/')
16 | parser.add_argument('--proc',type=int,help='multiprocessing num')
17 | args=parser.parse_args()
18 | return args
19 |
20 |
21 | def parsing_stanfordnlp(raw_text):
22 | try:
23 | parsing = snlp.parse(raw_text)
24 | return parsing
25 | except Exception as e:
26 | return 'None'
27 |
28 | def constituency_parsing(args):
29 | if not args.proc:
30 | args.proc = cpu_count()
31 | pandarallel.initialize(nb_workers=args.proc, progress_bar=True)
32 | for dataset in args.data:
33 | DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA')
34 | path_dir=os.path.join(DATA_dir,dataset.upper())
35 | output_path=os.path.join(path_dir,'data','train_parsing.csv')
36 | if os.path.exists(output_path):
37 | print('The data {} has already parsed!'.format(dataset.upper()))
38 | continue
39 | train=pd.read_csv(os.path.join(path_dir,'data','train.csv'),encoding="utf-8")
40 | for dataset in args.data:
41 | DATA_dir = os.path.join(os.path.abspath(
42 | os.path.join(os.getcwd(), "..")), 'DATA')
43 | path_dir = os.path.join(DATA_dir, dataset.upper())
44 | output_path = os.path.join(path_dir, 'data', 'train_parsing.csv')
45 | if os.path.exists(output_path):
46 | print('The data {} has already parsed!'.format(dataset.upper()))
47 | continue
48 | train = pd.read_csv(os.path.join(
49 | path_dir, 'data', 'train.csv'), encoding="utf-8")
50 | for i,text_name in enumerate(task_to_keys[dataset]):
51 | parsing_name = 'parsing{}'.format(i+1)
52 | train[parsing_name] = train[text_name].parallel_apply(
53 | parsing_stanfordnlp)
54 |
55 | for i,text_name in enumerate(task_to_keys[dataset]):
56 | parsing_name='parsing{}'.format(i+1)
57 | train=train.drop(train[train[parsing_name]=='None'].index)
58 | train.to_csv(output_path, index=0)
59 | def download_data(args):
60 |
61 | for dataset in args.data:
62 | DATA_dir=os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")),'DATA')
63 | path_dir=os.path.join(DATA_dir,dataset.upper())
64 | if dataset.upper() in os.listdir(DATA_dir):
65 | print('{} directory already exists !'.format(dataset.upper()))
66 | continue
67 | try:
68 | if dataset ==['addprim_jump','addprim_turn_left','simple']:
69 | downloaded_data_list = [load_dataset('scan', dataset)]
70 | if dataset in ['sst2', 'rte', 'mrpc', 'qqp', 'mnli', 'qnli']:
71 | downloaded_data_list=[load_dataset('glue',dataset)]
72 | elif dataset =='sst':
73 | downloaded_data_list = [load_dataset("sst", "default")]
74 | else:
75 | downloaded_data_list=[load_dataset(dataset)]
76 |
77 | if not os.path.exists(path_dir):
78 | if dataset=='trec':
79 | os.makedirs(os.path.join(path_dir,'generated/fine'))
80 | os.makedirs(os.path.join(path_dir,'generated/coarse'))
81 | os.makedirs(os.path.join(
82 | path_dir, 'runs/label-coarse/raw'))
83 | os.makedirs(os.path.join(
84 | path_dir, 'runs/label-coarse/aug'))
85 | os.makedirs(os.path.join(
86 | path_dir, 'runs/label-coarse/raw_aug'))
87 | os.makedirs(os.path.join(
88 | path_dir, 'runs/label-fine/raw'))
89 | os.makedirs(os.path.join(
90 | path_dir, 'runs/label-fine/aug'))
91 | os.makedirs(os.path.join(
92 | path_dir, 'runs/label-fine/raw_aug'))
93 | else:
94 | os.makedirs(os.path.join(path_dir,'generated'))
95 | os.makedirs(os.path.join(path_dir,'runs/raw'))
96 | os.makedirs(os.path.join(path_dir,'runs/aug'))
97 | os.makedirs(os.path.join(path_dir,'runs/raw_aug'))
98 | os.makedirs(os.path.join(path_dir,'logs'))
99 | os.makedirs(os.path.join(path_dir,'data'))
100 | for downloaded_data in downloaded_data_list:
101 | for data_split in downloaded_data:
102 | dataset_split=downloaded_data[data_split]
103 | dataset_split.to_csv(os.path.join(path_dir,'data',data_split+'.csv'),index=0)
104 | except Exception as e:
105 | print('Downloading failed on {}, due to error {}'.format(dataset,e))
106 | if __name__=='__main__':
107 | args = parse_argument()
108 | tasksettings=settings.TaskSettings()
109 | task_to_keys=tasksettings.task_to_keys
110 | print('='*20,'Start Downloading Datasets','='*20)
111 | download_data(args)
112 | print('='*20,'Start Parsing Datasets','='*20)
113 | snlp = StanfordCoreNLP(args.corenlp_dir)
114 | constituency_parsing(args)
115 |
--------------------------------------------------------------------------------
/process_data/settings.py:
--------------------------------------------------------------------------------
1 | class TaskSettings(object):
2 | def __init__(self):
3 | self.train_settings={
4 | "mnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2},
5 | "mrpc":{'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':0.2},
6 | "qnli":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.2},
7 | "qqp": {'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':300,'max_length':128,'augweight':0.2},
8 | "rte": {'epoch':10,'batch_size':32,'aug_batch_size':32,'val_steps':50,'max_length':128,'augweight':-0.2},
9 | "sst2":{'epoch':5,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5},
10 | "trec":{'epoch':20,'batch_size':96,'aug_batch_size':96,'val_steps':100,'max_length':128,'augweight':0.5},
11 | "imdb":{'epoch':5,'batch_size':8,'aug_batch_size':8,'val_steps':500,'max_length':512,'augweight':0.5},
12 | "ag_news": {'epoch': 5, 'batch_size': 96, 'aug_batch_size': 96, 'val_steps': 500, 'max_length': 128, 'augweight': 0.5},
13 |
14 | }
15 | self.task_to_keys = {
16 | "mnli": ["premise", "hypothesis"],
17 | "mrpc": ["sentence1", "sentence2"],
18 | "qnli": ["question", "sentence"],
19 | "qqp": ["question1", "question2"],
20 | "rte": ["sentence1", "sentence2"],
21 | "sst2": ["sentence"],
22 | "trec": ["text"],
23 | "anli": ["premise", "hypothesis"],
24 | "imdb": ["text"],
25 | "ag_news":["text"],
26 | "sst":["sentence"],
27 | "addprim_jump":["commands"],
28 | "addprim_turn_left":["commands"]
29 | }
30 | self.pair_datasets=['qqp','rte','qnli','mrpc','mnli']
31 | self.SCAN = ['addprim_turn_left', 'addprim_jump','simple']
32 | self.low_resource={
33 | "ag_news":[0.01,0.02,0.05,0.1,0.2],
34 | "sst":[0.01,0.02,0.05,0.1,0.2],
35 | "sst2":[0.01,0.02,0.05,0.1,0.2]
36 | }
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==1.11.0
2 | nltk==3.6.2
3 | pandarallel==1.5.4
4 | stanfordcorenlp==3.9.1.1
5 | torch==1.9.0+cu111
6 | transformers==4.12.5
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.utils.data.distributed
4 | from torch.utils.data.distributed import DistributedSampler
5 | import torch.nn.parallel
6 | from transformers import BertForSequenceClassification, AdamW
7 | from transformers import get_linear_schedule_with_warmup
8 | import numpy as np
9 | import torch.nn as nn
10 | from sklearn.metrics import f1_score, accuracy_score
11 | from tqdm import tqdm
12 | import os
13 | import re
14 | import torch.nn.functional as F
15 | from torch.utils.tensorboard import SummaryWriter
16 | from itertools import cycle
17 | import argparse
18 | import torch.distributed as dist
19 | import time
20 | import online_augmentation
21 | import logging
22 | from process_data.Load_data import DATA_process
23 |
24 |
25 |
26 | def set_seed(seed):
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed_all(seed)
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def cross_entropy(logits, target):
35 | p = F.softmax(logits, dim=1)
36 | log_p = -torch.log(p)
37 | loss = target*log_p
38 | # print(target,p,log_p,loss)
39 | batch_num = logits.shape[0]
40 | return loss.sum()/batch_num
41 |
42 |
43 |
44 | def flat_accuracy(preds, labels):
45 | pred_flat = np.argmax(preds, axis=1).flatten()
46 | labels_flat = labels.flatten()
47 | return accuracy_score(labels_flat, pred_flat)
48 |
49 |
50 | def reduce_tensor(tensor, args):
51 | rt = tensor.clone()
52 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
53 | rt /= args.world_size
54 | return rt
55 |
56 |
57 | def tensorboard_settings(args):
58 | if 'raw' in args.mode:
59 | if args.data_path:
60 | # raw_aug
61 | log_dir = os.path.join(args.output_dir, 'Raw_Aug_{}_{}_{}_{}_{}'.format(args.data_path.split(
62 | '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size))
63 | if os.path.exists(log_dir):
64 | raise IOError(
65 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir))
66 | writer = SummaryWriter(log_dir=log_dir)
67 | else:
68 | # raw
69 | if args.random_mix:
70 | log_dir = os.path.join(args.output_dir, 'Raw_random_mixup_{}_{}_{}'.format(
71 | args.random_mix, args.alpha, args.seed))
72 | if os.path.exists(log_dir):
73 | raise IOError(
74 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir))
75 | writer = SummaryWriter(log_dir=log_dir)
76 | else:
77 | log_dir = os.path.join(
78 | args.output_dir, 'Raw_{}'.format(args.seed))
79 | if os.path.exists(log_dir):
80 | raise IOError(
81 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir))
82 | writer = SummaryWriter(log_dir=log_dir)
83 | elif args.mode == 'aug':
84 | # aug
85 | log_dir = os.path.join(args.output_dir, 'Aug_{}_{}_{}_{}_{}'.format(args.data_path.split(
86 | '/')[-1], args.seed, args.augweight, args.batch_size, args.aug_batch_size))
87 | if os.path.exists(log_dir):
88 | raise IOError(
89 | 'This tensorboard file {} already exists! Please do not train the same data repeatedly, if you want to train this dataset, delete corresponding tensorboard file first! '.format(log_dir))
90 | writer = SummaryWriter(log_dir=log_dir)
91 | return writer
92 |
93 |
94 | def logging_settings(args):
95 | logger = logging.getLogger('result')
96 | logger.setLevel(logging.INFO)
97 | fmt = logging.Formatter(
98 | fmt='%(asctime)s - %(filename)s - %(levelname)s: %(message)s')
99 | if not os.path.exists(os.path.join('DATA', args.data.upper(), 'logs')):
100 | os.makedirs(os.path.join(
101 | 'DATA', args.data.upper(), 'logs'))
102 | if args.low_resource_dir:
103 | log_path = os.path.join('DATA', args.data.upper(),'logs', 'lowresourcebest_result.log')
104 | else:
105 | log_path = os.path.join('DATA', args.data.upper(),'logs', 'best_result.log')
106 |
107 | fh = logging.FileHandler(log_path, mode='a+', encoding='utf-8')
108 | ft=logging.Filter(name='result.a')
109 | fh.setFormatter(fmt)
110 | fh.setLevel(logging.INFO)
111 | fh.addFilter(ft)
112 | logger.addHandler(fh)
113 | result_logger=logging.getLogger('result.a')
114 | return result_logger
115 | def loading_model(args,label_num):
116 | t1 = time.time()
117 | if args.local_rank == -1:
118 | device = torch.device(
119 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
120 | args.n_gpu = torch.cuda.device_count()
121 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
122 | torch.cuda.set_device(args.local_rank)
123 | device = torch.device("cuda", args.local_rank)
124 | torch.distributed.init_process_group(backend='nccl')
125 | args.n_gpu = 1 # the number of gpu on each proc
126 | args.device = device
127 | if args.local_rank != -1:
128 | args.world_size = torch.cuda.device_count()
129 | else:
130 | args.world_size = 1
131 | print('*'*40, '\nSettings:{}'.format(args))
132 | print('*'*40)
133 | print('='*20, 'Loading models', '='*20)
134 | model = BertForSequenceClassification.from_pretrained(
135 | args.model, num_labels=label_num)
136 | model.to(device)
137 | t2 = time.time()
138 | print(
139 | '='*20, 'Loading models complete!, cost {:.2f}s'.format(t2-t1), '='*20)
140 | # model parrallel
141 | if args.local_rank != -1:
142 | model = torch.nn.parallel.DistributedDataParallel(
143 | model, device_ids=[args.local_rank])
144 | elif args.n_gpu > 1:
145 | model = nn.DataParallel(model)
146 | if args.load_model_path is not None:
147 | print("="*20, "Load model from %s", args.load_model_path,)
148 | model.load_state_dict(torch.load(args.load_model_path))
149 | return model
150 |
151 | def parse_argument():
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument('--local_rank', default=-1, type=int,
154 | help='node rank for distributed training')
155 | parser.add_argument("--no_cuda", action='store_true',
156 | help="Avoid using CUDA when available")
157 |
158 | parser.add_argument(
159 | '--mode', type=str, choices=['raw', 'aug', 'raw_aug', 'visualize'], required=True)
160 | parser.add_argument('--save_model', action='store_true')
161 | parser.add_argument('--load_model_path', type=str)
162 | parser.add_argument('--data', type=str, required=True)
163 | parser.add_argument('--num_proc', type=int, default=8,
164 | help='multi process number used in dataloader process')
165 |
166 | # training settings
167 | parser.add_argument('--output_dir', type=str, help="tensorboard fileoutput directory")
168 | parser.add_argument('--epoch', type=int, default=5, help='train epochs')
169 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate')
170 | parser.add_argument('--seed', default=42, type=int, help='seed ')
171 | parser.add_argument('--batch_size', default=128, type=int,
172 | help='train examples in each batch')
173 | parser.add_argument('--val_steps', default=100, type=int,
174 | help='evaluate on dev datasets every steps')
175 | parser.add_argument('--max_length', default=128,
176 | type=int, help='encode max length')
177 | parser.add_argument('--label_name', type=str, default='label')
178 | parser.add_argument('--model', type=str, default='bert-base-uncased')
179 | parser.add_argument('--low_resource_dir', type=str,
180 | help='Low resource data dir')
181 |
182 | # train on augmentation dataset parameters
183 | parser.add_argument('--aug_batch_size', default=128,
184 | type=int, help='train examples in each batch')
185 | parser.add_argument('--augweight', default=0.2, type=float)
186 | parser.add_argument('--data_path', type=str, help="augmentation file path")
187 | parser.add_argument('--min_train_token', type=int, default=0,
188 | help="minimum token num restriction for train dataset")
189 | parser.add_argument('--max_train_token', type=int, default=0,
190 | help="maximum token num restriction for train dataset")
191 | parser.add_argument('--mix', action='store_false', help='train on 01mixup')
192 |
193 | # random mixup
194 | parser.add_argument('--alpha', type=float, default=0.1,
195 | help="online augmentation alpha")
196 | parser.add_argument('--onlyaug', action='store_true',
197 | help="train only on online aug batch")
198 | parser.add_argument('--difflen', action='store_true',
199 | help="train only on online aug batch")
200 | parser.add_argument('--random_mix', type=str, help="random mixup ")
201 |
202 | # visualize dataset
203 |
204 | args = parser.parse_args()
205 | if args.data == 'trec':
206 | try:
207 | assert args.label_name in ['label-fine', 'label-coarse']
208 | except AssertionError:
209 | raise(AssertionError(
210 | "If you want to train on trec dataset with augmentation, you have to name the label of split"))
211 | if not args.output_dir:
212 | args.output_dir = os.path.join(
213 | 'DATA', args.data.upper(), 'runs', args.label_name, args.mode)
214 | if args.mode == 'raw':
215 | args.batch_size = 128
216 | if 'aug' in args.mode:
217 | assert args.data_path
218 | if args.mode == 'aug':
219 | args.seed = 42
220 | if not args.output_dir:
221 | args.output_dir = os.path.join(
222 | 'DATA', args.data.upper(), 'runs', args.mode)
223 | if not os.path.exists(args.output_dir):
224 | os.makedirs(args.output_dir)
225 | if args.data in ['rte', 'mrpc', 'qqp', 'mnli', 'qnli']:
226 | args.task = 'pair'
227 | else:
228 | args.task = 'single'
229 |
230 | return args
231 |
232 |
233 | def train(args):
234 | # ========================================
235 | # Tensorboard &Logging
236 | # ========================================
237 | writer = tensorboard_settings(args)
238 | result_logger = logging_settings(args)
239 | data_process = DATA_process(args)
240 | # ========================================
241 | # Loading datasets
242 | # ========================================
243 | print('='*20, 'Start processing dataset', '='*20)
244 | t1 = time.time()
245 |
246 | val_dataloader = data_process.validation_data()
247 |
248 | if args.mode != 'aug':
249 | train_dataloader, label_num = data_process.train_data(count_label=True)
250 | # print('Label_num',label_num)
251 | if args.data_path:
252 | print('='*20, 'Train Augmentation dataset path: {}'.format(args.data_path), '='*20)
253 | aug_dataloader = data_process.augmentation_data()
254 | if args.mode == 'aug':
255 | train_dataloader = aug_dataloader
256 | else:
257 | aug_dataloader = cycle(aug_dataloader)
258 |
259 | t2 = time.time()
260 | print('='*20, 'Dataset process done! cost {:.2f}s'.format(t2-t1), '='*20)
261 |
262 | # ========================================
263 | # Model
264 | # ========================================
265 | model=loading_model(args,label_num)
266 | # ========================================
267 | # Optimizer Settings
268 | # ========================================
269 | optimizer = AdamW(model.parameters(), lr=args.lr)
270 | all_steps = args.epoch*len(train_dataloader)
271 | scheduler = get_linear_schedule_with_warmup(
272 | optimizer, num_warmup_steps=20, num_training_steps=all_steps)
273 | criterion = nn.CrossEntropyLoss()
274 | model.train()
275 |
276 | # ========================================
277 | # Train
278 | # ========================================
279 | print('='*20, 'Start training', '='*20)
280 | best_acc = 0
281 | args.val_steps = min(len(train_dataloader), args.val_steps)
282 |
283 | for epoch in range(args.epoch):
284 | bar = tqdm(enumerate(train_dataloader), total=len(
285 | train_dataloader)//args.world_size)
286 | fail = 0
287 | loss = 0
288 | for step, batch in bar:
289 | model.zero_grad()
290 |
291 | # ----------------------------------------------
292 | # Train_dataloader
293 | # ----------------------------------------------
294 | if args.random_mix:
295 | try:
296 |
297 | input_ids, target_a = batch['input_ids'], batch['labels']
298 | lam = np.random.choice([0, 0.1, 0.2, 0.3])
299 | exchanged_ids, new_index = online_augmentation.random_mixup(
300 | args, input_ids, target_a, lam)
301 | target_b = target_a[new_index]
302 | outputs = model(exchanged_ids.to(args.device), token_type_ids=None, attention_mask=(
303 | exchanged_ids > 0).to(args.device))
304 | logits = outputs.logits
305 | loss = criterion(logits.to(args.device), target_a.to(
306 | args.device))*(1-lam)+criterion(logits.to(args.device), target_b.to(args.device))*lam
307 |
308 |
309 | except Exception as e:
310 | fail += 1
311 | batch = {k: v.to(args.device) for k, v in batch.items()}
312 | outputs = model(**batch)
313 | loss = outputs.loss
314 | elif args.model == 'aug':
315 | # train only on augmentation dataset
316 | batch = {k: v.to(args.device) for k, v in batch.items()}
317 | if args.mix:
318 | # train on 01 tree mixup augmentation dataset
319 | mix_label = batch['labels']
320 | del batch['labels']
321 |
322 | outputs = model(**batch)
323 | logits = outputs.logits
324 |
325 | loss = cross_entropy(logits, mix_label)
326 | else:
327 | # train on 00&11 tree mixup augmentation dataset
328 | outputs = model(**batch)
329 | loss = outputs.loss
330 | else:
331 | # normal train
332 |
333 | batch = {k: v.to(args.device) for k, v in batch.items()}
334 |
335 | outputs = model(**batch)
336 | loss = outputs.loss
337 | # ----------------------------------------------
338 | # Aug_dataloader
339 | # ----------------------------------------------
340 | if args.mode == 'raw_aug':
341 | aug_batch = next(aug_dataloader)
342 | aug_batch = {k: v.to(args.device) for k, v in aug_batch.items()}
343 |
344 | if args.mix:
345 | mix_label = aug_batch['labels']
346 | del aug_batch['labels']
347 | aug_outputs = model(**aug_batch)
348 | aug_logits = aug_outputs.logits
349 |
350 | aug_loss = cross_entropy(aug_logits, mix_label)
351 | else:
352 | aug_outputs = model(**aug_batch)
353 | aug_loss = aug_outputs.loss
354 | loss += aug_loss*args.augweight # for sst2,rte reaches best performance
355 |
356 | # Backward propagation
357 | if args.n_gpu > 1:
358 | loss = loss.mean()
359 | loss.backward()
360 | optimizer.step()
361 | scheduler.step()
362 | optimizer.zero_grad()
363 | if args.local_rank == 0 or args.local_rank == -1:
364 | writer.add_scalar("Loss/loss", loss, step +
365 | epoch*len(train_dataloader))
366 | writer.flush()
367 | if args.random_mix:
368 | bar.set_description(
369 | '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}| Fail:{}|'.format(epoch, args.epoch, best_acc*100, fail))
370 | else:
371 | bar.set_description(
372 | '| Epoch: {:<2}/{:<2}| Best acc:{:.2f}|'.format(epoch, args.epoch, best_acc*100))
373 |
374 | # =================================================
375 | # Validation
376 | # =================================================
377 | if (epoch*len(train_dataloader)+step+1) % args.val_steps == 0:
378 | total_eval_accuracy = 0
379 | total_val_loss = 0
380 | model.eval() # evaluation after each epoch
381 | for i, batch in enumerate(val_dataloader):
382 | with torch.no_grad():
383 | batch = {k: v.to(args.device)
384 | for k, v in batch.items()}
385 | outputs = model(**batch)
386 | logits = outputs.logits
387 | loss = outputs.loss
388 |
389 | if args.n_gpu > 1:
390 | loss = loss.mean()
391 | logits = logits.detach().cpu().numpy()
392 | label_ids = batch['labels'].to('cpu').numpy()
393 |
394 | accuracy = flat_accuracy(logits, label_ids)
395 | if args.local_rank != -1:
396 | torch.distributed.barrier()
397 | reduced_loss = reduce_tensor(loss, args)
398 | accuracy = torch.tensor(accuracy).to(args.device)
399 | reduced_acc = reduce_tensor(accuracy, args)
400 | total_val_loss += reduced_loss
401 | total_eval_accuracy += reduced_acc
402 | else:
403 | total_eval_accuracy += accuracy.item()
404 | total_val_loss += loss.item()
405 | avg_val_loss = total_val_loss/len(val_dataloader)
406 | avg_val_accuracy = total_eval_accuracy/len(val_dataloader)
407 | if avg_val_accuracy > best_acc:
408 | best_acc = avg_val_accuracy
409 | bset_steps = (epoch*len(train_dataloader) +
410 | step)*args.batch_size
411 | if args.save_model:
412 | torch.save(model.state_dict(), 'best_model.pt')
413 | if args.local_rank == 0 or args.local_rank == -1:
414 | writer.add_scalar("Test/Loss", avg_val_loss,
415 | epoch*len(train_dataloader)+step)
416 | writer.add_scalar(
417 | "Test/Accuracy", avg_val_accuracy, epoch*len(train_dataloader)+step)
418 | writer.flush()
419 | # print(f'Validation loss: {avg_val_loss}')
420 | # print(f'Accuracy: {avg_val_accuracy:.5f}')
421 | # print('Best Accuracy:{:.5f} Steps:{}\n'.format(best_acc, bset_steps))
422 |
423 | if args.data_path:
424 | aug_num=args.data_path.split('_')[-1]
425 |
426 | if args.low_resource_dir:
427 | # low resource raw_aug
428 | partial = re.findall(r'low_resource_(0.\d+)',
429 | args.low_resource_dir)[0]
430 | aug_num_seed = aug_num+'_'+str(args.seed)
431 | result_logger.info('-'*160)
432 | result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format(
433 | args.data+'_'+partial, args.mode, aug_num_seed, round(best_acc*100, 3), bset_steps, args.augweight, args.data_path))
434 | else:
435 | # raw_aug
436 | aug_data_seed=re.findall(r'seed(\d)',args.data_path)[0]
437 | aug_num_seed = aug_num+'_'+aug_data_seed
438 | result_logger.info('-'*160)
439 | result_logger.info('| Data : {} | Mode: {:<8} | #Aug {:<6} | Best acc:{} | Steps:{} | Weight {} |Aug data: {}'.format(
440 | args.data, args.mode, aug_num_seed ,round(best_acc*100,3), bset_steps, args.augweight,args.data_path))
441 | else:
442 | if args.low_resource_dir:
443 | # low resource raw
444 | partial=re.findall(r'low_resource_(0.\d+)',args.low_resource_dir)[0]
445 | result_logger.info('-'*160)
446 | result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format(
447 | args.data+'-'+partial, args.mode, args.seed, round(best_acc*100,3), bset_steps,bool(args.random_mix) ,args.data_path))
448 | else:
449 | # raw
450 | result_logger.info('-'*160)
451 | result_logger.info('| Data : {} | Mode: {:.8} | Seed: {} | Best acc:{} | Steps:{} | Randommix: {} | Aug data: {}'.format(
452 | args.data, args.mode, args.seed, round(best_acc*100,3), bset_steps, bool(args.random_mix),args.data_path))
453 |
454 |
455 |
456 |
457 |
458 | def main(args):
459 | set_seed(args.seed)
460 | if args.mode in ['raw', 'raw_aug', 'aug']:
461 | if args.low_resource_dir:
462 | print("="*20, ' Lowresource ', '='*20)
463 | train(args)
464 | if __name__ == '__main__':
465 | args = parse_argument()
466 | main(args)
467 |
--------------------------------------------------------------------------------
/transformers_doc/pytorch/task_summary.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "task_summary.ipynb",
7 | "provenance": [],
8 | "include_colab_link": true
9 | },
10 | "language_info": {
11 | "name": "python"
12 | },
13 | "kernelspec": {
14 | "name": "python3",
15 | "display_name": "Python 3"
16 | },
17 | "widgets": {
18 | "application/vnd.jupyter.widget-state+json": {
19 | "7c08442589064421b77b6aa52ce5b947": {
20 | "model_module": "@jupyter-widgets/controls",
21 | "model_name": "HBoxModel",
22 | "model_module_version": "1.5.0",
23 | "state": {
24 | "_view_name": "HBoxView",
25 | "_dom_classes": [],
26 | "_model_name": "HBoxModel",
27 | "_view_module": "@jupyter-widgets/controls",
28 | "_model_module_version": "1.5.0",
29 | "_view_count": null,
30 | "_view_module_version": "1.5.0",
31 | "box_style": "",
32 | "layout": "IPY_MODEL_23f3f545c30d4226987e7909e8365413",
33 | "_model_module": "@jupyter-widgets/controls",
34 | "children": [
35 | "IPY_MODEL_3d29469f501c4a0b96c211241459f34a",
36 | "IPY_MODEL_9578e983215a4bc497b5882a4244b8a9",
37 | "IPY_MODEL_47b6aefd82c04f0fb3a3de2c7f37c58f"
38 | ]
39 | }
40 | },
41 | "23f3f545c30d4226987e7909e8365413": {
42 | "model_module": "@jupyter-widgets/base",
43 | "model_name": "LayoutModel",
44 | "model_module_version": "1.2.0",
45 | "state": {
46 | "_view_name": "LayoutView",
47 | "grid_template_rows": null,
48 | "right": null,
49 | "justify_content": null,
50 | "_view_module": "@jupyter-widgets/base",
51 | "overflow": null,
52 | "_model_module_version": "1.2.0",
53 | "_view_count": null,
54 | "flex_flow": null,
55 | "width": null,
56 | "min_width": null,
57 | "border": null,
58 | "align_items": null,
59 | "bottom": null,
60 | "_model_module": "@jupyter-widgets/base",
61 | "top": null,
62 | "grid_column": null,
63 | "overflow_y": null,
64 | "overflow_x": null,
65 | "grid_auto_flow": null,
66 | "grid_area": null,
67 | "grid_template_columns": null,
68 | "flex": null,
69 | "_model_name": "LayoutModel",
70 | "justify_items": null,
71 | "grid_row": null,
72 | "max_height": null,
73 | "align_content": null,
74 | "visibility": null,
75 | "align_self": null,
76 | "height": null,
77 | "min_height": null,
78 | "padding": null,
79 | "grid_auto_rows": null,
80 | "grid_gap": null,
81 | "max_width": null,
82 | "order": null,
83 | "_view_module_version": "1.2.0",
84 | "grid_template_areas": null,
85 | "object_position": null,
86 | "object_fit": null,
87 | "grid_auto_columns": null,
88 | "margin": null,
89 | "display": null,
90 | "left": null
91 | }
92 | },
93 | "3d29469f501c4a0b96c211241459f34a": {
94 | "model_module": "@jupyter-widgets/controls",
95 | "model_name": "HTMLModel",
96 | "model_module_version": "1.5.0",
97 | "state": {
98 | "_view_name": "HTMLView",
99 | "style": "IPY_MODEL_b1e066e72ca04508bf50a60961e66bea",
100 | "_dom_classes": [],
101 | "description": "",
102 | "_model_name": "HTMLModel",
103 | "placeholder": "",
104 | "_view_module": "@jupyter-widgets/controls",
105 | "_model_module_version": "1.5.0",
106 | "value": "Downloading: 100%",
107 | "_view_count": null,
108 | "_view_module_version": "1.5.0",
109 | "description_tooltip": null,
110 | "_model_module": "@jupyter-widgets/controls",
111 | "layout": "IPY_MODEL_96fe7d4c88e541eea33dde3889e3e381"
112 | }
113 | },
114 | "9578e983215a4bc497b5882a4244b8a9": {
115 | "model_module": "@jupyter-widgets/controls",
116 | "model_name": "FloatProgressModel",
117 | "model_module_version": "1.5.0",
118 | "state": {
119 | "_view_name": "ProgressView",
120 | "style": "IPY_MODEL_65b763eef5184606b28d4673e6ad857d",
121 | "_dom_classes": [],
122 | "description": "",
123 | "_model_name": "FloatProgressModel",
124 | "bar_style": "success",
125 | "max": 629,
126 | "_view_module": "@jupyter-widgets/controls",
127 | "_model_module_version": "1.5.0",
128 | "value": 629,
129 | "_view_count": null,
130 | "_view_module_version": "1.5.0",
131 | "orientation": "horizontal",
132 | "min": 0,
133 | "description_tooltip": null,
134 | "_model_module": "@jupyter-widgets/controls",
135 | "layout": "IPY_MODEL_40984bf56a8b448e82451fa60fc15f7c"
136 | }
137 | },
138 | "47b6aefd82c04f0fb3a3de2c7f37c58f": {
139 | "model_module": "@jupyter-widgets/controls",
140 | "model_name": "HTMLModel",
141 | "model_module_version": "1.5.0",
142 | "state": {
143 | "_view_name": "HTMLView",
144 | "style": "IPY_MODEL_21bf499e3fb34c02b63934ea6e4f162f",
145 | "_dom_classes": [],
146 | "description": "",
147 | "_model_name": "HTMLModel",
148 | "placeholder": "",
149 | "_view_module": "@jupyter-widgets/controls",
150 | "_model_module_version": "1.5.0",
151 | "value": " 629/629 [00:00<00:00, 10.2kB/s]",
152 | "_view_count": null,
153 | "_view_module_version": "1.5.0",
154 | "description_tooltip": null,
155 | "_model_module": "@jupyter-widgets/controls",
156 | "layout": "IPY_MODEL_9566729c64304007953c365eb0341df2"
157 | }
158 | },
159 | "b1e066e72ca04508bf50a60961e66bea": {
160 | "model_module": "@jupyter-widgets/controls",
161 | "model_name": "DescriptionStyleModel",
162 | "model_module_version": "1.5.0",
163 | "state": {
164 | "_view_name": "StyleView",
165 | "_model_name": "DescriptionStyleModel",
166 | "description_width": "",
167 | "_view_module": "@jupyter-widgets/base",
168 | "_model_module_version": "1.5.0",
169 | "_view_count": null,
170 | "_view_module_version": "1.2.0",
171 | "_model_module": "@jupyter-widgets/controls"
172 | }
173 | },
174 | "96fe7d4c88e541eea33dde3889e3e381": {
175 | "model_module": "@jupyter-widgets/base",
176 | "model_name": "LayoutModel",
177 | "model_module_version": "1.2.0",
178 | "state": {
179 | "_view_name": "LayoutView",
180 | "grid_template_rows": null,
181 | "right": null,
182 | "justify_content": null,
183 | "_view_module": "@jupyter-widgets/base",
184 | "overflow": null,
185 | "_model_module_version": "1.2.0",
186 | "_view_count": null,
187 | "flex_flow": null,
188 | "width": null,
189 | "min_width": null,
190 | "border": null,
191 | "align_items": null,
192 | "bottom": null,
193 | "_model_module": "@jupyter-widgets/base",
194 | "top": null,
195 | "grid_column": null,
196 | "overflow_y": null,
197 | "overflow_x": null,
198 | "grid_auto_flow": null,
199 | "grid_area": null,
200 | "grid_template_columns": null,
201 | "flex": null,
202 | "_model_name": "LayoutModel",
203 | "justify_items": null,
204 | "grid_row": null,
205 | "max_height": null,
206 | "align_content": null,
207 | "visibility": null,
208 | "align_self": null,
209 | "height": null,
210 | "min_height": null,
211 | "padding": null,
212 | "grid_auto_rows": null,
213 | "grid_gap": null,
214 | "max_width": null,
215 | "order": null,
216 | "_view_module_version": "1.2.0",
217 | "grid_template_areas": null,
218 | "object_position": null,
219 | "object_fit": null,
220 | "grid_auto_columns": null,
221 | "margin": null,
222 | "display": null,
223 | "left": null
224 | }
225 | },
226 | "65b763eef5184606b28d4673e6ad857d": {
227 | "model_module": "@jupyter-widgets/controls",
228 | "model_name": "ProgressStyleModel",
229 | "model_module_version": "1.5.0",
230 | "state": {
231 | "_view_name": "StyleView",
232 | "_model_name": "ProgressStyleModel",
233 | "description_width": "",
234 | "_view_module": "@jupyter-widgets/base",
235 | "_model_module_version": "1.5.0",
236 | "_view_count": null,
237 | "_view_module_version": "1.2.0",
238 | "bar_color": null,
239 | "_model_module": "@jupyter-widgets/controls"
240 | }
241 | },
242 | "40984bf56a8b448e82451fa60fc15f7c": {
243 | "model_module": "@jupyter-widgets/base",
244 | "model_name": "LayoutModel",
245 | "model_module_version": "1.2.0",
246 | "state": {
247 | "_view_name": "LayoutView",
248 | "grid_template_rows": null,
249 | "right": null,
250 | "justify_content": null,
251 | "_view_module": "@jupyter-widgets/base",
252 | "overflow": null,
253 | "_model_module_version": "1.2.0",
254 | "_view_count": null,
255 | "flex_flow": null,
256 | "width": null,
257 | "min_width": null,
258 | "border": null,
259 | "align_items": null,
260 | "bottom": null,
261 | "_model_module": "@jupyter-widgets/base",
262 | "top": null,
263 | "grid_column": null,
264 | "overflow_y": null,
265 | "overflow_x": null,
266 | "grid_auto_flow": null,
267 | "grid_area": null,
268 | "grid_template_columns": null,
269 | "flex": null,
270 | "_model_name": "LayoutModel",
271 | "justify_items": null,
272 | "grid_row": null,
273 | "max_height": null,
274 | "align_content": null,
275 | "visibility": null,
276 | "align_self": null,
277 | "height": null,
278 | "min_height": null,
279 | "padding": null,
280 | "grid_auto_rows": null,
281 | "grid_gap": null,
282 | "max_width": null,
283 | "order": null,
284 | "_view_module_version": "1.2.0",
285 | "grid_template_areas": null,
286 | "object_position": null,
287 | "object_fit": null,
288 | "grid_auto_columns": null,
289 | "margin": null,
290 | "display": null,
291 | "left": null
292 | }
293 | },
294 | "21bf499e3fb34c02b63934ea6e4f162f": {
295 | "model_module": "@jupyter-widgets/controls",
296 | "model_name": "DescriptionStyleModel",
297 | "model_module_version": "1.5.0",
298 | "state": {
299 | "_view_name": "StyleView",
300 | "_model_name": "DescriptionStyleModel",
301 | "description_width": "",
302 | "_view_module": "@jupyter-widgets/base",
303 | "_model_module_version": "1.5.0",
304 | "_view_count": null,
305 | "_view_module_version": "1.2.0",
306 | "_model_module": "@jupyter-widgets/controls"
307 | }
308 | },
309 | "9566729c64304007953c365eb0341df2": {
310 | "model_module": "@jupyter-widgets/base",
311 | "model_name": "LayoutModel",
312 | "model_module_version": "1.2.0",
313 | "state": {
314 | "_view_name": "LayoutView",
315 | "grid_template_rows": null,
316 | "right": null,
317 | "justify_content": null,
318 | "_view_module": "@jupyter-widgets/base",
319 | "overflow": null,
320 | "_model_module_version": "1.2.0",
321 | "_view_count": null,
322 | "flex_flow": null,
323 | "width": null,
324 | "min_width": null,
325 | "border": null,
326 | "align_items": null,
327 | "bottom": null,
328 | "_model_module": "@jupyter-widgets/base",
329 | "top": null,
330 | "grid_column": null,
331 | "overflow_y": null,
332 | "overflow_x": null,
333 | "grid_auto_flow": null,
334 | "grid_area": null,
335 | "grid_template_columns": null,
336 | "flex": null,
337 | "_model_name": "LayoutModel",
338 | "justify_items": null,
339 | "grid_row": null,
340 | "max_height": null,
341 | "align_content": null,
342 | "visibility": null,
343 | "align_self": null,
344 | "height": null,
345 | "min_height": null,
346 | "padding": null,
347 | "grid_auto_rows": null,
348 | "grid_gap": null,
349 | "max_width": null,
350 | "order": null,
351 | "_view_module_version": "1.2.0",
352 | "grid_template_areas": null,
353 | "object_position": null,
354 | "object_fit": null,
355 | "grid_auto_columns": null,
356 | "margin": null,
357 | "display": null,
358 | "left": null
359 | }
360 | },
361 | "6bf72948e9ce47d296cf9673d5c11ced": {
362 | "model_module": "@jupyter-widgets/controls",
363 | "model_name": "HBoxModel",
364 | "model_module_version": "1.5.0",
365 | "state": {
366 | "_view_name": "HBoxView",
367 | "_dom_classes": [],
368 | "_model_name": "HBoxModel",
369 | "_view_module": "@jupyter-widgets/controls",
370 | "_model_module_version": "1.5.0",
371 | "_view_count": null,
372 | "_view_module_version": "1.5.0",
373 | "box_style": "",
374 | "layout": "IPY_MODEL_e32995501a0243c1ab56ee42ecddd80e",
375 | "_model_module": "@jupyter-widgets/controls",
376 | "children": [
377 | "IPY_MODEL_e1c3b3a1b0ce49cc9db2bd16be25c929",
378 | "IPY_MODEL_a1f0479b17f8451daaf69b0af81b9adb",
379 | "IPY_MODEL_5f7a3c947c564436bd2dfcc04d4da5ee"
380 | ]
381 | }
382 | },
383 | "e32995501a0243c1ab56ee42ecddd80e": {
384 | "model_module": "@jupyter-widgets/base",
385 | "model_name": "LayoutModel",
386 | "model_module_version": "1.2.0",
387 | "state": {
388 | "_view_name": "LayoutView",
389 | "grid_template_rows": null,
390 | "right": null,
391 | "justify_content": null,
392 | "_view_module": "@jupyter-widgets/base",
393 | "overflow": null,
394 | "_model_module_version": "1.2.0",
395 | "_view_count": null,
396 | "flex_flow": null,
397 | "width": null,
398 | "min_width": null,
399 | "border": null,
400 | "align_items": null,
401 | "bottom": null,
402 | "_model_module": "@jupyter-widgets/base",
403 | "top": null,
404 | "grid_column": null,
405 | "overflow_y": null,
406 | "overflow_x": null,
407 | "grid_auto_flow": null,
408 | "grid_area": null,
409 | "grid_template_columns": null,
410 | "flex": null,
411 | "_model_name": "LayoutModel",
412 | "justify_items": null,
413 | "grid_row": null,
414 | "max_height": null,
415 | "align_content": null,
416 | "visibility": null,
417 | "align_self": null,
418 | "height": null,
419 | "min_height": null,
420 | "padding": null,
421 | "grid_auto_rows": null,
422 | "grid_gap": null,
423 | "max_width": null,
424 | "order": null,
425 | "_view_module_version": "1.2.0",
426 | "grid_template_areas": null,
427 | "object_position": null,
428 | "object_fit": null,
429 | "grid_auto_columns": null,
430 | "margin": null,
431 | "display": null,
432 | "left": null
433 | }
434 | },
435 | "e1c3b3a1b0ce49cc9db2bd16be25c929": {
436 | "model_module": "@jupyter-widgets/controls",
437 | "model_name": "HTMLModel",
438 | "model_module_version": "1.5.0",
439 | "state": {
440 | "_view_name": "HTMLView",
441 | "style": "IPY_MODEL_d0d537e5dc744a5794187c5a498358fc",
442 | "_dom_classes": [],
443 | "description": "",
444 | "_model_name": "HTMLModel",
445 | "placeholder": "",
446 | "_view_module": "@jupyter-widgets/controls",
447 | "_model_module_version": "1.5.0",
448 | "value": "Downloading: 100%",
449 | "_view_count": null,
450 | "_view_module_version": "1.5.0",
451 | "description_tooltip": null,
452 | "_model_module": "@jupyter-widgets/controls",
453 | "layout": "IPY_MODEL_bdbbb4175a6c4bd389755dc0f8d186cb"
454 | }
455 | },
456 | "a1f0479b17f8451daaf69b0af81b9adb": {
457 | "model_module": "@jupyter-widgets/controls",
458 | "model_name": "FloatProgressModel",
459 | "model_module_version": "1.5.0",
460 | "state": {
461 | "_view_name": "ProgressView",
462 | "style": "IPY_MODEL_902c7115997647b9a5140ba707d74390",
463 | "_dom_classes": [],
464 | "description": "",
465 | "_model_name": "FloatProgressModel",
466 | "bar_style": "success",
467 | "max": 267844284,
468 | "_view_module": "@jupyter-widgets/controls",
469 | "_model_module_version": "1.5.0",
470 | "value": 267844284,
471 | "_view_count": null,
472 | "_view_module_version": "1.5.0",
473 | "orientation": "horizontal",
474 | "min": 0,
475 | "description_tooltip": null,
476 | "_model_module": "@jupyter-widgets/controls",
477 | "layout": "IPY_MODEL_cbff745d3ebb4ba0b9bfd8b1a210550f"
478 | }
479 | },
480 | "5f7a3c947c564436bd2dfcc04d4da5ee": {
481 | "model_module": "@jupyter-widgets/controls",
482 | "model_name": "HTMLModel",
483 | "model_module_version": "1.5.0",
484 | "state": {
485 | "_view_name": "HTMLView",
486 | "style": "IPY_MODEL_db0e50cf79174b99b7943fe9fa51b738",
487 | "_dom_classes": [],
488 | "description": "",
489 | "_model_name": "HTMLModel",
490 | "placeholder": "",
491 | "_view_module": "@jupyter-widgets/controls",
492 | "_model_module_version": "1.5.0",
493 | "value": " 255M/255M [00:06<00:00, 35.0MB/s]",
494 | "_view_count": null,
495 | "_view_module_version": "1.5.0",
496 | "description_tooltip": null,
497 | "_model_module": "@jupyter-widgets/controls",
498 | "layout": "IPY_MODEL_91e992390cf94d4bb23164ccbfb2a4d0"
499 | }
500 | },
501 | "d0d537e5dc744a5794187c5a498358fc": {
502 | "model_module": "@jupyter-widgets/controls",
503 | "model_name": "DescriptionStyleModel",
504 | "model_module_version": "1.5.0",
505 | "state": {
506 | "_view_name": "StyleView",
507 | "_model_name": "DescriptionStyleModel",
508 | "description_width": "",
509 | "_view_module": "@jupyter-widgets/base",
510 | "_model_module_version": "1.5.0",
511 | "_view_count": null,
512 | "_view_module_version": "1.2.0",
513 | "_model_module": "@jupyter-widgets/controls"
514 | }
515 | },
516 | "bdbbb4175a6c4bd389755dc0f8d186cb": {
517 | "model_module": "@jupyter-widgets/base",
518 | "model_name": "LayoutModel",
519 | "model_module_version": "1.2.0",
520 | "state": {
521 | "_view_name": "LayoutView",
522 | "grid_template_rows": null,
523 | "right": null,
524 | "justify_content": null,
525 | "_view_module": "@jupyter-widgets/base",
526 | "overflow": null,
527 | "_model_module_version": "1.2.0",
528 | "_view_count": null,
529 | "flex_flow": null,
530 | "width": null,
531 | "min_width": null,
532 | "border": null,
533 | "align_items": null,
534 | "bottom": null,
535 | "_model_module": "@jupyter-widgets/base",
536 | "top": null,
537 | "grid_column": null,
538 | "overflow_y": null,
539 | "overflow_x": null,
540 | "grid_auto_flow": null,
541 | "grid_area": null,
542 | "grid_template_columns": null,
543 | "flex": null,
544 | "_model_name": "LayoutModel",
545 | "justify_items": null,
546 | "grid_row": null,
547 | "max_height": null,
548 | "align_content": null,
549 | "visibility": null,
550 | "align_self": null,
551 | "height": null,
552 | "min_height": null,
553 | "padding": null,
554 | "grid_auto_rows": null,
555 | "grid_gap": null,
556 | "max_width": null,
557 | "order": null,
558 | "_view_module_version": "1.2.0",
559 | "grid_template_areas": null,
560 | "object_position": null,
561 | "object_fit": null,
562 | "grid_auto_columns": null,
563 | "margin": null,
564 | "display": null,
565 | "left": null
566 | }
567 | },
568 | "902c7115997647b9a5140ba707d74390": {
569 | "model_module": "@jupyter-widgets/controls",
570 | "model_name": "ProgressStyleModel",
571 | "model_module_version": "1.5.0",
572 | "state": {
573 | "_view_name": "StyleView",
574 | "_model_name": "ProgressStyleModel",
575 | "description_width": "",
576 | "_view_module": "@jupyter-widgets/base",
577 | "_model_module_version": "1.5.0",
578 | "_view_count": null,
579 | "_view_module_version": "1.2.0",
580 | "bar_color": null,
581 | "_model_module": "@jupyter-widgets/controls"
582 | }
583 | },
584 | "cbff745d3ebb4ba0b9bfd8b1a210550f": {
585 | "model_module": "@jupyter-widgets/base",
586 | "model_name": "LayoutModel",
587 | "model_module_version": "1.2.0",
588 | "state": {
589 | "_view_name": "LayoutView",
590 | "grid_template_rows": null,
591 | "right": null,
592 | "justify_content": null,
593 | "_view_module": "@jupyter-widgets/base",
594 | "overflow": null,
595 | "_model_module_version": "1.2.0",
596 | "_view_count": null,
597 | "flex_flow": null,
598 | "width": null,
599 | "min_width": null,
600 | "border": null,
601 | "align_items": null,
602 | "bottom": null,
603 | "_model_module": "@jupyter-widgets/base",
604 | "top": null,
605 | "grid_column": null,
606 | "overflow_y": null,
607 | "overflow_x": null,
608 | "grid_auto_flow": null,
609 | "grid_area": null,
610 | "grid_template_columns": null,
611 | "flex": null,
612 | "_model_name": "LayoutModel",
613 | "justify_items": null,
614 | "grid_row": null,
615 | "max_height": null,
616 | "align_content": null,
617 | "visibility": null,
618 | "align_self": null,
619 | "height": null,
620 | "min_height": null,
621 | "padding": null,
622 | "grid_auto_rows": null,
623 | "grid_gap": null,
624 | "max_width": null,
625 | "order": null,
626 | "_view_module_version": "1.2.0",
627 | "grid_template_areas": null,
628 | "object_position": null,
629 | "object_fit": null,
630 | "grid_auto_columns": null,
631 | "margin": null,
632 | "display": null,
633 | "left": null
634 | }
635 | },
636 | "db0e50cf79174b99b7943fe9fa51b738": {
637 | "model_module": "@jupyter-widgets/controls",
638 | "model_name": "DescriptionStyleModel",
639 | "model_module_version": "1.5.0",
640 | "state": {
641 | "_view_name": "StyleView",
642 | "_model_name": "DescriptionStyleModel",
643 | "description_width": "",
644 | "_view_module": "@jupyter-widgets/base",
645 | "_model_module_version": "1.5.0",
646 | "_view_count": null,
647 | "_view_module_version": "1.2.0",
648 | "_model_module": "@jupyter-widgets/controls"
649 | }
650 | },
651 | "91e992390cf94d4bb23164ccbfb2a4d0": {
652 | "model_module": "@jupyter-widgets/base",
653 | "model_name": "LayoutModel",
654 | "model_module_version": "1.2.0",
655 | "state": {
656 | "_view_name": "LayoutView",
657 | "grid_template_rows": null,
658 | "right": null,
659 | "justify_content": null,
660 | "_view_module": "@jupyter-widgets/base",
661 | "overflow": null,
662 | "_model_module_version": "1.2.0",
663 | "_view_count": null,
664 | "flex_flow": null,
665 | "width": null,
666 | "min_width": null,
667 | "border": null,
668 | "align_items": null,
669 | "bottom": null,
670 | "_model_module": "@jupyter-widgets/base",
671 | "top": null,
672 | "grid_column": null,
673 | "overflow_y": null,
674 | "overflow_x": null,
675 | "grid_auto_flow": null,
676 | "grid_area": null,
677 | "grid_template_columns": null,
678 | "flex": null,
679 | "_model_name": "LayoutModel",
680 | "justify_items": null,
681 | "grid_row": null,
682 | "max_height": null,
683 | "align_content": null,
684 | "visibility": null,
685 | "align_self": null,
686 | "height": null,
687 | "min_height": null,
688 | "padding": null,
689 | "grid_auto_rows": null,
690 | "grid_gap": null,
691 | "max_width": null,
692 | "order": null,
693 | "_view_module_version": "1.2.0",
694 | "grid_template_areas": null,
695 | "object_position": null,
696 | "object_fit": null,
697 | "grid_auto_columns": null,
698 | "margin": null,
699 | "display": null,
700 | "left": null
701 | }
702 | },
703 | "a66fb6293209491caf0f05bbab72c1f3": {
704 | "model_module": "@jupyter-widgets/controls",
705 | "model_name": "HBoxModel",
706 | "model_module_version": "1.5.0",
707 | "state": {
708 | "_view_name": "HBoxView",
709 | "_dom_classes": [],
710 | "_model_name": "HBoxModel",
711 | "_view_module": "@jupyter-widgets/controls",
712 | "_model_module_version": "1.5.0",
713 | "_view_count": null,
714 | "_view_module_version": "1.5.0",
715 | "box_style": "",
716 | "layout": "IPY_MODEL_b226e034649d4a27968a24f0215acfa2",
717 | "_model_module": "@jupyter-widgets/controls",
718 | "children": [
719 | "IPY_MODEL_143a846ef031430a94ca03ce798e1c38",
720 | "IPY_MODEL_02cd0740d8394ee9978e2205bd7c0885",
721 | "IPY_MODEL_fe2d58558248438f89639669b20d044a"
722 | ]
723 | }
724 | },
725 | "b226e034649d4a27968a24f0215acfa2": {
726 | "model_module": "@jupyter-widgets/base",
727 | "model_name": "LayoutModel",
728 | "model_module_version": "1.2.0",
729 | "state": {
730 | "_view_name": "LayoutView",
731 | "grid_template_rows": null,
732 | "right": null,
733 | "justify_content": null,
734 | "_view_module": "@jupyter-widgets/base",
735 | "overflow": null,
736 | "_model_module_version": "1.2.0",
737 | "_view_count": null,
738 | "flex_flow": null,
739 | "width": null,
740 | "min_width": null,
741 | "border": null,
742 | "align_items": null,
743 | "bottom": null,
744 | "_model_module": "@jupyter-widgets/base",
745 | "top": null,
746 | "grid_column": null,
747 | "overflow_y": null,
748 | "overflow_x": null,
749 | "grid_auto_flow": null,
750 | "grid_area": null,
751 | "grid_template_columns": null,
752 | "flex": null,
753 | "_model_name": "LayoutModel",
754 | "justify_items": null,
755 | "grid_row": null,
756 | "max_height": null,
757 | "align_content": null,
758 | "visibility": null,
759 | "align_self": null,
760 | "height": null,
761 | "min_height": null,
762 | "padding": null,
763 | "grid_auto_rows": null,
764 | "grid_gap": null,
765 | "max_width": null,
766 | "order": null,
767 | "_view_module_version": "1.2.0",
768 | "grid_template_areas": null,
769 | "object_position": null,
770 | "object_fit": null,
771 | "grid_auto_columns": null,
772 | "margin": null,
773 | "display": null,
774 | "left": null
775 | }
776 | },
777 | "143a846ef031430a94ca03ce798e1c38": {
778 | "model_module": "@jupyter-widgets/controls",
779 | "model_name": "HTMLModel",
780 | "model_module_version": "1.5.0",
781 | "state": {
782 | "_view_name": "HTMLView",
783 | "style": "IPY_MODEL_63e7051bccb64b19a9eb416d6a2aa63d",
784 | "_dom_classes": [],
785 | "description": "",
786 | "_model_name": "HTMLModel",
787 | "placeholder": "",
788 | "_view_module": "@jupyter-widgets/controls",
789 | "_model_module_version": "1.5.0",
790 | "value": "Downloading: 100%",
791 | "_view_count": null,
792 | "_view_module_version": "1.5.0",
793 | "description_tooltip": null,
794 | "_model_module": "@jupyter-widgets/controls",
795 | "layout": "IPY_MODEL_5cb83ac370b54d9bb51443cfb4b1d324"
796 | }
797 | },
798 | "02cd0740d8394ee9978e2205bd7c0885": {
799 | "model_module": "@jupyter-widgets/controls",
800 | "model_name": "FloatProgressModel",
801 | "model_module_version": "1.5.0",
802 | "state": {
803 | "_view_name": "ProgressView",
804 | "style": "IPY_MODEL_33843555ad2e49789f8704d6abda7c00",
805 | "_dom_classes": [],
806 | "description": "",
807 | "_model_name": "FloatProgressModel",
808 | "bar_style": "success",
809 | "max": 48,
810 | "_view_module": "@jupyter-widgets/controls",
811 | "_model_module_version": "1.5.0",
812 | "value": 48,
813 | "_view_count": null,
814 | "_view_module_version": "1.5.0",
815 | "orientation": "horizontal",
816 | "min": 0,
817 | "description_tooltip": null,
818 | "_model_module": "@jupyter-widgets/controls",
819 | "layout": "IPY_MODEL_0ce9c5b524074ee9996cfb7ba2c35c83"
820 | }
821 | },
822 | "fe2d58558248438f89639669b20d044a": {
823 | "model_module": "@jupyter-widgets/controls",
824 | "model_name": "HTMLModel",
825 | "model_module_version": "1.5.0",
826 | "state": {
827 | "_view_name": "HTMLView",
828 | "style": "IPY_MODEL_04653b5fc47244c7abe96622f28b13f7",
829 | "_dom_classes": [],
830 | "description": "",
831 | "_model_name": "HTMLModel",
832 | "placeholder": "",
833 | "_view_module": "@jupyter-widgets/controls",
834 | "_model_module_version": "1.5.0",
835 | "value": " 48.0/48.0 [00:00<00:00, 493B/s]",
836 | "_view_count": null,
837 | "_view_module_version": "1.5.0",
838 | "description_tooltip": null,
839 | "_model_module": "@jupyter-widgets/controls",
840 | "layout": "IPY_MODEL_40325b42a719450eb4ad85b7c4ff0423"
841 | }
842 | },
843 | "63e7051bccb64b19a9eb416d6a2aa63d": {
844 | "model_module": "@jupyter-widgets/controls",
845 | "model_name": "DescriptionStyleModel",
846 | "model_module_version": "1.5.0",
847 | "state": {
848 | "_view_name": "StyleView",
849 | "_model_name": "DescriptionStyleModel",
850 | "description_width": "",
851 | "_view_module": "@jupyter-widgets/base",
852 | "_model_module_version": "1.5.0",
853 | "_view_count": null,
854 | "_view_module_version": "1.2.0",
855 | "_model_module": "@jupyter-widgets/controls"
856 | }
857 | },
858 | "5cb83ac370b54d9bb51443cfb4b1d324": {
859 | "model_module": "@jupyter-widgets/base",
860 | "model_name": "LayoutModel",
861 | "model_module_version": "1.2.0",
862 | "state": {
863 | "_view_name": "LayoutView",
864 | "grid_template_rows": null,
865 | "right": null,
866 | "justify_content": null,
867 | "_view_module": "@jupyter-widgets/base",
868 | "overflow": null,
869 | "_model_module_version": "1.2.0",
870 | "_view_count": null,
871 | "flex_flow": null,
872 | "width": null,
873 | "min_width": null,
874 | "border": null,
875 | "align_items": null,
876 | "bottom": null,
877 | "_model_module": "@jupyter-widgets/base",
878 | "top": null,
879 | "grid_column": null,
880 | "overflow_y": null,
881 | "overflow_x": null,
882 | "grid_auto_flow": null,
883 | "grid_area": null,
884 | "grid_template_columns": null,
885 | "flex": null,
886 | "_model_name": "LayoutModel",
887 | "justify_items": null,
888 | "grid_row": null,
889 | "max_height": null,
890 | "align_content": null,
891 | "visibility": null,
892 | "align_self": null,
893 | "height": null,
894 | "min_height": null,
895 | "padding": null,
896 | "grid_auto_rows": null,
897 | "grid_gap": null,
898 | "max_width": null,
899 | "order": null,
900 | "_view_module_version": "1.2.0",
901 | "grid_template_areas": null,
902 | "object_position": null,
903 | "object_fit": null,
904 | "grid_auto_columns": null,
905 | "margin": null,
906 | "display": null,
907 | "left": null
908 | }
909 | },
910 | "33843555ad2e49789f8704d6abda7c00": {
911 | "model_module": "@jupyter-widgets/controls",
912 | "model_name": "ProgressStyleModel",
913 | "model_module_version": "1.5.0",
914 | "state": {
915 | "_view_name": "StyleView",
916 | "_model_name": "ProgressStyleModel",
917 | "description_width": "",
918 | "_view_module": "@jupyter-widgets/base",
919 | "_model_module_version": "1.5.0",
920 | "_view_count": null,
921 | "_view_module_version": "1.2.0",
922 | "bar_color": null,
923 | "_model_module": "@jupyter-widgets/controls"
924 | }
925 | },
926 | "0ce9c5b524074ee9996cfb7ba2c35c83": {
927 | "model_module": "@jupyter-widgets/base",
928 | "model_name": "LayoutModel",
929 | "model_module_version": "1.2.0",
930 | "state": {
931 | "_view_name": "LayoutView",
932 | "grid_template_rows": null,
933 | "right": null,
934 | "justify_content": null,
935 | "_view_module": "@jupyter-widgets/base",
936 | "overflow": null,
937 | "_model_module_version": "1.2.0",
938 | "_view_count": null,
939 | "flex_flow": null,
940 | "width": null,
941 | "min_width": null,
942 | "border": null,
943 | "align_items": null,
944 | "bottom": null,
945 | "_model_module": "@jupyter-widgets/base",
946 | "top": null,
947 | "grid_column": null,
948 | "overflow_y": null,
949 | "overflow_x": null,
950 | "grid_auto_flow": null,
951 | "grid_area": null,
952 | "grid_template_columns": null,
953 | "flex": null,
954 | "_model_name": "LayoutModel",
955 | "justify_items": null,
956 | "grid_row": null,
957 | "max_height": null,
958 | "align_content": null,
959 | "visibility": null,
960 | "align_self": null,
961 | "height": null,
962 | "min_height": null,
963 | "padding": null,
964 | "grid_auto_rows": null,
965 | "grid_gap": null,
966 | "max_width": null,
967 | "order": null,
968 | "_view_module_version": "1.2.0",
969 | "grid_template_areas": null,
970 | "object_position": null,
971 | "object_fit": null,
972 | "grid_auto_columns": null,
973 | "margin": null,
974 | "display": null,
975 | "left": null
976 | }
977 | },
978 | "04653b5fc47244c7abe96622f28b13f7": {
979 | "model_module": "@jupyter-widgets/controls",
980 | "model_name": "DescriptionStyleModel",
981 | "model_module_version": "1.5.0",
982 | "state": {
983 | "_view_name": "StyleView",
984 | "_model_name": "DescriptionStyleModel",
985 | "description_width": "",
986 | "_view_module": "@jupyter-widgets/base",
987 | "_model_module_version": "1.5.0",
988 | "_view_count": null,
989 | "_view_module_version": "1.2.0",
990 | "_model_module": "@jupyter-widgets/controls"
991 | }
992 | },
993 | "40325b42a719450eb4ad85b7c4ff0423": {
994 | "model_module": "@jupyter-widgets/base",
995 | "model_name": "LayoutModel",
996 | "model_module_version": "1.2.0",
997 | "state": {
998 | "_view_name": "LayoutView",
999 | "grid_template_rows": null,
1000 | "right": null,
1001 | "justify_content": null,
1002 | "_view_module": "@jupyter-widgets/base",
1003 | "overflow": null,
1004 | "_model_module_version": "1.2.0",
1005 | "_view_count": null,
1006 | "flex_flow": null,
1007 | "width": null,
1008 | "min_width": null,
1009 | "border": null,
1010 | "align_items": null,
1011 | "bottom": null,
1012 | "_model_module": "@jupyter-widgets/base",
1013 | "top": null,
1014 | "grid_column": null,
1015 | "overflow_y": null,
1016 | "overflow_x": null,
1017 | "grid_auto_flow": null,
1018 | "grid_area": null,
1019 | "grid_template_columns": null,
1020 | "flex": null,
1021 | "_model_name": "LayoutModel",
1022 | "justify_items": null,
1023 | "grid_row": null,
1024 | "max_height": null,
1025 | "align_content": null,
1026 | "visibility": null,
1027 | "align_self": null,
1028 | "height": null,
1029 | "min_height": null,
1030 | "padding": null,
1031 | "grid_auto_rows": null,
1032 | "grid_gap": null,
1033 | "max_width": null,
1034 | "order": null,
1035 | "_view_module_version": "1.2.0",
1036 | "grid_template_areas": null,
1037 | "object_position": null,
1038 | "object_fit": null,
1039 | "grid_auto_columns": null,
1040 | "margin": null,
1041 | "display": null,
1042 | "left": null
1043 | }
1044 | },
1045 | "a56872e4b7fd495f8ba3a1e2a63788bc": {
1046 | "model_module": "@jupyter-widgets/controls",
1047 | "model_name": "HBoxModel",
1048 | "model_module_version": "1.5.0",
1049 | "state": {
1050 | "_view_name": "HBoxView",
1051 | "_dom_classes": [],
1052 | "_model_name": "HBoxModel",
1053 | "_view_module": "@jupyter-widgets/controls",
1054 | "_model_module_version": "1.5.0",
1055 | "_view_count": null,
1056 | "_view_module_version": "1.5.0",
1057 | "box_style": "",
1058 | "layout": "IPY_MODEL_edd0e5ac02764983bcf87308769e20af",
1059 | "_model_module": "@jupyter-widgets/controls",
1060 | "children": [
1061 | "IPY_MODEL_c847204177474e33a61f49c4db5d8ac1",
1062 | "IPY_MODEL_0d1bdbc672ee4f83b836b363863e223f",
1063 | "IPY_MODEL_ac7b1ed676894875971eadbfc22bfd07"
1064 | ]
1065 | }
1066 | },
1067 | "edd0e5ac02764983bcf87308769e20af": {
1068 | "model_module": "@jupyter-widgets/base",
1069 | "model_name": "LayoutModel",
1070 | "model_module_version": "1.2.0",
1071 | "state": {
1072 | "_view_name": "LayoutView",
1073 | "grid_template_rows": null,
1074 | "right": null,
1075 | "justify_content": null,
1076 | "_view_module": "@jupyter-widgets/base",
1077 | "overflow": null,
1078 | "_model_module_version": "1.2.0",
1079 | "_view_count": null,
1080 | "flex_flow": null,
1081 | "width": null,
1082 | "min_width": null,
1083 | "border": null,
1084 | "align_items": null,
1085 | "bottom": null,
1086 | "_model_module": "@jupyter-widgets/base",
1087 | "top": null,
1088 | "grid_column": null,
1089 | "overflow_y": null,
1090 | "overflow_x": null,
1091 | "grid_auto_flow": null,
1092 | "grid_area": null,
1093 | "grid_template_columns": null,
1094 | "flex": null,
1095 | "_model_name": "LayoutModel",
1096 | "justify_items": null,
1097 | "grid_row": null,
1098 | "max_height": null,
1099 | "align_content": null,
1100 | "visibility": null,
1101 | "align_self": null,
1102 | "height": null,
1103 | "min_height": null,
1104 | "padding": null,
1105 | "grid_auto_rows": null,
1106 | "grid_gap": null,
1107 | "max_width": null,
1108 | "order": null,
1109 | "_view_module_version": "1.2.0",
1110 | "grid_template_areas": null,
1111 | "object_position": null,
1112 | "object_fit": null,
1113 | "grid_auto_columns": null,
1114 | "margin": null,
1115 | "display": null,
1116 | "left": null
1117 | }
1118 | },
1119 | "c847204177474e33a61f49c4db5d8ac1": {
1120 | "model_module": "@jupyter-widgets/controls",
1121 | "model_name": "HTMLModel",
1122 | "model_module_version": "1.5.0",
1123 | "state": {
1124 | "_view_name": "HTMLView",
1125 | "style": "IPY_MODEL_9f56b91e83444c8b9272a7a469156ca9",
1126 | "_dom_classes": [],
1127 | "description": "",
1128 | "_model_name": "HTMLModel",
1129 | "placeholder": "",
1130 | "_view_module": "@jupyter-widgets/controls",
1131 | "_model_module_version": "1.5.0",
1132 | "value": "Downloading: 100%",
1133 | "_view_count": null,
1134 | "_view_module_version": "1.5.0",
1135 | "description_tooltip": null,
1136 | "_model_module": "@jupyter-widgets/controls",
1137 | "layout": "IPY_MODEL_31f205c738f34949a55a20f282e7ce79"
1138 | }
1139 | },
1140 | "0d1bdbc672ee4f83b836b363863e223f": {
1141 | "model_module": "@jupyter-widgets/controls",
1142 | "model_name": "FloatProgressModel",
1143 | "model_module_version": "1.5.0",
1144 | "state": {
1145 | "_view_name": "ProgressView",
1146 | "style": "IPY_MODEL_a618850b03e949c495390e4c249ae29d",
1147 | "_dom_classes": [],
1148 | "description": "",
1149 | "_model_name": "FloatProgressModel",
1150 | "bar_style": "success",
1151 | "max": 231508,
1152 | "_view_module": "@jupyter-widgets/controls",
1153 | "_model_module_version": "1.5.0",
1154 | "value": 231508,
1155 | "_view_count": null,
1156 | "_view_module_version": "1.5.0",
1157 | "orientation": "horizontal",
1158 | "min": 0,
1159 | "description_tooltip": null,
1160 | "_model_module": "@jupyter-widgets/controls",
1161 | "layout": "IPY_MODEL_4846e7b941924fd8882c82def4094005"
1162 | }
1163 | },
1164 | "ac7b1ed676894875971eadbfc22bfd07": {
1165 | "model_module": "@jupyter-widgets/controls",
1166 | "model_name": "HTMLModel",
1167 | "model_module_version": "1.5.0",
1168 | "state": {
1169 | "_view_name": "HTMLView",
1170 | "style": "IPY_MODEL_2d2e0dc44d154729adcff6ced9f29042",
1171 | "_dom_classes": [],
1172 | "description": "",
1173 | "_model_name": "HTMLModel",
1174 | "placeholder": "",
1175 | "_view_module": "@jupyter-widgets/controls",
1176 | "_model_module_version": "1.5.0",
1177 | "value": " 226k/226k [00:00<00:00, 1.04MB/s]",
1178 | "_view_count": null,
1179 | "_view_module_version": "1.5.0",
1180 | "description_tooltip": null,
1181 | "_model_module": "@jupyter-widgets/controls",
1182 | "layout": "IPY_MODEL_25b532f5020a4790b3ec5e4a71e1b91b"
1183 | }
1184 | },
1185 | "9f56b91e83444c8b9272a7a469156ca9": {
1186 | "model_module": "@jupyter-widgets/controls",
1187 | "model_name": "DescriptionStyleModel",
1188 | "model_module_version": "1.5.0",
1189 | "state": {
1190 | "_view_name": "StyleView",
1191 | "_model_name": "DescriptionStyleModel",
1192 | "description_width": "",
1193 | "_view_module": "@jupyter-widgets/base",
1194 | "_model_module_version": "1.5.0",
1195 | "_view_count": null,
1196 | "_view_module_version": "1.2.0",
1197 | "_model_module": "@jupyter-widgets/controls"
1198 | }
1199 | },
1200 | "31f205c738f34949a55a20f282e7ce79": {
1201 | "model_module": "@jupyter-widgets/base",
1202 | "model_name": "LayoutModel",
1203 | "model_module_version": "1.2.0",
1204 | "state": {
1205 | "_view_name": "LayoutView",
1206 | "grid_template_rows": null,
1207 | "right": null,
1208 | "justify_content": null,
1209 | "_view_module": "@jupyter-widgets/base",
1210 | "overflow": null,
1211 | "_model_module_version": "1.2.0",
1212 | "_view_count": null,
1213 | "flex_flow": null,
1214 | "width": null,
1215 | "min_width": null,
1216 | "border": null,
1217 | "align_items": null,
1218 | "bottom": null,
1219 | "_model_module": "@jupyter-widgets/base",
1220 | "top": null,
1221 | "grid_column": null,
1222 | "overflow_y": null,
1223 | "overflow_x": null,
1224 | "grid_auto_flow": null,
1225 | "grid_area": null,
1226 | "grid_template_columns": null,
1227 | "flex": null,
1228 | "_model_name": "LayoutModel",
1229 | "justify_items": null,
1230 | "grid_row": null,
1231 | "max_height": null,
1232 | "align_content": null,
1233 | "visibility": null,
1234 | "align_self": null,
1235 | "height": null,
1236 | "min_height": null,
1237 | "padding": null,
1238 | "grid_auto_rows": null,
1239 | "grid_gap": null,
1240 | "max_width": null,
1241 | "order": null,
1242 | "_view_module_version": "1.2.0",
1243 | "grid_template_areas": null,
1244 | "object_position": null,
1245 | "object_fit": null,
1246 | "grid_auto_columns": null,
1247 | "margin": null,
1248 | "display": null,
1249 | "left": null
1250 | }
1251 | },
1252 | "a618850b03e949c495390e4c249ae29d": {
1253 | "model_module": "@jupyter-widgets/controls",
1254 | "model_name": "ProgressStyleModel",
1255 | "model_module_version": "1.5.0",
1256 | "state": {
1257 | "_view_name": "StyleView",
1258 | "_model_name": "ProgressStyleModel",
1259 | "description_width": "",
1260 | "_view_module": "@jupyter-widgets/base",
1261 | "_model_module_version": "1.5.0",
1262 | "_view_count": null,
1263 | "_view_module_version": "1.2.0",
1264 | "bar_color": null,
1265 | "_model_module": "@jupyter-widgets/controls"
1266 | }
1267 | },
1268 | "4846e7b941924fd8882c82def4094005": {
1269 | "model_module": "@jupyter-widgets/base",
1270 | "model_name": "LayoutModel",
1271 | "model_module_version": "1.2.0",
1272 | "state": {
1273 | "_view_name": "LayoutView",
1274 | "grid_template_rows": null,
1275 | "right": null,
1276 | "justify_content": null,
1277 | "_view_module": "@jupyter-widgets/base",
1278 | "overflow": null,
1279 | "_model_module_version": "1.2.0",
1280 | "_view_count": null,
1281 | "flex_flow": null,
1282 | "width": null,
1283 | "min_width": null,
1284 | "border": null,
1285 | "align_items": null,
1286 | "bottom": null,
1287 | "_model_module": "@jupyter-widgets/base",
1288 | "top": null,
1289 | "grid_column": null,
1290 | "overflow_y": null,
1291 | "overflow_x": null,
1292 | "grid_auto_flow": null,
1293 | "grid_area": null,
1294 | "grid_template_columns": null,
1295 | "flex": null,
1296 | "_model_name": "LayoutModel",
1297 | "justify_items": null,
1298 | "grid_row": null,
1299 | "max_height": null,
1300 | "align_content": null,
1301 | "visibility": null,
1302 | "align_self": null,
1303 | "height": null,
1304 | "min_height": null,
1305 | "padding": null,
1306 | "grid_auto_rows": null,
1307 | "grid_gap": null,
1308 | "max_width": null,
1309 | "order": null,
1310 | "_view_module_version": "1.2.0",
1311 | "grid_template_areas": null,
1312 | "object_position": null,
1313 | "object_fit": null,
1314 | "grid_auto_columns": null,
1315 | "margin": null,
1316 | "display": null,
1317 | "left": null
1318 | }
1319 | },
1320 | "2d2e0dc44d154729adcff6ced9f29042": {
1321 | "model_module": "@jupyter-widgets/controls",
1322 | "model_name": "DescriptionStyleModel",
1323 | "model_module_version": "1.5.0",
1324 | "state": {
1325 | "_view_name": "StyleView",
1326 | "_model_name": "DescriptionStyleModel",
1327 | "description_width": "",
1328 | "_view_module": "@jupyter-widgets/base",
1329 | "_model_module_version": "1.5.0",
1330 | "_view_count": null,
1331 | "_view_module_version": "1.2.0",
1332 | "_model_module": "@jupyter-widgets/controls"
1333 | }
1334 | },
1335 | "25b532f5020a4790b3ec5e4a71e1b91b": {
1336 | "model_module": "@jupyter-widgets/base",
1337 | "model_name": "LayoutModel",
1338 | "model_module_version": "1.2.0",
1339 | "state": {
1340 | "_view_name": "LayoutView",
1341 | "grid_template_rows": null,
1342 | "right": null,
1343 | "justify_content": null,
1344 | "_view_module": "@jupyter-widgets/base",
1345 | "overflow": null,
1346 | "_model_module_version": "1.2.0",
1347 | "_view_count": null,
1348 | "flex_flow": null,
1349 | "width": null,
1350 | "min_width": null,
1351 | "border": null,
1352 | "align_items": null,
1353 | "bottom": null,
1354 | "_model_module": "@jupyter-widgets/base",
1355 | "top": null,
1356 | "grid_column": null,
1357 | "overflow_y": null,
1358 | "overflow_x": null,
1359 | "grid_auto_flow": null,
1360 | "grid_area": null,
1361 | "grid_template_columns": null,
1362 | "flex": null,
1363 | "_model_name": "LayoutModel",
1364 | "justify_items": null,
1365 | "grid_row": null,
1366 | "max_height": null,
1367 | "align_content": null,
1368 | "visibility": null,
1369 | "align_self": null,
1370 | "height": null,
1371 | "min_height": null,
1372 | "padding": null,
1373 | "grid_auto_rows": null,
1374 | "grid_gap": null,
1375 | "max_width": null,
1376 | "order": null,
1377 | "_view_module_version": "1.2.0",
1378 | "grid_template_areas": null,
1379 | "object_position": null,
1380 | "object_fit": null,
1381 | "grid_auto_columns": null,
1382 | "margin": null,
1383 | "display": null,
1384 | "left": null
1385 | }
1386 | }
1387 | }
1388 | }
1389 | },
1390 | "cells": [
1391 | {
1392 | "cell_type": "markdown",
1393 | "metadata": {
1394 | "id": "view-in-github",
1395 | "colab_type": "text"
1396 | },
1397 | "source": [
1398 | "
"
1399 | ]
1400 | },
1401 | {
1402 | "cell_type": "markdown",
1403 | "metadata": {
1404 | "id": "4IYErMLarJTr"
1405 | },
1406 | "source": [
1407 | "## Example of TreeMix"
1408 | ]
1409 | },
1410 | {
1411 | "cell_type": "markdown",
1412 | "metadata": {
1413 | "id": "XJEVX6F9rQdI"
1414 | },
1415 | "source": [
1416 | "NOTE : This page was originally used by HuggingFace to illustrate the summary of various tasks ([original page](https://colab.research.google.com/github/huggingface/notebooks/blob/master/transformers_doc/pytorch/task_summary.ipynb#scrollTo=XJEVX6F9rQdI)), we use it to show the examples we illustrate in our paper. We follow the orginal settings and just change the sentence in to predict. This is a sequence classification model trained on full SST2 datasets."
1417 | ]
1418 | },
1419 | {
1420 | "cell_type": "code",
1421 | "metadata": {
1422 | "id": "D4qYNi9oqDXN",
1423 | "colab": {
1424 | "base_uri": "https://localhost:8080/"
1425 | },
1426 | "outputId": "0aa9ff6c-9b18-4e98-90d6-24057dbbb320"
1427 | },
1428 | "source": [
1429 | "# Transformers installation\n",
1430 | "! pip install transformers datasets\n",
1431 | "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
1432 | "# ! pip install git+https://github.com/huggingface/transformers.git\n"
1433 | ],
1434 | "execution_count": null,
1435 | "outputs": [
1436 | {
1437 | "output_type": "stream",
1438 | "name": "stdout",
1439 | "text": [
1440 | "Collecting transformers\n",
1441 | " Downloading transformers-4.12.2-py3-none-any.whl (3.1 MB)\n",
1442 | "\u001b[K |████████████████████████████████| 3.1 MB 4.9 MB/s \n",
1443 | "\u001b[?25hCollecting datasets\n",
1444 | " Downloading datasets-1.14.0-py3-none-any.whl (290 kB)\n",
1445 | "\u001b[K |████████████████████████████████| 290 kB 42.2 MB/s \n",
1446 | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.3.0)\n",
1447 | "Collecting huggingface-hub>=0.0.17\n",
1448 | " Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)\n",
1449 | "\u001b[K |████████████████████████████████| 56 kB 4.1 MB/s \n",
1450 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
1451 | "Collecting sacremoses\n",
1452 | " Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)\n",
1453 | "\u001b[K |████████████████████████████████| 895 kB 56.3 MB/s \n",
1454 | "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n",
1455 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1)\n",
1456 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n",
1457 | "Collecting pyyaml>=5.1\n",
1458 | " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n",
1459 | "\u001b[K |████████████████████████████████| 596 kB 57.1 MB/s \n",
1460 | "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n",
1461 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
1462 | "Collecting tokenizers<0.11,>=0.10.1\n",
1463 | " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n",
1464 | "\u001b[K |████████████████████████████████| 3.3 MB 30.8 MB/s \n",
1465 | "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3)\n",
1466 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7)\n",
1467 | "Collecting xxhash\n",
1468 | " Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)\n",
1469 | "\u001b[K |████████████████████████████████| 243 kB 57.5 MB/s \n",
1470 | "\u001b[?25hRequirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)\n",
1471 | "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)\n",
1472 | "Collecting aiohttp\n",
1473 | " Downloading aiohttp-3.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n",
1474 | "\u001b[K |████████████████████████████████| 1.1 MB 57.7 MB/s \n",
1475 | "\u001b[?25hCollecting fsspec[http]>=2021.05.0\n",
1476 | " Downloading fsspec-2021.10.1-py3-none-any.whl (125 kB)\n",
1477 | "\u001b[K |████████████████████████████████| 125 kB 60.6 MB/s \n",
1478 | "\u001b[?25hRequirement already satisfied: pyarrow!=4.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
1479 | "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n",
1480 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
1481 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n",
1482 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
1483 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n",
1484 | "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.7)\n",
1485 | "Collecting aiosignal>=1.1.2\n",
1486 | " Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n",
1487 | "Collecting yarl<2.0,>=1.0\n",
1488 | " Downloading yarl-1.7.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)\n",
1489 | "\u001b[K |████████████████████████████████| 271 kB 58.0 MB/s \n",
1490 | "\u001b[?25hCollecting asynctest==0.13.0\n",
1491 | " Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)\n",
1492 | "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.2.0)\n",
1493 | "Collecting frozenlist>=1.1.1\n",
1494 | " Downloading frozenlist-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (192 kB)\n",
1495 | "\u001b[K |████████████████████████████████| 192 kB 57.0 MB/s \n",
1496 | "\u001b[?25hCollecting multidict<7.0,>=4.5\n",
1497 | " Downloading multidict-5.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (160 kB)\n",
1498 | "\u001b[K |████████████████████████████████| 160 kB 57.1 MB/s \n",
1499 | "\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3\n",
1500 | " Downloading async_timeout-4.0.0a3-py3-none-any.whl (9.5 kB)\n",
1501 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n",
1502 | "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n",
1503 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n",
1504 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
1505 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n",
1506 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n",
1507 | "Installing collected packages: multidict, frozenlist, yarl, asynctest, async-timeout, aiosignal, pyyaml, fsspec, aiohttp, xxhash, tokenizers, sacremoses, huggingface-hub, transformers, datasets\n",
1508 | " Attempting uninstall: pyyaml\n",
1509 | " Found existing installation: PyYAML 3.13\n",
1510 | " Uninstalling PyYAML-3.13:\n",
1511 | " Successfully uninstalled PyYAML-3.13\n",
1512 | "Successfully installed aiohttp-3.8.0 aiosignal-1.2.0 async-timeout-4.0.0a3 asynctest-0.13.0 datasets-1.14.0 frozenlist-1.2.0 fsspec-2021.10.1 huggingface-hub-0.0.19 multidict-5.2.0 pyyaml-6.0 sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.12.2 xxhash-2.0.2 yarl-1.7.0\n"
1513 | ]
1514 | }
1515 | ]
1516 | },
1517 | {
1518 | "cell_type": "markdown",
1519 | "metadata": {
1520 | "id": "1FpvawfuqDXS"
1521 | },
1522 | "source": [
1523 | "## Sequence Classification"
1524 | ]
1525 | },
1526 | {
1527 | "cell_type": "markdown",
1528 | "metadata": {
1529 | "id": "HGQEKVGZqDXS"
1530 | },
1531 | "source": [
1532 | "Sequence classification is the task of classifying sequences according to a given number of classes. An example of\n",
1533 | "sequence classification is the GLUE dataset, which is entirely based on that task. If you would like to fine-tune a\n",
1534 | "model on a GLUE sequence classification task, you may leverage the :prefix_link:*run_glue.py\n",
1535 | "*, :prefix_link:*run_tf_glue.py\n",
1536 | "*, :prefix_link:*run_tf_text_classification.py\n",
1537 | "* or :prefix_link:*run_xnli.py\n",
1538 | "* scripts.\n",
1539 | "\n",
1540 | "Here is an example of using pipelines to do sentiment analysis: identifying if a sequence is positive or negative. It\n",
1541 | "leverages a fine-tuned model on sst2, which is a GLUE task.\n",
1542 | "\n",
1543 | "This returns a label (\"POSITIVE\" or \"NEGATIVE\") alongside a score, as follows:"
1544 | ]
1545 | },
1546 | {
1547 | "cell_type": "code",
1548 | "metadata": {
1549 | "id": "dcBfWoIaqDXT",
1550 | "colab": {
1551 | "base_uri": "https://localhost:8080/",
1552 | "height": 212,
1553 | "referenced_widgets": [
1554 | "7c08442589064421b77b6aa52ce5b947",
1555 | "23f3f545c30d4226987e7909e8365413",
1556 | "3d29469f501c4a0b96c211241459f34a",
1557 | "9578e983215a4bc497b5882a4244b8a9",
1558 | "47b6aefd82c04f0fb3a3de2c7f37c58f",
1559 | "b1e066e72ca04508bf50a60961e66bea",
1560 | "96fe7d4c88e541eea33dde3889e3e381",
1561 | "65b763eef5184606b28d4673e6ad857d",
1562 | "40984bf56a8b448e82451fa60fc15f7c",
1563 | "21bf499e3fb34c02b63934ea6e4f162f",
1564 | "9566729c64304007953c365eb0341df2",
1565 | "6bf72948e9ce47d296cf9673d5c11ced",
1566 | "e32995501a0243c1ab56ee42ecddd80e",
1567 | "e1c3b3a1b0ce49cc9db2bd16be25c929",
1568 | "a1f0479b17f8451daaf69b0af81b9adb",
1569 | "5f7a3c947c564436bd2dfcc04d4da5ee",
1570 | "d0d537e5dc744a5794187c5a498358fc",
1571 | "bdbbb4175a6c4bd389755dc0f8d186cb",
1572 | "902c7115997647b9a5140ba707d74390",
1573 | "cbff745d3ebb4ba0b9bfd8b1a210550f",
1574 | "db0e50cf79174b99b7943fe9fa51b738",
1575 | "91e992390cf94d4bb23164ccbfb2a4d0",
1576 | "a66fb6293209491caf0f05bbab72c1f3",
1577 | "b226e034649d4a27968a24f0215acfa2",
1578 | "143a846ef031430a94ca03ce798e1c38",
1579 | "02cd0740d8394ee9978e2205bd7c0885",
1580 | "fe2d58558248438f89639669b20d044a",
1581 | "63e7051bccb64b19a9eb416d6a2aa63d",
1582 | "5cb83ac370b54d9bb51443cfb4b1d324",
1583 | "33843555ad2e49789f8704d6abda7c00",
1584 | "0ce9c5b524074ee9996cfb7ba2c35c83",
1585 | "04653b5fc47244c7abe96622f28b13f7",
1586 | "40325b42a719450eb4ad85b7c4ff0423",
1587 | "a56872e4b7fd495f8ba3a1e2a63788bc",
1588 | "edd0e5ac02764983bcf87308769e20af",
1589 | "c847204177474e33a61f49c4db5d8ac1",
1590 | "0d1bdbc672ee4f83b836b363863e223f",
1591 | "ac7b1ed676894875971eadbfc22bfd07",
1592 | "9f56b91e83444c8b9272a7a469156ca9",
1593 | "31f205c738f34949a55a20f282e7ce79",
1594 | "a618850b03e949c495390e4c249ae29d",
1595 | "4846e7b941924fd8882c82def4094005",
1596 | "2d2e0dc44d154729adcff6ced9f29042",
1597 | "25b532f5020a4790b3ec5e4a71e1b91b"
1598 | ]
1599 | },
1600 | "outputId": "3ae13dd7-c7e5-45ac-eb01-48ecfa8620f3"
1601 | },
1602 | "source": [
1603 | "from transformers import pipeline\n",
1604 | "classifier = pipeline(\"sentiment-analysis\")\n",
1605 | "result = classifier(\"This film is good and every one loves it\")[0]\n",
1606 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n",
1607 | "result = classifier(\"The film is poor and I do not like it\")[0]\n",
1608 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n",
1609 | "result = classifier(\"This film is good but I do not like it\")[0]\n",
1610 | "print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")\n"
1611 | ],
1612 | "execution_count": null,
1613 | "outputs": [
1614 | {
1615 | "output_type": "stream",
1616 | "name": "stderr",
1617 | "text": [
1618 | "No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)\n"
1619 | ]
1620 | },
1621 | {
1622 | "output_type": "display_data",
1623 | "data": {
1624 | "application/vnd.jupyter.widget-view+json": {
1625 | "model_id": "7c08442589064421b77b6aa52ce5b947",
1626 | "version_minor": 0,
1627 | "version_major": 2
1628 | },
1629 | "text/plain": [
1630 | "Downloading: 0%| | 0.00/629 [00:00, ?B/s]"
1631 | ]
1632 | },
1633 | "metadata": {}
1634 | },
1635 | {
1636 | "output_type": "display_data",
1637 | "data": {
1638 | "application/vnd.jupyter.widget-view+json": {
1639 | "model_id": "6bf72948e9ce47d296cf9673d5c11ced",
1640 | "version_minor": 0,
1641 | "version_major": 2
1642 | },
1643 | "text/plain": [
1644 | "Downloading: 0%| | 0.00/255M [00:00, ?B/s]"
1645 | ]
1646 | },
1647 | "metadata": {}
1648 | },
1649 | {
1650 | "output_type": "display_data",
1651 | "data": {
1652 | "application/vnd.jupyter.widget-view+json": {
1653 | "model_id": "a66fb6293209491caf0f05bbab72c1f3",
1654 | "version_minor": 0,
1655 | "version_major": 2
1656 | },
1657 | "text/plain": [
1658 | "Downloading: 0%| | 0.00/48.0 [00:00, ?B/s]"
1659 | ]
1660 | },
1661 | "metadata": {}
1662 | },
1663 | {
1664 | "output_type": "display_data",
1665 | "data": {
1666 | "application/vnd.jupyter.widget-view+json": {
1667 | "model_id": "a56872e4b7fd495f8ba3a1e2a63788bc",
1668 | "version_minor": 0,
1669 | "version_major": 2
1670 | },
1671 | "text/plain": [
1672 | "Downloading: 0%| | 0.00/226k [00:00, ?B/s]"
1673 | ]
1674 | },
1675 | "metadata": {}
1676 | },
1677 | {
1678 | "output_type": "stream",
1679 | "name": "stdout",
1680 | "text": [
1681 | "label: POSITIVE, with score: 0.9999\n",
1682 | "label: NEGATIVE, with score: 0.9996\n",
1683 | "label: POSITIVE, with score: 0.7956\n"
1684 | ]
1685 | }
1686 | ]
1687 | },
1688 | {
1689 | "cell_type": "markdown",
1690 | "metadata": {
1691 | "id": "q-nYFRzzsLD2"
1692 | },
1693 | "source": [
1694 | "The first two examples are correctly classified.However, the last one, created by combining frag-ments from the first two, is wrongly classified.The model fails in last example because it can’trecognize the two fragments in the first two andinstead assigns a probability to the entire sentence,indicating the model’s poor compositional gener-alization capability. We humans can distinguishthe two parts in the last example. Although theyare contradictory, from the perspective of senti-ment classification, we will assign a certain scoreto both positive and negative, and the negativescore should be higher, so the ideal score maybepositive 40% and negative 60%. Unfortunately,such soft labels are difficult to appear in the actuallabeling process, although such complex examplesare countless in real life."
1695 | ]
1696 | }
1697 | ]
1698 | }
--------------------------------------------------------------------------------