├── README.md ├── _utils ├── utils.py └── would_like_to_pr.py ├── fine_tunning ├── SentiWSP_fine_tunning_ASBA.py ├── SentiWSP_fine_tunning_SA.py └── absa_data_utils.py ├── pretrain ├── SentiWSP_Pretrain_ANCE_GEN.py ├── SentiWSP_Pretrain_ANCE_TRAIN.py ├── SentiWSP_Pretrain_Warmup_inbatch.py └── SentiWSP_Pretrain_Word.py └── sentiment_vocab ├── senti_vector.npy └── senti_vocab.npy /README.md: -------------------------------------------------------------------------------- 1 | # SentiWSP 2 | ## For paper: Sentiment-Aware Word and Sentence Level Pre-training for Sentiment Analysis 3 | Shuai Fan, Chen Lin, Haonan Li, Zhenghao Lin, Jinsong Su, Hang Zhang, Yeyun Gong, Jian Guo, Nan Duan 4 | 5 | Xiamen University, The University of Melbourne, IDEA Research, Microsoft Research Asia 6 | 7 | paper link: ([https://arxiv.org/abs/2210.09803](https://arxiv.org/abs/2210.09803)) 8 | 9 | 10 | ## Dependencies 11 | - python>=3.6 12 | - torch>=1.7.1 13 | - datasets>=1.12.1 14 | - transformers>=4.9.2 (Huggingface) 15 | - fastcore>=1.3.29 16 | - fastai<=2.2.0 17 | - hugdatafast>=1.0.0 18 | - huggingface-hub>=0.0.19 19 | 20 | ## Quick Start for Fine-tunning 21 | Our experiments contain sentence-level sentiment classification (e.g. SST-5 / MR / IMDB / Yelp-2 / Yelp-5) and aspect-level sentiment analysis (e.g. Lap14 / Res14). 22 | ### Load our model(large) 23 | You can download the pre-train model in ([Google Drive](https://drive.google.com/drive/folders/1Azx30v2TdenuziOZB_ob3UfniO0yoLqa?usp=sharing)), and load our model by : 24 | ```python 25 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 26 | import torch 27 | 28 | tokenizer = AutoTokenizer.from_pretrained(save_path) 29 | model = AutoModelForSequenceClassification.from_pretrained(save_path) 30 | ``` 31 | You can also load our model in huggingface ([https://huggingface.co/shuaifan/SentiWSP](https://huggingface.co/shuaifan/SentiWSP)): 32 | ```python 33 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 34 | import torch 35 | 36 | tokenizer = AutoTokenizer.from_pretrained("shuaifan/SentiWSP") 37 | model = AutoModelForSequenceClassification.from_pretrained("shuaifan/SentiWSP") 38 | ``` 39 | ### Load our model(base) 40 | You can also load our base model in huggingface ([https://huggingface.co/shuaifan/SentiWSP-base](https://huggingface.co/shuaifan/SentiWSP-base)): 41 | ```python 42 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 43 | import torch 44 | 45 | tokenizer = AutoTokenizer.from_pretrained("shuaifan/SentiWSP-base") 46 | model = AutoModelForSequenceClassification.from_pretrained("shuaifan/SentiWSP-base") 47 | ``` 48 | 49 | ### Download downstream dataset 50 | You can download the downstream datasets from [huggingface/datasets](https://github.com/huggingface/datasets) and find download code in SentiWSP_fine_tunning_SA.py. Meanwhile, we also put some downstream datasets in ([Google Drive](https://drive.google.com/drive/folders/1Azx30v2TdenuziOZB_ob3UfniO0yoLqa?usp=sharing)). 51 | 52 | ### Fine-tunning 53 | We show the example of fine-tuning SentiWSP on sentence-level sentiment classification IMDB as follows: 54 | ```bash 55 | python SentiWSP_fine_tunning_SA.py 56 | --dataset=imdb 57 | --gpu_num=1 58 | --loadmodel=True 59 | --loadmodelpath=SentiWSP 60 | --batch_size=8 61 | --max_epoch=5 62 | --model_size=large 63 | --num_class=2 64 | ``` 65 | the example of fine-tuning SentiWSP on aspect-level sentiment analysis Lap14 as follows: 66 | ```bash 67 | python SentiWSP_fine_tunning_ASBA.py 68 | --dataset=laptop 69 | --model_name=SentiWSP 70 | --batch_size=32 71 | --max_epoch=10 72 | --max_len=128 73 | ``` 74 | For SentiWSP and SentiWSP-base, We fine-tune 3-5 epochs for sentence-level sentiment classification tasks and 7-10 epochs for aspect-level sentiment classification tasks. We use learning rate=2e-5 for SA tasks and 1e-5 for ASBA tasks. We use different batch_size for different model size: 75 | | model size | batch_size | max_sentence_length | 76 | | ---------- | ---------- | ------------------- | 77 | | base | 32 | 512 | 78 | | large | 8 | 512 | 79 | 80 | ## Pre-training 81 | If you want to conduct pre-training by yourself instead of directly using the checkpoint we provide, this part may help you pre-process the pre-training dataset and run the pre-training scripts. You should train the model on some NVIDIA Tesla A100 GPUs. 82 | 83 | ### Word-level pre-training 84 | 85 | ```bash 86 | python -m torch.distributed.launch 87 | --nproc_per_node=4 88 | --master_port=9999 89 | SentiWSP_Pretrain_Word.py 90 | --dataset=wiki 91 | --size=large 92 | --gpu_num=4 93 | --save_pretrain_model=./word5_large_model/ 94 | --max_len=128 95 | --batch_size=64 96 | --sentimask_prob=0.5 97 | ``` 98 | ### Sentence-level pre-training 99 | 1. Warm-up 100 | ```bash 101 | python -m torch.distributed.launch 102 | --nproc_per_node=4 103 | --master_port=9999 104 | SentiWSP_Pretrain_Warmup_inbatch.py 105 | --load_model=word5_large_model 106 | --gpu_num=4 107 | --batch_size=32 108 | --max_len=128 109 | --save_model=./word_sen_model/ 110 | ``` 111 | 2. Cross-batch 112 | - ANN Index Build: 113 | ```bash 114 | python SentiWSP_Pretrain_ANCE_GEN.py 115 | --gpu_num=1 116 | --sentimask_prob=0.7 117 | --max_length=128 118 | --model_path=word_sen_model 119 | ``` 120 | - Train: 121 | ```bash 122 | python -m torch.distributed.launch 123 | --nproc_per_node=4 124 | --master_port=9999 125 | SentiWSP_Pretrain_ANCE_TRAIN.py 126 | --load_model=word_sen_model 127 | --gpu_num=4 128 | --batch_size=32 129 | --max_len=128 130 | --save_model=./word_sen_model_iter_1/ 131 | ``` 132 | You should iteratively run "ANN Index Build" and "Train" alternately and change the save_model name or Write a shell script to loop run "ANN Index Build" and "Train" steps. 133 | 134 | 135 | ## Thanks 136 | Many thanks to the GitHub repositories of Huggingface Transformers, our codes are based on their framework. 137 | -------------------------------------------------------------------------------- /_utils/utils.py: -------------------------------------------------------------------------------- 1 | import random, re, os 2 | from functools import partial 3 | from fastai.text.all import * 4 | from hugdatafast.transform import CombineTransform 5 | 6 | class MyConfig(dict): 7 | def __getattr__(self, name): return self[name] 8 | def __setattr__(self, name, value): self[name] = value 9 | 10 | def adam_no_correction_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs): 11 | p.data.addcdiv_(grad_avg, (sqr_avg).sqrt() + eps, value = -lr) 12 | return p 13 | 14 | def Adam_no_bias_correction(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True): 15 | "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`" 16 | cbs = [weight_decay] if decouple_wd else [l2_reg] 17 | cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_no_correction_step] 18 | return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd) 19 | 20 | def linear_warmup_and_decay(pct, lr_max, total_steps, warmup_steps=None, warmup_pct=None, end_lr=0.0, decay_power=1): 21 | """ pct (float): fastai count it as ith_step/num_epoch*len(dl), so we can't just use pct when our num_epoch is fake.he ith_step is count from 0, """ 22 | if warmup_pct: warmup_steps = int(warmup_pct * total_steps) 23 | step_i = round(pct * total_steps) 24 | # According to the original source code, two schedules take effect at the same time, but decaying schedule will be neglible in the early time. 25 | decayed_lr = (lr_max-end_lr) * (1 - step_i/total_steps) ** decay_power + end_lr # https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/polynomial_decay 26 | warmed_lr = decayed_lr * min(1.0, step_i/warmup_steps) # https://github.com/google-research/electra/blob/81f7e5fc98b0ad8bfd20b641aa8bc9e6ac00c8eb/model/optimization.py#L44 27 | return warmed_lr 28 | 29 | def linear_warmup_and_then_decay(pct, lr_max, total_steps, warmup_steps=None, warmup_pct=None, end_lr=0.0, decay_power=1): 30 | """ pct (float): fastai count it as ith_step/num_epoch*len(dl), so we can't just use pct when our num_epoch is fake.he ith_step is count from 0, """ 31 | if warmup_pct: warmup_steps = int(warmup_pct * total_steps) 32 | step_i = round(pct * total_steps) 33 | if step_i <= warmup_steps: # warm up 34 | return lr_max * min(1.0, step_i/warmup_steps) 35 | else: # decay 36 | return (lr_max-end_lr) * (1 - (step_i-warmup_steps)/(total_steps-warmup_steps)) ** decay_power + end_lr 37 | 38 | def load_part_model(file, model, prefix, device=None, strict=True): 39 | "assume `model` is part of (child attribute at any level) of model whose states save in `file`." 40 | distrib_barrier() 41 | if prefix[-1] != '.': prefix += '.' 42 | if isinstance(device, int): device = torch.device('cuda', device) 43 | elif device is None: device = 'cpu' 44 | state = torch.load(file, map_location=device) 45 | hasopt = set(state)=={'model', 'opt'} 46 | model_state = state['model'] if hasopt else state 47 | model_state = {k[len(prefix):] : v for k,v in model_state.items() if k.startswith(prefix)} 48 | get_model(model).load_state_dict(model_state, strict=strict) 49 | 50 | def load_model_(learn, files, device=None, **kwargs): 51 | "if multiple file passed, then load and create an ensemble. Load normally otherwise" 52 | merge_out_fc = kwargs.pop('merge_out_fc', None) 53 | if not isinstance(files, list): 54 | learn.load(files, device=device, **kwargs) 55 | return 56 | if device is None: device = learn.dls.device 57 | model = learn.model.cpu() 58 | models = [model, *(deepcopy(model) for _ in range(len(files)-1)) ] 59 | for f,m in zip(files, models): 60 | file = join_path_file(f, learn.path/learn.model_dir, ext='.pth') 61 | load_model(file, m, learn.opt, device='cpu', **kwargs) 62 | learn.model = Ensemble(models, device, merge_out_fc) 63 | return learn 64 | 65 | class ConcatTransform(CombineTransform): 66 | def __init__(self, hf_dset, hf_tokenizer, max_length, text_col='text', book='multi'): 67 | super().__init__(hf_dset, in_cols=[text_col], out_cols=['input_ids', 'sentA_length']) 68 | self.max_length = max_length 69 | self.hf_tokenizer = hf_tokenizer 70 | self.book = book 71 | 72 | def reset_states(self): 73 | self.input_ids = [self.hf_tokenizer.cls_token_id] 74 | self.sent_lens = [] 75 | 76 | def accumulate(self, sentence): 77 | if 'isbn' in sentence: return 78 | tokens = self.hf_tokenizer.convert_tokens_to_ids(self.hf_tokenizer.tokenize(sentence)) 79 | tokens = tokens[:self.max_length-2] # trim sentence to max length if needed 80 | if self.book == 'single' or \ 81 | (len(self.input_ids) + len(tokens) + 1 > self.max_length) or \ 82 | (self.book == 'bi' and len(self.sent_lens)==2) : 83 | self.commit_example(self.create_example()) 84 | self.reset_states() 85 | self.input_ids += [*tokens, self.hf_tokenizer.sep_token_id] 86 | self.sent_lens.append(len(tokens)+1) 87 | 88 | def create_example(self): 89 | if not self.sent_lens: return None 90 | self.sent_lens[0] += 1 # cls 91 | if self.book == 'multi': 92 | diff= 99999999 93 | for i in range(len(self.sent_lens)): 94 | current_diff = abs(sum(self.sent_lens[:i+1]) - sum(self.sent_lens[i+1:])) 95 | if current_diff > diff: break 96 | diff = current_diff 97 | return {'input_ids': self.input_ids, 'sentA_length': sum(self.sent_lens[:i])} 98 | else: 99 | return {'input_ids': self.input_ids, 'sentA_length': self.sent_lens[0]} 100 | 101 | class ELECTRADataProcessor(object): 102 | """Given a stream of input text, creates pretraining examples.""" 103 | 104 | def __init__(self, hf_dset, hf_tokenizer, max_length, text_col='text', lines_delimiter='\n', minimize_data_size=True, apply_cleaning=True): 105 | self.hf_tokenizer = hf_tokenizer 106 | self._current_sentences = [] 107 | self._current_length = 0 108 | self._max_length = max_length 109 | self._target_length = max_length 110 | 111 | self.hf_dset = hf_dset 112 | self.text_col = text_col 113 | self.lines_delimiter = lines_delimiter 114 | self.minimize_data_size = minimize_data_size 115 | self.apply_cleaning = apply_cleaning 116 | 117 | def map(self, **kwargs): 118 | "Some settings of datasets.Dataset.map for ELECTRA data processing" 119 | num_proc = kwargs.pop('num_proc', os.cpu_count()) 120 | return self.hf_dset.my_map( 121 | function=self, 122 | batched=True, 123 | remove_columns=self.hf_dset.column_names, # this is must b/c we will return different number of rows 124 | disable_nullable=True, 125 | input_columns=[self.text_col], 126 | writer_batch_size=10**4, 127 | num_proc=num_proc, 128 | **kwargs 129 | ) 130 | 131 | def __call__(self, texts): 132 | if self.minimize_data_size: new_example = {'input_ids':[], 'sentA_length':[]} 133 | else: new_example = {'input_ids':[], 'input_mask': [], 'segment_ids': []} 134 | 135 | for text in texts: # for every doc 136 | 137 | for line in re.split(self.lines_delimiter, text): # for every paragraph 138 | 139 | if re.fullmatch(r'\s*', line): continue # empty string or string with all space characters 140 | if self.apply_cleaning and self.filter_out(line): continue 141 | 142 | example = self.add_line(line) 143 | if example: 144 | for k,v in example.items(): new_example[k].append(v) 145 | 146 | if self._current_length != 0: 147 | example = self._create_example() 148 | for k,v in example.items(): new_example[k].append(v) 149 | 150 | return new_example 151 | 152 | def filter_out(self, line): 153 | if len(line) < 80: return True 154 | return False 155 | 156 | def clean(self, line): 157 | # () is remainder after link in it filtered out 158 | return line.strip().replace("\n", " ").replace("()","") 159 | 160 | def add_line(self, line): 161 | """Adds a line of text to the current example being built.""" 162 | line = self.clean(line) 163 | tokens = self.hf_tokenizer.tokenize(line) 164 | tokids = self.hf_tokenizer.convert_tokens_to_ids(tokens) 165 | self._current_sentences.append(tokids) 166 | self._current_length += len(tokids) 167 | if self._current_length >= self._target_length: 168 | return self._create_example() 169 | return None 170 | 171 | def _create_example(self): 172 | """Creates a pre-training example from the current list of sentences.""" 173 | # small chance to only have one segment as in classification tasks 174 | if random.random() < 0.1: 175 | first_segment_target_length = 100000 176 | else: 177 | # -3 due to not yet having [CLS]/[SEP] tokens in the input text 178 | first_segment_target_length = (self._target_length - 3) // 2 179 | 180 | first_segment = [] 181 | second_segment = [] 182 | for sentence in self._current_sentences: 183 | # the sentence goes to the first segment if (1) the first segment is 184 | # empty, (2) the sentence doesn't put the first segment over length or 185 | # (3) 50% of the time when it does put the first segment over length 186 | if (len(first_segment) == 0 or 187 | len(first_segment) + len(sentence) < first_segment_target_length or 188 | (len(second_segment) == 0 and 189 | len(first_segment) < first_segment_target_length and 190 | random.random() < 0.5)): 191 | first_segment += sentence 192 | else: 193 | second_segment += sentence 194 | 195 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens 196 | first_segment = first_segment[:self._max_length - 2] 197 | second_segment = second_segment[:max(0, self._max_length - 198 | len(first_segment) - 3)] 199 | 200 | # prepare to start building the next example 201 | self._current_sentences = [] 202 | self._current_length = 0 203 | # small chance for random-length instead of max_length-length example 204 | if random.random() < 0.05: 205 | self._target_length = random.randint(5, self._max_length) 206 | else: 207 | self._target_length = self._max_length 208 | 209 | return self._make_example(first_segment, second_segment) 210 | 211 | def _make_example(self, first_segment, second_segment): 212 | """Converts two "segments" of text into a tf.train.Example.""" 213 | input_ids = [self.hf_tokenizer.cls_token_id] + first_segment + [self.hf_tokenizer.sep_token_id] 214 | sentA_length = len(input_ids) 215 | segment_ids = [0] * sentA_length 216 | if second_segment: 217 | input_ids += second_segment + [self.hf_tokenizer.sep_token_id] 218 | segment_ids += [1] * (len(second_segment) + 1) 219 | 220 | if self.minimize_data_size: 221 | return { 222 | 'input_ids': input_ids, 223 | 'sentA_length': sentA_length, 224 | } 225 | else: 226 | input_mask = [1] * len(input_ids) 227 | input_ids += [0] * (self._max_length - len(input_ids)) 228 | input_mask += [0] * (self._max_length - len(input_mask)) 229 | segment_ids += [0] * (self._max_length - len(segment_ids)) 230 | return { 231 | 'input_ids': input_ids, 232 | 'input_mask': input_mask, 233 | 'segment_ids': segment_ids, 234 | } -------------------------------------------------------------------------------- /_utils/would_like_to_pr.py: -------------------------------------------------------------------------------- 1 | import time 2 | from statistics import mean, stdev 3 | import torch 4 | from torch import nn 5 | from fastai.text.all import * 6 | 7 | """ 8 | I would like more uniform way to pass the metrics, no matter loss_func or metric, 9 | instantiate it and then pass. 10 | This uniform way also make it possible such as `metrics=[m() for m inTASK_METRICS[task]]` 11 | """ 12 | def Accuracy(axis=-1): 13 | return AvgMetric(partial(accuracy, axis=axis)) 14 | 15 | @delegates() 16 | class MyMSELossFlat(BaseLoss): 17 | def __init__(self,*args, axis=-1, floatify=True, low=None, high=None, **kwargs): 18 | super().__init__(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) 19 | self.low, self.high = low, high 20 | def decodes(self, x): 21 | if self.low is not None: x = torch.max(x, x.new_full(x.shape, self.low)) 22 | if self.high is not None: x = torch.min(x, x.new_full(x.shape, self.high)) 23 | return x 24 | 25 | class GradientClipping(Callback): 26 | def __init__(self, clip:float = 0.1): 27 | self.clip = clip 28 | assert self.clip 29 | def after_backward(self): 30 | if hasattr(self, 'scaler'): self.scaler.unscale_(self.opt) 31 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 32 | 33 | class RunSteps(Callback): 34 | toward_end = True 35 | 36 | def __init__(self, n_steps, save_points=None, base_name=None, no_val=True): 37 | """ 38 | Args: 39 | `n_steps` (`Int`): Run how many steps, could be larger or smaller than `len(dls.train)` 40 | `savepoints` 41 | - (`List[Float]`): save when reach one of percent specified. 42 | - (`List[Int]`): save when reache one of steps specified 43 | `base_name` (`String`): a format string with `{percent}` to be passed to `learn.save`. 44 | """ 45 | if save_points is None: save_points = [] 46 | else: 47 | assert '{percent}' in base_name 48 | save_points = [ s if isinstance(s,int) else int(n_steps*s) for s in save_points ] 49 | for sp in save_points: assert sp != 1, "Are you sure you want to save after 1 steps, instead of 1.0 * num_steps ?" 50 | assert max(save_points) <= n_steps 51 | store_attr('n_steps,save_points,base_name,no_val', self) 52 | 53 | def before_train(self): 54 | # fix pct_train (cuz we'll set `n_epoch` larger than we need) 55 | self.learn.pct_train = self.train_iter/self.n_steps 56 | 57 | def after_batch(self): 58 | # fix pct_train (cuz we'll set `n_epoch` larger than we need) 59 | self.learn.pct_train = self.train_iter/self.n_steps 60 | # when to save 61 | if self.train_iter in self.save_points: 62 | percent = (self.train_iter/self.n_steps)*100 63 | self.learn.save(self.base_name.format(percent=f'{percent}%')) 64 | # when to interrupt 65 | if self.train_iter == self.n_steps: 66 | raise CancelFitException 67 | 68 | def after_train(self): 69 | if self.no_val: 70 | if self.train_iter == self.n_steps: 71 | pass # CancelFit is raised, don't overlap it with CancelEpoch 72 | else: 73 | raise CancelEpochException 74 | 75 | _MESSAGE = [ 76 | 'dl.train load a batch + before_batch', 77 | 'forward + after_pred', 78 | 'loss calculation + after_loss', 79 | 'backward + after_backward', 80 | 'parameter updating + after_step', 81 | 'after_batch', 82 | ] 83 | 84 | @delegates() 85 | class Timer(RunSteps): 86 | toward_end=True 87 | 88 | def __init__(self, n_steps, ignore_first_n=1, break_after=None, precision=3, **kwargs): 89 | """ 90 | Args: 91 | `n_steps`: Average on how many training steps. 92 | `ignore_first_n`: Not use first n steps to average. Setting it at least 1 to avoid counting initilization time of dataloader is suggested. 93 | `break_after`: one of ['before_batch',...'after_batch'] 94 | `precision` 95 | """ 96 | steps = ignore_first_n + n_steps 97 | super().__init__(steps, **kwargs) 98 | store_attr('steps,break_after,ignore_first_n,precision', self) 99 | 100 | def time_delta(self): 101 | delta = time.time() - self.timepoint 102 | self.timepoint = time.time() 103 | return delta 104 | 105 | def before_fit(self): 106 | self.times = [ [] for _ in range(6)] 107 | self.timepoint = time.time() 108 | def before_batch(self): 109 | self.times[0].append(self.time_delta()) 110 | if self.break_after=='before_batch': raise CancelBatchException 111 | def after_pred(self): 112 | self.times[1].append(self.time_delta()) 113 | if self.break_after=='after_pred': raise CancelBatchException 114 | def after_loss(self): 115 | self.times[2].append(self.time_delta()) 116 | if self.break_after=='after_loss': raise CancelBatchException 117 | def after_backward(self): 118 | self.times[3].append(self.time_delta()) 119 | if self.break_after=='after_backward': raise CancelBatchException 120 | def after_step(self): 121 | self.times[4].append(self.time_delta()) 122 | if self.break_after=='after_step': raise CancelBatchException 123 | def after_batch(self): 124 | if self.break_after=='after_batch' or not self.break_after: 125 | self.times[5].append(self.time_delta()) 126 | if self.train_iter == self.steps: 127 | self.show() 128 | super().after_batch() 129 | 130 | def show(self): 131 | print(f"show average and standard deviation of step {self.ignore_first_n+1} ~ step {self.train_iter} (total {self.steps-self.ignore_first_n} training steps)") 132 | # print for each stage 133 | for i, deltas in enumerate(self.times): 134 | if len(deltas)==0: time_message = "Skipped or Exception raised by callbacks ran before Timer." 135 | else: 136 | m,s = mean(deltas[self.ignore_first_n:]), stdev(deltas[self.ignore_first_n:]) 137 | time_message = f"avg {round(m, self.precision)} secs ± stdev {round(s, self.precision)}" 138 | print(f"{(_MESSAGE[i]+':'):36} {time_message}") 139 | # print for total 140 | times = list(filter(None, self.times)) 141 | ## Some callback (e.g. MixedPrecisionCallback) might skip some stage "sometimes", so the length might be not equal 142 | max_len = max( len(deltas) for deltas in times ) 143 | for i, deltas in enumerate(times): 144 | if len(deltas) < max_len: times[i] += [0]*(max_len - len(deltas)) 145 | ## calculate 146 | times = torch.tensor(times)[:,self.ignore_first_n:] 147 | total_m, total_s = times.sum(0).mean().item(), times.sum(0).std().item() 148 | print(f"Total: avg {round(total_m, self.precision)} secs ± stdev {round(total_s, self.precision)} secs") 149 | 150 | class Ensemble(nn.Module): 151 | def __init__(self, models, device='cuda:0', merge_out_fc=None): 152 | super().__init__() 153 | self.models = nn.ModuleList( m.cpu() for m in models ) 154 | self.device = device 155 | self.merge_out_fc = merge_out_fc 156 | 157 | def to(self, device): 158 | self.device = device 159 | return self 160 | def getitem(self, i): return self.models[i] 161 | 162 | def forward(self, *args, **kwargs): 163 | outs = [] 164 | for m in self.models: 165 | m.to(self.device) 166 | out = m(*args, **kwargs) 167 | m.cpu() 168 | outs.append(out) 169 | if self.merge_out_fc: 170 | outs = self.merge_out_fc(outs) 171 | else: 172 | outs = torch.stack(outs) 173 | outs = outs.mean(dim=0) 174 | return outs -------------------------------------------------------------------------------- /fine_tunning/SentiWSP_fine_tunning_ASBA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json, os 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 9 | from sklearn.metrics import precision_recall_curve,classification_report 10 | from transformers import ElectraTokenizer, ElectraForSequenceClassification, ElectraConfig, get_linear_schedule_with_warmup,ElectraTokenizerFast 11 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 12 | import absa_data_utils as data_utils 13 | from transformers import WEIGHTS_NAME, CONFIG_NAME 14 | 15 | # 随机数种子 16 | random_seed = 2022 17 | random.seed(random_seed) 18 | torch.manual_seed(random_seed) 19 | 20 | # parser config 21 | parser = argparse.ArgumentParser(description='Pre training model configuration') 22 | parser.add_argument('--data_path', default="./jsondata/",help='finetune data path') 23 | parser.add_argument('--dataset', default="rest",help='finetune data name') 24 | parser.add_argument('--inference', default="valid",help='Choose whether to use a test set or a validation set') 25 | parser.add_argument('--model_path', default='./pretrain_model/',help='finetune model path') 26 | parser.add_argument('--model_name', default='ance7word5',help='finetune model name') 27 | parser.add_argument('--tokenizer_path', default="./pretrain_model/electra/large",help='finetune tokenizer path') 28 | parser.add_argument('--single_gpunum', default='0',help='GPU number used in single GPU environment') 29 | parser.add_argument('--save_model', type=bool ,default = False ,help='wheather you want to save the model') 30 | parser.add_argument('--save_model_path', default='./model_result/',help='the save dir of fine-tunning model path') 31 | parser.add_argument('--result_path', default='./result/',help='the save dir of fine-tunning experiment result') 32 | parser.add_argument('--save_model_name', default='ABSA_checkpoint',help='model name') 33 | parser.add_argument('--batch_size',type=int ,default=32,help='the batch size of training process') 34 | parser.add_argument('--max_epoch', type=int,default=10,help='the max epoch of training process') 35 | parser.add_argument('--lr', type=float,default = 1e-5,help='the learning rate of fine-tunning process') 36 | parser.add_argument('--eps', default = 1e-8,help='the eps of fine-tunning process') 37 | parser.add_argument('--max_len', type=int, default = 128,help='the max_length of input sequence') 38 | 39 | args = parser.parse_args() 40 | 41 | def validation(model,test_loader): 42 | 43 | model.eval() 44 | y_true=[] 45 | y_pred=[] 46 | total_eval_loss=0 47 | 48 | for batch in test_loader: 49 | 50 | with torch.no_grad(): 51 | batch = tuple(t.to(device) for t in batch) 52 | input_ids, segment_ids, input_mask, label_ids = batch 53 | model_input = {'input_ids':input_ids, 54 | 'token_type_ids':segment_ids, 55 | 'attention_mask':input_mask, 56 | 'labels':label_ids} 57 | outputs = model(**model_input) 58 | 59 | loss = outputs.loss 60 | logits = outputs.logits 61 | 62 | total_eval_loss += loss.item() 63 | prediction = torch.argmax(logits, 1) 64 | prediction = prediction.detach().cpu().tolist() 65 | label_ids = label_ids.to('cpu').tolist() 66 | y_pred+=prediction 67 | y_true+=label_ids 68 | target_names = ['positive', 'negative', 'neutral'] 69 | print("-------------------------------") 70 | print(classification_report(y_true, y_pred, target_names=target_names,digits=4)) 71 | file = open(args.result_path + args.model_name+'_'+args.dataset + "_result.txt", "a+") 72 | file.write(classification_report(y_true, y_pred, target_names=target_names,digits=4) + '\n' + '\n') 73 | file.close() 74 | print("Average valid loss: %.4f"%(total_eval_loss/len(test_loader))) 75 | print("-------------------------------") 76 | 77 | 78 | def train(model, train_loader, max_epoch, test_loader): 79 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr,eps = args.eps) 80 | total_steps = len(train_loader) * max_epoch 81 | warmup_steps=0.1*len(train_loader) 82 | scheduler = get_linear_schedule_with_warmup(optimizer, 83 | num_warmup_steps = warmup_steps, # Default value in run_glue.py 84 | num_training_steps = total_steps) 85 | model.train() 86 | 87 | for epoch in range(max_epoch): 88 | 89 | total_train_loss = 0 90 | 91 | for iter_num, batch in enumerate(tqdm(train_loader)): 92 | 93 | batch = tuple(t.to(device) for t in batch) 94 | input_ids, segment_ids, input_mask, label_ids = batch 95 | model_input = {'input_ids':input_ids, 96 | 'token_type_ids':segment_ids, 97 | 'attention_mask':input_mask, 98 | 'labels':label_ids} 99 | 100 | outputs = model(**model_input) 101 | loss = outputs.loss 102 | total_train_loss += loss.item() 103 | loss.backward() 104 | 105 | # clip 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0) 107 | optimizer.step() 108 | scheduler.step() 109 | optimizer.zero_grad() 110 | 111 | if iter_num % 20 == 0: 112 | # print(label_ids) 113 | # print(torch.argmax(outputs.logits, 1)) 114 | print("epoth: %d, iter_num: %d, loss: %.4f" % (epoch, iter_num, loss.item())) 115 | 116 | validation(model, test_loader) 117 | if args.save_model: 118 | save_model(model, args.save_model_path + str(epoch+1) + "epoch_" + args.save_model_name) 119 | print("Epoch: %d, Average training loss: %.4f" %(epoch, total_train_loss/len(train_loader))) 120 | 121 | def save_model(model,output_dir): 122 | if os.path.exists(output_dir) == False: 123 | os.makedirs(output_dir) 124 | model_to_save = model.module if hasattr(model, 'module') else model 125 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 126 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 127 | torch.save(model_to_save.state_dict(), output_model_file) 128 | model_to_save.config.to_json_file(output_config_file) 129 | print("save model in :"+output_dir) 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | # Define the operating environment 135 | device = torch.device("cuda:" + str(args.single_gpunum) if torch.cuda.is_available() else "cpu") 136 | data_path=args.data_path+args.dataset 137 | # load data 138 | processor = data_utils.AscProcessor() 139 | label_list = processor.get_labels() 140 | train_examples = processor.get_train_examples(data_path) 141 | print("load dataset:",args.dataset) 142 | # ABSA standard data input format conversion 143 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 144 | 145 | train_features = data_utils.convert_examples_to_features( 146 | train_examples, label_list, args.max_len, tokenizer) 147 | 148 | # turn to tensor 149 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 150 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 151 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 152 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 153 | train_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_label_ids) 154 | 155 | # Random Sampler DataLoader 156 | train_sampler = RandomSampler(train_data) 157 | train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) 158 | 159 | # load model 160 | model_path = args.model_path+args.model_name 161 | model = AutoModelForSequenceClassification.from_pretrained(save_path, num_labels=len(label_list)) 162 | model.to(device) 163 | print("load model:", args.model_name) 164 | 165 | # create directory 166 | if os.path.exists(args.result_path) == False: 167 | os.makedirs(args.result_path) 168 | if args.save_model and os.path.exists(args.save_model_path) == False: 169 | os.makedirs(args.save_model_path) 170 | 171 | # Load valid dataset 172 | if args.inference == "valid": 173 | valid_examples = processor.get_test_examples(data_path) 174 | valid_features = data_utils.convert_examples_to_features( 175 | valid_examples, label_list, args.max_len, tokenizer) 176 | valid_all_input_ids = torch.tensor([f.input_ids for f in valid_features], dtype=torch.long) 177 | valid_all_segment_ids = torch.tensor([f.segment_ids for f in valid_features], dtype=torch.long) 178 | valid_all_input_mask = torch.tensor([f.input_mask for f in valid_features], dtype=torch.long) 179 | valid_all_label_ids = torch.tensor([f.label_id for f in valid_features], dtype=torch.long) 180 | valid_data = TensorDataset(valid_all_input_ids, valid_all_segment_ids, valid_all_input_mask, valid_all_label_ids) 181 | valid_sampler = SequentialSampler(valid_data) 182 | valid_loader = DataLoader(valid_data, sampler=valid_sampler, batch_size=args.batch_size) 183 | # start to train 184 | train(model, train_loader, args.max_epoch, valid_loader) 185 | # use test dataset 186 | elif args.inference == "test": 187 | test_examples = processor.get_test_examples(data_path) 188 | test_features = data_utils.convert_examples_to_features( 189 | test_examples, label_list, args.max_len, tokenizer) 190 | test_all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 191 | test_all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 192 | test_all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 193 | test_all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 194 | test_data = TensorDataset(test_all_input_ids, test_all_segment_ids, test_all_input_mask, test_all_label_ids) 195 | test_sampler = SequentialSampler(test_data) 196 | test_loader = DataLoader(test_data, sampler=test_sampler, batch_size=args.batch_size) 197 | # start to train 198 | train(model, train_loader, args.max_epoch, test_loader) 199 | else: 200 | print("inference must be test or valid") 201 | exit(0) 202 | 203 | -------------------------------------------------------------------------------- /fine_tunning/SentiWSP_fine_tunning_SA.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | torch.multiprocessing.set_start_method('spawn') 4 | import sys 5 | import os 6 | import argparse 7 | from tqdm import tqdm 8 | import random 9 | from transformers import ElectraTokenizer, ElectraForSequenceClassification, ElectraConfig, get_linear_schedule_with_warmup 10 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 11 | from sklearn.metrics import precision_recall_curve,classification_report 12 | from transformers import BertForSequenceClassification, BertTokenizer, BertConfig 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.nn.parallel import DistributedDataParallel 15 | import torch.distributed as dist 16 | from torch.utils.data import Dataset, DataLoader 17 | from transformers import WEIGHTS_NAME, CONFIG_NAME 18 | 19 | 20 | random_seed = 2022 21 | random.seed(random_seed) 22 | torch.manual_seed(random_seed) 23 | 24 | parser = argparse.ArgumentParser(description='Pre training model configuration') 25 | parser.add_argument('--model', default='electra',help='pre train model type') 26 | parser.add_argument('--model_size', default='Small',help='pre train model size') 27 | # dataset :【imdb、yelp2、sst5、yelp5、mr、custom、】(custom Custom data file path) 28 | # The custom dataset must contain (text key, label key), and must contain the train and valid files 29 | parser.add_argument('--dataset', default='custom',help='pre train dataset') 30 | parser.add_argument('--inference', default="valid",help='Choose whether to use a test set or a validation set') 31 | parser.add_argument('--gpu_num', type=int ,default=4,help='pre train gpu num') 32 | parser.add_argument('--data_path', default='./datasets/',help='the save dir of dataset') 33 | parser.add_argument('--save_model', default='./save_models/',help='the save dir of fine-tunning model path') 34 | parser.add_argument('--pretrain_model', default='./pretrain_model/',help='the save dir of original model path') 35 | parser.add_argument('--batch_size',type=int ,default=32,help='the batch size of training process') 36 | parser.add_argument('--max_epoch', type=int,default=10,help='the max epoch of training process') 37 | parser.add_argument('--lr', default = 2e-5,help='the learning rate of fine-tunning process') 38 | parser.add_argument('--eps', default = 1e-8,help='the eps of fine-tunning process') 39 | parser.add_argument('--max_length', type=int,default = 512,help='the max_length of input sequence') 40 | parser.add_argument("--local_rank", type=int,default=-1, help="local rank") 41 | parser.add_argument("--rank", type=int,default=-1, help="rank") 42 | parser.add_argument("--result_path", default = './result/', help="result path") 43 | parser.add_argument("--loadmodel", type=bool, default = False, help="isload model") 44 | parser.add_argument("--loadmodelpath", type=str, default='ance7word5',help="load model path") 45 | parser.add_argument('--jsondata', default='./jsondata/custom/',help='the save dir of custom json dataset(if you use custom data)') 46 | parser.add_argument("--issavemodel", type=bool, default = False, help="is save model?") 47 | parser.add_argument("--fewshot", type=bool, default = False, help="is use few shot data") 48 | parser.add_argument("--fewshotnum", type=int, default = 8, help="few shot data size") 49 | parser.add_argument("--zeroshot", type=bool, default = False, help="is use zero shot?") 50 | parser.add_argument("--fewshotdatapath", type=str, default = './fewdata/', help="few shot data path") 51 | parser.add_argument("--num_class", type=int, default = 2, help="is classify num labels") 52 | args = parser.parse_args() 53 | 54 | torch.cuda.current_device() 55 | # Global definition to determine whether multi machine multi GPU or single machine multi GPU 56 | if args.gpu_num>1 and args.gpu_num<=8: 57 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7" 58 | torch.distributed.init_process_group(backend="nccl") 59 | local_rank = torch.distributed.get_rank() 60 | torch.cuda.set_device(local_rank) 61 | device = torch.device("cuda", local_rank) 62 | elif args.gpu_num>8: 63 | dist.init_process_group(backend='nccl') 64 | rank=torch.distributed.get_rank() 65 | local_rank = torch.distributed.get_rank() 66 | torch.cuda.set_device(local_rank) 67 | world_size = torch.distributed.get_world_size() 68 | device = torch.device("cuda", local_rank) 69 | else: 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | 72 | max_epoch = args.max_epoch 73 | lr = args.lr 74 | if args.model_size=='large': 75 | lr=1e-5 76 | eps = args.eps 77 | max_length = args.max_length 78 | batch_size = args.batch_size 79 | dataset_path = args.data_path + args.dataset 80 | 81 | if args.loadmodel: 82 | loadmodelpath=args.pretrain_model+args.loadmodelpath 83 | print("loadmodelpath:",loadmodelpath) 84 | else: 85 | print("--loadmodel=false") 86 | sys.exit() 87 | print("lr:",lr) 88 | 89 | 90 | def train(model, dataset, test_loader, max_epoch): 91 | 92 | # model.to(device) 93 | if args.gpu_num>1 and args.gpu_num<=8: 94 | sampler = DistributedSampler(dataset['train']) 95 | train_loader = DataLoader(dataset['train'], batch_size=batch_size,sampler=sampler) 96 | elif args.gpu_num>8: 97 | sampler = DistributedSampler(dataset['train'],num_replicas=world_size, rank=rank) 98 | train_loader = DataLoader(dataset['train'], batch_size=batch_size,sampler=sampler) 99 | else: 100 | train_loader = DataLoader(dataset['train'], batch_size=batch_size,shuffle=True) 101 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr,eps = eps) 102 | total_steps = len(train_loader) * max_epoch*args.gpu_num 103 | warmup_steps=0 104 | if args.model_size=='large': 105 | warmup_steps=0.1*len(train_loader)/args.gpu_num 106 | scheduler = get_linear_schedule_with_warmup(optimizer, 107 | num_warmup_steps = warmup_steps, # Default value in run_glue.py 108 | num_training_steps = total_steps) 109 | model.train() 110 | 111 | for epoch in range(max_epoch): 112 | # dataset = dataset.shuffle() 113 | # train_loader = DataLoader(dataset['train'], batch_size=batch_size) 114 | if args.gpu_num>1: 115 | sampler.set_epoch(epoch) 116 | total_train_loss = 0 117 | for iter_num, batch in enumerate(tqdm(train_loader)): 118 | batch = {k: v.to(device) for k, v in batch.items()} 119 | outputs = model.forward(**batch) 120 | loss = outputs.loss 121 | total_train_loss += loss.item() 122 | 123 | loss.backward() 124 | # clip 125 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0) 126 | optimizer.step() 127 | scheduler.step() 128 | optimizer.zero_grad() 129 | 130 | if iter_num % 20 == 0: 131 | print("epoth: %d, iter_num: %d, loss: %.4f" % (epoch, iter_num, loss.item())) 132 | 133 | if not args.loadmodel: 134 | if args.gpu_num>1: 135 | output_dir=args.save_model+args.model+'-'+args.model_size+'_'+args.dataset+'_epoch_'+str(epoch)+'_multgpu'+str(args.gpu_num) 136 | else: 137 | output_dir=args.save_model+args.model+'-'+args.model_size+'_'+args.dataset+'_epoch_'+str(epoch) 138 | else: 139 | output_dir=args.save_model+args.loadmodelpath+'_'+args.dataset+'_epoch_'+str(epoch)+'trans' 140 | 141 | if args.gpu_num>1: 142 | if dist.get_rank() == 0: 143 | validation(model,test_loader) 144 | else: 145 | validation(model,test_loader) 146 | if args.issavemodel: 147 | if args.gpu_num>1: 148 | if dist.get_rank() == 0: 149 | save_model(model,output_dir) 150 | else: 151 | save_model(model,output_dir) 152 | print("Epoch: %d, Average training loss: %.4f" %(epoch, total_train_loss/len(train_loader))) 153 | 154 | 155 | def validation(model,test_loader): 156 | model.eval() 157 | y_true=[] 158 | y_pred=[] 159 | total_eval_loss=0 160 | 161 | for batch in test_loader: 162 | with torch.no_grad(): 163 | batch = {k: v.to(device) for k, v in batch.items()} 164 | outputs = model.forward(**batch) 165 | 166 | loss = outputs.loss 167 | logits = outputs.logits 168 | 169 | total_eval_loss += loss.item() 170 | prediction = torch.argmax(logits, 1) 171 | prediction = prediction.detach().cpu().tolist() 172 | label_ids = batch['labels'].to('cpu').tolist() 173 | # print(prediction) 174 | # print(label_ids) 175 | y_pred+=prediction 176 | y_true+=label_ids 177 | if args.num_class==2: 178 | target_names = ['Negative', 'Positive'] 179 | elif args.num_class==5: 180 | target_names = ['Very Negative','Negative','Neutral', 'Positive','Very Positive'] 181 | else: 182 | sys.exit() 183 | print(classification_report(y_true, y_pred, target_names=target_names,digits=4)) 184 | if os.path.exists(args.result_path) == False: 185 | os.makedirs(args.result_path) 186 | if not args.loadmodel: 187 | if args.zeroshot: 188 | file = open(args.result_path + args.model+'-'+args.model_size+ "_" + args.dataset+'_zeroshot_result.txt',"a+") 189 | elif args.fewshot: 190 | file = open(args.result_path + args.model+'-'+args.model_size+ "_" + args.dataset+'_fewshot'+str(args.fewshotnum)+'_result.txt',"a+") 191 | else: 192 | file = open(args.result_path + args.model+'-'+args.model_size+ "_" + args.dataset + "_result.txt", "a+") 193 | else: 194 | if args.zeroshot: 195 | file = open(args.result_path + args.loadmodelpath+ "_" + args.dataset+'_zeroshot_result.txt',"a+") 196 | elif args.fewshot: 197 | file = open(args.result_path + args.loadmodelpath+ "_" + args.dataset+'_fewshot'+str(args.fewshotnum)+'_result.txt',"a+") 198 | else: 199 | file = open(args.result_path + args.loadmodelpath+ "_" + args.dataset+'trans' + "_result.txt", "a+") 200 | 201 | file.write(classification_report(y_true, y_pred, target_names=target_names,digits=4) + '\n' + '\n') 202 | file.close() 203 | print("Average testing loss: %.4f"%(total_eval_loss/len(test_loader))) 204 | print("-------------------------------") 205 | 206 | def save_model(model,output_dir): 207 | if os.path.exists(output_dir) == False: 208 | os.makedirs(output_dir) 209 | model_to_save = model.module if hasattr(model, 'module') else model 210 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 211 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 212 | torch.save(model_to_save.state_dict(), output_model_file) 213 | model_to_save.config.to_json_file(output_config_file) 214 | print("save model in :"+output_dir) 215 | 216 | 217 | if __name__ == "__main__": 218 | #model 219 | print("classify class is "+str(args.num_class)+" and use load model") 220 | model_path = os.path.join(args.pretrain_model, args.model) 221 | save_path = loadmodelpath 222 | tokenizer = AutoTokenizer.from_pretrained(save_path) 223 | model = AutoModelForSequenceClassification.from_pretrained(save_path, num_labels=args.num_class) 224 | #data 225 | if args.fewshot or args.zeroshot: 226 | #few shot and zero shot finetune 227 | dataset_path=args.fewshotdatapath+args.dataset 228 | if args.fewshot: 229 | def encode(examples): 230 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 231 | dataset = {} 232 | train_path="/train"+str(args.fewshotnum)+".json" 233 | dataset['train'] = datasets.dataset_dict.DatasetDict.from_json(dataset_path + train_path) 234 | if args.inference == "test": 235 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(dataset_path+ "test.jsonl") 236 | elif args.inference == "valid": 237 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(dataset_path + "valid.jsonl") 238 | else: 239 | print("inference must be test or valid") 240 | exit(0) 241 | print("format " + args.dataset + " dataset....") 242 | dataset['train'] = dataset['train'].map(encode, batched=True) 243 | dataset['test'] = dataset['test'].map(encode, batched=True) 244 | dataset['train'] = dataset['train'].map(lambda examples: {'labels': examples['label']}, batched=True) 245 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 246 | dataset['train'].set_format(type='torch', 247 | columns=['input_ids', 'token_type_ids', 'attention_mask', 248 | 'labels']) 249 | dataset['test'].set_format(type='torch', 250 | columns=['input_ids', 'token_type_ids', 'attention_mask', 251 | 'labels']) 252 | elif args.zeroshot: 253 | print("zero shot inference") 254 | def encode(examples): 255 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 256 | dataset = {} 257 | if args.inference == "test": 258 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(dataset_path + "test.jsonl") 259 | elif args.inference == "valid": 260 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(dataset_path + "valid.jsonl") 261 | else: 262 | print("inference must be test or valid") 263 | exit(0) 264 | print("format " + args.dataset + " dataset....") 265 | dataset['test'] = dataset['test'].map(encode, batched=True) 266 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 267 | dataset['test'].set_format(type='torch', 268 | columns=['input_ids', 'token_type_ids', 'attention_mask', 269 | 'labels']) 270 | else: 271 | # Fintune loading and processing datasets 272 | if args.dataset == 'imdb': 273 | print("load imdb dataset...") 274 | def encode(examples): 275 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 276 | 277 | if os.path.exists(dataset_path) == False: 278 | dataset = datasets.load_dataset('imdb', cache_dir='./datasets') 279 | dataset.save_to_disk(dataset_path) 280 | else: 281 | print(dataset_path) 282 | dataset = datasets.load_from_disk(dataset_path) 283 | print(dataset_path) 284 | print("format " + args.dataset + " dataset....") 285 | dataset = dataset.map(encode, batched=True) 286 | dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True) 287 | dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels']) 288 | 289 | elif args.dataset == 'yelp2': 290 | print("load yelp2 dataset...") 291 | def encode(examples): 292 | return tokenizer(examples['text'], truncation=True, padding='max_length',max_length=max_length) 293 | 294 | if os.path.exists(dataset_path) == False: 295 | dataset = datasets.load_dataset('yelp_polarity', cache_dir='./datasets') 296 | dataset.save_to_disk(dataset_path) 297 | else: 298 | dataset = datasets.load_from_disk(dataset_path) 299 | 300 | print("format " + args.dataset + " dataset....") 301 | dataset = dataset.map(encode, batched=True) 302 | dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True) 303 | dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels']) 304 | elif args.dataset == 'custom': 305 | #you can put the data in jsondata path in train.json and test.json to finetunning 306 | # path in args.jsondata 307 | print("Load custom datasets...") 308 | def encode(examples): 309 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 310 | 311 | dataset = {} 312 | if args.inference == "test": 313 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(args.jsondata + "test.jsonl") 314 | elif args.inference == "valid": 315 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(args.jsondata + "valid.jsonl") 316 | else: 317 | print("inference must be test or valid") 318 | exit(0) 319 | print("format " + args.dataset + " dataset....") 320 | dataset['train'] = dataset['train'].map(encode, batched=True) 321 | dataset['test'] = dataset['test'].map(encode, batched=True) 322 | dataset['train'] = dataset['train'].map(lambda examples: {'labels': examples['label']}, batched=True) 323 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 324 | dataset['train'].set_format(type='torch', 325 | columns=['input_ids', 'token_type_ids', 'attention_mask', 326 | 'labels']) 327 | dataset['test'].set_format(type='torch', 328 | columns=['input_ids', 'token_type_ids', 'attention_mask', 329 | 'labels']) 330 | print(len(dataset['train'])) 331 | elif args.dataset == 'mr': 332 | print("Load mr datasets...") 333 | def encode(examples): 334 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 335 | jsondata_path = './jsondata/MR/' 336 | dataset = {} 337 | dataset['train'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path+ "train.json") 338 | if args.inference == "test": 339 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path + "test.jsonl") 340 | elif args.inference == "valid": 341 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path + "valid.jsonl") 342 | else: 343 | print("inference must be test or valid") 344 | exit(0) 345 | print("format " + args.dataset + " dataset....") 346 | dataset['train'] = dataset['train'].map(encode, batched=True) 347 | dataset['test'] = dataset['test'].map(encode, batched=True) 348 | dataset['train'] = dataset['train'].map(lambda examples: {'labels': examples['label']}, batched=True) 349 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 350 | dataset['train'].set_format(type='torch', 351 | columns=['input_ids', 'token_type_ids', 'attention_mask', 352 | 'labels']) 353 | dataset['test'].set_format(type='torch', 354 | columns=['input_ids', 'token_type_ids', 'attention_mask', 355 | 'labels']) 356 | print(len(dataset['train'])) 357 | 358 | elif args.dataset == 'sst5': 359 | print("Load sst5 datasets...") 360 | def encode(examples): 361 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 362 | jsondata_path = './jsondata/sst5/' 363 | dataset = {} 364 | dataset['train'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path+ "train.jsonl") 365 | if args.inference == "test": 366 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path + "test.jsonl") 367 | elif args.inference == "valid": 368 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path + "valid.jsonl") 369 | else: 370 | print("inference must be test or valid") 371 | exit(0) 372 | print("format " + args.dataset + " dataset....") 373 | dataset['train'] = dataset['train'].map(encode, batched=True) 374 | dataset['test'] = dataset['test'].map(encode, batched=True) 375 | dataset['train'] = dataset['train'].map(lambda examples: {'labels': examples['label']}, batched=True) 376 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 377 | dataset['train'].set_format(type='torch', 378 | columns=['input_ids', 'token_type_ids', 'attention_mask', 379 | 'labels']) 380 | dataset['test'].set_format(type='torch', 381 | columns=['input_ids', 'token_type_ids', 'attention_mask', 382 | 'labels']) 383 | elif args.dataset =='yelp5': 384 | print("load yelp5 dataset...") 385 | print("Load sst5 datasets...") 386 | def encode(examples): 387 | return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length) 388 | jsondata_path = './jsondata/yelp5/' 389 | dataset = {} 390 | dataset['train'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path+ "train.jsonl") 391 | if args.inference == "test": 392 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path+ "test.jsonl") 393 | elif args.inference == "valid": 394 | dataset['test'] = datasets.dataset_dict.DatasetDict.from_json(jsondata_path + "valid.jsonl") 395 | else: 396 | print("inference must be test or valid") 397 | exit(0) 398 | print("format " + args.dataset + " dataset....") 399 | dataset['train'] = dataset['train'].map(encode, batched=True) 400 | dataset['test'] = dataset['test'].map(encode, batched=True) 401 | dataset['train'] = dataset['train'].map(lambda examples: {'labels': examples['label']}, batched=True) 402 | dataset['test'] = dataset['test'].map(lambda examples: {'labels': examples['label']}, batched=True) 403 | dataset['train'].set_format(type='torch', 404 | columns=['input_ids', 'token_type_ids', 'attention_mask', 405 | 'labels']) 406 | dataset['test'].set_format(type='torch', 407 | columns=['input_ids', 'token_type_ids', 'attention_mask', 408 | 'labels']) 409 | 410 | else: 411 | print("no dataset name " + args.dataset) 412 | sys.exit() 413 | 414 | test_loader = DataLoader(dataset['test'], batch_size=batch_size) 415 | # if args.inference == "test": 416 | # test_loader = DataLoader(dataset['test'], batch_size=batch_size) 417 | # elif args.inference == "valid": 418 | # test_loader = DataLoader(dataset['valid'], batch_size=batch_size) 419 | # else: 420 | # print("inference must be test or valid") 421 | # exit(0) 422 | 423 | model.to(device) 424 | #model 425 | if args.gpu_num>1: 426 | if torch.cuda.device_count() > 1: 427 | print("Let's use", torch.cuda.device_count(), "GPUs!") 428 | model = DistributedDataParallel(model, 429 | device_ids=[local_rank], 430 | output_device=local_rank) 431 | if not args.zeroshot: 432 | if not args.fewshot: 433 | # train model 434 | print("finetune " + args.model + " model...") 435 | train(model,dataset,test_loader,max_epoch) 436 | else: 437 | print("few shot train " + args.model + " model...") 438 | 439 | train(model,dataset,test_loader,max_epoch) 440 | else: 441 | print("zero shot inference " + args.model + " model...") 442 | validation(model,test_loader) 443 | -------------------------------------------------------------------------------- /fine_tunning/absa_data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | class InputExample(object): 6 | 7 | def __init__(self, data_type, text_a, text_b=None, label=None): 8 | self.type = data_type 9 | self.text_a = text_a 10 | self.text_b = text_b 11 | self.label = label 12 | 13 | class InputFeatures(object): 14 | """A single set of features of data.""" 15 | 16 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 17 | self.input_ids = input_ids 18 | self.input_mask = input_mask 19 | self.segment_ids = segment_ids 20 | self.label_id = label_id 21 | 22 | 23 | class DataProcessor(object): 24 | """Base class for data converters for sequence classification data sets.""" 25 | 26 | def get_train_examples(self, data_dir): 27 | """Gets a collection of `InputExample`s for the train set.""" 28 | raise NotImplementedError() 29 | 30 | def get_dev_examples(self, data_dir): 31 | """Gets a collection of `InputExample`s for the dev set.""" 32 | raise NotImplementedError() 33 | 34 | def get_test_examples(self, data_dir): 35 | """Gets a collection of `InputExample`s for the test set.""" 36 | raise NotImplementedError() 37 | 38 | def get_labels(self): 39 | """Gets the list of labels for this data set.""" 40 | raise NotImplementedError() 41 | 42 | @classmethod 43 | def _read_json(cls, input_file): 44 | """Reads a json file for tasks in sentiment analysis.""" 45 | with open(input_file) as f: 46 | return json.load(f) 47 | 48 | class AscProcessor(DataProcessor): 49 | """Processor for the SemEval Aspect Sentiment Classification.""" 50 | 51 | def get_train_examples(self, data_dir, fn="train.json"): 52 | """See base class.""" 53 | return self._create_examples( 54 | self._read_json(os.path.join(data_dir, fn)), "train") 55 | 56 | def get_dev_examples(self, data_dir, fn="dev.json"): 57 | """See base class.""" 58 | return self._create_examples( 59 | self._read_json(os.path.join(data_dir, fn)), "dev") 60 | 61 | def get_test_examples(self, data_dir, fn="test.json"): 62 | """See base class.""" 63 | return self._create_examples( 64 | self._read_json(os.path.join(data_dir, fn)), "test") 65 | 66 | def get_labels(self): 67 | """See base class.""" 68 | return ["positive", "negative", "neutral"] 69 | 70 | def _create_examples(self, lines, set_type): 71 | """Creates examples for the training and dev sets.""" 72 | examples = [] 73 | for (i, ids) in enumerate(lines): 74 | data_type = set_type 75 | text_a = lines[ids]['term'] 76 | text_b = lines[ids]['sentence'] 77 | label = lines[ids]['polarity'] 78 | examples.append( 79 | InputExample(data_type=data_type, text_a=text_a, text_b=text_b, label=label)) 80 | return examples 81 | 82 | def convert_examples_to_features(examples, label_list, max_len, tokenizer): 83 | """Loads a data file into a list of `InputBatch`s.""" #check later if we can merge this function with the SQuAD preprocessing 84 | label_map = {} 85 | for (i, label) in enumerate(label_list): 86 | label_map[label] = i 87 | 88 | features = [] 89 | for (ex_index, example) in enumerate(examples): 90 | 91 | tokens_a = tokenizer.tokenize(example.text_a) 92 | 93 | tokens_b = None 94 | if example.text_b: 95 | tokens_b = tokenizer.tokenize(example.text_b) 96 | 97 | if tokens_b: 98 | # Modifies `tokens_a` and `tokens_b` in place so that the total 99 | # length is less than the specified length. 100 | # Account for [CLS], [SEP], [SEP] with "- 3" 101 | _truncate_seq_pair(tokens_a, tokens_b, max_len - 3) 102 | else: 103 | # Account for [CLS] and [SEP] with "- 2" 104 | if len(tokens_a) > max_len - 2: 105 | tokens_a = tokens_a[0:(max_len - 2)] 106 | 107 | # target: "[CLS]aspect[SEP]sentence[SEP]" 108 | # segment_ids=0 when "[CLS]aspect[SEP]" segment_ids=1 when "sentence[SEP]",segment_ids=0 when "[PAD]" 109 | tokens = [] 110 | segment_ids = [] 111 | tokens.append("[CLS]") 112 | segment_ids.append(0) 113 | for token in tokens_a: 114 | tokens.append(token) 115 | segment_ids.append(0) 116 | tokens.append("[SEP]") 117 | segment_ids.append(0) 118 | 119 | if tokens_b: 120 | for token in tokens_b: 121 | tokens.append(token) 122 | segment_ids.append(1) 123 | tokens.append("[SEP]") 124 | segment_ids.append(1) 125 | 126 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 127 | 128 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 129 | # tokens are attended to. 130 | input_mask = [1] * len(input_ids) 131 | 132 | # Zero-pad up to the sequence length. 133 | while len(input_ids) < max_len: 134 | input_ids.append(0) 135 | input_mask.append(0) 136 | segment_ids.append(0) 137 | 138 | assert len(input_ids) == max_len 139 | assert len(input_mask) == max_len 140 | assert len(segment_ids) == max_len 141 | 142 | label_id = label_map[example.label] 143 | 144 | features.append( 145 | InputFeatures( 146 | input_ids=input_ids, 147 | input_mask=input_mask, 148 | segment_ids=segment_ids, 149 | label_id=label_id)) 150 | return features 151 | 152 | # pad a、b句长度同时缩减 153 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 154 | """Truncates a sequence pair in place to the maximum length.""" 155 | 156 | # This is a simple heuristic which will always truncate the longer sequence 157 | # one token at a time. This makes more sense than truncating an equal percent 158 | # of tokens from each, since if one sequence is very short then each token 159 | # that's truncated likely contains more information than a longer sequence. 160 | while True: 161 | total_length = len(tokens_a) + len(tokens_b) 162 | if total_length <= max_length: 163 | break 164 | if len(tokens_a) > len(tokens_b): 165 | tokens_a.pop() 166 | else: 167 | tokens_b.pop() -------------------------------------------------------------------------------- /pretrain/SentiWSP_Pretrain_ANCE_GEN.py: -------------------------------------------------------------------------------- 1 | import os, sys, random 2 | from pathlib import Path 3 | from datetime import datetime, timezone, timedelta 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining, get_linear_schedule_with_warmup, ElectraModel 12 | from _utils.utils import * 13 | from _utils.would_like_to_pr import * 14 | from tqdm import tqdm 15 | import json 16 | import faiss 17 | import argparse 18 | from transformers import WEIGHTS_NAME, CONFIG_NAME 19 | 20 | # parser config 21 | parser = argparse.ArgumentParser(description='ANN data generate configuration') 22 | parser.add_argument('--model', default='electra',help='pre train model type') 23 | parser.add_argument('--pretrain_model', type=str,default='./pretrain_model/',help='pre train model path') 24 | parser.add_argument('--model_path', default='largesen',help='model path required to generate Ann data(Warm up model)') 25 | parser.add_argument('--tokenizer_path', default='./pretrain_model/electra/large/',help='tokenizer path required to generate Ann data') 26 | parser.add_argument('--size', type=str,default='large',help='the size of model') 27 | parser.add_argument('--max_length', type=int,default = 128,help='the max_length of input sequence') 28 | parser.add_argument('--use_exist_data', type=bool ,default = False,help='Whether to use the sampled data. If not, regenerate it') 29 | parser.add_argument('--data_path', default='./ANCEdata/',help='Using the existing data to generate ANN') 30 | parser.add_argument('--sentimask_prob', type=float ,default = 0.7 ,help='Proportion of emotional words mask') 31 | parser.add_argument('--pooling', default ='cls' ,help='pooling method') 32 | parser.add_argument('--ANN_topK', type=int ,default = 100 ,help='Number of similar documents returned when Ann search') 33 | #parser.add_argument('--negative_num', type=int ,default = 10 ,help='Total number of hard negative selected') 34 | parser.add_argument('--negative_num', type=list ,default = 7 ,help='Total number of hard negative selected') 35 | parser.add_argument('--neg_ann_name', default='sen',help='hard negative path') 36 | parser.add_argument('--gpu_num',default='0',help='Number of GPUs used in training') 37 | parser.add_argument('--batch_size',type=int ,default=64,help='the batch size of emb process') 38 | parser.add_argument('--wiki_num',type=int ,default=500000,help='the sample size of Document') 39 | parser.add_argument('--loadwikijson',type=bool ,default=True,help='the sample size of Document') 40 | args = parser.parse_args() 41 | 42 | 43 | # train on device 44 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 45 | 46 | 47 | # seed 48 | random_seed = 2022 49 | random.seed(random_seed) 50 | torch.manual_seed(random_seed) 51 | 52 | tokenizer_path = args.pretrain_model+args.model+"/"+args.size 53 | model_path=args.pretrain_model+args.model_path 54 | # Tokenizer and settings for loading models 55 | disc_config = ElectraConfig.from_pretrained(model_path) 56 | 57 | if os.path.exists(tokenizer_path) == False: 58 | print("load Electra Tokenizer Fast from hub...") 59 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{args.size}-discriminator") 60 | hf_tokenizer.save_pretrained(tokenizer_path) 61 | else: 62 | print("load Electra Tokenizer Fast from tokenizer_path...") 63 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(tokenizer_path) 64 | 65 | 66 | ''' 67 | Data of emotional words by percentage mask sentimask_input_ids 68 | ''' 69 | sentivetor = np.load('./sentiment_vocab/senti_vector.npy') 70 | mask_token_index = hf_tokenizer.mask_token_id 71 | special_tok_ids = hf_tokenizer.all_special_ids 72 | vocab_size=hf_tokenizer.vocab_size 73 | 74 | json_wiki_path=args.data_path+'wikijson/wiki_50w_20%_clean.json' 75 | ann_data=args.data_path+args.model_path+'_'+args.size+str(args.sentimask_prob)+'/wiki'+str(args.sentimask_prob)+'_ANN.json' 76 | ann_data_path=args.data_path+args.model_path+'_'+args.size+str(args.sentimask_prob) 77 | neg_ann_data_path=args.data_path+args.model_path+'_'+args.size+str(args.sentimask_prob) 78 | neg_ann_data=args.data_path+args.model_path+'_'+args.size+str(args.sentimask_prob)+'/wiki_ANN_'+args.neg_ann_name+str(args.negative_num)+'_neg.npy' 79 | 80 | def get_senti_mask(example): 81 | newexample = {} 82 | senti_list = [] 83 | new_input_ids = torch.tensor(example['input_ids']).clone() 84 | for ids in example['input_ids']: 85 | if sentivetor[ids] == 1: 86 | senti_list.append(1) 87 | else: 88 | senti_list.append(0) 89 | 90 | senti_probability_matrix = torch.tensor(senti_list).clone() * args.sentimask_prob 91 | senti_mask = torch.bernoulli(senti_probability_matrix).bool() 92 | new_input_ids[senti_mask] = mask_token_index 93 | new_input_ids = new_input_ids.tolist() 94 | 95 | newexample['sentence_len'] = len(example['input_ids']) 96 | 97 | if len(example['input_ids']) < args.max_length: 98 | newexample['positive'] = example['input_ids'] + [hf_tokenizer.pad_token_id] * (args.max_length - len(example['input_ids'])) 99 | else: 100 | newexample['positive'] = example['input_ids'] 101 | 102 | if len(new_input_ids) < args.max_length: 103 | newexample['query'] = new_input_ids + [hf_tokenizer.pad_token_id] * (args.max_length - len(new_input_ids)) 104 | else: 105 | newexample['query'] = new_input_ids 106 | 107 | return newexample 108 | 109 | 110 | def SampleData(): 111 | if args.loadwikijson: 112 | e_wiki = datasets.dataset_dict.DatasetDict.from_json(json_wiki_path) 113 | jsonWiki = [] 114 | print("read wiki json!") 115 | for index,data in enumerate(e_wiki): 116 | if index == args.wiki_num: 117 | break 118 | new_data={} 119 | new_data['input_ids']=data['ori_input_ids'] 120 | jsonWiki.append(new_data) 121 | print("index:",index) 122 | print("read wiki json down!") 123 | senti_jsonWiki = [] 124 | 125 | for data in jsonWiki: 126 | newdata = get_senti_mask(data) 127 | senti_jsonWiki.append(newdata) 128 | else: 129 | print('load/download wiki dataset') 130 | if os.path.exists("./datasets/wiki") == False: 131 | wiki = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='./datasets')['train'] 132 | wiki.save_to_disk("./datasets/wiki") 133 | else: 134 | wiki = datasets.load_from_disk("./datasets/wiki") 135 | print('load/create data from wiki dataset for ELECTRA') 136 | ELECTRAProcessor = partial(ELECTRADataProcessor, hf_tokenizer=hf_tokenizer, max_length=args.max_length) 137 | e_wiki = ELECTRAProcessor(wiki).map(cache_file_name=f"electra_wiki_{args.max_length}.arrow", num_proc=8) 138 | 139 | # Sampling method: take the first 500000 items here, and modify them as you like 140 | jsonWiki = [] 141 | 142 | for index,data in enumerate(e_wiki): 143 | if index == args.wiki_num: 144 | break 145 | jsonWiki.append(data) 146 | 147 | senti_jsonWiki = [] 148 | 149 | for data in jsonWiki: 150 | newdata = get_senti_mask(data) 151 | senti_jsonWiki.append(newdata) 152 | print("write ann file") 153 | if os.path.exists(ann_data_path) == False: 154 | os.makedirs(ann_data_path) 155 | for data in senti_jsonWiki: 156 | with open(ann_data, 'a+', encoding='utf-8') as f_obj: 157 | json_str = json.dumps(data, ensure_ascii=False) 158 | f_obj.write(json_str + '\n') 159 | 160 | 161 | 162 | ''' 163 | ANN model definition 164 | ''' 165 | class ELEDIC_NLL_LN(nn.Module): 166 | 167 | def __init__(self, pretrained_model): 168 | super(ELEDIC_NLL_LN, self).__init__() 169 | self.DModel = pretrained_model 170 | self.norm = nn.LayerNorm(disc_config.hidden_size) 171 | 172 | def get_emb(self, input_ids, attention_mask, pooling='cls'): 173 | out = self.DModel(input_ids=input_ids, 174 | attention_mask=attention_mask) 175 | # Decide how to get the vector after the query passes through the model 176 | if pooling == 'cls': 177 | return self.norm(out.last_hidden_state[:, 0]) # [batch_size, hidden_size] 178 | 179 | if pooling == 'last-avg': 180 | last = out.last_hidden_state.transpose(1, 2) # [batch_size, hidden_size, seq_len] 181 | return self.norm(torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)) # [batch_size, hidden_size] 182 | 183 | if pooling == 'first-last-avg': 184 | first = out.hidden_states[1].transpose(1, 2) # [batch_size, hidden_size, seq_len] 185 | last = out.hidden_states[-1].transpose(1, 2) # [batch_size, hidden_size, seq_len] 186 | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch_size, hidden_size] 187 | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch_size, hidden_size] 188 | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch_size,2, hidden_size] 189 | return self.norm(torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)) # [batch_size, hidden_size] 190 | 191 | def forward(self,query_ids,attention_mask_q, 192 | input_ids_pos,attention_mask_pos, 193 | input_ids_neg,attention_mask_neg): 194 | 195 | ''' 196 | sentence a positive_passage 197 | sentence b negative_passage 198 | ''' 199 | q_embs = self.get_emb(query_ids, attention_mask_q) 200 | pos_embs = self.get_emb(input_ids_pos, attention_mask_pos) 201 | neg_embs = self.get_emb(input_ids_neg, attention_mask_neg) 202 | logit_matrix = torch.cat([(q_embs * pos_embs).sum(-1).unsqueeze(1), 203 | (q_embs * neg_embs).sum(-1).unsqueeze(1)], dim=1) # [B, 2] 204 | lsm = F.log_softmax(logit_matrix, dim=1) 205 | loss = -1.0 * lsm[:, 0] 206 | return (loss.mean(),) 207 | 208 | def save_DModel(self,output_dir): 209 | if os.path.exists(output_dir) == False: 210 | os.makedirs(output_dir) 211 | model_to_save = self.DModel 212 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 213 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 214 | torch.save(model_to_save.state_dict(), output_model_file) 215 | model_to_save.config.to_json_file(output_config_file) 216 | print("save model in :"+output_dir) 217 | 218 | 219 | class ANN_Dataset(torch.utils.data.Dataset): 220 | def __init__(self, data_file,read_data_fn): 221 | 222 | self.example_list = read_data_fn(data_file) 223 | self.size = len(self.example_list) 224 | 225 | def __getitem__(self, index): 226 | item = self.example_list[index] 227 | attention_mask = [1] * item['sentence_len'] + [0] * (args.max_length - item['sentence_len']) 228 | # Convert list to torch type because dataloader only accepts torch type data 229 | input_id = torch.tensor(item['input_id'], dtype=torch.int) 230 | attention_mask = torch.tensor(attention_mask, dtype=torch.bool) 231 | return input_id, attention_mask 232 | 233 | def __len__(self): 234 | return self.size 235 | 236 | def InferenceEmbedding(annmodel,train_dataloader,pooling = 'cls'): 237 | print("***** Running ANN Embedding Inference *****") 238 | embedding = [] 239 | annmodel.eval() 240 | for batch in tqdm(train_dataloader): 241 | batch = tuple(t.to(device) for t in batch) 242 | 243 | with torch.no_grad(): 244 | 245 | #ignore token_type_ids 246 | inputs = { 247 | "input_ids": batch[0].long(), 248 | "attention_mask": batch[1].long(), 249 | "pooling":pooling 250 | } 251 | # Select different EMB methods according to query Q or text P, which should be the same here 252 | embs = annmodel.get_emb(**inputs) 253 | 254 | embs = embs.detach().cpu().numpy() 255 | embedding.append(embs) 256 | 257 | embedding = np.concatenate(embedding, axis=0) 258 | return embedding 259 | 260 | def GenerateNegativePassaageID(ANN_Index,negative_sample_num): 261 | query_negative_passage = {} 262 | 263 | for query_idx in range(ANN_Index.shape[0]): 264 | 265 | # The index of the POS document is equal to the index of the query 266 | pos_pid = query_idx 267 | top_ann_pid = ANN_Index[query_idx, :].copy() 268 | 269 | query_negative_passage[query_idx] = [] 270 | 271 | neg_cnt = 0 272 | 273 | for neg_pid in top_ann_pid: 274 | # Skip if positive example is detected 275 | if neg_pid == pos_pid: 276 | continue 277 | 278 | if neg_cnt >= negative_sample_num: 279 | break 280 | 281 | query_negative_passage[query_idx].append(neg_pid) 282 | neg_cnt += 1 283 | 284 | return query_negative_passage 285 | 286 | 287 | if __name__ == "__main__": 288 | if os.path.exists(ann_data) == False: 289 | SampleData() 290 | else: 291 | print("data already in",ann_data) 292 | # load model 293 | print("starting generator ann data...") 294 | print("start generate ann data use checkpoint in "+ model_path) 295 | discmodel = ElectraModel.from_pretrained(model_path) 296 | annmodel = ELEDIC_NLL_LN(discmodel) 297 | annmodel.to(device) 298 | 299 | # inference query emb 300 | print("***** inference of query *****") 301 | def get_query_data(filename): 302 | example_list = [] 303 | for line in open(filename, 'rb'): 304 | example_dict = {} 305 | row = json.loads(line) 306 | example_dict['input_id'] = row['query'] 307 | example_dict['sentence_len'] = row['sentence_len'] 308 | example_list.append(example_dict) 309 | return example_list 310 | 311 | query_dataset = ANN_Dataset(ann_data, get_query_data) 312 | query_dataloader = torch.utils.data.DataLoader(query_dataset,batch_size=args.batch_size) 313 | query_embedding = InferenceEmbedding(annmodel, query_dataloader, args.pooling) 314 | print("***** Done query inference *****") 315 | 316 | # inference passages emb 317 | print("***** inference of passages *****") 318 | def get_passages_data(filename): 319 | example_list = [] 320 | for line in open(filename, 'rb'): 321 | example_dict = {} 322 | row = json.loads(line) 323 | example_dict['input_id'] = row['positive'] 324 | example_dict['sentence_len'] = row['sentence_len'] 325 | example_list.append(example_dict) 326 | return example_list 327 | 328 | passages_dataset = ANN_Dataset(ann_data, get_passages_data) 329 | passages_dataloader = torch.utils.data.DataLoader(passages_dataset,batch_size=args.batch_size) 330 | passages_embedding = InferenceEmbedding(annmodel, passages_dataloader, args.pooling) 331 | print("***** Done passage inference *****") 332 | 333 | # Building Ann index to find TOPK 334 | dim = passages_embedding.shape[1] 335 | print('passage embedding shape: ' + str(passages_embedding.shape)) 336 | 337 | faiss.omp_set_num_threads(16) 338 | cpu_index = faiss.IndexFlatIP(dim) 339 | cpu_index.add(passages_embedding) 340 | print("***** Done ANN Index *****") 341 | 342 | ''' 343 | In passage_ Embedding for finding TOPK similarity in embedding 344 | ''' 345 | # measure ANN mrr 346 | print("start searching ANN...") 347 | _, ANN_Index = cpu_index.search(query_embedding, args.ANN_topK) 348 | print("searching finish...") 349 | 350 | print("start generate neg data...") 351 | # Build hard negative based on the returned Ann index 352 | query_negative_passage = GenerateNegativePassaageID(ANN_Index,args.negative_num) 353 | if os.path.exists(neg_ann_data_path) == False: 354 | os.makedirs(neg_ann_data_path) 355 | # Organize new query_id, pos_pid ,neg_pid and save 356 | np.save(neg_ann_data, query_negative_passage) 357 | print("finished generating ann data") 358 | -------------------------------------------------------------------------------- /pretrain/SentiWSP_Pretrain_ANCE_TRAIN.py: -------------------------------------------------------------------------------- 1 | import os, sys, random 2 | from pathlib import Path 3 | from datetime import datetime, timezone, timedelta 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining, get_linear_schedule_with_warmup, ElectraModel 12 | from _utils.utils import * 13 | from _utils.would_like_to_pr import * 14 | from tqdm import tqdm 15 | import faiss 16 | import json 17 | import argparse 18 | from transformers import WEIGHTS_NAME, CONFIG_NAME 19 | import pickle 20 | # parser config 21 | parser = argparse.ArgumentParser(description='ANN data generate configuration') 22 | parser.add_argument('--model', default='electra',help='pre train model type') 23 | parser.add_argument('--model_path', default='rawbase',help='model path required to generate Ann data') 24 | parser.add_argument('--warm_name', default='largesen',help='model path required to generate Ann data') 25 | parser.add_argument('--size', type=str,default='large',help='the size of model') 26 | parser.add_argument('--tokenizer_path', default='./pretrain_model/electra/large/',help='tokenizer path required to generate Ann data') 27 | parser.add_argument('--max_length',type=int, default = 128,help='the max_length of input sequence') 28 | parser.add_argument('--Neg_File_name', default='sen',help='Last generated Ann hard neg file path') 29 | parser.add_argument('--negative_num', type=int,default=7,help='Last generated Ann hard neg file path') 30 | parser.add_argument('--ANN_File_path', default='./ANCEdata/',help='The path of the query positve generated by sampling to the data file') 31 | parser.add_argument('--gpu_num', type=int ,default = 4 ,help='Number of GPUs used in training') 32 | # parser.add_argument('--dpp_gpu_num', default="0, 1",help='available GPU numbers in a multi GPU environment') 33 | parser.add_argument('--single_gpunum', default='0',help='GPU number used in single GPU environment') 34 | parser.add_argument('--save_model', type=bool ,default = True ,help='wheather you want to save the model') 35 | parser.add_argument('--save_model_path', default='./save_ance/',help='the save dir of fine-tunning model path') 36 | parser.add_argument('--save_model_name', default='ANCE_checkpoint',help='model name') 37 | parser.add_argument('--sentimask_prob', type=float ,default = 0.7 ,help='Proportion of emotional words mask') 38 | 39 | parser.add_argument('--batch_size',type=int ,default=32,help='the batch size of training process') 40 | parser.add_argument('--max_epoch', type=int,default=10,help='the max epoch of training process') 41 | parser.add_argument('--lr', default = 1e-5,help='the learning rate of training process') 42 | parser.add_argument('--eps', default = 1e-8,help='the eps of training process') 43 | parser.add_argument('--pretrain_model', type=str,default='./pretrain_model/',help='pre train model path') 44 | parser.add_argument('--simcal', type=str,default='cos',help='The Similarity of loss calculate[dot/cos]') 45 | parser.add_argument('--losscal', type=str,default='nll',help='The loss calculate[nll/triplet]') 46 | 47 | parser.add_argument("--rank", type=int,default=-1, help="rank") 48 | parser.add_argument("--local_rank", type=int,default=-1, help="local rank") 49 | args = parser.parse_args() 50 | 51 | # seed 52 | random_seed = 2022 53 | random.seed(random_seed) 54 | torch.manual_seed(random_seed) 55 | 56 | ann_data=args.ANN_File_path+args.warm_name+'_'+args.size+str(args.sentimask_prob)+'/wiki'+str(args.sentimask_prob)+'_ANN.json' 57 | neg_ann_path=args.ANN_File_path+args.warm_name+'_'+args.size+str(args.sentimask_prob)+'/wiki_ANN_'+args.Neg_File_name+str(args.negative_num)+'_neg.npy' 58 | model_path=args.pretrain_model+args.model_path 59 | # mult GPU 60 | if args.gpu_num > 1: 61 | #os.environ["CUDA_VISIBLE_DEVICES"] = args.dpp_gpu_num 62 | torch.distributed.init_process_group(backend="nccl") 63 | 64 | local_rank = torch.distributed.get_rank() 65 | torch.cuda.set_device(local_rank) 66 | device = torch.device("cuda", local_rank) 67 | else: 68 | device = torch.device("cuda:" + str(args.single_gpunum) if torch.cuda.is_available() else "cpu") 69 | 70 | # Load model settings 71 | disc_config = ElectraConfig.from_pretrained(model_path) 72 | 73 | # Define Ann triplet data type 74 | class ANNTripletTrainingData(torch.utils.data.Dataset): 75 | 76 | def __init__(self, ANN_Neg_ID_File, ANN_File): 77 | 78 | # Load negative sample sequence number 79 | self.query_negative_passage = np.load(ANN_Neg_ID_File,allow_pickle=True).item() 80 | 81 | # Ann data after pad loading (including query and positive, and positive is the original data) 82 | self.ann_data = [] 83 | for line in open(ANN_File, 'rb'): 84 | example_dict = {} 85 | row = json.loads(line) 86 | for key,value in row.items(): 87 | example_dict[key] = value 88 | self.ann_data.append(example_dict) 89 | 90 | self.ANN_Triplet_data = [] 91 | 92 | for index,data in enumerate(self.ann_data): 93 | example_dict = {} 94 | example_dict['positive'] = data['positive'] 95 | example_dict['query'] = data['query'] 96 | example_dict['sentence_len'] = data['sentence_len'] 97 | for neg_id in self.query_negative_passage[index]: 98 | example_dict['negative'] = self.ann_data[neg_id]['positive'] 99 | example_dict['neg_sentence_len'] = self.ann_data[neg_id]['sentence_len'] 100 | self.ANN_Triplet_data.append(example_dict) 101 | 102 | self.size = len(self.ANN_Triplet_data) 103 | 104 | def __getitem__(self, index): 105 | item = self.ANN_Triplet_data[index] 106 | 107 | pos_attention_mask = [1] * item['sentence_len'] + [0] * (args.max_length - item['sentence_len']) 108 | q_attention_mask = pos_attention_mask 109 | neg_attention_mask = [1] * item['neg_sentence_len'] + [0] * (args.max_length - item['neg_sentence_len']) 110 | q_input_id = torch.tensor(item['query'], dtype=torch.int) 111 | pos_input_id = torch.tensor(item['positive'], dtype=torch.int) 112 | neg_input_id = torch.tensor(item['negative'], dtype=torch.int) 113 | q_attention_mask = torch.tensor(q_attention_mask, dtype=torch.bool) 114 | pos_attention_mask = torch.tensor(pos_attention_mask, dtype=torch.bool) 115 | neg_attention_mask = torch.tensor(neg_attention_mask, dtype=torch.bool) 116 | return q_input_id, q_attention_mask, pos_input_id, pos_attention_mask, neg_input_id, neg_attention_mask 117 | 118 | def __len__(self): 119 | return self.size 120 | 121 | class DotProductSimilarity(nn.Module): 122 | def __init__(self,scale_output=False): 123 | super(DotProductSimilarity,self).__init__() 124 | self.scale_output=scale_output 125 | def forward(self,tensor_1,tensor_2): 126 | result=(tensor_1*tensor_2).sum(dim=-1) 127 | if(self.scale_output): 128 | result/=math.sqrt(tensor_1.size(-1)) 129 | return result 130 | 131 | ''' 132 | ANN model definition 133 | ''' 134 | class ELEDIC_NLL_LN(nn.Module): 135 | 136 | def __init__(self, pretrained_model,temp=0.05,margin=0.2): 137 | super(ELEDIC_NLL_LN, self).__init__() 138 | self.DModel = pretrained_model 139 | self.norm = nn.LayerNorm(disc_config.hidden_size) 140 | self.temp=temp 141 | self.margin=margin 142 | def get_emb(self, input_ids, attention_mask, pooling='cls'): 143 | out = self.DModel(input_ids=input_ids, 144 | attention_mask=attention_mask) 145 | # Decide how to get the vector after the query passes through the model 146 | if pooling == 'cls': 147 | return self.norm(out.last_hidden_state[:, 0]) # [batch_size, hidden_size] 148 | 149 | if pooling == 'last-avg': 150 | last = out.last_hidden_state.transpose(1, 2) # [batch_size, hidden_size, seq_len] 151 | return self.norm(torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)) # [batch_size, hidden_size] 152 | 153 | if pooling == 'first-last-avg': 154 | first = out.hidden_states[1].transpose(1, 2) # [batch_size, hidden_size, seq_len] 155 | last = out.hidden_states[-1].transpose(1, 2) # [batch_size, hidden_size, seq_len] 156 | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch_size, hidden_size] 157 | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch_size, hidden_size] 158 | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch_size,2, hidden_size] 159 | return self.norm(torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)) # [batch_size, hidden_size] 160 | 161 | def forward(self,query_ids,attention_mask_q, 162 | input_ids_pos,attention_mask_pos, 163 | input_ids_neg,attention_mask_neg): 164 | 165 | ''' 166 | query_ids query 167 | input_ids_pos positive_passage 168 | input_ids_neg negative_passage 169 | ''' 170 | q_embs = self.get_emb(query_ids, attention_mask_q) 171 | pos_embs = self.get_emb(input_ids_pos, attention_mask_pos) 172 | neg_embs = self.get_emb(input_ids_neg, attention_mask_neg) 173 | 174 | if args.losscal=='nll': 175 | batch_size=q_embs.size(0) 176 | y_pred = torch.stack([q_embs, pos_embs, neg_embs], dim=1) 177 | y_pred = y_pred.reshape(batch_size * 3, -1) 178 | y_true = torch.arange(y_pred.shape[0], device=device) 179 | use_row = torch.where((y_true + 1) % 3 != 0)[0] 180 | y_true = (use_row - use_row % 3 * 2) + 1 181 | 182 | if args.simcal=='cos': 183 | sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1) 184 | elif args.simcal=='dot': 185 | dot = DotProductSimilarity() 186 | sim = dot(y_pred.unsqueeze(1),y_pred.unsqueeze(0)) 187 | else: 188 | sys.exit() 189 | 190 | sim = sim - torch.eye(y_pred.shape[0], device=device) * 1e12 191 | sim = torch.index_select(sim, 0, use_row) 192 | sim = sim / self.temp 193 | loss = F.cross_entropy(sim, y_true) 194 | 195 | elif args.losscal=='triplet': 196 | 197 | if args.simcal=='cos': 198 | pos_dist = F.cosine_similarity(q_embs,pos_embs) 199 | neg_dist = F.cosine_similarity(q_embs,neg_embs) 200 | elif args.simcal=='dot': 201 | dot = DotProductSimilarity() 202 | pos_dist = dot(q_embs,pos_embs) 203 | neg_dist = dot(q_embs,neg_embs) 204 | else: 205 | sys.exit() 206 | loss = pos_dist - neg_dist + self.margin 207 | loss = loss.mean() 208 | else: 209 | sys.exit() 210 | return (loss,) 211 | 212 | def save_DModel(self,output_dir): 213 | if os.path.exists(output_dir) == False: 214 | os.makedirs(output_dir) 215 | model_to_save = self.DModel 216 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 217 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 218 | torch.save(model_to_save.state_dict(), output_model_file) 219 | model_to_save.config.to_json_file(output_config_file) 220 | print("save model in :"+output_dir) 221 | 222 | 223 | def train(model, train_loader): 224 | print("***** Running training *****") 225 | model.train() 226 | model.zero_grad() 227 | 228 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, eps=args.eps) 229 | total_steps = len(train_loader) * args.max_epoch 230 | warmup_steps = 0.1 * len(train_loader) / args.gpu_num 231 | scheduler = get_linear_schedule_with_warmup(optimizer, 232 | num_warmup_steps = warmup_steps, # Default value in run_glue.py 233 | num_training_steps = total_steps) 234 | step=0 235 | loss_list=[] 236 | for epoch in range(args.max_epoch): 237 | total_train_loss = 0 238 | for iter_num, batch in enumerate(tqdm(train_loader)): 239 | step += 1 240 | batch = tuple(t.to(device) for t in batch) 241 | inputs = { 242 | "query_ids": batch[0].long(), 243 | "attention_mask_q": batch[1].long(), 244 | "input_ids_pos": batch[2].long(), 245 | "attention_mask_pos": batch[3].long(), 246 | "input_ids_neg": batch[4].long(), 247 | "attention_mask_neg": batch[5].long()} 248 | outputs = model(**inputs) 249 | loss = outputs[0] 250 | total_train_loss += loss.item() 251 | 252 | loss.backward() 253 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0) 254 | optimizer.step() 255 | scheduler.step() # Update learning rate schedule 256 | model.zero_grad() 257 | if step % 20 == 0: 258 | if dist.get_rank() == 0: 259 | loss_list.append(loss.item()) 260 | print(">>> epoth: %d, iter_num: %d, loss: %.4f" % (epoch, step, loss.item())) 261 | if step == 200: 262 | if dist.get_rank() == 0: 263 | if args.save_model: 264 | output_dir = args.save_model_path 265 | save_model(model, output_dir, step) 266 | if step % 1000 == 0 and step > 0: 267 | if dist.get_rank() == 0: 268 | if args.save_model: 269 | output_dir = args.save_model_path 270 | save_model(model, output_dir, step) 271 | if step == 2000: 272 | if dist.get_rank() == 0: 273 | with open('./loss_ance.pkl', 'wb') as f: 274 | pickle.dump(loss_list, f, pickle.HIGHEST_PROTOCOL) 275 | sys.exit() 276 | 277 | if dist.get_rank() == 0: 278 | if args.save_model: 279 | output_dir = args.save_model_path 280 | save_model(model, output_dir, epoch + 1) 281 | 282 | print("Epoch: %d, Average training loss: %.4f" %(epoch, total_train_loss/len(train_loader))) 283 | 284 | def save_model(model,output_dir,flag): 285 | if flag>100: 286 | save_dis_path=output_dir + str(flag) + 'iter_'+args.save_model_name 287 | else: 288 | save_dis_path = output_dir + str(flag) + 'epoch_'+args.save_model_name 289 | if os.path.exists(save_dis_path) == False: 290 | os.makedirs(save_dis_path) 291 | if args.gpu_num > 1: 292 | model.module.save_DModel(save_dis_path) 293 | else: 294 | model.save_DModel(save_dis_path) 295 | 296 | if __name__ == "__main__": 297 | ann_dataset = ANNTripletTrainingData(neg_ann_path, ann_data) 298 | ann_dataloader = torch.utils.data.DataLoader(ann_dataset, batch_size=args.batch_size) 299 | 300 | if args.gpu_num > 1: 301 | sampler = DistributedSampler(ann_dataset) 302 | ann_dataloader = torch.utils.data.DataLoader(ann_dataset, batch_size=args.batch_size,sampler=sampler) 303 | else: 304 | ann_dataloader = torch.utils.data.DataLoader(ann_dataset, batch_size=args.batch_size) 305 | 306 | print("starting using ann data to train...") 307 | print("start train model in "+ model_path) 308 | discmodel = ElectraModel.from_pretrained(model_path) 309 | annmodel = ELEDIC_NLL_LN(discmodel) 310 | annmodel.to(device) 311 | if args.gpu_num > 1: 312 | if torch.cuda.device_count() > 1: 313 | print("Let's use", torch.cuda.device_count(), "GPUs!") 314 | annmodel = torch.nn.parallel.DistributedDataParallel(annmodel, 315 | device_ids=[local_rank], 316 | output_device=local_rank) 317 | else: 318 | print("There are not enough GPUs available!") 319 | sys.exit() 320 | train(annmodel, ann_dataloader) 321 | -------------------------------------------------------------------------------- /pretrain/SentiWSP_Pretrain_Warmup_inbatch.py: -------------------------------------------------------------------------------- 1 | import os, sys, random 2 | from pathlib import Path 3 | from datetime import datetime, timezone, timedelta 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining, get_linear_schedule_with_warmup, ElectraModel 12 | from _utils.utils import * 13 | from _utils.would_like_to_pr import * 14 | from tqdm import tqdm 15 | from transformers import WEIGHTS_NAME, CONFIG_NAME 16 | import argparse 17 | 18 | 19 | parser = argparse.ArgumentParser(description='Pre training model configuration') 20 | parser.add_argument('--model', default='electra',help='pre train model type') 21 | parser.add_argument('--size', default='small',help='pre train model size,choose in [small, base, large]') 22 | parser.add_argument('--dataset', default='wiki',help='choose in [wiki,owt,merg]') 23 | parser.add_argument('--gpu_num', type=int ,default=4,help='pre train gpu num') 24 | parser.add_argument('--load_model', type=str,default='word_level5',help='continue train model path') 25 | parser.add_argument('--pretrain_model', type=str,default='./pretrain_model/',help='pre train model path') 26 | parser.add_argument("--rank", type=int,default=-1, help="rank") 27 | parser.add_argument("--local_rank", type=int,default=-1, help="local rank") 28 | parser.add_argument('--batch_size', type=int,default = 64,help='the batch_size in pretrain process') 29 | parser.add_argument('--max_len', type=int,default = 128,help='the seq max_len in pretrain process') 30 | parser.add_argument('--save_pretrain_model', type=str ,default='./save_pretrain_model/',help='save will create pre train model path') 31 | parser.add_argument('--Negative_type', type=str,default = 'random',help='[random,generate]the negative sample creat way') 32 | parser.add_argument('--random_type', type=str,default = 'sentivocab',help='[sentivocab,allvocab]the random replace in which vocab') 33 | parser.add_argument("--sentimask_prob", type=float,default=0.5, help="The sentiment word mask probability") 34 | parser.add_argument("--train_type", type=str,default='unsup', help="[unsup or sup]]") 35 | parser.add_argument('--save_model', type=bool ,default=True,help='Whether to Save model') 36 | parser.add_argument('--use_jsondata', type=bool ,default=True,help='use my jsondata') 37 | parser.add_argument('--jsondata_path', type=str ,default="./datasets/wikijson/wiki_50w_20%.json",help='json data path') 38 | args = parser.parse_args() 39 | 40 | 41 | # The default configuration is config. If there are no special requirements, it does not need to be changed 42 | config = MyConfig({ 43 | 'base_run_name': 'ELECTRA', 44 | 'seed': 2022, 45 | 'electra_mask_style': True, 46 | 'config_path':'./config_pretrain/', 47 | 'num_workers': 3 48 | }) 49 | i = ['small', 'base', 'large'].index(args.size) 50 | config.mask_prob = [0.15, 0.15, 0.25][i] 51 | config.lr = [2e-5, 1e-5, 1e-5][i] 52 | config.max_epoch = 10 53 | config.single_gpunum = 0 54 | generator_size_divisor = [4, 3, 4][i] 55 | # countinue train model path 56 | config.model_path = args.pretrain_model+args.load_model+'/'+args.size+'/disc' 57 | print("load raw model in :",config.model_path) 58 | config.pooling = 'cls' 59 | config.temperature = 0.05 60 | # If Negative_Type select the generator population, and specify the generator path here 61 | config.generator_path = args.pretrain_model+args.load_model+'/'+args.size+'/gen' 62 | 63 | 64 | print("train type:",args.train_type) 65 | print("now learning rate:",config.lr) 66 | if args.train_type=='sup': 67 | print("Negative generation way:",args.Negative_type) 68 | if args.Negative_type=='random': 69 | print("random way in:", args.random_type) 70 | else: 71 | print("negative sample in batch") 72 | 73 | # seed 74 | random.seed(config.seed) 75 | np.random.seed(config.seed) 76 | torch.manual_seed(config.seed) 77 | 78 | # mult GPU 79 | if args.gpu_num > 1: 80 | # os.environ["CUDA_VISIBLE_DEVICES"] = config.dpp_gpu_num 81 | torch.distributed.init_process_group(backend="nccl") 82 | 83 | local_rank = torch.distributed.get_rank() 84 | torch.cuda.set_device(local_rank) 85 | device = torch.device("cuda", local_rank) 86 | else: 87 | device = torch.device("cuda:" + str(config.single_gpunum) if torch.cuda.is_available() else "cpu") 88 | 89 | 90 | if args.use_jsondata == False: 91 | if args.dataset=='wiki' or args.dataset=='merg': 92 | print('load/download wiki dataset') 93 | if os.path.exists("./datasets/wiki") == False: 94 | wiki = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='./datasets')['train'] 95 | wiki.save_to_disk("./datasets/wiki") 96 | else: 97 | wiki = datasets.load_from_disk("./datasets/wiki") 98 | print('load/create data from wiki dataset for ELECTRA') 99 | if args.dataset=='owt' or args.dataset=='merg': 100 | print('load/download OpenWebText Corpus') 101 | if os.path.exists("./datasets/owt") == False: 102 | owt = datasets.load_dataset('openwebtext', cache_dir='./datasets')['train'] 103 | owt.save_to_disk('./datasets/owt') 104 | else: 105 | owt = datasets.load_from_disk('./datasets/owt') 106 | print('load/create data from OpenWebText Corpus for ELECTRA') 107 | 108 | pretrain_path = args.pretrain_model+args.model+"/"+args.size 109 | 110 | if os.path.exists(pretrain_path) == False: 111 | os.makedirs(pretrain_path) 112 | print("load Electra Tokenizer Fast from hub...") 113 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{args.size}-discriminator") 114 | hf_tokenizer.save_pretrained(pretrain_path) 115 | else: 116 | print("load Electra Tokenizer Fast from pretrain_path...") 117 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(pretrain_path) 118 | 119 | ELECTRAProcessor = partial(ELECTRADataProcessor, hf_tokenizer=hf_tokenizer, max_length=args.max_len) 120 | dsets = [] 121 | if args.dataset=='wiki' or args.dataset=='merg': 122 | e_wiki = ELECTRAProcessor(wiki).map(cache_file_name=f"./datasets/wiki/electra_wiki_{args.max_len}.arrow", num_proc=8) 123 | print("wiki len", len(e_wiki)) 124 | dsets.append(e_wiki) 125 | if args.dataset=='owt' or args.dataset=='merg': 126 | e_owt = ELECTRAProcessor(owt, apply_cleaning=False).map(cache_file_name=f"./datasets/owt/electra_owt_{args.max_len}.arrow", num_proc=8) 127 | print("owt len",len(e_owt)) 128 | dsets.append(e_owt) 129 | 130 | if len(dsets)>1: 131 | merged_dsets = datasets.concatenate_datasets(dsets) 132 | print("use merge dataset len:",len(merged_dsets)) 133 | elif args.dataset=='wiki': 134 | merged_dsets=dsets[0] 135 | print("use wiki dataset len:",len(merged_dsets)) 136 | else: 137 | merged_dsets=dsets[0] 138 | print("use owt dataset len:",len(merged_dsets)) 139 | 140 | ''' 141 | Original sentence retention data e_wiki_org (example 'input_ids') 142 | ''' 143 | #e_wiki_org = e_wiki 144 | pass 145 | 146 | ''' 147 | Data of emotional words by percentage mask sentimask_input_ids 148 | ''' 149 | sentivetor = np.load('./sentiment_vocab/senti_vector.npy') 150 | mask_token_index = hf_tokenizer.mask_token_id 151 | special_tok_ids = hf_tokenizer.all_special_ids 152 | vocab_size=hf_tokenizer.vocab_size 153 | 154 | def get_senti_mask(example): 155 | senti_list = [] 156 | 157 | new_input_ids = torch.tensor(example['input_ids']).clone() 158 | 159 | for ids in example['input_ids']: 160 | if sentivetor[ids] == 1: 161 | senti_list.append(1) 162 | else: 163 | senti_list.append(0) 164 | 165 | senti_probability_matrix = torch.tensor(senti_list).clone() * args.sentimask_prob 166 | senti_mask = torch.bernoulli(senti_probability_matrix).bool() 167 | new_input_ids[senti_mask] = mask_token_index 168 | example['sentimask_ids'] = new_input_ids 169 | 170 | return example 171 | print("senti mask map in"+args.dataset) 172 | if args.dataset=="wiki": 173 | merged_dsets = merged_dsets.map(get_senti_mask, 174 | cache_file_name = f"./datasets/wiki/sentence/sentimask/electra_wiki_{args.sentimask_prob}_{args.max_len}_sentimask_map.arrow",num_proc=16) 175 | elif args.dataset=="owt": 176 | merged_dsets = merged_dsets.map(get_senti_mask, 177 | cache_file_name = f"./datasets/owt/sentence/sentimask/electra_owt_{args.sentimask_prob}_{args.max_len}_sentimask_map.arrow",num_proc=16) 178 | else: 179 | merged_dsets = merged_dsets.map(get_senti_mask, 180 | cache_file_name = f"./datasets/merged_dsets/sentence/sentimask/electra_merg_{args.sentimask_prob}_{args.max_len}_sentimask_map.arrow",num_proc=16) 181 | 182 | 183 | 184 | ''' 185 | e_wiki_org、e_wiki_sentimask pad、attention_mask、token_type_ids 186 | ''' 187 | def get_org_sentimask_pad_mask_and_token_type(example): 188 | 189 | # PAD sentence 190 | if len(example['input_ids']) < args.max_len: 191 | example['ori_input_ids'] = example['input_ids'] + [hf_tokenizer.pad_token_id] * (args.max_len - len(example['input_ids'])) 192 | else: 193 | example['ori_input_ids'] = example['input_ids'] 194 | 195 | if len(example['sentimask_ids']) < args.max_len: 196 | example['sentimask_input_ids'] = example['sentimask_ids'] + [hf_tokenizer.pad_token_id] * (args.max_len - len(example['sentimask_ids'])) 197 | else: 198 | example['sentimask_input_ids'] = example['sentimask_ids'] 199 | 200 | attention_mask = torch.tensor(example['ori_input_ids']) != hf_tokenizer.pad_token_id 201 | 202 | # sentence A (token_type_ids =0) sentence B (token_type_ids =1) 203 | token_type_ids = torch.tensor([0]*example['sentA_length'] + [1]*(args.max_len-example['sentA_length'])) 204 | example['token_type_ids'] = token_type_ids 205 | example['attention_mask'] = attention_mask 206 | return example 207 | 208 | print("pad mask map in"+args.dataset) 209 | if args.dataset=="wiki": 210 | merged_dsets = merged_dsets.map(get_org_sentimask_pad_mask_and_token_type, 211 | cache_file_name = f"./datasets/wiki/sentence/padmask/electra_wiki_{args.sentimask_prob}_{args.max_len}_padmask_map.arrow",num_proc=16) 212 | elif args.dataset=="owt": 213 | merged_dsets = merged_dsets.map(get_org_sentimask_pad_mask_and_token_type, 214 | cache_file_name = f"./datasets/owt/sentence/padmask/electra_owt_{args.sentimask_prob}_{args.max_len}_padmask_map.arrow",num_proc=16) 215 | else: 216 | merged_dsets = merged_dsets.map(get_org_sentimask_pad_mask_and_token_type, 217 | cache_file_name = f"./datasets/merged_dsets/sentence/padmask/electra_merg_{args.sentimask_prob}_{args.max_len}_padmask_map.arrow",num_proc=16) 218 | 219 | elif args.use_jsondata == True: 220 | print("use json data file in " + args.jsondata_path) 221 | merged_dsets = datasets.dataset_dict.DatasetDict.from_json(args.jsondata_path) 222 | else: 223 | sys.exit() 224 | # Supervised learning generates negative samples 225 | if args.train_type == 'sup': 226 | senti_vocab = np.load('./sentiment/senti_vocab.npy') 227 | ''' 228 | Generate data of alternative words randomly or by generator e_wiki_rep 229 | ''' 230 | if args.Negative_type == 'random': 231 | 232 | def get_replace_word(example): 233 | 234 | rep_input_ids = [] 235 | input_ids = example['sentimask_input_ids'][:] 236 | 237 | for ids in input_ids: 238 | if ids == mask_token_index: 239 | if args.random_type == 'allvocab': 240 | rep_input_ids.append(random.randint(0,vocab_size-1)) 241 | elif args.random_type == 'sentivocab': 242 | rep_input_ids.append(senti_vocab[random.randint(0,len(senti_vocab)-1)]) 243 | else: 244 | rep_input_ids.append(ids) 245 | 246 | # PAD 247 | if len(rep_input_ids) < args.max_len: 248 | example['rep_input_ids'] = rep_input_ids + [hf_tokenizer.pad_token_id] * (args.max_len - len(rep_input_ids)) 249 | else: 250 | example['rep_input_ids'] = rep_input_ids 251 | 252 | return example 253 | print("replace in random") 254 | print("replace map in" + args.dataset) 255 | if args.dataset == "wiki": 256 | merged_dsets = merged_dsets.map(get_replace_word, 257 | cache_file_name=f"./datasets/wiki/sentence/repword/electra_wiki_{args.sentimask_prob}_{args.max_len}_{args.random_type}_repword_map.arrow", 258 | num_proc=16) 259 | elif args.dataset == "owt": 260 | merged_dsets = merged_dsets.map(get_replace_word, 261 | cache_file_name=f"./datasets/owt/sentence/repword/electra_owt_{args.sentimask_prob}_{args.max_len}_{args.random_type}_repword_map.arrow", 262 | num_proc=16) 263 | else: 264 | merged_dsets = merged_dsets.map(get_replace_word, 265 | cache_file_name=f"./datasets/merged_dsets/sentence/repword/electra_merg_{args.sentimask_prob}_{args.max_len}_{args.random_type}_repword_map.arrow", 266 | num_proc=16) 267 | merged_dsets.set_format(type='torch', columns=['ori_input_ids', 'sentimask_input_ids','rep_input_ids','token_type_ids', 'attention_mask']) 268 | 269 | elif args.Negative_type == 'generate': 270 | print("replace in model train generator") 271 | merged_dsets.set_format(type='torch', columns=['ori_input_ids', 'sentimask_input_ids','token_type_ids', 'attention_mask']) 272 | 273 | else: 274 | print("no negative type name " + args.Negative_type) 275 | sys.exit() 276 | 277 | elif args.train_type == 'unsup': 278 | 279 | merged_dsets.set_format(type='torch', columns=['ori_input_ids', 'sentimask_input_ids', 'attention_mask']) 280 | 281 | else: 282 | print("no simcse train type name " + args.train_type) 283 | sys.exit() 284 | 285 | 286 | class SimcseModel(nn.Module): 287 | 288 | def __init__(self, pretrained_model, model_type='unsup', pooling='cls'): 289 | super(SimcseModel, self).__init__() 290 | self.DModel = pretrained_model 291 | self.pooling = pooling 292 | self.model_type = model_type 293 | 294 | ''' 295 | ori_input_ids、sentimask_input_ids, rep_input_ids的形状均为: 296 | (batch_size,max_seq_len) 297 | ''' 298 | 299 | def forward(self, ori_input_ids, pos_input_ids, attention_mask, token_type_ids=None, neg_input_ids=None): 300 | 301 | if self.model_type == 'sup': 302 | # merge 303 | input_ids = torch.stack([ori_input_ids, pos_input_ids, neg_input_ids], dim=1) 304 | input_ids = input_ids.view(args.batch_size * 3, -1).to(device) 305 | 306 | # attention_mask and token_type_ids copy 3 times 307 | attention_mask = torch.stack([attention_mask, attention_mask, attention_mask], dim=1) 308 | attention_mask = attention_mask.view(args.batch_size * 3, -1).to(device) 309 | # token_type_ids = torch.stack([token_type_ids, token_type_ids, token_type_ids],dim=1) 310 | # token_type_ids = token_type_ids.view(args.batch_size * 3, -1).to(device) 311 | 312 | elif self.model_type == 'unsup': 313 | # merge 314 | input_ids = torch.stack([ori_input_ids, pos_input_ids], dim=1) 315 | input_ids = input_ids.view(args.batch_size * 2, -1).to(device) 316 | 317 | # attention_mask and token_type_ids copy 2 times 318 | attention_mask = torch.stack([attention_mask, attention_mask], dim=1) 319 | attention_mask = attention_mask.view(args.batch_size * 2, -1).to(device) 320 | # token_type_ids = torch.stack([token_type_ids, token_type_ids],dim=1) 321 | # token_type_ids = token_type_ids.view(args.batch_size * 2, -1).to(device) 322 | 323 | out = self.DModel(input_ids, attention_mask) 324 | 325 | if self.pooling == 'cls': 326 | return out.last_hidden_state[:, 0] # [3 * batch_size, hidden_size] 327 | 328 | if self.pooling == 'last-avg': 329 | last = out.last_hidden_state.transpose(1, 2) # [3 * batch_size, hidden_size, seq_len] 330 | return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [3 * batch_size, hidden_size] 331 | 332 | if self.pooling == 'first-last-avg': 333 | first = out.hidden_states[1].transpose(1, 2) # [3 * batch_size, hidden_size, seq_len] 334 | last = out.hidden_states[-1].transpose(1, 2) # [3 * batch_size, hidden_size, seq_len] 335 | first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [3 * batch_size, hidden_size] 336 | last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [3 * batch_size, hidden_size] 337 | avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [3 * batch_size,2, hidden_size] 338 | return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [3 * batch_size, hidden_size] 339 | 340 | def save_DModel(self, output_dir): 341 | if os.path.exists(output_dir) == False: 342 | os.makedirs(output_dir) 343 | model_to_save = self.DModel 344 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 345 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 346 | torch.save(model_to_save.state_dict(), output_model_file) 347 | model_to_save.config.to_json_file(output_config_file) 348 | print("save model in :" + output_dir) 349 | 350 | 351 | # import model 352 | if args.Negative_type == 'generate' and args.train_type == 'sup': 353 | rep_generator = ElectraForMaskedLM.from_pretrained(config.generator_path) 354 | rep_generator.generator_lm_head.weight = rep_generator.electra.embeddings.word_embeddings.weight 355 | 356 | DModel = ElectraModel.from_pretrained(config.model_path) 357 | SimCSEModel = SimcseModel(pretrained_model=DModel, model_type = args.train_type,pooling=config.pooling) 358 | 359 | 360 | # simcse Supervised loss function 361 | def simcse_sup_loss(y_pred, temp=0.05): 362 | """Supervised loss function 363 | y_pred (tensor): electra output, [batch_size * 3, hidden_size] 364 | 365 | """ 366 | y_true = torch.arange(y_pred.shape[0], device=device) 367 | use_row = torch.where((y_true + 1) % 3 != 0)[0] 368 | y_true = (use_row - use_row % 3 * 2) + 1 369 | sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1) 370 | sim = sim - torch.eye(y_pred.shape[0], device=device) * 1e12 371 | sim = torch.index_select(sim, 0, use_row) 372 | sim = sim / temp 373 | loss = F.cross_entropy(sim, y_true) 374 | return loss 375 | 376 | 377 | # simcse Unsupervised loss function 378 | def simcse_unsup_loss(y_pred, temp=0.05): 379 | """Unsupervised loss function 380 | y_pred (tensor): electra output, [batch_size * 2, hidden_size] 381 | 382 | """ 383 | y_true = torch.arange(y_pred.shape[0], device=device) 384 | y_true = (y_true - y_true % 2 * 2) + 1 385 | sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1) 386 | sim = sim - torch.eye(y_pred.shape[0], device=device) * 1e12 387 | sim = sim / temp 388 | loss = F.cross_entropy(sim, y_true) 389 | return loss 390 | 391 | 392 | def train(model, dataset, max_epoch, generator = None): 393 | 394 | 395 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config.lr, eps=1e-8, betas=(0.9,0.999), weight_decay=0.01) 396 | 397 | total_steps = (dataset.num_rows // args.batch_size) * config.max_epoch 398 | 399 | scheduler = get_linear_schedule_with_warmup(optimizer, 400 | num_warmup_steps = 1500, # Default value in run_glue.py 401 | num_training_steps = total_steps) 402 | model.train() 403 | step=0 404 | for epoch in range(max_epoch): 405 | 406 | if args.gpu_num > 1: 407 | sampler = DistributedSampler(dataset) 408 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,sampler=sampler) 409 | else: 410 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) 411 | if args.gpu_num > 1: 412 | sampler.set_epoch(epoch) 413 | total_train_loss = 0 414 | 415 | for iter_num, batch in enumerate(tqdm(train_loader)): 416 | step += 1 417 | batch = {k: v.to(device) for k, v in batch.items()} 418 | if args.train_type == 'sup': 419 | if args.Negative_type == 'generate': 420 | print("data generateing ") 421 | inputs = batch['sentimask_input_ids'].clone().to(device) 422 | attention_mask = batch['attention_mask'].clone().to(device) 423 | token_type_ids = batch['token_type_ids'].clone().to(device) 424 | is_sentimask_applied = batch['sentimask_input_ids'] == mask_token_index 425 | is_sentimask_applied = is_sentimask_applied.to(device) 426 | gen_logits = generator(inputs, attention_mask, token_type_ids)[0] 427 | sentimask_gen_logits = gen_logits[is_sentimask_applied, :] 428 | pred_toks = torch.multinomial(F.softmax(sentimask_gen_logits, dim=-1), 1).squeeze() 429 | rep_input_ids = inputs.clone() 430 | rep_input_ids[is_sentimask_applied] = pred_toks 431 | rep_input_ids.to(device) 432 | outputs = model.forward(ori_input_ids = batch['sentimask_input_ids'], 433 | pos_input_ids = batch['ori_input_ids'], 434 | neg_input_ids = rep_input_ids, 435 | attention_mask = batch['attention_mask']) 436 | else: 437 | outputs = model.forward(ori_input_ids = batch['sentimask_input_ids'], 438 | pos_input_ids = batch['ori_input_ids'], 439 | neg_input_ids = batch['rep_input_ids'], 440 | attention_mask = batch['attention_mask']) 441 | loss = simcse_sup_loss(outputs, temp=config.temperature) 442 | elif args.train_type == 'unsup': 443 | outputs = model.forward(ori_input_ids=batch['sentimask_input_ids'], 444 | pos_input_ids=batch['ori_input_ids'], 445 | attention_mask=batch['attention_mask']) 446 | 447 | loss = simcse_unsup_loss(outputs, temp=config.temperature) 448 | 449 | else: 450 | print("no simcse type name " + args.train_type) 451 | sys.exit() 452 | total_train_loss += loss.item() 453 | loss.backward() 454 | # clip 455 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0) 456 | optimizer.step() 457 | scheduler.step() 458 | optimizer.zero_grad() 459 | 460 | if step % 20 == 0: 461 | #print(batch['labels']) 462 | #print(torch.argmax(outputs.logits, 1)) 463 | print(">>> epoth: %d, iter_num: %d, loss: %.4f" % (epoch, step, loss.item())) 464 | if step == 200: 465 | if dist.get_rank() == 0: 466 | if args.save_model: 467 | output_dir = args.save_pretrain_model 468 | save_model(model, output_dir, step) 469 | if step % 1000 == 0 and step > 0: 470 | if dist.get_rank() == 0: 471 | if args.save_model: 472 | output_dir = args.save_pretrain_model 473 | save_model(model, output_dir, step) 474 | if step == 10000: 475 | sys.exit() 476 | 477 | if dist.get_rank() == 0: 478 | if args.save_model: 479 | output_dir = args.save_pretrain_model 480 | save_model(model, output_dir, epoch + 1) 481 | 482 | print("Epoch: %d, Average training loss: %.4f" %(epoch, total_train_loss/len(train_loader))) 483 | 484 | def save_model(model,output_dir,flag): 485 | if flag>100: 486 | save_dis_path=output_dir+str(flag)+'iter_discriminator' 487 | else: 488 | save_dis_path = output_dir + str(flag) + 'epoch_discriminator' 489 | if os.path.exists(save_dis_path) == False: 490 | os.makedirs(save_dis_path) 491 | if args.gpu_num > 1: 492 | model.module.save_DModel(save_dis_path) 493 | else: 494 | model.save_DModel(save_dis_path) 495 | SimCSEModel.to(device) 496 | if args.Negative_type == 'generate'and args.train_type == 'sup': 497 | rep_generator.to(device) 498 | 499 | if args.gpu_num > 1: 500 | if torch.cuda.device_count() > 1: 501 | print("Let's use", args.gpu_num, "GPUs!") 502 | SimCSEModel = torch.nn.parallel.DistributedDataParallel(SimCSEModel, 503 | device_ids=[local_rank], 504 | output_device=local_rank) 505 | if args.Negative_type == 'generate' and args.train_type == 'sup': 506 | rep_generator = torch.nn.parallel.DistributedDataParallel(rep_generator, 507 | device_ids=[local_rank], 508 | output_device=local_rank) 509 | else: 510 | print("There are not enough GPUs available!") 511 | sys.exit() 512 | 513 | if args.Negative_type == 'generate'and args.train_type == 'sup': 514 | train(SimCSEModel, merged_dsets, config.max_epoch, rep_generator) 515 | else: 516 | train(SimCSEModel, merged_dsets, config.max_epoch) 517 | -------------------------------------------------------------------------------- /pretrain/SentiWSP_Pretrain_Word.py: -------------------------------------------------------------------------------- 1 | import os, sys, random 2 | from pathlib import Path 3 | from datetime import datetime, timezone, timedelta 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from transformers import ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM, ElectraForPreTraining, get_linear_schedule_with_warmup 12 | from _utils.utils import * 13 | from _utils.would_like_to_pr import * 14 | from tqdm import tqdm 15 | from transformers import WEIGHTS_NAME, CONFIG_NAME 16 | import pickle 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description='Pre training model configuration') 20 | parser.add_argument('--model', default='electra',help='pre train model type') 21 | parser.add_argument('--size', default='large',help='pre train model size,choose in [small, base, large]') 22 | parser.add_argument('--dataset', default='wiki',help='choose in [wiki,owt,merg]') 23 | parser.add_argument('--gpu_num', type=int ,default=4,help='pre train gpu num') 24 | parser.add_argument('--pretrain_model', type=str ,default='./pretrain_model/',help='pre train model path') 25 | parser.add_argument('--save_model', type=bool ,default=True,help='Whether to Save model') 26 | parser.add_argument('--save_pretrain_model', type=str ,default='./save_pretrain_model/',help='save will create pre train model path') 27 | parser.add_argument("--rank", type=int,default=-1, help="rank") 28 | parser.add_argument("--local_rank", type=int,default=-1, help="local rank") 29 | parser.add_argument("--sentimask_prob", type=float,default=0.5, help="The sentiment word mask probability") 30 | parser.add_argument('--maskprob', type=float,default = 0.15,help='the mask prob in pretrain process') 31 | parser.add_argument('--batch_size', type=int,default = 64,help='the batch_size in pretrain process') 32 | parser.add_argument('--max_len', type=int,default = 128,help='the seq max_len in pretrain process') 33 | args = parser.parse_args() 34 | 35 | # 总体配置 36 | config = MyConfig({ 37 | 'base_run_name': 'ELECTRA', 38 | 'seed': 2022, 39 | 'electra_mask_style': True, 40 | 'config_path':'./config_pretrain/', 41 | 'num_workers': 3 42 | }) 43 | 44 | i = ['small', 'base', 'large'].index(args.size) 45 | config.lr = [2e-5, 1e-5, 1e-5][i] 46 | print("now learning rate:",config.lr) 47 | print("now mask prob",args.mask_prob) 48 | config.max_epoch = 10 49 | config.single_gpunum = 0 50 | #generator_size_divisor = [4, 3, 4][i] 51 | 52 | # seed 53 | random.seed(config.seed) 54 | np.random.seed(config.seed) 55 | torch.manual_seed(config.seed) 56 | 57 | # Global Multi GPU environment definition 58 | if args.gpu_num > 1: 59 | #os.environ["CUDA_VISIBLE_DEVICES"] = config.dpp_gp u_num 60 | torch.distributed.init_process_group(backend="nccl") 61 | local_rank = torch.distributed.get_rank() 62 | torch.cuda.set_device(local_rank) 63 | device = torch.device("cuda", local_rank) 64 | else: 65 | device = torch.device("cuda:" + str(config.single_gpunum) if torch.cuda.is_available() else "cpu") 66 | print("want use gpu num:",args.gpu_num) 67 | print("use gpu num:",torch.cuda.device_count()) 68 | 69 | # load data 70 | if args.dataset=='wiki' or args.dataset=='merg': 71 | print('load/download wiki dataset') 72 | if os.path.exists("./datasets/wiki") == False: 73 | wiki = datasets.load_dataset('wikipedia', '20200501.en', cache_dir='./datasets')['train'] 74 | wiki.save_to_disk("./datasets/wiki") 75 | else: 76 | wiki = datasets.load_from_disk("./datasets/wiki") 77 | print('load/create data from wiki dataset for ELECTRA') 78 | if args.dataset=='owt' or args.dataset=='merg': 79 | print('load/download OpenWebText Corpus') 80 | if os.path.exists("./datasets/owt") == False: 81 | owt = datasets.load_dataset('openwebtext', cache_dir='./datasets')['train'] 82 | owt.save_to_disk('./datasets/owt') 83 | else: 84 | owt = datasets.load_from_disk('./datasets/owt') 85 | print('load/create data from OpenWebText Corpus for ELECTRA') 86 | 87 | pretrain_path = args.pretrain_model+args.model+"/"+args.size 88 | 89 | if os.path.exists(pretrain_path) == False: 90 | os.makedirs(pretrain_path) 91 | print("load Electra Tokenizer Fast from hub...") 92 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{args.size}-discriminator") 93 | hf_tokenizer.save_pretrained(pretrain_path) 94 | else: 95 | print("load Electra Tokenizer Fast from pretrain_path...") 96 | hf_tokenizer = ElectraTokenizerFast.from_pretrained(pretrain_path) 97 | 98 | ELECTRAProcessor = partial(ELECTRADataProcessor, hf_tokenizer=hf_tokenizer, max_length=args.max_length) 99 | dsets = [] 100 | if args.dataset=='wiki' or args.dataset=='merg': 101 | e_wiki = ELECTRAProcessor(wiki).map(cache_file_name=f"./datasets/wiki/electra_wiki_{args.max_length}.arrow", num_proc=8) 102 | print("wiki len", len(e_wiki)) 103 | dsets.append(e_wiki) 104 | if args.dataset=='owt' or args.dataset=='merg': 105 | e_owt = ELECTRAProcessor(owt, apply_cleaning=False).map(cache_file_name=f"./datasets/owt/electra_owt_{args.max_length}.arrow", num_proc=8) 106 | print("owt len",len(e_owt)) 107 | dsets.append(e_owt) 108 | 109 | if len(dsets)>1: 110 | merged_dsets = datasets.concatenate_datasets(dsets) 111 | print("use merge dataset len:",len(merged_dsets)) 112 | elif args.dataset=='wiki': 113 | merged_dsets=dsets[0] 114 | print("use wiki dataset len:",len(merged_dsets)) 115 | else: 116 | merged_dsets=dsets[0] 117 | print("use owt dataset len:",len(merged_dsets)) 118 | 119 | def get_pad_mask_and_token_type(example): 120 | 121 | # PAD sentence 122 | if len(example['input_ids']) < args.max_length: 123 | example['new_input_ids'] = example['input_ids'] + [hf_tokenizer.pad_token_id] * (args.max_length - len(example['input_ids'])) 124 | else: 125 | example['new_input_ids'] = example['input_ids'] 126 | 127 | attention_mask = torch.tensor(example['new_input_ids']) != hf_tokenizer.pad_token_id 128 | 129 | # sentence A (token_type_ids =0) sentence B (token_type_ids =1) 130 | token_type_ids = torch.tensor([0]*example['sentA_length'] + [1]*(args.max_length-example['sentA_length'])) 131 | example['token_type_ids'] = token_type_ids 132 | example['attention_mask'] = attention_mask 133 | return 134 | 135 | print("pad mask map in"+args.dataset) 136 | if args.dataset=="wiki": 137 | merged_dsets = merged_dsets.map(get_pad_mask_and_token_type, 138 | cache_file_name = f"./datasets/wiki/padmask/electra_wiki_{args.max_length}_padmask_map.arrow",num_proc=16) 139 | elif args.dataset=="owt": 140 | merged_dsets = merged_dsets.map(get_pad_mask_and_token_type, 141 | cache_file_name = f"./datasets/owt/padmask/electra_owt_{args.max_length}_padmask_map.arrow",num_proc=16) 142 | else: 143 | merged_dsets = merged_dsets.map(get_pad_mask_and_token_type, 144 | cache_file_name = f"./datasets/merged_dsets/padmask/electra_merg_{args.max_length}_padmask_map.arrow",num_proc=16) 145 | 146 | sentivetor = np.load('./sentiment_vocab/senti_vector.npy') 147 | 148 | def get_senti_type(example): 149 | senti_list = [] 150 | for ids in example['new_input_ids']: 151 | if sentivetor[ids] == 1: 152 | senti_list.append(1) 153 | else: 154 | senti_list.append(0) 155 | example['senti_type'] = senti_list 156 | return example 157 | 158 | 159 | if args.dataset=="wiki": 160 | merged_dsets = merged_dsets.map(get_senti_type, 161 | cache_file_name = f"./datasets/wiki/sentimask/electra_wiki_{args.max_length}_sentimask_map.arrow",num_proc=16) 162 | elif args.dataset=="owt": 163 | merged_dsets = merged_dsets.map(get_senti_type, 164 | cache_file_name = f"./datasets/owt/sentimask/electra_owt_{args.max_length}_sentimask_map.arrow",num_proc=16) 165 | else: 166 | merged_dsets = merged_dsets.map(get_senti_type, 167 | cache_file_name = f"./datasets/merged_dsets/sentimask/electra_merg_{args.max_length}_sentimask_map.arrow",num_proc=16) 168 | 169 | # mask config 170 | mlm_probability=args.mask_prob 171 | ignore_index=-100 172 | mask_token_index = hf_tokenizer.mask_token_id 173 | special_tok_ids = hf_tokenizer.all_special_ids 174 | vocab_size=hf_tokenizer.vocab_size 175 | replace_prob=0.0 if config.electra_mask_style else 0.1 176 | orginal_prob=0.15 if config.electra_mask_style else 0.1 177 | sentimask_prob = args.sentimask_prob 178 | 179 | def mask_tokens_map(example): 180 | #introduce sentiment mask process 181 | inputs = torch.tensor(example['new_input_ids']).clone() 182 | device = inputs.device 183 | 184 | labels = inputs.clone() 185 | 186 | probability_matrix = torch.full(labels.shape, mlm_probability, device=device) 187 | 188 | senti_probability_matrix = torch.tensor(example['senti_type']).clone() * sentimask_prob 189 | 190 | special_tokens_mask = torch.full(inputs.shape, False, dtype=torch.bool, device=device) 191 | 192 | for sp_id in special_tok_ids: 193 | special_tokens_mask = special_tokens_mask | (inputs == sp_id) 194 | 195 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 196 | senti_probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 197 | 198 | mlm_mask = torch.bernoulli(probability_matrix).bool() 199 | senti_mask = torch.bernoulli(senti_probability_matrix).bool() 200 | 201 | # merge mlm mask and senti mask 202 | mlm_mask = mlm_mask | senti_mask 203 | 204 | labels[~mlm_mask] = ignore_index 205 | 206 | mask_prob = 1 - replace_prob - orginal_prob 207 | mask_token_mask = torch.bernoulli(torch.full(labels.shape, mask_prob, device=device)).bool() & mlm_mask 208 | inputs[mask_token_mask] = mask_token_index 209 | 210 | if int(replace_prob) != 0: 211 | rep_prob = replace_prob / (replace_prob + orginal_prob) 212 | replace_token_mask = torch.bernoulli( 213 | torch.full(labels.shape, rep_prob, device=device)).bool() & mlm_mask & ~mask_token_mask 214 | random_words = torch.randint(vocab_size, labels.shape, dtype=torch.long, device=device) 215 | inputs[replace_token_mask] = random_words[replace_token_mask] 216 | 217 | pass 218 | 219 | example['masked_inputs'] = inputs 220 | example['is_mlm_applied'] = mlm_mask 221 | example['labels'] = labels 222 | return example 223 | 224 | print("mlm mask map in"+args.dataset) 225 | if args.dataset=="wiki": 226 | if args.sentimask_prob == 0: 227 | #no sentiment mask 228 | merged_dsets = merged_dsets.map(mask_tokens_map, 229 | cache_file_name = f"./datasets/wiki/mlmmask/electra_wiki_{args.max_length}_padmask_map.arrow",num_proc=16) 230 | else: 231 | merged_dsets = merged_dsets.map(mask_tokens_map, 232 | cache_file_name=f"./datasets/wiki/mlmmask/electra_{args.sentimask_prob*10}_wiki_{args.max_length}_mlmmask_map.arrow", 233 | num_proc=16) 234 | elif args.dataset=="owt": 235 | if args.sentimask_prob == 0: 236 | merged_dsets = merged_dsets.map(mask_tokens_map, 237 | cache_file_name = f"./datasets/owt/mlmmask/electra_owt_{args.max_length}_padmask_map.arrow",num_proc=16) 238 | else: 239 | merged_dsets = merged_dsets.map(mask_tokens_map, 240 | cache_file_name=f"./datasets/owt/mlmmask/electra_{args.sentimask_prob*10}_owt_{args.max_length}_mlmmask_map.arrow", 241 | num_proc=16) 242 | else: 243 | merged_dsets = merged_dsets.map(mask_tokens_map, 244 | cache_file_name = f"./datasets/merged_dsets/mlmmask/electra_merg_{args.max_length}_mlmmask_map.arrow",num_proc=16) 245 | 246 | merged_dsets.set_format(type='torch', columns=['masked_inputs', 'token_type_ids', 'attention_mask','is_mlm_applied','labels']) 247 | print("set format") 248 | 249 | # model config 250 | class ELECTRAModel(nn.Module): 251 | 252 | # Model initialization needs to be specified generator and discriminator 253 | def __init__(self, generator, discriminator, hf_tokenizer): 254 | super().__init__() 255 | self.generator, self.discriminator = generator,discriminator 256 | self.hf_tokenizer = hf_tokenizer 257 | 258 | def forward(self, masked_inputs, is_mlm_applied, labels, attention_mask, token_type_ids): 259 | """ 260 | masked_inputs (Tensor[int]): (batch_size, max_seq_len) 261 | sentA_lenths (Tensor[int]): (batch_size) 262 | is_mlm_applied (Tensor[boolean]): (batch_size, max_seq_len), 值为True代表改位置被MASK 263 | labels (Tensor[int]): (batch_size, max_seq_len), -100 for positions where are not mlm applied 264 | """ 265 | 266 | # gen_logits形状 (batch_size, max_seq_len, vocab_size) 267 | gen_logits = self.generator(masked_inputs, attention_mask, token_type_ids)[0] 268 | # reduce size to save space and speed 269 | # mlm_gen_logits :(Mask_num, vocab_size) 270 | mlm_gen_logits = gen_logits[is_mlm_applied, :] # ( #mlm_positions, vocab_size) 271 | 272 | with torch.no_grad(): 273 | # sampling 274 | pred_toks = torch.multinomial(F.softmax(mlm_gen_logits, dim=-1), 1).squeeze() 275 | 276 | # Fill the predicted token into the input originally masked as the input of the discriminator 277 | generated = masked_inputs.clone() # (B,L) 278 | generated[is_mlm_applied] = pred_toks # (B,L) 279 | 280 | # produce labels for discriminator 281 | is_replaced = is_mlm_applied.clone() # (B,L) 282 | is_replaced[is_mlm_applied] = (pred_toks != labels[is_mlm_applied]) # (B,L) 283 | 284 | disc_logits = self.discriminator(generated, attention_mask, token_type_ids)[0] # (B, L) 285 | 286 | return mlm_gen_logits, generated, disc_logits, is_replaced, attention_mask, is_mlm_applied 287 | 288 | def save_discriminator(self, output_dir): 289 | if os.path.exists(output_dir) == False: 290 | os.makedirs(output_dir) 291 | model_to_save = self.discriminator 292 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 293 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 294 | torch.save(model_to_save.state_dict(), output_model_file) 295 | model_to_save.config.to_json_file(output_config_file) 296 | print("save model in :" + output_dir) 297 | 298 | def save_generator(self, output_dir): 299 | if os.path.exists(output_dir) == False: 300 | os.makedirs(output_dir) 301 | model_to_save = self.generator 302 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 303 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 304 | torch.save(model_to_save.state_dict(), output_model_file) 305 | model_to_save.config.to_json_file(output_config_file) 306 | print("save model in :" + output_dir) 307 | 308 | if config.my_model: 309 | # 训练配置 310 | config_path=config.config_path 311 | 312 | if os.path.exists(config_path+args.size+'disc') == False: 313 | print("load ele config in hub...") 314 | disc_config = ElectraConfig.from_pretrained(f'google/electra-{args.size}-discriminator') 315 | gen_config = ElectraConfig.from_pretrained(f'google/electra-{args.size}-generator') 316 | os.makedirs(config_path + args.size+'disc') 317 | os.makedirs(config_path + args.size+'gen') 318 | disc_config.save_pretrained(config_path + args.size+'disc') 319 | gen_config.save_pretrained(config_path + args.size+'gen') 320 | else: 321 | print("load ele config in "+config_path) 322 | disc_config = ElectraConfig.from_pretrained(config_path + args.size+'disc') 323 | gen_config = ElectraConfig.from_pretrained(config_path + args.size+'gen') 324 | 325 | # note that public electra-small model is actually small++ and don't scale down generator size 326 | gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor) 327 | gen_config.num_attention_heads = disc_config.num_attention_heads//generator_size_divisor 328 | gen_config.intermediate_size = disc_config.intermediate_size//generator_size_divisor 329 | 330 | generator = ElectraForMaskedLM(gen_config) 331 | discriminator = ElectraForPreTraining(disc_config) 332 | else: 333 | print("use bentchmark generator and discriminator") 334 | gen_path=args.pretrain_model+args.model+'/'+args.size+'/gen' 335 | disc_path=args.pretrain_model+args.model+'/'+args.size+'/disc' 336 | if os.path.exists(gen_path) == False: 337 | os.makedirs(gen_path) 338 | os.makedirs(disc_path) 339 | generator = ElectraForMaskedLM.from_pretrained(f'google/electra-{args.size}-generator') 340 | discriminator = ElectraForPreTraining.from_pretrained(f'google/electra-{args.size}-discriminator') 341 | generator.save_pretrained(gen_path) 342 | discriminator.save_pretrained(disc_path) 343 | else: 344 | print("load Generator and disc") 345 | generator = ElectraForMaskedLM.from_pretrained(gen_path) 346 | discriminator = ElectraForPreTraining.from_pretrained(disc_path) 347 | 348 | discriminator.electra.embeddings = generator.electra.embeddings 349 | generator.generator_lm_head.weight = generator.electra.embeddings.word_embeddings.weight 350 | 351 | electra_model = ELECTRAModel(generator, discriminator, hf_tokenizer) 352 | 353 | def ELECTRALoss(pred, targ_ids, loss_weights=(1.0, 50.0), gen_label_smooth=False, disc_label_smooth=False): 354 | gen_loss_fc = nn.CrossEntropyLoss() 355 | disc_loss_fc = nn.BCEWithLogitsLoss() 356 | 357 | mlm_gen_logits, generated, disc_logits, is_replaced, non_pad, is_mlm_applied = pred 358 | gen_loss = gen_loss_fc(mlm_gen_logits.float(), targ_ids[is_mlm_applied]) 359 | disc_logits = disc_logits.masked_select(non_pad) # -> 1d tensor 360 | is_replaced = is_replaced.masked_select(non_pad) # -> 1d tensor 361 | if disc_label_smooth: 362 | is_replaced = is_replaced.float().masked_fill(~is_replaced, disc_label_smooth) 363 | disc_loss = disc_loss_fc(disc_logits.float(), is_replaced.float()) 364 | 365 | 366 | return gen_loss * loss_weights[0] + disc_loss * loss_weights[1],gen_loss,disc_loss 367 | 368 | def train(model, dataset, max_epoch): 369 | 370 | if config.gpu_num > 1: 371 | sampler = DistributedSampler(dataset) 372 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,sampler=sampler) 373 | else: 374 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) 375 | 376 | optimizer = torch.optim.AdamW(params=electra_model.parameters(), lr=config.lr, eps=1e-8, betas=(0.9,0.999), weight_decay=0.01) 377 | 378 | total_steps = (dataset.num_rows // args.batch_size) * config.max_epoch 379 | 380 | scheduler = get_linear_schedule_with_warmup(optimizer, 381 | num_warmup_steps = 1500, # Default value in run_glue.py 382 | num_training_steps = total_steps) 383 | model.train() 384 | step=0 385 | gen_loss_list,disc_loss_list=[],[] 386 | for epoch in range(max_epoch): 387 | 388 | # Each epoch scrambles the data set 389 | # dataset = dataset.shuffle() 390 | if args.gpu_num>1: 391 | sampler.set_epoch(epoch) 392 | total_train_loss = 0 393 | for iter_num, batch in enumerate(tqdm(train_loader)): 394 | pass 395 | step+=1 396 | batch = {k: v.to(device) for k, v in batch.items()} 397 | outputs = model.forward(**batch) 398 | loss,gen_loss,disc_loss = ELECTRALoss(outputs,batch['labels']) 399 | total_train_loss += loss.item() 400 | 401 | loss.backward() 402 | # clip 403 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0) 404 | optimizer.step() 405 | scheduler.step() 406 | optimizer.zero_grad() 407 | 408 | if step % 200 == 0: 409 | if dist.get_rank() == 0: 410 | gen_loss_list.append(gen_loss.item()) 411 | disc_loss_list.append(disc_loss.item()) 412 | print(">>> epoth: %d, step: %d, loss: %.4f" % (epoch, step, loss.item())) 413 | 414 | if step ==200: 415 | if dist.get_rank() == 0: 416 | if args.save_model: 417 | output_dir = args.save_pretrain_model 418 | save_model(model, output_dir, step) 419 | if step%1000==0 and step>0: 420 | if dist.get_rank() == 0: 421 | if args.save_model: 422 | output_dir = args.save_pretrain_model 423 | save_model(model, output_dir, step) 424 | if step==5000: 425 | if dist.get_rank() == 0: 426 | with open('./loss_word.pkl', 'wb') as f: 427 | pickle.dump(gen_loss_list, f, pickle.HIGHEST_PROTOCOL) 428 | pickle.dump(disc_loss_list, f, pickle.HIGHEST_PROTOCOL) 429 | sys.exit() 430 | 431 | if dist.get_rank() == 0: 432 | if args.save_model: 433 | output_dir=args.save_pretrain_model 434 | save_model(model,output_dir,epoch+1) 435 | print("Epoch: %d, Average training loss: %.4f" %(epoch, total_train_loss/len(train_loader))) 436 | 437 | def save_model(model,output_dir,flag): 438 | if flag>100: 439 | save_dis_path = output_dir + str(flag) + 'iter_discriminator' 440 | save_gen_path = output_dir + str(flag) + 'iter_generator' 441 | else: 442 | save_dis_path = output_dir + str(flag) + 'epoch_discriminator' 443 | save_gen_path = output_dir + str(flag) + 'epoch_generator' 444 | if os.path.exists(save_dis_path) == False: 445 | os.makedirs(save_dis_path) 446 | os.makedirs(save_gen_path) 447 | if args.gpu_num > 1: 448 | model.module.save_discriminator(save_dis_path) 449 | model.module.save_generator(save_gen_path) 450 | else: 451 | model.save_discriminator(save_dis_path) 452 | model.save_generator(save_gen_path) 453 | 454 | electra_model.to(device) 455 | if args.gpu_num > 1: 456 | if torch.cuda.device_count() > 1: 457 | print("Let's use", args.gpu_num, "GPUs!") 458 | electra_model = torch.nn.parallel.DistributedDataParallel(electra_model, 459 | device_ids=[local_rank], 460 | output_device=local_rank) 461 | else: 462 | print("There are not enough GPUs available!") 463 | sys.exit() 464 | train(electra_model,merged_dsets,config.max_epoch) 465 | -------------------------------------------------------------------------------- /sentiment_vocab/senti_vector.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XMUDM/SentiWSP/3cd57e171bbce16265f1daee5f5c0a680a7bb078/sentiment_vocab/senti_vector.npy -------------------------------------------------------------------------------- /sentiment_vocab/senti_vocab.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XMUDM/SentiWSP/3cd57e171bbce16265f1daee5f5c0a680a7bb078/sentiment_vocab/senti_vocab.npy --------------------------------------------------------------------------------