├── README.md ├── collators ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── classification.cpython-39.pyc │ ├── contrast.cpython-39.pyc │ ├── dexperts.cpython-39.pyc │ ├── dexperts_padding.cpython-39.pyc │ ├── gpt2.cpython-39.pyc │ ├── gpt2_contrast.cpython-39.pyc │ ├── gpt2_eval.cpython-39.pyc │ ├── gpt2_infer.cpython-39.pyc │ ├── labels.cpython-39.pyc │ ├── text2text.cpython-39.pyc │ ├── text2text_contrast.cpython-39.pyc │ ├── text2text_dexperts.cpython-39.pyc │ └── text2text_labels.cpython-39.pyc ├── classification.py ├── gpt2.py ├── gpt2_contrast.py ├── gpt2_eval.py ├── text2text.py ├── text2text_contrast.py ├── text2text_dexperts.py ├── text2text_labels.py └── text2text_padding.py ├── eval_bad.py ├── eval_dist.py ├── eval_div.py ├── eval_mauve.py ├── eval_ppl_blender.py ├── eval_ppl_gpt2.py ├── eval_senti.py ├── generate.py ├── infer.py ├── metrics ├── __init__.py └── __pycache__ │ └── __init__.cpython-39.pyc ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── blender.cpython-39.pyc │ ├── blender_contrast_01.cpython-39.pyc │ ├── blender_contrast_02.cpython-39.pyc │ ├── blender_contrast_05.cpython-39.pyc │ ├── blender_cringe_01.cpython-39.pyc │ ├── blender_cringe_02.cpython-39.pyc │ ├── blender_cringe_03.cpython-39.pyc │ ├── blender_cringe_04.cpython-39.pyc │ ├── blender_cringe_05.cpython-39.pyc │ ├── blender_dexperts.cpython-39.pyc │ ├── blender_director_02.cpython-39.pyc │ ├── blender_director_05.cpython-39.pyc │ ├── blender_director_10.cpython-39.pyc │ ├── blender_gedi.cpython-39.pyc │ ├── blender_gedinew.cpython-39.pyc │ ├── blender_unlikelihood_01.cpython-39.pyc │ ├── blender_unlikelihood_02.cpython-39.pyc │ ├── blender_unlikelihood_03.cpython-39.pyc │ ├── blender_unlikelihood_04.cpython-39.pyc │ ├── blender_unlikelihood_05.cpython-39.pyc │ ├── contrast_01.cpython-39.pyc │ ├── contrast_02.cpython-39.pyc │ ├── contrast_05.cpython-39.pyc │ ├── cringe.cpython-39.pyc │ ├── cringe_01.cpython-39.pyc │ ├── cringe_02.cpython-39.pyc │ ├── cringe_03.cpython-39.pyc │ ├── cringe_04.cpython-39.pyc │ ├── cringe_05.cpython-39.pyc │ ├── dexperts.cpython-39.pyc │ ├── director.cpython-39.pyc │ ├── director_02.cpython-39.pyc │ ├── director_05.cpython-39.pyc │ ├── director_10.cpython-39.pyc │ ├── distilbert.cpython-39.pyc │ ├── gedi.cpython-39.pyc │ ├── gpt2.cpython-39.pyc │ ├── gpt2_contrast_01.cpython-39.pyc │ ├── gpt2_contrast_02.cpython-39.pyc │ ├── gpt2_contrast_03.cpython-39.pyc │ ├── gpt2_contrast_04.cpython-39.pyc │ ├── gpt2_contrast_05.cpython-39.pyc │ ├── roberta.cpython-39.pyc │ ├── unlikelihood.cpython-39.pyc │ ├── unlikelihood_01.cpython-39.pyc │ ├── unlikelihood_02.cpython-39.pyc │ ├── unlikelihood_03.cpython-39.pyc │ ├── unlikelihood_04.cpython-39.pyc │ └── unlikelihood_05.cpython-39.pyc ├── blender.py ├── blender_contrast_05.py ├── blender_cringe_02.py ├── blender_dexperts.py ├── blender_director_02.py ├── blender_gedi.py ├── blender_unlikelihood_01.py ├── distilbert.py ├── gpt2.py ├── gpt2_contrast_01.py ├── gpt2_contrast_03.py └── roberta.py ├── results_bad.py ├── results_bad_test.txt ├── results_senti.py ├── results_senti_neg.txt ├── results_senti_pos.txt ├── results_wiki.py ├── results_wiki_test.txt ├── scripts_bad ├── blender │ ├── augment.sh │ └── generate.sh ├── contrast │ ├── augment.sh │ ├── generate.sh │ └── train.sh ├── cringe │ ├── generate.sh │ └── train.sh ├── dexperts │ ├── antiexpert.sh │ ├── expert.sh │ └── generate.sh ├── director │ ├── generate.sh │ └── train.sh ├── eval │ ├── dist.sh │ ├── ppl_large.sh │ ├── ppl_self.sh │ ├── predict.sh │ └── predict_train.sh ├── ft │ └── generate.sh ├── gedi │ └── generate.sh ├── stats.sh └── unlikelihood │ ├── generate.sh │ └── train.sh ├── scripts_cls └── bad │ ├── test.sh │ └── train.sh ├── scripts_senti ├── eval │ ├── dist.sh │ ├── ppl_large.sh │ ├── ppl_self.sh │ ├── predict.sh │ └── predict_train.sh ├── gpt2 │ ├── augment.sh │ └── train.sh ├── neg_contrast │ ├── generate.sh │ └── train.sh ├── pos_contrast │ ├── generate.sh │ └── train.sh ├── stats_neg.sh └── stats_pos.sh ├── scripts_wiki ├── contrast │ ├── generate.sh │ └── train.sh ├── eval │ ├── div.sh │ └── mavue.sh ├── gpt2 │ ├── augment_0.5.sh │ ├── augment_0.7.sh │ ├── augment_0.9.sh │ └── train.sh └── stats.sh ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── building_utils.cpython-39.pyc ├── cuda_utils.cpython-39.pyc ├── dataloader_utils.cpython-39.pyc ├── eval_utils.cpython-39.pyc └── model_utils.cpython-39.pyc ├── building_utils.py ├── dataloader_utils.py ├── eval_utils.py └── model_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Click 2 | This is the official repository for the Findings of ACL 2023 paper "[Click: Controllable Text Generation with Sequence Likelihood Contrastive Learning](https://arxiv.org/abs/2306.03350)". 3 | 4 | ## Data, Models, and Generation Results 5 | 6 | You can download our used data and trained models in [HuggingFace](https://huggingface.co/chujiezheng/click) (these files may be a bit large). 7 | 8 | To reproduce the experiments, you need to download the `data_*` folders and put them in the project path. 9 | 10 | ## Requirements 11 | 12 | **Note**: The listed packages may be not inclusive. You may install additional packages according to the error information. 13 | 14 | ```conda 15 | torch 16 | transformers 17 | accelerate 18 | ``` 19 | 20 | ## Detoxification Experiments 21 | 22 | We additionally release the implementation of Click as well as the baselines (see `models`). The running scripts are in `scripts_bad`. We take Click for instance to illustrate how to run the codes. 23 | 24 | **NOTE**: `scripts_cls`, `data_cls`, and `checkpoints_cls` contains scripts, data, and checkpoints of our trained BAD classifier for evaluation. 25 | 26 | ```bash 27 | # process raw data 28 | # you may need to download the original datasets and change the data path 29 | cd /{project_path}/data_bad/raw 30 | python process.py 31 | 32 | # process train/valid/test data 33 | cd /{project_path}/data_bad/blender 34 | python process.py 35 | 36 | # generate multiple continuations 37 | cd /{project_path} 38 | bash scripts_bad/blender/augment.sh 39 | bash scripts_bad/eval/ppl_self.sh 40 | bash scripts_bad/eval/predict_train.sh 41 | 42 | # process contrastive learning data 43 | cd /{project_path}/data_bad/contrast 44 | python process.py 45 | 46 | # run training and generation 47 | # you may use scripts for different models 48 | cd /{project_path} 49 | bash scripts_bad/contrast/train.sh 50 | bash scripts_bad/contrast/generate.sh 51 | # ... and you may generate multiple continuations for the next iteration 52 | # bash scripst_bad/contrast/augment.sh 53 | 54 | # run evaluation 55 | # you may change the arguments as needed 56 | cd /{project_path} 57 | bash scripts_bad/eval/dist.sh 58 | bash scripts_bad/eval/ppl_large.sh 59 | bash scripts_bad/eval/predict.sh 60 | 61 | # gather results 62 | cd /{project_path} 63 | bash scripts_bad/stats.sh 64 | 65 | # then look into `results_bad_test.txt` for results 66 | # the results' order is consistent as presented in the paper 67 | # you can paste them into Excel :) 68 | ``` 69 | 70 | ## Sentiment Experiments 71 | 72 | ```bash 73 | # process raw data 74 | # you may need to change the data path 75 | cd /{project_path}/data_senti/gpt2 76 | python process.py 77 | 78 | # generate multiple continuations 79 | cd /{project_path} 80 | bash scripts_senti/gpt2/train.sh 81 | bash scripts_senti/gpt2/augment.sh 82 | bash scripts_senti/eval/ppl_self.sh 83 | bash scripts_senti/eval/predict_train.sh 84 | 85 | # process contrastive learning data 86 | cd /{project_path}/data_senti/contrast 87 | python process.py 88 | 89 | # you may use pos/neg to specify the domain 90 | # run training and generation 91 | cd /{project_path} 92 | bash scripts_senti/pos_contrast/train.sh 93 | bash scripts_senti/pos_contrast/generate.sh 94 | 95 | # run evaluation 96 | # you may change the arguments as needed 97 | cd /{project_path} 98 | bash scripts_senti/eval/dist.sh 99 | bash scripts_senti/eval/ppl_large.sh 100 | bash scripts_senti/eval/predict.sh 101 | 102 | # gather results 103 | cd /{project_path} 104 | bash scripts_senti/stats_pos.sh 105 | 106 | # then look into `results_senti_pos.txt` for results 107 | ``` 108 | 109 | ## Repetition Experiments 110 | 111 | **NOTE**: Since the training files for Click are too large (wikitext-103), I did not upload them on HuggingFace. But you may reproduce the data by running the following codes (the trained models used for producing the Click training data are uploaded). 112 | 113 | ```bash 114 | # process raw data 115 | # you may need to change the data path 116 | cd /{project_path}/data_wiki/gpt2 117 | python process.py 118 | 119 | # generate multiple continuations 120 | cd /{project_path} 121 | bash scripts_wiki/gpt2/train.sh 122 | # note: you should reproduce the Click training data from here 123 | bash scripts_wiki/gpt2/augment_0.5.sh 124 | bash scripts_wiki/gpt2/augment_0.7.sh 125 | bash scripts_wiki/gpt2/augment_0.9.sh 126 | bash scripts_wiki/eval/div.sh 127 | 128 | # process contrastive learning data 129 | cd /{project_path}/data_wiki/contrast 130 | python process.py 131 | 132 | # you may use pos/neg to specify the domain 133 | # run training and generation 134 | cd /{project_path} 135 | bash scripts_wiki/contrast/train.sh 136 | bash scripts_wiki/contrast/generate.sh 137 | 138 | # run evaluation 139 | # you may change the arguments as needed 140 | cd /{project_path} 141 | bash scripts_wiki/eval/mauve.sh 142 | 143 | # gather results 144 | cd /{project_path} 145 | bash scripts_wiki/stats.sh 146 | 147 | # then look into `results_wiki_test.txt` for results 148 | ``` 149 | 150 | ## Citation 151 | 152 | Please kindly cite our paper if our paper and codes are helpful. 153 | 154 | ```bib 155 | @inproceedings{zheng-etal-2023-click, 156 | title={Click: Controllable Text Generation with Sequence Likelihood Contrastive Learning}, 157 | author={Zheng, Chujie and 158 | Ke, Pei and 159 | Zhang, Zheng and 160 | Huang, Minlie}, 161 | booktitle={Findings of ACL}, 162 | year={2023} 163 | } 164 | ``` 165 | -------------------------------------------------------------------------------- /collators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__init__.py -------------------------------------------------------------------------------- /collators/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/classification.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/classification.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/contrast.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/contrast.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/dexperts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/dexperts.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/dexperts_padding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/dexperts_padding.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/gpt2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/gpt2.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/gpt2_contrast.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/gpt2_contrast.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/gpt2_eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/gpt2_eval.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/gpt2_infer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/gpt2_infer.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/labels.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/labels.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/text2text.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/text2text.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/text2text_contrast.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/text2text_contrast.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/text2text_dexperts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/text2text_dexperts.cpython-39.pyc -------------------------------------------------------------------------------- /collators/__pycache__/text2text_labels.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/collators/__pycache__/text2text_labels.cpython-39.pyc -------------------------------------------------------------------------------- /collators/classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | from transformers import logging 7 | 8 | logging.set_verbosity_error() 9 | 10 | 11 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 12 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 13 | 14 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 15 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 16 | token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 17 | labels = torch.tensor([f.label for f in features], dtype=torch.long) 18 | 19 | res = { 20 | 'input_ids': input_ids, 21 | 'attention_mask': attention_mask, 22 | 'token_type_ids': token_type_ids, 23 | 'labels': labels, # [bs] for single-label [bs, label_num] for multi-label 24 | } 25 | return res 26 | 27 | 28 | # `Feature` is a class that contains all the information needed to train a model 29 | class Feature(object): 30 | def __init__( 31 | self, input_ids, 32 | attention_mask, token_type_ids, label, 33 | input_len, 34 | ): 35 | self.input_ids = input_ids 36 | self.attention_mask = attention_mask 37 | self.token_type_ids = token_type_ids 38 | self.label = label 39 | self.input_len = input_len 40 | 41 | 42 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 43 | processed_data = toker( 44 | data['text'], 45 | data['text2'] if 'text2' in data else None, 46 | padding='max_length', 47 | truncation='longest_first', 48 | max_length=max_input_length, 49 | return_attention_mask=True, 50 | return_token_type_ids=True, 51 | ) 52 | 53 | feature = Feature( 54 | processed_data['input_ids'], 55 | processed_data['attention_mask'], 56 | processed_data['token_type_ids'], 57 | data['label'], 58 | len([e for e in processed_data['input_ids'] if e != toker.pad_token_id]) 59 | ) 60 | return feature 61 | 62 | -------------------------------------------------------------------------------- /collators/gpt2.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 10 | 11 | assert not infer 12 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 13 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 14 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 15 | 16 | res = { 17 | 'input_ids': input_ids, 18 | 'attention_mask': attention_mask, 19 | 'labels': labels, 20 | } 21 | return res 22 | 23 | 24 | class Feature(object): 25 | def __init__( 26 | self, 27 | input_ids, attention_mask, 28 | labels, 29 | ): 30 | self.input_ids = input_ids 31 | self.attention_mask = attention_mask 32 | self.labels = labels 33 | 34 | 35 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 36 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 37 | 38 | assert 'target' in data 39 | target = process(data['target']) 40 | eos = toker.eos_token_id 41 | 42 | input_ids = target[:-1][:max_input_length] 43 | labels = target[1:][:max_input_length] 44 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 45 | input_ids = input_ids + [eos] * (max_input_length - len(input_ids)) 46 | labels = labels + [-100] * (max_input_length - len(labels)) 47 | 48 | feature = Feature( 49 | input_ids, attention_mask, 50 | labels, 51 | ) 52 | return feature 53 | -------------------------------------------------------------------------------- /collators/gpt2_contrast.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 10 | 11 | assert not infer 12 | input_ids = torch.tensor([f.input_ids for f in features if f.input_ids is not None], dtype=torch.long) 13 | attention_mask = torch.tensor([f.attention_mask for f in features if f.input_ids is not None], dtype=torch.float) 14 | labels = torch.tensor([f.labels for f in features if f.input_ids is not None], dtype=torch.long) 15 | 16 | if input_ids.size(0) == 0: 17 | input_ids = torch.tensor([[0]], dtype=torch.long) 18 | attention_mask = torch.tensor([[0.]], dtype=torch.float) 19 | labels = torch.tensor([[-100]], dtype=torch.long) 20 | 21 | res = { 22 | 'input_ids': input_ids, 23 | 'attention_mask': attention_mask, 24 | 'labels': labels, 25 | } 26 | 27 | neg_input_ids = torch.tensor([e for f in features for e in f.neg_input_ids], dtype=torch.long) 28 | if len(neg_input_ids) == 0: 29 | return res 30 | 31 | pos_input_ids = torch.tensor([e for f in features for e in f.pos_input_ids], dtype=torch.long) 32 | pos_labels = torch.tensor([e for f in features for e in f.pos_labels], dtype=torch.long) 33 | neg_labels = torch.tensor([e for f in features for e in f.neg_labels], dtype=torch.long) 34 | 35 | res = { 36 | 'input_ids': input_ids, 37 | 'attention_mask': attention_mask, 38 | 'labels': labels, 39 | 40 | 'pos_input_ids': pos_input_ids, 41 | 'pos_labels': pos_labels, 42 | 'neg_input_ids': neg_input_ids, 43 | 'neg_labels': neg_labels, 44 | } 45 | return res 46 | 47 | 48 | class Feature(object): 49 | def __init__( 50 | self, 51 | input_ids, attention_mask, 52 | labels, 53 | pos_input_ids, pos_labels, 54 | neg_input_ids, neg_labels, 55 | ): 56 | self.input_ids = input_ids 57 | self.attention_mask = attention_mask 58 | self.labels = labels 59 | 60 | self.pos_input_ids = pos_input_ids 61 | self.pos_labels = pos_labels 62 | self.neg_input_ids = neg_input_ids 63 | self.neg_labels = neg_labels 64 | 65 | 66 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 67 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 68 | eos = toker.eos_token_id 69 | 70 | if 'target' in data: 71 | target = process(data['target']) 72 | 73 | input_ids = target[:-1][:max_input_length] 74 | labels = target[1:][:max_input_length] 75 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 76 | input_ids = input_ids + [eos] * (max_input_length - len(input_ids)) 77 | labels = labels + [-100] * (max_input_length - len(labels)) 78 | else: 79 | input_ids = None 80 | attention_mask = None 81 | labels = None 82 | 83 | pos_targets = [process(e) for e in data['pos_targets']] 84 | neg_targets = [process(e) for e in data['neg_targets']] 85 | 86 | # we use max_decoder_input_length as calibration data max length 87 | pos_input_ids = [e[:-1][:max_decoder_input_length] for e in pos_targets] 88 | pos_labels = [e[1:][:max_decoder_input_length] for e in pos_targets] 89 | pos_input_ids = [e + [eos] * (max_decoder_input_length - len(e)) for e in pos_input_ids] 90 | pos_labels = [e + [-100] * (max_decoder_input_length - len(e)) for e in pos_labels] 91 | 92 | neg_input_ids = [e[:-1][:max_decoder_input_length] for e in neg_targets] 93 | neg_labels = [e[1:][:max_decoder_input_length] for e in neg_targets] 94 | neg_input_ids = [e + [eos] * (max_decoder_input_length - len(e)) for e in neg_input_ids] 95 | neg_labels = [e + [-100] * (max_decoder_input_length - len(e)) for e in neg_labels] 96 | 97 | feature = Feature( 98 | input_ids, attention_mask, 99 | labels, 100 | pos_input_ids, pos_labels, 101 | neg_input_ids, neg_labels, 102 | ) 103 | return feature 104 | -------------------------------------------------------------------------------- /collators/gpt2_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length, infer) for e in data_list] 10 | 11 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 12 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 13 | if not infer: 14 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 15 | res = { 16 | 'input_ids': input_ids, 17 | 'attention_mask': attention_mask, 18 | 'labels': labels, 19 | } 20 | else: 21 | references = [f.reference for f in features] 22 | res = { 23 | 'input_ids': input_ids, 24 | 'attention_mask': attention_mask, 25 | 'references': references, 26 | } 27 | 28 | return res 29 | 30 | 31 | class Feature(object): 32 | def __init__( 33 | self, 34 | input_ids, attention_mask, 35 | labels, 36 | reference=None, 37 | ): 38 | self.input_ids = input_ids 39 | self.attention_mask = attention_mask 40 | self.labels = labels 41 | self.reference = reference 42 | 43 | 44 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None, infer=False) -> Feature: 45 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 46 | 47 | source = process(data['source']) 48 | target = process(data['target']) 49 | eos = toker.eos_token_id 50 | reference = data['target'] 51 | 52 | if infer: 53 | input_ids = source[-max_input_length:] 54 | attention_mask = [1.] * len(input_ids) 55 | input_ids = input_ids[:-1] + [eos] * (max_input_length - len(input_ids)) + input_ids[-1:] 56 | attention_mask = attention_mask[:-1] + [0.] * (max_input_length - len(attention_mask)) + attention_mask[-1:] 57 | labels = None 58 | else: 59 | source = source[-max_input_length:] 60 | input_ids = source + target[:-1][:max_decoder_input_length] 61 | labels = [-100] * (len(source) - 1) + target[:max_decoder_input_length + 1] 62 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length + max_decoder_input_length - len(input_ids)) 63 | input_ids = input_ids + [eos] * (max_input_length + max_decoder_input_length - len(input_ids)) 64 | labels = labels + [-100] * (max_input_length + max_decoder_input_length - len(labels)) 65 | 66 | feature = Feature( 67 | input_ids, attention_mask, 68 | labels, 69 | reference, 70 | ) 71 | return feature 72 | -------------------------------------------------------------------------------- /collators/text2text.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 10 | 11 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 12 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 13 | if not infer: 14 | decoder_input_ids = torch.tensor([f.decoder_input_ids for f in features], dtype=torch.long) 15 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 16 | else: 17 | decoder_input_ids = None 18 | references = None 19 | if all(f.reference is not None for f in features): 20 | references = [f.reference for f in features] 21 | 22 | res = { 23 | 'input_ids': input_ids, 24 | 'attention_mask': attention_mask, 25 | 'decoder_input_ids': decoder_input_ids, 26 | } 27 | if not infer: 28 | res['labels'] = labels 29 | elif references is not None: 30 | res['references'] = references 31 | return res 32 | 33 | 34 | class Feature(object): 35 | def __init__( 36 | self, 37 | input_ids, attention_mask, 38 | decoder_input_ids, labels, 39 | reference=None, 40 | ): 41 | self.input_ids = input_ids 42 | self.attention_mask = attention_mask 43 | self.decoder_input_ids = decoder_input_ids 44 | self.labels = labels 45 | self.reference = reference 46 | 47 | 48 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 49 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 50 | source = process(data['source']) 51 | target = process(data['target']) 52 | reference = toker.decode(target[1:-1], skip_special_tokens=True) 53 | 54 | pad_token_id = toker.pad_token_id 55 | 56 | input_ids = source[-max_input_length:] 57 | decoder_input_ids = target[:-1][:max_decoder_input_length] 58 | labels = target[1:][:max_decoder_input_length] 59 | assert decoder_input_ids[1:] == labels[:-1] 60 | 61 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 62 | input_ids = input_ids + [pad_token_id] * (max_input_length - len(input_ids)) 63 | decoder_input_ids = decoder_input_ids + [pad_token_id] * (max_decoder_input_length - len(decoder_input_ids)) 64 | labels = labels + [-100] * (max_decoder_input_length - len(labels)) 65 | 66 | feature = Feature( 67 | input_ids, attention_mask, 68 | decoder_input_ids, labels, 69 | reference, 70 | ) 71 | return feature 72 | -------------------------------------------------------------------------------- /collators/text2text_contrast.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 10 | 11 | assert not infer 12 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 13 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 14 | decoder_input_ids = torch.tensor([f.decoder_input_ids for f in features], dtype=torch.long) 15 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 16 | 17 | res = { 18 | 'input_ids': input_ids, 19 | 'attention_mask': attention_mask, 20 | 'decoder_input_ids': decoder_input_ids, 21 | 'labels': labels, 22 | } 23 | 24 | neg_decoder_input_ids = torch.tensor([e for f in features for e in f.neg_decoder_input_ids], dtype=torch.long) 25 | if len(neg_decoder_input_ids) == 0: 26 | return res 27 | 28 | pos_decoder_input_ids = torch.tensor([e for f in features for e in f.pos_decoder_input_ids], dtype=torch.long) 29 | pos_labels = torch.tensor([e for f in features for e in f.pos_labels], dtype=torch.long) 30 | neg_labels = torch.tensor([e for f in features for e in f.neg_labels], dtype=torch.long) 31 | selected_indices = torch.tensor([i for i, f in enumerate(features) for _ in range(len(f.neg_decoder_input_ids))], dtype=torch.long) 32 | 33 | res = { 34 | 'input_ids': input_ids, 35 | 'attention_mask': attention_mask, 36 | 'decoder_input_ids': decoder_input_ids, 37 | 'labels': labels, 38 | 39 | 'pos_decoder_input_ids': pos_decoder_input_ids, 40 | 'pos_labels': pos_labels, 41 | 'neg_decoder_input_ids': neg_decoder_input_ids, 42 | 'neg_labels': neg_labels, 43 | 'selected_indices': selected_indices, 44 | } 45 | return res 46 | 47 | 48 | class Feature(object): 49 | def __init__( 50 | self, 51 | input_ids, attention_mask, 52 | decoder_input_ids, labels, 53 | pos_decoder_input_ids, pos_labels, 54 | neg_decoder_input_ids, neg_labels, 55 | ): 56 | self.input_ids = input_ids 57 | self.attention_mask = attention_mask 58 | self.decoder_input_ids = decoder_input_ids 59 | self.labels = labels 60 | 61 | self.pos_decoder_input_ids = pos_decoder_input_ids 62 | self.pos_labels = pos_labels 63 | self.neg_decoder_input_ids = neg_decoder_input_ids 64 | self.neg_labels = neg_labels 65 | 66 | 67 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 68 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 69 | source = process(data['source']) 70 | target = process(data['target']) 71 | 72 | pad_token_id = toker.pad_token_id 73 | 74 | input_ids = source[-max_input_length:] 75 | decoder_input_ids = target[:-1][:max_decoder_input_length] 76 | labels = target[1:][:max_decoder_input_length] 77 | assert decoder_input_ids[1:] == labels[:-1] 78 | 79 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 80 | input_ids = input_ids + [pad_token_id] * (max_input_length - len(input_ids)) 81 | decoder_input_ids = decoder_input_ids + [pad_token_id] * (max_decoder_input_length - len(decoder_input_ids)) 82 | labels = labels + [-100] * (max_decoder_input_length - len(labels)) 83 | 84 | pos_targets = [process(e) for e in data['pos_targets']] 85 | neg_targets = [process(e) for e in data['neg_targets']] 86 | 87 | pos_decoder_input_ids = [e[:-1][:max_decoder_input_length] for e in pos_targets] 88 | pos_labels = [e[1:][:max_decoder_input_length] for e in pos_targets] 89 | pos_decoder_input_ids = [e + [pad_token_id] * (max_decoder_input_length - len(e)) for e in pos_decoder_input_ids] 90 | pos_labels = [e + [-100] * (max_decoder_input_length - len(e)) for e in pos_labels] 91 | 92 | neg_decoder_input_ids = [e[:-1][:max_decoder_input_length] for e in neg_targets] 93 | neg_labels = [e[1:][:max_decoder_input_length] for e in neg_targets] 94 | neg_decoder_input_ids = [e + [pad_token_id] * (max_decoder_input_length - len(e)) for e in neg_decoder_input_ids] 95 | neg_labels = [e + [-100] * (max_decoder_input_length - len(e)) for e in neg_labels] 96 | 97 | feature = Feature( 98 | input_ids, attention_mask, 99 | decoder_input_ids, labels, 100 | pos_decoder_input_ids, pos_labels, 101 | neg_decoder_input_ids, neg_labels, 102 | ) 103 | return feature 104 | -------------------------------------------------------------------------------- /collators/text2text_dexperts.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | MAX_AUX_LENGTH = 32 8 | 9 | 10 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 11 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 12 | 13 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 14 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 15 | aux_input_ids = torch.tensor([f.aux_input_ids for f in features], dtype=torch.long) 16 | aux_attention_mask = torch.tensor([f.aux_attention_mask for f in features], dtype=torch.float) 17 | if not infer: 18 | decoder_input_ids = torch.tensor([f.decoder_input_ids for f in features], dtype=torch.long) 19 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 20 | else: 21 | decoder_input_ids = None 22 | references = None 23 | if all(f.reference is not None for f in features): 24 | references = [f.reference for f in features] 25 | 26 | res = { 27 | 'input_ids': input_ids, 28 | 'attention_mask': attention_mask, 29 | 'aux_input_ids': aux_input_ids, 30 | 'aux_attention_mask': aux_attention_mask, 31 | 'decoder_input_ids': decoder_input_ids, 32 | } 33 | if not infer: 34 | res['labels'] = labels 35 | elif references is not None: 36 | res['references'] = references 37 | return res 38 | 39 | 40 | class Feature(object): 41 | def __init__( 42 | self, 43 | input_ids, attention_mask, 44 | aux_input_ids, aux_attention_mask, 45 | decoder_input_ids, labels, 46 | reference=None, 47 | ): 48 | self.input_ids = input_ids 49 | self.attention_mask = attention_mask 50 | self.aux_input_ids = aux_input_ids 51 | self.aux_attention_mask = aux_attention_mask 52 | self.decoder_input_ids = decoder_input_ids 53 | self.labels = labels 54 | self.reference = reference 55 | 56 | 57 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 58 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 59 | source = process(data['source']) 60 | target = process(data['target']) 61 | reference = toker.decode(target[1:-1], skip_special_tokens=True) 62 | if 'aux_source' in data: 63 | aux_source = process(data['aux_source']) 64 | max_aux_length = MAX_AUX_LENGTH 65 | else: 66 | aux_source = source[:] 67 | max_aux_length = max_input_length 68 | 69 | pad_token_id = toker.pad_token_id 70 | 71 | #assert len(source) <= max_input_length 72 | #assert len(target) - 1 <= max_decoder_input_length 73 | input_ids = source[-max_input_length:] 74 | aux_input_ids = aux_source[:max_aux_length] 75 | decoder_input_ids = target[:-1][:max_decoder_input_length] 76 | labels = target[1:][:max_decoder_input_length] 77 | assert decoder_input_ids[1:] == labels[:-1] 78 | 79 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 80 | input_ids = input_ids + [pad_token_id] * (max_input_length - len(input_ids)) 81 | aux_attention_mask = [1.] * len(aux_input_ids) + [0.] * (max_aux_length - len(aux_input_ids)) 82 | aux_input_ids = aux_input_ids + [pad_token_id] * (max_aux_length - len(aux_input_ids)) 83 | decoder_input_ids = decoder_input_ids + [pad_token_id] * (max_decoder_input_length - len(decoder_input_ids)) 84 | labels = labels + [-100] * (max_decoder_input_length - len(labels)) 85 | 86 | feature = Feature( 87 | input_ids, attention_mask, 88 | aux_input_ids, aux_attention_mask, 89 | decoder_input_ids, labels, 90 | reference, 91 | ) 92 | return feature 93 | -------------------------------------------------------------------------------- /collators/text2text_labels.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | 8 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 9 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 10 | 11 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 12 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 13 | if not infer: 14 | decoder_input_ids = torch.tensor([f.decoder_input_ids for f in features], dtype=torch.long) 15 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 16 | cls_labels = torch.tensor([f.cls_label for f in features], dtype=torch.float) 17 | else: 18 | decoder_input_ids = None 19 | references = None 20 | if all(f.reference is not None for f in features): 21 | references = [f.reference for f in features] 22 | 23 | res = { 24 | 'input_ids': input_ids, 25 | 'attention_mask': attention_mask, 26 | 'decoder_input_ids': decoder_input_ids, 27 | } 28 | if not infer: 29 | res['labels'] = labels 30 | res['cls_labels'] = cls_labels 31 | elif references is not None: 32 | res['references'] = references 33 | return res 34 | 35 | 36 | class Feature(object): 37 | def __init__( 38 | self, 39 | input_ids, attention_mask, 40 | decoder_input_ids, labels, 41 | cls_label=None, 42 | reference=None, 43 | ): 44 | self.input_ids = input_ids 45 | self.attention_mask = attention_mask 46 | self.decoder_input_ids = decoder_input_ids 47 | self.labels = labels 48 | self.cls_label = cls_label 49 | self.reference = reference 50 | 51 | 52 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 53 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 54 | source = process(data['source']) 55 | target = process(data['target']) 56 | reference = toker.decode(target[1:-1], skip_special_tokens=True) 57 | cls_label = None 58 | if 'cls_label' in data: 59 | cls_label = data['cls_label'] 60 | 61 | pad_token_id = toker.pad_token_id 62 | 63 | #assert len(source) <= max_input_length 64 | #assert len(target) - 1 <= max_decoder_input_length 65 | input_ids = source[-max_input_length:] 66 | decoder_input_ids = target[:-1][:max_decoder_input_length] 67 | labels = target[1:][:max_decoder_input_length] 68 | assert decoder_input_ids[1:] == labels[:-1] 69 | 70 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 71 | input_ids = input_ids + [pad_token_id] * (max_input_length - len(input_ids)) 72 | decoder_input_ids = decoder_input_ids + [pad_token_id] * (max_decoder_input_length - len(decoder_input_ids)) 73 | labels = labels + [-100] * (max_decoder_input_length - len(labels)) 74 | 75 | feature = Feature( 76 | input_ids, attention_mask, 77 | decoder_input_ids, labels, 78 | cls_label, 79 | reference, 80 | ) 81 | return feature 82 | -------------------------------------------------------------------------------- /collators/text2text_padding.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import torch 4 | from tqdm import tqdm 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | 7 | MAX_SRC_NUM = 5 8 | MAX_AUX_LENGTH = 32 9 | 10 | 11 | def collate_fn(data_list, toker: PreTrainedTokenizer, max_input_length=None, max_decoder_input_length=None, infer=False): 12 | features: List[Feature] = [convert_data_to_feature(e, toker, max_input_length, max_decoder_input_length) for e in data_list] 13 | 14 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 15 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.float) 16 | aux_input_ids = torch.tensor([e for f in features for e in f.aux_input_ids], dtype=torch.long) 17 | aux_attention_mask = torch.tensor([e for f in features for e in f.aux_attention_mask], dtype=torch.float) 18 | if not infer: 19 | decoder_input_ids = torch.tensor([f.decoder_input_ids for f in features], dtype=torch.long) 20 | labels = torch.tensor([f.labels for f in features], dtype=torch.long) 21 | else: 22 | decoder_input_ids = None 23 | references = None 24 | if all(f.reference is not None for f in features): 25 | references = [f.reference for f in features] 26 | 27 | res = { 28 | 'input_ids': input_ids, 29 | 'attention_mask': attention_mask, 30 | 'aux_input_ids': aux_input_ids, 31 | 'aux_attention_mask': aux_attention_mask, 32 | 'decoder_input_ids': decoder_input_ids, 33 | } 34 | if not infer: 35 | res['labels'] = labels 36 | elif references is not None: 37 | res['references'] = references 38 | return res 39 | 40 | 41 | class Feature(object): 42 | def __init__( 43 | self, 44 | input_ids, attention_mask, 45 | aux_input_ids, aux_attention_mask, 46 | decoder_input_ids, labels, 47 | reference=None, 48 | ): 49 | self.input_ids = input_ids 50 | self.attention_mask = attention_mask 51 | self.aux_input_ids = aux_input_ids 52 | self.aux_attention_mask = aux_attention_mask 53 | self.decoder_input_ids = decoder_input_ids 54 | self.labels = labels 55 | self.reference = reference 56 | 57 | 58 | def convert_data_to_feature(data, toker: PreTrainedTokenizer, max_input_length, max_decoder_input_length=None) -> Feature: 59 | process = lambda x: toker.convert_tokens_to_ids(toker.tokenize(x)) 60 | source = process(data['source']) 61 | target = process(data['target']) 62 | reference = toker.decode(target[1:-1], skip_special_tokens=True) 63 | aux_sources = [process(e) for e in data['aux_sources']] 64 | 65 | pad_token_id = toker.pad_token_id 66 | 67 | #assert len(source) <= max_input_length 68 | #assert len(target) - 1 <= max_decoder_input_length 69 | input_ids = source[-max_input_length:] 70 | aux_input_ids = [e[:MAX_AUX_LENGTH] for e in aux_sources] + [[] for _ in range(MAX_SRC_NUM - len(aux_sources))] 71 | decoder_input_ids = target[:-1][:max_decoder_input_length] 72 | labels = target[1:][:max_decoder_input_length] 73 | assert decoder_input_ids[1:] == labels[:-1] 74 | 75 | attention_mask = [1.] * len(input_ids) + [0.] * (max_input_length - len(input_ids)) 76 | input_ids = input_ids + [pad_token_id] * (max_input_length - len(input_ids)) 77 | aux_attention_mask = [[1.] * len(e) + [0.] * (MAX_AUX_LENGTH - len(e)) for e in aux_input_ids] 78 | aux_input_ids = [e + [pad_token_id] * (MAX_AUX_LENGTH - len(e)) for e in aux_input_ids] 79 | decoder_input_ids = decoder_input_ids + [pad_token_id] * (max_decoder_input_length - len(decoder_input_ids)) 80 | labels = labels + [-100] * (max_decoder_input_length - len(labels)) 81 | 82 | feature = Feature( 83 | input_ids, attention_mask, 84 | aux_input_ids, aux_attention_mask, 85 | decoder_input_ids, labels, 86 | reference, 87 | ) 88 | return feature 89 | -------------------------------------------------------------------------------- /eval_bad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import numpy as np 7 | from tqdm import tqdm 8 | from functools import partial 9 | 10 | import torch 11 | from torch import Tensor 12 | from transformers.trainer_utils import set_seed 13 | from accelerate import Accelerator 14 | 15 | from utils.building_utils import boolean_string, build_model 16 | from utils.dataloader_utils import _norm, BatchDataLoader 17 | 18 | 19 | def _norm(x): 20 | return " ".join(x.strip().split()) 21 | 22 | 23 | def make_sample(args, toker, context, generation): 24 | utterances = context + [generation] 25 | dialogue_context = '' 26 | for i, utterance in enumerate(utterances): 27 | text = _norm(utterance) 28 | if i % 2 == 0: 29 | dialogue_context += f'Human: {text}\n' 30 | else: 31 | dialogue_context += f'Bot: {text}\n' 32 | 33 | dialogue_context = dialogue_context.strip() 34 | dialogue_context = toker.convert_tokens_to_string(toker.tokenize(dialogue_context)[-args.max_input_length+2:]) 35 | d = { 36 | 'text': dialogue_context, 37 | 'label': 0, 38 | } 39 | return d 40 | 41 | 42 | def run(args): 43 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 44 | datefmt = '%m/%d/%Y %H:%M:%S', 45 | level = logging.INFO) 46 | logger = logging.getLogger(__name__) 47 | 48 | accelerator = Accelerator() 49 | set_seed(args.seed) 50 | n_gpu = torch.cuda.device_count() 51 | 52 | toker, model = build_model(args) 53 | model = accelerator.prepare(model) 54 | if n_gpu > 1: 55 | logger.info('use `torch.nn.DataParallel`') 56 | model = torch.nn.DataParallel(model) 57 | model.eval() 58 | 59 | contexts = [json.loads(e) for e in open(args.context_file)] 60 | contexts = [e['context'] for e in contexts] 61 | for infer_data_path in args.infer_data_paths: 62 | generation_file = infer_data_path + '/gen.txt' 63 | if not os.path.exists(generation_file): 64 | continue 65 | if os.path.exists(infer_data_path + '/pred_list.txt'): 66 | print('prediction have existed') 67 | continue 68 | generations = [json.loads(e)['generation'] for e in open(generation_file)] 69 | assert len(contexts) == len(generations) 70 | 71 | data_list = [] 72 | for context, generation in tqdm(zip(contexts, generations), total=len(contexts), ncols=0, leave=False): 73 | if not isinstance(generation, list): 74 | generation = [generation] 75 | for g in generation: 76 | tmp_data = make_sample(args, toker, context, g) 77 | data_list.append(tmp_data) 78 | 79 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 80 | infer_dataloader = BatchDataLoader( 81 | data_list=data_list, 82 | batch_size=args.batch_size, 83 | collate_fn=partial( 84 | collate_fn, 85 | toker=toker, 86 | max_input_length=args.max_input_length, 87 | infer=True, 88 | ), 89 | shuffle=False, 90 | ) 91 | infer_dataloader = accelerator.prepare(infer_dataloader) 92 | 93 | preds = [] 94 | for batch in tqdm(infer_dataloader, total=len(infer_dataloader), desc='inference', dynamic_ncols=True): 95 | batch.pop('labels') 96 | batch['inference'] = True 97 | with torch.no_grad(): 98 | encoded_info = model(**batch) 99 | preds.extend(encoded_info['preds_dist'].tolist()) 100 | 101 | with open(infer_data_path + '/pred_list.txt', 'w') as f: 102 | for d in preds: 103 | f.write(json.dumps(d) + '\n') 104 | 105 | 106 | def main(): 107 | ######################################################################### 108 | # Prepare Parser 109 | ########################################################################## 110 | 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--collator_name', type=str, required=True) 113 | parser.add_argument('--model_name', type=str, required=True) 114 | parser.add_argument('--pretrained_model_path', type=str, required=True) 115 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 116 | 117 | parser.add_argument('--context_file', type=str, required=True) 118 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 119 | parser.add_argument("--max_input_length", type=int, default=48) 120 | 121 | parser.add_argument("--seed", type=int, default=42) 122 | parser.add_argument("--batch_size", type=int, default=128) 123 | 124 | args = parser.parse_args() 125 | run(args) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /eval_dist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import nltk 7 | import numpy as np 8 | from tqdm import tqdm 9 | from collections import defaultdict 10 | from functools import partial 11 | import multiprocessing as mp 12 | from collections import Counter 13 | 14 | import torch 15 | from torch import Tensor 16 | from transformers.trainer_utils import set_seed 17 | from accelerate import Accelerator 18 | 19 | from metrics import Metrics 20 | 21 | 22 | def _norm(x): 23 | return " ".join(x.strip().split()) 24 | 25 | 26 | def tokenize(x): 27 | return nltk.word_tokenize(_norm(x)) 28 | 29 | 30 | def calculate_distinct(generation): 31 | generation = [tokenize(e) for e in generation] 32 | lengths = [len(e) for e in generation] 33 | distinct = [] 34 | for n in range(1, 4): 35 | ngrams = [] 36 | for g in generation: 37 | tmp_ngrams = list(zip(*[g[i:] for i in range(n)])) 38 | ngrams.extend(tmp_ngrams) 39 | ngrams = Counter(ngrams) 40 | #distinct.append(len(ngrams) / sum(ngrams.values())) 41 | distinct.append(len(ngrams) / sum(lengths)) 42 | return distinct 43 | 44 | 45 | def evaluate(args): 46 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 47 | datefmt = '%m/%d/%Y %H:%M:%S', 48 | level = logging.INFO) 49 | logger = logging.getLogger(__name__) 50 | 51 | pool = mp.Pool(mp.cpu_count() * 2) 52 | 53 | for infer_data_path in args.infer_data_paths: 54 | generation_file = infer_data_path + '/gen.txt' 55 | if not os.path.exists(generation_file): 56 | continue 57 | if os.path.exists(f'{infer_data_path}/dist_list.txt'): 58 | print('prediction have existed') 59 | continue 60 | 61 | generations = [json.loads(e)['generation'] for e in open(generation_file)] 62 | distinct = [e for e in pool.imap(calculate_distinct, tqdm(generations, dynamic_ncols=True, total=len(generations)))] 63 | 64 | with open(infer_data_path + '/dist_list.txt', 'w') as f: 65 | for d in distinct: 66 | f.write(json.dumps(d) + '\n') 67 | 68 | 69 | def main(): 70 | ######################################################################### 71 | # Prepare Parser 72 | ########################################################################## 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--context_file', type=str, required=True) 76 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 77 | 78 | args = parser.parse_args() 79 | evaluate(args) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /eval_div.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import numpy as np 7 | from tqdm import tqdm 8 | from functools import partial 9 | import multiprocessing as mp 10 | from collections import Counter 11 | import nltk 12 | 13 | 14 | def nltk_repetition(text): 15 | tokens = nltk.word_tokenize(text) 16 | if len(tokens) <= 4: 17 | tokens = [e for e in list(text) if e.strip()] 18 | repn = {} 19 | for k in range(2, 5): 20 | ngrams = list(zip(*[tokens[i:] for i in range(k)])) 21 | ngrams = Counter(ngrams) 22 | repn[k] = 1. - len(ngrams) / max(sum(ngrams.values()), 1e-10) 23 | div = (1. - repn[2]) * (1. - repn[3]) * (1. - repn[4]) 24 | return div 25 | 26 | 27 | def infer(args): 28 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt = '%m/%d/%Y %H:%M:%S', 30 | level = logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | pool = mp.Pool(mp.cpu_count() * 2) 34 | 35 | contexts = [] 36 | for e in open(args.context_file): 37 | e = json.loads(e) 38 | contexts.append(e['source']) 39 | 40 | for infer_data_path in args.infer_data_paths: 41 | generation_file = infer_data_path + '/gen.txt' 42 | if not os.path.exists(generation_file): 43 | continue 44 | if os.path.exists(infer_data_path + '/div_list.txt'): 45 | print('prediction have existed') 46 | #continue 47 | generations = [] 48 | lines = open(generation_file).readlines() 49 | assert len(lines) == len(contexts) 50 | for idx, e in enumerate(lines): 51 | g = json.loads(e)['generation'] 52 | if not isinstance(g, list): 53 | g = [g] 54 | for gg in g: 55 | #generations.append(contexts[idx] + gg) 56 | generations.append(gg) 57 | preds_new = [e for e in pool.imap(nltk_repetition, tqdm(generations, total=len(generations), ncols=0))] 58 | 59 | with open(infer_data_path + '/div_list.txt', 'w') as f: 60 | for d in preds_new: 61 | f.write(json.dumps(d) + '\n') 62 | 63 | pool.close() 64 | 65 | 66 | def main(): 67 | ######################################################################### 68 | # Prepare Parser 69 | ########################################################################## 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 73 | parser.add_argument('--context_file', type=str, required=True) 74 | 75 | args = parser.parse_args() 76 | infer(args) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /eval_mauve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import numpy as np 7 | from tqdm import tqdm 8 | from functools import partial 9 | import multiprocessing as mp 10 | from collections import Counter 11 | import nltk 12 | from mauve import compute_mauve 13 | from transformers import AutoTokenizer 14 | 15 | 16 | def infer(args): 17 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt = '%m/%d/%Y %H:%M:%S', 19 | level = logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | gpt2_toker = AutoTokenizer.from_pretrained('/home/zhengchujie/pretrained-models/gpt2-small', use_fast=True) 23 | 24 | contexts = [] 25 | p_text = [] 26 | for e in open(args.context_file): 27 | e = json.loads(e) 28 | contexts.append(gpt2_toker.decode(gpt2_toker.encode(e['source']))) 29 | p_text.append(contexts[-1] + e['target']) 30 | 31 | p_features = None 32 | 33 | if args.save_name == 'large': 34 | featurize_model_name = f'/home/zhengchujie/pretrained-models-large/gpt2-large' 35 | else: 36 | featurize_model_name = f'/home/zhengchujie/pretrained-models/gpt2-{args.save_name}' 37 | 38 | for infer_data_path in args.infer_data_paths: 39 | generation_file = infer_data_path + '/gen.txt' 40 | if not os.path.exists(generation_file): 41 | continue 42 | if os.path.exists(infer_data_path + f'/mauve_{args.save_name}.txt'): 43 | print('prediction have existed') 44 | continue 45 | q_text = [] 46 | lines = open(generation_file).readlines() 47 | assert len(lines) == len(contexts) 48 | for idx, e in enumerate(lines): 49 | g = json.loads(e)['generation'] 50 | q_text.append(contexts[idx] + g) 51 | 52 | if p_features is None: 53 | out = compute_mauve( 54 | p_text=p_text, 55 | q_text=q_text, 56 | max_text_length=160, 57 | featurize_model_name=featurize_model_name, 58 | batch_size=args.batch_size, 59 | verbose=False 60 | ) 61 | p_features = out.p_features 62 | else: 63 | out = compute_mauve( 64 | p_features=p_features, 65 | q_text=q_text, 66 | max_text_length=160, 67 | featurize_model_name=featurize_model_name, 68 | batch_size=args.batch_size, 69 | verbose=False 70 | ) 71 | 72 | res = out.mauve 73 | with open(infer_data_path + f'/mauve_{args.save_name}.txt', 'w') as f: 74 | f.write(str(res) + '\n') 75 | 76 | 77 | def main(): 78 | ######################################################################### 79 | # Prepare Parser 80 | ########################################################################## 81 | 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 84 | parser.add_argument('--save_name', type=str, required=True, choices=['small', 'medium', 'large']) 85 | parser.add_argument('--context_file', type=str, required=True) 86 | parser.add_argument('--batch_size', type=int, required=True) 87 | 88 | args = parser.parse_args() 89 | infer(args) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /eval_ppl_blender.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | from importlib import import_module 7 | import numpy as np 8 | from tqdm import tqdm 9 | from functools import partial 10 | 11 | import torch 12 | from torch import Tensor 13 | from transformers.trainer_utils import set_seed 14 | from accelerate import Accelerator 15 | 16 | from utils.building_utils import boolean_string, build_model 17 | from utils.dataloader_utils import _norm, BatchDataLoader 18 | from utils.eval_utils import eval_model_loss 19 | 20 | 21 | def _norm(x): 22 | return " ".join(x.strip().split()) 23 | 24 | 25 | def make_source(toker, x): 26 | if isinstance(x, list): 27 | x = [' ' + e.strip() for e in x] 28 | x = ' '.join(x) + toker.eos_token 29 | return x 30 | 31 | 32 | def make_sample(toker, context, generation): 33 | target = toker.bos_token + ' ' + generation.strip() + toker.eos_token 34 | d = {'source': context, 'target': target} 35 | return d 36 | 37 | 38 | def run(args): 39 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt = '%m/%d/%Y %H:%M:%S', 41 | level = logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | accelerator = Accelerator() 45 | set_seed(args.seed) 46 | n_gpu = torch.cuda.device_count() 47 | 48 | toker, model = build_model(args) 49 | model = accelerator.prepare(model) 50 | if n_gpu > 1: 51 | logger.info('use `torch.nn.DataParallel`') 52 | model = torch.nn.DataParallel(model) 53 | model.eval() 54 | 55 | contexts = [json.loads(e) for e in open(args.context_file)] 56 | contexts = [make_source(toker, e['context'] if 'context' in e else e['source']) for e in contexts] 57 | for infer_data_path in args.infer_data_paths: 58 | generation_file = infer_data_path + '/gen.txt' 59 | if not os.path.exists(generation_file): 60 | continue 61 | if os.path.exists(f'{infer_data_path}/loss_{args.save_name}_list.txt'): 62 | print('prediction have existed') 63 | continue 64 | generations = [json.loads(e)['generation'] for e in open(generation_file)] 65 | assert len(contexts) == len(generations) 66 | 67 | data_list = [] 68 | for context, generation in tqdm(zip(contexts, generations), total=len(contexts), ncols=0, leave=False): 69 | if not isinstance(generation, list): 70 | generation = [generation] 71 | for g in generation: 72 | tmp_data = make_sample(toker, context, g) 73 | data_list.append(tmp_data) 74 | 75 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 76 | eval_dataloader = BatchDataLoader( 77 | data_list=data_list, 78 | batch_size=args.batch_size, 79 | collate_fn=partial( 80 | collate_fn, 81 | toker=toker, 82 | max_input_length=args.max_input_length, 83 | max_decoder_input_length=args.max_decoder_input_length, 84 | infer=False, 85 | ), 86 | shuffle=False, 87 | ) 88 | eval_dataloader = accelerator.prepare(eval_dataloader) 89 | 90 | _, eval_ppl_micro, *_, pointwise_loss, pointwise_sample = eval_model_loss( 91 | accelerator=accelerator, 92 | model=model, 93 | eval_dataloader=eval_dataloader, 94 | epoch_id=0, 95 | infer=True, 96 | ) 97 | eval_loss_list = [{'loss': np.sum(x), 'num_tokens': y} for x, y in zip(pointwise_loss, pointwise_sample)] 98 | with open(f'{infer_data_path}/loss_{args.save_name}_list.txt', 'w') as f: 99 | for d in eval_loss_list: 100 | f.write(json.dumps(d) + '\n') 101 | 102 | 103 | def main(): 104 | ######################################################################### 105 | # Prepare Parser 106 | ########################################################################## 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument('--collator_name', type=str, default='text2text') 110 | parser.add_argument('--save_name', type=str, required=True, choices=['self', 'large']) 111 | parser.add_argument('--model_name', type=str, default='blender') 112 | parser.add_argument('--pretrained_model_path', type=str, required=True) 113 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 114 | 115 | parser.add_argument('--context_file', type=str, required=True) 116 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 117 | parser.add_argument("--max_input_length", type=int, default=128) 118 | parser.add_argument("--max_decoder_input_length", type=int, default=24) 119 | 120 | parser.add_argument("--seed", type=int, default=42) 121 | parser.add_argument("--batch_size", type=int, default=128) 122 | 123 | args = parser.parse_args() 124 | run(args) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /eval_ppl_gpt2.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | from importlib import import_module 7 | import numpy as np 8 | from tqdm import tqdm 9 | from functools import partial 10 | 11 | import torch 12 | from torch import Tensor 13 | from transformers.trainer_utils import set_seed 14 | from accelerate import Accelerator 15 | 16 | from utils.building_utils import boolean_string, build_model 17 | from utils.dataloader_utils import _norm, BatchDataLoader 18 | from utils.eval_utils import eval_model_loss 19 | 20 | 21 | def _norm(x): 22 | return " ".join(x.strip().split()) 23 | 24 | 25 | def make_sample(toker, context, generation): 26 | d = {'source': context, 'target': generation} 27 | return d 28 | 29 | 30 | def run(args): 31 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 32 | datefmt = '%m/%d/%Y %H:%M:%S', 33 | level = logging.INFO) 34 | logger = logging.getLogger(__name__) 35 | 36 | accelerator = Accelerator() 37 | set_seed(args.seed) 38 | n_gpu = torch.cuda.device_count() 39 | 40 | toker, model = build_model(args) 41 | model = accelerator.prepare(model) 42 | if n_gpu > 1: 43 | logger.info('use `torch.nn.DataParallel`') 44 | model = torch.nn.DataParallel(model) 45 | model.eval() 46 | 47 | contexts = [json.loads(e) for e in open(args.context_file)] 48 | contexts = [e['source'] for e in contexts] 49 | for infer_data_path in args.infer_data_paths: 50 | if not os.path.exists(f'{infer_data_path}/gen.txt'): 51 | continue 52 | if os.path.exists(f'{infer_data_path}/loss_{args.save_name}_list.txt'): 53 | print('prediction have existed') 54 | continue 55 | generations =[json.loads(e)['generation'] for e in open(f'{infer_data_path}/gen.txt')] 56 | assert len(contexts) == len(generations) 57 | 58 | data_list = [] 59 | for context, generation in zip(contexts, generations): 60 | if not isinstance(generation, list): 61 | generation = [generation] 62 | for g in generation: 63 | tmp_data = make_sample(toker, context, g) 64 | data_list.append(tmp_data) 65 | 66 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 67 | eval_dataloader = BatchDataLoader( 68 | data_list=data_list, 69 | batch_size=args.batch_size, 70 | collate_fn=partial( 71 | collate_fn, 72 | toker=toker, 73 | max_input_length=args.max_input_length, 74 | max_decoder_input_length=args.max_decoder_input_length, 75 | infer=False, 76 | ), 77 | shuffle=False, 78 | ) 79 | eval_dataloader = accelerator.prepare(eval_dataloader) 80 | 81 | _, eval_ppl_micro, *_, pointwise_loss, pointwise_sample = eval_model_loss( 82 | accelerator=accelerator, 83 | model=model, 84 | eval_dataloader=eval_dataloader, 85 | epoch_id=0, 86 | infer=True, 87 | ) 88 | eval_loss_list = [{'loss': np.sum(x), 'num_tokens': y} for x, y in zip(pointwise_loss, pointwise_sample)] 89 | with open(f'{infer_data_path}/loss_{args.save_name}_list.txt', 'w') as f: 90 | for d in eval_loss_list: 91 | f.write(json.dumps(d) + '\n') 92 | 93 | 94 | def main(): 95 | ######################################################################### 96 | # Prepare Parser 97 | ########################################################################## 98 | 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--collator_name', type=str, default='gpt2_eval') 101 | parser.add_argument('--save_name', type=str, required=True, choices=['self', 'large']) 102 | parser.add_argument('--model_name', type=str, default='gpt2') 103 | parser.add_argument('--pretrained_model_path', type=str, required=True) 104 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 105 | 106 | parser.add_argument('--context_file', type=str, required=True) 107 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 108 | parser.add_argument("--max_input_length", type=int, default=128) 109 | parser.add_argument("--max_decoder_input_length", type=int, default=24) 110 | 111 | parser.add_argument("--seed", type=int, default=42) 112 | parser.add_argument('--batch_size', type=int, default=128) 113 | 114 | args = parser.parse_args() 115 | run(args) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /eval_senti.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import numpy as np 7 | from tqdm import tqdm 8 | from functools import partial 9 | 10 | import torch 11 | from torch import Tensor 12 | from transformers.trainer_utils import set_seed 13 | from accelerate import Accelerator 14 | 15 | from utils.building_utils import boolean_string, build_model 16 | from utils.dataloader_utils import _norm, BatchDataLoader 17 | 18 | 19 | def _norm(x): 20 | return " ".join(x.strip().split()) 21 | 22 | 23 | def is_negative(x): 24 | return x[0] > 0.5 25 | 26 | 27 | def make_sample(args, toker, context, generation): 28 | if context != '<|eodoftext|>': 29 | ret = { 30 | 'text': context + generation, 31 | 'label': 0, 32 | } 33 | else: 34 | ret = { 35 | 'text': generation, 36 | 'label': 0, 37 | } 38 | return ret 39 | 40 | 41 | def run(args): 42 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 43 | datefmt = '%m/%d/%Y %H:%M:%S', 44 | level = logging.INFO) 45 | logger = logging.getLogger(__name__) 46 | 47 | accelerator = Accelerator() 48 | set_seed(args.seed) 49 | n_gpu = torch.cuda.device_count() 50 | 51 | toker, model = build_model(args) 52 | model = accelerator.prepare(model) 53 | if n_gpu > 1: 54 | logger.info('use `torch.nn.DataParallel`') 55 | model = torch.nn.DataParallel(model) 56 | model.eval() 57 | 58 | contexts = [json.loads(e) for e in open(args.context_file)] 59 | contexts = [e['source'] for e in contexts] 60 | for infer_data_path in args.infer_data_paths: 61 | generation_file = infer_data_path + '/gen.txt' 62 | if not os.path.exists(generation_file): 63 | continue 64 | if os.path.exists(infer_data_path + '/pred_list.txt'): 65 | print('prediction have existed') 66 | continue 67 | generations = [json.loads(e)['generation'] for e in open(generation_file)] 68 | assert len(contexts) == len(generations) 69 | 70 | data_list = [] 71 | for context, generation in tqdm(zip(contexts, generations), total=len(contexts), ncols=0, leave=False): 72 | if not isinstance(generation, list): 73 | generation = [generation] 74 | for g in generation: 75 | tmp_data = make_sample(args, toker, context, g) 76 | data_list.append(tmp_data) 77 | 78 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 79 | infer_dataloader = BatchDataLoader( 80 | data_list=data_list, 81 | batch_size=args.batch_size, 82 | collate_fn=partial( 83 | collate_fn, 84 | toker=toker, 85 | max_input_length=args.max_input_length, 86 | infer=True, 87 | ), 88 | shuffle=False, 89 | ) 90 | infer_dataloader = accelerator.prepare(infer_dataloader) 91 | 92 | preds = [] 93 | for batch in tqdm(infer_dataloader, total=len(infer_dataloader), desc='inference', dynamic_ncols=True): 94 | batch.pop('labels') 95 | batch['inference'] = True 96 | with torch.no_grad(): 97 | encoded_info = model(**batch) 98 | preds.extend(encoded_info['preds_dist'].tolist()) 99 | 100 | with open(infer_data_path + '/pred_list.txt', 'w') as f: 101 | for d in preds: 102 | f.write(json.dumps(d) + '\n') 103 | 104 | 105 | def main(): 106 | ######################################################################### 107 | # Prepare Parser 108 | ########################################################################## 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--collator_name', type=str, required=True) 112 | parser.add_argument('--model_name', type=str, required=True) 113 | parser.add_argument('--pretrained_model_path', type=str, required=True) 114 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 115 | 116 | parser.add_argument('--context_file', type=str, required=True) 117 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 118 | parser.add_argument("--max_input_length", type=int, default=48) 119 | 120 | parser.add_argument("--seed", type=int, default=42) 121 | parser.add_argument("--batch_size", type=int, default=128) 122 | 123 | args = parser.parse_args() 124 | run(args) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import nltk 7 | import numpy as np 8 | from tqdm import tqdm 9 | from functools import partial 10 | from collections import defaultdict 11 | 12 | import torch 13 | from torch import Tensor 14 | from transformers.trainer_utils import set_seed 15 | from accelerate import Accelerator 16 | 17 | from utils.building_utils import boolean_string, build_model 18 | from utils.eval_utils import eval_model_loss 19 | from utils.dataloader_utils import _norm, BatchDataLoader 20 | 21 | 22 | def cut_sequence_to_eos(seq, eos_token_id): 23 | ret = [] 24 | for t in seq: 25 | if len(ret) > 0 and t == eos_token_id: 26 | break 27 | ret.append(t) 28 | return ret 29 | 30 | 31 | def cut_label_to_golden(seq): 32 | ret = [] 33 | for t in seq: 34 | if t == -100: 35 | if len(ret) == 0: 36 | continue 37 | else: 38 | break 39 | ret.append(t) 40 | return ret 41 | 42 | 43 | def generate(args): 44 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 45 | datefmt = '%m/%d/%Y %H:%M:%S', 46 | level = logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | 49 | assert len(args.infer_data_paths) == len(args.infer_names) 50 | if args.only_evaluate: 51 | args.only_generate = False 52 | 53 | accelerator = Accelerator() 54 | set_seed(args.seed) 55 | 56 | #logger.info('Input Argument Information') 57 | #args_dict = vars(args) 58 | #for a in args_dict: 59 | # logger.info('%-28s %s' % (a, args_dict[a])) 60 | 61 | ######################################################################### 62 | # Prepare Data Set 63 | ########################################################################## 64 | 65 | toker, model = build_model(args, checkpoint=args.load_checkpoint) 66 | model = accelerator.prepare(model) 67 | model.eval() 68 | 69 | model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 70 | total_params = sum([np.prod(p.size()) for p in model_parameters]) 71 | logger.info('Number of parameter = {}'.format(total_params)) 72 | 73 | eos_token_id = toker.eos_token_id # if model.config.is_encoder_decoder else toker.convert_tokens_to_ids(toker.tokenize('\n'))[0] 74 | generation_kwargs = { 75 | 'max_new_tokens': args.max_length, 76 | 'min_length': args.min_length, 77 | 'do_sample': True if (args.top_k > 0 or args.top_p < 1) else False, 78 | 'temperature': args.temperature, 79 | 'top_k': args.top_k, 80 | 'top_p': args.top_p, 81 | 'num_beams': args.num_beams, 82 | 'num_return_sequences': args.num_return_sequences, 83 | 'length_penalty': args.length_penalty, 84 | 'repetition_penalty': args.repetition_penalty, 85 | 'no_repeat_ngram_size': args.no_repeat_ngram_size, 86 | 'encoder_no_repeat_ngram_size': args.encoder_no_repeat_ngram_size, 87 | 'pad_token_id': eos_token_id, 88 | 'eos_token_id': eos_token_id, 89 | } 90 | if not args.only_evaluate: 91 | print(json.dumps(generation_kwargs, indent=2, ensure_ascii=False)) 92 | 93 | ######################################################################### 94 | # Inference ! 95 | ########################################################################## 96 | 97 | for infer_data_path, infer_name in zip(args.infer_data_paths, args.infer_names): 98 | if not os.path.exists(infer_data_path): 99 | logger.info(f'file {infer_data_path} does not exist') 100 | continue 101 | set_seed(args.seed) 102 | 103 | if args.save_path is not None: 104 | os.makedirs(args.save_path, exist_ok=True) 105 | save_path = f'{args.save_path}/{infer_name}' 106 | os.makedirs(save_path, exist_ok=True) 107 | else: 108 | assert args.load_checkpoint is not None 109 | checkpoint_dir_path = '/'.join(args.load_checkpoint.split('/')[:-1]) 110 | save_path = f'{checkpoint_dir_path}/{infer_name}' 111 | os.makedirs(save_path, exist_ok=True) 112 | 113 | gen_file = os.path.join(save_path, 'gen.txt') 114 | gen_exist = os.path.exists(gen_file) 115 | metric_file = os.path.join(save_path, f'metric.json') 116 | metric_exist = os.path.exists(metric_file) 117 | 118 | if gen_exist and metric_exist: 119 | print('all have existed') 120 | #continue 121 | elif gen_exist and args.only_generate: 122 | print('gen has existed while metric not required') 123 | #continue 124 | elif metric_exist and args.only_evaluate: 125 | print('metric has existed while gen not required') 126 | #continue 127 | 128 | metric_res = {} 129 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 130 | if not args.only_generate: 131 | eval_dataloader = BatchDataLoader( 132 | data_path=infer_data_path, 133 | batch_size=args.batch_size, 134 | collate_fn=partial( 135 | collate_fn, 136 | toker=toker, 137 | max_input_length=args.max_input_length, 138 | max_decoder_input_length=args.max_decoder_input_length, 139 | infer=False, 140 | ), 141 | shuffle=False, 142 | ) 143 | eval_dataloader = accelerator.prepare(eval_dataloader) 144 | 145 | _, eval_ppl_micro, eval_acc, eval_rep, eval_wrep, pointwise_loss, pointwise_sample = eval_model_loss( 146 | accelerator=accelerator, 147 | model=model, 148 | eval_dataloader=eval_dataloader, 149 | epoch_id=0, 150 | infer=True, 151 | ) 152 | metric_res['acc'] = float(eval_acc) 153 | metric_res['rep'] = float(eval_rep) 154 | metric_res['wrep'] = float(eval_wrep) 155 | metric_res['ppl_micro'] = float(eval_ppl_micro) 156 | eval_ppl_list = [np.exp(np.sum(x) / y) for x, y in zip(pointwise_loss, pointwise_sample)] 157 | eval_ppl_macro = np.mean(eval_ppl_list) 158 | metric_res['ppl_macro'] = float(eval_ppl_macro) 159 | assert len(pointwise_loss) == len(pointwise_sample) 160 | ptr = 0 161 | 162 | if not args.only_evaluate: 163 | infer_dataloader = BatchDataLoader( 164 | data_path=infer_data_path, 165 | batch_size=args.batch_size, 166 | collate_fn=partial( 167 | collate_fn, 168 | toker=toker, 169 | max_input_length=args.max_input_length, 170 | max_decoder_input_length=args.max_decoder_input_length, 171 | infer=True, 172 | ), 173 | shuffle=False, 174 | ) 175 | infer_dataloader = accelerator.prepare(infer_dataloader) 176 | 177 | if not args.only_generate: 178 | from metrics import Metrics 179 | metrics = Metrics(toker) 180 | 181 | res = [] 182 | decode = lambda x: toker.decode(x, skip_special_tokens=False) 183 | for batch in tqdm(infer_dataloader, total=len(infer_dataloader), desc='inference', dynamic_ncols=True): 184 | if 'references' in batch: 185 | references = batch.pop('references') 186 | batch.update(generation_kwargs) 187 | generations = model.generate(**batch) 188 | generations = [cut_sequence_to_eos(each, eos_token_id) for each in generations.tolist()] 189 | batch_size = len(generations) // args.num_return_sequences 190 | 191 | for idx in range(batch_size): 192 | if args.num_return_sequences > 1: 193 | generation = generations[idx * args.num_return_sequences: (idx+1) * args.num_return_sequences] 194 | generation = [decode(g) for g in generation] 195 | else: 196 | generation = generations[idx] 197 | generation = decode(generation) 198 | tmp_res_to_append = {'generation': generation} 199 | 200 | if not args.only_generate: 201 | if args.num_return_sequences == 1: 202 | g = generation 203 | else: 204 | g = generation[0] 205 | reference = references[idx] 206 | metrics.forword([reference], g, lower=args.lower, chinese=args.chinese) 207 | 208 | ptr_loss = pointwise_loss[ptr] 209 | ptr_sample = pointwise_sample[ptr] 210 | turn_loss = sum(ptr_loss) 211 | turn_ppl = np.exp(turn_loss / ptr_sample) 212 | tmp_res_to_append['token_num'] = int(ptr_sample) 213 | tmp_res_to_append['loss'] = turn_loss 214 | tmp_res_to_append['ppl'] = turn_ppl 215 | ptr += 1 216 | 217 | res.append(tmp_res_to_append) 218 | 219 | if not args.only_generate: 220 | assert ptr == len(pointwise_loss) 221 | 222 | if not args.only_evaluate: 223 | with open(os.path.join(save_path, f'gen.txt'), 'w') as f: 224 | for line in res: 225 | f.write(json.dumps(line, ensure_ascii=False) + '\n') 226 | 227 | metric_res_list = None 228 | if not args.only_evaluate and not args.only_generate: 229 | metric_res_list = {} 230 | closed_res = metrics.close() 231 | metric_res.update(closed_res[0]) 232 | metric_res_list.update(closed_res[1]) 233 | 234 | if not args.only_generate: 235 | with open(os.path.join(save_path, f'metric.json'), 'w') as f: 236 | json.dump(metric_res, f, ensure_ascii=False, indent=2, sort_keys=True) 237 | if metric_res_list is not None: 238 | with open(os.path.join(save_path, f'metric_list.json'), 'w') as f: 239 | json.dump(metric_res_list, f) 240 | 241 | 242 | def main(): 243 | ######################################################################### 244 | # Prepare Parser 245 | ########################################################################## 246 | 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument('--collator_name', type=str, required=True) 249 | parser.add_argument('--model_name', type=str, required=True) 250 | parser.add_argument('--pretrained_model_path', type=str, required=True) 251 | parser.add_argument('--save_path', type=str, required=True) 252 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 253 | 254 | parser.add_argument('--infer_data_paths', type=str, nargs='+', required=True) 255 | parser.add_argument('--infer_names', type=str, nargs='+', required=True) 256 | parser.add_argument("--max_input_length", type=int, default=128) 257 | parser.add_argument("--max_decoder_input_length", type=int, default=48) 258 | 259 | parser.add_argument('--only_evaluate', action='store_true', help='only do evaluation and no inference') 260 | parser.add_argument('--only_generate', action='store_true', help='do not conduct evaluations') 261 | parser.add_argument('--chinese', action='store_true') 262 | parser.add_argument('--lower', action='store_true') 263 | 264 | parser.add_argument("--seed", type=int, default=42) 265 | parser.add_argument("--load_checkpoint", type=str, default=None) 266 | parser.add_argument("--batch_size", type=int, default=16) 267 | parser.add_argument("--min_length", type=int, default=5) 268 | parser.add_argument("--max_length", type=int, default=64) 269 | parser.add_argument("--num_return_sequences", type=int, default=1) 270 | parser.add_argument("--temperature", type=float, default=1) 271 | parser.add_argument("--top_k", type=int, default=0) 272 | parser.add_argument("--top_p", type=float, default=1) 273 | parser.add_argument('--num_beams', type=int, default=1) 274 | parser.add_argument("--length_penalty", type=float, default=1.0) 275 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 276 | parser.add_argument("--encoder_no_repeat_ngram_size", type=int, default=0) 277 | parser.add_argument("--no_repeat_ngram_size", type=int, default=0) 278 | 279 | args = parser.parse_args() 280 | generate(args) 281 | 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | from importlib import import_module 6 | import nltk 7 | import numpy as np 8 | from tqdm import tqdm 9 | from functools import partial 10 | from collections import defaultdict 11 | 12 | from sklearn.metrics import classification_report, f1_score, confusion_matrix 13 | 14 | import torch 15 | from torch import Tensor 16 | from transformers.trainer_utils import set_seed 17 | from accelerate import Accelerator 18 | 19 | from utils.building_utils import boolean_string, build_model 20 | from utils.dataloader_utils import _norm, BatchDataLoader 21 | 22 | 23 | def infer(args): 24 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 25 | datefmt = '%m/%d/%Y %H:%M:%S', 26 | level = logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | assert len(args.infer_data_paths) == len(args.infer_names) 30 | 31 | accelerator = Accelerator() 32 | set_seed(args.seed) 33 | n_gpu = torch.cuda.device_count() 34 | 35 | #logger.info('Input Argument Information') 36 | #args_dict = vars(args) 37 | #for a in args_dict: 38 | # logger.info('%-28s %s' % (a, args_dict[a])) 39 | 40 | ######################################################################### 41 | # Prepare Data Set 42 | ########################################################################## 43 | 44 | toker, model = build_model(args, checkpoint=args.load_checkpoint) 45 | model = accelerator.prepare(model) 46 | if n_gpu > 1: 47 | logger.info('use `torch.nn.DataParallel`') 48 | model = torch.nn.DataParallel(model) 49 | model.eval() 50 | 51 | model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 52 | total_params = sum([np.prod(p.size()) for p in model_parameters]) 53 | logger.info('Number of parameter = {}'.format(total_params)) 54 | 55 | ######################################################################### 56 | # Inference ! 57 | ########################################################################## 58 | 59 | for infer_data_path, infer_name in zip(args.infer_data_paths, args.infer_names): 60 | if not os.path.exists(infer_data_path): 61 | logger.info(f'file {infer_data_path} does not exist') 62 | continue 63 | set_seed(args.seed) 64 | 65 | if args.save_path is not None: 66 | os.makedirs(args.save_path, exist_ok=True) 67 | save_path = f'{args.save_path}/{infer_name}' 68 | os.makedirs(save_path, exist_ok=True) 69 | else: 70 | assert args.load_checkpoint is not None 71 | checkpoint_dir_path = '/'.join(args.load_checkpoint.split('/')[:-1]) 72 | save_path = f'{checkpoint_dir_path}/{infer_name}' 73 | os.makedirs(save_path, exist_ok=True) 74 | 75 | gen_file = os.path.join(save_path, 'gen.txt') 76 | gen_exist = os.path.exists(gen_file) 77 | metric_file = os.path.join(save_path, f'metric.json') 78 | metric_exist = os.path.exists(metric_file) 79 | 80 | if gen_exist and metric_exist: 81 | print('all have existed') 82 | continue 83 | elif gen_exist and args.only_infer: 84 | print('gen has existed while metric not required') 85 | continue 86 | 87 | metric_res = {} 88 | collate_fn = getattr(import_module('collators.' + args.collator_name), 'collate_fn') 89 | infer_dataloader = BatchDataLoader( 90 | data_path=infer_data_path, 91 | batch_size=args.batch_size, 92 | collate_fn=partial( 93 | collate_fn, 94 | toker=toker, 95 | max_input_length=args.max_input_length, 96 | infer=True, 97 | ), 98 | shuffle=False, 99 | ) 100 | infer_dataloader = accelerator.prepare(infer_dataloader) 101 | 102 | res = [] 103 | other_res = defaultdict(list) 104 | for batch in tqdm(infer_dataloader, total=len(infer_dataloader), desc='inference', dynamic_ncols=True): 105 | labels = batch.pop('labels') 106 | batch['inference'] = True 107 | with torch.no_grad(): 108 | encoded_info = model(**batch) 109 | if not args.only_infer: 110 | encoded_info['labels'] = labels 111 | 112 | for key in ['labels', 'preds', 'preds_top3', 'preds_dist']: 113 | if key in encoded_info: 114 | encoded_info[key] = encoded_info[key].tolist() 115 | other_res[key].extend(encoded_info[key]) 116 | 117 | for idx in range(labels.size(0)): 118 | tmp_res_to_append = {} 119 | for key in ['labels', 'preds', 'preds_top3', 'preds_dist']: 120 | if key in encoded_info: 121 | tmp_res_to_append[key.replace('labels', 'label').replace('preds', 'pred')] = encoded_info[key][idx] if 'dist' not in key else ' '.join(map(str, encoded_info[key][idx])) 122 | 123 | res.append(tmp_res_to_append) 124 | 125 | with open(os.path.join(save_path, f'gen.txt'), 'w') as f: 126 | for line in res: 127 | f.write(json.dumps(line, ensure_ascii=False) + '\n') 128 | 129 | if not args.only_infer: 130 | metric_res_list = {} 131 | labels = np.array(other_res['labels'], dtype=int) 132 | preds = np.array(other_res['preds'], dtype=int) 133 | print(f'classification_report\t{save_path}\n', classification_report(labels, preds, digits=4)) 134 | with open(os.path.join(save_path, f'confusion_matrix.json'), 'w') as f: 135 | json.dump(confusion_matrix(labels, preds).tolist(), f) 136 | #print('confusion_matrix\n', confusion_matrix(labels, preds)) 137 | 138 | metric_res['acc'] = np.mean(labels == preds) 139 | metric_res['f1_micro'] = f1_score(labels, preds, average='micro') 140 | metric_res['f1_macro'] = f1_score(labels, preds, average='macro') 141 | metric_res_list['acc'] = (labels == preds).astype(int).tolist() 142 | 143 | if 'preds_top3' in other_res: 144 | preds_top3 = np.array(other_res['preds_top3'], dtype=int) 145 | metric_res['acc_top3'] = np.mean(np.sum((labels.reshape(-1, 1) - preds_top3) == 0, axis=-1) != 0) 146 | metric_res_list['acc_top3'] = (np.sum((labels.reshape(-1, 1) - preds_top3) == 0, axis=-1) != 0).astype(int).tolist() 147 | 148 | with open(os.path.join(save_path, f'metric.json'), 'w') as f: 149 | json.dump(metric_res, f, ensure_ascii=False, indent=2, sort_keys=True) 150 | 151 | if metric_res_list is not None: 152 | with open(os.path.join(save_path, f'metric_list.json'), 'w') as f: 153 | json.dump(metric_res_list, f) 154 | 155 | 156 | def main(): 157 | ######################################################################### 158 | # Prepare Parser 159 | ########################################################################## 160 | 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument('--collator_name', type=str, required=True) 163 | parser.add_argument('--model_name', type=str, required=True) 164 | parser.add_argument('--pretrained_model_path', type=str, required=True) 165 | parser.add_argument('--save_path', type=str, required=True) 166 | parser.add_argument('--model_args', type=str, nargs='+', default=[]) 167 | 168 | parser.add_argument('--infer_data_paths', type=str, nargs='+', required=True) 169 | parser.add_argument('--infer_names', type=str, nargs='+', required=True) 170 | parser.add_argument("--max_input_length", type=int, default=256) 171 | parser.add_argument('--only_infer', action='store_true', help='do not conduct evaluations') 172 | 173 | parser.add_argument("--seed", type=int, default=42) 174 | parser.add_argument("--load_checkpoint", type=str, default=None) 175 | parser.add_argument("--batch_size", type=int, default=16) 176 | 177 | args = parser.parse_args() 178 | infer(args) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import json 4 | import warnings 5 | import numpy as np 6 | import nltk 7 | import copy 8 | from typing import List 9 | from collections import Counter 10 | from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction 11 | 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if len(string) < len(sub): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for _ in range(0, len(sub)+1)] for _ in range(0, len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1, len(string) + 1): 29 | if string[i-1] == sub[j-1]: 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j], lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | 37 | class Metrics(object): 38 | def __init__(self, toker=None): 39 | self.refs = [] 40 | self.hyps = [] 41 | self.toker = toker 42 | 43 | def forword(self, refs: str, hyp: str, lower=False, chinese=False): # TODO: only applicable to English 44 | if not chinese: 45 | self.refs.append([nltk.word_tokenize(e.lower() if lower else e) for e in refs]) 46 | self.hyps.append(nltk.word_tokenize(hyp.lower() if lower else hyp)) 47 | else: 48 | self.refs.append([self.toker.tokenize(e) for e in refs]) 49 | self.hyps.append(self.toker.tokenize(hyp)) 50 | 51 | def set_refs(self, refs): 52 | self.refs = copy.deepcopy(refs) 53 | 54 | def set_hyps(self, hyps): 55 | self.hyps = copy.deepcopy(hyps) 56 | 57 | def calculate_bleu_k(self, k): 58 | weights = [1. / k] * k + (4 - k) * [0.] 59 | try: 60 | bleu = corpus_bleu(self.refs, self.hyps, weights=weights, 61 | smoothing_function=SmoothingFunction().method3) 62 | except ZeroDivisionError as _: 63 | warnings.warn('the bleu is invalid') 64 | bleu = 0. 65 | return bleu 66 | 67 | def calculate_distinct_k(self, k): 68 | ngrams = [] 69 | for sen in self.hyps: 70 | tmp_ngrams = list(zip(*[sen[i:] for i in range(k)])) 71 | ngrams.extend(tmp_ngrams) 72 | ngrams = Counter(ngrams) 73 | dist = len(ngrams) / max(sum(ngrams.values()), 1e-10) 74 | return dist 75 | 76 | """ 77 | def calculate_repetition_k(self, k): 78 | repetitions = [] 79 | for sen in self.hyps: 80 | tmp_ngrams = list(zip(*[sen[i:] for i in range(k)])) 81 | tmp_ngrams = Counter(tmp_ngrams) 82 | tmp_dist = len(tmp_ngrams) / max(sum(tmp_ngrams.values()), 1e-10) 83 | repetitions.append(1. - tmp_dist) 84 | return np.mean(repetitions), repetitions 85 | """ 86 | 87 | def calculate_repetition_k(self, k): 88 | count = [0, 0] 89 | for sen in self.hyps: 90 | tmp_ngrams = list(zip(*[sen[i:] for i in range(k)])) 91 | count[1] += len(tmp_ngrams) 92 | tmp_ngrams = Counter(tmp_ngrams) 93 | count[0] += len(tmp_ngrams) 94 | return 1. - count[0] / count[1] 95 | 96 | def calculate_unigram_f1(self): 97 | f1_scores = [] 98 | for hyp, refs in zip(self.hyps, self.refs): 99 | scores = [] 100 | for ref in refs: 101 | cross = Counter(hyp) & Counter(ref) 102 | cross = sum(cross.values()) 103 | p = cross / max(len(hyp), 1e-10) 104 | r = cross / max(len(ref), 1e-10) 105 | f1 = 2 * p * r / max(p + r, 1e-10) 106 | scores.append(f1) 107 | f1_scores.append(max(scores)) 108 | return np.mean(f1_scores), f1_scores 109 | 110 | def calculate_rouge_k(self, k): 111 | scores = [] 112 | for hyp, refs in zip(self.hyps, self.refs): 113 | rec = [] 114 | hyp_kgrams = Counter(zip(*(hyp[i:] for i in range(k)))) 115 | for ref in refs: 116 | ref_kgrams = Counter(zip(*(ref[i:] for i in range(k)))) 117 | cross_kgrams = hyp_kgrams & ref_kgrams 118 | rec.append(sum(cross_kgrams.values()) / max(sum(ref_kgrams.values()), 1e-10)) 119 | score = max(rec) 120 | scores.append(score) 121 | return np.mean(scores), scores 122 | 123 | def calculate_rouge_l(self, beta=1.2): 124 | scores = [] 125 | for hyp, refs in zip(self.hyps, self.refs): 126 | prec = [] 127 | rec = [] 128 | for ref in refs: 129 | lcs = my_lcs(ref, hyp) 130 | prec.append(lcs / max(len(hyp), 1e-10)) 131 | rec.append(lcs / max(len(ref), 1e-10)) 132 | prec_max = max(prec) 133 | rec_max = max(rec) 134 | if prec_max != 0 and rec_max !=0: 135 | score = ((1 + beta**2) * prec_max * rec_max)/float(rec_max + beta**2 * prec_max) 136 | else: 137 | score = 0.0 138 | scores.append(score) 139 | return np.mean(scores), scores 140 | 141 | def close(self): 142 | result = {} 143 | result_list = {} 144 | 145 | result['length'] = np.mean(list(map(len, self.hyps))) 146 | 147 | for k in range(1, 5): 148 | bleu = self.calculate_bleu_k(k) 149 | result[f'bleu-{k}'] = 100 * bleu 150 | 151 | for k in range(1, 4): 152 | dist = self.calculate_distinct_k(k) 153 | result[f'dist-{k}'] = 100 * dist 154 | 155 | for k in range(1, 3): 156 | rouge, scores = self.calculate_rouge_k(k) 157 | result[f'rouge-{k}'] = 100 * rouge 158 | result_list[f'rouge-{k}'] = scores 159 | 160 | for k in range(2, 5): 161 | rep = self.calculate_repetition_k(k) 162 | result[f'rep-{k}'] = 100 * rep 163 | result['diversity'] = (1. - result['rep-2'] / 100) * (1. - result['rep-3'] / 100) * (1. - result['rep-4'] / 100) 164 | 165 | f1, scores = self.calculate_unigram_f1() 166 | result['f1'] = 100 * f1 167 | result_list['f1-l'] = scores 168 | 169 | rl, scores = self.calculate_rouge_l() 170 | result['rouge-l'] = 100 * rl 171 | result_list['rouge-l'] = scores 172 | 173 | return result, result_list 174 | -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/metrics/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_contrast_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_contrast_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_contrast_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_contrast_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_contrast_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_contrast_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_cringe_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_cringe_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_cringe_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_cringe_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_cringe_03.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_cringe_03.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_cringe_04.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_cringe_04.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_cringe_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_cringe_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_dexperts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_dexperts.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_director_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_director_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_director_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_director_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_director_10.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_director_10.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_gedi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_gedi.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_gedinew.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_gedinew.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_unlikelihood_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_unlikelihood_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_unlikelihood_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_unlikelihood_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_unlikelihood_03.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_unlikelihood_03.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_unlikelihood_04.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_unlikelihood_04.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_unlikelihood_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/blender_unlikelihood_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/contrast_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/contrast_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/contrast_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/contrast_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/contrast_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/contrast_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe_03.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe_03.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe_04.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe_04.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cringe_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/cringe_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/dexperts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/dexperts.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/director.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/director.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/director_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/director_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/director_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/director_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/director_10.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/director_10.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/distilbert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/distilbert.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gedi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gedi.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2_contrast_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2_contrast_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2_contrast_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2_contrast_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2_contrast_03.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2_contrast_03.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2_contrast_04.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2_contrast_04.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt2_contrast_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/gpt2_contrast_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/roberta.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/roberta.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood_01.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood_01.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood_02.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood_02.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood_03.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood_03.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood_04.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood_04.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unlikelihood_05.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/models/__pycache__/unlikelihood_05.cpython-39.pyc -------------------------------------------------------------------------------- /models/blender.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 7 | from utils.model_utils import BaseModel 8 | 9 | 10 | class Model(BaseModel, BlenderbotForConditionalGeneration): 11 | def __init__(self, config: BlenderbotConfig): 12 | super().__init__(config) 13 | 14 | def forward( 15 | self, 16 | input_ids=None, 17 | attention_mask=None, 18 | decoder_input_ids=None, 19 | encoder_outputs=None, 20 | past_key_values=None, 21 | labels=None, 22 | use_cache=None, 23 | return_dict=True, 24 | validation=False, 25 | **kwargs 26 | ): 27 | assert self.toker is not None 28 | assert not (self.training and validation) 29 | if self.training or validation: 30 | assert labels is not None 31 | use_cache = False 32 | 33 | outputs = super().forward( 34 | input_ids=input_ids, 35 | attention_mask=attention_mask, 36 | decoder_input_ids=decoder_input_ids, 37 | encoder_outputs=encoder_outputs, 38 | past_key_values=past_key_values, 39 | labels=None, 40 | return_dict=return_dict, 41 | use_cache=use_cache, 42 | **kwargs 43 | ) 44 | 45 | lm_logits = outputs.logits 46 | masked_lm_loss = None 47 | if self.training or validation: 48 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 49 | loss = loss.view(labels.size(0), labels.size(1)) 50 | label_mask = labels.ne(-100).type_as(loss) 51 | label_size = label_mask.sum(1).type_as(loss) 52 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 53 | ppl_value = masked_lm_loss.exp() 54 | 55 | if validation: 56 | preds = torch.argmax(lm_logits, dim=-1) 57 | acc = ((preds == labels) * label_mask).type_as(loss) # [bs, length] 58 | occurrence = torch.tril(preds.unsqueeze(-1) == labels.unsqueeze(-2), -1) # [bs, length, length] 59 | rep = (occurrence.sum(dim=-1) > 0).type_as(loss) * label_mask # [bs, length] 60 | wrep = rep * (1. - acc) 61 | 62 | outputs.loss = masked_lm_loss 63 | 64 | if not self.training and not validation: # inference 65 | return outputs 66 | elif self.training: # training 67 | res = {'all': masked_lm_loss, 'ppl': ppl_value, } 68 | return res 69 | else: # validation 70 | assert not self.training 71 | return loss, label_size, acc, rep, wrep 72 | -------------------------------------------------------------------------------- /models/blender_contrast_05.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 7 | from utils.model_utils import BaseModel 8 | from transformers.modeling_outputs import (BaseModelOutput, Seq2SeqLMOutput) 9 | 10 | GAMMA = 0.5 11 | 12 | 13 | class Model(BaseModel, BlenderbotForConditionalGeneration): 14 | def __init__(self, config: BlenderbotConfig, alpha): 15 | super().__init__(config) 16 | self.alpha = float(alpha) 17 | 18 | def forward( 19 | self, 20 | input_ids=None, 21 | attention_mask=None, 22 | decoder_input_ids=None, 23 | #neg_input_ids=None, 24 | #neg_attention_mask=None, 25 | pos_decoder_input_ids=None, 26 | neg_decoder_input_ids=None, 27 | encoder_outputs=None, 28 | past_key_values=None, 29 | labels=None, 30 | pos_labels=None, 31 | neg_labels=None, 32 | selected_indices=None, 33 | use_cache=None, 34 | return_dict=True, 35 | validation=False, 36 | **kwargs 37 | ): 38 | assert self.toker is not None 39 | assert self.training and not validation 40 | assert labels is not None 41 | use_cache = False 42 | 43 | encoder_outputs: BaseModelOutput = self.model.encoder( 44 | input_ids=input_ids, 45 | attention_mask=attention_mask, 46 | return_dict=return_dict, 47 | ) 48 | outputs: Seq2SeqLMOutput = super().forward( 49 | input_ids=input_ids, 50 | attention_mask=attention_mask, 51 | decoder_input_ids=decoder_input_ids, 52 | encoder_outputs=encoder_outputs, 53 | past_key_values=past_key_values, 54 | labels=None, 55 | return_dict=return_dict, 56 | use_cache=use_cache, 57 | **kwargs 58 | ) 59 | 60 | lm_logits = outputs.logits 61 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 62 | loss = loss.view(labels.size(0), labels.size(1)) 63 | label_size = labels.ne(-100).sum(1).type_as(loss) 64 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 65 | ppl_value = masked_lm_loss.exp() 66 | 67 | outputs.loss = masked_lm_loss 68 | 69 | if neg_decoder_input_ids is None: 70 | res = {'all': masked_lm_loss, 'ppl': ppl_value, } 71 | return res 72 | 73 | # alpha 74 | encoder_outputs.last_hidden_state = torch.index_select(encoder_outputs.last_hidden_state, 0, selected_indices) 75 | attention_mask = torch.index_select(attention_mask, 0, selected_indices) 76 | 77 | pos_outputs = super().forward( 78 | input_ids=None, 79 | attention_mask=None, 80 | decoder_input_ids=pos_decoder_input_ids, 81 | encoder_outputs=encoder_outputs, 82 | labels=None, 83 | return_dict=return_dict, 84 | use_cache=use_cache, 85 | **kwargs 86 | ) 87 | pos_lm_logits = pos_outputs.logits 88 | pos_loss = F.cross_entropy(pos_lm_logits.view(-1, pos_lm_logits.size(-1)), pos_labels.view(-1), reduction='none') 89 | pos_loss = pos_loss.view(pos_labels.size(0), pos_labels.size(1)).sum(-1) 90 | 91 | neg_outputs = super().forward( 92 | input_ids=None, 93 | attention_mask=None, 94 | decoder_input_ids=neg_decoder_input_ids, 95 | encoder_outputs=encoder_outputs, 96 | labels=None, 97 | return_dict=return_dict, 98 | use_cache=use_cache, 99 | **kwargs 100 | ) 101 | neg_lm_logits = neg_outputs.logits 102 | neg_loss = F.cross_entropy(neg_lm_logits.view(-1, neg_lm_logits.size(-1)), neg_labels.view(-1), reduction='none') 103 | neg_loss = neg_loss.view(neg_labels.size(0), neg_labels.size(1)).sum(-1) 104 | 105 | # we have pos_loss < neg_loss 106 | loss1 = torch.clamp(self.alpha + pos_loss - neg_loss, min=0.) 107 | loss1 = loss1.mean() 108 | 109 | res = {'all': masked_lm_loss + GAMMA * loss1, 'ppl': ppl_value, 'loss': masked_lm_loss, 'loss1': loss1, } 110 | return res 111 | -------------------------------------------------------------------------------- /models/blender_cringe_02.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.distributions import Categorical 8 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 9 | from utils.model_utils import BaseModel 10 | 11 | GAMMA = 0.2 12 | 13 | 14 | class CringeLoss(CrossEntropyLoss): 15 | def __init__(self, k=5, **kwargs): 16 | super().__init__(**kwargs) 17 | self.k = k 18 | 19 | def __call__(self, x, y, classifier_labels, **kwargs): 20 | 21 | # Compute the CrossEntropy loss for the positive labels and mask 22 | # with classifier labels to not train with negative feedback (0) 23 | ce_loss = super().__call__(x, y, **kwargs) 24 | ce_loss *= classifier_labels 25 | 26 | # compute the contrastive loss part for the negative labels 27 | # first, get the positives as the top predictions != target 28 | preds = torch.topk(x, k=self.k + 1, axis=-1) 29 | y_rep = y.unsqueeze(1).repeat(1, self.k + 1) 30 | logits = preds.values - (preds.indices == y_rep) * 1e10 31 | 32 | # if the positive is not in the first k predictions, mask out 33 | # the final (k+1)'s logit 34 | prediction_mask = torch.cat( 35 | (torch.zeros_like(logits)[:, :-1], 36 | torch.abs((preds.indices == y_rep).sum(-1).unsqueeze(1) - 1),), 37 | 1,) 38 | logits -= prediction_mask * 1e10 39 | 40 | # Sample from the categorical distribution of the top-k predictions 41 | # (with the label masked out). 42 | preds_dist = Categorical(logits=logits) 43 | idx_sample = preds_dist.sample() 44 | sample_preds_values = preds.values[torch.arange(x.shape[0]), idx_sample] 45 | 46 | # Concatenate the logits of the preds with the negative label's logits. 47 | x_negative_target = x[torch.arange(x.shape[0]), y] 48 | x_cr = torch.concat( 49 | [x_negative_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1) 50 | 51 | # Create the y's for the x_cr (the correct label is always index 1). 52 | y_cr = torch.ones(y.shape).type(y.dtype).to(x_cr.device) 53 | 54 | # Compute the Cringe loss as cross entropy loss between x_cr, y_cr 55 | # and mask out the positive labels. 56 | cr_loss = super().__call__(x_cr, y_cr, **kwargs) 57 | 58 | return ce_loss, cr_loss 59 | 60 | 61 | class Model(BaseModel, BlenderbotForConditionalGeneration): 62 | def __init__(self, config: BlenderbotConfig): 63 | super().__init__(config) 64 | 65 | def forward( 66 | self, 67 | input_ids=None, 68 | attention_mask=None, 69 | decoder_input_ids=None, 70 | encoder_outputs=None, 71 | past_key_values=None, 72 | labels=None, 73 | cls_labels=None, 74 | use_cache=None, 75 | return_dict=True, 76 | validation=False, 77 | **kwargs 78 | ): 79 | assert self.toker is not None 80 | assert not (self.training and validation) 81 | if self.training or validation: 82 | assert labels is not None 83 | assert cls_labels is not None 84 | cls_labels = cls_labels.unsqueeze(-1).expand_as(labels).contiguous() 85 | use_cache = False 86 | 87 | outputs = super().forward( 88 | input_ids=input_ids, 89 | attention_mask=attention_mask, 90 | decoder_input_ids=decoder_input_ids, 91 | encoder_outputs=encoder_outputs, 92 | past_key_values=past_key_values, 93 | labels=None, 94 | return_dict=return_dict, 95 | use_cache=use_cache, 96 | **kwargs 97 | ) 98 | 99 | lm_logits = outputs.logits 100 | masked_lm_loss = None 101 | if self.training or validation: 102 | loss, negative_loss = CringeLoss(reduction='none')(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), cls_labels.view(-1)) 103 | loss = loss.view(labels.size(0), labels.size(1)) * cls_labels 104 | label_size = (labels.ne(-100) * cls_labels).sum(1).type_as(loss) 105 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 106 | ppl_value = masked_lm_loss.exp() 107 | 108 | negative_loss = negative_loss.view(labels.size(0), labels.size(1)) * (1. - cls_labels) 109 | negative_label_size = (labels.ne(-100) * (1. - cls_labels)).sum(1).type_as(loss) 110 | negative_lm_loss = negative_loss.sum() / negative_label_size.sum() 111 | 112 | outputs.loss = masked_lm_loss 113 | 114 | if not self.training and not validation: # inference 115 | return outputs 116 | elif self.training: # training 117 | assert not validation 118 | res = {'all': masked_lm_loss + GAMMA * negative_lm_loss, 'ppl': ppl_value, } 119 | return res 120 | else: # validation 121 | assert not self.training 122 | return loss, label_size 123 | -------------------------------------------------------------------------------- /models/blender_director_02.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 7 | from utils.model_utils import BaseModel 8 | from transformers.modeling_outputs import ( 9 | Seq2SeqLMOutput, 10 | ) 11 | 12 | GAMMA = 0.2 13 | 14 | 15 | class Model(BaseModel, BlenderbotForConditionalGeneration): 16 | def __init__(self, config: BlenderbotConfig, alpha='1.0'): 17 | super().__init__(config) 18 | self.cls_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 19 | self._init_weights(self.cls_head) 20 | self.alpha = float(alpha) 21 | 22 | def forward( 23 | self, 24 | input_ids=None, 25 | attention_mask=None, 26 | decoder_input_ids=None, 27 | encoder_outputs=None, 28 | past_key_values=None, 29 | labels=None, 30 | cls_labels=None, 31 | use_cache=None, 32 | return_dict=True, 33 | validation=False, 34 | **kwargs 35 | ): 36 | assert self.toker is not None 37 | assert not (self.training and validation) 38 | if self.training: 39 | assert labels is not None 40 | assert cls_labels is not None 41 | cls_labels = cls_labels.unsqueeze(-1).expand(*labels.size()).contiguous() 42 | use_cache = False 43 | 44 | outputs = self.model( 45 | input_ids, 46 | attention_mask=attention_mask, 47 | decoder_input_ids=decoder_input_ids, 48 | encoder_outputs=encoder_outputs, 49 | past_key_values=past_key_values, 50 | use_cache=use_cache, 51 | return_dict=return_dict, 52 | ) 53 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 54 | cls_logits = self.cls_head(outputs[0]) 55 | 56 | masked_lm_loss = None 57 | if self.training or validation: 58 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 59 | loss = loss.view(labels.size(0), labels.size(1)) 60 | label_mask = labels.ne(-100).type_as(loss) 61 | label_size = label_mask.sum(1).type_as(loss) 62 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 63 | ppl_value = masked_lm_loss.exp() 64 | 65 | if validation: 66 | preds = torch.argmax(lm_logits, dim=-1) 67 | acc = ((preds == labels) * label_mask).type_as(loss) # [bs, length] 68 | occurrence = torch.tril(preds.unsqueeze(-1) == labels.unsqueeze(-2), -1) # [bs, length, length] 69 | rep = (occurrence.sum(dim=-1) > 0).type_as(loss) * label_mask # [bs, length] 70 | wrep = rep * (1. - acc) 71 | 72 | if self.training: 73 | cls_logits = cls_logits.view(-1, self.model.shared.num_embeddings) 74 | cls_tgt_logits = cls_logits[range(cls_logits.size(0)), labels.view(-1)] # [bs * tgt_len] 75 | cls_loss = F.binary_cross_entropy_with_logits(cls_tgt_logits, cls_labels.view(-1), reduction='none') 76 | cls_loss = cls_loss.view(labels.size(0), labels.size(1)) * labels.ne(-100) # [bs, tgt_len] 77 | cls_loss = cls_loss.sum() / labels.ne(-100).sum() 78 | 79 | else: 80 | cls_logits = cls_logits.view(lm_logits.size(0), -1, self.model.shared.num_embeddings) 81 | lm_logits = torch.log_softmax(lm_logits, dim=-1) + torch.sigmoid(cls_logits) * self.alpha 82 | 83 | outputs = Seq2SeqLMOutput( 84 | loss=masked_lm_loss, 85 | logits=lm_logits, 86 | past_key_values=outputs.past_key_values, 87 | decoder_hidden_states=outputs.decoder_hidden_states, 88 | decoder_attentions=outputs.decoder_attentions, 89 | cross_attentions=outputs.cross_attentions, 90 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 91 | encoder_hidden_states=outputs.encoder_hidden_states, 92 | encoder_attentions=outputs.encoder_attentions, 93 | ) 94 | 95 | if not self.training and not validation: # inference 96 | return outputs 97 | elif self.training: # training 98 | assert not validation 99 | res = {'all': masked_lm_loss + GAMMA * cls_loss, 'ppl': ppl_value, 'cls_loss': cls_loss, } 100 | return res 101 | else: # validation 102 | assert not self.training 103 | return loss, label_size, acc, rep, wrep 104 | -------------------------------------------------------------------------------- /models/blender_gedi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import copy 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | from models.blender_dexperts import Model as DExperts 9 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 10 | 11 | import inspect 12 | import warnings 13 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 14 | 15 | from transformers.generation.beam_constraints import Constraint 16 | from transformers.generation.logits_process import ( 17 | LogitsProcessorList, 18 | ) 19 | from transformers.generation.stopping_criteria import ( 20 | StoppingCriteriaList, 21 | validate_stopping_criteria, 22 | ) 23 | from transformers.utils import ModelOutput, logging 24 | from transformers.generation.utils import SampleOutput, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | class Model(DExperts): 30 | def __init__(self, config: BlenderbotConfig, expert_path, antiexpert_path, alpha): 31 | super().__init__(config, expert_path, antiexpert_path, alpha) 32 | 33 | def sample( 34 | self, 35 | input_ids: torch.LongTensor, 36 | expert_input_ids: torch.LongTensor, 37 | antiexpert_input_ids: torch.LongTensor, 38 | num_return_sequences: int = 1, 39 | logits_processor: Optional[LogitsProcessorList] = None, 40 | stopping_criteria: Optional[StoppingCriteriaList] = None, 41 | logits_warper: Optional[LogitsProcessorList] = None, 42 | max_length: Optional[int] = None, 43 | pad_token_id: Optional[int] = None, 44 | eos_token_id: Optional[int] = None, 45 | output_attentions: Optional[bool] = None, 46 | output_hidden_states: Optional[bool] = None, 47 | output_scores: Optional[bool] = None, 48 | return_dict_in_generate: Optional[bool] = None, 49 | synced_gpus: Optional[bool] = False, 50 | model_kwargs: dict = None, 51 | expert_kwargs: dict = None, 52 | antiexpert_kwargs: dict = None, 53 | ) -> Union[SampleOutput, torch.LongTensor]: 54 | 55 | # init values 56 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 57 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 58 | if max_length is not None: 59 | warnings.warn( 60 | "`max_length` is deprecated in this function, use" 61 | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", 62 | UserWarning, 63 | ) 64 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 65 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 66 | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id 67 | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 68 | output_scores = output_scores if output_scores is not None else self.config.output_scores 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict_in_generate = ( 74 | return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate 75 | ) 76 | 77 | # init attention / hidden states / scores tuples 78 | scores = () if (return_dict_in_generate and output_scores) else None 79 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 80 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 81 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 82 | 83 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 84 | if return_dict_in_generate and self.config.is_encoder_decoder: 85 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 86 | encoder_hidden_states = ( 87 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 88 | ) 89 | 90 | # keep track of which sequences are already finished 91 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 92 | cur_len = input_ids.shape[-1] 93 | 94 | assert expert_input_ids.shape[0] == antiexpert_input_ids.shape[0] 95 | assert antiexpert_input_ids.shape[0] % input_ids.shape[0] == 0 96 | batch_size = input_ids.shape[0] 97 | multiple = antiexpert_input_ids.shape[0] // input_ids.shape[0] 98 | 99 | expert_scores = input_ids.new_ones((batch_size, ), dtype=torch.float) 100 | antiexpert_scores = input_ids.new_ones((batch_size, ), dtype=torch.float) 101 | 102 | this_peer_finished = False # used by synced_gpus only 103 | # auto-regressive generation 104 | while True: 105 | # prepare model inputs 106 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 107 | expert_inputs = self.expert.prepare_inputs_for_generation(expert_input_ids, **expert_kwargs) 108 | antiexpert_inputs = self.antiexpert.prepare_inputs_for_generation(antiexpert_input_ids, **antiexpert_kwargs) 109 | 110 | # forward pass to get next token 111 | outputs = self( 112 | **model_inputs, 113 | return_dict=True, 114 | output_attentions=output_attentions, 115 | output_hidden_states=output_hidden_states, 116 | ) 117 | expert_outputs = self.expert( 118 | **expert_inputs, 119 | return_dict=True, 120 | output_attentions=output_attentions, 121 | output_hidden_states=output_hidden_states, 122 | ) 123 | antiexpert_outputs = self.antiexpert( 124 | **antiexpert_inputs, 125 | return_dict=True, 126 | output_attentions=output_attentions, 127 | output_hidden_states=output_hidden_states, 128 | ) 129 | 130 | next_token_logits = outputs.logits[:, -1, :] 131 | expert_next_token_logits = expert_outputs.logits[:, -1, :] 132 | antiexpert_next_token_logits = antiexpert_outputs.logits[:, -1, :] 133 | 134 | if multiple > 1: 135 | raise ValueError 136 | # reshape and average 137 | aux_attention_mask = antiexpert_inputs['attention_mask'].view(batch_size // num_return_sequences, multiple, num_return_sequences, -1) 138 | aux_attention_mask = (aux_attention_mask.sum(dim=-1, keepdims=True) > 0).type_as(aux_attention_mask) 139 | 140 | expert_next_token_logits = expert_next_token_logits.view(batch_size // num_return_sequences, multiple, num_return_sequences, -1) 141 | expert_next_token_logits = (expert_next_token_logits * aux_attention_mask).sum(dim=1) / aux_attention_mask.sum(dim=1) 142 | expert_next_token_logits = expert_next_token_logits.view(batch_size, -1) 143 | 144 | antiexpert_next_token_logits = antiexpert_next_token_logits.view(batch_size // num_return_sequences, multiple, num_return_sequences, -1) 145 | antiexpert_next_token_logits = (antiexpert_next_token_logits * aux_attention_mask).sum(dim=1) / aux_attention_mask.sum(dim=1) 146 | antiexpert_next_token_logits = antiexpert_next_token_logits.view(batch_size, -1) 147 | 148 | # apply modification 149 | expert_next_token_prob = expert_scores.unsqueeze(-1) * torch.softmax(expert_next_token_logits, dim=-1) / cur_len 150 | antiexpert_next_token_prob = antiexpert_scores.unsqueeze(-1) * torch.softmax(antiexpert_next_token_logits, dim=-1) / cur_len 151 | next_token_logits = next_token_logits + self.alpha * ( 152 | expert_next_token_logits - torch.clamp(expert_next_token_prob + antiexpert_next_token_prob, min=1e-5).log()) 153 | 154 | # pre-process distribution 155 | next_token_scores = logits_processor(input_ids, next_token_logits) 156 | next_token_scores = logits_warper(input_ids, next_token_scores) 157 | 158 | # Store scores, attentions and hidden_states when required 159 | if return_dict_in_generate: 160 | if output_scores: 161 | scores += (next_token_scores,) 162 | if output_attentions: 163 | decoder_attentions += ( 164 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 165 | ) 166 | if self.config.is_encoder_decoder: 167 | cross_attentions += (outputs.cross_attentions,) 168 | 169 | if output_hidden_states: 170 | decoder_hidden_states += ( 171 | (outputs.decoder_hidden_states,) 172 | if self.config.is_encoder_decoder 173 | else (outputs.hidden_states,) 174 | ) 175 | 176 | # sample 177 | probs = nn.functional.softmax(next_token_scores, dim=-1) 178 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 179 | 180 | # finished sentences should have their next token be a padding token 181 | if eos_token_id is not None: 182 | if pad_token_id is None: 183 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 184 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 185 | 186 | # update generated ids, model inputs, and length for next step 187 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 188 | expert_input_ids = input_ids.unsqueeze(1).repeat(1, multiple, 1).view(-1, input_ids.shape[1]) 189 | antiexpert_input_ids = input_ids.unsqueeze(1).repeat(1, multiple, 1).view(-1, input_ids.shape[1]) 190 | model_kwargs = self._update_model_kwargs_for_generation( 191 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 192 | ) 193 | expert_kwargs = self.expert._update_model_kwargs_for_generation( 194 | expert_outputs, expert_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 195 | ) 196 | antiexpert_kwargs = self.antiexpert._update_model_kwargs_for_generation( 197 | antiexpert_outputs, antiexpert_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 198 | ) 199 | expert_scores = expert_scores * expert_next_token_prob[range(next_tokens.size(0)), next_tokens] 200 | antiexpert_scores = antiexpert_scores * antiexpert_next_token_prob[range(next_tokens.size(0)), next_tokens] 201 | cur_len = cur_len + 1 202 | 203 | # if eos_token was found in one sentence, set sentence to finished 204 | if eos_token_id is not None: 205 | unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) 206 | 207 | # stop when each sentence is finished, or if we exceed the maximum length 208 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 209 | break 210 | 211 | if return_dict_in_generate: 212 | if self.config.is_encoder_decoder: 213 | return SampleEncoderDecoderOutput( 214 | sequences=input_ids, 215 | scores=scores, 216 | encoder_attentions=encoder_attentions, 217 | encoder_hidden_states=encoder_hidden_states, 218 | decoder_attentions=decoder_attentions, 219 | cross_attentions=cross_attentions, 220 | decoder_hidden_states=decoder_hidden_states, 221 | ) 222 | else: 223 | return SampleDecoderOnlyOutput( 224 | sequences=input_ids, 225 | scores=scores, 226 | attentions=decoder_attentions, 227 | hidden_states=decoder_hidden_states, 228 | ) 229 | else: 230 | return input_ids 231 | -------------------------------------------------------------------------------- /models/blender_unlikelihood_01.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.blenderbot.modeling_blenderbot import (BlenderbotConfig, BlenderbotForConditionalGeneration,) 7 | from utils.model_utils import BaseModel 8 | 9 | GAMMA = 0.1 10 | 11 | 12 | class Model(BaseModel, BlenderbotForConditionalGeneration): 13 | def __init__(self, config: BlenderbotConfig): 14 | super().__init__(config) 15 | 16 | def forward( 17 | self, 18 | input_ids=None, 19 | attention_mask=None, 20 | decoder_input_ids=None, 21 | encoder_outputs=None, 22 | past_key_values=None, 23 | labels=None, 24 | cls_labels=None, 25 | use_cache=None, 26 | return_dict=True, 27 | validation=False, 28 | **kwargs 29 | ): 30 | assert self.toker is not None 31 | assert not (self.training and validation) 32 | if self.training or validation: 33 | assert labels is not None 34 | assert cls_labels is not None 35 | cls_labels = cls_labels.unsqueeze(-1).expand_as(labels).contiguous() 36 | use_cache = False 37 | 38 | outputs = super().forward( 39 | input_ids=input_ids, 40 | attention_mask=attention_mask, 41 | decoder_input_ids=decoder_input_ids, 42 | encoder_outputs=encoder_outputs, 43 | past_key_values=past_key_values, 44 | labels=None, 45 | return_dict=return_dict, 46 | use_cache=use_cache, 47 | **kwargs 48 | ) 49 | 50 | lm_logits = outputs.logits 51 | masked_lm_loss = None 52 | if self.training or validation: 53 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 54 | loss = loss.view(labels.size(0), labels.size(1)) * cls_labels 55 | label_size = (labels.ne(-100) * cls_labels).sum(1).type_as(loss) 56 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 57 | ppl_value = masked_lm_loss.exp() 58 | 59 | negative_lm_logits = torch.clamp(1. - F.softmax(lm_logits, dim=-1), min=1e-5).log() 60 | negative_loss = F.cross_entropy(negative_lm_logits.view(-1, negative_lm_logits.size(-1)), labels.view(-1), reduction='none') 61 | negative_loss = negative_loss.view(labels.size(0), labels.size(1)) * (1. - cls_labels) 62 | negative_label_size = (labels.ne(-100) * (1. - cls_labels)).sum(1).type_as(loss) 63 | negative_lm_loss = negative_loss.sum() / negative_label_size.sum() 64 | 65 | outputs.loss = masked_lm_loss 66 | 67 | if not self.training and not validation: # inference 68 | return outputs 69 | elif self.training: # training 70 | assert not validation 71 | res = {'all': masked_lm_loss + GAMMA * negative_lm_loss, 'ppl': ppl_value, } 72 | return res 73 | else: # validation 74 | assert not self.training 75 | return loss, label_size 76 | -------------------------------------------------------------------------------- /models/distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.distilbert.modeling_distilbert import (DistilBertConfig, DistilBertPreTrainedModel, DistilBertModel, DistilBertForSequenceClassification) 7 | from utils.model_utils import BaseModel 8 | 9 | 10 | class Model(BaseModel, DistilBertPreTrainedModel): 11 | def __init__(self, config: DistilBertConfig, num_labels): 12 | super().__init__(config) 13 | self.num_labels = config.num_labels = int(num_labels) 14 | self.config = config 15 | 16 | self.distilbert = DistilBertModel(config) 17 | self.pre_classifier = nn.Linear(config.dim, config.dim) 18 | self.classifier = nn.Linear(config.dim, config.num_labels) 19 | self.dropout = nn.Dropout(config.seq_classif_dropout) 20 | 21 | # Initialize weights and apply final processing 22 | self.post_init() 23 | 24 | def forward( 25 | self, 26 | input_ids=None, 27 | attention_mask=None, 28 | return_dict=None, 29 | validation=False, 30 | inference=False, 31 | **kwargs, 32 | ): 33 | assert (not validation and not inference) == self.training 34 | encoded_info = kwargs 35 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 36 | 37 | distilbert_output = self.distilbert( 38 | input_ids=input_ids, 39 | attention_mask=attention_mask, 40 | return_dict=return_dict, 41 | ) 42 | hidden_state = distilbert_output[0] # (bs, seq_len, dim) 43 | pooled_output = hidden_state[:, 0] # (bs, dim) 44 | pooled_output = self.pre_classifier(pooled_output) # (bs, dim) 45 | pooled_output = nn.ReLU()(pooled_output) # (bs, dim) 46 | pooled_output = self.dropout(pooled_output) # (bs, dim) 47 | logits = self.classifier(pooled_output) # (bs, num_labels) 48 | 49 | loss = self.predict_label(logits, encoded_info) 50 | 51 | if inference: 52 | return encoded_info.copy() 53 | else: 54 | ppl_value = loss.new_tensor([0.], dtype=loss.dtype) 55 | if not validation: 56 | res = { 57 | 'all': loss, 58 | 'ppl': ppl_value, 59 | 'loss': loss, 60 | } 61 | return res 62 | else: 63 | return loss * input_ids.size(0), loss.new_tensor([input_ids.size(0)]) 64 | 65 | def predict_label(self, logits, encoded_info): 66 | preds = torch.argmax(logits, dim=-1) 67 | if self.num_labels > 3: 68 | preds_top3 = torch.topk(logits, k=3, dim=-1)[1] 69 | loss = None 70 | if 'labels' in encoded_info: 71 | loss = F.cross_entropy(logits, encoded_info.get('labels'), reduction='mean') 72 | encoded_info['preds'] = preds 73 | if self.num_labels > 3: 74 | encoded_info['preds_top3'] = preds_top3 75 | encoded_info['preds_dist'] = F.softmax(logits, dim=-1) 76 | return loss 77 | -------------------------------------------------------------------------------- /models/gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.gpt2 import (GPT2LMHeadModel, GPT2Config) 7 | from transformers.models.gpt_neo import (GPTNeoForCausalLM, GPTNeoConfig) 8 | from transformers.models.gptj import (GPTJForCausalLM, GPTJConfig) 9 | from transformers.models.opt import (OPTForCausalLM, OPTConfig) 10 | from utils.model_utils import BaseModel 11 | 12 | 13 | class Model(BaseModel, GPT2LMHeadModel): 14 | def __init__(self, config: GPT2Config): 15 | super().__init__(config) 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | past_key_values=None, 22 | labels=None, 23 | use_cache=None, 24 | return_dict=True, 25 | validation=False, 26 | **kwargs 27 | ): 28 | assert self.toker is not None 29 | assert not (self.training and validation) 30 | if self.training or validation: 31 | assert labels is not None 32 | use_cache = False 33 | 34 | outputs = super().forward( 35 | input_ids=input_ids, 36 | attention_mask=attention_mask, 37 | past_key_values=past_key_values, 38 | labels=None, 39 | return_dict=return_dict, 40 | use_cache=use_cache, 41 | **kwargs 42 | ) 43 | 44 | lm_logits = outputs.logits 45 | masked_lm_loss = None 46 | if self.training or validation: 47 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 48 | loss = loss.view(labels.size(0), labels.size(1)) 49 | label_mask = labels.ne(-100).type_as(loss) 50 | label_size = label_mask.sum(1).type_as(loss) 51 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 52 | ppl_value = masked_lm_loss.exp() 53 | 54 | if validation: 55 | preds = torch.argmax(lm_logits, dim=-1) 56 | acc = ((preds == labels) * label_mask).type_as(loss) # [bs, length] 57 | occurrence = torch.tril(preds.unsqueeze(-1) == labels.unsqueeze(-2), -1) # [bs, length, length] 58 | rep = (occurrence.sum(dim=-1) > 0).type_as(loss) * label_mask # [bs, length] 59 | wrep = rep * (1. - acc) 60 | 61 | outputs.loss = masked_lm_loss 62 | 63 | if not self.training and not validation: # inference 64 | return outputs 65 | elif self.training: # training 66 | res = {'all': masked_lm_loss, 'ppl': ppl_value, } 67 | return res 68 | else: # validation 69 | assert not self.training 70 | return loss, label_size, acc, rep, wrep 71 | -------------------------------------------------------------------------------- /models/gpt2_contrast_01.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.gpt2 import (GPT2LMHeadModel, GPT2Config) 7 | from transformers.models.gpt_neo import (GPTNeoForCausalLM, GPTNeoConfig) 8 | from transformers.models.gptj import (GPTJForCausalLM, GPTJConfig) 9 | from transformers.models.opt import (OPTForCausalLM, OPTConfig) 10 | from utils.model_utils import BaseModel 11 | 12 | GAMMA = 0.1 13 | 14 | 15 | class Model(BaseModel, GPT2LMHeadModel): 16 | def __init__(self, config: GPT2Config, alpha): 17 | super().__init__(config) 18 | self.alpha = float(alpha) 19 | 20 | def forward( 21 | self, 22 | input_ids=None, 23 | attention_mask=None, 24 | pos_input_ids=None, 25 | neg_input_ids=None, 26 | past_key_values=None, 27 | labels=None, 28 | pos_labels=None, 29 | neg_labels=None, 30 | use_cache=None, 31 | return_dict=True, 32 | validation=False, 33 | **kwargs 34 | ): 35 | assert self.toker is not None 36 | assert self.training and not validation 37 | assert labels is not None 38 | use_cache = False 39 | 40 | outputs = super().forward( 41 | input_ids=input_ids, 42 | attention_mask=attention_mask, 43 | past_key_values=past_key_values, 44 | labels=None, 45 | return_dict=return_dict, 46 | use_cache=use_cache, 47 | **kwargs 48 | ) 49 | 50 | lm_logits = outputs.logits 51 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 52 | loss = loss.view(labels.size(0), labels.size(1)) 53 | label_size = labels.ne(-100).sum(1).type_as(loss) 54 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 55 | ppl_value = masked_lm_loss.exp() 56 | 57 | outputs.loss = masked_lm_loss 58 | 59 | if neg_input_ids is None: 60 | res = {'all': masked_lm_loss, 'ppl': ppl_value, } 61 | return res 62 | 63 | # alpha 64 | pos_outputs = super().forward( 65 | input_ids=pos_input_ids, 66 | attention_mask=None, 67 | labels=None, 68 | return_dict=return_dict, 69 | use_cache=use_cache, 70 | **kwargs 71 | ) 72 | pos_lm_logits = pos_outputs.logits 73 | pos_loss = F.cross_entropy(pos_lm_logits.view(-1, pos_lm_logits.size(-1)), pos_labels.view(-1), reduction='none') 74 | pos_loss = pos_loss.view(pos_labels.size(0), pos_labels.size(1)).sum(-1) 75 | 76 | neg_outputs = super().forward( 77 | input_ids=neg_input_ids, 78 | attention_mask=None, 79 | labels=None, 80 | return_dict=return_dict, 81 | use_cache=use_cache, 82 | **kwargs 83 | ) 84 | neg_lm_logits = neg_outputs.logits 85 | neg_loss = F.cross_entropy(neg_lm_logits.view(-1, neg_lm_logits.size(-1)), neg_labels.view(-1), reduction='none') 86 | neg_loss = neg_loss.view(neg_labels.size(0), neg_labels.size(1)).sum(-1) 87 | 88 | # we have pos_loss < neg_loss 89 | loss1 = torch.clamp(self.alpha + pos_loss - neg_loss, min=0.) 90 | loss1 = loss1.mean() 91 | 92 | res = {'all': masked_lm_loss + GAMMA * loss1, 'ppl': ppl_value, 'loss': masked_lm_loss, 'loss1': loss1, } 93 | return res 94 | -------------------------------------------------------------------------------- /models/gpt2_contrast_03.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.gpt2 import (GPT2LMHeadModel, GPT2Config) 7 | from transformers.models.gpt_neo import (GPTNeoForCausalLM, GPTNeoConfig) 8 | from transformers.models.gptj import (GPTJForCausalLM, GPTJConfig) 9 | from transformers.models.opt import (OPTForCausalLM, OPTConfig) 10 | from utils.model_utils import BaseModel 11 | 12 | GAMMA = 0.3 13 | 14 | 15 | class Model(BaseModel, GPT2LMHeadModel): 16 | def __init__(self, config: GPT2Config, alpha): 17 | super().__init__(config) 18 | self.alpha = float(alpha) 19 | 20 | def forward( 21 | self, 22 | input_ids=None, 23 | attention_mask=None, 24 | pos_input_ids=None, 25 | neg_input_ids=None, 26 | past_key_values=None, 27 | labels=None, 28 | pos_labels=None, 29 | neg_labels=None, 30 | use_cache=None, 31 | return_dict=True, 32 | validation=False, 33 | **kwargs 34 | ): 35 | assert self.toker is not None 36 | assert self.training and not validation 37 | assert labels is not None 38 | use_cache = False 39 | 40 | outputs = super().forward( 41 | input_ids=input_ids, 42 | attention_mask=attention_mask, 43 | past_key_values=past_key_values, 44 | labels=None, 45 | return_dict=return_dict, 46 | use_cache=use_cache, 47 | **kwargs 48 | ) 49 | 50 | lm_logits = outputs.logits 51 | loss = F.cross_entropy(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1), reduction='none') 52 | loss = loss.view(labels.size(0), labels.size(1)) 53 | label_size = labels.ne(-100).sum(1).type_as(loss) 54 | masked_lm_loss = loss.sum() / torch.clamp(label_size.sum(), min=1e-5) 55 | ppl_value = masked_lm_loss.exp() 56 | 57 | outputs.loss = masked_lm_loss 58 | 59 | if neg_input_ids is None: 60 | res = {'all': masked_lm_loss, 'ppl': ppl_value, } 61 | return res 62 | 63 | # alpha 64 | pos_outputs = super().forward( 65 | input_ids=pos_input_ids, 66 | attention_mask=None, 67 | labels=None, 68 | return_dict=return_dict, 69 | use_cache=use_cache, 70 | **kwargs 71 | ) 72 | pos_lm_logits = pos_outputs.logits 73 | pos_loss = F.cross_entropy(pos_lm_logits.view(-1, pos_lm_logits.size(-1)), pos_labels.view(-1), reduction='none') 74 | pos_loss = pos_loss.view(pos_labels.size(0), pos_labels.size(1)).sum(-1) 75 | 76 | neg_outputs = super().forward( 77 | input_ids=neg_input_ids, 78 | attention_mask=None, 79 | labels=None, 80 | return_dict=return_dict, 81 | use_cache=use_cache, 82 | **kwargs 83 | ) 84 | neg_lm_logits = neg_outputs.logits 85 | neg_loss = F.cross_entropy(neg_lm_logits.view(-1, neg_lm_logits.size(-1)), neg_labels.view(-1), reduction='none') 86 | neg_loss = neg_loss.view(neg_labels.size(0), neg_labels.size(1)).sum(-1) 87 | 88 | # we have pos_loss < neg_loss 89 | loss1 = torch.clamp(self.alpha + pos_loss - neg_loss, min=0.) 90 | loss1 = loss1.mean() 91 | 92 | res = {'all': masked_lm_loss + 0.3 * loss1, 'ppl': ppl_value, 'loss': masked_lm_loss, 'loss1': loss1, } 93 | return res 94 | -------------------------------------------------------------------------------- /models/roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers.models.roberta.modeling_roberta import (RobertaConfig, RobertaPreTrainedModel, RobertaModel, RobertaClassificationHead) 7 | from utils.model_utils import BaseModel 8 | 9 | 10 | class Model(BaseModel, RobertaPreTrainedModel): 11 | def __init__(self, config: RobertaConfig, num_labels): 12 | super().__init__(config) 13 | self.num_labels = config.num_labels = int(num_labels) 14 | self.config = config 15 | 16 | self.roberta = RobertaModel(config, add_pooling_layer=False) 17 | self.classifier = RobertaClassificationHead(config) 18 | 19 | # Initialize weights and apply final processing 20 | self.post_init() 21 | 22 | def forward( 23 | self, 24 | input_ids=None, 25 | attention_mask=None, 26 | token_type_ids=None, 27 | position_ids=None, 28 | return_dict=None, 29 | validation=False, 30 | inference=False, 31 | **kwargs, 32 | ): 33 | assert (not validation and not inference) == self.training 34 | encoded_info = kwargs 35 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 36 | 37 | outputs = self.roberta( 38 | input_ids, 39 | attention_mask=attention_mask, 40 | token_type_ids=token_type_ids, 41 | position_ids=position_ids, 42 | return_dict=return_dict, 43 | ) 44 | sequence_output = outputs[0] 45 | logits = self.classifier(sequence_output) 46 | loss = self.predict_label(logits, encoded_info) 47 | 48 | if inference: 49 | return encoded_info.copy() 50 | else: 51 | ppl_value = loss.new_tensor([0.], dtype=loss.dtype) 52 | if not validation: 53 | res = { 54 | 'all': loss, 55 | 'ppl': ppl_value, 56 | 'loss': loss, 57 | } 58 | return res 59 | else: 60 | return loss * input_ids.size(0), loss.new_tensor([input_ids.size(0)]) 61 | 62 | def predict_label(self, logits, encoded_info): 63 | preds = torch.argmax(logits, dim=-1) 64 | if self.num_labels > 3: 65 | preds_top3 = torch.topk(logits, k=3, dim=-1)[1] 66 | loss = None 67 | if 'labels' in encoded_info: 68 | loss = F.cross_entropy(logits, encoded_info.get('labels'), reduction='mean') 69 | encoded_info['preds'] = preds 70 | if self.num_labels > 3: 71 | encoded_info['preds_top3'] = preds_top3 72 | encoded_info['preds_dist'] = F.softmax(logits, dim=-1) 73 | return loss 74 | -------------------------------------------------------------------------------- /results_bad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import multiprocessing as mp 5 | import argparse 6 | import csv 7 | 8 | 9 | def is_negative(x): 10 | return x[0] > 0.5 11 | 12 | 13 | def process(infer_data_path): 14 | #print('\t', infer_data_path) 15 | if not os.path.exists(f'{infer_data_path}/gen.txt'): 16 | return [] 17 | 18 | stats_line = [] 19 | 20 | first_gen = json.loads(open(f'{infer_data_path}/gen.txt').readline())['generation'] 21 | if isinstance(first_gen, str): 22 | multiple = 1 23 | else: 24 | multiple = len(first_gen) 25 | 26 | if os.path.exists(f'{infer_data_path}/pred_list.txt'): 27 | preds = [json.loads(e) for e in open(f'{infer_data_path}/pred_list.txt')] 28 | preds = np.array(preds).reshape(-1, multiple, 2).transpose(2, 0, 1) 29 | #max_negative = preds[0].max(axis=-1).mean() 30 | emp_prob = (is_negative(preds).sum(axis=-1) > 0).mean() 31 | stats_line.extend([str(emp_prob)]) 32 | else: 33 | stats_line.extend(['']) 34 | 35 | if os.path.exists(f'{infer_data_path}/loss_large_list.txt'): 36 | losses = [json.loads(e) for e in open(f'{infer_data_path}/loss_large_list.txt')] 37 | ppls = [np.exp(e['loss'] / e['num_tokens']) for e in losses] 38 | ppl = np.mean(ppls) 39 | stats_line.append(ppl) 40 | else: 41 | stats_line.extend(['', ]) 42 | 43 | if os.path.exists(f'{infer_data_path}/dist_list.txt'): 44 | dists = [json.loads(e) for e in open(f'{infer_data_path}/dist_list.txt')] 45 | dists = np.array(dists).mean(0) 46 | dists = [str(e) for e in dists][-2:] 47 | stats_line.extend(dists) 48 | else: 49 | stats_line.extend(['', '', ]) 50 | 51 | return stats_line 52 | 53 | 54 | def main(args): 55 | if not args.test: 56 | save_name = './results_bad.txt' 57 | else: 58 | save_name = './results_bad_test.txt' 59 | stats_writer = csv.writer(open(save_name, 'w'), delimiter='\t') 60 | 61 | pool = mp.Pool(mp.cpu_count()) 62 | for stats_line in pool.imap(process, args.infer_data_paths): 63 | stats_writer.writerow(stats_line) 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 70 | parser.add_argument('--test', action='store_true') 71 | 72 | args = parser.parse_args() 73 | 74 | main(args) 75 | -------------------------------------------------------------------------------- /results_bad_test.txt: -------------------------------------------------------------------------------- 1 | 0.4503464203233256 5.226265091834621 0.39591065401690667 0.46189286012783837 2 | 0.45265588914549654 6.319413270050275 0.421233081007264 0.4876807977878095 3 | 0.18706697459584296 7.104138651268213 0.1416165430750332 0.15087858327286904 4 | 0.30331023864511164 8.918080006617199 0.3979275107555768 0.4412089461160174 5 | 0.16397228637413394 7.931808082637668 0.2558327931046866 0.2868374112114443 6 | 0.4372594303310239 9.06198458534322 0.4226993784226805 0.49154224198261953 7 | 0.08391070053887606 6.4830424547548775 0.48837646701603316 0.5591401157652692 8 | -------------------------------------------------------------------------------- /results_senti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import multiprocessing as mp 5 | import argparse 6 | import csv 7 | from functools import partial 8 | 9 | 10 | def calculate_positive_ratio(file_path): 11 | preds = [json.loads(e) for e in open(file_path)] 12 | preds = np.array(preds).transpose(1, 0) 13 | mean = (preds[0] < 0.5).mean() * 100 14 | return mean 15 | 16 | 17 | def process(infer_data_path, key): 18 | #print('\t', infer_data_path) 19 | if not (os.path.exists(f'{infer_data_path}/neutral/gen.txt') and 20 | os.path.exists(f'{infer_data_path}/{key}/gen.txt')): 21 | return [] 22 | 23 | stats_line = [] 24 | 25 | if key == 'positive': 26 | if os.path.exists(f'{infer_data_path}/positive/pred_list.txt'): 27 | value = 100. - calculate_positive_ratio(f'{infer_data_path}/positive/pred_list.txt') 28 | stats_line.append(str(value)) 29 | else: 30 | stats_line.append('') 31 | else: 32 | if os.path.exists(f'{infer_data_path}/negative/pred_list.txt'): 33 | value = calculate_positive_ratio(f'{infer_data_path}/negative/pred_list.txt') 34 | stats_line.append(str(value)) 35 | else: 36 | stats_line.append('') 37 | 38 | if os.path.exists(f'{infer_data_path}/neutral/pred_list.txt'): 39 | value = calculate_positive_ratio(f'{infer_data_path}/neutral/pred_list.txt') 40 | if key == 'positive': 41 | value = 100. - value 42 | stats_line.append(str(value)) 43 | else: 44 | stats_line.append('') 45 | 46 | if os.path.exists(f'{infer_data_path}/neutral/loss_large_list.txt') and os.path.exists(f'{infer_data_path}/{key}/loss_large_list.txt'): 47 | losses = [] 48 | if os.path.exists(f'{infer_data_path}/neutral/loss_large_list.txt'): 49 | losses += [json.loads(e) for e in open(f'{infer_data_path}/neutral/loss_large_list.txt')] 50 | if os.path.exists(f'{infer_data_path}/{key}/loss_large_list.txt'): 51 | losses += [json.loads(e) for e in open(f'{infer_data_path}/{key}/loss_large_list.txt')] 52 | ppls = [np.exp(e['loss'] / e['num_tokens']) for e in losses] 53 | ppls = [e for e in ppls if e < 1e4] 54 | ppl = np.mean(ppls) 55 | #ppl = np.exp(np.sum([e['loss'] for e in losses]) / np.sum([e['num_tokens'] for e in losses])) 56 | stats_line.append(ppl) 57 | else: 58 | stats_line.append('') 59 | 60 | if os.path.exists(f'{infer_data_path}/neutral/dist_list.txt') and os.path.exists(f'{infer_data_path}/{key}/dist_list.txt'): 61 | dists = [] 62 | if os.path.exists(f'{infer_data_path}/neutral/dist_list.txt'): 63 | dists += [json.loads(e) for e in open(f'{infer_data_path}/neutral/dist_list.txt')] 64 | if os.path.exists(f'{infer_data_path}/{key}/dist_list.txt'): 65 | dists += [json.loads(e) for e in open(f'{infer_data_path}/{key}/dist_list.txt')] 66 | dists = np.array(dists).mean(0) 67 | dists = [str(e) for e in dists][-2:] 68 | stats_line.extend(dists) 69 | else: 70 | stats_line.extend([ '', '', ]) 71 | 72 | return stats_line 73 | 74 | 75 | def main(args): 76 | if args.positive: 77 | save_name = './results_senti_pos.txt' 78 | else: 79 | save_name = './results_senti_neg.txt' 80 | stats_writer = csv.writer(open(save_name, 'w'), delimiter='\t') 81 | 82 | pool = mp.Pool(mp.cpu_count()) 83 | for stats_line in pool.imap( 84 | partial(process, key='negative' if args.positive else 'positive'), 85 | args.infer_data_paths 86 | ): 87 | stats_writer.writerow(stats_line) 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 94 | parser.add_argument('--positive', action='store_true') 95 | parser.add_argument('--negative', action='store_true') 96 | 97 | args = parser.parse_args() 98 | 99 | assert args.positive or args.negative 100 | assert not (args.positive and args.negative) 101 | 102 | main(args) 103 | -------------------------------------------------------------------------------- /results_senti_neg.txt: -------------------------------------------------------------------------------- 1 | 90.6176 95.4248 51.46155644863876 0.8125300232572494 0.8475403835478289 2 | -------------------------------------------------------------------------------- /results_senti_pos.txt: -------------------------------------------------------------------------------- 1 | 85.776 96.69919999999999 57.43047246008308 0.8034099486358759 0.8388055530484299 2 | -------------------------------------------------------------------------------- /results_wiki.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import argparse 5 | import csv 6 | 7 | 8 | def process(infer_data_path): 9 | print('\t', infer_data_path) 10 | stats_line = [] 11 | 12 | if os.path.exists(f'{infer_data_path}/metric.json'): 13 | metric = json.load(open(f'{infer_data_path}/metric.json')) 14 | for metric_name in [ 15 | #'dist-1', 16 | 'ppl_macro', 17 | 'acc', 'rep', 'wrep', 18 | 'rep-2', 'rep-3', 19 | 'diversity', 20 | ]: 21 | if metric_name in metric: 22 | stats_line.append(str(metric[metric_name])) 23 | else: 24 | stats_line.append('') 25 | 26 | if os.path.exists(f'{infer_data_path}/mauve.txt'): 27 | metric = open(f'{infer_data_path}/mauve.txt').readline().strip() 28 | stats_line.append(metric) 29 | else: 30 | stats_line.append('') 31 | 32 | return stats_line 33 | 34 | 35 | def main(args): 36 | if not args.test: 37 | save_name = './results_wiki.txt' 38 | else: 39 | save_name = './results_wiki_test.txt' 40 | stats_writer = csv.writer(open(save_name, 'w'), delimiter='\t') 41 | 42 | for infer_data_path in args.infer_data_paths: 43 | stats_line = process(infer_data_path) 44 | stats_writer.writerow(stats_line) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--infer_data_paths', type=str, nargs='+') 51 | parser.add_argument('--test', action='store_true') 52 | 53 | args = parser.parse_args() 54 | 55 | main(args) 56 | -------------------------------------------------------------------------------- /results_wiki_test.txt: -------------------------------------------------------------------------------- 1 | 31.797867581652913 38.830918073654175 43.86789798736572 24.725477397441864 20.228296947945058 7.425708879541526 0.7228203201668153 0.9343364298980311 2 | -------------------------------------------------------------------------------- /scripts_bad/blender/augment.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | generate.py \ 8 | --collator_name text2text \ 9 | --model_name blender \ 10 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 11 | --save_path checkpoints_bad/blender \ 12 | --infer_data_paths data_bad/blender/train.txt \ 13 | --infer_names train \ 14 | --only_generate \ 15 | --max_input_length 128 \ 16 | --max_decoder_input_length 32 \ 17 | --seed 0 \ 18 | --lower \ 19 | --max_length 32 \ 20 | --min_length 5 \ 21 | --batch_size 32 \ 22 | --temperature 1 \ 23 | --top_k 0 \ 24 | --top_p 0.9 \ 25 | --num_beams 1 \ 26 | --num_return_sequences 20 \ 27 | --length_penalty 1 \ 28 | --repetition_penalty 1 \ 29 | --no_repeat_ngram_size 0 \ 30 | --encoder_no_repeat_ngram_size 0 31 | 32 | -------------------------------------------------------------------------------- /scripts_bad/blender/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | generate.py \ 8 | --collator_name text2text \ 9 | --model_name blender \ 10 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 11 | --save_path checkpoints_bad/blender \ 12 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 13 | --infer_names valid test \ 14 | --only_generate \ 15 | --max_input_length 128 \ 16 | --max_decoder_input_length 32 \ 17 | --seed 0 \ 18 | --lower \ 19 | --max_length 32 \ 20 | --min_length 5 \ 21 | --batch_size 6 \ 22 | --temperature 1 \ 23 | --top_k 0 \ 24 | --top_p 0.9 \ 25 | --num_beams 1 \ 26 | --num_return_sequences 25 \ 27 | --length_penalty 1 \ 28 | --repetition_penalty 1 \ 29 | --no_repeat_ngram_size 0 \ 30 | --encoder_no_repeat_ngram_size 0 31 | -------------------------------------------------------------------------------- /scripts_bad/contrast/augment.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | generate.py \ 8 | --collator_name text2text \ 9 | --model_name blender \ 10 | --pretrained_model_path checkpoints_bad/contrast_05/20.0 \ 11 | --save_path checkpoints_bad/contrast_05/20.0 \ 12 | --infer_data_paths data_bad/blender/train.txt \ 13 | --infer_names train \ 14 | --only_generate \ 15 | --max_input_length 128 \ 16 | --max_decoder_input_length 32 \ 17 | --seed 0 \ 18 | --lower \ 19 | --max_length 32 \ 20 | --min_length 5 \ 21 | --batch_size 32 \ 22 | --temperature 1 \ 23 | --top_k 0 \ 24 | --top_p 0.9 \ 25 | --num_beams 1 \ 26 | --num_return_sequences 20 \ 27 | --length_penalty 1 \ 28 | --repetition_penalty 1 \ 29 | --no_repeat_ngram_size 0 \ 30 | --encoder_no_repeat_ngram_size 0 31 | 32 | -------------------------------------------------------------------------------- /scripts_bad/contrast/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=4 3 | 4 | for model in 05; do 5 | for alpha in 20.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision no \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | generate.py \ 13 | --collator_name text2text \ 14 | --model_name blender \ 15 | --pretrained_model_path checkpoints_bad/contrast_${model}/${alpha} \ 16 | --save_path checkpoints_bad/contrast_${model}/${alpha} \ 17 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 18 | --infer_names valid test \ 19 | --only_generate \ 20 | --max_input_length 128 \ 21 | --max_decoder_input_length 32 \ 22 | --seed 0 \ 23 | --lower \ 24 | --max_length 32 \ 25 | --min_length 5 \ 26 | --batch_size 6 \ 27 | --temperature 1 \ 28 | --top_k 0 \ 29 | --top_p 0.9 \ 30 | --num_beams 1 \ 31 | --num_return_sequences 25 \ 32 | --length_penalty 1 \ 33 | --repetition_penalty 1 \ 34 | --no_repeat_ngram_size 0 \ 35 | --encoder_no_repeat_ngram_size 0 36 | 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts_bad/contrast/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=3 3 | 4 | for alpha in 20.0; do 5 | for model in 05; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision fp16 \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | train.py \ 13 | --collator_name text2text_contrast \ 14 | --model_name blender_contrast_${model} \ 15 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 16 | --model_args ${alpha} \ 17 | --save_path checkpoints_bad/contrast_${model}/${alpha} \ 18 | --train_data_path data_bad/contrast/train.txt \ 19 | --max_input_length 128 \ 20 | --max_decoder_input_length 32 \ 21 | --seed 42 \ 22 | --adafactor \ 23 | --batch_size 63 \ 24 | --gradient_accumulation_steps 3 \ 25 | --learning_rate 5e-5 \ 26 | --num_epochs 2 \ 27 | --warmup_steps 50 28 | 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /scripts_bad/cringe/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=1 3 | 4 | for gamma in 02; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name text2text \ 13 | --model_name blender \ 14 | --pretrained_model_path checkpoints_bad/cringe_${gamma} \ 15 | --save_path checkpoints_bad/cringe_${gamma} \ 16 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 17 | --infer_names valid test \ 18 | --only_generate \ 19 | --max_input_length 128 \ 20 | --max_decoder_input_length 32 \ 21 | --seed 0 \ 22 | --lower \ 23 | --max_length 32 \ 24 | --min_length 5 \ 25 | --batch_size 6 \ 26 | --temperature 1 \ 27 | --top_k 0 \ 28 | --top_p 0.9 \ 29 | --num_beams 1 \ 30 | --num_return_sequences 25 \ 31 | --length_penalty 1 \ 32 | --repetition_penalty 1 \ 33 | --no_repeat_ngram_size 0 \ 34 | --encoder_no_repeat_ngram_size 0 35 | 36 | done 37 | -------------------------------------------------------------------------------- /scripts_bad/cringe/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for gamma in 02; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision fp16 \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | train.py \ 12 | --collator_name text2text_labels \ 13 | --model_name blender_cringe_${gamma} \ 14 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 15 | --save_path checkpoints_bad/cringe_${gamma} \ 16 | --train_data_path data_bad/labels/train.txt \ 17 | --max_input_length 128 \ 18 | --max_decoder_input_length 32 \ 19 | --seed 42 \ 20 | --adafactor \ 21 | --batch_size 64 \ 22 | --gradient_accumulation_steps 1 \ 23 | --learning_rate 5e-5 \ 24 | --num_epochs 2 \ 25 | --warmup_steps 50 26 | 27 | done 28 | -------------------------------------------------------------------------------- /scripts_bad/dexperts/antiexpert.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 accelerate launch \ 2 | --mixed_precision fp16 \ 3 | --num_processes 1 \ 4 | --num_machines 1 \ 5 | --num_cpu_threads_per_process 32 \ 6 | train.py \ 7 | --collator_name text2text \ 8 | --model_name blender \ 9 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 10 | --save_path checkpoints_bad/experts/antiexpert \ 11 | --train_data_path data_bad/experts/antiexpert.txt \ 12 | --max_input_length 128 \ 13 | --max_decoder_input_length 32 \ 14 | --seed 42 \ 15 | --adafactor \ 16 | --batch_size 64 \ 17 | --gradient_accumulation_steps 1 \ 18 | --learning_rate 5e-5 \ 19 | --num_epochs 2 \ 20 | --warmup_steps 50 21 | -------------------------------------------------------------------------------- /scripts_bad/dexperts/expert.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 accelerate launch \ 2 | --mixed_precision fp16 \ 3 | --num_processes 1 \ 4 | --num_machines 1 \ 5 | --num_cpu_threads_per_process 32 \ 6 | train.py \ 7 | --collator_name text2text \ 8 | --model_name blender \ 9 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 10 | --save_path checkpoints_bad/experts/expert \ 11 | --train_data_path data_bad/experts/expert.txt \ 12 | --max_input_length 128 \ 13 | --max_decoder_input_length 32 \ 14 | --seed 42 \ 15 | --adafactor \ 16 | --batch_size 64 \ 17 | --gradient_accumulation_steps 1 \ 18 | --learning_rate 5e-5 \ 19 | --num_epochs 2 \ 20 | --warmup_steps 50 21 | -------------------------------------------------------------------------------- /scripts_bad/dexperts/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=6 3 | 4 | for alpha in 5.0; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name text2text_dexperts \ 13 | --model_name blender_dexperts \ 14 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 15 | --model_args checkpoints_bad/experts/expert checkpoints_bad/experts/antiexpert ${alpha} \ 16 | --save_path checkpoints_bad/dexperts \ 17 | --infer_data_paths data_bad/dexperts/valid.txt data_bad/dexperts/test.txt \ 18 | --infer_names valid_${alpha} test_${alpha} \ 19 | --only_generate \ 20 | --max_input_length 128 \ 21 | --max_decoder_input_length 32 \ 22 | --seed 0 \ 23 | --lower \ 24 | --max_length 32 \ 25 | --min_length 5 \ 26 | --batch_size 6 \ 27 | --temperature 1 \ 28 | --top_k 0 \ 29 | --top_p 0.9 \ 30 | --num_beams 1 \ 31 | --num_return_sequences 25 \ 32 | --length_penalty 1 \ 33 | --repetition_penalty 1 \ 34 | --no_repeat_ngram_size 0 \ 35 | --encoder_no_repeat_ngram_size 0 36 | 37 | done 38 | -------------------------------------------------------------------------------- /scripts_bad/director/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for gamma in 02; do 5 | for alpha in 10.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision no \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | generate.py \ 13 | --collator_name text2text \ 14 | --model_name blender_director_${gamma} \ 15 | --pretrained_model_path checkpoints_bad/director_${gamma} \ 16 | --model_args ${alpha} \ 17 | --save_path checkpoints_bad/director_${gamma} \ 18 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 19 | --infer_names valid_${alpha} test_${alpha} \ 20 | --max_input_length 128 \ 21 | --max_decoder_input_length 32 \ 22 | --only_generate \ 23 | --seed 0 \ 24 | --lower \ 25 | --max_length 32 \ 26 | --min_length 5 \ 27 | --batch_size 6 \ 28 | --temperature 1 \ 29 | --top_k 0 \ 30 | --top_p 0.9 \ 31 | --num_beams 1 \ 32 | --num_return_sequences 25 \ 33 | --length_penalty 1 \ 34 | --repetition_penalty 1 \ 35 | --no_repeat_ngram_size 0 \ 36 | --encoder_no_repeat_ngram_size 0 37 | 38 | done 39 | 40 | done 41 | -------------------------------------------------------------------------------- /scripts_bad/director/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=6 3 | 4 | for gamma in 02; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision fp16 \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | train.py \ 12 | --collator_name text2text_labels \ 13 | --model_name blender_director_${gamma} \ 14 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 15 | --save_path checkpoints_bad/director_${gamma} \ 16 | --train_data_path data_bad/labels/train.txt \ 17 | --max_input_length 128 \ 18 | --max_decoder_input_length 32 \ 19 | --seed 42 \ 20 | --adafactor \ 21 | --batch_size 64 \ 22 | --gradient_accumulation_steps 1 \ 23 | --learning_rate 5e-5 \ 24 | --num_epochs 3 \ 25 | --warmup_steps 50 26 | 27 | done 28 | -------------------------------------------------------------------------------- /scripts_bad/eval/dist.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_bad/ft/test" 4 | paths[${#paths[*]}]="checkpoints_bad/unlikelihood_01/test" 5 | paths[${#paths[*]}]="checkpoints_bad/gedi/test_20.0" 6 | paths[${#paths[*]}]="checkpoints_bad/dexperts/test_5.0" 7 | paths[${#paths[*]}]="checkpoints_bad/director_02/test_10.0" 8 | paths[${#paths[*]}]="checkpoints_bad/cringe_02/test" 9 | paths[${#paths[*]}]="checkpoints_bad/contrast_05/20.0/test" 10 | 11 | python eval_dist.py \ 12 | --context_file data_bad/blender/test.txt \ 13 | --infer_data_paths ${paths[*]} \ 14 | 15 | -------------------------------------------------------------------------------- /scripts_bad/eval/ppl_large.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_bad/ft/test" 4 | paths[${#paths[*]}]="checkpoints_bad/unlikelihood_01/test" 5 | paths[${#paths[*]}]="checkpoints_bad/gedi/test_20.0" 6 | paths[${#paths[*]}]="checkpoints_bad/dexperts/test_5.0" 7 | paths[${#paths[*]}]="checkpoints_bad/director_02/test_10.0" 8 | paths[${#paths[*]}]="checkpoints_bad/cringe_02/test" 9 | paths[${#paths[*]}]="checkpoints_bad/contrast_05/20.0/test" 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 12 | --mixed_precision no \ 13 | --num_processes 1 \ 14 | --num_machines 1 \ 15 | --num_cpu_threads_per_process 32 \ 16 | eval_ppl_blender.py \ 17 | --save_name large \ 18 | --pretrained_model_path /home/zhengchujie/pretrained-models-large/facebook/blenderbot-1B-distill \ 19 | --context_file data_bad/raw/test.txt \ 20 | --infer_data_paths ${paths[*]} \ 21 | --max_input_length 128 \ 22 | --max_decoder_input_length 32 \ 23 | --seed 42 \ 24 | --batch_size 800 25 | -------------------------------------------------------------------------------- /scripts_bad/eval/ppl_self.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | eval_ppl_blender.py \ 8 | --save_name self \ 9 | --pretrained_model_path checkpoints_bad/contrast_05/20.0 \ 10 | --context_file data_bad/blender/train.txt \ 11 | --infer_data_paths checkpoints_bad/contrast_05/20.0/train \ 12 | --max_input_length 128 \ 13 | --max_decoder_input_length 32 \ 14 | --seed 42 \ 15 | --batch_size 1600 16 | 17 | -------------------------------------------------------------------------------- /scripts_bad/eval/predict.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_bad/ft/test" 4 | paths[${#paths[*]}]="checkpoints_bad/unlikelihood_01/test" 5 | paths[${#paths[*]}]="checkpoints_bad/gedi/test_20.0" 6 | paths[${#paths[*]}]="checkpoints_bad/dexperts/test_5.0" 7 | paths[${#paths[*]}]="checkpoints_bad/director_02/test_10.0" 8 | paths[${#paths[*]}]="checkpoints_bad/cringe_02/test" 9 | paths[${#paths[*]}]="checkpoints_bad/contrast_05/20.0/test" 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 12 | --mixed_precision no \ 13 | --num_processes 1 \ 14 | --num_machines 1 \ 15 | --num_cpu_threads_per_process 32 \ 16 | eval_bad.py \ 17 | --collator_name classification \ 18 | --model_name roberta \ 19 | --pretrained_model_path checkpoints_cls/bad \ 20 | --model_args 2 \ 21 | --context_file data_bad/raw/test.txt \ 22 | --infer_data_paths ${paths[*]} \ 23 | --max_input_length 192 \ 24 | --seed 42 \ 25 | --batch_size 800 26 | -------------------------------------------------------------------------------- /scripts_bad/eval/predict_train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | eval_bad.py \ 8 | --collator_name classification \ 9 | --model_name roberta \ 10 | --pretrained_model_path checkpoints_cls/bad \ 11 | --model_args 2 \ 12 | --context_file data_bad/raw/train.txt \ 13 | --infer_data_paths checkpoints_bad/contrast_05/20.0/train \ 14 | --max_input_length 192 \ 15 | --seed 42 \ 16 | --batch_size 1600 17 | -------------------------------------------------------------------------------- /scripts_bad/ft/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | generate.py \ 8 | --collator_name text2text \ 9 | --model_name blender \ 10 | --pretrained_model_path checkpoints_bad/experts/expert \ 11 | --save_path checkpoints_bad/ft \ 12 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 13 | --infer_names valid test \ 14 | --only_generate \ 15 | --max_input_length 128 \ 16 | --max_decoder_input_length 32 \ 17 | --seed 0 \ 18 | --lower \ 19 | --max_length 32 \ 20 | --min_length 5 \ 21 | --batch_size 6 \ 22 | --temperature 1 \ 23 | --top_k 0 \ 24 | --top_p 0.9 \ 25 | --num_beams 1 \ 26 | --num_return_sequences 25 \ 27 | --length_penalty 1 \ 28 | --repetition_penalty 1 \ 29 | --no_repeat_ngram_size 0 \ 30 | --encoder_no_repeat_ngram_size 0 31 | 32 | -------------------------------------------------------------------------------- /scripts_bad/gedi/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for alpha in 20.0; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name text2text_dexperts \ 13 | --model_name blender_gedinew \ 14 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 15 | --model_args checkpoints_bad/experts/expert checkpoints_bad/experts/antiexpert ${alpha} \ 16 | --save_path checkpoints_bad/gedinew \ 17 | --infer_data_paths data_bad/dexperts/valid.txt data_bad/dexperts/test.txt \ 18 | --infer_names valid_${alpha} test_${alpha} \ 19 | --only_generate \ 20 | --max_input_length 128 \ 21 | --max_decoder_input_length 32 \ 22 | --seed 0 \ 23 | --lower \ 24 | --max_length 32 \ 25 | --min_length 5 \ 26 | --batch_size 8 \ 27 | --temperature 1 \ 28 | --top_k 0 \ 29 | --top_p 0.9 \ 30 | --num_beams 1 \ 31 | --num_return_sequences 25 \ 32 | --length_penalty 1 \ 33 | --repetition_penalty 1 \ 34 | --no_repeat_ngram_size 0 \ 35 | --encoder_no_repeat_ngram_size 0 36 | 37 | done 38 | -------------------------------------------------------------------------------- /scripts_bad/stats.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_bad/ft/test" 4 | paths[${#paths[*]}]="checkpoints_bad/unlikelihood_01/test" 5 | paths[${#paths[*]}]="checkpoints_bad/gedi/test_20.0" 6 | paths[${#paths[*]}]="checkpoints_bad/dexperts/test_5.0" 7 | paths[${#paths[*]}]="checkpoints_bad/director_02/test_10.0" 8 | paths[${#paths[*]}]="checkpoints_bad/cringe_02/test" 9 | paths[${#paths[*]}]="checkpoints_bad/contrast_05/20.0/test" 10 | 11 | python results_bad.py --infer_data_paths ${paths[*]} --test 12 | -------------------------------------------------------------------------------- /scripts_bad/unlikelihood/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=1 3 | 4 | for gamma in 01; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name text2text \ 13 | --model_name blender \ 14 | --pretrained_model_path checkpoints_bad/unlikelihood_${gamma} \ 15 | --save_path checkpoints_bad/unlikelihood_${gamma} \ 16 | --infer_data_paths data_bad/blender/valid.txt data_bad/blender/test.txt \ 17 | --infer_names valid test \ 18 | --only_generate \ 19 | --max_input_length 128 \ 20 | --max_decoder_input_length 32 \ 21 | --seed 0 \ 22 | --lower \ 23 | --max_length 32 \ 24 | --min_length 5 \ 25 | --batch_size 6 \ 26 | --temperature 1 \ 27 | --top_k 0 \ 28 | --top_p 0.9 \ 29 | --num_beams 1 \ 30 | --num_return_sequences 25 \ 31 | --length_penalty 1 \ 32 | --repetition_penalty 1 \ 33 | --no_repeat_ngram_size 0 \ 34 | --encoder_no_repeat_ngram_size 0 35 | 36 | done 37 | -------------------------------------------------------------------------------- /scripts_bad/unlikelihood/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=7 3 | 4 | for gamma in 01; do 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision fp16 \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | train.py \ 12 | --collator_name text2text_labels \ 13 | --model_name blender_unlikelihood_${gamma} \ 14 | --pretrained_model_path /home/zhengchujie/pretrained-models/facebook/blenderbot-400M-distill \ 15 | --save_path checkpoints_bad/unlikelihood_${gamma} \ 16 | --train_data_path data_bad/labels/train.txt \ 17 | --max_input_length 128 \ 18 | --max_decoder_input_length 32 \ 19 | --seed 42 \ 20 | --adafactor \ 21 | --batch_size 64 \ 22 | --gradient_accumulation_steps 1 \ 23 | --learning_rate 5e-5 \ 24 | --num_epochs 2 \ 25 | --warmup_steps 50 26 | 27 | done 28 | -------------------------------------------------------------------------------- /scripts_cls/bad/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 accelerate launch \ 2 | --mixed_precision no \ 3 | --num_processes 1 \ 4 | --num_machines 1 \ 5 | --num_cpu_threads_per_process 32 \ 6 | infer.py \ 7 | --collator_name classification \ 8 | --model_name roberta \ 9 | --pretrained_model_path checkpoints_cls/bad \ 10 | --model_args 2 \ 11 | --save_path checkpoints_cls/bad \ 12 | --max_input_length 192 \ 13 | --seed 0 \ 14 | --batch_size 256 \ 15 | --infer_data_paths data_cls/bad/valid.txt data_cls/bad/test.txt \ 16 | --infer_names valid test 17 | -------------------------------------------------------------------------------- /scripts_cls/bad/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 accelerate launch \ 2 | --mixed_precision fp16 \ 3 | --num_processes 1 \ 4 | --num_machines 1 \ 5 | --num_cpu_threads_per_process 32 \ 6 | train.py \ 7 | --collator_name classification \ 8 | --model_name roberta \ 9 | --pretrained_model_path /home/zhengchujie/pretrained-models/roberta-base \ 10 | --model_args 2 \ 11 | --save_path checkpoints_cls/bad \ 12 | --train_data_path data_cls/bad/train.txt \ 13 | --max_input_length 192 \ 14 | --seed 42 \ 15 | --adafactor \ 16 | --batch_size 64 \ 17 | --gradient_accumulation_steps 1 \ 18 | --learning_rate 1e-5 \ 19 | --num_epochs 1 \ 20 | --warmup_steps 50 21 | -------------------------------------------------------------------------------- /scripts_senti/eval/dist.sh: -------------------------------------------------------------------------------- 1 | 2 | for domain in neutral positive negative; do 3 | 4 | paths=() 5 | for setting in pos neg; do 6 | paths[${#paths[*]}]="checkpoints_senti/${setting}_contrast_01/15.0/${domain}" 7 | done 8 | 9 | python eval_dist.py \ 10 | --context_file data_senti/gpt2/${domain}.txt \ 11 | --infer_data_paths ${paths[*]} \ 12 | 13 | done 14 | -------------------------------------------------------------------------------- /scripts_senti/eval/ppl_large.sh: -------------------------------------------------------------------------------- 1 | 2 | for domain in positive negative neutral; do 3 | 4 | paths=() 5 | for setting in pos neg; do 6 | paths[${#paths[*]}]="checkpoints_senti/${setting}_contrast_01/15.0/${domain}" 7 | done 8 | 9 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 10 | --mixed_precision no \ 11 | --num_processes 1 \ 12 | --num_machines 1 \ 13 | --num_cpu_threads_per_process 32 \ 14 | eval_ppl_gpt2.py \ 15 | --save_name large \ 16 | --pretrained_model_path /home/zhengchujie/pretrained-models-large/gpt2-xl \ 17 | --context_file data_senti/gpt2/${domain}.txt \ 18 | --infer_data_paths ${paths[*]} \ 19 | --max_input_length 15 \ 20 | --max_decoder_input_length 20 \ 21 | --seed 42 \ 22 | --batch_size 1600 23 | 24 | done 25 | -------------------------------------------------------------------------------- /scripts_senti/eval/ppl_self.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | eval_ppl_gpt2.py \ 8 | --save_name self \ 9 | --pretrained_model_path checkpoints_senti/gpt2_both \ 10 | --context_file data_senti/gpt2/augment.txt \ 11 | --infer_data_paths checkpoints_senti/gpt2_both/augment \ 12 | --max_input_length 2 \ 13 | --max_decoder_input_length 33 \ 14 | --seed 42 \ 15 | --batch_size 1600 16 | -------------------------------------------------------------------------------- /scripts_senti/eval/predict.sh: -------------------------------------------------------------------------------- 1 | 2 | for domain in neutral positive negative; do 3 | 4 | paths=() 5 | for setting in pos neg; do 6 | paths[${#paths[*]}]="checkpoints_senti/${setting}_contrast_01/15.0/${domain}" 7 | done 8 | 9 | done 10 | 11 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ 12 | --mixed_precision no \ 13 | --num_processes 1 \ 14 | --num_machines 1 \ 15 | --num_cpu_threads_per_process 32 \ 16 | eval_senti.py \ 17 | --collator_name classification \ 18 | --model_name distilbert \ 19 | --pretrained_model_path /home/zhengchujie/pretrained-models/distilbert-base-uncased-finetuned-sst-2-english \ 20 | --model_args 2 \ 21 | --context_file data_senti/gpt2/${domain}.txt \ 22 | --infer_data_paths ${paths[*]} \ 23 | --max_input_length 45 \ 24 | --seed 42 \ 25 | --batch_size 1600 26 | -------------------------------------------------------------------------------- /scripts_senti/eval/predict_train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,7 accelerate launch \ 3 | --mixed_precision no \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | eval_senti.py \ 8 | --collator_name classification \ 9 | --model_name distilbert \ 10 | --pretrained_model_path /home/zhengchujie/pretrained-models/distilbert-base-uncased-finetuned-sst-2-english \ 11 | --model_args 2 \ 12 | --context_file data_senti/gpt2/augment.txt \ 13 | --infer_data_paths checkpoints_senti/gpt2_both/augment \ 14 | --max_input_length 45 \ 15 | --seed 42 \ 16 | --batch_size 800 17 | -------------------------------------------------------------------------------- /scripts_senti/gpt2/augment.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=6 3 | seed=0 4 | 5 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 6 | --mixed_precision no \ 7 | --num_processes 1 \ 8 | --num_machines 1 \ 9 | --num_cpu_threads_per_process 32 \ 10 | generate.py \ 11 | --collator_name gpt2_eval \ 12 | --model_name gpt2 \ 13 | --pretrained_model_path checkpoints_senti/gpt2_both \ 14 | --save_path checkpoints_senti/gpt2_both \ 15 | --infer_data_paths data_senti/gpt2/augment.txt \ 16 | --infer_names augment \ 17 | --only_generate \ 18 | --max_input_length 2 \ 19 | --max_decoder_input_length 33 \ 20 | --seed ${seed} \ 21 | --max_length 33 \ 22 | --min_length 20 \ 23 | --batch_size 16 \ 24 | --temperature 1 \ 25 | --top_k 0 \ 26 | --top_p 0.9 \ 27 | --num_beams 1 \ 28 | --num_return_sequences 20 \ 29 | --length_penalty 1 \ 30 | --repetition_penalty 1 \ 31 | --no_repeat_ngram_size 0 \ 32 | --encoder_no_repeat_ngram_size 0 33 | -------------------------------------------------------------------------------- /scripts_senti/gpt2/train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=6 accelerate launch \ 3 | --mixed_precision fp16 \ 4 | --num_processes 1 \ 5 | --num_machines 1 \ 6 | --num_cpu_threads_per_process 32 \ 7 | train.py \ 8 | --collator_name gpt2 \ 9 | --model_name gpt2 \ 10 | --pretrained_model_path /home/zhengchujie/pretrained-models-large/gpt2-large \ 11 | --save_path checkpoints_senti/gpt2_both \ 12 | --train_data_path data_senti/gpt2/train_both.txt \ 13 | --max_input_length 35 \ 14 | --seed 42 \ 15 | --adafactor \ 16 | --batch_size 32 \ 17 | --gradient_accumulation_steps 1 \ 18 | --learning_rate 5e-5 \ 19 | --num_epochs 2 \ 20 | --warmup_steps 50 21 | -------------------------------------------------------------------------------- /scripts_senti/neg_contrast/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=2 3 | 4 | for model in 01; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision no \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | generate.py \ 13 | --collator_name gpt2_eval \ 14 | --model_name gpt2 \ 15 | --pretrained_model_path checkpoints_senti/neg_contrast2_${model}/${alpha} \ 16 | --save_path checkpoints_senti/neg_contrast2_${model}/${alpha} \ 17 | --infer_data_paths data_senti/gpt2/neutral.txt data_senti/gpt2/positive.txt \ 18 | --infer_names neutral positive \ 19 | --only_generate \ 20 | --max_input_length 15 \ 21 | --max_decoder_input_length 20 \ 22 | --seed 0 \ 23 | --max_length 20 \ 24 | --min_length 15 \ 25 | --batch_size 16 \ 26 | --temperature 1 \ 27 | --top_k 0 \ 28 | --top_p 0.9 \ 29 | --num_beams 1 \ 30 | --num_return_sequences 25 \ 31 | --length_penalty 1 \ 32 | --repetition_penalty 1 \ 33 | --no_repeat_ngram_size 0 \ 34 | --encoder_no_repeat_ngram_size 0 35 | 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /scripts_senti/neg_contrast/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for model in 01; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ 8 | --mixed_precision fp16 \ 9 | --num_processes 4 \ 10 | --multi_gpu \ 11 | --num_machines 1 \ 12 | --num_cpu_threads_per_process 32 \ 13 | --main_process_port 29504 \ 14 | train.py \ 15 | --collator_name gpt2_contrast \ 16 | --model_name gpt2_contrast_${model} \ 17 | --pretrained_model_path /home/zhengchujie/pretrained-models-large/gpt2-large \ 18 | --model_args ${alpha} \ 19 | --save_path checkpoints_senti/neg_contrast2_${model}/${alpha} \ 20 | --train_data_path data_senti/neg_contrast2/train.txt \ 21 | --max_input_length 35 \ 22 | --seed 42 \ 23 | --adafactor \ 24 | --batch_size 8 \ 25 | --gradient_accumulation_steps 2 \ 26 | --learning_rate 5e-5 \ 27 | --num_epochs 2 \ 28 | --warmup_steps 50 29 | 30 | done 31 | done 32 | -------------------------------------------------------------------------------- /scripts_senti/pos_contrast/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=1 3 | 4 | for model in 01; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision no \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | generate.py \ 13 | --collator_name gpt2_eval \ 14 | --model_name gpt2 \ 15 | --pretrained_model_path checkpoints_senti/pos_contrast2_${model}/${alpha} \ 16 | --save_path checkpoints_senti/pos_contrast2_${model}/${alpha} \ 17 | --infer_data_paths data_senti/gpt2/neutral.txt data_senti/gpt2/negative.txt \ 18 | --infer_names neutral negative \ 19 | --only_generate \ 20 | --max_input_length 15 \ 21 | --max_decoder_input_length 20 \ 22 | --seed 0 \ 23 | --max_length 20 \ 24 | --min_length 15 \ 25 | --batch_size 16 \ 26 | --temperature 1 \ 27 | --top_k 0 \ 28 | --top_p 0.9 \ 29 | --num_beams 1 \ 30 | --num_return_sequences 25 \ 31 | --length_penalty 1 \ 32 | --repetition_penalty 1 \ 33 | --no_repeat_ngram_size 0 \ 34 | --encoder_no_repeat_ngram_size 0 35 | 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /scripts_senti/pos_contrast/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for model in 01; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ 8 | --mixed_precision fp16 \ 9 | --num_processes 4 \ 10 | --multi_gpu \ 11 | --num_machines 1 \ 12 | --num_cpu_threads_per_process 32 \ 13 | --main_process_port 29504 \ 14 | train.py \ 15 | --collator_name gpt2_contrast \ 16 | --model_name gpt2_contrast_${model} \ 17 | --pretrained_model_path /home/zhengchujie/pretrained-models-large/gpt2-large \ 18 | --model_args ${alpha} \ 19 | --save_path checkpoints_senti/pos_contrast2_${model}/${alpha} \ 20 | --train_data_path data_senti/pos_contrast2/train.txt \ 21 | --max_input_length 35 \ 22 | --seed 42 \ 23 | --adafactor \ 24 | --batch_size 8 \ 25 | --gradient_accumulation_steps 2 \ 26 | --learning_rate 5e-5 \ 27 | --num_epochs 2 \ 28 | --warmup_steps 50 29 | 30 | done 31 | done 32 | 33 | bash scripts_senti/neg_contrast2/train.sh 34 | -------------------------------------------------------------------------------- /scripts_senti/stats_neg.sh: -------------------------------------------------------------------------------- 1 | 2 | setting=neg 3 | paths=() 4 | paths[${#paths[*]}]="checkpoints_senti/${setting}_contrast_01/15.0" 5 | 6 | python results_senti.py --infer_data_paths ${paths[*]} --negative 7 | -------------------------------------------------------------------------------- /scripts_senti/stats_pos.sh: -------------------------------------------------------------------------------- 1 | 2 | setting=pos 3 | paths=() 4 | paths[${#paths[*]}]="checkpoints_senti/${setting}_contrast_01/15.0" 5 | 6 | python results_senti.py --infer_data_paths ${paths[*]} --positive 7 | -------------------------------------------------------------------------------- /scripts_wiki/contrast/generate.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=3 3 | 4 | for model in 03; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 8 | --mixed_precision no \ 9 | --num_processes 1 \ 10 | --num_machines 1 \ 11 | --num_cpu_threads_per_process 32 \ 12 | generate.py \ 13 | --collator_name gpt2_eval \ 14 | --model_name gpt2 \ 15 | --pretrained_model_path checkpoints_wiki/contrast_${model}/${alpha} \ 16 | --save_path checkpoints_wiki/contrast_${model}/${alpha} \ 17 | --infer_data_paths data_wiki/gpt2/valid.txt data_wiki/gpt2/test.txt \ 18 | --infer_names valid_greedy test_greedy \ 19 | --max_input_length 35 \ 20 | --max_decoder_input_length 128 \ 21 | --seed 0 \ 22 | --max_length 128 \ 23 | --min_length 64 \ 24 | --batch_size 64 \ 25 | --temperature 1 \ 26 | --top_k 0 \ 27 | --top_p 1 \ 28 | --num_beams 1 \ 29 | --num_return_sequences 1 \ 30 | --length_penalty 1 \ 31 | --repetition_penalty 1 \ 32 | --no_repeat_ngram_size 0 \ 33 | --encoder_no_repeat_ngram_size 0 34 | 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /scripts_wiki/contrast/train.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=0 3 | 4 | for model in 03; do 5 | for alpha in 15.0; do 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ 8 | --mixed_precision fp16 \ 9 | --num_processes 4 \ 10 | --multi_gpu \ 11 | --num_machines 1 \ 12 | --num_cpu_threads_per_process 32 \ 13 | --main_process_port 29504 \ 14 | train.py \ 15 | --collator_name gpt2_contrast \ 16 | --model_name gpt2_contrast_${model} \ 17 | --pretrained_model_path checkpoints_wiki/gpt2 \ 18 | --model_args ${alpha} \ 19 | --save_path checkpoints_wiki/contrast_${model}/${alpha} \ 20 | --train_data_path data_wiki/contrast/train.txt \ 21 | --max_input_length 256 \ 22 | --max_decoder_input_length 160 \ 23 | --seed 42 \ 24 | --adafactor \ 25 | --batch_size 16 \ 26 | --gradient_accumulation_steps 2 \ 27 | --learning_rate 5e-5 \ 28 | --num_optim_steps 5000 \ 29 | --warmup_steps 200 30 | 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts_wiki/eval/div.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_wiki/gpt2/augment/train_0.5" 4 | paths[${#paths[*]}]="checkpoints_wiki/gpt2/augment/train_0.7" 5 | paths[${#paths[*]}]="checkpoints_wiki/gpt2/augment/train_0.9" 6 | 7 | python eval_div.py \ 8 | --context_file data_wiki/gpt2/train_augment.txt \ 9 | --infer_data_paths ${paths[*]} 10 | -------------------------------------------------------------------------------- /scripts_wiki/eval/mavue.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_wiki/contrast_03/15.0/test_greedy" 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python eval_mauve.py \ 6 | --context_file data_wiki/gpt2/test.txt \ 7 | --infer_data_paths ${paths[*]} \ 8 | --batch_size 320 9 | -------------------------------------------------------------------------------- /scripts_wiki/gpt2/augment_0.5.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=7 3 | p=0.5 4 | num=3 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name gpt2_eval \ 13 | --model_name gpt2 \ 14 | --pretrained_model_path checkpoints_wiki/gpt2 \ 15 | --save_path checkpoints_wiki/gpt2/augment/train_${p} \ 16 | --infer_data_paths data_wiki/gpt2/train_augment.txt \ 17 | --infer_names train \ 18 | --only_generate \ 19 | --max_input_length 32 \ 20 | --max_decoder_input_length 128 \ 21 | --seed 0 \ 22 | --max_length 128 \ 23 | --min_length 64 \ 24 | --batch_size 128 \ 25 | --temperature 1 \ 26 | --top_k 0 \ 27 | --top_p ${p} \ 28 | --num_beams 1 \ 29 | --num_return_sequences ${p} \ 30 | --length_penalty 1 \ 31 | --repetition_penalty 1 \ 32 | --no_repeat_ngram_size 0 \ 33 | --encoder_no_repeat_ngram_size 0 34 | -------------------------------------------------------------------------------- /scripts_wiki/gpt2/augment_0.7.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=7 3 | p=0.7 4 | num=4 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name gpt2_eval \ 13 | --model_name gpt2 \ 14 | --pretrained_model_path checkpoints_wiki/gpt2 \ 15 | --save_path checkpoints_wiki/gpt2/augment/train_${p} \ 16 | --infer_data_paths data_wiki/gpt2/train_augment.txt \ 17 | --infer_names train \ 18 | --only_generate \ 19 | --max_input_length 32 \ 20 | --max_decoder_input_length 128 \ 21 | --seed 0 \ 22 | --max_length 128 \ 23 | --min_length 64 \ 24 | --batch_size 128 \ 25 | --temperature 1 \ 26 | --top_k 0 \ 27 | --top_p ${p} \ 28 | --num_beams 1 \ 29 | --num_return_sequences ${p} \ 30 | --length_penalty 1 \ 31 | --repetition_penalty 1 \ 32 | --no_repeat_ngram_size 0 \ 33 | --encoder_no_repeat_ngram_size 0 34 | -------------------------------------------------------------------------------- /scripts_wiki/gpt2/augment_0.9.sh: -------------------------------------------------------------------------------- 1 | 2 | cuda=7 3 | p=0.9 4 | num=5 5 | 6 | CUDA_VISIBLE_DEVICES=${cuda} accelerate launch \ 7 | --mixed_precision no \ 8 | --num_processes 1 \ 9 | --num_machines 1 \ 10 | --num_cpu_threads_per_process 32 \ 11 | generate.py \ 12 | --collator_name gpt2_eval \ 13 | --model_name gpt2 \ 14 | --pretrained_model_path checkpoints_wiki/gpt2 \ 15 | --save_path checkpoints_wiki/gpt2/augment/train_${p} \ 16 | --infer_data_paths data_wiki/gpt2/train_augment.txt \ 17 | --infer_names train \ 18 | --only_generate \ 19 | --max_input_length 32 \ 20 | --max_decoder_input_length 128 \ 21 | --seed 0 \ 22 | --max_length 128 \ 23 | --min_length 64 \ 24 | --batch_size 128 \ 25 | --temperature 1 \ 26 | --top_k 0 \ 27 | --top_p ${p} \ 28 | --num_beams 1 \ 29 | --num_return_sequences ${p} \ 30 | --length_penalty 1 \ 31 | --repetition_penalty 1 \ 32 | --no_repeat_ngram_size 0 \ 33 | --encoder_no_repeat_ngram_size 0 34 | -------------------------------------------------------------------------------- /scripts_wiki/gpt2/train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch \ 3 | --mixed_precision fp16 \ 4 | --num_processes 4 \ 5 | --multi_gpu \ 6 | --num_machines 1 \ 7 | --num_cpu_threads_per_process 32 \ 8 | --main_process_port 29504 \ 9 | train.py \ 10 | --collator_name gpt2 \ 11 | --model_name gpt2 \ 12 | --pretrained_model_path /home/zhengchujie/pretrained-models/gpt2-small \ 13 | --save_path checkpoints_wiki/gpt2 \ 14 | --train_data_path data_wiki/gpt2/train.txt \ 15 | --max_input_length 256 \ 16 | --seed 42 \ 17 | --adafactor \ 18 | --batch_size 32 \ 19 | --gradient_accumulation_steps 1 \ 20 | --learning_rate 2e-5 \ 21 | --num_optim_steps 40000 \ 22 | --warmup_steps 4000 23 | -------------------------------------------------------------------------------- /scripts_wiki/stats.sh: -------------------------------------------------------------------------------- 1 | 2 | paths=() 3 | paths[${#paths[*]}]="checkpoints_wiki/contrast_03/15.0/test_greedy" 4 | 5 | python results_wiki.py --infer_data_paths ${paths[*]} --test 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/building_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/building_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cuda_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/cuda_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/dataloader_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/eval_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chujiezheng/Click/494e90434e76432a4b1da65fb2efe1dd992ebeba/utils/__pycache__/model_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/building_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import json 4 | import logging 5 | import os 6 | 7 | from importlib import import_module 8 | import torch 9 | from transformers import (AutoTokenizer, AutoConfig) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def boolean_string(s): 15 | if s.lower() not in {'false', 'true'}: 16 | raise ValueError('Not a valid boolean string') 17 | return s.lower() == 'true' 18 | 19 | 20 | def build_model(args, only_toker=False, checkpoint=None, process_index=-1): 21 | # blenderbot tokenizer would add a mask token by default, so we abandon it 22 | if 'blenderbot-' in args.pretrained_model_path: 23 | toker = AutoTokenizer.from_pretrained(args.pretrained_model_path, mask_token=None, use_fast=False) 24 | else: 25 | toker = AutoTokenizer.from_pretrained(args.pretrained_model_path, use_fast=False) 26 | if only_toker: 27 | return toker 28 | 29 | # import the model from ``models'' 30 | Model = getattr(import_module('models.' + args.model_name), 'Model') 31 | model = Model.from_pretrained(args.pretrained_model_path, *args.model_args) 32 | if hasattr(args, 'gradient_checkpointing') and args.gradient_checkpointing: 33 | model.gradient_checkpointing_enable() 34 | model.tie_tokenizer_and_post_init(toker, process_index) 35 | 36 | if checkpoint is not None and os.path.exists(checkpoint): 37 | if process_index == 0: 38 | logger.info('loading finetuned model from %s' % checkpoint) 39 | model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu')), strict=False) 40 | 41 | return toker, model 42 | -------------------------------------------------------------------------------- /utils/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import math 4 | from functools import partial 5 | import nltk 6 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, BatchSampler 7 | 8 | 9 | def _norm(s): 10 | return ' '.join(s.strip().split()) 11 | 12 | 13 | def _norm1(x): 14 | x = " ".join(x.strip().split()) 15 | xs = [] 16 | for e in nltk.sent_tokenize(x): 17 | xs.append(e.capitalize()) 18 | x = ' '.join(xs) 19 | return x 20 | 21 | 22 | def _norm2(x: str): 23 | return " ".join(x.strip().split()).replace(" l o l ", " lol ").replace(" ' m ", "'m ")\ 24 | .replace(" ’ t ", "'t ").replace(" ’ ll ", "'ll ").replace(" ’ s ", "'s ").replace(" ’ ve ", "'ve ").replace(" ’ re ", "'re ")\ 25 | .replace(" ' t ", "'t ").replace(" ' ll ", "'ll ").replace(" ' s ", "'s ").replace(" ' ve ", "'ve ").replace(" ' re ", "'re ")\ 26 | .replace(" .", ".").replace(" ,", ",").replace(" ?", '?').replace(" !", '!') 27 | 28 | 29 | class BasicDataset(Dataset): 30 | def __init__(self, data_list): 31 | self.data_list = data_list 32 | 33 | def __getitem__(self, i): 34 | return self.data_list[i] 35 | 36 | def __len__(self): 37 | return len(self.data_list) 38 | 39 | 40 | class BatchDataLoader(DataLoader): 41 | def __init__(self, 42 | data_list=None, data_path=None, batch_size=None, 43 | collate_fn=None, shuffle=True, num_workers=16, 44 | ): 45 | if data_list is None: 46 | data_list = [json.loads(e) for e in open(data_path)] 47 | dataset = BasicDataset(data_list) 48 | basic_sampler = RandomSampler if shuffle else SequentialSampler 49 | sampler = BatchSampler(basic_sampler(dataset), batch_size=batch_size, drop_last=False) 50 | super().__init__(dataset, batch_sampler=sampler, num_workers=num_workers, collate_fn=collate_fn) 51 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | from torch import Tensor 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @torch.no_grad() 14 | def eval_model_loss(accelerator, model, eval_dataloader, epoch_id, infer): 15 | # use the same signature with eval_model_generation 16 | if accelerator.process_index == 0: 17 | logger.info('compute eval model loss, using eval mode, ' 18 | 'please change it back to train after calling this function') 19 | model.eval() 20 | tot_loss = 0. 21 | tot_acc = 0. 22 | tot_rep = 0. 23 | tot_wrep = 0. 24 | tot_sample = 0 25 | pointwise_loss = [] 26 | pointwise_sample = [] 27 | with torch.no_grad(): 28 | if accelerator.process_index == 0: 29 | pbar = tqdm(eval_dataloader, total=len(eval_dataloader), desc='evaluation', dynamic_ncols=True, leave=True) 30 | else: 31 | pbar = eval_dataloader 32 | for batch in pbar: 33 | loss_sample, n_sample, acc, rep, wrep, *_ = model( 34 | validation=True, 35 | **batch 36 | ) 37 | if torch.isnan(loss_sample).sum().cpu().long().numpy() > 0: 38 | logger.info(f'process_index {accelerator.process_index}: NaN occurring!') 39 | exit() 40 | tot_loss += loss_sample.sum() 41 | tot_acc += acc.sum() 42 | tot_rep += rep.sum() 43 | tot_wrep += wrep.sum() 44 | tot_sample += n_sample.sum() 45 | if infer: 46 | pointwise_loss.extend(loss_sample.cpu().tolist()) 47 | pointwise_sample.extend(n_sample.cpu().tolist()) 48 | 49 | if accelerator.process_index == 0: 50 | tot_loss = accelerator.reduce(tot_loss) 51 | tot_sample = accelerator.reduce(tot_sample) 52 | 53 | tot_loss = np.sum(tot_loss.cpu().float().numpy()) 54 | tot_acc = np.sum(tot_acc.cpu().float().numpy()) 55 | tot_rep = np.sum(tot_rep.cpu().float().numpy()) 56 | tot_wrep = np.sum(tot_wrep.cpu().float().numpy()) 57 | tot_sample = np.sum(tot_sample.cpu().float().numpy()) 58 | mean_loss = tot_loss / tot_sample 59 | mean_ppl = np.exp(mean_loss) 60 | mean_acc = tot_acc / tot_sample * 100 61 | mean_rep = tot_rep / tot_sample * 100 62 | mean_wrep = tot_wrep / tot_sample * 100 63 | if accelerator.process_index == 0: 64 | logger.info(f"Epoch {epoch_id}: Val loss {mean_loss} Val ppl {mean_ppl}") 65 | return mean_loss, mean_ppl, mean_acc, mean_rep, mean_wrep, pointwise_loss, pointwise_sample 66 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import logging 3 | 4 | import torch 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | from transformers.configuration_utils import PretrainedConfig 7 | from transformers.modeling_utils import PreTrainedModel 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BaseModel(PreTrainedModel): 13 | def __init__(self, config: PretrainedConfig): 14 | super().__init__(config) 15 | self.toker = None 16 | 17 | def tie_tokenizer_and_post_init(self, toker: PreTrainedTokenizer, process_index=0): 18 | # tying tokenizer is useful 19 | self.toker = toker 20 | old_num_tokens, _ = self.get_input_embeddings().weight.size() 21 | if len(self.toker) != old_num_tokens: 22 | self.resize_token_embeddings(len(self.toker)) 23 | if process_index == 0: 24 | logger.info(f'resize token embeddings from {old_num_tokens} to {len(self.toker)}') 25 | self.resize_token_embeddings(len(self.toker)) 26 | #self.init_new_tokens() 27 | self.init_new_tokens_with_semantic() 28 | self.init_new_layers() 29 | 30 | def init_new_tokens(self): 31 | # if we add new tokens, initialize them here 32 | pass 33 | 34 | def init_new_tokens_with_semantic(self): 35 | # we may need to initialize newly added tokens, with semantic initialization 36 | process = lambda x: self.toker.convert_tokens_to_ids(self.toker.tokenize(x)) 37 | for i in range(self.toker.vocab_size, len(self.toker)): 38 | token = self.toker.convert_ids_to_tokens([i])[0] 39 | token = token[1:-1] 40 | ids = torch.LongTensor(process(token)) 41 | embeds = torch.index_select(self.get_input_embeddings().weight.data.detach(), 0, ids) 42 | embeds = torch.mean(embeds, dim=0) 43 | self.get_input_embeddings().weight.data[i] = embeds 44 | self.tie_weights() 45 | 46 | def init_new_layers(self): 47 | # if we add new layers, initialize them here 48 | pass 49 | 50 | @torch.no_grad() 51 | def generate( 52 | self, 53 | input_ids=None, 54 | attention_mask=None, 55 | decoder_input_ids=None, 56 | **kwargs 57 | ): 58 | """ 59 | (input_ids, attention_mask, decoder_input_ids) 60 | (input_ids, attention_mask) 61 | """ 62 | assert not self.training 63 | assert self.toker is not None 64 | assert input_ids.size(0) == 1 or ((input_ids is None) == (attention_mask is None)) 65 | if not kwargs.get('min_length', None): 66 | raise KeyError 67 | if not kwargs.get('max_new_tokens', None) and not kwargs.get('max_length', None): 68 | raise KeyError 69 | kwargs['use_cache'] = True 70 | 71 | # bad_words_ids 72 | bad_words_ids = kwargs.get('bad_words_ids', []) 73 | for e in [self.toker.pad_token_id, self.toker.unk_token_id, self.toker.bos_token_id]: 74 | if e is not None and e != self.toker.eos_token_id: 75 | bad_words_ids.append([e]) 76 | if len(self.toker) > self.toker.vocab_size: 77 | bad_words_ids.extend([[i] for i in range(self.toker.vocab_size, len(self.toker))]) 78 | if bad_words_ids: 79 | kwargs['bad_words_ids'] = bad_words_ids 80 | 81 | # prepare the prefix ids for generation, and use prefix_length to truncate the generation output 82 | if self.config.is_encoder_decoder: 83 | if decoder_input_ids is not None: 84 | kwargs['decoder_input_ids'] = decoder_input_ids 85 | prefix_length = decoder_input_ids.size(1) 86 | else: 87 | prefix_length = 1 # removing bos 88 | else: 89 | prefix_length = input_ids.size(1) if input_ids is not None else 1 90 | 91 | # generation length 92 | kwargs['min_length'] = prefix_length + kwargs.get('min_length', 1) 93 | if kwargs.get('max_new_tokens', None): 94 | kwargs['max_length'] = prefix_length + kwargs['max_new_tokens'] 95 | kwargs.pop('max_new_tokens') 96 | 97 | # generate! 98 | generations = super().generate( 99 | input_ids=input_ids, 100 | attention_mask=attention_mask, 101 | **kwargs 102 | ) 103 | if kwargs.get('return_dict_in_generate', False): 104 | generations.sequences = generations.sequences[:, prefix_length:] 105 | return generations 106 | else: 107 | return generations[:, prefix_length:] 108 | --------------------------------------------------------------------------------