├── requirements.txt ├── __pycache__ ├── arguments.cpython-38.pyc ├── data_pool.cpython-38.pyc └── token_gpt2.cpython-38.pyc ├── utils ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── constants.cpython-38.pyc │ └── perspective_api.cpython-38.pyc ├── constants.py ├── tbl_sentiment_main.tex ├── utils.py └── perspective_api.py ├── Sentiment ├── __pycache__ │ ├── data.cpython-38.pyc │ ├── models.cpython-38.pyc │ ├── main_disc.cpython-38.pyc │ ├── discriminator.cpython-38.pyc │ └── prompt_encoder.cpython-38.pyc ├── models.py ├── prompt_encoder.py ├── main_disc.py ├── data.py └── discriminator.py ├── README.md ├── data_pool.py ├── arguments.py ├── token_gpt2.py └── token_main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2 2 | torchaudio==2.0.2+cu118 3 | torchvision==0.15.2+cu118 4 | transformers==4.39.3 5 | -------------------------------------------------------------------------------- /__pycache__/arguments.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/__pycache__/arguments.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/data_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/__pycache__/data_pool.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/token_gpt2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/__pycache__/token_gpt2.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /Sentiment/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/Sentiment/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /Sentiment/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/Sentiment/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/utils/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /Sentiment/__pycache__/main_disc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/Sentiment/__pycache__/main_disc.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/perspective_api.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/utils/__pycache__/perspective_api.cpython-38.pyc -------------------------------------------------------------------------------- /Sentiment/__pycache__/discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/Sentiment/__pycache__/discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /Sentiment/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WindyLee0822/CTG/HEAD/Sentiment/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import yaml 3 | 4 | #AIzaSyDDtAyo35Gh1CR6Hc9EgevI-dhfR_T-Ljo 5 | NEGATIVE_INF = -100000.0 6 | 7 | PERSPECTIVE_API_KEY = 'AIzaSyDDtAyo35Gh1CR6Hc9EgevI-dhfR_T-Ljo' 8 | 9 | PERSPECTIVE_API_ATTRIBUTES = { 10 | 'TOXICITY' 11 | } 12 | 13 | PERSPECTIVE_API_ATTRIBUTES_LOWER = tuple(a.lower() for a in PERSPECTIVE_API_ATTRIBUTES) 14 | -------------------------------------------------------------------------------- /Sentiment/models.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import ( 3 | CONFIG_MAPPING, 4 | MODEL_FOR_CAUSAL_LM_MAPPING, 5 | AutoConfig, 6 | AutoModelForCausalLM, 7 | AutoTokenizer, 8 | HfArgumentParser, 9 | Trainer, 10 | TrainingArguments, 11 | default_data_collator, 12 | set_seed, 13 | BertTokenizer, 14 | GPT2Tokenizer) 15 | 16 | from transformers import GPT2LMHeadModel, AutoTokenizer, AutoModelForMaskedLM 17 | 18 | def create_model(args): 19 | 20 | if args.model_name_or_path: 21 | 22 | config = AutoConfig.from_pretrained(args.model_name_or_path) 23 | model = GPT2LMHeadModel.from_pretrained( 24 | args.model_name_or_path, 25 | from_tf=bool(".ckpt" in args.model_name_or_path), 26 | config=config 27 | ) 28 | else: 29 | print("Model path is not set!!!") 30 | 31 | return model 32 | 33 | 34 | 35 | def _create_model(model_path): 36 | if model_path: 37 | model = GPT2LMHeadModel.from_pretrained(model_path) 38 | else: 39 | print("Model path is not set!!!") 40 | 41 | return model 42 | 43 | 44 | def get_embedding_layer(args, model): 45 | 46 | embeddings = model.base_model.get_input_embeddings() 47 | 48 | return embeddings -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TOLE: Reinforcement Learning with Token-level Feedback for Controllable Text Generation 2 | 3 | Source code of “Reinforcement Learning with Token-level Feedback for Controllable Text Generation (NAACL 2024)” 4 | 5 | Codes of single-attribute control (sentiment transformation) experiments are at `main` branch. 6 | 7 | Codes of multi-attribute control experiments are at `multi-attribute` branch. 8 | 9 | The codes are a little messy currently, I will try to sort them out as soon as possible. 10 | 11 | If you encounter problems, feel free to contact me (wendili@hust.edu.cn). 12 | 13 | ## Single-attribute Control 14 | 15 | - Train an attribute classifier. 16 | 17 | In sentiment transformation, we retrain a attribute classifier with SST-5. To run the recommendation part. 18 | 19 | ```python Sentiment/main_disc.py``` 20 | 21 | - Run Token-level RL 22 | 23 | To train a policy model, run 24 | 25 | ```python token_main.py --source_mode neutral --target_mode positive --reward_model {best checkpoint of your classifier} ``` 26 | 27 | ## Citation 28 | If you find our research helpful, please kindly cite our paper! 29 | 30 | ```bibtex 31 | @article{li2024reinforcement, 32 | title={Reinforcement Learning with Token-level Feedback for Controllable Text Generation}, 33 | author={Li, Wendi and Wei, Wei and Xu, Kaihe and Xie, Wenfeng and Chen, Dangyang and Cheng, Yu}, 34 | journal={arXiv preprint arXiv:2403.11558}, 35 | year={2024} 36 | } 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /Sentiment/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PromptEncoder(torch.nn.Module): 6 | def __init__(self, template, hidden_size, tokenizer, args): 7 | super().__init__() 8 | self.spell_length = sum(template) 9 | self.hidden_size = hidden_size 10 | self.tokenizer = tokenizer 11 | self.args = args 12 | # ent embedding 13 | self.cloze_length = template 14 | self.cloze_mask = [ 15 | [1] * self.cloze_length[0] # first cloze 16 | + [1] * self.cloze_length[1] # second cloze 17 | ] 18 | self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(args.device) 19 | 20 | self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(args.device) 21 | # embedding 22 | self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(args.device) 23 | # LSTM 24 | self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size, 25 | hidden_size=self.hidden_size // 2, 26 | num_layers=2, 27 | dropout=self.args.lstm_dropout, 28 | bidirectional=True, 29 | batch_first=True) 30 | self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), 31 | nn.ReLU(), 32 | nn.Linear(self.hidden_size, self.hidden_size)) 33 | print("init prompt encoder...") 34 | 35 | def forward(self): 36 | input_embeds = self.embedding(self.seq_indices.long()).unsqueeze(0) 37 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze() 38 | return output_embeds -------------------------------------------------------------------------------- /utils/tbl_sentiment_main.tex: -------------------------------------------------------------------------------- 1 | 2 | \begin{table*}[t!] 3 | \centering\footnotesize 4 | \scalebox{.832}{ 5 | \begin{tabular}{l|cc|c|cc|cc|c|cc} 6 | \toprule 7 | \hspace{5.5mm}\multirow{4}{*}{\textbf{Model}} & \multicolumn{5}{c|}{{\cellcolor[gray]{.95}} \textbf{Sentiment to Unlearn:} \textsc{Negative} } & \multicolumn{5}{c}{{\cellcolor[gray]{.95}}\textbf{Sentiment to Unlearn:} \textsc{Positive} } \\ \cmidrule{2-11} 8 | &\multicolumn{2}{c|}{\textbf{\% Positive} ($\uparrow$)} & \textbf{Fluency} ($\downarrow$) & \multicolumn{2}{c|}{\textbf{Diversity} ($\uparrow$)} &\multicolumn{2}{c|}{\textbf{\% Positive} ($\downarrow$)} & \textbf{Fluency} ($\downarrow$) & \multicolumn{2}{c}{\textbf{Diversity} ($\uparrow$)}\\ 9 | & negative & neutral & \multirow{2}{*}{output ppl} & \multirow{2}{*}{dist-2} & \multirow{2}{*}{dist-3} & positive & neutral & \multirow{2}{*}{output ppl} & \multirow{2}{*}{dist-2} & \multirow{2}{*}{dist-3} \\ 10 | & prompt & prompt & & & & prompt & prompt & & & \\\midrule 11 | GPT2 \cite{radford2019language} & \phantom{0}0.00 & 50.02 & 11.42 & 0.85 & 0.85 & 99.08 & 50.02 & 11.42 & 0.84 & 0.84 \\ 12 | \midrule 13 | PPLM \cite{Dathathri2020PPLM} & \phantom{0}8.72 & 52.68 & 142.1 & 0.86 & 0.85 & 89.74 & 39.05 & 181.7 & 0.87 & 0.86 \\ 14 | CTRL \cite{CTRL2019} & 18.88 & 61.81 & 43.79 & 0.83 & 0.86 & 79.05 & 37.63 & 35.94 & 0.83 & 0.86 \\ 15 | GeDi \cite{krause-etal-2021-gedi-generative} & 26.80 & 86.01 & 58.41 & 0.80 & 0.79 & 39.57 & \phantom{0}8.73 & 84.11 & 0.84 & 0.82 \\ 16 | \dexpert \cite{liu-etal-2021-dexperts} & 36.42 & 94.46 & 25.83 & 0.84 & 0.84 & 35.99 & \phantom{0}3.77 & 45.91 & 0.84 & 0.83 \\ 17 | DAPT \cite{gururangan-etal-2020-dapt} & 14.17 & 77.24 & 30.52 & 0.83 & 0.84 & 87.43 & 33.28 & 32.86 & 0.85 & 0.84 \\ 18 | PPO \cite{NEURIPS2020_1f89885d} & 43.13 & 94.10 & 15.16 & 0.80 & 0.84 & 32.22 & 3.65 & 15.54 & 0.81 & 0.84\\ 19 | \midrule 20 | \methodnameshort & \textbf{46.55} & \textbf{95.00} & \textbf{14.54} & 0.80 & 0.84 & \textbf{27.50} & \textbf{2.75} & \textbf{14.72} & 0.80 & 0.84\\ 21 | \bottomrule 22 | \end{tabular}} 23 | \caption{Automatic evaluation results of unlearning sentiment experiments. Baseline results (except PPO) are from \cite{liu-etal-2021-dexperts}.} 24 | \label{tab:sentiment_results} 25 | \end{table*} 26 | 27 | 28 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import TypeVar, Iterable, List, Union, Any 4 | import numpy as np 5 | import torch 6 | from tqdm.auto import tqdm 7 | import os 8 | import collections 9 | from utils.constants import NEGATIVE_INF 10 | 11 | T = TypeVar('T') 12 | 13 | 14 | def reduce_sum(value, mask, axis=None): 15 | if axis is None: 16 | return torch.sum(value * mask) 17 | return torch.sum(value * mask, axis) 18 | 19 | 20 | def reduce_mean(value, mask, axis=None): 21 | if axis is None: 22 | return (torch.sum(value * mask)+1e-8)/ (torch.sum(mask)+1e-8) 23 | return (reduce_sum(value, mask, axis)+1e-8) / (torch.sum(mask, axis)+1e-8) 24 | 25 | 26 | def reduce_std(value, mask): 27 | return torch.sqrt(reduce_mean(torch.square(value), mask) - torch.square(reduce_mean(value, mask))) 28 | 29 | 30 | def logits_to_entropy(logits): 31 | distribution = torch.distributions.Categorical(logits=logits) 32 | return distribution.entropy() 33 | 34 | 35 | def mask_pad(value, mask): 36 | return value * mask + NEGATIVE_INF * (1 - mask) 37 | 38 | 39 | def clamp(value, min_value, max_value): 40 | return torch.max(torch.min(value, max_value), min_value) 41 | 42 | 43 | def ceil_div(a, b): 44 | return (a - 1) // b + 1 45 | 46 | 47 | def exact_div(a, b): 48 | q = a // b 49 | if a != q * b: 50 | raise ValueError('Inexact division: %s / %s = %s' % (a, b, a / b)) 51 | return q 52 | 53 | 54 | def whiten(values, masks, shift_mean=True): 55 | mean, var = reduce_mean(values, masks), reduce_std(values, masks) 56 | whitened = (values - mean) * torch.rsqrt(var + 1e-8) 57 | if not shift_mean: 58 | whitened += mean 59 | return whitened 60 | 61 | 62 | def flatten_dict(nested, sep='.'): 63 | def rec(nest, prefix, into): 64 | for k, v in nest.items(): 65 | if sep in k: 66 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 67 | if isinstance(v, collections.Mapping): 68 | rec(v, prefix + k + sep, into) 69 | else: 70 | into[prefix + k] = v 71 | flat = {} 72 | rec(nested, '', flat) 73 | return flat 74 | 75 | 76 | def distinctness(generations): 77 | unigrams, bigrams, trigrams, forgrams = set(), set(), set(), set() 78 | total_words = 0 79 | for gen in generations: 80 | o = gen.split(' ') 81 | total_words += len(o) 82 | unigrams.update(o) 83 | for i in range(len(o) - 1): 84 | bigrams.add(o[i] + '_' + o[i + 1]) 85 | for i in range(len(o) - 2): 86 | trigrams.add(o[i] + '_' + o[i + 1] + '_' + o[i + 2]) 87 | for i in range(len(o) - 3): 88 | forgrams.add(o[i] + '_' + o[i + 1] + '_' + o[i + 2] + '_' + o[i + 3]) 89 | 90 | return len(unigrams) / total_words, len(bigrams) / total_words, len(trigrams) / total_words, len(forgrams) / total_words 91 | 92 | 93 | def ensure_dir(d): 94 | if not os.path.exists(d): 95 | os.makedirs(d) 96 | 97 | 98 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 99 | assert batch_size > 0 100 | 101 | batch = [] 102 | for item in data: 103 | # Yield next batch 104 | if len(batch) == batch_size: 105 | yield batch 106 | batch = [] 107 | 108 | batch.append(item) 109 | 110 | # Yield last un-filled batch 111 | if len(batch) != 0: 112 | yield batch 113 | 114 | 115 | def set_seed(seed, n_gpu): 116 | np.random.seed(seed) 117 | torch.manual_seed(seed) 118 | if n_gpu > 0: 119 | torch.cuda.manual_seed_all(seed) 120 | 121 | 122 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 123 | with open(file) as f: 124 | for line in f: 125 | yield json.loads(line) 126 | 127 | 128 | def load_cache(file: Path): 129 | if file.exists(): 130 | with file.open() as f: 131 | for line in tqdm(f, desc=f'Loading cache from {file}'): 132 | yield json.loads(line) 133 | 134 | 135 | -------------------------------------------------------------------------------- /data_pool.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from copy import deepcopy 3 | import torch 4 | import random 5 | 6 | class DataPool: 7 | def __init__(self): 8 | # self.tree_tokens = tree_tokens 9 | # self.n_extra_tokens = n_extra_tokens 10 | self.cat_mask = [] 11 | self.ids_pool, self.mask_pool, self.score_pool, self.life = [], [], [], [] 12 | self.sort_score=[] 13 | self.r_limit =0 14 | 15 | def add(self, input_ids: List[int], output_mask: List[int], scores: List[List],pos=True): 16 | 17 | ids_pool, mask_pool, score_pool, life = [], [], [], [] 18 | for a,b,c,d in zip(self.ids_pool,self.mask_pool,self.score_pool,self.life): 19 | d-=1 20 | if d>0: 21 | ids_pool.append(a) 22 | mask_pool.append(b) 23 | score_pool.append(c) 24 | life.append(d) 25 | 26 | self.ids_pool,self.mask_pool,self.score_pool, self.life = ids_pool,mask_pool,score_pool,life 27 | 28 | self.ids_pool.extend(input_ids) 29 | self.mask_pool.extend(output_mask) 30 | self.score_pool.extend(scores) 31 | self.life.extend([8] * len(input_ids)) 32 | 33 | sort_score=[] 34 | for score,mask in zip(self.score_pool,self.mask_pool): 35 | assert len(score)+1 == len(mask) 36 | sort_score.extend([s for s,m in zip(score,mask[1:]) if m!=0]) 37 | # sorted_score = [y for x in self.score_pool for y in x] 38 | # score_tensor = torch.cat([torch.zeros(len(self.ids_pool),1),torch.tensor(self.score_pool)],dim=-1) 39 | # sort_score = (score_tensor * torch.tensor(mask_pool,dtype=torch.bool)).view(-1).tolist() 40 | sorted_score = sorted(sort_score,reverse=True) 41 | 42 | # r_limit = sorted_score[len(sorted_score)//5] # for translation,sentiment task,k=5 43 | # self.r_limit = r_limit 44 | # p_limit = sorted_score[len(sorted_score)//5*4] 45 | # 46 | # # logging.info(f'score_distribution:{[sorted_score[0],r_limit,p_limit,sort_score[-1]]}') 47 | # r_limit = max(r_limit,0) 48 | # p_limit = min(p_limit,0) 49 | self.r_limit=[] 50 | self.r_score=[] 51 | quantile_num=5 52 | for i in range(quantile_num-1): 53 | self.r_limit.append(sorted_score[len(sorted_score)//quantile_num*(i+1)]) 54 | self.r_limit.append(sorted_score[-1]) 55 | self.r_limit.insert(0,sorted_score[0]) 56 | self.r_interval = [self.r_limit[i-1]-self.r_limit[i] for i in range(1,len(self.r_limit))] 57 | self.r_limit = self.r_limit[1:] 58 | # ave = sum(sorted_score[len(sorted_score) // quantile_num * (quantile_num -1):])/(len(sorted_score)//quantile_num) 59 | # self.r_score.append(ave) 60 | # self.r_score = [i+abs(ave)/2 for i in self.r_score] 61 | # maltitude = 1/max(self.r_score) 62 | # self.r_score = [i * maltitude for i in self.r_score] 63 | # print(f'shrehold:{self.r_limit},ave_score:{self.r_score}') 64 | 65 | cur=[] 66 | for score in self.score_pool: 67 | cur_ele =[] 68 | for ele in score: 69 | # if ele>r_limit: 70 | # cur_ele.append(1) 71 | # elif ele < p_limit: 72 | # cur_ele.append(-0.2) #translation sentiment -0.2 73 | # else: 74 | # # if score[-1] > r_limit/2: 75 | # # cur_ele.append(random.uniform(0,0.01)) 76 | # # elif score[-1] < p_limit: 77 | # # cur_ele.append(-random.uniform(0,0.00001)) 78 | # # else: 79 | # cur_ele.append(1e-3) 80 | flag = 1 81 | for i in range(len(self.r_limit)): 82 | if ele >= self.r_limit[i]: 83 | flag =0 84 | # cur_ele.append(self.r_limit[i]+random.gauss(ele-self.r_limit[i],0.6)) 85 | cur_ele.append(self.r_limit[i] + max(min(random.gauss(ele-self.r_limit[i],0.1)* self.r_interval[i],self.r_interval[i]),0)) 86 | # cur_ele.append(self.r_limit[i] + random.random() * self.r_interval[i]) 87 | # cur_ele.append(ele) 88 | break 89 | assert flag==0, f"element {ele}" 90 | cur.append(cur_ele) 91 | 92 | cur = ((torch.sigmoid((torch.tensor(cur) + abs(sorted_score[len(sorted_score)//10*(9)]))/sorted_score[0]) - 0.5 ) * 2).tolist() 93 | 94 | self.cat_mask = cur 95 | 96 | def get_data(self): 97 | return deepcopy(self.ids_pool), deepcopy(self.mask_pool), deepcopy(self.cat_mask) 98 | 99 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser(description='RL') 7 | 8 | #new_added 9 | parser.add_argument( 10 | '--template', type=tuple, default=(6,6)) #(6,6) for negative sentiment, 4,4 for trans&repe 11 | parser.add_argument( 12 | '--model_name_or_path', type=str, default='/home/lwd/gpt2-large') 13 | parser.add_argument( 14 | '--reward_model', type=str, default='/home/lwd/quark/Sentiment/checkpoint/fudge/disc_tuning_positive_temperature0.01_scope_50_epoch_5_f1_0.88_(2,2).ckpt') 15 | parser.add_argument( 16 | '--device', type=str, default='cuda') 17 | parser.add_argument("--pseudo_token", type=str, default='xxx') 18 | parser.add_argument("--lstm_dropout", type=float, default=0.0) 19 | parser.add_argument("--max_prompt_length", type=int, default=8) # 32 for translation, 8 for sentiment, 16 for toy 20 | parser.add_argument("--ranking_scope", type=int, default=50) 21 | parser.add_argument("--source_mode", type=str, default='neutral') 22 | parser.add_argument("--target_mode", type=str, default='positive') 23 | 24 | # dataset 25 | parser.add_argument( 26 | '--output-dir', type=str, default='outputs') 27 | parser.add_argument( 28 | '--dataset-train', type=str, default='data/toxicity/train.jsonl', 29 | help='JSONL file containing train prompts. Each row must contain a prompt at `row["prompt"]["text"]`.') 30 | parser.add_argument( 31 | '--dataset-val', type=str, default='data/toxicity/val.jsonl', 32 | help='JSONL file containing dev prompts. Each row must contain a prompt at `row["prompt"]["text"]`.') 33 | parser.add_argument( 34 | '--perspective-rate-limit', type=int, default=135, help='number of perspective call per second') 35 | 36 | # reward 37 | parser.add_argument( 38 | '--n_extra_tokens', type=int, default=5, help='number of reward categorization') 39 | parser.add_argument( 40 | '--horizon', type=float, default=2500, help='horizon value in adaptive controller') 41 | # KL term 42 | parser.add_argument( 43 | '--kl_coef', type=float, default=0.02, help='coefficient for KL term in reward,0.05 for sentiment') # 0.06 for sentiment, 0.1 for formal 44 | parser.add_argument( 45 | '--adaptive_kl', action='store_true', default=False, help='whether to use adaptive KL controller') 46 | parser.add_argument( 47 | '--target_kl', type=float, default=3, help='target value in adaptive KL controller') 48 | # entropy term 49 | parser.add_argument( 50 | '--entropy_coef', type=float, default=0.08, help='coefficient for entropy term in reward') #0.06 for sentiment,translation 51 | parser.add_argument( 52 | '--adaptive_entropy', action='store_true', default=False, help='whether to use adaptive entropy controller') 53 | parser.add_argument( 54 | '--target_entropy', type=float, default=40, help='target value in adaptive entropy controller') 55 | 56 | # policy 57 | parser.add_argument( 58 | '--init-model', type=str, default='/home/lwd/gpt2-large', help='language model used for policy.') 59 | parser.add_argument( 60 | '--ref-model', type=str, default='/home/lwd/gpt2-large', help='language model used for reference policy.') 61 | parser.add_argument( 62 | '--response-length', type=int, default=64, help='number of tokens to generate for each prompt.') 63 | parser.add_argument( 64 | '--temperature', type=float, default=1.0, help='temperature for sampling policy.') 65 | 66 | # training˚ 67 | parser.add_argument( 68 | '--total-episodes', type=int, default=3000000, help='total number of episodes') 69 | parser.add_argument( 70 | '--lr', type=float, default=1e-5, help='learning rate') # sentiment translation 1e-5 71 | parser.add_argument( 72 | '--num_warmup_steps', type=int, default=500, help='number of warmup steps in lr scheduler') 73 | parser.add_argument( 74 | '--clip_grad', action='store_true', default=False, help='whether to clip gradient') 75 | parser.add_argument( 76 | '--max-grad-norm', type=float, default=0.5, help='maximum norm of gradients ') 77 | 78 | # generation 79 | parser.add_argument( 80 | '--num-samples', type=int, default=25, help='number of samples to generate for each prompt.') 81 | parser.add_argument( 82 | '--top-p', type=float, default=1.0, help='hyperparameter for nucleus sampling') 83 | 84 | # other 85 | parser.add_argument( 86 | '--batch_size', type=int, default=32, help='batch size') 87 | parser.add_argument( 88 | '--seed', type=int, default=1, help='random seed (default: 1)') 89 | parser.add_argument( 90 | '--log-interval', type=int, default=100, help='step interval to print out logs') 91 | parser.add_argument( 92 | '--save-interval', type=int, default=500, help='step interval to save model checkpoints') 93 | parser.add_argument( 94 | '--eval-interval', type=int, default=500, help='step interval to do evaluation') 95 | #这个之前是500 eval_internal,是为了统计step-precision暂时设为250 96 | parser.add_argument( 97 | '--sample-interval', type=int, default=1000, help='step interval to sample from current policy') 98 | parser.add_argument( 99 | '--cuda-deterministic', action='store_false', default=True, 100 | help="sets flags for determinism when using CUDA (potentially slow!)") 101 | 102 | args = parser.parse_args() 103 | args.cuda = torch.cuda.is_available() 104 | 105 | return args 106 | -------------------------------------------------------------------------------- /utils/perspective_api.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import multiprocessing as mp 4 | import time 5 | from pathlib import Path 6 | from typing import List, Union, Optional, Tuple, Dict, Any, Iterable 7 | 8 | from googleapiclient import discovery 9 | from googleapiclient.errors import HttpError 10 | from tqdm.auto import tqdm 11 | 12 | from utils.constants import PERSPECTIVE_API_ATTRIBUTES, PERSPECTIVE_API_KEY, PERSPECTIVE_API_ATTRIBUTES_LOWER 13 | from utils.utils import batchify, load_cache 14 | 15 | 16 | def unpack_scores(response_json: dict) -> Optional[Tuple[dict, dict]]: 17 | if not response_json: 18 | return None 19 | 20 | attribute_scores = response_json['attributeScores'].items() 21 | 22 | summary_scores = {} 23 | span_scores = {} 24 | for attribute, scores in attribute_scores: 25 | attribute = attribute.lower() 26 | 27 | # Save summary score 28 | assert scores['summaryScore']['type'] == 'PROBABILITY' 29 | summary_scores[attribute] = scores['summaryScore']['value'] 30 | 31 | # Save span scores 32 | for span_score_dict in scores['spanScores']: 33 | assert span_score_dict['score']['type'] == 'PROBABILITY' 34 | span = (span_score_dict['begin'], span_score_dict['end']) 35 | span_scores.setdefault(span, {})[attribute] = span_score_dict['score']['value'] 36 | 37 | return summary_scores, span_scores 38 | 39 | 40 | def make_generations_col(generations, responses): 41 | for generation, response in zip(generations, responses): 42 | if response['response']: 43 | response = unpack_scores(response['response'])[0] 44 | else: 45 | # import ipdb 46 | # ipdb.set_trace() 47 | response = {x: None for x in PERSPECTIVE_API_ATTRIBUTES_LOWER} 48 | yield {'text': generation, **response} 49 | 50 | 51 | class PerspectiveAPI: 52 | def __init__(self, api_key: str = PERSPECTIVE_API_KEY, rate_limit: int = 25): 53 | self.service = self._make_service(api_key) 54 | self.last_request_time = -1 # satisfies initial condition 55 | self.rate_limit = rate_limit 56 | self.next_uid = 0 57 | 58 | def request(self, texts: Union[str, List[str]]) -> List[Tuple[Optional[Dict[str, Any]], Optional[HttpError]]]: 59 | if isinstance(texts, str): 60 | texts = [texts] 61 | 62 | # Rate limit to 1 batch request per second 63 | assert len(texts) <= self.rate_limit 64 | time_since_last_request = time.time() - self.last_request_time 65 | if time_since_last_request < 1: 66 | time.sleep(1 - time_since_last_request) 67 | self.last_request_time = time.time() 68 | 69 | # Keys guaranteed in insertion order (Python 3.7+) 70 | responses = {str(uid): None for uid in range(self.next_uid, self.next_uid + len(texts))} 71 | self.next_uid += len(texts) 72 | 73 | def response_callback(request_id, response, exception): 74 | nonlocal responses 75 | responses[request_id] = (response, exception) 76 | 77 | # Make API request 78 | batch_request = self.service.new_batch_http_request() 79 | for uid, text in zip(responses.keys(), texts): 80 | batch_request.add(self._make_request(text, self.service), callback=response_callback, request_id=uid) 81 | batch_request.execute() 82 | 83 | return list(responses.values()) 84 | 85 | def request_bulk(self, 86 | corpus: Union[Iterable[str], Iterable[Tuple[str, str]]], 87 | output_file: Union[str, Path], 88 | pbar: tqdm = None): 89 | # Check for output file 90 | output_file = Path(output_file) 91 | # assert not output_file.exists() 92 | 93 | # Set up progress bar 94 | if not pbar: 95 | total = len(corpus) if isinstance(corpus, collections.abc.Sequence) else None 96 | pbar = tqdm(total=total, dynamic_ncols=True) 97 | pbar.set_description(f'Perspective API') 98 | 99 | i = 0 100 | num_failures = 0 101 | # with open(output_file,'w') as f: 102 | # pass 103 | with output_file.open('a') as f: 104 | for batch in batchify(corpus, self.rate_limit): 105 | request_ids = None 106 | if isinstance(batch[0], tuple): 107 | request_ids, batch = zip(*batch) 108 | 109 | for j, (response, exception) in enumerate(self.request(batch)): 110 | response_dict = { 111 | 'request_id': request_ids[j] if request_ids else i, 112 | 'response': response, 113 | 'error': str(exception) if exception else None 114 | } 115 | 116 | # Save response 117 | json.dump(response_dict, f) 118 | f.write('\n') 119 | 120 | if exception: 121 | num_failures += 1 122 | 123 | i += len(batch) 124 | pbar.update(len(batch)) 125 | pbar.set_postfix(failures=num_failures, rate_limt=self.rate_limit) 126 | 127 | @staticmethod 128 | def _make_service(api_key: str): 129 | # Generate API client object dynamically based on service name and version 130 | return discovery.build('comments:analyze', 'v1alpha1', 131 | discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1", 132 | developerKey=api_key, 133 | static_discovery=False) 134 | 135 | @staticmethod 136 | def _make_request(text: str, service): 137 | analyze_request = { 138 | 'comment': {'text': text}, 139 | 'languages': ['en'], 140 | 'requestedAttributes': {attr: {} for attr in PERSPECTIVE_API_ATTRIBUTES}, 141 | 'spanAnnotations': True, 142 | } 143 | return service.comments().analyze(body=analyze_request) 144 | 145 | 146 | class PerspectiveWorker: 147 | SENTINEL = 'STOP' 148 | 149 | def __init__(self, out_file: Path, total: int, rate_limit: int): 150 | if not rate_limit: 151 | print("Disabling Perspective API (rps is 0)") 152 | self.enabled = False 153 | return 154 | self.enabled = True 155 | 156 | self.requests_handled = set() 157 | for response in load_cache(out_file): 158 | self.requests_handled.add(response['request_id']) 159 | total -= len(self.requests_handled) 160 | 161 | # Setup worker thread 162 | self.task_queue = mp.Queue() 163 | self.process = mp.Process(target=self.perspective_worker, 164 | args=(self.task_queue, out_file, total, rate_limit)) 165 | self.process.start() 166 | 167 | def __call__(self, request_id: str, text: str): 168 | if not self.enabled: 169 | return 170 | 171 | if request_id not in self.requests_handled: 172 | self.task_queue.put((request_id, text)) 173 | 174 | def stop(self): 175 | if not self.enabled: 176 | return 177 | 178 | print("Waiting for Perspective to finish...") 179 | self.task_queue.put(self.SENTINEL) 180 | self.process.join() 181 | 182 | @classmethod 183 | def perspective_worker(cls, queue: mp.Queue, responses_file: Path, total: int, rate_limit: int): 184 | queue_iter = iter(queue.get, cls.SENTINEL) 185 | api = PerspectiveAPI(rate_limit=rate_limit) 186 | pbar = tqdm(total=total, dynamic_ncols=True, position=1) 187 | api.request_bulk(queue_iter, output_file=responses_file, pbar=pbar) 188 | 189 | 190 | def test_perspective_api(): 191 | api = PerspectiveAPI() 192 | 193 | text_success = "Testing" 194 | text_error = 'x' * (20480 + 1) 195 | 196 | score_1, error_1 = api.request(text_success)[0] 197 | assert score_1 and not error_1 198 | 199 | score_2, error_2 = api.request(text_error)[0] 200 | assert not score_2 and isinstance(error_2, HttpError) 201 | 202 | multi_score, multi_error = zip(*api.request([text_success, text_error])) 203 | assert multi_score == (score_1, score_2) 204 | assert tuple(map(str, multi_error)) == tuple(map(str, (error_1, error_2))) 205 | 206 | 207 | # test_perspective_api() 208 | -------------------------------------------------------------------------------- /Sentiment/main_disc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import argparse 5 | import random 6 | import numpy as np 7 | 8 | import sys 9 | from pathlib import Path 10 | sys.path.append(str(Path(__file__).resolve().parents[1])) 11 | 12 | from torch.utils.data import DataLoader 13 | import torch.nn.functional as F 14 | 15 | 16 | from datetime import datetime 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer 19 | from sklearn.metrics import classification_report 20 | from transformers import pipeline, set_seed 21 | 22 | from os.path import join, abspath, dirname 23 | from Sentiment.data import Classification_Dataset, SentimentPrompt, DetoxicDataset, Sentiment_Suffix, GPT2Label 24 | from Sentiment.discriminator import PTuneForLAMA 25 | # from data import Classification_Dataset, SentimentPrompt, DetoxicDataset, Sentiment_Suffix, GPT2Label 26 | # from discriminator import PTuneForLAMA 27 | 28 | def seed_everything(seed): 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True 36 | 37 | def construct_generation_args(): 38 | parser = argparse.ArgumentParser() 39 | 40 | # pre-parsing args 41 | parser.add_argument("--model_name_or_path", type=str, default='/home/lwd/gpt2-base') 42 | 43 | parser.add_argument("--data_path", type=str, default='/home/lwd/quark-publish/Sentiment/data/pos_neg') 44 | 45 | parser.add_argument("--embedding_checkpoint", type=str, default=None) 46 | parser.add_argument("--task_name", type=str, default="sentiment", choices=["detoxic", "sentiment"]) 47 | 48 | parser.add_argument("--pseudo_token", type=str, default='xxx') 49 | 50 | parser.add_argument("--batch_size", type=int, default=160) 51 | parser.add_argument("--epoch", type=int, default=50) 52 | 53 | parser.add_argument("--template", type=str, default="(2, 2)") 54 | parser.add_argument("--early_stop", type=int, default=20) 55 | 56 | parser.add_argument("--lr", type=float, default=2e-4) 57 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 58 | parser.add_argument("--decay_rate", type=float, default=0.98) 59 | parser.add_argument("--weight_decay", type=float, default=0.0005) 60 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 61 | 62 | # lama configuration 63 | parser.add_argument("--only_evaluate", type=bool, default=False) 64 | parser.add_argument("--use_original_template", type=bool, default=False) 65 | parser.add_argument("--use_lm_finetune", type=bool, default=True) 66 | 67 | parser.add_argument("--lstm_dropout", type=float, default=0.0) 68 | 69 | # directories 70 | parser.add_argument("--out_dir", type=str, default=join(abspath(dirname(__file__)), './checkpoint')) 71 | # MegatronLM 11B 72 | 73 | ## generation configure 74 | parser.add_argument("--temperature", type=float, default=0.01) 75 | parser.add_argument("--max_length", type=int, default=30) 76 | parser.add_argument("--max_prompt_length", type=int, default=10) 77 | 78 | parser.add_argument("--beta", type=float, default=0.4) 79 | parser.add_argument("--prompt_type", type=str, default="negative") 80 | parser.add_argument("--target_type", type=str, default="positive") 81 | 82 | parser.add_argument("--prompt_pad_length", type=int, default=10) 83 | # parser.add_argument("--top_k", type=int, default=3) 84 | parser.add_argument("--ranking_scope", type=int, default=50) 85 | parser.add_argument("--top_p", type=float, default=0.95) 86 | 87 | parser.add_argument("--file_name", type=str, default="./eval") 88 | parser.add_argument("--mode", type=str, default="train", choices=["ctg", "train", "classifer"]) 89 | parser.add_argument("--evaluate_file", type=str, default="../our_text") 90 | parser.add_argument("--evaluate_outfile", type=str, default="./eval/our/result.csv") 91 | parser.add_argument("--iter_num", type=int, default=10) 92 | parser.add_argument("--corpus_type", type=str, default="positive") 93 | parser.add_argument("--tuning_name", type=str, default="disc_tuning", 94 | choices=["prompt_tuning", "disc_tuning", "distill_tuning"]) 95 | 96 | ## discriminator information for distilled tuning 97 | parser.add_argument("--disc_embedding_checkpoint", type=str, default=None) 98 | parser.add_argument("--template_disc", type=str, default="(2, 3)") 99 | 100 | args = parser.parse_args() 101 | # post-parsing args 102 | 103 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 104 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 105 | args.template = eval(args.template) if type(args.template) is not tuple else args.template 106 | args.template_disc = eval(args.template_disc) if type(args.template_disc) is not tuple else args.template_disc 107 | 108 | assert type(args.template) is tuple 109 | 110 | seed_everything(args.seed) 111 | 112 | return args 113 | 114 | class Scorer(): 115 | def __init__(self,path,device): 116 | n_args = construct_generation_args() 117 | n_args.device = device 118 | self.args = n_args 119 | self.label_token = {"positive": 'good', "negative": 'bad'} 120 | self.model = PTuneForLAMA(n_args, n_args.template, label_token=self.label_token) 121 | ckpt = torch.load(path)['embedding'] 122 | # ckpt = torch.load('/home/lwd/quark/Sentiment/checkpoint/train-in-fudge-way/disc_tuning_positive_temperature0.01_scope_50_epoch_2_f1_0.85_(2,2).ckpt')['embedding'] 123 | self.model.load_state_dict(ckpt) 124 | self.tokenizer = self.model.tokenizer 125 | 126 | def score(self,input_ids,mode='positive'): 127 | score = self.model._predict_scores(input_ids, input_ids!=self.tokenizer.pad_token_id, reward=True) 128 | return score 129 | # def score(self,input_ids,mode='positive'): 130 | # dataset = Classification_Dataset(tokenizer=self.tokenizer, data_dir=input_ids,label_token=self.label_token,max_length=35) 131 | # 132 | # data_loader = DataLoader(dataset, self.args.batch_size, shuffle=False) 133 | # input_list,rewards_list=[],[] 134 | # with torch.no_grad(): 135 | # self.model.eval() 136 | # for batch in data_loader: 137 | # self.model.eval() 138 | # x = batch[0].to(self.args.device).squeeze(1) 139 | # musk = batch[1].to(self.args.device).long().squeeze(1) 140 | # y = batch[2] 141 | # max_length = musk.sum(1).max() 142 | # scores=torch.zeros_like(x,dtype=torch.float32) 143 | # sen_mask = self.model._predict_scores(x, musk) 144 | # if mode == 'positive': 145 | # sen_mask = sen_mask == 11274 146 | # else: 147 | # sen_mask = sen_mask == 14774 148 | # for l in range(max_length): 149 | # state_score = self.model._predict_scores(x[:,:l+1], musk[:,:l+1], reward=True) 150 | # scores[:,l] = state_score 151 | # rewards = scores[:,1:] - scores[:,:-1] 152 | # input_list.extend(x.tolist()) 153 | # rewards_list.extend(rewards.tolist()) 154 | # return input_list,rewards_list,sen_mask 155 | 156 | class Classifier(): 157 | def __init__(self,args): 158 | n_args = construct_generation_args() 159 | n_args.device = args.device 160 | self.args = n_args 161 | self.label_num = 2 162 | if self.label_num==3: 163 | self.label_token = {"positive": 'good', "negative": 'bad','neutral':'neutral'} 164 | elif self.label_num==2: 165 | self.label_token = {"positive": 'good', "negative": 'bad'} 166 | 167 | self.model = PTuneForLAMA(n_args, n_args.template, label_token=self.label_token) 168 | # used for quark 169 | # ckpt = torch.load( 170 | # 'Sentiment/checkpoint/prompt_model/disc_tuning_positive_temperature0.01_scope_50_epoch_2_f1_0.81_(2,2).ckpt')['embedding'] 171 | # self.model.load_state_dict(ckpt) 172 | #used for fudge 173 | ckpt = torch.load( 174 | '/home/lwd/quark/Sentiment/checkpoint/fudge/disc_tuning_positive_temperature0.01_scope_50_epoch_5_f1_0.88_(2,2).ckpt')['embedding'] 175 | self.model.load_state_dict(ckpt) 176 | self.tokenizer = self.model.tokenizer 177 | 178 | def get_past_key_values(self,input_ids): 179 | attn_mask = input_ids != self.tokenizer.pad_token_id 180 | pkv=self.model.get_past_key_values(input_ids,attn_mask) 181 | return pkv 182 | 183 | def get_q_value(self,input_ids,add_ids,past_key_values,mean=True): 184 | past_key_values_list = [] 185 | bsz = add_ids.shape[0] 186 | sub_bsz = add_ids.shape[1] 187 | for i in range(bsz): 188 | pkv_ele =[] 189 | for tp in past_key_values: 190 | pkv_ele.append((tp[0][i].unsqueeze(0).repeat(sub_bsz,1,1,1),tp[1][i].unsqueeze(0).repeat(sub_bsz,1,1,1))) 191 | past_key_values_list.append(pkv_ele) 192 | 193 | r_storage = [] 194 | for i in range(bsz): 195 | attn_mask = torch.cat([input_ids[i]!=self.tokenizer.pad_token_id,torch.ones(1,device=input_ids.device)],dim=-1) 196 | attn_mask = attn_mask.unsqueeze(0).repeat(sub_bsz,1) 197 | r_value = self.model.forward_with_pkv(add_ids[i],attn_mask,past_key_values_list[i],mean).view(-1) 198 | r_storage.append(r_value) 199 | 200 | return torch.stack(r_storage) 201 | 202 | def get_next_contrast_token(self,input_ids,add_ids,past_key_values,mean=False): 203 | past_key_values_list = [] 204 | bsz = add_ids.shape[0] 205 | # print(input_ids.shape,add_ids.shape) 206 | sub_bsz = add_ids.shape[1] 207 | for i in range(bsz): 208 | pkv_ele =[] 209 | for tp in past_key_values: 210 | pkv_ele.append((tp[0][i].unsqueeze(0).repeat(sub_bsz,1,1,1),tp[1][i].unsqueeze(0).repeat(sub_bsz,1,1,1))) 211 | past_key_values_list.append(pkv_ele) 212 | 213 | r_storage = [] 214 | for i in range(bsz): 215 | attn_mask = torch.cat([input_ids[i]!=self.tokenizer.pad_token_id,torch.ones(1,device=input_ids.device)],dim=-1) 216 | attn_mask = attn_mask.unsqueeze(0).repeat(sub_bsz,1) 217 | r_value = self.model.forward_with_pkv(add_ids[i][...,None],attn_mask,past_key_values_list[i],mean).view(-1) 218 | r_storage.append(r_value) 219 | 220 | return torch.stack(r_storage) 221 | 222 | 223 | 224 | class Trainer(object): 225 | def __init__(self, args): 226 | self.args = args 227 | 228 | # self.label_token ={ 229 | # "positive":'good', 230 | # "negative":'bad' 231 | # } 232 | label_kind = 2 233 | assert self.args.tuning_name == "disc_tuning" 234 | if label_kind ==2: 235 | self.label_token = {"positive": 'good', "negative": 'bad'} 236 | elif label_kind==3: 237 | self.label_token = {"positive": 'good', "negative": 'bad','neutral':'neutral'} 238 | self.model = PTuneForLAMA(args, args.template, label_token=self.label_token) 239 | 240 | self.tokenizer = self.model.tokenizer 241 | data_path = args.data_path 242 | 243 | if self.args.task_name == "sentiment": 244 | print(self.args.tuning_name) 245 | 246 | if self.args.tuning_name == "disc_tuning" or self.args.tuning_name == "distill_tuning": 247 | all_dataset = Classification_Dataset(tokenizer=self.tokenizer, data_dir=data_path, max_length=30, 248 | type_path="train", label_token=self.label_token) 249 | 250 | else: 251 | 252 | all_dataset = Sentiment_Suffix(tokenizer=self.tokenizer, data_dir=data_path, max_length=30, 253 | task_type=self.args.corpus_type, label_token=self.label_token) 254 | 255 | elif self.args.task_name == "detoxic": 256 | print("load detoxic dataset!!!") 257 | 258 | if self.args.tuning_name == "disc_tuning" or self.args.tuning_name == "distill_tuning": 259 | 260 | all_dataset = DetoxicDataset(tokenizer=self.tokenizer, data_dir=data_path, max_length=30, 261 | type_path="train", label_token=self.label_token) 262 | 263 | else: 264 | all_dataset = Sentiment_Suffix(tokenizer=self.tokenizer, data_dir=data_path, max_length=30, 265 | task_type=self.args.corpus_type, label_token=self.label_token) 266 | # all_dataset = GPT2Label(tokenizer=self.tokenizer, data_dir=data_path, max_length=20) 267 | 268 | train_size = int(len(all_dataset) * 0.9) 269 | test_size = len(all_dataset) - train_size 270 | # train_dataset, test_dataset = torch.utils.data.split(all_dataset, [train_size, test_size]) 271 | train_dataset = torch.utils.data.Subset(all_dataset, range(train_size)) 272 | test_dataset = torch.utils.data.Subset(all_dataset, range(train_size, train_size + test_size)) 273 | self.train_loader = DataLoader(train_dataset, args.batch_size, num_workers=2, shuffle=True) 274 | self.test_loader = DataLoader(test_dataset, args.batch_size, num_workers=2, shuffle=True) 275 | 276 | def evaluate(self, epoch_idx, evaluate_type): 277 | self.model.eval() 278 | if evaluate_type == 'Test': 279 | loader = self.test_loader 280 | else: 281 | loader = self.dev_loader 282 | labels = [] 283 | preds = [] 284 | with torch.no_grad(): 285 | self.model.eval() 286 | for batch in loader: 287 | self.model.eval() 288 | x = batch[0].cuda().squeeze(1) 289 | musk = batch[1].cuda().long().squeeze(1) 290 | y = batch[2] 291 | 292 | pred_ids = self.model._predict_scores(x, musk) 293 | 294 | preds += pred_ids 295 | labels += y.tolist() 296 | 297 | result = self.disc_metric(labels, preds) 298 | print('*********precision:{}**********'.format(result)) 299 | return result 300 | 301 | # def evaluate(self, epoch_idx, evaluate_type): 302 | # self.model.eval() 303 | # if evaluate_type == 'Test': 304 | # loader = self.test_loader 305 | # else: 306 | # loader = self.dev_loader 307 | # scores,totals = 0,0 308 | # with torch.no_grad(): 309 | # self.model.eval() 310 | # for batch in loader: 311 | # self.model.eval() 312 | # 313 | # x = batch[0].cuda().squeeze(1) 314 | # musk = batch[1].cuda().long().squeeze(1) 315 | # y = batch[2] 316 | # 317 | # ave,nums = self.model(x, musk) 318 | # scores+=ave 319 | # totals+=nums 320 | # 321 | # print('*********precision:{}**********'.format(scores/totals)) 322 | # return scores/totals 323 | 324 | def disc_metric(self,labels,preds): 325 | correct=0. 326 | sum=0. 327 | for l,p in zip(labels,preds): 328 | if l==p: 329 | correct+=1. 330 | sum+=1. 331 | return correct/sum 332 | 333 | def get_save_path(self): 334 | return join(self.args.out_dir, 'train-in-fudge-way') 335 | 336 | def get_checkpoint(self, epoch_idx, f1_score): 337 | ckpt_name = "{}_{}_temperature{}_scope_{}_epoch_{}_f1_{}_{}.ckpt".format(self.args.tuning_name, 338 | self.args.corpus_type, 339 | self.args.temperature, 340 | self.args.ranking_scope, epoch_idx, 341 | str(f1_score), 342 | str(self.args.template).replace(" ", 343 | "")) 344 | return {'embedding': self.model.state_dict(), 345 | 'ckpt_name': ckpt_name, 346 | 'args': self.args} 347 | 348 | def save(self, best_ckpt): 349 | ckpt_name = best_ckpt['ckpt_name'] 350 | path = self.get_save_path() 351 | os.makedirs(path, exist_ok=True) 352 | torch.save(best_ckpt, join(path, ckpt_name)) 353 | 354 | if self.args.use_lm_finetune: 355 | self.model.model.save_pretrained(str(join(path, ckpt_name))[:-5]) 356 | print("Checkpoint {} saved.".format(ckpt_name)) 357 | 358 | def train(self): 359 | best_dev, early_stop, has_adjusted = 0, 0, True 360 | best_ckpt = None 361 | params = [{'params': self.model.prompt_encoder.parameters(), 'lr': self.args.lr}] 362 | if self.args.use_lm_finetune: 363 | params.append({'params': self.model.model.parameters(), 'lr': 1e-5}) 364 | 365 | optimizer = torch.optim.Adam(params, weight_decay=self.args.weight_decay) 366 | my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=self.args.decay_rate) 367 | 368 | stop_count = 0 369 | best_result = 0.0 370 | for epoch_idx in range(self.args.epoch): 371 | 372 | tot_loss = 0 373 | count = 0 374 | for batch_idx, batch in tqdm(enumerate(self.train_loader),total=len(self.train_loader),desc='epoch:{}'.format(epoch_idx)): 375 | self.model.train() 376 | x = batch[0].cuda().squeeze(1) 377 | musk = batch[1].long().cuda().squeeze(1) 378 | y = batch[2].long().cuda() 379 | 380 | loss = self.model(x, y, musk) 381 | 382 | tot_loss += loss.item() 383 | 384 | loss.backward() 385 | torch.cuda.empty_cache() 386 | optimizer.step() 387 | torch.cuda.empty_cache() 388 | optimizer.zero_grad() 389 | 390 | print(f"epoch index is {epoch_idx}, and total loss is {tot_loss}") 391 | 392 | my_lr_scheduler.step() 393 | 394 | # if epoch_idx > -1: 395 | # result = self.evaluate(epoch_idx, 'Test') 396 | # weight_avg =result["weighted avg"] 397 | # f1_score = weight_avg["f1-score"] 398 | 399 | # if f1_score > best_result: 400 | # best_ckpt = self.get_checkpoint(epoch_idx,best_result) 401 | # best_result = f1_score 402 | # stop_count = 0 403 | # continue 404 | # else: 405 | # stop_count += 1 406 | # if stop_count>5: 407 | # self.save(best_ckpt) 408 | # break 409 | 410 | if epoch_idx >= -1: 411 | 412 | if self.args.tuning_name == "prompt_tuning" or self.args.tuning_name == "disc_tuning": 413 | result = self.evaluate(epoch_idx, 'Test') 414 | # weight_avg = result["weighted avg"] 415 | # f1_score = round(weight_avg["f1- 416 | best_ckpt = self.get_checkpoint(epoch_idx, round(result, 2)) 417 | else: 418 | best_ckpt = self.get_checkpoint(epoch_idx, round(tot_loss, 2)) 419 | 420 | self.save(best_ckpt) 421 | # def train(self): 422 | # best_dev, early_stop, has_adjusted = 0, 0, True 423 | # best_ckpt = None 424 | # params = [{'params': self.model.prompt_encoder.parameters(), 'lr': self.args.lr}] 425 | # if self.args.use_lm_finetune: 426 | # params.append({'params': self.model.model.parameters(), 'lr': 1e-5}) 427 | # 428 | # optimizer = torch.optim.Adam(params, weight_decay=self.args.weight_decay) 429 | # my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=self.args.decay_rate) 430 | # 431 | # stop_count = 0 432 | # best_result = 0.0 433 | # for epoch_idx in range(self.args.epoch): 434 | # 435 | # tot_loss = 0 436 | # count = 0 437 | # for batch_idx, batch in tqdm(enumerate(self.train_loader),total=len(self.train_loader),desc='epoch:{}'.format(epoch_idx)): 438 | # self.model.train() 439 | # x = batch[0].cuda().squeeze(1) 440 | # musk = batch[1].long().cuda().squeeze(1) 441 | # y = batch[2].long().cuda() 442 | # 443 | # loss = self.model(x,musk,y) 444 | # 445 | # tot_loss += loss.item() 446 | # 447 | # loss.backward() 448 | # torch.cuda.empty_cache() 449 | # optimizer.step() 450 | # torch.cuda.empty_cache() 451 | # optimizer.zero_grad() 452 | # 453 | # print(f"epoch index is {epoch_idx}, and total loss is {tot_loss}") 454 | # 455 | # my_lr_scheduler.step() 456 | # 457 | # # if epoch_idx > -1: 458 | # # result = self.evaluate(epoch_idx, 'Test') 459 | # # weight_avg =result["weighted avg"] 460 | # # f1_score = weight_avg["f1-score"] 461 | # 462 | # # if f1_score > best_result: 463 | # # best_ckpt = self.get_checkpoint(epoch_idx,best_result) 464 | # # best_result = f1_score 465 | # # stop_count = 0 466 | # # continue 467 | # # else: 468 | # # stop_count += 1 469 | # # if stop_count>5: 470 | # # self.save(best_ckpt) 471 | # # break 472 | # 473 | # if epoch_idx >= -1: 474 | # 475 | # if self.args.tuning_name == "prompt_tuning" or self.args.tuning_name == "disc_tuning": 476 | # result = self.evaluate(epoch_idx, 'Test') 477 | # # weight_avg = result["weighted avg"] 478 | # # f1_score = round(weight_avg["f1- 479 | # best_ckpt = self.get_checkpoint(epoch_idx, round(result, 2)) 480 | # else: 481 | # best_ckpt = self.get_checkpoint(epoch_idx, round(tot_loss, 2)) 482 | # 483 | # self.save(best_ckpt) 484 | 485 | def main(relation_id=None): 486 | args = construct_generation_args() 487 | 488 | # train stage 489 | trainer = Trainer(args) 490 | trainer.train() 491 | 492 | ## generation stage 493 | 494 | 495 | if __name__ == '__main__': 496 | main() 497 | -------------------------------------------------------------------------------- /token_gpt2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | from os.path import join 5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 6 | from transformers import AutoModelWithLMHead, AutoTokenizer 7 | 8 | import re 9 | import datetime 10 | 11 | from transformers import AutoTokenizer 12 | 13 | import torch.nn.functional as F 14 | from utils.utils import logits_to_entropy 15 | SMALL_CONST = 1e-10 16 | BIG_CONST = -1e15 17 | 18 | import transformers 19 | from transformers import ( 20 | CONFIG_MAPPING, 21 | MODEL_FOR_CAUSAL_LM_MAPPING, 22 | AutoConfig, 23 | AutoModelForCausalLM, 24 | AutoTokenizer, 25 | HfArgumentParser, 26 | Trainer, 27 | TrainingArguments, 28 | default_data_collator, 29 | set_seed, 30 | BertTokenizer, 31 | GPT2Tokenizer) 32 | 33 | from transformers import GPT2LMHeadModel, AutoTokenizer, AutoModelForMaskedLM 34 | from Sentiment.main_disc import Classifier,Scorer 35 | 36 | def create_model(args): 37 | if args.model_name_or_path: 38 | # config = AutoConfig.from_pretrained(args.model_name_or_path) 39 | model = GPT2LMHeadModel.from_pretrained( 40 | args.model_name_or_path 41 | ) 42 | else: 43 | print("Model path is not set!!!") 44 | 45 | return model 46 | 47 | def _create_model(model_path): 48 | if model_path: 49 | model = GPT2LMHeadModel.from_pretrained(model_path) 50 | else: 51 | print("Model path is not set!!!") 52 | 53 | return model 54 | 55 | def get_embedding_layer(args, model): 56 | embeddings = model.base_model.get_input_embeddings() 57 | 58 | return embeddings 59 | 60 | class PromptEncoder(torch.nn.Module): 61 | def __init__(self, template, hidden_size, tokenizer, args): 62 | super().__init__() 63 | self.spell_length = sum(template) 64 | self.hidden_size = hidden_size 65 | self.tokenizer = tokenizer 66 | self.args = args 67 | # ent embedding 68 | self.cloze_length = template 69 | self.cloze_mask = [ 70 | [1] * self.cloze_length[0] # first cloze 71 | + [1] * self.cloze_length[1] # second cloze 72 | ] 73 | self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(args.device) 74 | 75 | self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(args.device) 76 | # embedding 77 | self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(args.device) 78 | # LSTM 79 | self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size, 80 | hidden_size=self.hidden_size // 2, 81 | num_layers=2, 82 | dropout=self.args.lstm_dropout, 83 | bidirectional=True, 84 | batch_first=True) 85 | self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), 86 | nn.ReLU(), 87 | nn.Linear(self.hidden_size, self.hidden_size)) 88 | print("init prompt encoder...") 89 | 90 | def forward(self): 91 | input_embeds = self.embedding(self.seq_indices).unsqueeze(0) 92 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze() 93 | return output_embeds 94 | 95 | 96 | class Distill_Tuning(torch.nn.Module): 97 | 98 | def __init__(self, args, template, label_token=None): 99 | super().__init__() 100 | self.args = args 101 | self.target_mode = args.target_mode 102 | 103 | # load tokenizer 104 | self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path) 105 | self.tokenizer.pad_token = self.tokenizer.eos_token 106 | 107 | # model setting 108 | self.model = create_model(self.args) 109 | # self.model.resize_token_embeddings(len(self.tokenizer)) 110 | self.model = self.model.to(self.args.device) 111 | for param in self.model.parameters(): 112 | param.requires_grad = False 113 | 114 | # get model's embeddings 115 | self.embeddings = self.model.get_input_embeddings() 116 | 117 | # label information 118 | self.label_token = label_token 119 | self.label_token_ids = {} 120 | 121 | for k, v in self.label_token.items(): 122 | print(k, v, self.tokenizer.encode(v)) 123 | self.label_token_ids[k] = self.tokenizer.encode(v) 124 | 125 | self.template = template 126 | # load prompt encoder 127 | self.hidden_size = self.embeddings.embedding_dim 128 | 129 | self.pseudo_token_id = self.tokenizer.convert_tokens_to_ids(self.args.pseudo_token) 130 | 131 | self.spell_length = sum(self.template) 132 | self.prompt_encoder_pos = PromptEncoder(self.template, self.hidden_size, self.tokenizer, args) 133 | self.prompt_encoder_pos = self.prompt_encoder_pos.to(self.args.device) 134 | # self.prompt_encoder_neg = PromptEncoder(self.template, self.hidden_size, self.tokenizer, args) 135 | # self.prompt_encoder_neg = self.prompt_encoder_neg.to(self.args.device) 136 | self.fc_loss = CrossEntropyLoss(reduction='none') 137 | # self.classifier = Classifier(param) 138 | self.scorer = Scorer(args.reward_model,param) 139 | self.kl_dropout = nn.Dropout(p=0.5) 140 | ### load discriminator 141 | # if self.args.disc_embedding_checkpoint != None: 142 | # self.disc_model = _create_model(self.args.disc_embedding_checkpoint[:-5]).to(self.args.device) 143 | # self.spell_length_disc = sum(self.args.template_disc) 144 | # self.disc_embedding = self.disc_model.get_input_embeddings() 145 | # self.prompt_encoder_disc = PromptEncoder(self.args.template_disc, self.disc_embedding.embedding_dim, 146 | # self.tokenizer, args) 147 | # self.prompt_encoder_disc = self.prompt_encoder_disc.to(self.args.device) 148 | # self.prompt_encoder_disc.load_state_dict(self.load_prompt(self.args.disc_embedding_checkpoint)) 149 | # else: 150 | # self.disc_model = self.model 151 | # self.prompt_encoder_disc = self.prompt_encoder 152 | 153 | def load_prompt(self, embedding_checkpoint): 154 | checkpoint = torch.load(embedding_checkpoint) 155 | prompt_embedding = checkpoint['embedding'] 156 | return prompt_embedding 157 | 158 | def embed_input(self, queries): 159 | bz = queries.shape[0] 160 | queries_for_embedding = queries.clone() 161 | raw_embeds = self.disc_embedding(queries_for_embedding) 162 | 163 | replace_embeds = self.prompt_encoder_disc() 164 | 165 | replace_embeds = replace_embeds.unsqueeze(0).expand(bz, -1, -1) 166 | 167 | raw_embeds[:, -self.prompt_encoder_disc.spell_length:, :] = replace_embeds 168 | 169 | return raw_embeds 170 | 171 | def get_query_head(self, x_h, prompt_tokens, x_t=None): 172 | 173 | prompt_tensor_head = torch.tensor(prompt_tokens * (self.spell_length)).to(self.args.device) 174 | 175 | trans_inputs = [] 176 | 177 | index_musk = (x_h == self.tokenizer.pad_token_id).type(torch.uint8) # only calculte the token which is not eos 178 | 179 | valid_number_length = torch.sum(index_musk, 1) 180 | 181 | for index, seq in zip(valid_number_length, x_h): 182 | trans_inputs.append(torch.cat([prompt_tensor_head, seq])) 183 | # if index == x_h.shape[1]: 184 | # trans_inputs.append(torch.cat([prompt_tensor_head, seq])) 185 | # else: 186 | # trans_inputs.append(torch.cat([seq[:index], prompt_tensor_head, seq[index:]])) 187 | 188 | res = torch.stack(trans_inputs, dim=0) 189 | if x_t != None: 190 | # x_t = x_t.unsqueeze(1) 191 | return torch.cat([res, x_t], dim=1) 192 | else: 193 | return res 194 | 195 | def embed_input_head(self, queries,mode='positive'): 196 | bz = queries.shape[0] 197 | queries_for_embedding = queries.clone() 198 | 199 | queries_for_embedding[(queries == self.pseudo_token_id)] = self.tokenizer.unk_token_id 200 | raw_embeds = self.embeddings(queries_for_embedding) 201 | 202 | try: 203 | blocked_indices = (queries == self.pseudo_token_id).type(torch.uint8).nonzero().reshape( 204 | (bz, self.spell_length, 2))[:, :, 1] # bz 205 | except: 206 | print(bz) 207 | print(queries.shape) 208 | print(queries[0]) 209 | print(queries[-1]) 210 | print((queries == self.pseudo_token_id).type(torch.uint8).nonzero().shape) 211 | raise ValueError 212 | 213 | if mode=='positive': 214 | prompt_encoder = self.prompt_encoder_pos 215 | elif mode=='negtive': 216 | prompt_encoder = self.prompt_encoder_neg 217 | else: 218 | raise ValueError 219 | 220 | replace_embeds = prompt_encoder() 221 | for bidx in range(bz): 222 | for i in range(prompt_encoder.spell_length): 223 | raw_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :] 224 | return raw_embeds 225 | 226 | 227 | def top_k_top_p_filtering(self,logits,top_k=0,top_p=1.0,filter_value=BIG_CONST,min_tokens_to_keep=1,): 228 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 229 | Args: 230 | logits: logits distribution shape (batch size, vocabulary size) 231 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 232 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 233 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 234 | Make sure we keep at least min_tokens_to_keep per batch example in the output 235 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 236 | """ 237 | if top_k > 0: 238 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 239 | # Remove all tokens with a probability less than the last token of the top-k 240 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 241 | logits[indices_to_remove] = filter_value 242 | 243 | if top_p < 1.0: 244 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 245 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 246 | 247 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 248 | sorted_indices_to_remove = cumulative_probs > top_p 249 | if min_tokens_to_keep > 1: 250 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 251 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 252 | # Shift the indices to the right to keep also the first token above the threshold 253 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 254 | sorted_indices_to_remove[..., 0] = 0 255 | 256 | # scatter sorted tensors to original indexing 257 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 258 | logits[indices_to_remove] = filter_value 259 | 260 | return logits 261 | 262 | # def get_q_value(self,text_ids,logits,topk): 263 | # past_key_values = self.classifier.get_past_key_values(text_ids) 264 | # sort_logits = logits.sort(dim=-1,descending=True) 265 | # indices = sort_logits.indices[:,:topk,None] 266 | # p = torch.softmax(sort_logits.values[:,:topk],dim=-1) 267 | # # cat_ids = torch.cat([text_ids[:,None,:].repeat(1,topk,1),indices[None,...].repeat(text_ids.shape[0],1,1)]) 268 | # reward = self.classifier.get_q_value(text_ids,indices,past_key_values) 269 | # #todo whether maintain the interference 270 | # # interfere = torch.normal(0,0.08,size=reward.shape) 271 | # # interfere = interfere.masked_fill(reward!=0,0) 272 | # # reward += interfere 273 | # # q_value = p * reward 274 | # # return q_value.sum(-1) 275 | # return reward.view(-1) 276 | 277 | def get_q_value(self,text_ids): 278 | q=self.scorer.score(text_ids) 279 | return q if self.target_mode=='positive' else 1-q 280 | 281 | def sample(self, prompts_ids, max_length, mode='positive', gen=False): 282 | cur_len = prompts_ids.shape[1] 283 | logits = [] 284 | output_ids = prompts_ids 285 | q_storage = torch.zeros_like(output_ids)[:, :-1] 286 | output_mask = torch.zeros_like(output_ids) 287 | return_dict = {} 288 | eos_flag = torch.ones([prompts_ids.shape[0]]).type(torch.uint8).to(self.args.device) 289 | 290 | if mode != None: 291 | prompt_tokens = [self.pseudo_token_id] 292 | queries = self.get_query_head(prompts_ids, prompt_tokens) 293 | inputs_embeds = self.embed_input_head(queries, mode) 294 | 295 | # attention_mask = torch.cat([prompts_ids != self.tokenizer.pad_token_id, torch.ones( 296 | # [prompts_ids.shape[0], self.spell_length + max_length - prompts_ids.shape[1]]).long().to( 297 | # self.args.device)], dim=1) 298 | attention_mask = torch.cat([torch.ones(prompts_ids.shape[0], self.spell_length, device=self.args.device), 299 | prompts_ids != self.tokenizer.pad_token_id, 300 | torch.ones(prompts_ids.shape[0], max_length - prompts_ids.shape[1] + 1, 301 | device=self.args.device).long().to(self.args.device)], dim=-1) 302 | else: 303 | inputs_embeds = self.embeddings(prompts_ids) 304 | attention_mask = torch.cat([prompts_ids != self.tokenizer.pad_token_id, torch.ones( 305 | [prompts_ids.shape[0], max_length - prompts_ids.shape[1] + 1]).long().to(self.args.device)], dim=1) 306 | 307 | position_ids = attention_mask.long().cumsum(-1) - 1 308 | position_ids = position_ids.masked_fill_(attention_mask == 0, 0) 309 | 310 | if not gen: 311 | q_value = self.get_q_value(output_ids) 312 | q_storage = torch.cat([q_storage, q_value[..., None]], dim=-1) 313 | 314 | # start = datetime.datetime.now() 315 | # test generation time 316 | first_round = 1 317 | while cur_len <= max_length: 318 | outputs = self.model(inputs_embeds=inputs_embeds, 319 | attention_mask=attention_mask[:, :inputs_embeds.shape[1]], 320 | position_ids=position_ids[:, :inputs_embeds.shape[1]], 321 | return_dict=True) 322 | 323 | if first_round: 324 | if mode == None: 325 | last_non_masked_idx = torch.sum(prompts_ids != self.tokenizer.pad_token_id, dim=1) - 1 326 | else: 327 | last_non_masked_idx = torch.sum(prompts_ids != self.tokenizer.pad_token_id, 328 | dim=1) - 1 + self.spell_length 329 | next_token_logits = outputs.logits[range(prompts_ids.shape[0]), last_non_masked_idx, :] 330 | first_round = 0 331 | else: 332 | next_token_logits = outputs.logits[:, -1, :] 333 | 334 | # if gen == False: 335 | next_token_logits_ = self.top_k_top_p_filtering(next_token_logits, top_k=self.args.ranking_scope, top_p=1.0, 336 | filter_value=BIG_CONST) 337 | # else: 338 | # next_token_logits_ = self.top_k_top_p_filtering(next_token_logits, top_k=10, 339 | # top_p=1.0, filter_value=BIG_CONST) 340 | 341 | next_token_logits_prob = torch.softmax(next_token_logits_, dim=1) 342 | 343 | next_token_logits_prob[:, self.tokenizer.eos_token_id] = 0 344 | next_token_logits_prob[:, self.pseudo_token_id] = 0 345 | # if gen==False: 346 | next_tokens = torch.multinomial(next_token_logits_prob, num_samples=1).squeeze(1) 347 | # else: 348 | # next_tokens = torch.argmax(next_token_logits, dim=-1) 349 | 350 | eos_flag = eos_flag.mul((next_tokens != self.tokenizer.eos_token_id).type( 351 | torch.uint8)) # if flag = 0, it means the generation is over 352 | next_tokens = next_tokens.mul(eos_flag) 353 | next_tokens[next_tokens == 0] = self.tokenizer.eos_token_id 354 | output_ids = torch.cat([output_ids, next_tokens.unsqueeze(1)], dim=1) 355 | # output_mask = torch.cat([output_mask,torch.ones_like(next_tokens.unsqueeze(1))],dim=1) 356 | output_mask = torch.cat([output_mask, (next_tokens.unsqueeze(1) != self.tokenizer.eos_token_id)], dim=1) 357 | inputs_embeds = torch.cat([inputs_embeds, self.embeddings(next_tokens).unsqueeze(1)], dim=1) 358 | 359 | cur_len = cur_len + 1 360 | 361 | if not gen: 362 | q_value = self.get_q_value(output_ids) 363 | q_storage = torch.cat([q_storage, q_value[..., None]], dim=-1) 364 | 365 | # end = datetime.datetime.now() 366 | # print("runing time is:",end-start) 367 | response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) 368 | for output in output_ids] 369 | prompt_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) 370 | for output in prompts_ids] 371 | 372 | if not gen: 373 | assert output_ids.shape[-1] == q_storage.shape[-1] 374 | reward = q_storage[:, 1:] - q_storage[:, :-1] 375 | else: 376 | reward = None 377 | # reward = q_storage[:, -1][..., None].repeat(1, q_storage.shape[-1] - 1) 378 | return_dict = { 379 | 'text': response_text, 380 | 'prompt': prompt_text, 381 | 'input_ids': output_ids, 382 | 'output_mask': output_mask, 383 | 'q_values': reward 384 | } 385 | return return_dict 386 | 387 | 388 | def forward(self, x_hs, x_ts, att_mask): 389 | # construct query ids 390 | prompt_tokens = [self.pseudo_token_id] 391 | queries = self.get_query_head(x_hs, prompt_tokens) 392 | # construct label ids 393 | attention_mask = torch.cat( 394 | [att_mask, torch.ones([att_mask.shape[0], self.prompt_encoder.spell_length]).long().to(self.args.device)], 395 | dim=1) 396 | 397 | position_ids = attention_mask.long().cumsum(-1) - 1 398 | position_ids.masked_fill_(attention_mask == 0, 0) 399 | 400 | labels = torch.clone(queries) 401 | 402 | labels.masked_fill_(attention_mask == 0, -100) 403 | labels.masked_fill_(queries == self.pseudo_token_id, -100) 404 | 405 | # get embedded input 406 | inputs_embeds = self.embed_input_head(queries) 407 | 408 | output = self.model(inputs_embeds=inputs_embeds, 409 | attention_mask=attention_mask, 410 | position_ids=position_ids, 411 | labels=None) 412 | 413 | output_logits = output.logits 414 | # ce_loss = self.contrast_crossEntry_loss(torch.softmax(output_logits, dim = -1), labels, sentence_labels = x_ts) 415 | 416 | _queries = queries.view(queries.size(0) * queries.size(1)) 417 | _output_logits = output_logits.view(output_logits.size(0) * output_logits.size(1), -1) 418 | disc_logits = _output_logits.index_select(0, torch.nonzero(_queries != self.pseudo_token_id).squeeze(1)).view( 419 | output_logits.shape[0], -1, output_logits.shape[2]) 420 | 421 | logits_candidate = self.get_candidate_logits(x_hs, att_mask) 422 | logits_candidate = self.top_k_top_p_filtering( 423 | logits_candidate.view(logits_candidate.shape[0] * logits_candidate.shape[1], -1), 424 | top_k=self.args.ranking_scope, top_p=self.args.top_p, filter_value=BIG_CONST).view(x_hs.shape[0], 425 | x_hs.shape[1], -1) 426 | 427 | reank_output = self.get_ranked_logtis(x_hs, logits_candidate.detach().clone(), desired_att=None) 428 | 429 | reank_output = (logits_candidate > BIG_CONST + 10).mul(reank_output) 430 | 431 | kl_loss = self.KL_loss(torch.softmax(disc_logits, dim=-1), reank_output, att_mask) 432 | 433 | loss = kl_loss 434 | 435 | return loss 436 | 437 | def forward_pass(self, x_hs, att_mask, out_mask, reward_mask=None, mode='positive',gen=False): 438 | # construct query ids 439 | if mode!=None: 440 | prompt_tokens = [self.pseudo_token_id] 441 | queries = self.get_query_head(x_hs, prompt_tokens) 442 | inputs_embeds = self.embed_input_head(queries,mode) 443 | attention_mask = torch.cat( 444 | [torch.ones([att_mask.shape[0], self.spell_length]).long().to(self.args.device),att_mask], 445 | dim=1) 446 | 447 | else: 448 | queries = x_hs 449 | inputs_embeds = self.embeddings(queries) 450 | attention_mask = att_mask 451 | 452 | position_ids = attention_mask.long().cumsum(-1) - 1 453 | position_ids = position_ids.masked_fill_(attention_mask == 0, 0) 454 | 455 | # construct label ids 456 | labels = torch.clone(queries) 457 | # labels.masked_fill_(attention_mask == 0, -100) 458 | # labels.masked_fill_(queries == self.pseudo_token_id, -100) 459 | if mode != None: 460 | supplement = torch.zeros((out_mask.shape[0],self.spell_length),device=self.args.device) 461 | out_mask = torch.cat([supplement,out_mask],-1) 462 | # attention_mask = torch.cat([supplement,attention_mask],-1) 463 | reward_mask = torch.cat([supplement,reward_mask],-1) 464 | # punish_mask = torch.cat([supplement, punish_mask], -1) 465 | 466 | total_mask = attention_mask * out_mask 467 | labels = labels.masked_fill_(total_mask == 0 ,-100) 468 | labels = labels[:,1:] 469 | # reward_labels.masked_fill_(reward_mask==0, -100) 470 | # punish_labels.masked_fill_(punish_mask == 0, -100) 471 | 472 | outputs = self.model(inputs_embeds=inputs_embeds, 473 | attention_mask=attention_mask, 474 | position_ids=position_ids, 475 | labels=None) 476 | 477 | logits = outputs.logits[:,:-1,:] 478 | 479 | if mode==None: 480 | output_logits = logits * (labels.unsqueeze(-1) != -100) 481 | return {'logits':output_logits} 482 | 483 | loss = self.fc_loss(logits.reshape(-1,logits.shape[-1]), labels.reshape(-1)) #- self.fc_loss(logits.reshape(-1,logits.shape[-1]),punish_labels.reshape(-1)) * 0.8 484 | loss = loss.reshape(x_hs.shape[0],-1) 485 | lm_loss = (loss * reward_mask).sum() 486 | 487 | log_prob = F.log_softmax(logits, dim=-1) 488 | labels_select = torch.where(labels==-100,0,labels) 489 | output_logprob = torch.gather(log_prob, -1, labels_select[...,None]).squeeze(2) 490 | output_logprob = output_logprob.masked_fill_(labels==-100,0) 491 | 492 | output_logits = logits * (labels.unsqueeze(-1)!=-100) 493 | # kl_mask = (labels!=-100) ^ ((reward_mask>0) & (reward_mask<0.01)) 494 | kl_mask = (reward_mask!=0) 495 | 496 | # output_logprob = torch.gather(log_prob, 2, labels_select[...,None]).squeeze(2) 497 | output_entropy = logits_to_entropy(logits) 498 | # lm_loss = -1. * output_logprob 499 | 500 | output_dic={'logits':output_logits,'loss':lm_loss,'logprob':output_logprob, 501 | 'entropy':output_entropy,'kl_mask':kl_mask} 502 | 503 | return output_dic 504 | 505 | 506 | 507 | 508 | -------------------------------------------------------------------------------- /Sentiment/data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import re 7 | import socket 8 | import string 9 | from collections import Counter 10 | from logging import getLogger 11 | from pathlib import Path 12 | from typing import Callable, Dict, Iterable, List 13 | from transformers import AutoModel,AutoTokenizer,AutoModelForSequenceClassification 14 | 15 | import torch 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | # import jsonlines 19 | 20 | class Evaluation(): 21 | def __init__(self): 22 | self.tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment-latest') 23 | self.model = AutoModelForSequenceClassification.from_pretrained( 24 | 'cardiffnlp/twitter-roberta-base-sentiment-latest') 25 | 26 | def eval(self,text,target='positive'): 27 | inputs = self.tokenizer(text,return_tensors='pt',padding=True) 28 | # inputs = {k:v.to('cuda') for k,v in inputs.items()} 29 | output = self.model(**inputs) 30 | predicted_class_id = output.logits.argmax(-1) 31 | labels = [self.model.config.id2label[i] for i in predicted_class_id.tolist()] 32 | nums = [1 for i in labels if i.lower()==target] 33 | return sum(nums)/len(labels) 34 | 35 | def score(self,text,target='POSITIVE'): 36 | inputs = self.tokenizer(text, return_tensors='pt', padding=True) 37 | output = self.model(**inputs) 38 | #todo check whether has been softmax? 39 | pid = self.model.config.label2id['positive'] 40 | nid = self.model.config.label2id['negative'] 41 | logits = torch.softmax(output.logits,dim=-1) 42 | scores = logits[:,pid] - logits[:,nid] 43 | return scores 44 | 45 | 46 | class TextDataset(Dataset): 47 | def __init__( 48 | self, 49 | tokenizer, 50 | data_dir, 51 | max_length, 52 | ): 53 | super().__init__() 54 | 55 | self.src_file = Path(data_dir) 56 | self.src_lens = self.get_char_lens(self.src_file) 57 | self.max_source_length = max_length 58 | self.tokenizer = tokenizer 59 | self.tokenizer.padding_side = "left" 60 | 61 | def __len__(self): 62 | return len(self.src_lens) 63 | 64 | def __getitem__(self, index): 65 | index = index + 1 # linecache starts at 1 66 | source_line = linecache.getline(str(self.src_file), index).rstrip("\n") # +self.tokenizer.bos_token 67 | source_line = source_line.replace("xxx", '') 68 | 69 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 70 | truncation=True, padding="max_length") 71 | return [res_input["input_ids"], res_input["attention_mask"]] 72 | 73 | @staticmethod 74 | def get_char_lens(data_file): 75 | return [len(x) for x in Path(data_file).open().readlines()] 76 | 77 | 78 | class ToxicPrompt(Dataset): 79 | def __init__( 80 | self, 81 | tokenizer, 82 | data_dir, 83 | max_length, 84 | n_obs=None, 85 | prefix="", 86 | ): 87 | super().__init__() 88 | 89 | self.src_file = data_dir 90 | 91 | self.prompts = [] 92 | with open(str(self.src_file), "r+", encoding="utf8") as f: 93 | for item in jsonlines.Reader(f): 94 | prompt = item["prompt"]["text"] 95 | self.prompts.append(prompt) 96 | 97 | self.tokenizer = tokenizer 98 | self.max_lens = max_length 99 | self.tokenizer.padding_side = "left" 100 | 101 | def __len__(self): 102 | return len(self.prompts) 103 | 104 | def __getitem__(self, index): 105 | index = index # linecache starts at 1 106 | source_line = self.prompts[index].rstrip("\n") 107 | source_line = source_line.replace("xxx", '') 108 | 109 | res = self.tokenizer.encode_plus(source_line, max_length=self.max_lens, return_tensors="pt", truncation=True, 110 | padding="max_length") 111 | 112 | return (res["input_ids"], res["attention_mask"]) 113 | 114 | @staticmethod 115 | def get_char_lens(data_file): 116 | return [len(x) for x in Path(data_file).open().readlines()] 117 | 118 | 119 | class SentimentPrompt(Dataset): 120 | 121 | def __init__( 122 | self, 123 | tokenizer, 124 | data_dir, 125 | max_length, 126 | prompt_type="negative", 127 | n_obs=None, 128 | prefix="", 129 | ): 130 | super().__init__() 131 | 132 | self.src_file = data_dir + "/" + str(prompt_type) + '_prompts.jsonl' 133 | 134 | self.prompts = [] 135 | with open(str(self.src_file), "r+", encoding="utf8") as f: 136 | for item in jsonlines.Reader(f): 137 | prompt = item["prompt"]["text"] 138 | self.prompts.append(prompt) 139 | 140 | self.tokenizer = tokenizer 141 | self.max_lens = max_length 142 | self.tokenizer.padding_side = "left" 143 | 144 | def __len__(self): 145 | return len(self.prompts) 146 | 147 | def __getitem__(self, index): 148 | index = index # linecache starts at 1 149 | source_line = self.prompts[index].rstrip("\n") 150 | source_line = source_line.replace("xxx", '') 151 | 152 | assert source_line, f"empty source line for index {index}" 153 | 154 | res = self.tokenizer.encode_plus(source_line, max_length=self.max_lens, return_tensors="pt", truncation=True, 155 | padding="max_length") 156 | 157 | return (res["input_ids"], res["attention_mask"]) 158 | 159 | @staticmethod 160 | def get_char_lens(data_file): 161 | return [len(x) for x in Path(data_file).open().readlines()] 162 | 163 | 164 | class DetoxicDataset(Dataset): 165 | def __init__( 166 | self, 167 | tokenizer, 168 | data_dir, 169 | max_length, 170 | type_path="train", 171 | n_obs=None, 172 | src_lang=None, 173 | tgt_lang=None, 174 | prefix="", 175 | label_token={} 176 | ): 177 | super().__init__() 178 | 179 | self.src_file = Path(data_dir).joinpath(type_path + ".src") 180 | self.tgt_file = Path(data_dir).joinpath(type_path + ".tgt") 181 | 182 | self.src_lens = self.get_char_lens(self.src_file) 183 | self.max_source_length = max_length 184 | self.max_target_length = max_length 185 | 186 | self.label_token = label_token 187 | 188 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 189 | 190 | self.tokenizer = tokenizer 191 | self.prefix = prefix 192 | 193 | if n_obs is not None: 194 | self.src_lens = self.src_lens[:n_obs] 195 | self.src_lang = src_lang 196 | self.tgt_lang = tgt_lang 197 | self.tokenizer.padding_side = "left" 198 | 199 | def token_wrapper(args, token): 200 | if 'roberta' in args.model_name or 'gpt' in args.model_name or 'megatron' in args.model_name: 201 | return 'Ġ' + token 202 | else: 203 | return token 204 | 205 | def __len__(self): 206 | return len(self.src_lens) 207 | 208 | def __getitem__(self, index): 209 | index = index + 1 # linecache starts at 1 210 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 211 | "\n") # +self.tokenizer.bos_token 212 | source_line = source_line.replace("xxx", '') 213 | 214 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 215 | tgt_line = str(tgt_line) 216 | if "1" in tgt_line: 217 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['positive'])) 218 | else: 219 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['negative'])) 220 | 221 | assert source_line, f"empty source line for index {index}" 222 | assert tgt_line, f"empty tgt line for index {index}" 223 | 224 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 225 | truncation=True, padding="max_length") 226 | 227 | return [res_input["input_ids"], res_input["attention_mask"], tgt_line] 228 | 229 | @staticmethod 230 | def get_char_lens(data_file): 231 | return [len(x) for x in Path(data_file).open().readlines()] 232 | 233 | 234 | class GPT2Label(Dataset): 235 | def __init__(self,tokenizer,data_dir,max_length): 236 | self.device = 'cuda' 237 | self.model = AutoModel.from_pretrained('gpt2').to(self.device) 238 | self.classify = Evaluation() 239 | self.pad_token_id = tokenizer.pad_token_id 240 | self.tokenizer = tokenizer 241 | self.data_dir = data_dir 242 | self.max_source_length = max_length 243 | self.data = self.process_data() 244 | 245 | def get_q(self,input_ids): 246 | input_ids = input_ids.to(self.device) 247 | output = self.model(input_ids = input_ids, attention_mask = (input_ids != self.pad_token_id)) 248 | logits = torch.matmul(output.last_hidden_state[:,-1,:],self.model.wte.weight.T) 249 | s_logits = logits.sort(-1,descending=True) 250 | ids = s_logits.indices[...,:32] 251 | possibility = s_logits.values[...,:32] 252 | batch_input = torch.cat([input_ids.repeat(32,1),ids.T],dim=-1) 253 | batch_text = [self.tokenizer.decode(i) for i in batch_input] 254 | batch_score = self.classify.score(batch_text) 255 | q_value = batch_score * torch.softmax(possibility,-1).to('cpu') 256 | q_value = q_value.sum() 257 | return q_value 258 | 259 | def process_data(self): 260 | with open(Path(self.data_dir).joinpath("train.src")) as f: 261 | all_text = f.readlines() 262 | 263 | data=[] 264 | for raw_text in tqdm(all_text[:8],desc='Calculating Q-values:'): 265 | # text_list = raw_text.strip().split() 266 | encoding = self.tokenizer.encode_plus(raw_text,return_tensors="pt") 267 | input_ids,attn_mask = encoding['input_ids'],encoding['attention_mask'] 268 | q_list = [] 269 | q_len = min(input_ids.shape[-1],20) 270 | pad_len = 20 271 | for ilen in range(q_len): 272 | assert 0 not in attn_mask 273 | q_value = self.get_q(input_ids[:,:ilen+1]) 274 | q_list.append(q_value.item()) 275 | assert len(q_list)== q_len 276 | if input_ids.shape[-1]10:#todo it's 10 for quark 325 | for ilen in range(10,len(text_list)): 326 | text = ' '.join(text_list[:ilen+1]) 327 | self.dataset.append((text,label)) 328 | else: 329 | self.dataset.append((text,label)) 330 | # self.dataset.append((raw_text.strip(),label)) 331 | else: 332 | self.dataset = [(text,'') for text in data_dir] 333 | 334 | # self.src_lens = self.get_char_lens(self.src_file) 335 | self.src_lens = self.get_char_lens([x[0] for x in self.dataset]) 336 | self.max_source_length = max_length 337 | self.max_target_length = max_length 338 | 339 | self.label_token = label_token 340 | 341 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 342 | 343 | self.tokenizer = tokenizer 344 | self.prefix = prefix 345 | 346 | if n_obs is not None: 347 | self.src_lens = self.src_lens[:n_obs] 348 | self.src_lang = src_lang 349 | self.tgt_lang = tgt_lang 350 | self.tokenizer.padding_side = "left" 351 | 352 | def token_wrapper(args, token): 353 | if 'roberta' in args.model_name or 'gpt' in args.model_name or 'megatron' in args.model_name: 354 | return 'Ġ' + token 355 | else: 356 | return token 357 | 358 | def __len__(self): 359 | return len(self.src_lens) 360 | 361 | def __getitem__(self, index): 362 | if type(self.data_dir) == list: 363 | return [torch.tensor(self.dataset[index][0]), torch.tensor(self.dataset[index][0])!=self.tokenizer.pad_token_id, 0] 364 | # index = index + 1 # linecache starts at 1 365 | # source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 366 | # "\n") # +self.tokenizer.bos_token 367 | # 368 | # tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 369 | source_line,tgt_line = self.dataset[index] 370 | 371 | if "positive" in tgt_line: 372 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['positive'])) 373 | elif "negative" in tgt_line: 374 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['negative'])) 375 | else: 376 | raise ValueError 377 | 378 | assert source_line, f"empty source line for index {index}" 379 | # assert tgt_line, f"empty tgt line for index {index}" 380 | 381 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 382 | truncation=True, padding="max_length") 383 | 384 | return [res_input["input_ids"], res_input["attention_mask"], tgt_line] 385 | 386 | @staticmethod 387 | def get_char_lens(data_file): 388 | return [len(x) for x in data_file] 389 | # return [len(x) for x in Path(data_file).open().readlines()] 390 | 391 | class Classification_Dataset_double_sent(Dataset): 392 | def __init__( 393 | self, 394 | tokenizer, 395 | data_dir, 396 | max_length, 397 | type_path="train", 398 | n_obs=None, 399 | src_lang=None, 400 | tgt_lang=None, 401 | prefix="", 402 | label_token={} 403 | ): 404 | super().__init__() 405 | 406 | # self.src_file = Path(data_dir).joinpath(type_path + ".src") 407 | # self.tgt_file = Path(data_dir).joinpath(type_path + ".tgt") 408 | # transfer to fudge form 409 | self.data_dir = data_dir 410 | if type(data_dir) != list: 411 | with open(Path(data_dir).joinpath(type_path + ".src")) as f: 412 | all_text = f.readlines() 413 | with open(Path(data_dir).joinpath(type_path + ".tgt")) as f: 414 | labels = f.readlines() 415 | assert len(all_text)==len(labels) 416 | self.dataset=[] 417 | for raw_text,label in zip(all_text,labels): 418 | text_list=raw_text.strip().split() 419 | # if len(text_list)>10:#todo it's 10 for quark 420 | # for ilen in range(10,len(text_list)): 421 | # text = ' '.join(text_list[:ilen+1]) 422 | # self.dataset.append((text,label)) 423 | # else: 424 | # self.dataset.append((text,label)) 425 | self.dataset.append((raw_text.strip(),label)) 426 | else: 427 | self.dataset = [(text,'') for text in data_dir] 428 | 429 | # self.src_lens = self.get_char_lens(self.src_file) 430 | self.src_lens = self.get_char_lens([x[0] for x in self.dataset]) 431 | self.max_source_length = max_length 432 | self.max_target_length = max_length 433 | 434 | self.label_token = label_token 435 | 436 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 437 | 438 | self.tokenizer = tokenizer 439 | self.prefix = prefix 440 | 441 | if n_obs is not None: 442 | self.src_lens = self.src_lens[:n_obs] 443 | self.src_lang = src_lang 444 | self.tgt_lang = tgt_lang 445 | self.tokenizer.padding_side = "left" 446 | 447 | def token_wrapper(args, token): 448 | if 'roberta' in args.model_name or 'gpt' in args.model_name or 'megatron' in args.model_name: 449 | return 'Ġ' + token 450 | else: 451 | return token 452 | 453 | def __len__(self): 454 | return len(self.src_lens) 455 | 456 | def __getitem__(self, index): 457 | if type(self.data_dir) == list: 458 | return [torch.tensor(self.dataset[index][0]), torch.tensor(self.dataset[index][0])!=self.tokenizer.pad_token_id, 0] 459 | # index = index + 1 # linecache starts at 1 460 | # source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 461 | # "\n") # +self.tokenizer.bos_token 462 | # 463 | # tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 464 | source_line,tgt_line = self.dataset[index] 465 | 466 | if "positive" in tgt_line: 467 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['positive'])) 468 | elif "negative" in tgt_line: 469 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['negative'])) 470 | else: 471 | raise ValueError 472 | 473 | assert source_line, f"empty source line for index {index}" 474 | # assert tgt_line, f"empty tgt line for index {index}" 475 | 476 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 477 | truncation=True, padding="max_length") 478 | 479 | return [res_input["input_ids"], res_input["attention_mask"], tgt_line] 480 | 481 | @staticmethod 482 | def get_char_lens(data_file): 483 | return [len(x) for x in data_file] 484 | # return [len(x) for x in Path(data_file).open().readlines()] 485 | 486 | 487 | class Classification_Dataset_label_num_3(Dataset): 488 | def __init__( 489 | self, 490 | tokenizer, 491 | data_dir, 492 | max_length, 493 | type_path="train", 494 | n_obs=None, 495 | src_lang=None, 496 | tgt_lang=None, 497 | prefix="", 498 | label_token={} 499 | ): 500 | super().__init__() 501 | 502 | self.src_file = Path(data_dir).joinpath(type_path + ".src") 503 | self.tgt_file = Path(data_dir).joinpath(type_path + ".tgt") 504 | # transfer to fudge form 505 | 506 | self.data_dir = data_dir 507 | if type(data_dir) != list: 508 | dataset = json.load(open(self.data_dir)) 509 | self.dataset=[] 510 | for label,raw_text in dataset: 511 | if 'positive' in label: 512 | label = 'positive' 513 | elif 'negative' in label: 514 | label = 'negative' 515 | else: 516 | assert label == 'neutral' 517 | self.dataset.append((raw_text,label)) 518 | else: 519 | self.dataset = [(text,'') for text in data_dir] 520 | 521 | # self.src_lens = self.get_char_lens(self.src_file) 522 | self.src_lens = self.get_char_lens([x[0] for x in self.dataset]) 523 | self.max_source_length = max_length 524 | self.max_target_length = max_length 525 | 526 | self.label_token = label_token 527 | 528 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 529 | 530 | self.tokenizer = tokenizer 531 | self.prefix = prefix 532 | 533 | if n_obs is not None: 534 | self.src_lens = self.src_lens[:n_obs] 535 | self.src_lang = src_lang 536 | self.tgt_lang = tgt_lang 537 | self.tokenizer.padding_side = "left" 538 | 539 | def token_wrapper(args, token): 540 | if 'roberta' in args.model_name or 'gpt' in args.model_name or 'megatron' in args.model_name: 541 | return 'Ġ' + token 542 | else: 543 | return token 544 | 545 | def __len__(self): 546 | return len(self.src_lens) 547 | 548 | def __getitem__(self, index): 549 | if type(self.data_dir) == list: 550 | return [torch.tensor(self.dataset[index][0]), torch.tensor(self.dataset[index][0])!=self.tokenizer.pad_token_id, 0] 551 | # index = index + 1 # linecache starts at 1 552 | # source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 553 | # "\n") # +self.tokenizer.bos_token 554 | # 555 | # tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 556 | source_line,tgt_line = self.dataset[index] 557 | 558 | if "positive" in tgt_line: 559 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['positive'])) 560 | elif 'negative' in tgt_line: 561 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['negative'])) 562 | else: 563 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token['neutral'])) 564 | 565 | assert source_line, f"empty source line for index {index}" 566 | # assert tgt_line, f"empty tgt line for index {index}" 567 | 568 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 569 | truncation=True, padding="max_length") 570 | 571 | return [res_input["input_ids"], res_input["attention_mask"], tgt_line] 572 | 573 | @staticmethod 574 | def get_char_lens(data_file): 575 | return [len(x) for x in data_file] 576 | # return [len(x) for x in Path(data_file).open().readlines()] 577 | 578 | 579 | class Sentiment_Suffix(Dataset): 580 | def __init__( 581 | self, 582 | tokenizer, 583 | data_dir, 584 | max_length, 585 | task_type="positive", 586 | n_obs=None, 587 | src_lang=None, 588 | tgt_lang=None, 589 | prefix="", 590 | label_token={} 591 | ): 592 | super().__init__() 593 | 594 | self.src_file = data_dir 595 | 596 | self.src_lens = self.get_char_lens(self.src_file) 597 | self.max_source_length = max_length 598 | 599 | self.label_token = label_token 600 | self.task_type = task_type 601 | 602 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 603 | 604 | self.tokenizer = tokenizer 605 | self.prefix = prefix 606 | 607 | if n_obs is not None: 608 | self.src_lens = self.src_lens[:n_obs] 609 | self.src_lang = src_lang 610 | self.tgt_lang = tgt_lang 611 | self.tokenizer.padding_side = "left" 612 | 613 | def token_wrapper(args, token): 614 | if 'roberta' in args.model_name or 'gpt' in args.model_name or 'megatron' in args.model_name: 615 | return 'Ġ' + token 616 | else: 617 | return token 618 | 619 | def __len__(self): 620 | return len(self.src_lens) 621 | 622 | def __getitem__(self, index): 623 | index = index + 1 # linecache starts at 1 624 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 625 | "\n") # +self.tokenizer.bos_token 626 | 627 | tgt_line = torch.tensor(self.tokenizer.encode(self.label_token[self.task_type])) 628 | 629 | if len(source_line) < 2: 630 | source_line = "Hello world! Today is nice!" 631 | 632 | res_input = self.tokenizer.encode_plus(source_line, max_length=self.max_source_length, return_tensors="pt", 633 | truncation=True, padding="max_length") 634 | 635 | return [res_input["input_ids"], res_input["attention_mask"], tgt_line] 636 | 637 | @staticmethod 638 | def get_char_lens(data_file): 639 | return [len(x) for x in Path(data_file).open().readlines()] 640 | -------------------------------------------------------------------------------- /token_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import time 5 | import logging 6 | import random 7 | import argparse 8 | import numpy as np 9 | import itertools 10 | from typing import List 11 | from datetime import datetime 12 | from tqdm import tqdm 13 | import torch.nn.functional as F 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.optim import Adam, Optimizer 16 | from torch.optim.lr_scheduler import LambdaLR 17 | # from torch.utils.tensorboard import SummaryWriter 18 | from transformers import get_linear_schedule_with_warmup,AutoModelForSequenceClassification,AutoTokenizer 19 | 20 | from arguments import get_args 21 | # from policy_marian import Policy 22 | from data_pool import DataPool 23 | # from reward import Reward, reward_to_toxicity 24 | # from fudge_reward import RewardScore 25 | from utils.utils import ensure_dir, ceil_div, reduce_mean, reduce_sum, distinctness 26 | from token_gpt2 import Distill_Tuning 27 | from Sentiment.main_disc import Classifier 28 | # from nltk import sent_tokenize 29 | 30 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "DEBUG")) 31 | log = logging.getLogger(__name__) 32 | 33 | # def process_openwebtext_into_train_prompts: 34 | # from datasets import load_dataset 35 | # dataset = load_dataset('openwebtext')['train'][:50000] 36 | # dataset = [i['text'] for i in dataset] 37 | # neutral_num,negative_num,positive_num = 10000,10000,10000 38 | # model = AutoModelForSequenceClassification.from_pretrained( 39 | # 'cardiffnlp/twitter-roberta-base-sentiment-latest') 40 | # token = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-sentiment-latest') 41 | # input_ids_list,attention_mask_list=[],[] 42 | # for text in dataset: 43 | # t = ' '.join(text.split()[:20]) 44 | # inp = token(t,return_tensors='pt') 45 | # input_ids_list.append('inp') 46 | class PromptDataset(Dataset): 47 | # def __init__(self, args, mode): 48 | # # self.prompts = [json.loads(s.strip())["prompt"]["text"].strip() for s in open(path, 'r').readlines()][:100] 49 | # num_reference = 2 50 | # 51 | # valid_es = [] 52 | # train_es = [] 53 | # test_es = [] 54 | # train_en = [] 55 | # valid_en = [[]] * num_reference 56 | # test_en = [[]] * num_reference 57 | # 58 | # with open('../NADO/fisher-callhome-corpus/corpus/ldc/fisher_test.es', 'r') as f: 59 | # for line in f: 60 | # test_es.append(line.strip()) 61 | # 62 | # with open('../NADO/fisher-callhome-corpus/corpus/ldc/fisher_train.es', 'r') as f: 63 | # for line in f: 64 | # train_es.append(line.strip()) 65 | # 66 | # with open('../NADO/fisher-callhome-corpus/corpus/ldc/fisher_dev.es', 'r') as f: 67 | # for line in f: 68 | # valid_es.append(line.strip()) 69 | # 70 | # with open('../NADO/fisher-callhome-corpus/corpus/ldc/fisher_train.en', 'r') as f: 71 | # for line in f: 72 | # train_en.append(line.strip()) 73 | # 74 | # for i in range(num_reference): 75 | # with open('../NADO/fluent-fisher/noids/dev.noid.cleaned_%d' % (i), 'r') as f: 76 | # for line in f: 77 | # # clean_line = line.strip().split()[1:] 78 | # # clean_line = " ".join(clean_line) 79 | # valid_en[i].append(line.strip()) 80 | # 81 | # for i in range(num_reference): 82 | # with open('../NADO/fluent-fisher/noids/test.noid.cleaned_%d' % (i), 'r') as f: 83 | # for line in f: 84 | # # clean_line = line.strip().split()[1:] 85 | # # clean_line = " ".join(clean_line) 86 | # test_en[i].append(line.strip()) 87 | # 88 | # if mode=='train': 89 | # self.dataset = [i for i in train_es] 90 | # self.ref = train_en 91 | # self.dataset = self.dataset[:320] 92 | # elif mode =='valid': 93 | # self.dataset = [i for i in valid_es] 94 | # self.ref = valid_en 95 | # else: 96 | # self.dataset = [i for i in test_es] 97 | # self.ref = test_en 98 | def __init__(self,mode='positive',train=True): 99 | if train==True: 100 | dataset = json.load(open('Sentiment/data/train_prompts_v1.json')) 101 | dataset = random.sample(dataset,12000) 102 | self.dataset = [i for _ in range(1) for i in dataset] 103 | else: 104 | path = 'Sentiment/data/sentiment_prompts-10k/' + mode + '_prompts.jsonl' 105 | with open(path) as f: 106 | dataset = [json.loads(line)['prompt']['text'] for line in f] 107 | self.dataset = dataset 108 | 109 | def __len__(self): 110 | return len(self.dataset) 111 | 112 | def __getitem__(self, index): 113 | return self.dataset[index] 114 | 115 | 116 | class PromptCollator(object): 117 | def __init__(self, tokenizer, max_source_length): 118 | self.max_source_length = max_source_length 119 | self.tokenizer = tokenizer 120 | 121 | def __call__(self, sequences): 122 | res_input = self.tokenizer.batch_encode_plus(sequences, max_length=self.max_source_length, return_tensors="pt", 123 | truncation=True, padding="max_length") 124 | return res_input['input_ids'],res_input['attention_mask'] 125 | 126 | class SequenceDataset(Dataset): 127 | def __init__(self, data_pool: DataPool): 128 | self.ids, self.masks, self.cat_mask = data_pool.get_data() 129 | 130 | def __len__(self): 131 | return len(self.ids) 132 | 133 | def __getitem__(self, idx): 134 | return {'input_ids': self.ids[idx], 135 | 'output_mask':self.masks[idx], 136 | 'cat_mask': self.cat_mask[idx] 137 | } 138 | 139 | class SequenceCollator(object): 140 | def __init__(self, args,tokenizer): 141 | self.tokenizer = tokenizer 142 | self.device = args.device 143 | # def __call__(self, sequences): 144 | # queries = [sequence['query'] for sequence in sequences] 145 | # responses = [sequence['response'] for sequence in sequences] 146 | # # cat_ids = [self.tokenizer.convert_tokens_to_ids(sequence['cat_tokens']) for sequence in sequences] 147 | # 148 | # query_encodings_dict = self.tokenizer(queries, return_tensors="pt", padding=True) 149 | # query_input_ids = query_encodings_dict['input_ids'] 150 | # query_mask = query_encodings_dict['attention_mask'] 151 | # # query_input_ids = torch.cat([query_input_ids.new(cat_ids)[:, None], query_input_ids], dim=1) 152 | # # query_mask = torch.cat([query_mask.new([1] * len(query_mask))[:, None], query_mask], dim=1) 153 | # 154 | # response_encodings_dict = self.tokenizer(responses, return_tensors="pt", padding=True) 155 | # response_input_ids = response_encodings_dict['input_ids'] 156 | # response_mask = response_encodings_dict['attention_mask'] 157 | # 158 | # cat_mask = [sequence['cat_mask'] for sequence in sequences] 159 | # mask_tensor,len_list = self.pad_tensor(cat_mask) 160 | # assert len_list == response_mask.sum(-1).to_list() 161 | # 162 | # return query_input_ids, query_mask, response_input_ids, response_mask, mask_tensor 163 | # 164 | def __call__(self,sequences): 165 | input_ids = torch.tensor([sequence['input_ids'] for sequence in sequences],device=self.device) 166 | input_mask = input_ids != self.tokenizer.eos_token_id 167 | output_mask = torch.tensor([sequence['output_mask'] for sequence in sequences],device=self.device) 168 | reward_mask = torch.tensor([sequence['cat_mask'] for sequence in sequences],device=self.device) 169 | # punish_mask = torch.tensor([sequence['cat_mask'] for sequence in sequences],device=self.device) == -1 170 | return input_ids,input_mask,output_mask,reward_mask 171 | 172 | def pad_tensor(self,mask_list): 173 | len_list = [len(mask) for mask in mask_list] 174 | max_len = max(len_list) 175 | tensor_stack = [mask+[0]*(max_len-len(mask)) for mask in mask_list] 176 | padded_tensor = torch.tensor(tensor_stack) 177 | return padded_tensor,len_list 178 | 179 | class FixedController: 180 | def __init__(self, coef): 181 | self.value = coef 182 | 183 | def update(self, current, n_steps, lower_bound): 184 | pass 185 | 186 | 187 | class AdaptiveController: 188 | def __init__(self, init_coef, target, horizon): 189 | self.value = init_coef 190 | self.target = target 191 | self.horizon = horizon 192 | 193 | def update(self, current, n_steps, lower_bound): 194 | proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) 195 | if lower_bound: 196 | mult = 1 + proportional_error * n_steps / self.horizon 197 | else: 198 | mult = 1 - proportional_error * n_steps / self.horizon 199 | self.value *= mult 200 | 201 | 202 | class Evaluation(): 203 | def __init__(self): 204 | self.tokenizer = AutoTokenizer.from_pretrained('/home/lwd/distilbert-base-uncased-finetuned-sst-2-english') 205 | self.model = AutoModelForSequenceClassification.from_pretrained( 206 | '/home/lwd/distilbert-base-uncased-finetuned-sst-2-english') 207 | 208 | def eval(self,text,target='positive'): 209 | inputs = self.tokenizer(text,return_tensors='pt',padding=True,truncation=True) 210 | output = self.model(**inputs) 211 | predicted_class_id = output.logits.argmax(-1) 212 | labels = [self.model.config.id2label[i] for i in predicted_class_id.tolist()] 213 | nums = [1 for i in labels if i.lower()==target] 214 | return sum(nums),len(labels) 215 | 216 | def score(self,text,target='POSITIVE'): 217 | inputs = self.tokenizer(text, return_tensors='pt', padding=True) 218 | output = self.model(**inputs) 219 | #todo check whether has been softmax? 220 | id = self.model.config.label2id[target] 221 | scores = output.logits[:,id] 222 | return scores 223 | 224 | class ConditionTrainer: 225 | def __init__(self, 226 | params: argparse.Namespace, 227 | policy, 228 | ref_policy, 229 | data_pool: DataPool, 230 | train_dataloader: DataLoader, 231 | val_dataloader: DataLoader, 232 | optimizer: Optimizer, 233 | scheduler: LambdaLR): 234 | 235 | self.params = params 236 | self.policy = policy 237 | self.ref_policy = ref_policy 238 | self.data_pool = data_pool 239 | # self.score_model = score_model 240 | self.optimizer = optimizer 241 | self.scheduler = scheduler 242 | self.train_dataloader = train_dataloader 243 | self.val_dataloader = val_dataloader 244 | # self.writer = SummaryWriter() 245 | self.q_record=[] 246 | 247 | if self.params.adaptive_kl: 248 | self.kl_ctl = AdaptiveController(self.params.kl_coef, self.params.target_kl, self.params.horizon) 249 | else: 250 | self.kl_ctl = FixedController(self.params.kl_coef) 251 | self.kl_loss = torch.nn.KLDivLoss(reduction="none") 252 | 253 | if self.params.adaptive_entropy: 254 | self.entropy_ctl = AdaptiveController(self.params.entropy_coef, self.params.target_entropy, 255 | self.params.horizon) 256 | else: 257 | self.entropy_ctl = FixedController(self.params.entropy_coef) 258 | 259 | # self.tree_tokens = tree_tokens 260 | # self.best_cat = self.tree_tokens[0] 261 | # self.best_cat_id = self.policy.tokenizer.convert_tokens_to_ids(self.best_cat) 262 | # self.pos_id,self.neg_id,self.pad_id=special_ids 263 | # self.special_tokens_num = len(special_ids)-1 264 | 265 | self.sample_dataloader, self.sampler = None, None 266 | self.seq_collator = SequenceCollator(self.params,tokenizer=policy.tokenizer) 267 | self.classifier = Evaluation() 268 | self.best_correctness = 0 269 | self.best_distinct = [] 270 | 271 | 272 | def add_control_code(self, input_ids, attention_mask): 273 | input_ids = torch.cat([input_ids.new([self.pad_id] * len(input_ids))[:, None], input_ids], dim=1) 274 | pos_ids = torch.cat([input_ids.new([self.pos_id] * len(input_ids))[:, None], input_ids], dim=1) 275 | neg_ids = torch.cat([input_ids.new([self.neg_id] * len(input_ids))[:, None], input_ids], dim=1) 276 | attention_mask = torch.cat([attention_mask.new([1] * len(attention_mask))[:, None], attention_mask], dim=1) 277 | attention_mask = torch.cat([attention_mask.new([1] * len(attention_mask))[:, None], attention_mask], dim=1) 278 | return input_ids,neg_ids,attention_mask 279 | 280 | def decode(self, query_input_ids, response_input_ids=None): 281 | query = [self.policy.tokenizer.decode(p, skip_special_tokens=True, clean_up_tokenization_spaces=True) 282 | for p in query_input_ids] 283 | 284 | if response_input_ids is None: 285 | return query 286 | 287 | response = [self.policy.tokenizer.decode(r, skip_special_tokens=True, clean_up_tokenization_spaces=True) 288 | for r in response_input_ids] 289 | return query, response 290 | 291 | def sample(self, step): 292 | if step % self.params.sample_interval != 0: 293 | return 294 | log.info(f"[step {step}] Sampling ...") 295 | 296 | # prompts, responses = [], [] 297 | # q_inputs,d_inputs,d_len = [],[],[] 298 | text_list,input_list,mask_list,q_values=[],[],[],[] 299 | for i, batch in enumerate(tqdm(self.train_dataloader, total=len(self.train_dataloader), 300 | desc='Sampling from current policy')): 301 | input_ids, attention_mask = batch 302 | input_ids = input_ids.to(self.params.device) 303 | attention_mask = attention_mask.to(self.params.device) 304 | 305 | if step == 0: 306 | # rollouts = self.ref_policy.sample(input_ids=input_ids, attention_mask=attention_mask, top_p=self.params.top_p) 307 | rollouts = self.ref_policy.sample(prompts_ids=input_ids, max_length=20+self.params.max_prompt_length,mode=None) 308 | text,input_ids,mask,q_value = rollouts['text'],rollouts['input_ids'],rollouts['output_mask'],rollouts['q_values'] 309 | # prompt, response = rollouts['query/text'], rollouts['response/text'] 310 | # d_len.extend((rollouts['response/input_ids'][:, 1:] != self.ref_policy.tokenizer.pad_token_id).sum(-1).tolist()) 311 | else: 312 | # pos_ids,neg_ids,attention_mask = self.add_control_code(input_i 313 | # ds, attention_mask) 314 | rollouts = self.policy.sample(prompts_ids=input_ids, max_length=20+self.params.max_prompt_length, mode='positive') 315 | text, input_ids, mask,q_value = rollouts['text'], rollouts['input_ids'], rollouts['output_mask'],rollouts['q_values'] 316 | # prompt = rollouts['query/text'] 317 | # d_len.extend((rollouts['response/input_ids'][:, 2:] != self.policy.tokenizer.pad_token_id).sum(-1).tolist()) 318 | 319 | # prompts.extend(prompt) 320 | # responses.extend(response) 321 | text_list.extend(text) 322 | input_list.extend(input_ids.tolist()) 323 | mask_list.extend(mask.tolist()) 324 | q_values.extend(q_value.tolist()) 325 | #todo log_probs 326 | # input_ids,rewards,sen_mask = self.score_model.score(input_list,mode='positive') 327 | # aa=1 328 | # assert input_ids.tolist() == input_list 329 | self.data_pool.add(input_list, mask_list, q_values, pos=True) 330 | self.q_record.append(self.data_pool.r_limit) 331 | print(self.q_record) 332 | sample_dataset = SequenceDataset(data_pool=self.data_pool) 333 | self.sample_dataloader = DataLoader(sample_dataset, batch_size=self.params.batch_size, 334 | shuffle=False, drop_last=True, collate_fn=self.seq_collator) 335 | self.sampler = iter(self.sample_dataloader) 336 | 337 | def step(self, step_num): 338 | 339 | with torch.no_grad(): 340 | self.eval(step=step_num) 341 | self.sample(step=step_num) 342 | try: 343 | batch = next(self.sampler) 344 | assert len(batch[0]) == self.params.batch_size, 'insufficient batch' 345 | except (StopIteration, AssertionError): 346 | self.sampler = iter(self.sample_dataloader) 347 | batch = next(self.sampler) 348 | 349 | self.optimizer.zero_grad() 350 | ppo_loss = self.loss(step_num, *batch) 351 | ppo_loss.backward() 352 | if self.params.clip_grad: 353 | torch.nn.utils.clip_grad_norm_(self.policy.model.parameters(), self.params.max_grad_norm) 354 | self.optimizer.step() 355 | self.scheduler.step() 356 | 357 | 358 | def loss(self, step, input_ids, input_mask, output_mask, reward_tensor): 359 | self.policy.model.train() 360 | outputs = self.policy.forward_pass(input_ids, input_mask, output_mask, reward_tensor, mode='positive') 361 | lm_loss, logits,entropy = outputs['loss'],outputs['logits'],outputs['entropy'] 362 | # logits = outputs['response/logits'][:, :, :-len(self.special_tokens_num)] 363 | kl_mask = outputs['kl_mask'] 364 | 365 | with torch.no_grad(): 366 | ref_outputs = self.ref_policy.forward_pass(input_ids, input_mask, output_mask,mode=None) 367 | ref_logits = ref_outputs['logits'] 368 | pad_logits = torch.zeros(ref_logits.shape[0],self.policy.spell_length,ref_logits.shape[-1],device=self.params.device) 369 | ref_logits = torch.cat([pad_logits,ref_logits],dim=1) 370 | 371 | # kl = torch.sum(self.kl_loss(F.log_softmax(ref_logits, dim=-1), F.softmax(logits, dim=-1)), dim=-1) 372 | kl = torch.sum( 373 | torch.softmax(ref_logits, dim=-1) * (F.log_softmax(ref_logits, dim=-1) - F.log_softmax(logits, dim=-1)), 374 | dim=-1) 375 | 376 | loss = lm_loss + reduce_mean(- self.entropy_ctl.value * entropy, kl_mask) + reduce_mean(self.kl_ctl.value * kl, torch.ones_like(kl_mask)) 377 | 378 | # kl_loss = reduce_mean(self.kl_ctl.value * kl, kl_mask) 379 | # loss = lm_loss + kl_loss 380 | 381 | 382 | 383 | # queries = self.decode(input_ids) 384 | # self.print_samples(queries=queries, responses=responses, lm_loss=reduce_mean(lm_loss, masks, axis=1), 385 | # logprobs=logprobs, ref_logprobs=ref_logprobs, masks=masks, step=step) 386 | # self.print_samples(queries=queries, lm_loss=reduce_mean(lm_loss, masks, axis=1), 387 | # loss=loss,step=step) 388 | # r_loss = reduce_mean(lm_loss,masks) 389 | # r_kl_loss = reduce_mean(self.kl_ctl.value * kl,masks) 390 | if step % self.params.log_interval ==0: 391 | log.info(f"[step {step}] lm_loss={lm_loss:.4f}, kl={reduce_mean(kl,kl_mask):.4f},entropy={reduce_mean(- self.entropy_ctl.value * entropy, kl_mask):.4f}") 392 | 393 | 394 | return loss 395 | 396 | def record_step_stats(self, data): 397 | masks = data['masks'] 398 | stats = {} 399 | # kl = torch.sum(self.kl_loss(F.log_softmax(data['ref_logits'], dim=-1), F.softmax(data['logits'], dim=-1)), dim=-1) 400 | # mean_kl = torch.mean(reduce_sum(kl, masks, axis=1)) 401 | # mean_entropy = torch.mean(reduce_sum(-data['logprobs'], masks, axis=1)) 402 | # stats = { 403 | # 'objective/kl': mean_kl.item(), 404 | # } 405 | stats.update({ 406 | 'loss/total': data['total_loss'].item(), 407 | 'loss/kl': data['kl_loss'].item(), 408 | 'loss/lm': data['lm_loss'].item(), 409 | }) 410 | # stats = { 411 | # 'objective/kl': mean_kl.item(), 412 | # 'objective/entropy': mean_entropy.item(), 413 | # } 414 | # stats.update({ 415 | # 'loss/total': data['total_loss'].item(), 416 | # 'loss/kl': data['kl_loss'].item(), 417 | # 'loss/lm': data['lm_loss'].item(), 418 | # 'loss/entropy': data['entropy'].item(), 419 | # }) 420 | return stats 421 | 422 | def print_samples(self, queries, lm_loss, loss, step): 423 | if step % self.params.log_interval != 0: 424 | return 425 | # Log samples 426 | for i in range(min(3, len(queries))): 427 | # sample_kl = torch.sum((logprobs[i] - ref_logprobs[i]) * masks[i]).item() 428 | print(queries[i]) 429 | print(f" lm_loss = {lm_loss[i].item():+.2f}") 430 | # print(f" total_loss = {loss[i].item():+.2f}") 431 | # print(f" kl = {sample_kl:+.2f}") 432 | # print(f" total = {lm_loss[i].item() + self.params.kl_coef * sample_kl:+.2f}") 433 | 434 | def save(self, mode): 435 | # if step % self.params.save_interval != 0: 436 | # return 437 | torch.save({ 438 | 'prompt_encoder_pos': self.policy.prompt_encoder_pos.state_dict(), 439 | 'optimizer': self.optimizer.state_dict(), 440 | 'scheduler': self.scheduler.state_dict() 441 | }, f'{self.params.model_dir}/ckp_{mode}.pth') 442 | log.info(f"model checkpoint saved once") 443 | 444 | def load(self,load_dir): 445 | load_dic = torch.load(load_dir) 446 | self.policy.prompt_encoder_pos.load_state_dict(load_dic['prompt_encoder_pos']) 447 | self.optimizer.load_state_dict(load_dic['optimizer']) 448 | self.scheduler.load_state_dict(load_dic['scheduler']) 449 | log.info(f"Load Model Successfully from {load_dir}") 450 | 451 | def eval(self, step): 452 | if step % self.params.eval_interval != 0: 453 | return 454 | self.policy.model.eval() 455 | log.info(f"[step {step}] evaluating ...") 456 | generations, perplexities, toxicities = [], [], [] 457 | correct,count = 0,0 458 | for i, (input_ids, attention_mask) in enumerate(tqdm(self.val_dataloader)): 459 | with torch.no_grad(): 460 | input_ids = input_ids.to(self.params.device) 461 | attention_mask = attention_mask.to(self.params.device) 462 | # input_ids, attention_mask = self.add_control_code(input_ids, attention_mask) 463 | rollouts = self.policy.sample(prompts_ids=input_ids,max_length=20 + self.params.max_prompt_length,mode='positive',gen=True) 464 | cur_cast_mask = torch.ones_like(rollouts['input_ids'])[:,:-1] 465 | forward_inputs = {'x_hs':rollouts['input_ids'], 466 | 'att_mask':rollouts['input_ids']!=self.policy.tokenizer.pad_token_id, 467 | 'out_mask':rollouts['output_mask'], 468 | 'reward_mask':cur_cast_mask} 469 | outputs = self.policy.forward_pass(**forward_inputs,mode='positive',gen=True) 470 | ref_logprobs = outputs['logprob'] 471 | ## prompt = self.decode(rollouts['query/input_ids'][:, 1:]) 472 | # response = rollouts['response/text'] 473 | # score = self.score_model.get_reward(prompt, response, f'step{step}_eval{i}') 474 | # toxicity = [reward_to_toxicity(x) for x in score if x is not None] 475 | # toxicities.extend(toxicity) 476 | generations.extend(rollouts['text']) 477 | 478 | x1,x2 = self.classifier.eval(rollouts['text'],target=self.params.target_mode) 479 | correct += x1 480 | count += x2 481 | 482 | correctness = correct/count 483 | 484 | dist_1, dist_2, dist_3, dist_4 = distinctness(generations) 485 | log.info('*******************************') 486 | log.info(f" correctness = {correctness:+.4f}") 487 | log.info(f'dist-1={dist_1:.3f}, dist-2={dist_2:.3f}, dist-3={dist_3:.3f}, dist-4={dist_4:.3f}') 488 | log.info('***example***') 489 | log.info(generations[-1]) 490 | log.info(generations[-2]) 491 | log.info(generations[-3]) 492 | log.info(generations[-4]) 493 | log.info(generations[-5]) 494 | log.info(generations[-6]) 495 | log.info('******************************') 496 | 497 | result = f"cor={correctness:+.3f}+step={step}+dist={dist_1:.3f}-{dist_2:.3f}-{dist_3:.3f}-{dist_4:.3f}" 498 | # if correctness > self.best_correctness: 499 | self.save(result) 500 | self.best_correctness = correctness 501 | # self.writer.add_scalar('Evaluation/perplexity', ppl_score, step) 502 | # self.writer.add_scalar('Evaluation/toxicity', toxicity_score, step) 503 | # self.writer.add_scalar('Evaluation/Dist-1', dist_1, step) 504 | # self.writer.add_scalar('Evaluation/Dist-2', dist_2, step) 505 | # self.writer.add_scalar('Evaluation/Dist-3', dist_3, step) 506 | 507 | 508 | 509 | def main(): 510 | args = get_args() 511 | 512 | random.seed(args.seed) 513 | np.random.seed(args.seed) 514 | torch.manual_seed(args.seed) 515 | torch.cuda.manual_seed(args.seed) 516 | torch.cuda.manual_seed_all(args.seed) 517 | 518 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 519 | torch.backends.cudnn.deterministic = True 520 | torch.backends.cudnn.benchmark = False 521 | 522 | 523 | device= 'cuda' 524 | 525 | time = datetime.now() 526 | # date_time = time.strftime("%m-%d-%Y_%H:%M:%S") 527 | date_time = time.strftime("%m-%d-%Y") 528 | args.save_dir = os.path.join(args.output_dir, date_time) 529 | args.reward_dir = os.path.join(args.save_dir, 'reward') 530 | args.model_dir = os.path.join(args.save_dir, 'model-fudge-way') 531 | args.tensorboard_dir = os.path.join(args.save_dir, 'tensorboard') 532 | for d in [args.output_dir, args.save_dir, args.reward_dir, args.model_dir, args.tensorboard_dir]: 533 | ensure_dir(d) 534 | log.info(f'Write to output directory: {args.save_dir}') 535 | 536 | with open(os.path.join(args.save_dir, 'args.json'), 'w') as f: 537 | json.dump(args.__dict__, f, indent=2) 538 | 539 | 540 | log.info(f'Initializing models ...') 541 | 542 | label_token = {"positive": 'good', "negative": 'bad','neutral':'neutral'} 543 | ref_policy = Distill_Tuning(args,args.template,label_token) 544 | policy = Distill_Tuning(args, args.template, label_token) 545 | 546 | data_pool = DataPool() 547 | log.info(f'Initialization done!') 548 | 549 | prompt_collator = PromptCollator(tokenizer=ref_policy.tokenizer,max_source_length=args.max_prompt_length) 550 | train_dataset = PromptDataset(mode='neutral',train=True) 551 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, collate_fn=prompt_collator) 552 | log.info(f'Load train set with {len(train_dataset)} examples') 553 | 554 | val_dataset = PromptDataset(mode='neutral',train=False) 555 | val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size * 2, shuffle=False, collate_fn=prompt_collator) 556 | log.info(f'Load val set with {len(val_dataset)} examples') 557 | 558 | # set up optimizer and scheduler 559 | parameters2update = [para for name,para in policy.named_parameters() if 'prompt' in name] 560 | optimizer = Adam(parameters2update, lr=args.lr, eps=1e-5) 561 | args.total_steps = ceil_div(args.total_episodes, args.batch_size) 562 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.total_steps) 563 | 564 | # special_ids = [policy.tokenizer.get_vocab()['__pos__'],policy.tokenizer.get_vocab()['__neg__'],policy.tokenizer.pad_token_id] 565 | trainer = ConditionTrainer(params=args, policy=policy, ref_policy=ref_policy, data_pool=data_pool, 566 | train_dataloader=train_dataloader, val_dataloader=val_dataloader, 567 | optimizer=optimizer, scheduler=scheduler) 568 | 569 | 570 | for step_num in range(100000): 571 | if step_num>30000: 572 | trainer.save_result() 573 | break 574 | trainer.step(step_num) 575 | 576 | 577 | 578 | if __name__ == "__main__": 579 | main() 580 | -------------------------------------------------------------------------------- /Sentiment/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from os.path import join 4 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 5 | 6 | import re 7 | import datetime 8 | 9 | from transformers import AutoTokenizer 10 | 11 | from Sentiment.models import get_embedding_layer, create_model, _create_model 12 | from Sentiment.prompt_encoder import PromptEncoder 13 | # from models import get_embedding_layer, create_model, _create_model 14 | # from prompt_encoder import PromptEncoder 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | SMALL_CONST = 1e-10 21 | BIG_CONST = -1e15 22 | 23 | 24 | class PTuneForLAMA(torch.nn.Module): 25 | def __init__(self, args, template, label_token=None): 26 | super().__init__() 27 | self.args = args 28 | self.label_num = 2 29 | self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path) 30 | self.tokenizer.pad_token = self.tokenizer.eos_token 31 | 32 | # model setting 33 | self.model = create_model(self.args) 34 | # self.model.resize_token_embeddings(len(self.tokenizer)) 35 | self.model = self.model.to(self.args.device) 36 | # for param in self.model.parameters(): 37 | # param.requires_grad = self.args.use_lm_finetune 38 | 39 | self.template = template 40 | 41 | # get model's embeddings 42 | self.embeddings = self.model.get_input_embeddings() 43 | 44 | self.label_token = label_token 45 | self.label_map = {} 46 | self.label_token_ids = {} 47 | 48 | for k, v in self.label_token.items(): 49 | print(k, v, self.tokenizer.convert_tokens_to_ids(v)) 50 | self.label_map[self.tokenizer.convert_tokens_to_ids(v)] = k 51 | 52 | self.label_token_ids[k] = self.tokenizer.convert_tokens_to_ids(v) 53 | 54 | # load prompt encoder 55 | self.hidden_size = self.embeddings.embedding_dim 56 | self.pseudo_token_id = self.tokenizer.convert_tokens_to_ids(self.args.pseudo_token) 57 | 58 | # if self.args.disc_embedding_checkpoint == None: 59 | self.spell_length_disc = sum(self.template) 60 | self.prompt_encoder_disc = PromptEncoder(self.template, self.hidden_size, self.tokenizer, args) 61 | self.prompt_encoder_disc = self.prompt_encoder_disc.to(args.device) 62 | self.prompt_encoder = self.prompt_encoder_disc 63 | self.disc_embedding = self.embeddings 64 | 65 | self.fc_loss = CrossEntropyLoss() 66 | 67 | def load_prompt(self, embedding_checkpoint): 68 | checkpoint = torch.load(embedding_checkpoint) 69 | prompt_embedding = checkpoint['embedding'] 70 | return prompt_embedding 71 | 72 | def generate(self, prompts_ids, max_length, desired_att=None, beta=0.5): 73 | """ 74 | generation forward based on given prompt tokens, 75 | Args: 76 | prompt_ids: the prompt tokens 77 | max_length: the max len of the generation 78 | Returns: 79 | generated_texts:[generated tokens] 80 | """ 81 | cur_len = prompts_ids.shape[1] 82 | logits = [] 83 | output_ids = prompts_ids 84 | return_dict = {} 85 | eos_flag = torch.ones([prompts_ids.shape[0]]).type(torch.uint8).to(self.args.device) 86 | 87 | # start = datetime.datetime.now() 88 | past = None 89 | while cur_len <= max_length: 90 | past_k_v = past 91 | future_logits, past = self.generate_soft_tokens(output_ids, past_k_v) 92 | next_token_logits = future_logits.clone().detach().squeeze(1) 93 | perturb_logits = self.feedback_from_discriminator(output_ids, future_logits.unsqueeze(1), desired_att) 94 | 95 | next_token_logits_prob = torch.softmax(next_token_logits, dim=1) 96 | 97 | perturb_logits_prob = torch.softmax(perturb_logits, dim=1) 98 | 99 | next_token_logits_prob = perturb_logits_prob.mul(next_token_logits_prob) 100 | 101 | next_tokens = torch.multinomial(next_token_logits_prob, num_samples=1).squeeze(1) 102 | ## avoid eos token appeals continuely 103 | eos_flag = eos_flag.mul((next_tokens != self.tokenizer.eos_token_id).type( 104 | torch.uint8)) # if flag = 0, it means the generated is over 105 | next_tokens = next_tokens.mul(eos_flag) 106 | next_tokens[next_tokens == 0] = self.tokenizer.eos_token_id 107 | output_ids = torch.cat([output_ids, next_tokens.unsqueeze(1)], dim=1) 108 | print("cur_len is:", cur_len) 109 | cur_len = cur_len + 1 110 | 111 | # end = datetime.datetime.now() 112 | # print("runing time is:",end-start) 113 | 114 | return_dict = {"generated_tokens": output_ids} 115 | return return_dict 116 | 117 | def generate_soft_tokens(self, generated_tokens, past_key_values=None): 118 | 119 | if past_key_values != None: 120 | last_embeds = self.embeddings(generated_tokens[:, -1]).unsqueeze(1) # get its embeddings 121 | # print("last_embeds:", last_embeds.shape) 122 | with torch.no_grad(): 123 | outputs = self.model(inputs_embeds=last_embeds, 124 | past_key_values=past_key_values, 125 | return_dict=True) 126 | 127 | else: 128 | attention_mask = (generated_tokens != self.tokenizer.eos_token_id).type(torch.uint8) 129 | position_ids = attention_mask.long().cumsum(-1) - 1 130 | position_ids = position_ids.masked_fill_(attention_mask == 0, 0) 131 | last_embeds = self.embeddings(generated_tokens) # get its embeddings 132 | 133 | with torch.no_grad(): 134 | outputs = self.model(inputs_embeds=last_embeds, 135 | past_key_values=past_key_values, 136 | attention_mask=attention_mask, 137 | position_ids=position_ids, 138 | return_dict=True) 139 | 140 | next_token_logits = outputs.logits[:, -1, :] 141 | 142 | next_token_logits = self.top_k_top_p_filtering(next_token_logits.squeeze(1), top_k=self.args.ranking_scope, 143 | top_p=self.args.top_p, filter_value=BIG_CONST) 144 | 145 | return next_token_logits, outputs.past_key_values 146 | 147 | def discriminator_predict(self, input_ids): 148 | 149 | input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 150 | 151 | musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 152 | 153 | pred_ids = self.predict(input_ids_left_pad, musk) 154 | 155 | return pred_ids 156 | 157 | def scores_predict(self, input_ids): 158 | 159 | input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 160 | 161 | musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 162 | 163 | pred_scores = self.predict_scores(input_ids_left_pad, musk) 164 | 165 | return pred_scores 166 | 167 | def top_k_top_p_filtering(self, 168 | logits, 169 | top_k=0, 170 | top_p=1.0, 171 | filter_value=BIG_CONST, 172 | min_tokens_to_keep=1, 173 | ): 174 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 175 | Args: 176 | logits: logits distribution shape (batch size, vocabulary size) 177 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 178 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 179 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 180 | Make sure we keep at least min_tokens_to_keep per batch example in the output 181 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 182 | """ 183 | if top_k > 0: 184 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 185 | # Remove all tokens with a probability less than the last token of the top-k 186 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 187 | logits[indices_to_remove] = filter_value 188 | 189 | if top_p < 1.0: 190 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 191 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 192 | 193 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 194 | sorted_indices_to_remove = cumulative_probs > top_p 195 | if min_tokens_to_keep > 1: 196 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 197 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 198 | # Shift the indices to the right to keep also the first token above the threshold 199 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 200 | sorted_indices_to_remove[..., 0] = 0 201 | 202 | # scatter sorted tensors to original indexing 203 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 204 | logits[indices_to_remove] = filter_value 205 | 206 | return logits 207 | 208 | def pad_left_to_right(self, inputs, pad_id): 209 | trans_inputs = torch.empty_like(inputs) 210 | 211 | input_remove_prompt = inputs[:, self.prompt_pad_length:] 212 | 213 | index_musk = (input_remove_prompt != pad_id).type(torch.uint8) # only calculte the token which is not eos 214 | 215 | length_of_generated_text = torch.sum(index_musk, 1) 216 | 217 | valid_number_length = length_of_generated_text + self.prompt_pad_length 218 | 219 | count = 0 220 | for index, seq in zip(valid_number_length, inputs): 221 | 222 | if index == 0 or index == inputs.shape[1]: 223 | trans_inputs[count] = seq 224 | else: 225 | trans_inputs[count][-index:] = seq[:index] 226 | trans_inputs[count][:inputs.shape[1] - index] = seq[index:] 227 | count += 1 228 | 229 | return trans_inputs, length_of_generated_text 230 | 231 | def feedback_from_discriminator(self, input_ids, logits_seq, desired_att): 232 | logits_seq = logits_seq.squeeze(1) 233 | top_logits, top_indices = logits_seq.topk(self.args.ranking_scope, dim=1) # batch x topk 234 | 235 | scores = [] 236 | candidates = [] 237 | for logit_id, ids in zip(top_indices, input_ids): 238 | data = ids.expand(self.args.ranking_scope, -1) 239 | new_input_candidates = torch.cat([data, logit_id.unsqueeze(1)], dim=1) # batch x topk x seq+1 240 | candidates.append(new_input_candidates) 241 | candidates = torch.cat(candidates, dim=0) 242 | 243 | musk = (candidates != self.tokenizer.eos_token_id).type(torch.uint8) 244 | pred_scores = self._predict_scores(candidates, musk) 245 | pred_scores = pred_scores.reshape(input_ids.shape[0], -1) 246 | 247 | logits_seq.scatter_(-1, top_indices, pred_scores) 248 | 249 | indices_to_remove = logits_seq < torch.topk(logits_seq, 3)[0][..., -1, None] 250 | logits_seq[indices_to_remove] = BIG_CONST 251 | 252 | return logits_seq 253 | 254 | def gradient_feedback_from_discriminator(self, past, logits_seq, desired_att, lr): 255 | 256 | musk_value = torch.empty_like(logits_seq).fill_(0.0).long().to(self.args.device) 257 | indices_musk = (logits_seq == BIG_CONST) 258 | musk_value[indices_musk] = BIG_CONST 259 | musk_value.requires_grad = False 260 | 261 | index = torch.nonzero(logits_seq != BIG_CONST) 262 | 263 | logits_seqs = torch.empty_like(logits_seq) 264 | 265 | logits_seqs.fill_(0.0) 266 | 267 | update_logit = logits_seqs 268 | update_logit.requires_grad = True 269 | 270 | optimizer = torch.optim.AdamW([{"params": update_logit}], lr=lr, amsgrad=True, weight_decay=0.1) 271 | 272 | num_backward_iters = self.args.iter_num 273 | 274 | for i in range(num_backward_iters): 275 | update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 276 | logits = update_logit_ + musk_value 277 | 278 | logit_softmax = torch.softmax(logits, dim=-1) 279 | 280 | soft_tokens = torch.matmul(logit_softmax, self.discrimirator_embedding.weight) 281 | 282 | loss_discrimlator = self.loss_for_desiredAtt(past, soft_tokens, desired_att) 283 | 284 | loss = loss_discrimlator # + 0.0*l1_loss 285 | 286 | print("the loss is:", loss) 287 | 288 | loss.backward() 289 | torch.cuda.empty_cache() 290 | optimizer.step() 291 | torch.cuda.empty_cache() 292 | optimizer.zero_grad() 293 | 294 | update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 295 | logits = update_logit_ + musk_value 296 | 297 | indices_to_remove = logits < torch.topk(logits, self.args.top_k)[0][..., -1, None] 298 | logits[indices_to_remove] = BIG_CONST 299 | 300 | return logits.squeeze(1) 301 | 302 | def get_query(self, x_h, prompt_tokens, x_t=None): 303 | 304 | prompt_tensor = torch.tensor(prompt_tokens * (self.spell_length_disc)).to(x_h.device) 305 | prompt_tensor = prompt_tensor.expand(x_h.shape[0], -1) 306 | if x_t != None: 307 | x_t = x_t.unsqueeze(1) 308 | return torch.cat([x_h, prompt_tensor, x_t], dim=1) 309 | else: 310 | return torch.cat([x_h, prompt_tensor], dim=1) 311 | 312 | def embed_input(self, queries): 313 | bz = queries.shape[0] 314 | queries_for_embedding = queries.clone() 315 | raw_embeds = self.disc_embedding(queries_for_embedding) 316 | 317 | replace_embeds = self.prompt_encoder_disc() 318 | 319 | replace_embeds = replace_embeds.unsqueeze(0).expand(bz, -1, -1) 320 | 321 | raw_embeds[:, -self.prompt_encoder_disc.spell_length:, :] = replace_embeds 322 | 323 | return raw_embeds 324 | 325 | def _predict_scores(self, x_hs, att_mask, reward=False): 326 | bz = len(x_hs) 327 | # construct query ids 328 | prompt_tokens = [self.pseudo_token_id] 329 | 330 | queries = self.get_query(x_hs, prompt_tokens) 331 | # construct label ids 332 | attention_mask = torch.cat([att_mask, torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 333 | self.args.device)],dim=1) 334 | # get embedded input 335 | 336 | # print(queries.shape) 337 | inputs_embeds = self.embed_input(queries) 338 | 339 | position_ids = attention_mask.long().cumsum(-1) - 1 340 | position_ids.masked_fill_(attention_mask == 0, 0) 341 | 342 | # print(position_ids.shape, inputs_embeds.shape, attention_mask.shape) 343 | 344 | with torch.no_grad(): 345 | 346 | output = self.model(inputs_embeds=inputs_embeds, 347 | attention_mask=attention_mask, 348 | position_ids=position_ids, 349 | labels=None) 350 | 351 | 352 | logits = output.logits[:, -1, :].squeeze(1) 353 | # last_non_masked_idx = position_ids.argmax(-1) 354 | # logits = output.logits[range(position_ids.shape[0]), last_non_masked_idx, :] 355 | 356 | if self.label_num == 2: 357 | serial_list = [self.label_token_ids['positive'], 358 | self.label_token_ids['negative']] 359 | elif self.label_num == 3: 360 | serial_list = [self.label_token_ids['positive'], 361 | self.label_token_ids['negative'], 362 | self.label_token_ids['neutral']] 363 | 364 | 365 | tri_mask = torch.ones_like(logits,dtype=torch.bool) 366 | for i in serial_list: 367 | tri_mask[:,i]=0 368 | tri_prob = torch.masked_fill(logits,tri_mask,-torch.inf) 369 | 370 | # if self.args.target_type == "negative": 371 | # return binary_prob[:, 1] 372 | # else: 373 | # return binary_prob[:, 0] 374 | if reward==False: 375 | # results = torch.where(binary_prob[:,0]>binary_prob[:,1],11274,14774) 376 | # return results.unsqueeze(-1).tolist() 377 | rds = tri_prob.argmax(dim=-1).unsqueeze(-1).tolist() 378 | return rds 379 | 380 | else: 381 | # binary_prob = torch.softmax(binary_prob,dim=-1) 382 | tri_prob = torch.softmax(tri_prob[:,serial_list],dim=-1) 383 | return tri_prob[:,0] 384 | 385 | def forward(self, x_hs, x_ts, att_mask): 386 | bz = len(x_hs) 387 | # construct query ids 388 | prompt_tokens = [self.pseudo_token_id] 389 | 390 | queries = self.get_query(x_hs, prompt_tokens) 391 | 392 | # construct label ids 393 | attention_mask = torch.cat([att_mask,torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 394 | att_mask.device)], dim=1) 395 | 396 | position_ids = attention_mask.long().cumsum(-1) - 1 397 | position_ids.masked_fill_(attention_mask == 0, 0) 398 | 399 | # get embedded input 400 | inputs_embeds = self.embed_input(queries) 401 | 402 | output = self.model(inputs_embeds=inputs_embeds, 403 | attention_mask=attention_mask, 404 | position_ids=position_ids, 405 | labels=None) 406 | 407 | logits = output.logits[:, -1, :].squeeze(1) 408 | # last_non_masked_idx = position_ids.argmax(-1) 409 | # # print(position_ids[0],last_non_masked_idx[0]) 410 | # # print(position_ids[1], last_non_masked_idx[1]) 411 | # # raise ValueError 412 | # logits = output.logits[range(position_ids.shape[0]), last_non_masked_idx, :] 413 | 414 | loss = self.fc_loss(logits, x_ts.squeeze(1)) 415 | 416 | return loss 417 | 418 | def get_past_key_values(self, x_hs, att_mask): 419 | bz = len(x_hs) 420 | # construct query ids 421 | prompt_tokens = [self.pseudo_token_id] 422 | 423 | queries = self.get_query(x_hs, prompt_tokens) 424 | 425 | # construct label ids 426 | attention_mask = torch.cat([att_mask,torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 427 | att_mask.device)], dim=1) 428 | 429 | position_ids = attention_mask.long().cumsum(-1) - 1 430 | position_ids = position_ids.masked_fill_(attention_mask == 0, 0) 431 | 432 | # get embedded input 433 | inputs_embeds = self.embed_input(queries) 434 | output = self.model(inputs_embeds=inputs_embeds, 435 | attention_mask=attention_mask, 436 | position_ids=position_ids, 437 | return_dict=True) 438 | return output.past_key_values 439 | 440 | def forward_with_pkv(self,input_ids,attn_mask,pkv,mean=True): 441 | # construct label ids 442 | attention_mask = torch.cat([attn_mask,torch.ones([attn_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 443 | attn_mask.device)],dim=1) 444 | 445 | # no need to specify position ids. (will add past_key_values length automatically) 446 | position_ids = attention_mask.long().cumsum(-1) - 1 447 | position_ids = position_ids.masked_fill_(attention_mask == 0, 0) 448 | position_ids = position_ids[:,-1] 449 | output = self.model(input_ids = input_ids, 450 | attention_mask = attention_mask, 451 | past_key_values = pkv,position_ids=position_ids) 452 | 453 | logits = output.logits[:, -1, :].squeeze(1) 454 | # last_non_masked_idx = position_ids.argmax(-1) 455 | # logits = output.logits[range(position_ids.shape[0]), last_non_masked_idx, :] 456 | 457 | if self.label_num==3: 458 | serial_list = [self.label_token_ids['positive'], 459 | self.label_token_ids['negative'], 460 | self.label_token_ids['neutral']] 461 | tri_prob = torch.softmax(logits[:, serial_list], dim=-1) 462 | 463 | results = tri_prob.argmax(dim=-1) 464 | reward = torch.where(results == 0, 1, torch.where(results == 1, -1, 0)) 465 | elif self.label_num==2: 466 | serial_list = [self.label_token_ids['positive'], 467 | self.label_token_ids['negative']] 468 | if mean: 469 | bi_prob = torch.softmax(logits[:, serial_list], dim=-1) 470 | reward = bi_prob[:,0].mean() 471 | else: 472 | bi_prob = torch.softmax(logits[:, serial_list], dim=-1) 473 | reward = bi_prob[:,0] 474 | # results = bi_prob.argmax(dim=-1) 475 | # reward = torch.where(results == 0, 1, -1) 476 | 477 | return reward 478 | 479 | 480 | # class PTuneForLAMA(torch.nn.Module): 481 | # 482 | # def __init__(self, args, template, label_token=None): 483 | # super().__init__() 484 | # self.args = args 485 | # 486 | # self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path) 487 | # self.tokenizer.pad_token = self.tokenizer.eos_token 488 | # 489 | # # model setting 490 | # self.model = create_model(self.args) 491 | # # self.model.resize_token_embeddings(len(self.tokenizer)) 492 | # self.model = self.model.to(self.args.device) 493 | # # for param in self.model.parameters(): 494 | # # param.requires_grad = self.args.use_lm_finetune 495 | # 496 | # self.template = template 497 | # 498 | # # get model's embeddings 499 | # self.embeddings = self.model.get_input_embeddings() 500 | # # label import torch 501 | # # from torch.nn.utils.rnn import pad_sequence 502 | # # from os.path import join 503 | # # from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 504 | # # 505 | # # import re 506 | # # import datetime 507 | # # 508 | # # from transformers import AutoTokenizer 509 | # # 510 | # # from models import get_embedding_layer, create_model, _create_model 511 | # # from prompt_encoder import PromptEncoder 512 | # # 513 | # # import torch 514 | # # import torch.nn as nn 515 | # # import torch.nn.functional as F 516 | # # 517 | # # SMALL_CONST = 1e-10 518 | # # BIG_CONST = -1e15 519 | # # 520 | # # 521 | # # 522 | # # class PTuneForLAMA(torch.nn.Module): 523 | # # 524 | # # def __init__(self, args, template, label_token = None): 525 | # # super().__init__() 526 | # # self.args = args 527 | # # 528 | # # self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path) 529 | # # self.tokenizer.pad_token = self.tokenizer.eos_token 530 | # # 531 | # # # model setting 532 | # # self.model = create_model(self.args) 533 | # # # self.model.resize_token_embeddings(len(self.tokenizer)) 534 | # # self.model = self.model.to(self.args.device) 535 | # # for param in self.model.parameters(): 536 | # # param.requires_grad = self.args.use_lm_finetune 537 | # # 538 | # # self.template = template 539 | # # 540 | # # # get model's embeddings 541 | # # self.embeddings = self.model.get_input_embeddings() 542 | # # 543 | # # # label information 544 | # # 545 | # # self.label_token = label_token 546 | # # self.label_map = {} 547 | # # self.label_token_ids ={} 548 | # # 549 | # # for k, v in self.label_token.items(): 550 | # # print(k,v,self.tokenizer.convert_tokens_to_ids(v)) 551 | # # self.label_map[self.tokenizer.convert_tokens_to_ids(v)] = k 552 | # # 553 | # # self.label_token_ids[k] = self.tokenizer.convert_tokens_to_ids(v) 554 | # # 555 | # # # load prompt encoder 556 | # # self.hidden_size = self.embeddings.embedding_dim 557 | # # self.pseudo_token_id = self.tokenizer.convert_tokens_to_ids(self.args.pseudo_token) 558 | # # 559 | # # if self.args.disc_embedding_checkpoint == None: 560 | # # self.spell_length_disc = sum(self.template) 561 | # # self.prompt_encoder_disc = PromptEncoder(self.template, self.hidden_size, self.tokenizer, args) 562 | # # self.prompt_encoder_disc = self.prompt_encoder_disc.cuda() 563 | # # self.prompt_encoder = self.prompt_encoder_disc 564 | # # self.disc_embedding = self.embeddings 565 | # # else: 566 | # # self.disc_model = _create_model(self.args.disc_embedding_checkpoint[:-5]).to(self.args.device) 567 | # # self.spell_length_disc = sum(self.args.template_disc) 568 | # # self.disc_embedding = self.disc_model.get_input_embeddings() 569 | # # self.prompt_encoder_disc = PromptEncoder(self.args.template_disc, self.disc_embedding.embedding_dim, self.tokenizer, args) 570 | # # self.prompt_encoder_disc = self.prompt_encoder_disc.to(self.args.device) 571 | # # self.prompt_encoder_disc.load_state_dict(self.load_prompt(self.args.disc_embedding_checkpoint)) 572 | # # 573 | # # self.fc_loss = CrossEntropyLoss() 574 | # # 575 | # # 576 | # # 577 | # # def load_prompt(self, embedding_checkpoint): 578 | # # checkpoint = torch.load(embedding_checkpoint) 579 | # # prompt_embedding = checkpoint['embedding'] 580 | # # return prompt_embedding 581 | # # 582 | # # 583 | # # def generate(self, prompts_ids, max_length, desired_att = None, beta = 0.5): 584 | # # """ 585 | # # generation forward based on given prompt tokens, 586 | # # Args: 587 | # # prompt_ids: the prompt tokens 588 | # # max_length: the max len of the generation 589 | # # Returns: 590 | # # generated_texts:[generated tokens] 591 | # # """ 592 | # # cur_len = prompts_ids.shape[1] 593 | # # logits = [] 594 | # # output_ids = prompts_ids 595 | # # return_dict = {} 596 | # # eos_flag = torch.ones([prompts_ids.shape[0]]).type(torch.uint8).to(self.args.device) 597 | # # 598 | # # 599 | # # # start = datetime.datetime.now() 600 | # # past = None 601 | # # while cur_len <= max_length: 602 | # # past_k_v = past 603 | # # future_logits, past = self.generate_soft_tokens(output_ids, past_k_v) 604 | # # next_token_logits = future_logits.clone().detach().squeeze(1) 605 | # # perturb_logits = self.feedback_from_discriminator(output_ids, future_logits.unsqueeze(1), desired_att) 606 | # # 607 | # # next_token_logits_prob = torch.softmax(next_token_logits, dim=1) 608 | # # 609 | # # perturb_logits_prob = torch.softmax(perturb_logits, dim=1) 610 | # # 611 | # # next_token_logits_prob = perturb_logits_prob.mul(next_token_logits_prob) 612 | # # 613 | # # next_tokens = torch.multinomial(next_token_logits_prob, num_samples=1).squeeze(1) 614 | # # ## avoid eos token appeals continuely 615 | # # eos_flag = eos_flag.mul((next_tokens != self.tokenizer.eos_token_id).type(torch.uint8))# if flag = 0, it means the generated is over 616 | # # next_tokens = next_tokens.mul(eos_flag) 617 | # # next_tokens[next_tokens == 0] = self.tokenizer.eos_token_id 618 | # # output_ids = torch.cat([output_ids, next_tokens.unsqueeze(1)], dim=1) 619 | # # print("cur_len is:",cur_len) 620 | # # cur_len = cur_len + 1 621 | # # 622 | # # # end = datetime.datetime.now() 623 | # # # print("runing time is:",end-start) 624 | # # 625 | # # return_dict = {"generated_tokens":output_ids} 626 | # # return return_dict 627 | # # 628 | # # 629 | # # def generate_soft_tokens(self, generated_tokens, past_key_values= None): 630 | # # 631 | # # if past_key_values!= None: 632 | # # last_embeds =self.embeddings(generated_tokens[:, -1]).unsqueeze(1)#get its embeddings 633 | # # # print("last_embeds:", last_embeds.shape) 634 | # # with torch.no_grad(): 635 | # # outputs = self.model(inputs_embeds=last_embeds, 636 | # # past_key_values = past_key_values, 637 | # # return_dict=True) 638 | # # 639 | # # else: 640 | # # attention_mask = (generated_tokens!=self.tokenizer.eos_token_id).type(torch.uint8) 641 | # # position_ids = attention_mask.long().cumsum(-1)- 1 642 | # # position_ids.masked_fill_(attention_mask == 0, 0) 643 | # # last_embeds =self.embeddings(generated_tokens) #get its embeddings 644 | # # 645 | # # with torch.no_grad(): 646 | # # outputs = self.model(inputs_embeds=last_embeds, 647 | # # past_key_values = past_key_values, 648 | # # attention_mask = attention_mask, 649 | # # position_ids = position_ids, 650 | # # return_dict=True) 651 | # # 652 | # # next_token_logits = outputs.logits[:, -1, :] 653 | # # 654 | # # next_token_logits = self.top_k_top_p_filtering(next_token_logits.squeeze(1), top_k=self.args.ranking_scope, top_p=self.args.top_p, filter_value=BIG_CONST) 655 | # # 656 | # # return next_token_logits, outputs.past_key_values 657 | # # 658 | # # 659 | # # 660 | # # def discriminator_predict(self, input_ids): 661 | # # 662 | # # input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 663 | # # 664 | # # musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 665 | # # 666 | # # pred_ids = self.predict(input_ids_left_pad, musk) 667 | # # 668 | # # return pred_ids 669 | # # 670 | # # 671 | # # def scores_predict(self, input_ids): 672 | # # 673 | # # input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 674 | # # 675 | # # musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 676 | # # 677 | # # pred_scores = self.predict_scores(input_ids_left_pad, musk) 678 | # # 679 | # # return pred_scores 680 | # # 681 | # # 682 | # # def top_k_top_p_filtering(self, 683 | # # logits, 684 | # # top_k = 0, 685 | # # top_p = 1.0, 686 | # # filter_value = BIG_CONST , 687 | # # min_tokens_to_keep = 1, 688 | # # ): 689 | # # """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 690 | # # Args: 691 | # # logits: logits distribution shape (batch size, vocabulary size) 692 | # # if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 693 | # # if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 694 | # # Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 695 | # # Make sure we keep at least min_tokens_to_keep per batch example in the output 696 | # # From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 697 | # # """ 698 | # # if top_k > 0: 699 | # # top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 700 | # # # Remove all tokens with a probability less than the last token of the top-k 701 | # # indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 702 | # # logits[indices_to_remove] = filter_value 703 | # # 704 | # # if top_p < 1.0: 705 | # # sorted_logits, sorted_indices = torch.sort(logits, descending=True) 706 | # # cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 707 | # # 708 | # # # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 709 | # # sorted_indices_to_remove = cumulative_probs > top_p 710 | # # if min_tokens_to_keep > 1: 711 | # # # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 712 | # # sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 713 | # # # Shift the indices to the right to keep also the first token above the threshold 714 | # # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 715 | # # sorted_indices_to_remove[..., 0] = 0 716 | # # 717 | # # # scatter sorted tensors to original indexing 718 | # # indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 719 | # # logits[indices_to_remove] = filter_value 720 | # # 721 | # # return logits 722 | # # 723 | # # 724 | # # def pad_left_to_right(self, inputs, pad_id): 725 | # # trans_inputs = torch.empty_like(inputs) 726 | # # 727 | # # input_remove_prompt = inputs[:, self.prompt_pad_length:] 728 | # # 729 | # # index_musk = (input_remove_prompt != pad_id).type(torch.uint8) # only calculte the token which is not eos 730 | # # 731 | # # length_of_generated_text = torch.sum(index_musk, 1) 732 | # # 733 | # # valid_number_length = length_of_generated_text + self.prompt_pad_length 734 | # # 735 | # # count =0 736 | # # for index, seq in zip(valid_number_length, inputs): 737 | # # 738 | # # if index == 0 or index == inputs.shape[1]: 739 | # # trans_inputs[count] = seq 740 | # # else: 741 | # # trans_inputs[count][-index:] = seq[:index] 742 | # # trans_inputs[count][:inputs.shape[1]-index] = seq[index:] 743 | # # count +=1 744 | # # 745 | # # return trans_inputs, length_of_generated_text 746 | # # 747 | # # 748 | # # 749 | # # def feedback_from_discriminator(self, input_ids, logits_seq, desired_att): 750 | # # logits_seq = logits_seq.squeeze(1) 751 | # # top_logits, top_indices = logits_seq.topk(self.args.ranking_scope, dim=1) # batch x topk 752 | # # 753 | # # scores = [] 754 | # # candidates = [] 755 | # # for logit_id, ids in zip(top_indices, input_ids): 756 | # # data = ids.expand(self.args.ranking_scope, -1) 757 | # # new_input_candidates = torch.cat([data, logit_id.unsqueeze(1)], dim=1) # batch x topk x seq+1 758 | # # candidates.append(new_input_candidates) 759 | # # candidates = torch.cat(candidates, dim=0) 760 | # # 761 | # # musk = (candidates != self.tokenizer.eos_token_id).type(torch.uint8) 762 | # # pred_scores = self._predict_scores(candidates, musk) 763 | # # pred_scores = pred_scores.reshape(input_ids.shape[0], -1) 764 | # # 765 | # # logits_seq.scatter_(-1, top_indices, pred_scores) 766 | # # 767 | # # indices_to_remove = logits_seq < torch.topk(logits_seq, 3)[0][..., -1, None] 768 | # # logits_seq[indices_to_remove] = BIG_CONST 769 | # # 770 | # # return logits_seq 771 | # # 772 | # # 773 | # # 774 | # # def gradient_feedback_from_discriminator(self, past, logits_seq, desired_att, lr): 775 | # # 776 | # # musk_value = torch.empty_like(logits_seq).fill_(0.0).long().to(self.args.device) 777 | # # indices_musk = (logits_seq== BIG_CONST) 778 | # # musk_value[indices_musk] = BIG_CONST 779 | # # musk_value.requires_grad =False 780 | # # 781 | # # index = torch.nonzero(logits_seq!= BIG_CONST) 782 | # # 783 | # # logits_seqs = torch.empty_like(logits_seq) 784 | # # 785 | # # logits_seqs.fill_(0.0) 786 | # # 787 | # # update_logit = logits_seqs 788 | # # update_logit.requires_grad = True 789 | # # 790 | # # optimizer = torch.optim.AdamW([{"params":update_logit}], lr = lr, amsgrad=True, weight_decay=0.1) 791 | # # 792 | # # num_backward_iters = self.args.iter_num 793 | # # 794 | # # for i in range(num_backward_iters): 795 | # # 796 | # # update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 797 | # # logits = update_logit_ + musk_value 798 | # # 799 | # # logit_softmax = torch.softmax(logits, dim=-1) 800 | # # 801 | # # soft_tokens = torch.matmul(logit_softmax, self.discrimirator_embedding.weight) 802 | # # 803 | # # loss_discrimlator = self.loss_for_desiredAtt(past, soft_tokens, desired_att) 804 | # # 805 | # # loss = loss_discrimlator #+ 0.0*l1_loss 806 | # # 807 | # # print("the loss is:",loss) 808 | # # 809 | # # loss.backward() 810 | # # torch.cuda.empty_cache() 811 | # # optimizer.step() 812 | # # torch.cuda.empty_cache() 813 | # # optimizer.zero_grad() 814 | # # 815 | # # update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 816 | # # logits = update_logit_ + musk_value 817 | # # 818 | # # indices_to_remove = logits < torch.topk(logits, self.args.top_k)[0][..., -1, None] 819 | # # logits[indices_to_remove] = BIG_CONST 820 | # # 821 | # # return logits.squeeze(1) 822 | # # 823 | # # 824 | # # def get_query(self, x_h, prompt_tokens, x_t = None): 825 | # # 826 | # # prompt_tensor = torch.tensor(prompt_tokens* (self.spell_length_disc)).to(x_h.device) 827 | # # prompt_tensor = prompt_tensor.expand(x_h.shape[0],-1) 828 | # # if x_t != None: 829 | # # x_t = x_t.unsqueeze(1) 830 | # # return torch.cat([x_h, prompt_tensor, x_t], dim =1) 831 | # # else: 832 | # # return torch.cat([x_h, prompt_tensor], dim =1) 833 | # # 834 | # # 835 | # # 836 | # # def embed_input(self, queries): 837 | # # bz = queries.shape[0] 838 | # # queries_for_embedding = queries.clone() 839 | # # raw_embeds = self.disc_embedding(queries_for_embedding) 840 | # # 841 | # # replace_embeds = self.prompt_encoder_disc() 842 | # # 843 | # # replace_embeds = replace_embeds.unsqueeze(0).expand(bz,-1, -1) 844 | # # 845 | # # raw_embeds[:,-self.prompt_encoder_disc.spell_length:,: ] = replace_embeds 846 | # # 847 | # # return raw_embeds 848 | # # 849 | # # 850 | # # 851 | # # 852 | # # def _predict_scores(self, x_hs, att_mask): 853 | # # bz = len(x_hs) 854 | # # # construct query ids 855 | # # prompt_tokens = [self.pseudo_token_id] 856 | # # 857 | # # queries = self.get_query(x_hs, prompt_tokens) 858 | # # # construct label ids 859 | # # attention_mask = torch.cat([att_mask, torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to(self.args.device)], dim=1) 860 | # # # get embedded input 861 | # # 862 | # # # print(queries.shape) 863 | # # inputs_embeds = self.embed_input(queries) 864 | # # 865 | # # position_ids = attention_mask.long().cumsum(-1)- 1 866 | # # position_ids.masked_fill_(attention_mask == 0, 0) 867 | # # 868 | # # # print(position_ids.shape, inputs_embeds.shape, attention_mask.shape) 869 | # # 870 | # # with torch.no_grad(): 871 | # # 872 | # # output = self.disc_model(inputs_embeds = inputs_embeds, 873 | # # attention_mask = attention_mask, 874 | # # position_ids = position_ids, 875 | # # labels=None) 876 | # # 877 | # # logits = output.logits[:,-1,:].squeeze(1) 878 | # # 879 | # # binary_prob = torch.softmax(logits[:,[11274,14774]], dim=-1) 880 | # # 881 | # # if self.args.target_type == "negative": 882 | # # return binary_prob[:,1] 883 | # # else: 884 | # # return binary_prob[:,0] 885 | # # 886 | # # 887 | # # def forward(self, x_hs, x_ts, att_mask): 888 | # # bz = len(x_hs) 889 | # # # construct query ids 890 | # # prompt_tokens = [self.pseudo_token_id] 891 | # # 892 | # # queries = self.get_query(x_hs, prompt_tokens) 893 | # # 894 | # # # construct label ids 895 | # # attention_mask = torch.cat([att_mask, torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to(att_mask.device)], dim=1) 896 | # # 897 | # # position_ids = attention_mask.long().cumsum(-1)- 1 898 | # # position_ids.masked_fill_(attention_mask == 0, 0) 899 | # # 900 | # # # get embedded input 901 | # # inputs_embeds = self.embed_input(queries) 902 | # # 903 | # # label_mask = att_mask 904 | # # 905 | # # output = self.model(inputs_embeds=inputs_embeds, 906 | # # attention_mask=attention_mask, 907 | # # position_ids=position_ids, 908 | # # labels= None) 909 | # # 910 | # # logits = output.logits[:,-1,:].squeeze(1) 911 | # # 912 | # # loss = self.fc_loss(logits, x_ts.squeeze(1)) 913 | # # 914 | # # 915 | # # return loss 916 | # # 917 | # # 918 | # # 919 | # # 920 | # # 921 | # # 922 | # self.label_token = label_token 923 | # self.label_map = {} 924 | # self.label_token_ids = {} 925 | # 926 | # # for k, v in self.label_token.items(): 927 | # # print(k, v, self.tokenizer.convert_tokens_to_ids(v)) 928 | # # self.label_map[self.tokenizer.convert_tokens_to_ids(v)] = k 929 | # # 930 | # # self.label_token_ids[k] = self.tokenizer.convert_tokens_to_ids(v) 931 | # self.target_id = self.tokenizer.convert_tokens_to_ids(self.label_token['positive']) 932 | # 933 | # # load prompt encoder 934 | # self.hidden_size = self.embeddings.embedding_dim 935 | # self.pseudo_token_id = self.tokenizer.convert_tokens_to_ids(self.args.pseudo_token) 936 | # 937 | # # if self.args.disc_embedding_checkpoint == None: 938 | # self.spell_length_disc = sum(self.template) 939 | # self.prompt_encoder_disc = PromptEncoder(self.template, self.hidden_size, self.tokenizer, args) 940 | # self.prompt_encoder_disc = self.prompt_encoder_disc.to(args.device) 941 | # self.prompt_encoder = self.prompt_encoder_disc 942 | # self.disc_embedding = self.embeddings 943 | # # else: 944 | # # self.disc_model = _create_model(self.args.disc_embedding_checkpoint[:-5]).to(self.args.device) 945 | # # self.spell_length_disc = sum(self.args.template_disc) 946 | # # self.disc_embedding = self.disc_model.get_input_embeddings() 947 | # # self.prompt_encoder_disc = PromptEncoder(self.args.template_disc, self.disc_embedding.embedding_dim, 948 | # # self.tokenizer, args) 949 | # # self.prompt_encoder_disc = self.prompt_encoder_disc.to(self.args.device) 950 | # # self.prompt_encoder_disc.load_state_dict(self.load_prompt(self.args.disc_embedding_checkpoint)) 951 | # 952 | # self.fc_loss = CrossEntropyLoss() 953 | # self.mse_loss = nn.MSELoss 954 | # 955 | # def load_prompt(self, embedding_checkpoint): 956 | # checkpoint = torch.load(embedding_checkpoint) 957 | # prompt_embedding = checkpoint['embedding'] 958 | # return prompt_embedding 959 | # 960 | # def generate(self, prompts_ids, max_length, desired_att=None, beta=0.5): 961 | # """ 962 | # generation forward based on given prompt tokens, 963 | # Args: 964 | # prompt_ids: the prompt tokens 965 | # max_length: the max len of the generation 966 | # Returns: 967 | # generated_texts:[generated tokens] 968 | # """ 969 | # cur_len = prompts_ids.shape[1] 970 | # logits = [] 971 | # output_ids = prompts_ids 972 | # return_dict = {} 973 | # eos_flag = torch.ones([prompts_ids.shape[0]]).type(torch.uint8).to(self.args.device) 974 | # 975 | # # start = datetime.datetime.now() 976 | # past = None 977 | # while cur_len <= max_length: 978 | # past_k_v = past 979 | # future_logits, past = self.generate_soft_tokens(output_ids, past_k_v) 980 | # next_token_logits = future_logits.clone().detach().squeeze(1) 981 | # perturb_logits = self.feedback_from_discriminator(output_ids, future_logits.unsqueeze(1), desired_att) 982 | # 983 | # next_token_logits_prob = torch.softmax(next_token_logits, dim=1) 984 | # 985 | # perturb_logits_prob = torch.softmax(perturb_logits, dim=1) 986 | # 987 | # next_token_logits_prob = perturb_logits_prob.mul(next_token_logits_prob) 988 | # 989 | # next_tokens = torch.multinomial(next_token_logits_prob, num_samples=1).squeeze(1) 990 | # ## avoid eos token appeals continuely 991 | # eos_flag = eos_flag.mul((next_tokens != self.tokenizer.eos_token_id).type( 992 | # torch.uint8)) # if flag = 0, it means the generated is over 993 | # next_tokens = next_tokens.mul(eos_flag) 994 | # next_tokens[next_tokens == 0] = self.tokenizer.eos_token_id 995 | # output_ids = torch.cat([output_ids, next_tokens.unsqueeze(1)], dim=1) 996 | # print("cur_len is:", cur_len) 997 | # cur_len = cur_len + 1 998 | # 999 | # # end = datetime.datetime.now() 1000 | # # print("runing time is:",end-start) 1001 | # 1002 | # return_dict = {"generated_tokens": output_ids} 1003 | # return return_dict 1004 | # 1005 | # def generate_soft_tokens(self, generated_tokens, past_key_values=None): 1006 | # 1007 | # if past_key_values != None: 1008 | # last_embeds = self.embeddings(generated_tokens[:, -1]).unsqueeze(1) # get its embeddings 1009 | # # print("last_embeds:", last_embeds.shape) 1010 | # with torch.no_grad(): 1011 | # outputs = self.model(inputs_embeds=last_embeds, 1012 | # past_key_values=past_key_values, 1013 | # return_dict=True) 1014 | # 1015 | # else: 1016 | # attention_mask = (generated_tokens != self.tokenizer.eos_token_id).type(torch.uint8) 1017 | # position_ids = attention_mask.long().cumsum(-1) - 1 1018 | # position_ids.masked_fill_(attention_mask == 0, 0) 1019 | # last_embeds = self.embeddings(generated_tokens) # get its embeddings 1020 | # 1021 | # with torch.no_grad(): 1022 | # outputs = self.model(inputs_embeds=last_embeds, 1023 | # past_key_values=past_key_values, 1024 | # attention_mask=attention_mask, 1025 | # position_ids=position_ids, 1026 | # return_dict=True) 1027 | # 1028 | # next_token_logits = outputs.logits[:, -1, :] 1029 | # 1030 | # next_token_logits = self.top_k_top_p_filtering(next_token_logits.squeeze(1), top_k=self.args.ranking_scope, 1031 | # top_p=self.args.top_p, filter_value=BIG_CONST) 1032 | # 1033 | # return next_token_logits, outputs.past_key_values 1034 | # 1035 | # def discriminator_predict(self, input_ids): 1036 | # 1037 | # input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 1038 | # 1039 | # musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 1040 | # 1041 | # pred_ids = self.predict(input_ids_left_pad, musk) 1042 | # 1043 | # return pred_ids 1044 | # 1045 | # def scores_predict(self, input_ids): 1046 | # 1047 | # input_ids_left_pad, length_generated_tokens = self.pad_left_to_right(input_ids, self.tokenizer.eos_token_id) 1048 | # 1049 | # musk = (input_ids_left_pad != self.tokenizer.eos_token_id).type(torch.uint8) 1050 | # 1051 | # pred_scores = self.predict_scores(input_ids_left_pad, musk) 1052 | # 1053 | # return pred_scores 1054 | # 1055 | # def top_k_top_p_filtering(self, 1056 | # logits, 1057 | # top_k=0, 1058 | # top_p=1.0, 1059 | # filter_value=BIG_CONST, 1060 | # min_tokens_to_keep=1, 1061 | # ): 1062 | # """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 1063 | # Args: 1064 | # logits: logits distribution shape (batch size, vocabulary size) 1065 | # if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 1066 | # if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 1067 | # Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 1068 | # Make sure we keep at least min_tokens_to_keep per batch example in the output 1069 | # From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 1070 | # """ 1071 | # if top_k > 0: 1072 | # top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 1073 | # # Remove all tokens with a probability less than the last token of the top-k 1074 | # indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 1075 | # logits[indices_to_remove] = filter_value 1076 | # 1077 | # if top_p < 1.0: 1078 | # sorted_logits, sorted_indices = torch.sort(logits, descending=True) 1079 | # cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 1080 | # 1081 | # # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 1082 | # sorted_indices_to_remove = cumulative_probs > top_p 1083 | # if min_tokens_to_keep > 1: 1084 | # # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 1085 | # sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 1086 | # # Shift the indices to the right to keep also the first token above the threshold 1087 | # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 1088 | # sorted_indices_to_remove[..., 0] = 0 1089 | # 1090 | # # scatter sorted tensors to original indexing 1091 | # indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 1092 | # logits[indices_to_remove] = filter_value 1093 | # 1094 | # return logits 1095 | # 1096 | # def pad_left_to_right(self, inputs, pad_id): 1097 | # trans_inputs = torch.empty_like(inputs) 1098 | # 1099 | # input_remove_prompt = inputs[:, self.prompt_pad_length:] 1100 | # 1101 | # index_musk = (input_remove_prompt != pad_id).type(torch.uint8) # only calculte the token which is not eos 1102 | # 1103 | # length_of_generated_text = torch.sum(index_musk, 1) 1104 | # 1105 | # valid_number_length = length_of_generated_text + self.prompt_pad_length 1106 | # 1107 | # count = 0 1108 | # for index, seq in zip(valid_number_length, inputs): 1109 | # 1110 | # if index == 0 or index == inputs.shape[1]: 1111 | # trans_inputs[count] = seq 1112 | # else: 1113 | # trans_inputs[count][-index:] = seq[:index] 1114 | # trans_inputs[count][:inputs.shape[1] - index] = seq[index:] 1115 | # count += 1 1116 | # 1117 | # return trans_inputs, length_of_generated_text 1118 | # 1119 | # def feedback_from_discriminator(self, input_ids, logits_seq, desired_att): 1120 | # logits_seq = logits_seq.squeeze(1) 1121 | # top_logits, top_indices = logits_seq.topk(self.args.ranking_scope, dim=1) # batch x topk 1122 | # 1123 | # scores = [] 1124 | # candidates = [] 1125 | # for logit_id, ids in zip(top_indices, input_ids): 1126 | # data = ids.expand(self.args.ranking_scope, -1) 1127 | # new_input_candidates = torch.cat([data, logit_id.unsqueeze(1)], dim=1) # batch x topk x seq+1 1128 | # candidates.append(new_input_candidates) 1129 | # candidates = torch.cat(candidates, dim=0) 1130 | # 1131 | # musk = (candidates != self.tokenizer.eos_token_id).type(torch.uint8) 1132 | # pred_scores = self._predict_scores(candidates, musk) 1133 | # pred_scores = pred_scores.reshape(input_ids.shape[0], -1) 1134 | # 1135 | # logits_seq.scatter_(-1, top_indices, pred_scores) 1136 | # 1137 | # indices_to_remove = logits_seq < torch.topk(logits_seq, 3)[0][..., -1, None] 1138 | # logits_seq[indices_to_remove] = BIG_CONST 1139 | # 1140 | # return logits_seq 1141 | # 1142 | # def gradient_feedback_from_discriminator(self, past, logits_seq, desired_att, lr): 1143 | # 1144 | # musk_value = torch.empty_like(logits_seq).fill_(0.0).long().to(self.args.device) 1145 | # indices_musk = (logits_seq == BIG_CONST) 1146 | # musk_value[indices_musk] = BIG_CONST 1147 | # musk_value.requires_grad = False 1148 | # 1149 | # index = torch.nonzero(logits_seq != BIG_CONST) 1150 | # 1151 | # logits_seqs = torch.empty_like(logits_seq) 1152 | # 1153 | # logits_seqs.fill_(0.0) 1154 | # 1155 | # update_logit = logits_seqs 1156 | # update_logit.requires_grad = True 1157 | # 1158 | # optimizer = torch.optim.AdamW([{"params": update_logit}], lr=lr, amsgrad=True, weight_decay=0.1) 1159 | # 1160 | # num_backward_iters = self.args.iter_num 1161 | # 1162 | # for i in range(num_backward_iters): 1163 | # update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 1164 | # logits = update_logit_ + musk_value 1165 | # 1166 | # logit_softmax = torch.softmax(logits, dim=-1) 1167 | # 1168 | # soft_tokens = torch.matmul(logit_softmax, self.discrimirator_embedding.weight) 1169 | # 1170 | # loss_discrimlator = self.loss_for_desiredAtt(past, soft_tokens, desired_att) 1171 | # 1172 | # loss = loss_discrimlator # + 0.0*l1_loss 1173 | # 1174 | # print("the loss is:", loss) 1175 | # 1176 | # loss.backward() 1177 | # torch.cuda.empty_cache() 1178 | # optimizer.step() 1179 | # torch.cuda.empty_cache() 1180 | # optimizer.zero_grad() 1181 | # 1182 | # update_logit_ = update_logit.mul(~indices_musk) ##topk, value is orginal 1183 | # logits = update_logit_ + musk_value 1184 | # 1185 | # indices_to_remove = logits < torch.topk(logits, self.args.top_k)[0][..., -1, None] 1186 | # logits[indices_to_remove] = BIG_CONST 1187 | # 1188 | # return logits.squeeze(1) 1189 | # 1190 | # def get_query(self, x_h, prompt_tokens, x_t=None): 1191 | # 1192 | # prompt_tensor = torch.tensor(prompt_tokens * (self.spell_length_disc)).to(x_h.device) 1193 | # prompt_tensor = prompt_tensor.expand(x_h.shape[0], -1) 1194 | # if x_t != None: 1195 | # x_t = x_t.unsqueeze(1) 1196 | # return torch.cat([x_h, prompt_tensor, x_t], dim=1) 1197 | # else: 1198 | # return torch.cat([x_h, prompt_tensor], dim=1) 1199 | # 1200 | # def embed_input(self, queries): 1201 | # bz = queries.shape[0] 1202 | # queries_for_embedding = queries.clone() 1203 | # raw_embeds = self.disc_embedding(queries_for_embedding) 1204 | # 1205 | # replace_embeds = self.prompt_encoder_disc() 1206 | # 1207 | # replace_embeds = replace_embeds.unsqueeze(0).expand(bz, -1, -1) 1208 | # 1209 | # raw_embeds[:, -self.prompt_encoder_disc.spell_length:, :] = replace_embeds 1210 | # 1211 | # return raw_embeds 1212 | # 1213 | # def _predict_scores(self, x_hs, att_mask, reward=False): 1214 | # bz = len(x_hs) 1215 | # # construct query ids 1216 | # prompt_tokens = [self.pseudo_token_id] 1217 | # 1218 | # queries = self.get_query(x_hs, prompt_tokens) 1219 | # # construct label ids 1220 | # attention_mask = torch.cat([att_mask, 1221 | # torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 1222 | # self.args.device)], dim=1) 1223 | # # get embedded input 1224 | # 1225 | # # print(queries.shape) 1226 | # inputs_embeds = self.embed_input(queries) 1227 | # 1228 | # position_ids = attention_mask.long().cumsum(-1) - 1 1229 | # position_ids.masked_fill_(attention_mask == 0, 0) 1230 | # 1231 | # # print(position_ids.shape, inputs_embeds.shape, attention_mask.shape) 1232 | # 1233 | # with torch.no_grad(): 1234 | # 1235 | # output = self.model(inputs_embeds=inputs_embeds, 1236 | # attention_mask=attention_mask, 1237 | # position_ids=position_ids, 1238 | # labels=None) 1239 | # 1240 | # logits = output.logits[:, -1, :].squeeze(1) 1241 | # 1242 | # binary_prob = torch.softmax(logits[:, [11274, 14774]], dim=-1) 1243 | # 1244 | # # if self.args.target_type == "negative": 1245 | # # return binary_prob[:, 1] 1246 | # # else: 1247 | # # return binary_prob[:, 0] 1248 | # if reward==False: 1249 | # results = torch.where(binary_prob[:,0]>binary_prob[:,1],11274,14774) 1250 | # return results.unsqueeze(-1).tolist() 1251 | # else: 1252 | # # binary_prob = torch.softmax(binary_prob,dim=-1) 1253 | # return binary_prob[:,0] 1254 | # 1255 | # def forward(self, x_hs, att_mask, label=None): 1256 | # bz = len(x_hs) 1257 | # # construct query ids 1258 | # prompt_tokens = [self.pseudo_token_id] 1259 | # 1260 | # queries = self.get_query(x_hs, prompt_tokens) 1261 | # 1262 | # # construct label ids 1263 | # attention_mask = torch.cat([att_mask, 1264 | # torch.ones([att_mask.shape[0], self.prompt_encoder_disc.spell_length]).long().to( 1265 | # att_mask.device)], dim=1) 1266 | # 1267 | # position_ids = attention_mask.long().cumsum(-1) - 1 1268 | # position_ids.masked_fill_(attention_mask == 0, 0) 1269 | # 1270 | # # get embedded input 1271 | # inputs_embeds = self.embed_input(queries) 1272 | # 1273 | # label_mask = att_mask 1274 | # 1275 | # output = self.model(inputs_embeds=inputs_embeds, 1276 | # attention_mask=attention_mask, 1277 | # position_ids=position_ids, 1278 | # labels=None) 1279 | # 1280 | # logits = output.logits[:, self.spell_length_disc:, self.target_id].squeeze(1) 1281 | # 1282 | # score = torch.sigmoid(logits) 1283 | # 1284 | # 1285 | # if label==None: 1286 | # ave = (score - label) * x_attn 1287 | # print('ave_deprency:',ave.sum()/x_attn.sum()) 1288 | # return ave.sum(),x_attn.sum() 1289 | # else: 1290 | # print(score[0],score[1]) 1291 | # print(label[0],label[1]) 1292 | # loss = self.mse_loss(score, label) 1293 | # # loss = self.fc_loss(logits, x_ts.squeeze(1)) 1294 | # return loss 1295 | 1296 | 1297 | 1298 | 1299 | 1300 | 1301 | --------------------------------------------------------------------------------