├── LICENSE ├── README.md ├── classify ├── eval.py ├── indep_eval.sh └── logs │ └── AE_final.txt ├── multicontrol ├── generate_config_final.json ├── generate_multi.py ├── generate_multi.sh ├── generation_utils.py ├── model.py ├── requirements.txt ├── train_multi.py └── train_multi.sh ├── priorcontrol ├── combine_eval.sh ├── generate_combine.py ├── generate_combine.sh ├── generate_combine_optim.py ├── generate_config_combine.json ├── generate_config_combine_optim.json ├── generate_prior.py ├── generate_prior.sh ├── latentops_modules.py ├── model.py ├── requirements.txt ├── single_eval.sh ├── train_prior_only.py └── train_prior_only.sh └── res ├── multicontrol └── predict_final.txt └── priorcontrol ├── generate_combination.txt ├── generate_combination_optim.txt ├── generate_combination_optimcons.txt ├── generate_prior.txt └── generate_prior_extend.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 HappyGu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiControl & PriorControl 2 | **MultiControl**: [A Distributional Lens for Multi-Aspect Controllable Text Generation](https://arxiv.org/pdf/2210.02889.pdf) *EMNLP 2022 Oral* 3 | 4 | **PriorControl**: [Controllable Text Generation via Probability Density Estimation in the Latent Space](https://arxiv.org/pdf/2212.08307.pdf) *ACL 2023* 5 | 6 | 7 | ## File structure 8 | ``` 9 | ├── LICENSE 10 | ├── README.md 11 | ├── classify 12 | │   ├── eval.py # Evaluation metric 13 | │   ├── indep_eval.sh # Evaluation shell 14 | │   ├── logs 15 | │   │   └── AE_final.txt # Final Score for MultiControl 16 | │   └── model # Evaluator checkpoints 17 | │   ├── AGnews-checkpoint-6000 # Topic Evaluation 18 | │  │  ├── config.json 19 | │ │ └── pytorch_model.bin 20 | │   ├── Toxic-checkpoint-3000 # Toxic Evaluation 21 | # Noting that this is only for validation. 22 | # We use Perspective API for final toxic evaluation 23 | │  │  ├── config.json 24 | │ │ └── pytorch_model.bin 25 | │   └── Yelp2-checkpoint-64000 # Sentiment Evaluation 26 | │  ├── config.json 27 | │ └── pytorch_model.bin 28 | ├── data # Training data 29 | ├── model 30 | │   └── multicontrol # Model Folder for MultiControl 31 | │   │ └── checkpoint-30000 # Model Checkpoint 32 | │ # We preserve the parameters of Encoder, Fixed Decoder, and Mapping Layer all together for simplicity. 33 | │ # You can only keep the searched Intersection Prefixes and drop out other parts. 34 | │ │ └── pytorch_model.bin 35 | │   └── priorcontrol # Model Folder for PriorControl 36 | │   └── All_notricks_checkpoint-300000 37 | # We preserve the parameters of Encoder, Fixed Decoder, Mapping Layer, and Normalizing Flows all together for simplicity. 38 | # You need to keep the Prior Heads and Mapping Layer before drop out other parts. 39 | │ └── pytorch_model.bin 40 | ├── multicontrol # Code and Script Folder for MultiControl 41 | │ ├── generate_config_final.json # Config for Multi-Aspect Controllable Generation 42 | │ ├── generate_multi.py # Generation file 43 | │ ├── generate_multi.sh # Generation shell 44 | │ ├── generation_utils.py # Intersection Searching Algorithm 45 | │ ├── model.py # Our model with AE structure, PrefixTuning strategy and all losses for Attribute Space 46 | │ ├── requirements.txt # All based on huggingface/transformers 47 | │ ├── train_multi.py # Training file 48 | │ └── train_multi.sh # Training shell 49 | ├── priorcontrol # Code and Script Folder for PriorControl 50 | │ ├── logs # Folder for Final Scores 51 | │ ├── combine_eval.sh # Evaluation Script for Multi-Attribute Control 52 | │ ├── generate_combine_optim.py # Code for Multi-Attribute Control via Optimization 53 | │ ├── generate_combine.py # Code for Multi-Attribute Control via Interpolation 54 | │ ├── generate_combine.sh # Script for Multi-Attribute Control 55 | │ ├── generate_config_combine_optim.json # Config for Multi-Aspect Control via Optimization 56 | │ ├── generate_config_combine.json # Config for Multi-Aspect Control 57 | │ ├── generate_prior.py # Code for Single-Attribute Control 58 | │ ├── generate_prior.sh # Script for Single-Attribute Control 59 | │ ├── latentops_modules.py # Code Modified from LatentOps which provides ODE for Optimization 60 | │ ├── model.py # Our model with AE structure, PrefixTuning strategy, Normalizing Flows, and all losses for Attribute and Prior Space 61 | │ ├── requirements.txt # Framework: huggingface/transformers, Normalizing Flows: FrEIA, ODE Optimization: torchdiffeq 62 | │ ├── single_eval.sh # Evaluation Script for Single-Attribute Control 63 | │ ├── train_prior_only.py # Training file for Normalizing Flows with fixed AE 64 | │ └── train_prior_only.sh # Training Script 65 | └── res 66 |    └── multicontrol 67 |    │ └── predict_final.txt # Final Generated Sentences for MultiControl 68 |    └── priorcontrol # Final Generated Sentences for PriorControl 69 |    ├── generate_combination_optim.txt 70 | ├── generate_combination_optimcons.txt 71 | ├── generate_combination.txt 72 | ├── generate_prior_extend.txt 73 | └── generate_prior.txt 74 | 75 | ``` 76 | ## MultiControl 77 | ### Train 78 | ``` 79 | sh train_multi.sh 80 | ``` 81 | ### Generate & Test 82 | ``` 83 | sh generate_multi.sh 84 | ``` 85 | 86 | ## PriorControl 87 | ### Train 88 | ``` 89 | sh train_prior_only.sh 90 | ``` 91 | ### Generate & Test 92 | ``` 93 | ### Single-Attribute Control ### 94 | 95 | sh generate_prior.sh 96 | sh single_eval.sh 97 | 98 | ### Multi-Attribute Control ### 99 | 100 | sh generate_combine.sh 101 | sh combine_eval.sh 102 | ``` 103 | 104 | ## Model & Data 105 | 106 | Checkpoint available at: 107 | https://drive.google.com/drive/folders/14XHSG4IAGlAL9t-SYoTUKnAs5ARqHd5f?usp=sharing 108 | 109 | 110 | ## Cite Us 111 | ### MultiControl 112 | ``` 113 | @inproceedings{gu-etal-2022-distributional, 114 | title = "A Distributional Lens for Multi-Aspect Controllable Text Generation", 115 | author = "Gu, Yuxuan and 116 | Feng, Xiaocheng and 117 | Ma, Sicheng and 118 | Zhang, Lingyuan and 119 | Gong, Heng and 120 | Qin, Bing", 121 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 122 | month = dec, 123 | year = "2022", 124 | address = "Abu Dhabi, United Arab Emirates", 125 | publisher = "Association for Computational Linguistics", 126 | url = "https://aclanthology.org/2022.emnlp-main.67", 127 | pages = "1023--1043", 128 | } 129 | ``` 130 | ### PriorControl 131 | ``` 132 | @inproceedings{gu-etal-2023-controllable, 133 | title = "Controllable Text Generation via Probability Density Estimation in the Latent Space", 134 | author = "Gu, Yuxuan and 135 | Feng, Xiaocheng and 136 | Ma, Sicheng and 137 | Zhang, Lingyuan and 138 | Gong, Heng and 139 | Zhong, Weihong and 140 | Qin, Bing", 141 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 142 | month = jul, 143 | year = "2023", 144 | address = "Toronto, Canada", 145 | publisher = "Association for Computational Linguistics", 146 | url = "https://aclanthology.org/2023.acl-long.704", 147 | pages = "12590--12616", 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /classify/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import transformers 4 | from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification 5 | import argparse 6 | import datasets 7 | from datasets import load_dataset, load_metric, concatenate_datasets, Dataset 8 | from transformers import Trainer, TrainingArguments 9 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 10 | import json 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--file_loc", type=str, default="../res/AE/predict.txt") 15 | parser.add_argument( 16 | "--specify", 17 | type=str, 18 | default=None 19 | #json.dumps([1,1,1]) 20 | ) 21 | 22 | args = parser.parse_args() 23 | 24 | def compute_metrics(pred): 25 | labels = torch.tensor(pred.label_ids).long() 26 | preds = torch.softmax(torch.tensor(pred.predictions),dim=-1) 27 | probs = torch.gather(preds, 1,labels.view(-1, 1)) 28 | acc = torch.mean(probs).item() 29 | #print(labels) 30 | #print(preds) 31 | #print(probs) 32 | #print(acc) 33 | #raise Exception('test') 34 | #preds = pred.predictions.argmax(-1) 35 | 36 | #precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro') 37 | #acc = accuracy_score(labels, preds) 38 | return { 39 | 'accuracy': acc, 40 | #'f1': f1, 41 | #'precision': precision, 42 | #'recall': recall 43 | } 44 | 45 | dataset = {'label':[], 'sent':[]} 46 | with open(args.file_loc, 'r') as f: 47 | for line in f.readlines(): 48 | label, sent = json.loads(line.strip()) 49 | dataset['label'].append(label) 50 | dataset['sent'].append(sent.strip()) 51 | 52 | 53 | dataset = Dataset.from_dict(dataset) 54 | 55 | tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') 56 | 57 | 58 | model_list = ['model/Yelp2-checkpoint-64000', 'model/AGnews-checkpoint-6000', 'model/Toxic-checkpoint-3000'] 59 | topics = ["world","sports","business","science"] 60 | task_list = ['sentiment', 'topic', 'detoxification'] 61 | 62 | test_args = TrainingArguments( 63 | output_dir='logs', 64 | do_train = False, 65 | do_predict = True, 66 | no_cuda = False, 67 | per_device_eval_batch_size=64, 68 | dataloader_drop_last = False, 69 | report_to='none' 70 | ) 71 | 72 | train_out = {} 73 | 74 | if args.specify is not None: 75 | specify = json.loads(args.specify) 76 | dataset = dataset.filter(lambda e: e['label'] == specify) 77 | 78 | for i in range(3): 79 | if args.specify is not None: 80 | if specify[i] == -1: 81 | continue 82 | model = DebertaV2ForSequenceClassification.from_pretrained(model_list[i], num_labels=2) 83 | eval_dataset = None 84 | if 'AGnews' in model_list[i]: 85 | eval_dataset = dataset.map(lambda e: tokenizer(topics[e['label'][i]]+'[SEP]'+e['sent'], truncation=True, padding='max_length', max_length=100)) 86 | eval_dataset = eval_dataset.map(lambda e: {'labels': 1}) 87 | else: 88 | eval_dataset = dataset.map(lambda e: tokenizer(e['sent'], truncation=True, padding='max_length', max_length=100), batched=True) 89 | eval_dataset = eval_dataset.map(lambda e: {'labels': e['label'][i]}) 90 | eval_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels']) 91 | 92 | trainer = Trainer( 93 | model=model, 94 | args = test_args, 95 | compute_metrics=compute_metrics, 96 | ) 97 | train_out[task_list[i]] = trainer.evaluate(eval_dataset)['eval_accuracy'] 98 | 99 | 100 | 101 | print(train_out) -------------------------------------------------------------------------------- /classify/indep_eval.sh: -------------------------------------------------------------------------------- 1 | if [ $# == 2 ]; then 2 | file_loc=$1 3 | output_loc=$2 4 | else 5 | file_loc=../res/AE/predict.txt 6 | output_loc=logs/indep_AE.txt 7 | fi 8 | 9 | for i in 0 1 10 | do 11 | for j in 0 1 2 3 12 | do 13 | echo [$i,$j,1] >> $output_loc 14 | python eval.py --file_loc $file_loc --specify [$i,$j,1] >> $output_loc 15 | done; 16 | done; 17 | echo average >> $output_loc 18 | python eval.py --file_loc $file_loc >> $output_loc -------------------------------------------------------------------------------- /classify/logs/AE_final.txt: -------------------------------------------------------------------------------- 1 | [0,0,1] 2 | {'sentiment': 0.6965930461883545, 'topic': 0.7166248559951782, 'detoxification': 0.941453754901886} 3 | [0,1,1] 4 | {'sentiment': 0.7853208780288696, 'topic': 0.8000810742378235, 'detoxification': 0.839763343334198} 5 | [0,2,1] 6 | {'sentiment': 0.9996450543403625, 'topic': 0.9672621488571167, 'detoxification': 0.9942843914031982} 7 | [0,3,1] 8 | {'sentiment': 0.9282599687576294, 'topic': 0.9800723791122437, 'detoxification': 0.9244822263717651} 9 | [1,0,1] 10 | {'sentiment': 0.8048356175422668, 'topic': 0.5795329809188843, 'detoxification': 0.986760675907135} 11 | [1,1,1] 12 | {'sentiment': 0.8474857211112976, 'topic': 0.8657174110412598, 'detoxification': 0.9999216794967651} 13 | [1,2,1] 14 | {'sentiment': 0.8755072951316833, 'topic': 0.9169918298721313, 'detoxification': 0.9999983310699463} 15 | [1,3,1] 16 | {'sentiment': 0.996782660484314, 'topic': 0.9614540338516235, 'detoxification': 0.9961794018745422} 17 | average 18 | {'sentiment': 0.866803765296936, 'topic': 0.8484671711921692, 'detoxification': 0.9603554606437683} -------------------------------------------------------------------------------- /multicontrol/generate_config_final.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight":{ 3 | "default":[2,7,1], 4 | "01":[2,4,1], 5 | "02":[2,8,1], 6 | "03":[3,1,3], 7 | "10":[2,12,1], 8 | "11":[3,5.5,1], 9 | "12":[2,9,1], 10 | "13":[3,1,1] 11 | }, 12 | 13 | "num_output_centers":[ 14 | [1,1,5,1], 15 | [10,10,5,1] 16 | ] 17 | } -------------------------------------------------------------------------------- /multicontrol/generate_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | import datasets 6 | from datasets import load_dataset, load_metric, concatenate_datasets, Dataset 7 | from tqdm import tqdm 8 | import json 9 | from sklearn.cluster import KMeans 10 | import random 11 | import numpy as np 12 | 13 | from generation_utils import KCenters 14 | from model import AE 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 19 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 20 | parser.add_argument("--no_cuda", action="store_true") 21 | parser.add_argument("--latent_size", type=int, default=768) 22 | parser.add_argument("--latent_num",type=int, default=1) 23 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 24 | parser.add_argument("--model_path", type=str, default='../model/multicontrol/checkpoint-30000/pytorch_model.bin') 25 | parser.add_argument("--output_dir", type=str, default="../res/multicontrol/predict_final.txt") 26 | parser.add_argument("--batch_size", type=int, default=60) 27 | parser.add_argument("--pre_tokens", 28 | type=str, 29 | default=json.dumps( 30 | ['In summary','This essay discusses','Views on','The connection','Foundational to this is', 31 | 'To review,','In brief,','An illustration of','Furthermore,','The central theme', 32 | 'To conclude,','The key aspect','Prior to this','Emphasised are','To summarise', 33 | 'The relationship','More importantly,','It has been shown','The issue focused on','In this essay', 34 | 'Once upon a time','The book','The chicken','The city','The country', 35 | 'The horse','The lake','The last time','The movie','The painting', 36 | 'The pizza','The potato','The president of the country','The road','The year is 1910'] 37 | ) 38 | ) 39 | parser.add_argument("--max_length", type=int, default=100) 40 | parser.add_argument("--seed", type=int, default=0) 41 | 42 | parser.add_argument("--variation", type=float, default=1e-3) 43 | 44 | #Parameters for KCenters 45 | parser.add_argument("--num_centers", type=int, default=1000) 46 | parser.add_argument("--num_output_centers", type=int, default=10) 47 | parser.add_argument("--topk", type=int, default=200) 48 | parser.add_argument("--batch", type=int, default=5) 49 | parser.add_argument("--max_iter", type=int, default=15) 50 | parser.add_argument("--strategy", type=str, default='none', choices=('none', 'weight')) 51 | parser.add_argument("--temperature", type=float, default=50) 52 | parser.add_argument("--SDM_reinit", type=bool, default=True) 53 | parser.add_argument("--weight", 54 | type=str, 55 | default=json.dumps( 56 | [1,5,1] 57 | ) 58 | ) 59 | parser.add_argument("--config", type=str, default=None) 60 | 61 | args = parser.parse_args() 62 | 63 | weight = json.loads(args.weight) 64 | 65 | if args.config is not None: 66 | with open(args.config, 'r') as f: 67 | config = json.loads(f.read()) 68 | for keys in config: 69 | if keys == 'weight': 70 | weight = config['weight'] 71 | if keys == 'num_output_centers': 72 | args.num_output_centers = config['num_output_centers'] 73 | 74 | 75 | if isinstance(weight, dict): 76 | default_weight = weight['default'] 77 | weight_dict = [[default_weight for jt in range(4)]for it in range(2)] 78 | for keys in weight: 79 | if keys != 'default': 80 | tmp_i = int(keys[0]) 81 | tmp_j = int(keys[1]) 82 | weight_dict[tmp_i][tmp_j] = weight[keys] 83 | else: 84 | weight_dict = [[weight for jt in range(4)]for it in range(2)] 85 | 86 | 87 | if isinstance(args.num_output_centers, int): 88 | args.num_output_centers = [[args.num_output_centers]*4]*2 89 | 90 | 91 | 92 | 93 | 94 | encoder_tokenizer = BertTokenizer.from_pretrained(args.pretrained_encoder) 95 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 96 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 97 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 98 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 99 | 100 | model = AE(encoder=encoder, decoder=decoder, args=args) 101 | model.load_state_dict(torch.load(args.model_path), strict=False) 102 | model.eval() 103 | 104 | random.seed(args.seed) 105 | np.random.seed(args.seed) 106 | torch.manual_seed(args.seed) 107 | 108 | 109 | if args.no_cuda: 110 | device='cpu' 111 | else: 112 | device='cuda' 113 | 114 | model.to(device) 115 | 116 | imdb_dataset = [{'sent':[]} for i in range(2)] 117 | ag_dataset = [{'sent':[]} for i in range(4)] 118 | toxic_dataset = [{'sent':[]} for i in range(2)] 119 | 120 | with open('../data/IMDb/IMDb.txt', 'r') as f: 121 | for line in f.readlines(): 122 | line = json.loads(line) 123 | label = int(line[0]) 124 | imdb_dataset[label]['sent'].append(line[1].strip()) 125 | 126 | with open('../data/ToxicComment/Toxic.txt', 'r') as f: 127 | for line in f.readlines(): 128 | line = json.loads(line) 129 | label = int(line[0]) 130 | toxic_dataset[label]['sent'].append(line[1].strip()) 131 | 132 | with open('../data/AGnews/AG-data.txt', 'r') as f: 133 | for line in f.readlines(): 134 | line = json.loads(line) 135 | label = int(line[0]) 136 | ag_dataset[label]['sent'].append(line[1].strip()) 137 | label = int(line[0]) 138 | ag_dataset[label]['sent'].append(line[1].strip()) 139 | 140 | 141 | imdb_dataset = [Dataset.from_dict(i) for i in imdb_dataset] 142 | ag_dataset = [Dataset.from_dict(i) for i in ag_dataset] 143 | toxic_dataset = [Dataset.from_dict(i) for i in toxic_dataset] 144 | 145 | 146 | 147 | imdb_dataloader = [] 148 | for dataset in imdb_dataset: 149 | tmp_dataset = dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 150 | tmp_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids']) 151 | imdb_dataloader.append(torch.utils.data.DataLoader(tmp_dataset, batch_size=32)) 152 | 153 | ag_dataloader = [] 154 | for dataset in ag_dataset: 155 | tmp_dataset = dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 156 | tmp_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids']) 157 | ag_dataloader.append(torch.utils.data.DataLoader(tmp_dataset, batch_size=32)) 158 | 159 | toxic_dataloader = [] 160 | for dataset in toxic_dataset: 161 | tmp_dataset = dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 162 | tmp_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids']) 163 | toxic_dataloader.append(torch.utils.data.DataLoader(tmp_dataset, batch_size=32)) 164 | 165 | 166 | 167 | not_latents = None 168 | sentiment_latents = {0:None, 1:None} 169 | topic_latents = {0:None, 1:None, 2:None, 3:None} 170 | 171 | for i in range(2): 172 | for cnt in tqdm(iter(imdb_dataloader[i])): 173 | encoder_input_ids = cnt['input_ids'] 174 | encoder_attention_mask = cnt['attention_mask'] 175 | encoder_token_type_ids = cnt['token_type_ids'] 176 | 177 | latent, encoder_output, past_key_values = model.encode(encoder_input_ids, encoder_attention_mask, encoder_token_type_ids) 178 | if sentiment_latents[i] is None: 179 | sentiment_latents[i] = latent.squeeze().detach() 180 | else: 181 | sentiment_latents[i] = torch.cat((sentiment_latents[i], latent.squeeze().detach()), dim=0) 182 | 183 | for i in range(4): 184 | for cnt in tqdm(iter(ag_dataloader[i])): 185 | encoder_input_ids = cnt['input_ids'] 186 | encoder_attention_mask = cnt['attention_mask'] 187 | encoder_token_type_ids = cnt['token_type_ids'] 188 | 189 | latent, encoder_output, past_key_values = model.encode(encoder_input_ids, encoder_attention_mask, encoder_token_type_ids) 190 | if topic_latents[i] is None: 191 | topic_latents[i] = latent.squeeze().detach() 192 | else: 193 | topic_latents[i] = torch.cat((topic_latents[i], latent.squeeze().detach()), dim=0) 194 | 195 | 196 | for cnt in tqdm(iter(toxic_dataloader[1])): 197 | encoder_input_ids = cnt['input_ids'] 198 | encoder_attention_mask = cnt['attention_mask'] 199 | encoder_token_type_ids = cnt['token_type_ids'] 200 | 201 | latent, encoder_output, past_key_values = model.encode(encoder_input_ids, encoder_attention_mask, encoder_token_type_ids) 202 | if not_latents is None: 203 | not_latents = latent.squeeze().detach() 204 | else: 205 | not_latents = torch.cat((not_latents, latent.squeeze().detach()), dim=0) 206 | 207 | 208 | 209 | kcmodel = KCenters(num_centers=args.num_centers, latent_size=args.latent_size, num_output_centers=args.num_output_centers, device='cuda') 210 | 211 | output_text = [] 212 | labels = [] 213 | 214 | 215 | 216 | for i in range(2): 217 | for j in range(4): 218 | weight = weight_dict[i][j] 219 | num_output_centers = args.num_output_centers[i][j] 220 | print(weight) 221 | print(num_output_centers) 222 | centers = kcmodel.train( 223 | [sentiment_latents[i].to('cuda'), topic_latents[j].to('cuda'), not_latents.to('cuda')], 224 | weight=weight, 225 | topk=args.topk, 226 | SDM_reinit=args.SDM_reinit, 227 | max_iter=args.max_iter, 228 | strategy=args.strategy, 229 | temperature=args.temperature, 230 | num_output_centers=num_output_centers 231 | ).cpu().numpy() 232 | centers = [torch.FloatTensor(k).unsqueeze(0) for k in centers] 233 | 234 | 235 | for prompts in tqdm(json.loads(args.pre_tokens)): 236 | tokens = decoder_tokenizer(prompts, return_tensors='pt') 237 | input_ids = tokens.input_ids 238 | attention_mask = tokens.attention_mask 239 | input_ids = input_ids.expand(args.batch_size, -1) 240 | attention_mask = attention_mask.expand(args.batch_size, -1) 241 | 242 | output = model.generate( 243 | input_latent=random.choice(centers), 244 | input_ids=input_ids, 245 | attention_mask=attention_mask, 246 | variation=args.variation, 247 | max_len=50, 248 | rp=1.2 249 | ) 250 | 251 | output_text.extend(decoder_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)) 252 | labels.extend([[i,j,1]] * args.batch_size) 253 | assert len(labels) == len(output_text) 254 | 255 | 256 | 257 | with open(args.output_dir, 'w') as f: 258 | for i in tqdm(range(len(output_text))): 259 | f.write(json.dumps([labels[i], output_text[i]])+'\n') 260 | -------------------------------------------------------------------------------- /multicontrol/generate_multi.sh: -------------------------------------------------------------------------------- 1 | python generate_multi.py --config generate_config_final.json --batch_size 5 --topk 200 --strategy none --model_path ./model/AE/checkpoint-30000/pytorch_model.bin --output_dir ./res/AE/predict_final.txt --latent_size 768 --variation 1e-3 2 | 3 | cd .. 4 | cd classify 5 | 6 | sh indep_eval.sh ../res/multicontrol/predict_final.txt logs/AE_final.txt 7 | cd .. 8 | cd multicontrol -------------------------------------------------------------------------------- /multicontrol/generation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | class KCenters: 4 | def __init__( 5 | self, 6 | num_centers, 7 | latent_size, 8 | num_output_centers, 9 | device, 10 | ): 11 | ''' 12 | num_centers: 13 | Number of clusters for searching. 14 | ''' 15 | self.num_centers = num_centers 16 | self.num_output_centers = num_output_centers 17 | self.device = device 18 | self.latent_size = latent_size 19 | self.centers = None#self.init_cluster_center(self.num_centers, self.latent_size).to(device) 20 | self.score = None 21 | 22 | 23 | def init_cluster_center(self, num_centers, latent_size): 24 | ''' 25 | ''' 26 | if num_centers == 0: 27 | clusters = None 28 | else: 29 | clusters = torch.rand(num_centers, latent_size) * 2 - 1 30 | return clusters 31 | 32 | 33 | def Sparse_Distributed_Memory_Reinitalization(self, X, topk): 34 | length = len(X) 35 | self.centers = None 36 | for i in range(length): 37 | query_matrix = X[i] 38 | query_centers = torch.zeros_like(query_matrix).to(self.device) 39 | for j in range(length): 40 | if j !=i : 41 | key_matrix = X[j] 42 | query_centers += self.optim(query_matrix, key_matrix, topk, 'none') 43 | query_centers = (query_centers + query_matrix) / length 44 | 45 | query_score = torch.zeros(query_centers.shape[0]).to(self.device) 46 | for matrix in X: 47 | tmp_score = -self.distance(query_centers, matrix) 48 | tmp_values, tmp_indices = torch.topk(tmp_score, k=topk, dim=-1) 49 | query_score += torch.mean(tmp_values, dim=-1) 50 | query_score = query_score/length 51 | query_values, query_indices = torch.topk(query_score, k=self.num_centers) 52 | query_centers = torch.index_select(query_centers, 0, query_indices) 53 | 54 | if self.centers is None: 55 | self.centers = query_centers 56 | else: 57 | self.centers = torch.cat([self.centers, query_centers], dim=0) 58 | 59 | 60 | scores = torch.zeros(self.centers.shape[0]).to(self.device) 61 | for matrix in X: 62 | tmp_score = -self.distance(self.centers, matrix) 63 | tmp_values, tmp_indices = torch.topk(tmp_score, k=topk, dim=-1) 64 | scores += torch.mean(tmp_values, dim=-1) 65 | scores = scores/length 66 | out_values, out_indices = torch.topk(scores, k=self.num_centers) 67 | self.centers = torch.index_select(self.centers, 0, out_indices) 68 | 69 | 70 | 71 | def Kernel_Density_Estimation(self): 72 | return 73 | 74 | 75 | 76 | def train( 77 | self, 78 | X, 79 | weight=[1,1,1], 80 | topk=50, 81 | max_iter=1, 82 | strategy='none', 83 | SDM_reinit=False, 84 | tol=1e-10, 85 | temperature=50, 86 | num_output_centers=None 87 | ): 88 | ''' 89 | X: [Tensor(batch, latent_size)] 90 | List of FloatTensors from different aspects 91 | example: X[0] from postive sentiment 92 | X[1] from nontoxic 93 | ''' 94 | 95 | assert strategy in {'none', 'weight'} 96 | length = sum(weight) 97 | 98 | if num_output_centers is not None: 99 | self.num_output_centers = num_output_centers 100 | 101 | if SDM_reinit: 102 | self.Sparse_Distributed_Memory_Reinitalization(X, topk) 103 | 104 | if strategy in {'none', 'weight'}: 105 | 106 | for i in tqdm(range(max_iter)): 107 | new_centers = torch.zeros_like(self.centers).to(self.device) 108 | 109 | for j in range(len(X)): 110 | matrix = X[j] 111 | w = weight[j] 112 | new_centers += w * self.optim(self.centers, matrix, topk, strategy, temperature=temperature) 113 | new_centers = new_centers/length 114 | 115 | 116 | self.centers = new_centers 117 | 118 | 119 | 120 | 121 | 122 | new_score = torch.zeros(self.centers.shape[0]).to(self.device) 123 | for matrix in X: 124 | tmp_score = -self.distance(self.centers, matrix) 125 | tmp_values, tmp_indices = torch.topk(tmp_score, k=topk, dim=-1) 126 | new_score += torch.mean(tmp_values, dim=-1) 127 | self.score = new_score/length 128 | 129 | 130 | out_values, out_indices = torch.topk(self.score, k=self.num_output_centers) 131 | return torch.index_select(self.centers, 0, out_indices) 132 | 133 | def optim(self, centers, matrix, topk, strategy, batch=100, temperature=50): 134 | tmp_score = - self.distance(centers, matrix) 135 | tmp_values, tmp_indices = torch.topk(tmp_score, k=topk, dim=-1) 136 | 137 | tot_num = tmp_indices.shape[0] 138 | epoch = tot_num//batch + (1 if tot_num % batch != 0 else 0) 139 | 140 | new_centers = None 141 | for i in range(epoch): 142 | start = i * batch 143 | end = i * batch + batch 144 | if end > tot_num: 145 | end = tot_num 146 | if strategy == 'none': 147 | tmp_centers = torch.mean(torch.gather(matrix.unsqueeze(0).expand(end-start,-1,-1), 1, tmp_indices[start:end].unsqueeze(-1).expand(-1,-1,self.latent_size)),dim=1).squeeze() 148 | elif strategy == 'weight': 149 | #torch.gather -> [batch_size, topk, latent_size] 150 | #[batch_size, latent_size, topk] * [topk, 1] 151 | weight = torch.softmax(-torch.log(-tmp_values[start:end]) * temperature, dim=-1).unsqueeze(-1) 152 | tmp_c = torch.gather(matrix.unsqueeze(0).expand(end-start,-1,-1), 1, tmp_indices[start:end].unsqueeze(-1).expand(-1,-1,self.latent_size)) 153 | tmp_centers = torch.matmul(tmp_c.permute(0,2,1),weight).squeeze() 154 | if new_centers is None: 155 | new_centers = tmp_centers 156 | else: 157 | new_centers = torch.cat([new_centers, tmp_centers], dim=0) 158 | return new_centers 159 | 160 | 161 | 162 | def distance(self, matrix1, matrix2, batch=100): 163 | ''' 164 | Input: 165 | matrix1: FloatTensor(i * m) 166 | matrix2: FloatTensor(j * m) 167 | Output: 168 | distance matrix: FloatTensor(i * j) 169 | ''' 170 | 171 | assert len(matrix1.shape) == 2 172 | assert len(matrix2.shape) == 2 173 | 174 | dis = None 175 | tot_num = matrix1.shape[0] 176 | epoch = tot_num//batch + (1 if tot_num % batch != 0 else 0) 177 | matrix1 = matrix1.unsqueeze(dim=1) 178 | matrix2 = matrix2.unsqueeze(dim=0) 179 | for i in range(epoch): 180 | start = i * batch 181 | end = i * batch + batch 182 | if end > tot_num: 183 | end = tot_num 184 | tmp_matrix1 = matrix1[start:end] 185 | tmp_dis = (tmp_matrix1 - matrix2) ** 2.0 186 | tmp_dis = torch.sum(tmp_dis, dim=-1).squeeze() 187 | if dis is None: 188 | dis = tmp_dis 189 | else: 190 | dis = torch.cat([dis, tmp_dis], dim=0) 191 | 192 | 193 | return dis 194 | 195 | -------------------------------------------------------------------------------- /multicontrol/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import wandb 5 | import json 6 | 7 | 8 | 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Reshape(nn.Module): 14 | ''' 15 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 16 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, 17 | sequence_length, embed_size_per_head)`). 18 | ''' 19 | def __init__(self, arg_dict): 20 | super(Reshape, self).__init__() 21 | self.seq_len = arg_dict['seq_len'] 22 | self.num_layer = arg_dict['num_layer'] 23 | self.hidden_size = arg_dict['hidden_size'] 24 | self.num_head = arg_dict['num_head'] 25 | def forward(self, x): 26 | batch_size = x.shape[0] 27 | assert self.hidden_size % self.num_head == 0 28 | embed_size_per_head = self.hidden_size//self.num_head 29 | x = x.view(batch_size, self.num_layer, 2, self.num_head, self.seq_len, embed_size_per_head).permute(1,2,0,3,4,5) 30 | past_key_values = [] 31 | for i in range(self.num_head): 32 | past_key_values.append((x[i][0],x[i][1],)) 33 | assert past_key_values[0][0].requires_grad == True 34 | return tuple(past_key_values) 35 | 36 | 37 | 38 | class AE(nn.Module): 39 | """ 40 | AE with Decoder Fixed. 41 | We adopt connection method from prompt tuning. 42 | """ 43 | #_keys_to_ignore_on_load_missing = [r"latent_classify_head\.\d+\.weight", r"latent_classify_head\.\d+\.bias"] 44 | def __init__(self, encoder, decoder, args): # 45 | super(AE, self).__init__() 46 | self.encoder = encoder #BertModel 47 | self.decoder = decoder #GPT2LMHeadModel 48 | 49 | self.encoder_config = encoder.config 50 | self.encoder_hidden_size = self.encoder_config.hidden_size 51 | 52 | self.decoder_config = decoder.config 53 | self.decoder_num_layer = self.decoder_config.n_layer 54 | self.decoder_hidden_size = self.decoder_config.n_embd 55 | self.decoder_num_head = self.decoder_config.n_head 56 | 57 | self.losslist = None 58 | self.latent_classify_head = None 59 | 60 | 61 | 62 | self.args = args 63 | self.latent_size = args.latent_size 64 | self.seq_len_per_latent = args.seq_len_per_latent 65 | self.latent_num = args.latent_num 66 | self.seq_len = self.latent_num * self.seq_len_per_latent 67 | if 'variation' in args: 68 | self.variation = args.variation 69 | else: 70 | self.variation = 0 71 | 72 | 73 | ## connector: 74 | # 1. from Bert hidden units to the latent space 75 | # 2. convert latent space to `past_key_values' in GPT 76 | # [batch_size, bert_hidden_size] -> [batch_size, latent_num * latent_size] 77 | # -> [batch_size, latent_num, decoder_layer * len([key,value]) * gpt_hidden_size] 78 | # -> (num_layer* (len([key,value])* tensor[batch_size, num_head, seq_len, embed_size_per_head])) 79 | ''' 80 | self.trans = torch.nn.Sequential( 81 | torch.nn.Linear(self.encoder_hidden_size, self.latent_num * self.latent_size), 82 | torch.nn.Tanh(), 83 | torch.nn.Linear(self.latent_num * self.latent_size, self.seq_len * self.decoder_num_layer * 2 * self.decoder_hidden_size), 84 | Reshape({'seq_len':self.seq_len, 'num_layer':self.decoder_num_layer, 'hidden_size':self.decoder_hidden_size, 'num_head':self.decoder_num_head}) 85 | ) 86 | ''' 87 | self.trans1 = torch.nn.Sequential( 88 | torch.nn.Linear(self.encoder_hidden_size, self.latent_num * self.latent_size), 89 | torch.nn.Tanh(), 90 | nn.Dropout(self.decoder_config.attn_pdrop)#added 91 | ) 92 | self.trans2 = torch.nn.Sequential( 93 | torch.nn.Linear(self.latent_num * self.latent_size, self.seq_len * self.decoder_num_layer * 2 * self.decoder_hidden_size), 94 | nn.Dropout(self.decoder_config.attn_pdrop),#added 95 | Reshape({'seq_len':self.seq_len, 'num_layer':self.decoder_num_layer, 'hidden_size':self.decoder_hidden_size, 'num_head':self.decoder_num_head}) 96 | ) 97 | 98 | 99 | def fix_decoder(self): 100 | ''' 101 | Fix the decoder to work as prefix tuning. 102 | ''' 103 | self.decoder.eval() 104 | for param in self.decoder.parameters(): 105 | param.requires_grad = False 106 | 107 | 108 | def connect(self, encoder_output, variation=0): 109 | ''' 110 | Connect encoder to decoder and get the latent representation. 111 | ''' 112 | tmp_latent = self.trans1(encoder_output) 113 | eps = torch.zeros_like(tmp_latent).normal_(std=variation).to(tmp_latent.device) 114 | past_key_values = self.trans2(tmp_latent + eps) 115 | latent = tmp_latent.view(-1, self.latent_num, self.latent_size) 116 | return past_key_values, latent 117 | 118 | 119 | def sparse_loss(self, latent, dim=None): 120 | ''' 121 | Increase the sparsity. 122 | ''' 123 | if len(latent) == 3 and dim is None: 124 | raise Exception('Expect latent to be dim 2.') 125 | loss_func = nn.L1Loss(reduction='mean') 126 | batch_size = latent.shape[0] 127 | if dim is not None: 128 | tmp_latent = latent[:,dim,:].squeeze() 129 | average = torch.sum(tmp_latent, dim=0)/batch_size 130 | loss = loss_func(latent, average.expand(batch_size, -1)) 131 | else: 132 | average = torch.sum(latent, dim=0)/batch_size 133 | loss = loss_func(latent, average.expand(batch_size, -1)) 134 | return -loss 135 | 136 | def contrasitive_loss(self, latent1, latent2, loss_func=nn.SmoothL1Loss(reduction='mean'), dim=None): 137 | ''' 138 | Increase the distance between latent1 and latent2. 139 | loss_func: nn.L1Loss, nn.SmoothL1Loss, nn.MSELoss, ... 140 | ''' 141 | if dim is not None: 142 | loss = loss_func(latent1[:,dim,:].squeeze(), latent2[:,dim,:].squeeze()) 143 | else: 144 | loss = loss_func(latent1, latent2) 145 | return -1 * loss 146 | 147 | def latent_classify_loss(self, latent, pos_label, neg_labels, head_index=None): 148 | if len(latent.shape) == 3: 149 | latent = latent.view(-1, self.latent_num*self.latent_size) 150 | if self.latent_classify_head_type == 'single': 151 | probs = torch.softmax(self.latent_classify_head(latent), dim=-1) 152 | batch_size, class_num = probs.shape 153 | loss = 0 154 | neg_len = neg_labels.shape[-1] 155 | 156 | for i in range(batch_size): 157 | pos_prob = probs[i, pos_label[i]] 158 | if pos_prob < 1/self.head_num: 159 | loss += torch.log(pos_prob) 160 | loss += torch.log(1 - probs[i, neg_labels[i]]).sum() 161 | 162 | return -1 * loss / (batch_size * (neg_len+1)) 163 | elif self.latent_classify_head_type == 'multiple': 164 | if head_index is None: 165 | print("UserWarning: head_index not set for multiple classifier head, default to 0") 166 | head_index = 0 167 | device = latent.device 168 | logits = self.latent_classify_head[head_index](latent) 169 | loss = torch.nn.functional.cross_entropy(logits, pos_label.to(device)) 170 | return loss 171 | else: 172 | raise Exception('Wrong latent classifier head type.') 173 | 174 | def aspect_gap_loss(self, latent, head_index): 175 | if len(latent.shape) == 3: 176 | latent = latent.view(-1, self.latent_num * self.latent_size) 177 | 178 | mean_latent = torch.mean(latent, dim=0) 179 | loss = None 180 | for i in range(self.aspect_head_num): 181 | if i != head_index and self.aspect_gap_head[i] is not None: 182 | if loss is None: 183 | loss = torch.nn.functional.mse_loss(mean_latent, self.aspect_gap_head[i]) * self.aspect_gap_loss_amplification 184 | else: 185 | loss += torch.nn.functional.mse_loss(mean_latent, self.aspect_gap_head[i]) * self.aspect_gap_loss_amplification 186 | self.set_aspect_gap_head(mean_latent, head_index) 187 | return loss 188 | 189 | def set_losslist(self, 190 | losslist:dict, 191 | latent_classify_args={'head_num':1, 'class_num_per_head':2,'mid_size':128,'head_type':'single'}, 192 | aspect_gap_args={'head_num':2, 'amplification':5} 193 | ): 194 | ''' 195 | losslist: 196 | Sample: {'contrasitive_loss': 0.001, 'sparse_loss': 0.001, 'latent_classify_loss':0.1, 'aspect_gap_loss':0.1} 197 | ''' 198 | self.losslist = losslist 199 | if 'latent_classify_loss' in losslist: 200 | self.head_num = 1 201 | class_num_per_head = 2 202 | mid_size = 128 203 | head_type = 'single' 204 | if latent_classify_args is not None: 205 | if 'head_num' in latent_classify_args: 206 | self.head_num = latent_classify_args['head_num'] 207 | if 'class_num_per_head' in latent_classify_args: 208 | class_num_per_head = latent_classify_args['class_num_per_head'] 209 | if 'mid_size' in latent_classify_args: 210 | mid_size = latent_classify_args['mid_size'] 211 | if 'head_type' in latent_classify_args: 212 | head_type = latent_classify_args['head_type'] 213 | 214 | self.set_latent_classify_head(head_num=self.head_num, class_num_per_head=class_num_per_head, mid_size=mid_size, head_type=head_type) 215 | 216 | self.latent_classify_head_type=head_type 217 | 218 | if 'aspect_gap_loss' in losslist: 219 | if 'latent_classify_loss' in losslist: 220 | if self.latent_classify_head_type == 'multiple': 221 | self.aspect_head_num = self.head_num 222 | elif self.latent_classify_head == 'single': 223 | print('set aspect head num to {aspect_head_num}.') 224 | self.aspect_head_num = aspect_gap_args['head_num'] 225 | else: 226 | print('set aspect head num to {aspect_head_num}.') 227 | self.aspect_head_num = aspect_gap_args['head_num'] 228 | 229 | self.aspect_gap_loss_amplification = aspect_gap_args['amplification'] 230 | 231 | self.aspect_gap_head = [None for i in range(self.aspect_head_num)] 232 | 233 | def set_latent_classify_head(self, head_num=1, class_num_per_head=2, mid_size=128, head_type='single'): 234 | if head_type == 'single': 235 | self.latent_classify_head = nn.Sequential( 236 | nn.Linear(self.latent_num * self.latent_size, mid_size), 237 | nn.ReLU(), 238 | nn.Linear(mid_size, class_num_per_head * head_num) 239 | ) 240 | elif head_type == 'multiple': 241 | if type(class_num_per_head) is list: 242 | self.latent_classify_head = nn.ModuleList([ 243 | nn.Sequential( 244 | nn.Linear(self.latent_num * self.latent_size, mid_size), 245 | nn.ReLU(), 246 | nn.Linear(mid_size, head_num) 247 | ) for head_num in class_num_per_head] 248 | ) 249 | else: 250 | self.latent_classify_head = nn.ModuleList([ 251 | nn.Sequential( 252 | nn.Linear(self.latent_num * self.latent_size, mid_size), 253 | nn.ReLU(), 254 | nn.Linear(mid_size, class_num_per_head) 255 | ) for i in range(head_num)] 256 | ) 257 | 258 | def set_aspect_gap_head(self, latent, head_index): 259 | if len(latent.shape) == 3: 260 | latent = latent.view(-1, self.latent_num * self.latent_size) 261 | 262 | if len(latent.shape) == 2: 263 | mean_latent = torch.mean(latent.detach(), dim=0) 264 | if self.aspect_gap_head[head_index] is not None: 265 | assert self.aspect_gap_head[head_index].shape == mean_latent.shape 266 | self.aspect_gap_head[head_index] = mean_latent 267 | elif len(latent.shape) == 1: 268 | if self.aspect_gap_head[head_index] is not None: 269 | assert self.aspect_gap_head[head_index].shape == latent.shape 270 | self.aspect_gap_head[head_index] = latent.detach() 271 | 272 | 273 | 274 | 275 | def forward(self, 276 | encoder_input_ids, 277 | encoder_attention_mask, 278 | encoder_token_type_ids, 279 | decoder_input_ids, 280 | decoder_attention_mask, 281 | adv_input_ids=None, 282 | adv_attention_mask=None, 283 | adv_token_type_ids=None, 284 | pos_label=None, 285 | neg_labels=None, 286 | variation=None, 287 | head_index=None 288 | ): 289 | ''' 290 | Forward method for training which returns a reconstruction loss. 291 | Args: 292 | encoder_input_ids, 293 | encoder_atention_mask, 294 | encoder_token_type_ids: 295 | Outputs of BertTokenizer(List of Strings, return_tensors='pt', padding=True) 296 | decoder_input_ids, 297 | decoder_attention_mask: 298 | Outputs of GPT2Tokenizer(List of Strings, return_tensors='pt', padding=True) 299 | adv_input_ids, 300 | adv_attention_mask, 301 | adv_token_type_ids: 302 | Adversarial text 303 | 304 | ''' 305 | if len(encoder_input_ids.shape) == 3: 306 | encoder_input_ids = encoder_input_ids.view(encoder_input_ids.shape[1], encoder_input_ids.shape[2]) 307 | encoder_attention_mask = encoder_attention_mask.view(encoder_attention_mask.shape[1], encoder_attention_mask.shape[2]) 308 | encoder_token_type_ids = encoder_token_type_ids.view(encoder_token_type_ids.shape[1], encoder_token_type_ids.shape[2]) 309 | decoder_input_ids = decoder_input_ids.view(decoder_input_ids.shape[1], decoder_input_ids.shape[2]) 310 | decoder_attention_mask = decoder_attention_mask.view(decoder_attention_mask.shape[1], decoder_attention_mask.shape[2]) 311 | if head_index is not None: 312 | head_index = head_index.item() 313 | if adv_input_ids is not None: 314 | adv_input_ids = adv_input_ids.view(adv_input_ids.shape[1], adv_input_ids.shape[2]) 315 | if adv_attention_mask is not None: 316 | adv_attention_mask = adv_attention_mask.view(adv_attention_mask.shape[1], adv_attention_mask.shape[2]) 317 | if adv_token_type_ids is not None: 318 | adv_token_type_ids = adv_token_type_ids.view(adv_token_type_ids.shape[1], adv_token_type_ids.shape[2]) 319 | if pos_label is not None: 320 | pos_label = pos_label.view(pos_label.shape[1]) 321 | if neg_labels is not None: 322 | neg_labels = neg_labels.view(neg_labels.shape[1], neg_labels.shape[2]) 323 | 324 | if variation is None: 325 | variation = self.variation 326 | batch_size = decoder_input_ids.shape[0] 327 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(decoder_input_ids.device) 328 | decoder_attention_mask = torch.cat([infix_attn, decoder_attention_mask], dim=1) 329 | 330 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 331 | past_key_values, latent = self.connect(encoder_output, variation) 332 | outputs = self.decoder(input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, labels=decoder_input_ids, past_key_values=past_key_values, return_dict=True) 333 | lm_loss = outputs.loss 334 | # lm_logits = outputs.logits 335 | 336 | #labels = decoder_input_ids 337 | # Shift so that tokens < n predict n 338 | #shift_logits = lm_logits[..., self.seq_len:-1, :].contiguous()#########lm_logits[..., :-1, :] -> lm_logits[..., self.seq_len:-1, :] 339 | #shift_labels = labels[..., 1:].contiguous() 340 | # Flatten the tokens 341 | #loss_fct = torch.nn.CrossEntropyLoss() 342 | #lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 343 | 344 | #logits = outputs.lm_logits 345 | loss = 0 346 | loss_detail = {"lm_loss":lm_loss.detach()} 347 | w = 1 348 | if self.losslist is not None: 349 | if 'contrasitive_loss' in self.losslist: 350 | if adv_input_ids is None: 351 | raise Exception('Expect adversarial inputs for contrasitive loss.') 352 | adv_encoder_output = self.encoder(input_ids=adv_input_ids, attention_mask=adv_attention_mask, token_type_ids=adv_token_type_ids, return_dict=True).pooler_output 353 | adv_latent = self.trans1(adv_encoder_output) 354 | adv_loss = self.contrasitive_loss(latent, adv_latent) 355 | #TO DO: change arg `dim' in the future 356 | loss += adv_loss * self.losslist['contrasitive_loss'] 357 | w -= self.losslist['contrasitive_loss'] 358 | loss_detail["contrasitive_loss"] = adv_loss.detach() 359 | 360 | if 'sparse_loss' in self.losslist: 361 | spa_loss = self.sparse_loss(latent) 362 | #TO DO: change arg `dim' in the future 363 | loss += spa_loss * self.losslist['sparse_loss'] 364 | w -= self.losslist['sparse_loss'] 365 | loss_detail["sparse_loss"] = spa_loss.detach() 366 | 367 | if 'latent_classify_loss' in self.losslist: 368 | lac_loss = self.latent_classify_loss(latent, pos_label, neg_labels, head_index) 369 | if lac_loss.detach().item() < 0.1: 370 | loss += lac_loss * 0.05 371 | w -= 0.05 372 | else: 373 | loss += lac_loss * self.losslist['latent_classify_loss'] 374 | w -= self.losslist['latent_classify_loss'] 375 | loss_detail["latent_classify_loss"] = lac_loss.detach() 376 | 377 | 378 | if 'aspect_gap_loss' in self.losslist: 379 | agp_loss = self.aspect_gap_loss(latent, head_index) 380 | if agp_loss is not None: 381 | loss += agp_loss * self.losslist['aspect_gap_loss'] 382 | w -= self.losslist['aspect_gap_loss'] 383 | loss_detail["aspect_gap_loss"] = agp_loss.detach() 384 | 385 | wandb.log(loss_detail) 386 | if w < 0: 387 | w = 1 388 | 389 | loss += w * lm_loss 390 | 391 | return loss, latent, loss_detail 392 | 393 | 394 | def encode(self, 395 | encoder_input_ids, 396 | encoder_attention_mask=None, 397 | encoder_token_type_ids=None, 398 | ): 399 | ''' 400 | Encode the input text and get the latent representation 401 | ''' 402 | device = next(self.parameters()).device 403 | encoder_input_ids = encoder_input_ids.to(device) 404 | if encoder_attention_mask is not None: 405 | encoder_attention_mask = encoder_attention_mask.to(device) 406 | if encoder_token_type_ids is not None: 407 | encoder_token_type_ids = encoder_token_type_ids.to(device) 408 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 409 | past_key_values, latent = self.connect(encoder_output) 410 | return latent, encoder_output, past_key_values 411 | 412 | 413 | def generate( 414 | self, 415 | input_latent, 416 | input_ids=None, 417 | attention_mask=None, 418 | batch_size=None, 419 | variation=None, 420 | min_len=30, 421 | max_len=50, 422 | do_sample=True, 423 | topk=5, 424 | topp=0.9, 425 | lp=1, 426 | rp=1.0, 427 | use_cache=True): 428 | ''' 429 | Generate text with given latent represention. 430 | ''' 431 | device = next(self.parameters()).device 432 | input_latent = input_latent.to(device) 433 | 434 | 435 | if len(input_latent.shape) == 3: 436 | tmp_batch_size, latent_num, latent_size = input_latent.shape 437 | input_latent = input_latent.view(tmp_batch_size, latent_num * latent_size) 438 | elif len(input_latent.shape) != 2: 439 | raise Exception('Shape of input_latent is expected to be [batch_size, latent_num, latent_size] \ 440 | or [batch_size, latent_num * latent_size]') 441 | 442 | 443 | if batch_size is None: 444 | batch_size = input_latent.shape[0] 445 | if input_ids is not None: 446 | if input_ids.shape[0] > batch_size: 447 | if batch_size == 1: 448 | batch_size = input_ids.shape[0] 449 | input_latent = input_latent.expand(batch_size, -1) 450 | else: 451 | raise Exception('Batch size of input_latent and input_ids mismatched') 452 | elif input_ids.shape[0] < batch_size and input_ids.shape[0] == 1: 453 | input_ids = input_ids.expand(batch_size, -1) 454 | 455 | 456 | 457 | if input_latent.shape[0] < batch_size: 458 | input_latent.expand(batch_size, -1) 459 | 460 | 461 | if variation is not None: 462 | eps = torch.zeros_like(input_latent).normal_(std=variation).to(input_latent.device) 463 | input_latent = input_latent + eps 464 | 465 | 466 | past_key_values = self.trans2(input_latent) 467 | 468 | if input_ids is None: 469 | input_ids = self.decoder.generate(input_ids=torch.LongTensor([[50256]]*batch_size).to(device), max_length=3, do_sample=True)[:,1:] 470 | attention_mask = torch.ones(batch_size, 2).bool() 471 | else: 472 | input_ids = input_ids.to(device) 473 | if attention_mask is None: 474 | attention_mask = torch.ones(batch_size, 2).bool() 475 | 476 | cur_len = input_ids.shape[1] 477 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(device) 478 | attention_mask = torch.cat([infix_attn, attention_mask.to(device)], dim=-1) 479 | 480 | if cur_len < 1: 481 | raise Exception('input length error') 482 | if cur_len == 1: 483 | result = self.decoder.generate(input_ids=input_ids, past=past_key_values, attention_mask=attention_mask, repetition_penalty=rp,\ 484 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 485 | else: 486 | past_key_values = self.decoder(input_ids=input_ids[:,:-1], attention_mask=attention_mask[:,:-1], past_key_values=past_key_values, return_dict=True, use_cache=True).past_key_values 487 | result = self.decoder.generate(input_ids=input_ids, past=past_key_values, attention_mask=attention_mask, repetition_penalty=rp,\ 488 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 489 | 490 | return result 491 | 492 | 493 | def reconstruct(self, 494 | encoder_input_ids, 495 | decoder_input_ids=None, 496 | encoder_attention_mask=None, 497 | encoder_token_type_ids=None, 498 | decoder_attention_mask=None, 499 | do_sample=True, 500 | max_len=50, 501 | min_len=30, 502 | topk=5, 503 | topp=0.9, 504 | lp=1.0, 505 | use_cache=True): 506 | ''' 507 | Reconstruct input text. 508 | ''' 509 | device = next(self.parameters()).device 510 | batch_size = encoder_input_ids.shape[0] 511 | encoder_input_ids = encoder_input_ids.to(device) 512 | if encoder_attention_mask is not None: 513 | encoder_attention_mask = encoder_attention_mask.to(device) 514 | if encoder_token_type_ids is not None: 515 | encoder_token_type_ids = encoder_token_type_ids.to(device) 516 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 517 | 518 | past_key_values, latent = self.connect(encoder_output) 519 | if decoder_input_ids is None: 520 | decoder_input_ids = self.decoder.generate(input_ids=torch.LongTensor([[50256]]*batch_size).to(device), max_length=3, do_sample=True)[:,1:] 521 | decoder_attention_mask = torch.ones(batch_size, 2).bool() 522 | else: 523 | decoder_input_ids = decoder_input_ids.to(device) 524 | if decoder_attention_mask is None: 525 | decoder_attention_mask = torch.ones(batch_size, 2).bool() 526 | 527 | cur_len = decoder_input_ids.shape[1] 528 | 529 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(device) 530 | decoder_attention_mask = torch.cat([infix_attn, decoder_attention_mask.to(device)], dim=-1) 531 | 532 | if cur_len < 1: 533 | raise Exception('input length error') 534 | if cur_len == 1: 535 | result = self.decoder.generate(input_ids=decoder_input_ids, past=past_key_values, attention_mask=decoder_attention_mask,\ 536 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 537 | else: 538 | past_key_values = self.decoder(input_ids=decoder_input_ids[:,:-1], attention_mask=decoder_attention_mask[:,:-1], past_key_values=past_key_values, return_dict=True, use_cache=True).past_key_values 539 | result = self.decoder.generate(input_ids=decoder_input_ids, past=past_key_values, attention_mask=decoder_attention_mask,\ 540 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 541 | return result 542 | -------------------------------------------------------------------------------- /multicontrol/requirements.txt: -------------------------------------------------------------------------------- 1 | python=3.7.12 2 | apex==0.1 3 | datasets==2.0.0 4 | huggingface-hub==0.4.0 5 | nltk==3.7 6 | numpy==1.21.5 7 | scikit-learn==1.0.2 8 | scipy==1.7.3 9 | sentencepiece==0.1.96 10 | tokenizers==0.11.6 11 | torch==1.10.0 12 | tqdm==4.63.1 13 | transformers==4.17.0 14 | wandb==0.12.11 15 | 16 | -------------------------------------------------------------------------------- /multicontrol/train_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | import datasets 6 | from datasets import load_dataset, load_metric, concatenate_datasets, Dataset 7 | from transformers import Trainer, TrainingArguments 8 | from tqdm import tqdm 9 | import json 10 | import wandb 11 | 12 | import random 13 | 14 | from model import AE 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 19 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 20 | parser.add_argument('--model_dir', default='../model/multicontrol/') 21 | parser.add_argument("--no_cuda", action="store_true") 22 | parser.add_argument("--latent_size", type=int, default=768) 23 | parser.add_argument("--latent_num",type=int, default=1) 24 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 25 | parser.add_argument("--batch_size", type=int, default=128) 26 | parser.add_argument("--epoch",type=int, default=200) 27 | parser.add_argument("--lr",type=float, default=1e-4) 28 | parser.add_argument("--fp16", action="store_true") 29 | parser.add_argument("--wandb", action="store_true") 30 | parser.add_argument("--no_fix", action="store_true") 31 | parser.add_argument("--max_length", type=int, default=100) 32 | parser.add_argument("--contrasitive_loss", type=float, default=None) 33 | parser.add_argument("--sparse_loss", type=float, default=None) 34 | parser.add_argument("--latent_classify_loss", type=float, default=None) 35 | parser.add_argument("--aspect_gap_loss", type=float, default=None) 36 | parser.add_argument("--variation", type=float, default=0) 37 | 38 | parser.add_argument("--classifier_head_num", type=int, default=3) 39 | parser.add_argument("--classifier_class_num_per_head", type=str, default='[2,2,4]') 40 | parser.add_argument("--classifier_mid_size", type=int, default=128) 41 | parser.add_argument("--classifier_head_type", type=str, default='multiple', choices=('single', 'multiple')) 42 | 43 | parser.add_argument("--aspect_gap_head_num", type=int, default=3) 44 | parser.add_argument("--aspect_gap_amplification", type=int, default=10) 45 | 46 | args = parser.parse_args() 47 | 48 | if args.wandb: 49 | wandb.login() 50 | wandb.init(project="", entity="")#your account 51 | 52 | encoder_tokenizer = BertTokenizer.from_pretrained(args.pretrained_encoder) 53 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 54 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 55 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 56 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 57 | 58 | model = AE(encoder=encoder, decoder=decoder, args=args) 59 | 60 | loss_list = {} 61 | latent_classify_args = None 62 | aspect_gap_args = None 63 | if args.contrasitive_loss is not None: 64 | loss_list['contrasitive_loss'] = args.contrasitive_loss 65 | if args.sparse_loss is not None: 66 | loss_list['sparse_loss'] = args.sparse_loss 67 | if args.latent_classify_loss is not None: 68 | loss_list['latent_classify_loss'] = args.latent_classify_loss 69 | latent_classify_args = { 70 | 'head_num':args.classifier_head_num, 71 | 'class_num_per_head':[2,2,4],#json.loads(args.classifier_class_num_per_head), 72 | 'mid_size':args.classifier_mid_size, 73 | 'head_type':args.classifier_head_type 74 | } 75 | if args.aspect_gap_loss is not None: 76 | loss_list['aspect_gap_loss'] = args.aspect_gap_loss 77 | aspect_gap_args = { 78 | 'head_num':args.aspect_gap_head_num, 79 | 'amplification':args.aspect_gap_amplification 80 | } 81 | 82 | 83 | 84 | if len(loss_list) == 0: 85 | loss_list = None 86 | 87 | model.set_losslist(loss_list, latent_classify_args, aspect_gap_args) 88 | 89 | if not args.no_fix: 90 | model.fix_decoder() 91 | 92 | 93 | if args.classifier_head_type == 'multiple': 94 | dataset = [{'sent':[], 'type':[]} for i in range(3)] 95 | if 'contrasitive_loss' in loss_list: 96 | for i in range(len(dataset)): 97 | dataset[i]['adv_sent'] = [] 98 | 99 | with open('../data/IMDb/IMDb.txt', 'r') as f: 100 | for line in f.readlines(): 101 | line = json.loads(line) 102 | dataset[0]['sent'].append(line[1].strip()) 103 | dataset[0]['type'].append(int(line[0])) 104 | 105 | with open('../data/ToxicComment/Toxic.txt', 'r') as f: 106 | for line in f.readlines(): 107 | line = json.loads(line) 108 | dataset[1]['sent'].append(line[1].strip()) 109 | dataset[1]['type'].append(int(line[0])) 110 | 111 | with open('../data/AGnews/AG-data.txt', 'r') as f: 112 | for line in f.readlines(): 113 | line = json.loads(line) 114 | dataset[2]['sent'].append(line[1].strip()) 115 | dataset[2]['type'].append(int(line[0])) 116 | 117 | 118 | 119 | columns = ['encoder_input_ids', 'encoder_attention_mask', 'encoder_token_type_ids', 'type', 'decoder_input_ids', 'decoder_attention_mask'] 120 | if 'contrasitive_loss' in loss_list: 121 | columns.extend(['adv_input_ids', 'adv_attention_mask', 'adv_token_type_ids']) 122 | if 'latent_classify_loss' in loss_list: 123 | columns.extend(['pos_label', 'neg_labels']) 124 | train_dataset = {i:[] for i in columns} 125 | train_dataset['head_index']=[] 126 | for i in range(3): 127 | tmp_dataset = Dataset.from_dict(dataset[i]) 128 | tmp_dataset = tmp_dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 129 | tmp_dataset = tmp_dataset.rename_columns({'input_ids':'encoder_input_ids', 'attention_mask':'encoder_attention_mask', 'token_type_ids':'encoder_token_type_ids'}) 130 | tmp_dataset = tmp_dataset.map(lambda e: decoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 131 | tmp_dataset = tmp_dataset.rename_columns({'input_ids':'decoder_input_ids', 'attention_mask':'decoder_attention_mask'}) 132 | if 'contrasitive_loss' in loss_list: 133 | tmp_dataset = tmp_dataset.map(lambda e: encoder_tokenizer(e['adv_sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 134 | tmp_dataset = tmp_dataset.rename_columns({'input_ids':'adv_input_ids', 'attention_mask':'adv_attention_mask', 'token_type_ids':'adv_token_type_ids'}) 135 | if 'latent_classify_loss' in loss_list: 136 | tmp_dataset = tmp_dataset.map(lambda e: {'pos_label': e['type'], 'neg_labels': [1 - e['type']]}) 137 | 138 | tmp_dataset.set_format(type='torch', columns=columns) 139 | 140 | tmp_dataloader = torch.utils.data.DataLoader(tmp_dataset, batch_size=args.batch_size) 141 | for cnt in iter(tmp_dataloader): 142 | for k in columns: 143 | train_dataset[k].append(cnt[k]) 144 | train_dataset['head_index'].append(i) 145 | train_dataset = Dataset.from_dict(train_dataset) 146 | train_dataset.set_format(columns=columns+['head_index']) 147 | 148 | training_args = TrainingArguments( 149 | output_dir=args.model_dir, 150 | learning_rate=args.lr, 151 | num_train_epochs=args.epoch, 152 | #gradient_accumulation_steps=4, 153 | per_device_train_batch_size=1, 154 | logging_dir='./logs', 155 | logging_steps=100, 156 | do_train=True, 157 | do_eval=False, 158 | no_cuda=args.no_cuda, 159 | save_strategy="steps", 160 | save_steps=2000, 161 | fp16=args.fp16, 162 | report_to='wandb' if args.wandb else 'none' 163 | ) 164 | 165 | 166 | 167 | 168 | 169 | elif args.classifier_head_type == 'single': 170 | #need implemention 171 | dataset = {'sent':[], 'type':[], 'neg_types':[]} 172 | if 'contrasitive_loss' in loss_list: 173 | dataset['adv_sent'] = [] 174 | 175 | 176 | with open('../data/SST/pos.txt', 'r') as f: 177 | for line in f.readlines(): 178 | dataset['sent'].append(line.strip()) 179 | dataset['type'].append(1) 180 | dataset['neg_types'].append([0]) 181 | 182 | with open('../data/SST/neg.txt', 'r') as f: 183 | for line in f.readlines(): 184 | dataset['sent'].append(line.strip()) 185 | dataset['type'].append(0) 186 | dataset['neg_types'].append([1]) 187 | 188 | 189 | with open('../data/topic/military/s_military.json', 'r') as f: 190 | for line in json.loads(f.read()): 191 | dataset['sent'].append(line.strip().split('<|endoftext|>')[0].strip()) 192 | dataset['type'].append(3) 193 | dataset['neg_types'].append([2]) 194 | 195 | with open('../data/topic/military/s_common.json', 'r') as f: 196 | for line in json.loads(f.read()): 197 | dataset['sent'].append(line.strip().split('<|endoftext|>')[0].strip()) 198 | dataset['type'].append(2) 199 | dataset['neg_types'].append([3]) 200 | 201 | train_dataset = Dataset.from_dict(dataset) 202 | 203 | train_dataset = train_dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 204 | train_dataset = train_dataset.rename_columns({'input_ids':'encoder_input_ids', 'attention_mask':'encoder_attention_mask', 'token_type_ids':'encoder_token_type_ids'}) 205 | train_dataset = train_dataset.map(lambda e: decoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 206 | train_dataset = train_dataset.rename_columns({'input_ids':'decoder_input_ids', 'attention_mask':'decoder_attention_mask'}) 207 | columns = ['encoder_input_ids', 'encoder_attention_mask', 'encoder_token_type_ids', 'type', 'decoder_input_ids', 'decoder_attention_mask'] 208 | if 'contrasitive_loss' in loss_list: 209 | train_dataset = train_dataset.map(lambda e: encoder_tokenizer(e['adv_sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 210 | train_dataset = train_dataset.rename_columns({'input_ids':'adv_input_ids', 'attention_mask':'adv_attention_mask', 'token_type_ids':'adv_token_type_ids'}) 211 | columns.extend(['adv_input_ids', 'adv_attention_mask', 'adv_token_type_ids']) 212 | 213 | if 'latent_classify_loss' in loss_list: 214 | train_dataset = train_dataset.map(lambda e: {'pos_label': e['type'], 'neg_labels': e['neg_types']}) 215 | columns.extend(['pos_label', 'neg_labels']) 216 | 217 | 218 | train_dataset.set_format(type='torch', columns=columns) 219 | 220 | training_args = TrainingArguments( 221 | output_dir=args.model_dir, 222 | learning_rate=args.lr, 223 | num_train_epochs=args.epoch, 224 | per_device_train_batch_size=args.batch_size, 225 | logging_dir='./logs', 226 | logging_steps=100, 227 | do_train=True, 228 | do_eval=False, 229 | no_cuda=args.no_cuda, 230 | save_strategy="steps", 231 | save_steps=2000, 232 | fp16=args.fp16, 233 | report_to='wandb' if args.wandb else 'none' 234 | ) 235 | 236 | 237 | 238 | trainer = Trainer( 239 | model=model, 240 | args=training_args, 241 | train_dataset=train_dataset 242 | ) 243 | train_out = trainer.train() -------------------------------------------------------------------------------- /multicontrol/train_multi.sh: -------------------------------------------------------------------------------- 1 | python train_multi.py --fp16 --wandb --latent_classify_loss 0.2 --aspect_gap_loss 0.3 --variation 1e-3 --classifier_head_type multiple --lr 1e-4 --epoch 300 -------------------------------------------------------------------------------- /priorcontrol/combine_eval.sh: -------------------------------------------------------------------------------- 1 | if [ $# == 2 ]; then 2 | file_loc=$1 3 | output_loc=$2 4 | else 5 | file_loc=../res/generate_combination.txt 6 | output_loc=./logs/generate_combination.txt 7 | #file_loc=../res/generate_combination_optim.txt 8 | #output_loc=./logs/generate_combination_optim.txt 9 | #file_loc=../res/generate_combination_optimcons.txt 10 | #output_loc=./logs/generate_combination_optimcons.txt 11 | fi 12 | 13 | sh ../classify/indep_eval.sh $file_loc $output_loc 14 | -------------------------------------------------------------------------------- /priorcontrol/generate_combine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | from tqdm import tqdm 6 | import json 7 | import random 8 | import numpy as np 9 | 10 | from model import AE 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 15 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 16 | parser.add_argument("--no_cuda", action="store_true") 17 | parser.add_argument("--latent_size", type=int, default=768) 18 | parser.add_argument("--latent_num",type=int, default=1) 19 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 20 | parser.add_argument("--model_path", type=str, default='../model/priorcontrol/All_notricks_checkpoint-300000/pytorch_model.bin') 21 | parser.add_argument("--output_dir", type=str, default="../res/priorcontrol/generate_combination.txt") 22 | parser.add_argument("--batch_size", type=int, default=5) 23 | parser.add_argument("--pre_tokens", 24 | type=str, 25 | default=json.dumps( 26 | ['In summary','This essay discusses','Views on','The connection','Foundational to this is', 27 | 'To review,','In brief,','An illustration of','Furthermore,','The central theme', 28 | 'To conclude,','The key aspect','Prior to this','Emphasised are','To summarise', 29 | 'The relationship','More importantly,','It has been shown','The issue focused on','In this essay', 30 | 'Once upon a time','The book','The chicken','The city','The country', 31 | 'The horse','The lake','The last time','The movie','The painting', 32 | 'The pizza','The potato','The president of the country','The road','The year is 1910'] 33 | ) 34 | ) 35 | parser.add_argument("--max_length", type=int, default=50) 36 | parser.add_argument("--seed", type=int, default=0) 37 | 38 | parser.add_argument("--variation", type=float, default=0) 39 | 40 | #Parameters for Prior 41 | parser.add_argument("--prior", type=bool, default=True) 42 | parser.add_argument("--flow_num", type=int, default=8) 43 | parser.add_argument("--prior_num", type=int, default=8) 44 | 45 | 46 | #Generation 47 | parser.add_argument("--std", type=float, default=1) 48 | 49 | #Attribute Combination 50 | parser.add_argument("--weight", 51 | type=str, 52 | default=json.dumps( 53 | [1,5,1] 54 | ) 55 | ) 56 | parser.add_argument("--config", type=str, default="generate_config_combine.json") 57 | 58 | args = parser.parse_args() 59 | 60 | 61 | weight = json.loads(args.weight) 62 | std = args.std 63 | 64 | if args.config is not None: 65 | with open(args.config, 'r') as f: 66 | config = json.loads(f.read()) 67 | for keys in config: 68 | if keys == 'weight': 69 | weight = config['weight'] 70 | 71 | if keys == 'std': 72 | std = config['std'] 73 | 74 | 75 | if isinstance(weight, dict): 76 | default_weight = weight['default'] 77 | weight_dict = [[default_weight for jt in range(4)]for it in range(2)] 78 | for keys in weight: 79 | if keys != 'default': 80 | tmp_i = int(keys[0]) 81 | tmp_j = int(keys[1]) 82 | weight_dict[tmp_i][tmp_j] = weight[keys] 83 | else: 84 | weight_dict = [[weight for jt in range(4)]for it in range(2)] 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | encoder_tokenizer = BertTokenizer.from_pretrained(args.pretrained_encoder) 93 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 94 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 95 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 96 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 97 | 98 | model = AE(encoder=encoder, decoder=decoder, args=args) 99 | 100 | 101 | model.load_state_dict(torch.load(args.model_path), strict=False) 102 | model.eval() 103 | model.fix_decoder() 104 | model.set_mode('prior') 105 | 106 | random.seed(args.seed) 107 | np.random.seed(args.seed) 108 | torch.manual_seed(args.seed) 109 | 110 | 111 | if args.no_cuda: 112 | device='cpu' 113 | else: 114 | device='cuda' 115 | 116 | model.to(device) 117 | 118 | 119 | def calculate(alpha, mean, std, eps): 120 | total = sum(alpha) 121 | alpha = [num/total for num in alpha] 122 | mu = 0 123 | for w, mean in zip(alpha, mean): 124 | mu = mu + w * mean 125 | 126 | sigma = 0 127 | for w, std in zip(alpha, std): 128 | if w < 0: 129 | w = 0 130 | if w > 1: 131 | w = 1 132 | sigma = sigma + (w * std)**2 133 | sigma = torch.sqrt(sigma) 134 | 135 | sampled_dis = sigma * eps + mu 136 | return sampled_dis 137 | 138 | 139 | 140 | 141 | 142 | output_text = [] 143 | labels = [] 144 | for i in range(2): 145 | for j in range(4): 146 | weight = weight_dict[i][j] 147 | 148 | if isinstance(std, list): 149 | tmp_std = std[i*4+j] 150 | else: 151 | tmp_std = std 152 | 153 | 154 | for prompts in tqdm(json.loads(args.pre_tokens)): 155 | tokens = decoder_tokenizer(prompts, return_tensors='pt') 156 | input_ids = tokens.input_ids 157 | attention_mask = tokens.attention_mask 158 | input_ids = input_ids.expand(args.batch_size, -1) 159 | attention_mask = attention_mask.expand(args.batch_size, -1) 160 | 161 | #latents = torch.normal(0,1, (args.batch_size, args.latent_num * args.latent_size)) 162 | latents = torch.zeros(args.batch_size, args.latent_num * args.latent_size) 163 | 164 | output = model.generate( 165 | input_latent=latents, 166 | input_ids=input_ids, 167 | attention_mask=attention_mask, 168 | variation=args.variation, 169 | max_len=50, 170 | rp=1.2, 171 | prior_head_index=[i,j+2,7], 172 | alpha=weight, 173 | calculate=calculate, 174 | std=tmp_std 175 | ) 176 | 177 | output_text.extend(decoder_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)) 178 | labels.extend([[i,j,1]] * args.batch_size) 179 | assert len(labels) == len(output_text) 180 | 181 | 182 | 183 | with open(args.output_dir, 'w') as f: 184 | for i in tqdm(range(len(output_text))): 185 | f.write(json.dumps([labels[i], output_text[i]])+'\n') 186 | -------------------------------------------------------------------------------- /priorcontrol/generate_combine.sh: -------------------------------------------------------------------------------- 1 | #combine with interpolation 2 | python generate_combine.py --output_dir ../res/priorcontrol/generate_combination.txt 3 | 4 | #combine with optimization without constraints 5 | python generate_combine_optim.py --output_dir ../res/priorcontrol/generate_combination_optim.txt 6 | 7 | #combine with optimization with constraints 8 | python generate_combine_optim.py --is_constrained --output_dir ../res/priorcontrol/generate_combination_optimcons.txt -------------------------------------------------------------------------------- /priorcontrol/generate_combine_optim.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | from tqdm import tqdm 6 | import json 7 | import random 8 | import numpy as np 9 | from functools import partial 10 | import math 11 | 12 | from model import AE 13 | from latentops_modules import DIS, DIScons, sample_q_ode 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 17 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 18 | parser.add_argument("--no_cuda", action="store_true") 19 | parser.add_argument("--latent_size", type=int, default=768) 20 | parser.add_argument("--latent_num",type=int, default=1) 21 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 22 | parser.add_argument("--model_path", type=str, default='../model/priorcontrol/All_notricks_checkpoint-300000/pytorch_model.bin') 23 | parser.add_argument("--output_dir", type=str, default="../res/priorcontrol/generate_combination_optim.txt") 24 | parser.add_argument("--batch_size", type=int, default=5) 25 | parser.add_argument("--pre_tokens", 26 | type=str, 27 | default=json.dumps( 28 | ['In summary','This essay discusses','Views on','The connection','Foundational to this is', 29 | 'To review,','In brief,','An illustration of','Furthermore,','The central theme', 30 | 'To conclude,','The key aspect','Prior to this','Emphasised are','To summarise', 31 | 'The relationship','More importantly,','It has been shown','The issue focused on','In this essay', 32 | 'Once upon a time','The book','The chicken','The city','The country', 33 | 'The horse','The lake','The last time','The movie','The painting', 34 | 'The pizza','The potato','The president of the country','The road','The year is 1910'] 35 | ) 36 | ) 37 | parser.add_argument("--max_length", type=int, default=50) 38 | parser.add_argument("--seed", type=int, default=0) 39 | 40 | parser.add_argument("--variation", type=float, default=0) 41 | 42 | #Parameters for Prior 43 | parser.add_argument("--prior", type=bool, default=True) 44 | parser.add_argument("--flow_num", type=int, default=8) 45 | parser.add_argument("--prior_num", type=int, default=8) 46 | 47 | 48 | #Generation 49 | parser.add_argument("--std", type=float, default=1) 50 | 51 | #Attribute Combination 52 | parser.add_argument("--weight", 53 | type=str, 54 | default=json.dumps( 55 | [1,5,1] 56 | ) 57 | ) 58 | parser.add_argument("--config", type=str, default="generate_config_combine.json") 59 | 60 | parser.add_argument("--optim_weight", 61 | type=str, 62 | default=json.dumps( 63 | [1,1,1] 64 | ) 65 | ) 66 | parser.add_argument("--optim_config", type=str, default="generate_config_combine_optim.json") 67 | 68 | #Optimizaiton with or without constraints 69 | parser.add_argument("--is_constrained", action="store_true") 70 | args = parser.parse_args() 71 | 72 | 73 | weight = json.loads(args.weight) 74 | optim_weight = json.loads(args.optim_weight) 75 | std = args.std 76 | 77 | if args.config is not None: 78 | with open(args.config, 'r') as f: 79 | config = json.loads(f.read()) 80 | for keys in config: 81 | if keys == 'weight': 82 | weight = config['weight'] 83 | 84 | if keys == 'std': 85 | std = config['std'] 86 | 87 | if args.optim_config is not None: 88 | with open(args.optim_config, 'r') as f: 89 | optim_config = json.loads(f.read()) 90 | for keys in config: 91 | if keys == 'weight': 92 | optim_weight = optim_config['weight'] 93 | 94 | 95 | if isinstance(weight, dict): 96 | default_weight = weight['default'] 97 | weight_dict = [[default_weight for jt in range(4)]for it in range(2)] 98 | for keys in weight: 99 | if keys != 'default': 100 | tmp_i = int(keys[0]) 101 | tmp_j = int(keys[1]) 102 | weight_dict[tmp_i][tmp_j] = weight[keys] 103 | else: 104 | weight_dict = [[weight for jt in range(4)]for it in range(2)] 105 | 106 | 107 | if isinstance(optim_weight, dict): 108 | default_weight = optim_weight['default'] 109 | optim_weight_dict = [[default_weight for jt in range(4)]for it in range(2)] 110 | for keys in optim_weight: 111 | if keys != 'default': 112 | tmp_i = int(keys[0]) 113 | tmp_j = int(keys[1]) 114 | optim_weight_dict[tmp_i][tmp_j] = optim_weight[keys] 115 | else: 116 | optim_weight_dict = [[optim_weight for jt in range(4)]for it in range(2)] 117 | 118 | 119 | 120 | 121 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 122 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 123 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 124 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 125 | 126 | model = AE(encoder=encoder, decoder=decoder, args=args) 127 | 128 | 129 | model.load_state_dict(torch.load(args.model_path), strict=False) 130 | model.eval() 131 | model.fix_decoder() 132 | model.set_mode('prior') 133 | 134 | random.seed(args.seed) 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | 138 | 139 | if args.no_cuda: 140 | device='cpu' 141 | else: 142 | device='cuda' 143 | 144 | model.to(device) 145 | 146 | 147 | def calculate(alpha, mean, std, eps): 148 | total = sum(alpha) 149 | alpha = [num/total for num in alpha] 150 | mu = 0 151 | for w, mean in zip(alpha, mean): 152 | mu = mu + w * mean 153 | 154 | sigma = 0 155 | for w, std in zip(alpha, std): 156 | if w < 0: 157 | w = 0 158 | if w > 1: 159 | w = 1 160 | sigma = sigma + (w * std)**2 161 | sigma = torch.sqrt(sigma) 162 | 163 | sampled_dis = sigma * eps + mu 164 | return sampled_dis 165 | 166 | 167 | def gaussian_log_prob(x, mu, log_sd): 168 | return -0.5 * math.log(2 * torch.pi) - log_sd - 0.5 * (x - mu) ** 2 / torch.exp(2 * log_sd) 169 | 170 | priors = model.priors 171 | #disdict = [] 172 | #for prior in priors: 173 | # disdict.append(DIS([prior], [10]).to(device)) 174 | 175 | 176 | ode_kwargs = {'atol': 1e-3, 'rtol': 1e-3, 'method': 'dopri5', 'use_adjoint': True, 'latent_dim': args.latent_size} 177 | sampler = partial(sample_q_ode, device=device, **ode_kwargs) 178 | 179 | model.set_mode('normal') 180 | 181 | output_text = [] 182 | labels = [] 183 | for i in range(2): 184 | for j in range(4): 185 | weight = weight_dict[i][j] 186 | optim_weight = optim_weight_dict[i][j] 187 | 188 | if isinstance(std, list): 189 | tmp_std = std[i*4+j] 190 | else: 191 | tmp_std = std 192 | 193 | if args.is_constrained: 194 | dismodel = DIScons([priors[i], priors[2+j], priors[7]], [optim_weight[0], optim_weight[1], optim_weight[2]]).to(device) 195 | else: 196 | dismodel = DIS([priors[i], priors[2+j], priors[7]], [optim_weight[0], optim_weight[1], optim_weight[2]]).to(device) 197 | 198 | probs_raw = None 199 | probs_dis = None 200 | 201 | 202 | for prompts in tqdm(json.loads(args.pre_tokens)): 203 | tokens = decoder_tokenizer(prompts, return_tensors='pt') 204 | input_ids = tokens.input_ids 205 | attention_mask = tokens.attention_mask 206 | input_ids = input_ids.expand(args.batch_size, -1) 207 | attention_mask = attention_mask.expand(args.batch_size, -1) 208 | 209 | #latents = torch.normal(0,1, (args.batch_size, args.latent_num * args.latent_size)) 210 | #latents = torch.zeros(args.batch_size, args.latent_num * args.latent_size) 211 | eps = torch.normal(0,tmp_std, (args.batch_size, args.latent_size)).to(device) 212 | 213 | prior_head_index = [i,j+2,7] 214 | 215 | learnable_prior_mean = [priors[p_index][0] for p_index in prior_head_index] 216 | learnable_prior_std = [torch.exp(priors[p_index][1]) for p_index in prior_head_index] 217 | 218 | z_k = calculate(weight, learnable_prior_mean, learnable_prior_std, eps) 219 | 220 | if probs_raw is None: 221 | probs_raw = [gaussian_log_prob(z_k, priors[i][0], priors[i][1]).view(args.batch_size,-1).sum(-1) / args.latent_size, 222 | gaussian_log_prob(z_k, priors[j+2][0], priors[j+2][1]).view(args.batch_size,-1).sum(-1) / args.latent_size, 223 | gaussian_log_prob(z_k, priors[7][0], priors[7][1]).view(args.batch_size,-1).sum(-1) / args.latent_size 224 | ] 225 | else: 226 | probs_raw = [torch.concat([probs_raw[0], gaussian_log_prob(z_k, priors[i][0], priors[i][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0), 227 | torch.concat([probs_raw[1], gaussian_log_prob(z_k, priors[j+2][0], priors[j+2][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0), 228 | torch.concat([probs_raw[2], gaussian_log_prob(z_k, priors[7][0], priors[7][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0) 229 | ] 230 | 231 | y = torch.tensor([prior_head_index] * args.batch_size).to(device) 232 | latent = sampler(ccf=dismodel, y=y, z_k=z_k.clone()) 233 | 234 | if probs_dis is None: 235 | probs_dis = [gaussian_log_prob(latent, priors[i][0], priors[i][1]).view(args.batch_size,-1).sum(-1) / args.latent_size, 236 | gaussian_log_prob(latent, priors[j+2][0], priors[j+2][1]).view(args.batch_size,-1).sum(-1) / args.latent_size, 237 | gaussian_log_prob(latent, priors[7][0], priors[7][1]).view(args.batch_size,-1).sum(-1) / args.latent_size 238 | ] 239 | else: 240 | probs_dis = [torch.concat([probs_dis[0], gaussian_log_prob(latent, priors[i][0], priors[i][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0), 241 | torch.concat([probs_dis[1], gaussian_log_prob(latent, priors[j+2][0], priors[j+2][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0), 242 | torch.concat([probs_dis[2], gaussian_log_prob(latent, priors[7][0], priors[7][1]).view(args.batch_size,-1).sum(-1) / args.latent_size], dim=0) 243 | ] 244 | 245 | input_latent, _ = model.inv_flow(latent, rev=True) 246 | output = model.generate( 247 | input_latent=input_latent, 248 | input_ids=input_ids, 249 | attention_mask=attention_mask, 250 | variation=args.variation, 251 | max_len=50, 252 | rp=1.2 253 | ) 254 | 255 | output_text.extend(decoder_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)) 256 | labels.extend([[i,j,1]] * args.batch_size) 257 | assert len(labels) == len(output_text) 258 | 259 | if args.is_constrained: 260 | log_dir = 'logs/logs_combine_optimcons.txt' 261 | else: 262 | log_dir = 'logs/logs_combine_optim.txt' 263 | with open(log_dir, 'a') as f: 264 | f.write(str(i) + '\n') 265 | f.write('RAW:' + str([res.mean().item() for res in probs_raw]) + '\n') 266 | f.write('ODE:' + str([res.mean().item() for res in probs_dis]) + '\n') 267 | 268 | 269 | with open(args.output_dir, 'w') as f: 270 | for i in tqdm(range(len(output_text))): 271 | f.write(json.dumps([labels[i], output_text[i]])+'\n') 272 | -------------------------------------------------------------------------------- /priorcontrol/generate_config_combine.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight":{ 3 | "default":[1,1,1], 4 | "00":[2,12,1], 5 | "01":[2,6,1], 6 | "02":[2,16,1], 7 | "03":[2,1,5], 8 | "10":[14,16,0.2], 9 | "11":[28,20,0.2], 10 | "12":[20,26,0.2], 11 | "13":[6,1,1] 12 | }, 13 | "std":[ 14 | 0.5, 15 | 0.5, 16 | 0.5, 17 | 0.5, 18 | 0.2, 19 | 0.2, 20 | 0.1, 21 | 0.5 22 | ] 23 | } -------------------------------------------------------------------------------- /priorcontrol/generate_config_combine_optim.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight":{ 3 | "default":[1,1,1], 4 | "00":[2,12,1], 5 | "01":[2,6,1], 6 | "02":[2,16,1], 7 | "03":[2,1,5], 8 | "10":[13,14,1], 9 | "11":[28,20,1], 10 | "12":[20,26,1], 11 | "13":[6,1,1] 12 | } 13 | } -------------------------------------------------------------------------------- /priorcontrol/generate_prior.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | from tqdm import tqdm 6 | import json 7 | import random 8 | import numpy as np 9 | 10 | from model import AE 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 15 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 16 | parser.add_argument("--no_cuda", action="store_true") 17 | parser.add_argument("--latent_size", type=int, default=768) 18 | parser.add_argument("--latent_num",type=int, default=1) 19 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 20 | parser.add_argument("--model_path", type=str, default='../model/priorcontrol/All_notricks_checkpoint-300000/pytorch_model.bin') 21 | parser.add_argument("--output_dir", type=str, default="../res/prior/control/generate_prior.txt") 22 | parser.add_argument("--batch_size", type=int, default=5) 23 | parser.add_argument("--pre_tokens", 24 | type=str, 25 | default=json.dumps( 26 | ['In summary','This essay discusses','Views on','The connection','Foundational to this is', 27 | 'To review,','In brief,','An illustration of','Furthermore,','The central theme', 28 | 'To conclude,','The key aspect','Prior to this','Emphasised are','To summarise', 29 | 'The relationship','More importantly,','It has been shown','The issue focused on','In this essay', 30 | 'Once upon a time','The book','The chicken','The city','The country', 31 | 'The horse','The lake','The last time','The movie','The painting', 32 | 'The pizza','The potato','The president of the country','The road','The year is 1910'] 33 | ) 34 | ) 35 | parser.add_argument("--max_length", type=int, default=50) 36 | parser.add_argument("--seed", type=int, default=0) 37 | 38 | parser.add_argument("--variation", type=float, default=0) 39 | 40 | #Parameters for Prior 41 | parser.add_argument("--prior", type=bool, default=True) 42 | parser.add_argument("--flow_num", type=int, default=8) 43 | parser.add_argument("--prior_num", type=int, default=8) 44 | 45 | #extend mode for control 46 | parser.add_argument("--is_extend", action="store_true") 47 | 48 | 49 | #Generation 50 | parser.add_argument("--std", type=float, default=1) 51 | 52 | args = parser.parse_args() 53 | 54 | 55 | 56 | encoder_tokenizer = BertTokenizer.from_pretrained(args.pretrained_encoder) 57 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 58 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 59 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 60 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 61 | 62 | model = AE(encoder=encoder, decoder=decoder, args=args) 63 | 64 | 65 | model.load_state_dict(torch.load(args.model_path), strict=False) 66 | model.eval() 67 | model.fix_decoder() 68 | model.set_mode('prior') 69 | 70 | random.seed(args.seed) 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | 74 | 75 | if args.no_cuda: 76 | device='cpu' 77 | else: 78 | device='cuda' 79 | 80 | model.to(device) 81 | 82 | 83 | def calculate(alpha, mean, std, eps): 84 | mu = 0 85 | for w, mean in zip(alpha, mean): 86 | mu = mu + w * mean 87 | 88 | sigma = 0 89 | for w, std in zip(alpha, std): 90 | if w < 0: 91 | w = 0 92 | if w > 1: 93 | w = 1 94 | sigma = sigma + (w * std)**2 95 | sigma = torch.sqrt(sigma) 96 | 97 | sampled_dis = sigma * eps + mu 98 | return sampled_dis 99 | 100 | 101 | 102 | output_text = [] 103 | labels = [] 104 | 105 | l = [[0,-1,-1], [1, -1,-1], 106 | [-1,0,-1], [-1,1,-1], [-1,2,-1], [-1,3,-1], 107 | [-1,-1,0], [-1,-1,1]] 108 | 109 | alpha_list = [[1,0], [-0.2, 1.2], [1.3,-0.1,-0.1,-0.1], [-0.1, 1.3, -0.1, -0.1], [-0.1, -0.1, 1.3, -0.1], [-0.1, -0.1, -0.1, 1.3], [1,0], [-0.3, 1.3]] 110 | head_list = [[0,1], [0,1], [2,3,4,5], [2,3,4,5], [2,3,4,5], [2,3,4,5], [6,7], [6,7]] 111 | for i in range(8): 112 | #for i in [2,3,4,5]: 113 | 114 | for prompts in tqdm(json.loads(args.pre_tokens)): 115 | tokens = decoder_tokenizer(prompts, return_tensors='pt') 116 | input_ids = tokens.input_ids 117 | attention_mask = tokens.attention_mask 118 | input_ids = input_ids.expand(args.batch_size, -1) 119 | attention_mask = attention_mask.expand(args.batch_size, -1) 120 | 121 | #latents = torch.normal(0,1, (args.batch_size, args.latent_num * args.latent_size)) 122 | latents = torch.zeros(args.batch_size, args.latent_num * args.latent_size) 123 | 124 | if args.is_extend: 125 | output = model.generate( 126 | input_latent=latents, 127 | input_ids=input_ids, 128 | attention_mask=attention_mask, 129 | variation=args.variation, 130 | max_len=50, 131 | rp=1.2, 132 | prior_head_index=head_list[i], 133 | alpha=alpha_list[i], 134 | calculate=calculate, 135 | std=args.std 136 | ) 137 | else: 138 | output = model.generate( 139 | input_latent=latents, 140 | input_ids=input_ids, 141 | attention_mask=attention_mask, 142 | variation=args.variation, 143 | max_len=50, 144 | rp=1.2, 145 | prior_head_index=i, 146 | std=args.std 147 | ) 148 | 149 | output_text.extend(decoder_tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)) 150 | labels.extend([l[i]] * args.batch_size) 151 | assert len(labels) == len(output_text) 152 | 153 | 154 | 155 | with open(args.output_dir, 'w') as f: 156 | for i in tqdm(range(len(output_text))): 157 | f.write(json.dumps([labels[i], output_text[i]])+'\n') 158 | -------------------------------------------------------------------------------- /priorcontrol/generate_prior.sh: -------------------------------------------------------------------------------- 1 | #adjust the lambda 2 | #for std in 1.0 0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.0 3 | #do 4 | #python generate_prior.py --std $std --output_dir ../res/priorcontrol/generate_prior$std.txt 5 | #done 6 | 7 | #normal generation 8 | python generate_prior.py --std 0.6 --output_dir ../res/priorcontrol/generate_prior.txt 9 | 10 | #generation with extend mode 11 | python generate_prior.py --std 0.6 --is_extend --output_dir ../res/priorcontrol/generate_prior_extend.txt -------------------------------------------------------------------------------- /priorcontrol/latentops_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | from torchdiffeq import odeint_adjoint 7 | from torchdiffeq import odeint as odeint_normal 8 | 9 | 10 | class DenseEmbedder(nn.Module): 11 | """Supposed to map small-scale features (e.g. labels) to some given late 12 | \nt dim""" 13 | 14 | def __init__(self, input_dim, up_dim, depth=4, num_classes=10): 15 | super().__init__() 16 | self.net = nn.ModuleList() 17 | dims = np.linspace(input_dim, up_dim, depth).astype(int) 18 | 19 | for l in range(len(dims) - 1): 20 | self.net.append(nn.Dropout(0.2)) 21 | self.net.append(nn.Conv2d(dims[l], dims[l + 1], 1)) 22 | # self.net.append(get_norm(dims[l + 1], norm)) 23 | self.net.append(nn.LeakyReLU(0.2)) 24 | # self.net.append(nn.Tanh()) 25 | 26 | self.last_dim = up_dim 27 | self.linear = nn.Linear(up_dim, num_classes) 28 | self.energy_weight = 1 29 | # print('Using DenseEmbedder...') 30 | # print(f'{norm1} norm') 31 | 32 | def set_energy_weight(self, weight): 33 | self.energy_weight = weight 34 | # print('Energy Weight = ',weight) 35 | 36 | def forward(self, x): 37 | if x.ndim == 2: 38 | x = x[:, :, None, None] 39 | 40 | for layer in self.net: 41 | x = layer(x) 42 | 43 | out = x.squeeze(-1).squeeze(-1) 44 | out = self.linear(out) 45 | logits = out 46 | return logits 47 | 48 | 49 | class CCF(nn.Module): 50 | def __init__(self, classifier): 51 | super(CCF, self).__init__() 52 | self.f = nn.ModuleList() 53 | for cls in classifier: 54 | self.f.append(cls) 55 | 56 | def get_cond_energy(self, z, y_): 57 | energy_outs = [] 58 | # for i, cls in enumerate(self.f): 59 | for i in range(y_.shape[1]): 60 | cls = self.f[i] 61 | logits = cls(z) 62 | # logits_list.append(logits) 63 | n_classes = logits.size(1) 64 | if n_classes > 1: 65 | y = y_[:, i].long() 66 | sigle_energy = torch.gather(logits, 1, y[:, None]).squeeze() - logits.logsumexp(1) 67 | energy_outs.append(cls.energy_weight * sigle_energy) 68 | # energy_outs.append((cls.energy_weight)*(torch.gather(logits, 1, y[:, None]).squeeze() - logits.logsumexp(1))) 69 | else: 70 | assert n_classes == 1, n_classes 71 | y = y_[:, i].float() 72 | sigma = 0.1 # this value works well 73 | sigle_energy = -torch.norm(logits - y[:, None], dim=1) ** 2 * 0.5 / (sigma ** 2) 74 | energy_outs.append(cls.energy_weight * sigle_energy) 75 | # print('dog:', round(energy_outs[0].sum().item(),2), '\tchild:', round(energy_outs[1].sum().item(),2), '\tball:',round(energy_outs[2].sum().item(),2)) 76 | 77 | energy_output = torch.stack(energy_outs).sum(dim=0) 78 | return energy_output # - 0.03*torch.norm(z, dim=1) ** 2 * 0.5 79 | def get_cond_energy_single(self, z, y_): 80 | for i in range(y_.shape[1]): 81 | energy_outs = [] 82 | # for i, cls in enumerate(self.f): 83 | cls = self.f[i] 84 | logits = cls(z) 85 | # logits_list.append(logits) 86 | n_classes = logits.size(1) 87 | if n_classes > 1: 88 | y = y_[:, i].long() 89 | sigle_energy = torch.gather(logits, 1, y[:, None]).squeeze() - logits.logsumexp(1) 90 | energy_outs.append(cls.energy_weight * sigle_energy) 91 | # energy_outs.append((cls.energy_weight)*(torch.gather(logits, 1, y[:, None]).squeeze() - logits.logsumexp(1))) 92 | else: 93 | assert n_classes == 1, n_classes 94 | y = y_[:, i].float() 95 | sigma = 0.1 # this value works well 96 | sigle_energy = -torch.norm(logits - y[:, None], dim=1) ** 2 * 0.5 / (sigma ** 2) 97 | energy_outs.append(cls.energy_weight * sigle_energy) 98 | # print('dog:', round(energy_outs[0].sum().item(),2), '\tchild:', round(energy_outs[1].sum().item(),2), '\tball:',round(energy_outs[2].sum().item(),2)) 99 | 100 | energy_output = torch.stack(energy_outs).sum(dim=0) 101 | return energy_output 102 | 103 | def forward(self, z, y): 104 | energy_output = self.get_cond_energy(z, y) - torch.norm(z, dim=1) ** 2 * 0.5 105 | return energy_output 106 | 107 | 108 | class DIS(nn.Module): 109 | def __init__(self, distributions, weights=None): 110 | super(DIS, self).__init__() 111 | self.f = nn.ParameterList() 112 | for dis in distributions: 113 | self.f.append(dis) 114 | 115 | if weights is not None: 116 | self.weights = weights 117 | else: 118 | self.weights = [(1/len(distributions))] * len(distributions) 119 | 120 | def gaussian_log_prob(self, x, mu, log_sd): 121 | return -0.5 * math.log(2 * torch.pi) - log_sd - 0.5 * (x - mu) ** 2 / torch.exp(2 * log_sd) 122 | 123 | def get_cond_energy(self, z, y_): 124 | energy_outs = [] 125 | # for i, cls in enumerate(self.f): 126 | batch_size, dim = z.shape[0], z.shape[1] 127 | for i in range(y_.shape[1]): 128 | dis = self.f[i] 129 | 130 | log_prob = self.gaussian_log_prob(z, dis[0], dis[1]).view(batch_size,-1).sum(-1) #/ dim 131 | 132 | #energy = torch.exp(log_prob) * self.weights[i] 133 | energy = log_prob * self.weights[i] 134 | energy_outs.append(energy) 135 | 136 | 137 | energy_output = torch.stack(energy_outs).sum(dim=0) 138 | return energy_output # - 0.03*torch.norm(z, dim=1) ** 2 * 0.5 139 | 140 | def forward(self, z, y): 141 | energy_output = self.get_cond_energy(z, y) 142 | return energy_output 143 | 144 | 145 | class DIScons(nn.Module): 146 | def __init__(self, distributions, weights=None, eps = 8e-5): 147 | super(DIScons, self).__init__() 148 | self.f = nn.ParameterList() 149 | self.eps = eps 150 | for dis in distributions: 151 | self.f.append(dis) 152 | 153 | if weights is not None: 154 | self.weights = weights 155 | else: 156 | self.weights = [(1/len(distributions))] * len(distributions) 157 | 158 | def gaussian_log_prob(self, x, mu, log_sd): 159 | return -0.5 * math.log(2 * torch.pi) - log_sd - 0.5 * (x - mu) ** 2 / torch.exp(2 * log_sd) 160 | 161 | def get_cond_energy(self, z, y_): 162 | energy_outs = [] 163 | # for i, cls in enumerate(self.f): 164 | batch_size, dim = z.shape[0], z.shape[1] 165 | 166 | log_probs = [] 167 | 168 | for i in range(y_.shape[1]): 169 | dis = self.f[i] 170 | 171 | log_prob = self.gaussian_log_prob(z, dis[0], dis[1]).view(batch_size,-1).sum(-1) #/ dim 172 | log_probs.append(log_prob) 173 | 174 | energy = log_prob * self.weights[i] 175 | energy_outs.append(energy) 176 | 177 | energy_output = torch.stack(energy_outs).sum(dim=0) 178 | 179 | for i in range(y_.shape[1]): 180 | for j in range(y_.shape[1]): 181 | if i != j: 182 | left_prob, right_prob = log_probs[i], log_probs[j] 183 | 184 | if torch.sum((left_prob.detach() - right_prob.detach()) / dim) > self.eps: 185 | energy_output += -0.3 * (left_prob - right_prob) 186 | 187 | else: 188 | energy_output += -0.01 * (left_prob - right_prob) 189 | 190 | 191 | 192 | return energy_output # - 0.03*torch.norm(z, dim=1) ** 2 * 0.5 193 | 194 | def forward(self, z, y): 195 | energy_output = self.get_cond_energy(z, y) 196 | return energy_output 197 | 198 | 199 | class VPODE(nn.Module): 200 | def __init__(self, ccf, y, beta_min=0.1, beta_max=20, T=1.0): 201 | super().__init__() 202 | self.ccf = ccf 203 | self.beta_0 = beta_min 204 | self.beta_1 = beta_max 205 | self.T = T 206 | self.y = y 207 | 208 | 209 | def forward(self, t_k, states): 210 | z = states[0] 211 | with torch.set_grad_enabled(True): 212 | z.requires_grad_(True) 213 | beta_t = self.beta_0 + t_k * (self.beta_1 - self.beta_0) 214 | cond_energy_neg = self.ccf.get_cond_energy(z, self.y) 215 | cond_f_prime = torch.autograd.grad(cond_energy_neg.sum(), [z])[0] 216 | dz_dt = -0.5 * beta_t * cond_f_prime 217 | return dz_dt, 218 | 219 | 220 | #ode sampling 221 | def sample_q_ode(ccf, y, device=torch.device('cuda'), **kwargs): 222 | """sampling in the z space""" 223 | ccf.eval() 224 | atol = kwargs['atol'] 225 | rtol = kwargs['rtol'] 226 | method = kwargs['method'] 227 | use_adjoint = kwargs['use_adjoint'] 228 | kwargs['device'] = device 229 | # generate initial samples 230 | z_k = kwargs['z_k'] 231 | # z_k: batch x latent_dim, 232 | # y: batch 233 | # ODE function 234 | vpode = VPODE(ccf, y) 235 | states = (z_k,) 236 | if 'T' in kwargs: 237 | times = kwargs['T'] 238 | else: 239 | times = vpode.T 240 | integration_times = torch.linspace(times, 0., 2).type(torch.float32).to(device) 241 | 242 | # ODE solver 243 | odeint = odeint_adjoint if use_adjoint else odeint_normal 244 | state_t = odeint( 245 | vpode, # model 246 | states, # (z,) 247 | integration_times, 248 | atol=atol, # tolerance 249 | rtol=rtol, 250 | method=method) 251 | 252 | ccf.train() 253 | z_t0 = state_t[0][-1] 254 | # print(f'#ODE steps : {vpode.n_evals}') 255 | return z_t0.detach() 256 | 257 | 258 | #langevin dynamics sampling 259 | def sample_q_sgld(ccf, y, device=torch.device('cuda'), save_path=None, plot=None, every_n_plot=5, **kwargs): 260 | """sampling in the z space""" 261 | ccf.eval() 262 | 263 | latent_dim = kwargs['latent_dim'] 264 | sgld_lr = kwargs['sgld_lr'] 265 | sgld_std = kwargs['sgld_std'] 266 | n_steps = kwargs['n_steps'] 267 | 268 | # generate initial samples 269 | init_sample = torch.randn(y.size(0), latent_dim).to(device) 270 | x_k = torch.autograd.Variable(init_sample, requires_grad=True) 271 | 272 | # sgld 273 | for k in range(n_steps): 274 | energy_neg = ccf(x_k, y=y) 275 | f_prime = torch.autograd.grad(energy_neg.sum(), [x_k])[0] 276 | x_k.data += sgld_lr * f_prime + sgld_std * torch.randn_like(x_k) 277 | 278 | ccf.train() 279 | final_samples = x_k.detach() 280 | 281 | return final_samples,k 282 | 283 | 284 | #sde sampling 285 | def sample_q_vpsde(ccf, y, device=torch.device('cuda'), save_path=None, plot=None, every_n_plot=5, 286 | beta_min=0.1, beta_max=20, T=1, eps=1e-3, **kwargs): 287 | """sampling in the z space""" 288 | ccf.eval() 289 | 290 | latent_dim = kwargs['latent_dim'] 291 | N = kwargs['N'] 292 | correct_nsteps = kwargs['correct_nsteps'] 293 | target_snr = kwargs['target_snr'] 294 | 295 | # generate initial samples 296 | z_init = torch.FloatTensor(y.size(0), latent_dim).normal_(0, 1).to(device) 297 | z_k = torch.autograd.Variable(z_init, requires_grad=True) 298 | 299 | discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) 300 | alphas = 1. - discrete_betas 301 | timesteps = torch.linspace(T, eps, N, device=device) 302 | 303 | # vpsde 304 | for k in range(N): 305 | energy_neg = ccf(z_k, y=y) 306 | 307 | # predictor 308 | t_k = timesteps[k] 309 | timestep = (t_k * (N - 1) / T).long() 310 | beta_t = discrete_betas[timestep] 311 | alpha_t = alphas[timestep] 312 | 313 | score_t = torch.autograd.grad(energy_neg.sum(), [z_k])[0] 314 | 315 | z_k = (2 - torch.sqrt(alpha_t)) * z_k + beta_t * score_t 316 | noise = torch.FloatTensor(y.size(0), latent_dim).normal_(0, 1).to(device) 317 | z_k = z_k + torch.sqrt(beta_t) * noise 318 | 319 | # corrector 320 | for j in range(correct_nsteps): 321 | noise = torch.FloatTensor(y.size(0), latent_dim).normal_(0, 1).to(device) 322 | 323 | grad_norm = torch.norm(score_t.reshape(score_t.shape[0], -1), dim=-1).mean() 324 | noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() 325 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha_t 326 | 327 | assert step_size.ndim == 0, step_size.ndim 328 | 329 | z_k_mean = z_k + step_size * score_t 330 | z_k = z_k_mean + torch.sqrt(step_size * 2) * noise 331 | 332 | ccf.train() 333 | final_samples = z_k.detach() 334 | 335 | return final_samples, k -------------------------------------------------------------------------------- /priorcontrol/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import wandb 4 | import math 5 | import random 6 | 7 | from FrEIA.framework import SequenceINN 8 | from FrEIA.modules import AllInOneBlock 9 | 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Reshape(nn.Module): 16 | ''' 17 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 18 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, 19 | sequence_length, embed_size_per_head)`). 20 | ''' 21 | def __init__(self, arg_dict): 22 | super(Reshape, self).__init__() 23 | self.seq_len = arg_dict['seq_len'] 24 | self.num_layer = arg_dict['num_layer'] 25 | self.hidden_size = arg_dict['hidden_size'] 26 | self.num_head = arg_dict['num_head'] 27 | def forward(self, x): 28 | batch_size = x.shape[0] 29 | assert self.hidden_size % self.num_head == 0 30 | embed_size_per_head = self.hidden_size//self.num_head 31 | x = x.view(batch_size, self.num_layer, 2, self.num_head, self.seq_len, embed_size_per_head).permute(1,2,0,3,4,5) 32 | past_key_values = [] 33 | for i in range(self.num_head): 34 | past_key_values.append((x[i][0],x[i][1],)) 35 | #assert past_key_values[0][0].requires_grad == True 36 | return tuple(past_key_values) 37 | 38 | 39 | 40 | class AE(nn.Module): 41 | """ 42 | AE with Decoder Fixed. 43 | We adopt connection method from prompt tuning. 44 | """ 45 | #_keys_to_ignore_on_load_missing = [r"latent_classify_head\.\d+\.weight", r"latent_classify_head\.\d+\.bias"] 46 | def __init__(self, encoder, decoder, args): # 47 | super(AE, self).__init__() 48 | self.encoder = encoder #BertModel 49 | self.decoder = decoder #GPT2LMHeadModel 50 | 51 | self.encoder_config = encoder.config 52 | self.encoder_hidden_size = self.encoder_config.hidden_size 53 | 54 | self.decoder_config = decoder.config 55 | self.decoder_num_layer = self.decoder_config.n_layer 56 | self.decoder_hidden_size = self.decoder_config.n_embd 57 | self.decoder_num_head = self.decoder_config.n_head 58 | 59 | self.losslist = None 60 | self.latent_classify_head = None 61 | 62 | 63 | 64 | 65 | self.args = args 66 | self.latent_size = args.latent_size 67 | self.seq_len_per_latent = args.seq_len_per_latent 68 | self.latent_num = args.latent_num 69 | self.seq_len = self.latent_num * self.seq_len_per_latent 70 | if 'variation' in args: 71 | self.variation = args.variation 72 | else: 73 | self.variation = 0 74 | 75 | self.mode = 'normal' 76 | 77 | 78 | 79 | ## connector: 80 | # 1. from Bert hidden units to the latent space 81 | # 2. convert latent space to `past_key_values' in GPT 82 | # [batch_size, bert_hidden_size] -> [batch_size, latent_num * latent_size] 83 | # -> [batch_size, latent_num, decoder_layer * len([key,value]) * gpt_hidden_size] 84 | # -> (num_layer* (len([key,value])* tensor[batch_size, num_head, seq_len, embed_size_per_head])) 85 | self.trans1 = torch.nn.Sequential( 86 | torch.nn.Linear(self.encoder_hidden_size, self.latent_num * self.latent_size), 87 | torch.nn.Tanh(), 88 | nn.Dropout(self.decoder_config.attn_pdrop)#added 89 | ) 90 | self.trans2 = torch.nn.Sequential( 91 | torch.nn.Linear(self.latent_num * self.latent_size, self.seq_len * self.decoder_num_layer * 2 * self.decoder_hidden_size), 92 | nn.Dropout(self.decoder_config.attn_pdrop),#added 93 | Reshape({'seq_len':self.seq_len, 'num_layer':self.decoder_num_layer, 'hidden_size':self.decoder_hidden_size, 'num_head':self.decoder_num_head}) 94 | ) 95 | 96 | self.is_prior = False 97 | if args.prior is True: 98 | 99 | 100 | self.is_prior = True 101 | def subnet(dims_in, dims_out): 102 | return nn.Sequential( 103 | nn.Linear(dims_in, 512), 104 | nn.LeakyReLU(0.1), 105 | nn.Linear(512, dims_out), 106 | ) 107 | self.inv_flow = SequenceINN(self.latent_size * self.latent_num) 108 | for k in range(args.flow_num): 109 | self.inv_flow.append(AllInOneBlock, subnet_constructor=subnet, permute_soft=True) 110 | 111 | #learned mu and log_sigma 112 | self.priors = nn.ParameterList([nn.Parameter(torch.zeros(2, self.latent_size * self.latent_num)) for i in range(args.prior_num)]) 113 | self.prior_num = args.prior_num 114 | 115 | self.prior_classify_head = None 116 | 117 | 118 | 119 | def set_mode(self, mode): 120 | #'normal': Train the AE structure 121 | #'prior': Fix encoder and train the Normalizing Flow Prior 122 | self.mode = mode 123 | if self.mode == 'normal': 124 | for struct in [self.encoder, self.trans1, self.trans2]: 125 | struct.train() 126 | for params in struct.parameters(): 127 | params.requires_grad = True 128 | 129 | elif self.mode == 'prior': 130 | if not self.is_prior: 131 | raise Exception('Need to initialize with a prior structure.') 132 | for struct in [self.encoder, self.trans1, self.trans2]: 133 | struct.eval() 134 | for params in struct.parameters(): 135 | params.requires_grad = False 136 | 137 | print('Trainable part are listed below:') 138 | for keys, params in self.named_parameters(): 139 | if params.requires_grad is True: 140 | print(keys) 141 | 142 | 143 | def fix_decoder(self): 144 | ''' 145 | Fix the decoder to work as prefix tuning. 146 | ''' 147 | self.decoder.eval() 148 | for param in self.decoder.parameters(): 149 | param.requires_grad = False 150 | 151 | 152 | def connect(self, encoder_output, variation=0): 153 | ''' 154 | Connect encoder to decoder and get the latent representation. 155 | ''' 156 | tmp_latent = self.trans1(encoder_output) 157 | eps = torch.zeros_like(tmp_latent).normal_(std=variation).to(tmp_latent.device) 158 | past_key_values = self.trans2(tmp_latent + eps) 159 | latent = tmp_latent.view(-1, self.latent_num, self.latent_size) 160 | return past_key_values, latent 161 | 162 | 163 | def sparse_loss(self, latent, dim=None): 164 | ''' 165 | Increase the sparsity. 166 | ''' 167 | if len(latent) == 3 and dim is None: 168 | raise Exception('Expect latent to be dim 2.') 169 | loss_func = nn.L1Loss(reduction='mean') 170 | batch_size = latent.shape[0] 171 | if dim is not None: 172 | tmp_latent = latent[:,dim,:].squeeze() 173 | average = torch.sum(tmp_latent, dim=0)/batch_size 174 | loss = loss_func(latent, average.expand(batch_size, -1)) 175 | else: 176 | average = torch.sum(latent, dim=0)/batch_size 177 | loss = loss_func(latent, average.expand(batch_size, -1)) 178 | return -loss 179 | 180 | def contrasitive_loss(self, latent1, latent2, loss_func=nn.SmoothL1Loss(reduction='mean'), dim=None): 181 | ''' 182 | Increase the distance between latent1 and latent2. 183 | loss_func: nn.L1Loss, nn.SmoothL1Loss, nn.MSELoss, ... 184 | ''' 185 | if dim is not None: 186 | loss = loss_func(latent1[:,dim,:].squeeze(), latent2[:,dim,:].squeeze()) 187 | else: 188 | loss = loss_func(latent1, latent2) 189 | return -1 * loss 190 | 191 | def latent_classify_loss(self, latent, pos_label, neg_labels, head_index=None): 192 | if len(latent.shape) == 3: 193 | latent = latent.view(-1, self.latent_num*self.latent_size) 194 | if self.latent_classify_head_type == 'single': 195 | probs = torch.softmax(self.latent_classify_head(latent), dim=-1) 196 | batch_size, class_num = probs.shape 197 | loss = 0 198 | neg_len = neg_labels.shape[-1] 199 | 200 | for i in range(batch_size): 201 | pos_prob = probs[i, pos_label[i]] 202 | if pos_prob < 1/self.head_num: 203 | loss += torch.log(pos_prob) 204 | loss += torch.log(1 - probs[i, neg_labels[i]]).sum() 205 | 206 | return -1 * loss / (batch_size * (neg_len+1)) 207 | elif self.latent_classify_head_type == 'multiple': 208 | if head_index is None: 209 | print("UserWarning: head_index not set for multiple classifier head, default to 0") 210 | head_index = 0 211 | device = latent.device 212 | logits = self.latent_classify_head[head_index](latent) 213 | loss = torch.nn.functional.cross_entropy(logits, pos_label.to(device)) 214 | return loss 215 | else: 216 | raise Exception('Wrong latent classifier head type.') 217 | 218 | def aspect_gap_loss(self, latent, head_index): 219 | if len(latent.shape) == 3: 220 | latent = latent.view(-1, self.latent_num * self.latent_size) 221 | 222 | mean_latent = torch.mean(latent, dim=0) 223 | loss = None 224 | for i in range(self.aspect_head_num): 225 | if i != head_index and self.aspect_gap_head[i] is not None: 226 | if loss is None: 227 | loss = torch.nn.functional.mse_loss(mean_latent, self.aspect_gap_head[i]) * self.aspect_gap_loss_amplification 228 | else: 229 | loss += torch.nn.functional.mse_loss(mean_latent, self.aspect_gap_head[i]) * self.aspect_gap_loss_amplification 230 | self.set_aspect_gap_head(mean_latent, head_index) 231 | return loss 232 | 233 | 234 | def prior_classify_loss(self, z, pos_label, head_index): 235 | device = z.device 236 | logits = self.prior_classify_head[head_index](z) 237 | loss = torch.nn.functional.cross_entropy(logits, pos_label.to(device)) 238 | return loss 239 | 240 | 241 | def set_losslist(self, 242 | losslist:dict, 243 | latent_classify_args={'head_num':1, 'class_num_per_head':2,'mid_size':128,'head_type':'single'}, 244 | aspect_gap_args={'head_num':2, 'amplification':5}, 245 | adv_z_prob_args={'grouping':[[0,1],[2,3,4,5],[6,7]]}, 246 | prior_classify_args={'head_num':3, 'class_num_per_head':[2,2,4],'mid_size':128} 247 | ): 248 | ''' 249 | losslist: 250 | Sample: {'contrasitive_loss': 0.001, 'sparse_loss': 0.001, 'latent_classify_loss':0.1, 'aspect_gap_loss':0.1} 251 | ''' 252 | self.losslist = losslist 253 | ### AE training 254 | if 'latent_classify_loss' in losslist: 255 | self.head_num = 1 256 | class_num_per_head = 2 257 | mid_size = 128 258 | head_type = 'single' 259 | if latent_classify_args is not None: 260 | if 'head_num' in latent_classify_args: 261 | self.head_num = latent_classify_args['head_num'] 262 | if 'class_num_per_head' in latent_classify_args: 263 | class_num_per_head = latent_classify_args['class_num_per_head'] 264 | if 'mid_size' in latent_classify_args: 265 | mid_size = latent_classify_args['mid_size'] 266 | if 'head_type' in latent_classify_args: 267 | head_type = latent_classify_args['head_type'] 268 | 269 | self.set_latent_classify_head(head_num=self.head_num, class_num_per_head=class_num_per_head, mid_size=mid_size, head_type=head_type) 270 | 271 | self.latent_classify_head_type=head_type 272 | 273 | if 'aspect_gap_loss' in losslist: 274 | if 'latent_classify_loss' in losslist: 275 | if self.latent_classify_head_type == 'multiple': 276 | self.aspect_head_num = self.head_num 277 | elif self.latent_classify_head_type == 'single': 278 | print('set aspect head num to {aspect_head_num}.') 279 | self.aspect_head_num = aspect_gap_args['head_num'] 280 | else: 281 | print('set aspect head num to {aspect_head_num}.') 282 | self.aspect_head_num = aspect_gap_args['head_num'] 283 | 284 | self.aspect_gap_loss_amplification = aspect_gap_args['amplification'] 285 | 286 | self.aspect_gap_head = [None for i in range(self.aspect_head_num)] 287 | 288 | ### Normalizing Flow training 289 | 290 | if 'adv_z_prob_loss' in losslist: 291 | self.adv_prior_head_dict = {} 292 | for i in range(self.prior_num): 293 | for group in adv_z_prob_args['grouping']: 294 | if i in group: 295 | self.adv_prior_head_dict[i] = [k for k in group if k != i] 296 | break 297 | 298 | if 'adv_x_prob_loss' in losslist: 299 | pass 300 | 301 | if 'prior_classify_loss' in losslist: 302 | self.head_num = 3 303 | class_num_per_head = [2,2,4] 304 | mid_size = 128 305 | if prior_classify_args is not None: 306 | if 'head_num' in prior_classify_args: 307 | self.head_num = prior_classify_args['head_num'] 308 | if 'class_num_per_head' in prior_classify_args: 309 | class_num_per_head = prior_classify_args['class_num_per_head'] 310 | if 'mid_size' in prior_classify_args: 311 | mid_size = prior_classify_args['mid_size'] 312 | self.set_prior_classify_head(head_num=self.head_num, class_num_per_head=class_num_per_head, mid_size=mid_size) 313 | 314 | 315 | 316 | def set_prior_classify_head(self, head_num, class_num_per_head, mid_size): 317 | if type(class_num_per_head) is list: 318 | self.prior_classify_head = nn.ModuleList([ 319 | nn.Sequential( 320 | nn.Linear(self.latent_num * self.latent_size, mid_size), 321 | nn.ReLU(), 322 | nn.Linear(mid_size, head_n) 323 | ) for head_n in class_num_per_head] 324 | ) 325 | else: 326 | self.prior_classify_head = nn.ModuleList([ 327 | nn.Sequential( 328 | nn.Linear(self.latent_num * self.latent_size, mid_size), 329 | nn.ReLU(), 330 | nn.Linear(mid_size, class_num_per_head) 331 | ) for i in range(head_num)] 332 | ) 333 | 334 | 335 | def set_latent_classify_head(self, head_num=1, class_num_per_head=2, mid_size=128, head_type='single'): 336 | if head_type == 'single': 337 | self.latent_classify_head = nn.Sequential( 338 | nn.Linear(self.latent_num * self.latent_size, mid_size), 339 | nn.ReLU(), 340 | nn.Linear(mid_size, class_num_per_head * head_num) 341 | ) 342 | elif head_type == 'multiple': 343 | if type(class_num_per_head) is list: 344 | self.latent_classify_head = nn.ModuleList([ 345 | nn.Sequential( 346 | nn.Linear(self.latent_num * self.latent_size, mid_size), 347 | nn.ReLU(), 348 | nn.Linear(mid_size, head_num) 349 | ) for head_num in class_num_per_head] 350 | ) 351 | else: 352 | self.latent_classify_head = nn.ModuleList([ 353 | nn.Sequential( 354 | nn.Linear(self.latent_num * self.latent_size, mid_size), 355 | nn.ReLU(), 356 | nn.Linear(mid_size, class_num_per_head) 357 | ) for i in range(head_num)] 358 | ) 359 | 360 | def set_aspect_gap_head(self, latent, head_index): 361 | if len(latent.shape) == 3: 362 | latent = latent.view(-1, self.latent_num * self.latent_size) 363 | 364 | if len(latent.shape) == 2: 365 | mean_latent = torch.mean(latent.detach(), dim=0) 366 | if self.aspect_gap_head[head_index] is not None: 367 | assert self.aspect_gap_head[head_index].shape == mean_latent.shape 368 | self.aspect_gap_head[head_index] = mean_latent 369 | elif len(latent.shape) == 1: 370 | if self.aspect_gap_head[head_index] is not None: 371 | assert self.aspect_gap_head[head_index].shape == latent.shape 372 | self.aspect_gap_head[head_index] = latent.detach() 373 | 374 | 375 | def gaussian_log_prob(self, x, mu, log_sd): 376 | return -0.5 * math.log(2 * torch.pi) - log_sd - 0.5 * (x - mu) ** 2 / torch.exp(2 * log_sd) 377 | 378 | def gaussian_sample(self, eps, mu, log_sd): 379 | return mu + torch.exp(log_sd) * eps 380 | 381 | def cal_prior(self, latent, prior_head_index): #########################TO DO adv_prior prob######################## 382 | #latent to normal distribution 383 | if len(latent.shape) == 3: 384 | latent = latent.view(-1, self.latent_num * self.latent_size) 385 | batch_size = latent.shape[0] 386 | z, log_jac_det = self.inv_flow(latent) 387 | learnable_prior = self.priors[prior_head_index] 388 | 389 | log_prob = self.gaussian_log_prob(z, learnable_prior[0], learnable_prior[1]).view(batch_size,-1).sum(-1) 390 | 391 | loss = -log_prob - log_jac_det 392 | loss = loss.mean() / (self.latent_num * self.latent_size) 393 | return loss, z, log_prob, log_jac_det 394 | 395 | 396 | 397 | def forward(self, 398 | encoder_input_ids, 399 | encoder_attention_mask, 400 | encoder_token_type_ids, 401 | decoder_input_ids=None, 402 | decoder_attention_mask=None, 403 | adv_input_ids=None, 404 | adv_attention_mask=None, 405 | adv_token_type_ids=None, 406 | pos_label=None, 407 | neg_labels=None, 408 | variation=None, 409 | variation_prob=0.2, 410 | head_index=None, 411 | prior_head_index=None 412 | ): 413 | ''' 414 | Forward method for training which returns a reconstruction loss. 415 | Args: 416 | encoder_input_ids, 417 | encoder_atention_mask, 418 | encoder_token_type_ids: 419 | Outputs of BertTokenizer(List of Strings, return_tensors='pt', padding=True) 420 | decoder_input_ids, 421 | decoder_attention_mask: 422 | Outputs of GPT2Tokenizer(List of Strings, return_tensors='pt', padding=True) 423 | adv_input_ids, 424 | adv_attention_mask, 425 | adv_token_type_ids: 426 | Adversarial text 427 | 428 | ''' 429 | if self.mode == 'normal': 430 | if len(encoder_input_ids.shape) == 3: 431 | encoder_input_ids = encoder_input_ids.view(encoder_input_ids.shape[1], encoder_input_ids.shape[2]) 432 | encoder_attention_mask = encoder_attention_mask.view(encoder_attention_mask.shape[1], encoder_attention_mask.shape[2]) 433 | encoder_token_type_ids = encoder_token_type_ids.view(encoder_token_type_ids.shape[1], encoder_token_type_ids.shape[2]) 434 | assert decoder_input_ids is not None and decoder_attention_mask is not None 435 | decoder_input_ids = decoder_input_ids.view(decoder_input_ids.shape[1], decoder_input_ids.shape[2]) 436 | decoder_attention_mask = decoder_attention_mask.view(decoder_attention_mask.shape[1], decoder_attention_mask.shape[2]) 437 | if head_index is not None: 438 | head_index = head_index.item() 439 | if adv_input_ids is not None: 440 | adv_input_ids = adv_input_ids.view(adv_input_ids.shape[1], adv_input_ids.shape[2]) 441 | if adv_attention_mask is not None: 442 | adv_attention_mask = adv_attention_mask.view(adv_attention_mask.shape[1], adv_attention_mask.shape[2]) 443 | if adv_token_type_ids is not None: 444 | adv_token_type_ids = adv_token_type_ids.view(adv_token_type_ids.shape[1], adv_token_type_ids.shape[2]) 445 | if pos_label is not None: 446 | pos_label = pos_label.view(pos_label.shape[1]) 447 | if neg_labels is not None: 448 | neg_labels = neg_labels.view(neg_labels.shape[1], neg_labels.shape[2]) 449 | 450 | if variation is None: 451 | variation = self.variation 452 | batch_size = decoder_input_ids.shape[0] 453 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(decoder_input_ids.device) 454 | decoder_attention_mask = torch.cat([infix_attn, decoder_attention_mask], dim=1) 455 | 456 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 457 | past_key_values, latent = self.connect(encoder_output, variation) 458 | outputs = self.decoder(input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, labels=decoder_input_ids, past_key_values=past_key_values, return_dict=True) 459 | lm_loss = outputs.loss 460 | # lm_logits = outputs.logits 461 | 462 | #labels = decoder_input_ids 463 | # Shift so that tokens < n predict n 464 | #shift_logits = lm_logits[..., self.seq_len:-1, :].contiguous()#########lm_logits[..., :-1, :] -> lm_logits[..., self.seq_len:-1, :] 465 | #shift_labels = labels[..., 1:].contiguous() 466 | # Flatten the tokens 467 | #loss_fct = torch.nn.CrossEntropyLoss() 468 | #lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 469 | 470 | #logits = outputs.lm_logits 471 | loss = 0 472 | loss_detail = {"lm_loss":lm_loss.detach()} 473 | w = 1 474 | if self.losslist is not None: 475 | if 'contrasitive_loss' in self.losslist: 476 | if adv_input_ids is None: 477 | raise Exception('Expect adversarial inputs for contrasitive loss.') 478 | adv_encoder_output = self.encoder(input_ids=adv_input_ids, attention_mask=adv_attention_mask, token_type_ids=adv_token_type_ids, return_dict=True).pooler_output 479 | adv_latent = self.trans1(adv_encoder_output) 480 | adv_loss = self.contrasitive_loss(latent, adv_latent) 481 | #TO DO: change arg `dim' in the future 482 | loss += adv_loss * self.losslist['contrasitive_loss'] 483 | w -= self.losslist['contrasitive_loss'] 484 | loss_detail["contrasitive_loss"] = adv_loss.detach() 485 | 486 | if 'sparse_loss' in self.losslist: 487 | spa_loss = self.sparse_loss(latent) 488 | #TO DO: change arg `dim' in the future 489 | loss += spa_loss * self.losslist['sparse_loss'] 490 | w -= self.losslist['sparse_loss'] 491 | loss_detail["sparse_loss"] = spa_loss.detach() 492 | 493 | if 'latent_classify_loss' in self.losslist: 494 | lac_loss = self.latent_classify_loss(latent, pos_label, neg_labels, head_index) 495 | if lac_loss.detach().item() < 0.1: 496 | loss += lac_loss * 0.05 497 | w -= 0.05 498 | else: 499 | loss += lac_loss * self.losslist['latent_classify_loss'] 500 | w -= self.losslist['latent_classify_loss'] 501 | loss_detail["latent_classify_loss"] = lac_loss.detach() 502 | 503 | 504 | if 'aspect_gap_loss' in self.losslist: 505 | agp_loss = self.aspect_gap_loss(latent, head_index) 506 | if agp_loss is not None: 507 | loss += agp_loss * self.losslist['aspect_gap_loss'] 508 | w -= self.losslist['aspect_gap_loss'] 509 | loss_detail["aspect_gap_loss"] = agp_loss.detach() 510 | 511 | wandb.log(loss_detail) 512 | if w < 0: 513 | w = 1 514 | 515 | loss += w * lm_loss 516 | 517 | return loss, latent, loss_detail 518 | 519 | elif self.mode == 'prior': 520 | if len(encoder_input_ids.shape) == 3: 521 | encoder_input_ids = encoder_input_ids.view(encoder_input_ids.shape[1], encoder_input_ids.shape[2]) 522 | encoder_attention_mask = encoder_attention_mask.view(encoder_attention_mask.shape[1], encoder_attention_mask.shape[2]) 523 | encoder_token_type_ids = encoder_token_type_ids.view(encoder_token_type_ids.shape[1], encoder_token_type_ids.shape[2]) 524 | if decoder_input_ids is not None: 525 | print('Notion: Prior is guided Only by fixed encoder') 526 | if prior_head_index is not None: 527 | prior_head_index = prior_head_index.item() 528 | else: 529 | if self.prior_num > 1: 530 | print('Warning: prior head index is required if `prior_num` is not 1. Default index is 0.') 531 | prior_head_index = 0 532 | 533 | latent, _, _ = self.encode(encoder_input_ids=encoder_input_ids, encoder_attention_mask=encoder_attention_mask, encoder_token_type_ids=encoder_token_type_ids) 534 | if variation is None: 535 | variation = self.variation 536 | if variation > 0: 537 | choice = random.uniform(0,1) < variation_prob 538 | if choice: 539 | eps = torch.zeros_like(latent).normal_(std=variation).to(latent.device) 540 | latent = latent + eps 541 | loss, z, log_prob, log_jac_det = self.cal_prior(latent, prior_head_index) 542 | 543 | loss_detail = {"loss":loss.detach(), "log_prob":log_prob.detach().mean(), "log_det":log_jac_det.detach().mean()} 544 | 545 | if self.losslist is not None: 546 | 547 | if 'adv_z_prob_loss' in self.losslist: 548 | adv_index_list = self.adv_prior_head_dict[prior_head_index] 549 | adv_z_loss = 0 550 | inds = 0 551 | for adv_head_index in adv_index_list: 552 | tmp_adv_z_loss, _, adv_log_prob, adv_log_det = self.cal_prior(latent, adv_head_index) 553 | adv_z_loss += -1 * adv_log_prob.mean() / (self.latent_num * self.latent_size) 554 | inds += 1 555 | 556 | adv_z_loss = adv_z_loss / inds 557 | 558 | if adv_z_loss.item() < 10: 559 | loss = loss - adv_z_loss * self.losslist['adv_z_prob_loss'] 560 | 561 | loss_detail["adv_z_loss"] = adv_z_loss.detach() 562 | else: 563 | loss_detail["adv_z_loss"] = 10 564 | 565 | if 'adv_x_prob_loss' in self.losslist: 566 | if len(adv_input_ids.shape) == 3: 567 | adv_input_ids = adv_input_ids.view(adv_input_ids.shape[1], adv_input_ids.shape[2]) 568 | adv_attention_mask = adv_attention_mask.view(adv_attention_mask.shape[1], adv_attention_mask.shape[2]) 569 | adv_token_type_ids = adv_token_type_ids.view(adv_token_type_ids.shape[1], adv_token_type_ids.shape[2]) 570 | 571 | adv_latent, _, _ = self.encode(encoder_input_ids=adv_input_ids, encoder_attention_mask=adv_attention_mask, encoder_token_type_ids=adv_token_type_ids) 572 | if variation > 0: 573 | choice = random.uniform(0,1) < variation_prob 574 | if choice: 575 | eps = torch.zeros_like(adv_latent).normal_(std=variation).to(adv_latent.device) 576 | adv_latent = adv_latent + eps 577 | 578 | adv_x_loss, _, _, _ = self.cal_prior(adv_latent, prior_head_index) 579 | 580 | if adv_x_loss.item() < 10: 581 | loss = loss - adv_x_loss * self.losslist['adv_x_prob_loss'] 582 | 583 | loss_detail["adv_x_loss"] = adv_x_loss.detach() 584 | else: 585 | loss_detail["adv_x_loss"] = 10 586 | 587 | if 'prior_classify_loss' in self.losslist: 588 | assert head_index is not None 589 | head_index = head_index.item() 590 | 591 | assert pos_label is not None 592 | pos_label = pos_label.view(pos_label.shape[1]) 593 | 594 | prior_cls_loss = self.prior_classify_loss(z, pos_label, head_index) 595 | 596 | loss = loss + prior_cls_loss 597 | 598 | loss_detail["prior_cls_loss"] = prior_cls_loss.detach() 599 | 600 | 601 | 602 | 603 | 604 | 605 | wandb.log(loss_detail) 606 | return (loss, z) 607 | 608 | 609 | 610 | def encode(self, 611 | encoder_input_ids, 612 | encoder_attention_mask=None, 613 | encoder_token_type_ids=None, 614 | ): 615 | ''' 616 | Encode the input text and get the latent representation 617 | ''' 618 | device = next(self.parameters()).device 619 | encoder_input_ids = encoder_input_ids.to(device) 620 | if encoder_attention_mask is not None: 621 | encoder_attention_mask = encoder_attention_mask.to(device) 622 | if encoder_token_type_ids is not None: 623 | encoder_token_type_ids = encoder_token_type_ids.to(device) 624 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 625 | past_key_values, latent = self.connect(encoder_output) 626 | return latent, encoder_output, past_key_values 627 | 628 | 629 | def generate( 630 | self, 631 | input_latent, 632 | input_ids=None, 633 | attention_mask=None, 634 | batch_size=None, 635 | variation=None, 636 | min_len=30, 637 | max_len=50, 638 | do_sample=True, 639 | topk=5, 640 | topp=0.9, 641 | lp=1, 642 | rp=1.0, 643 | prior_head_index=None, 644 | classify_head_index=None, 645 | threshold_for_cls=None, 646 | mean=0, 647 | std=1, 648 | alpha=None, 649 | calculate=None, 650 | use_cache=True): 651 | ''' 652 | Generate text with given latent represention. 653 | ''' 654 | device = next(self.parameters()).device 655 | input_latent = input_latent.to(device) 656 | 657 | 658 | if len(input_latent.shape) == 3: 659 | tmp_batch_size, latent_num, latent_size = input_latent.shape 660 | input_latent = input_latent.view(tmp_batch_size, latent_num * latent_size) 661 | elif len(input_latent.shape) != 2: 662 | raise Exception('Shape of input_latent is expected to be [batch_size, latent_num, latent_size] \ 663 | or [batch_size, latent_num * latent_size]') 664 | 665 | 666 | if batch_size is None: 667 | batch_size = input_latent.shape[0] 668 | if input_ids is not None: 669 | if input_ids.shape[0] > batch_size: 670 | if batch_size == 1: 671 | batch_size = input_ids.shape[0] 672 | input_latent = input_latent.expand(batch_size, -1) 673 | else: 674 | raise Exception('Batch size of input_latent and input_ids mismatched') 675 | elif input_ids.shape[0] < batch_size and input_ids.shape[0] == 1: 676 | input_ids = input_ids.expand(batch_size, -1) 677 | 678 | 679 | 680 | if input_latent.shape[0] < batch_size: 681 | input_latent.expand(batch_size, -1) 682 | 683 | 684 | if self.mode == 'prior': 685 | batch_size, latent_size = input_latent.shape 686 | tmp_input_latent = self.get_prior_sampling(batch_size, latent_size, mean=mean, std=std, prior_head_index=prior_head_index, alpha=alpha, calculate=calculate) 687 | 688 | #if prior_head_index is not None: 689 | # learnable_prior = self.priors[prior_head_index] 690 | # sampled_dis = self.gaussian_sample(input_latent, learnable_prior[0], learnable_prior[1]) 691 | # tmp_input_latent, _ = self.inv_flow(sampled_dis, rev=True) 692 | #else: 693 | # tmp_input_latent, _ = self.inv_flow(input_latent, rev=True) 694 | 695 | if threshold_for_cls is not None and self.latent_classify_head is not None and classify_head_index is not None and prior_head_index is not None: 696 | prob = torch.softmax(self.latent_classify_head[classify_head_index[0]](tmp_input_latent), dim=-1)[:, classify_head_index[1]] 697 | mask = prob > threshold_for_cls 698 | candidate_latent = tmp_input_latent[mask,:] 699 | while candidate_latent.shape[0] < batch_size: 700 | tmp_input_latent = self.get_prior_sampling(batch_size * 10, latent_size, mean=mean, std=std, prior_head_index=prior_head_index, alpha=alpha, calculate=calculate) 701 | #ext_latent = torch.zeros_like(input_latent).repeat(10, 1).normal_(std=1).to(device) 702 | #sampled_dis = self.gaussian_sample(ext_latent, learnable_prior[0], learnable_prior[1]) 703 | #tmp_input_latent, _ = self.inv_flow(sampled_dis, rev=True) 704 | prob = torch.softmax(self.latent_classify_head[classify_head_index[0]](tmp_input_latent), dim=-1)[:, classify_head_index[1]] 705 | mask = prob > threshold_for_cls 706 | candidate_latent = torch.cat([candidate_latent, tmp_input_latent[mask,:]], dim=0) 707 | del tmp_input_latent 708 | del prob 709 | del mask 710 | torch.cuda.empty_cache() 711 | 712 | 713 | candidate_latent = candidate_latent[:batch_size,:] 714 | #print(candidate_latent.shape) 715 | 716 | input_latent = candidate_latent 717 | 718 | else: 719 | input_latent = tmp_input_latent 720 | 721 | 722 | if variation is not None: 723 | eps = torch.zeros_like(input_latent).normal_(std=variation).to(device) 724 | input_latent = input_latent + eps 725 | 726 | 727 | 728 | 729 | 730 | past_key_values = self.trans2(input_latent) 731 | 732 | if input_ids is None: 733 | input_ids = self.decoder.generate(input_ids=torch.LongTensor([[50256]]*batch_size).to(device), max_length=3, do_sample=True)[:,1:] 734 | attention_mask = torch.ones(batch_size, 2).bool() 735 | else: 736 | input_ids = input_ids.to(device) 737 | if attention_mask is None: 738 | attention_mask = torch.ones(batch_size, 2).bool() 739 | 740 | cur_len = input_ids.shape[1] 741 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(device) 742 | attention_mask = torch.cat([infix_attn, attention_mask.to(device)], dim=-1) 743 | 744 | if cur_len < 1: 745 | raise Exception('input length error') 746 | if cur_len == 1: 747 | result = self.decoder.generate(input_ids=input_ids, past=past_key_values, attention_mask=attention_mask, repetition_penalty=rp,\ 748 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 749 | else: 750 | past_key_values = self.decoder(input_ids=input_ids[:,:-1], attention_mask=attention_mask[:,:-1], past_key_values=past_key_values, return_dict=True, use_cache=True).past_key_values 751 | result = self.decoder.generate(input_ids=input_ids, past=past_key_values, attention_mask=attention_mask, repetition_penalty=rp,\ 752 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 753 | 754 | return result 755 | 756 | 757 | def get_prior_sampling(self, 758 | batch_size, 759 | latent_size, 760 | mean=0, 761 | std=1, 762 | prior_head_index=None, 763 | alpha=None, 764 | calculate=None 765 | ): 766 | device = next(self.parameters()).device 767 | eps = torch.normal(mean,std, (batch_size, latent_size)).to(device) 768 | if type(prior_head_index) is int: 769 | learnable_prior = self.priors[prior_head_index] 770 | sampled_dis = self.gaussian_sample(eps, learnable_prior[0], learnable_prior[1]) 771 | input_latent, _ = self.inv_flow(sampled_dis, rev=True) 772 | elif type(prior_head_index) is list: 773 | learnable_prior_mean = [self.priors[p_index][0] for p_index in prior_head_index] 774 | learnable_prior_std = [torch.exp(self.priors[p_index][1]) for p_index in prior_head_index] 775 | if alpha is None: 776 | alpha = ([1/len(prior_head_index)] * len(prior_head_index)) 777 | else: 778 | assert len(alpha) == len(prior_head_index) 779 | if calculate is not None: 780 | sampled_dis = calculate(alpha, learnable_prior_mean, learnable_prior_std, eps) 781 | 782 | else: 783 | def weighted_combine(alpha, learnable_prior_mean,learnable_prior_std, eps): 784 | mu = 0 785 | for w, mean in zip(alpha, learnable_prior_mean): 786 | mu = mu + w * mean 787 | 788 | sigma = 0 789 | for w, std in zip(alpha, learnable_prior_std): 790 | sigma = sigma + (w * std)**2 791 | sigma = torch.sqrt(sigma) 792 | 793 | sampled_dis = sigma * eps + mu 794 | return sampled_dis 795 | sampled_dis = weighted_combine(alpha, learnable_prior_mean, learnable_prior_std, eps) 796 | input_latent, _ = self.inv_flow(sampled_dis, rev=True) 797 | elif prior_head_index is None: 798 | input_latent, _ = self.inv_flow(eps, rev=True) 799 | 800 | 801 | return input_latent 802 | 803 | 804 | 805 | def reconstruct(self, 806 | encoder_input_ids, 807 | decoder_input_ids=None, 808 | encoder_attention_mask=None, 809 | encoder_token_type_ids=None, 810 | decoder_attention_mask=None, 811 | do_sample=True, 812 | max_len=50, 813 | min_len=30, 814 | topk=5, 815 | topp=0.9, 816 | lp=1.0, 817 | use_cache=True): 818 | ''' 819 | Reconstruct input text. 820 | ''' 821 | device = next(self.parameters()).device 822 | batch_size = encoder_input_ids.shape[0] 823 | encoder_input_ids = encoder_input_ids.to(device) 824 | if encoder_attention_mask is not None: 825 | encoder_attention_mask = encoder_attention_mask.to(device) 826 | if encoder_token_type_ids is not None: 827 | encoder_token_type_ids = encoder_token_type_ids.to(device) 828 | encoder_output = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, token_type_ids=encoder_token_type_ids, return_dict=True).pooler_output 829 | 830 | past_key_values, latent = self.connect(encoder_output) 831 | if decoder_input_ids is None: 832 | decoder_input_ids = self.decoder.generate(input_ids=torch.LongTensor([[50256]]*batch_size).to(device), max_length=3, do_sample=True)[:,1:] 833 | decoder_attention_mask = torch.ones(batch_size, 2).bool() 834 | else: 835 | decoder_input_ids = decoder_input_ids.to(device) 836 | if decoder_attention_mask is None: 837 | decoder_attention_mask = torch.ones(batch_size, 2).bool() 838 | 839 | cur_len = decoder_input_ids.shape[1] 840 | 841 | infix_attn = torch.ones(batch_size, self.seq_len).bool().to(device) 842 | decoder_attention_mask = torch.cat([infix_attn, decoder_attention_mask.to(device)], dim=-1) 843 | 844 | if cur_len < 1: 845 | raise Exception('input length error') 846 | if cur_len == 1: 847 | result = self.decoder.generate(input_ids=decoder_input_ids, past=past_key_values, attention_mask=decoder_attention_mask,\ 848 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 849 | else: 850 | past_key_values = self.decoder(input_ids=decoder_input_ids[:,:-1], attention_mask=decoder_attention_mask[:,:-1], past_key_values=past_key_values, return_dict=True, use_cache=True).past_key_values 851 | result = self.decoder.generate(input_ids=decoder_input_ids, past=past_key_values, attention_mask=decoder_attention_mask,\ 852 | do_sample=do_sample, top_k=topk, top_p=topp, length_penalty=lp, max_length=max_len, min_length=min_len, use_cache=use_cache) 853 | return result 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | -------------------------------------------------------------------------------- /priorcontrol/requirements.txt: -------------------------------------------------------------------------------- 1 | python=3.7.12 2 | apex==0.1 3 | datasets==2.0.0 4 | huggingface-hub==0.4.0 5 | nltk==3.7 6 | numpy==1.21.5 7 | scikit-learn==1.0.2 8 | scipy==1.7.3 9 | sentencepiece==0.1.96 10 | tokenizers==0.11.6 11 | torch==1.10.0 12 | tqdm==4.63.1 13 | transformers==4.17.0 14 | wandb==0.12.11 15 | FrEIA==0.2 16 | torchdiffeq==0.2.3 17 | -------------------------------------------------------------------------------- /priorcontrol/single_eval.sh: -------------------------------------------------------------------------------- 1 | if [ $# == 2 ]; then 2 | file_loc=$1 3 | output_loc=$2 4 | else 5 | file_loc=../res/generate_prior.txt 6 | output_loc=./logs/generate_prior.txt 7 | #file_loc=../res/generate_prior_extend.txt 8 | #output_loc=./logs/generate_prior_extend.txt 9 | fi 10 | 11 | 12 | 13 | 14 | echo 0 >> $output_loc 15 | python ../classify/eval.py --file_loc $file_loc --specify [0,-1,-1] >> $output_loc 16 | 17 | echo 1 >> $output_loc 18 | python ../classify/eval.py --file_loc $file_loc --specify [1,-1,-1] >> $output_loc 19 | 20 | echo 2 >> $output_loc 21 | python ../classify/eval.py --file_loc $file_loc --specify [-1,0,-1] >> $output_loc 22 | 23 | echo 3 >> $output_loc 24 | python ../classify/eval.py --file_loc $file_loc --specify [-1,1,-1] >> $output_loc 25 | 26 | echo 4 >> $output_loc 27 | python ../classify/eval.py --file_loc $file_loc --specify [-1,2,-1] >> $output_loc 28 | 29 | echo 5 >> $output_loc 30 | python ../classify/eval.py --file_loc $file_loc --specify [-1,3,-1] >> $output_loc 31 | 32 | echo 6 >> $output_loc 33 | python ../classify/eval.py --file_loc $file_loc --specify [-1,-1,0] >> $output_loc 34 | 35 | echo 7 >> $output_loc 36 | python ../classify/eval.py --file_loc $file_loc --specify [-1,-1,1] >> $output_loc 37 | 38 | -------------------------------------------------------------------------------- /priorcontrol/train_prior_only.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import transformers 4 | from transformers import GPT2LMHeadModel, BertModel, GPT2Tokenizer, BertTokenizer 5 | import datasets 6 | from datasets import load_dataset, load_metric, concatenate_datasets, Dataset 7 | from transformers import Trainer, TrainingArguments 8 | from tqdm import tqdm 9 | import json 10 | import wandb 11 | 12 | import random 13 | 14 | from model import AE 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--pretrained_encoder", type=str, default="bert-base-uncased") 19 | parser.add_argument("--pretrained_decoder", type=str, default="gpt2-medium") 20 | parser.add_argument('--model_dir', default='../model/priorcontrol/') 21 | parser.add_argument("--no_cuda", action="store_true") 22 | parser.add_argument("--latent_size", type=int, default=768) 23 | parser.add_argument("--latent_num",type=int, default=1) 24 | parser.add_argument("--seq_len_per_latent",type=int, default=20) 25 | parser.add_argument("--batch_size", type=int, default=100) 26 | parser.add_argument("--epoch",type=int, default=200) 27 | parser.add_argument("--lr",type=float, default=1e-4) 28 | parser.add_argument("--fp16", action="store_true") 29 | parser.add_argument("--wandb", action="store_true") 30 | parser.add_argument("--no_fix", action="store_true") 31 | parser.add_argument("--max_length", type=int, default=100) 32 | parser.add_argument("--model_path", type=str, default='../model/multicontrol/checkpoint-30000/pytorch_model.bin') 33 | parser.add_argument("--variation", type=float, default=1e-3) 34 | 35 | #prior 36 | parser.add_argument("--not_prior", action="store_true") 37 | parser.add_argument("--flow_num", type=int, default=8) 38 | parser.add_argument("--prior_num", type=int, default=8) 39 | 40 | #adv_z_prior 41 | parser.add_argument("--adv_z_prob_loss", type=float, default=None) 42 | parser.add_argument("--adv_z_prob_grouping", default=json.dumps([[0,1],[2,3,4,5],[6,7]])) 43 | 44 | #adv_x_prior 45 | parser.add_argument("--adv_x_prob_loss", type=float, default=None) 46 | 47 | #prior_classify_loss 48 | 49 | parser.add_argument("--prior_classify_loss", type=float, default=0.3) 50 | parser.add_argument("--prior_classifier_head_num", type=int, default=3) 51 | parser.add_argument("--prior_classifier_class_num_per_head", type=str, default=json.dumps([2,2,4])) 52 | parser.add_argument("--prior_classifier_mid_size", type=int, default=128) 53 | 54 | args = parser.parse_args() 55 | 56 | if not args.not_prior: 57 | args.prior = True 58 | else: 59 | args.prior = False 60 | 61 | if args.wandb: 62 | wandb.login() 63 | wandb.init(project="", entity="")#your account 64 | 65 | 66 | 67 | 68 | adv_z_prob_args = None 69 | prior_classify_args = None 70 | 71 | loss_list = {} 72 | if args.adv_z_prob_loss is not None: 73 | loss_list['adv_z_prob_loss'] = args.adv_z_prob_loss 74 | adv_z_prob_args = { 75 | "grouping": json.loads(args.adv_z_prob_grouping) 76 | } 77 | 78 | if args.adv_x_prob_loss is not None: 79 | loss_list['adv_x_prob_loss'] = args.adv_x_prob_loss 80 | 81 | if args.prior_classify_loss is not None: 82 | loss_list['prior_classify_loss'] = args.prior_classify_loss 83 | prior_classify_args = { 84 | 'head_num':args.prior_classifier_head_num, 85 | 'class_num_per_head':json.loads(args.prior_classifier_class_num_per_head), 86 | 'mid_size':args.prior_classifier_mid_size 87 | } 88 | 89 | 90 | 91 | 92 | encoder_tokenizer = BertTokenizer.from_pretrained(args.pretrained_encoder) 93 | encoder = BertModel.from_pretrained(args.pretrained_encoder) 94 | decoder_tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_decoder) 95 | decoder = GPT2LMHeadModel.from_pretrained(args.pretrained_decoder) 96 | decoder_tokenizer.pad_token = decoder_tokenizer.eos_token 97 | 98 | 99 | 100 | 101 | model = AE(encoder=encoder, decoder=decoder, args=args) 102 | model.load_state_dict(torch.load(args.model_path), strict=False) 103 | 104 | 105 | model.set_losslist(loss_list, adv_z_prob_args=adv_z_prob_args, prior_classify_args=prior_classify_args) 106 | 107 | 108 | if not args.no_fix: 109 | model.fix_decoder() 110 | 111 | model.set_mode('prior') 112 | 113 | 114 | dataset = [{'sent':[]} for i in range(8)] 115 | 116 | with open('../data/IMDb/IMDb.txt', 'r') as f: 117 | for line in f.readlines(): 118 | line = json.loads(line) 119 | label = int(line[0]) 120 | dataset[0+label]['sent'].append(line[1].strip()) 121 | #dataset[0]['type'].append(int(line[0])) 122 | 123 | with open('../data/AGnews/AG-data.txt', 'r') as f: 124 | for line in f.readlines(): 125 | line = json.loads(line) 126 | label = int(line[0]) 127 | dataset[2+label]['sent'].append(line[1].strip()) 128 | #dataset[1]['type'].append(int(line[0])) 129 | 130 | with open('../data/ToxicComment/Toxic.txt', 'r') as f: 131 | for line in f.readlines(): 132 | line = json.loads(line) 133 | label = int(line[0]) 134 | dataset[6+label]['sent'].append(line[1].strip()) 135 | #dataset[2]['type'].append(int(line[0])) 136 | 137 | 138 | 139 | columns = ['encoder_input_ids', 'encoder_attention_mask', 'encoder_token_type_ids'] 140 | adv_columns = ['adv_input_ids', 'adv_attention_mask', 'adv_token_type_ids'] 141 | 142 | if args.adv_x_prob_loss is not None: 143 | train_dataset = {i:[] for i in (columns + adv_columns)} 144 | else: 145 | train_dataset = {i:[] for i in columns} 146 | 147 | if args.prior_classify_loss is not None: 148 | train_dataset['head_index'] = [] 149 | train_dataset['pos_label'] = [] 150 | train_dataset['prior_head_index']=[] 151 | 152 | #for i in range(8): 153 | 154 | if args.adv_x_prob_loss is not None: 155 | #####adv_data 156 | adv_sent = dataset[3]['sent'] + dataset[4]['sent'] + dataset[5]['sent'] 157 | random.shuffle(adv_sent) 158 | adv_dataset = {'sent':adv_sent} 159 | 160 | tmp_dataset = Dataset.from_dict(adv_dataset) 161 | tmp_dataset = tmp_dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 162 | tmp_dataset = tmp_dataset.rename_columns({'input_ids':'adv_input_ids', 'attention_mask':'adv_attention_mask', 'token_type_ids':'adv_token_type_ids'}) 163 | tmp_dataset.set_format(type='torch', columns=['adv_input_ids', 'adv_token_type_ids', 'adv_attention_mask']) 164 | adv_dataloader = torch.utils.data.DataLoader(tmp_dataset, batch_size=(args.batch_size * 3)) 165 | 166 | 167 | if args.prior_classify_loss is not None: 168 | label_dict=[[0,0],[0,1],[2,0],[2,1],[2,2],[2,3],[1,0],[1,1]] 169 | 170 | 171 | #for i in [2,3,4,5]: 172 | for i in range(8): 173 | tmp_dataset = Dataset.from_dict(dataset[i]) 174 | tmp_dataset = tmp_dataset.map(lambda e: encoder_tokenizer(e['sent'], max_length=args.max_length, padding='max_length', truncation=True), batched=True) 175 | tmp_dataset = tmp_dataset.rename_columns({'input_ids':'encoder_input_ids', 'attention_mask':'encoder_attention_mask', 'token_type_ids':'encoder_token_type_ids'}) 176 | tmp_dataset.set_format(type='torch', columns=['encoder_input_ids', 'encoder_token_type_ids', 'encoder_attention_mask']) 177 | tmp_dataloader = torch.utils.data.DataLoader(tmp_dataset, batch_size=args.batch_size) 178 | for cnt in iter(tmp_dataloader): 179 | for k in columns: 180 | train_dataset[k].append(cnt[k].tolist()) 181 | train_dataset['prior_head_index'].append(i) 182 | if args.prior_classify_loss is not None: 183 | train_dataset['head_index'].append(label_dict[i][0]) 184 | train_dataset['pos_label'].append([label_dict[i][1]]*args.batch_size) 185 | 186 | if args.adv_x_prob_loss is not None: 187 | for adv_cnt in iter(adv_dataloader): 188 | for k in adv_columns: 189 | train_dataset[k].append(adv_cnt[k].tolist()) 190 | 191 | 192 | 193 | 194 | train_dataset = Dataset.from_dict(train_dataset) 195 | 196 | ult_columns = columns + ['prior_head_index'] 197 | if args.adv_x_prob_loss is not None: 198 | ult_columns = ult_columns + adv_columns 199 | #train_dataset.set_format(columns=columns+adv_columns+['prior_head_index']) 200 | if args.prior_classify_loss is not None: 201 | ult_columns = ult_columns + ['head_index', 'pos_label'] 202 | #train_dataset.set_format(columns=columns+['prior_head_index']) 203 | 204 | train_dataset.set_format(columns=ult_columns) 205 | 206 | 207 | 208 | 209 | 210 | 211 | training_args = TrainingArguments( 212 | output_dir=args.model_dir, 213 | learning_rate=args.lr, 214 | num_train_epochs=args.epoch, 215 | #gradient_accumulation_steps=4, 216 | per_device_train_batch_size=1, 217 | logging_dir='./logs', 218 | logging_steps=100, 219 | do_train=True, 220 | do_eval=False, 221 | no_cuda=args.no_cuda, 222 | save_strategy="steps", 223 | save_steps=5000, 224 | fp16=args.fp16, 225 | report_to='wandb' if args.wandb else 'none' 226 | ) 227 | trainer = Trainer( 228 | model=model, 229 | args=training_args, 230 | train_dataset=train_dataset 231 | ) 232 | train_out = trainer.train() 233 | -------------------------------------------------------------------------------- /priorcontrol/train_prior_only.sh: -------------------------------------------------------------------------------- 1 | python train_prior_only.py --fp16 --wandb --lr 1e-4 --epoch 1000 --------------------------------------------------------------------------------