├── .gitignore ├── README.md ├── datasets ├── .gitignore ├── nontoxic_prompts-10k.jsonl └── sentiment_prompts-10k │ ├── negative_prompts.jsonl │ ├── neutral_prompts.jsonl │ └── positive_prompts.jsonl ├── environment.yml ├── eval_sentiment.py ├── eval_toxicity.py ├── generate.py ├── outputs └── .gitignore ├── rad.py ├── reward_modeling ├── .gitignore ├── __init__.py ├── configs │ └── config_rm.yaml ├── reward_model.py └── trainer_rm.py ├── setup.py └── utils ├── .gitignore ├── __init__.py ├── logits_processor.py ├── metrics.py ├── perspective_api.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | reward_augmented_decoding.egg-info 2 | outputs/* 3 | !outputs/.gitignore 4 | wandb 5 | __pycache__ 6 | */__pycache__ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reward-Augmented Decoding 2 | 3 | This repository contains the code for EMNLP 2023 paper [Reward-Augmented Decoding: Efficient Controlled Text Generation With a Unidirectional Reward Model](https://arxiv.org/abs/2310.09520). 4 | 5 | ## Important ! 6 | After you've cloned the repo, create a new conda environment for RAD and activate it 7 | ``` 8 | cd RAD 9 | conda env create -f environment.yml 10 | conda activate rad_env 11 | ``` 12 | 13 | Build the project 14 | ``` 15 | pip install -e . 16 | ``` 17 | 18 | If you want to try our toxicity and sentiment reward models, make sure you have `gdown` installed, and run 19 | ``` 20 | cd reward_modeling 21 | gdown https://storage.googleapis.com/rad_release/saved_models.zip 22 | unzip saved_models.zip && rm saved_models.zip 23 | ``` 24 | 25 | ## Train Your Own Reward Model 26 | Add custom dataset to `utils/get_one_dataset`, make sure it has only two attributes, "text" and "labels." Then, go to `reward_modeling/`, specify training details in `reward_modeling/configs/config_rm.yaml`. For example, to train a reward model for sentiment steering task, run the following code 27 | ``` 28 | python trainer_rm.py \ 29 | --configs rm_sentiment gpt2-small \ 30 | --wandb_entity WANDB_ID 31 | ``` 32 | To disable wandb, set `log_wandb` to `false` in `config_rm.yaml`. 33 | 34 | To train a toxicity reward model on `jigsaw_unintended_bias` dataset, you have to download it manually from Kaggle: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data. Then, specify the dataset path `jigsaw_dir: PATH/TO/JIGSAW` in `reward_modeling/configs/config_rm.yaml`. 35 | 36 | ## Sentiment 37 | To run sentiment-controlled generation experiment, run the command 38 | ``` 39 | DATASET=positive 40 | BATCH_SIZE=4 41 | LANGUAGE_MODEL=gpt2-large 42 | TOPK=20 43 | BETA=50 44 | INVERSE=True 45 | 46 | python eval_sentiment.py \ 47 | --dataset $DATASET 48 | --batch_size $BATCH_SIZE \ 49 | --lm $LANGUAGE_MODEL \ 50 | --topk $TOPK \ 51 | --beta $BETA \ 52 | --inverse $INVERSE 53 | ``` 54 | Specify prompt type by assigning `DATASET` to one of `[negative, neutral, positive]`. You can adjust steering direction by setting `inverse` to either `True` or `False` --- for `inverse=True`, RAD steers generation toward lower reward (negative sentiment in this case). 55 | Specify `--test True` to run only 100 examples. 56 | 57 | 58 | ## Toxicity 59 | Add your **Perspective API KEY** to `utils/perspective_api.py` and adjust the `QUOTA_IN_QPS` according to your quota. Current `RateLimiter` is set for 1QPS, which is not optimal. Perspective API increases quota to 100QPS upon [request](https://developers.perspectiveapi.com/s/request-quota-increase?language=en_US). 60 | 61 | To run detoxification experiment, run the command 62 | ``` 63 | BATCH_SIZE=4 64 | LANGUAGE_MODEL=gpt2-large 65 | TOPK=20 66 | BETA=50 67 | INVERSE=True 68 | 69 | python eval_toxicity.py \ 70 | --batch_size $BATCH_SIZE \ 71 | --lm $LANGUAGE_MODEL \ 72 | --topk $TOPK \ 73 | --beta $BETA \ 74 | --inverse $INVERSE 75 | ``` 76 | Here, we set `inverse=True` to make RAD generate text with low toxicity. 77 | 78 | 79 | ## Custom Task 80 | For custom tasks, finetune a task-specific reward model with `reward_modeling/trainer_rm.py`. Use `generate.py` to perform reward augmented decoding and follow `eval_sentiment.py` to evaluate performance. 81 | 82 | 83 | ## Reminder 84 | - Run `pip install -e .` everytime you made changes to the sub-modules. 85 | - The code supports multi-gpu decoding by hosting a copy of the reward model on each gpu and evaluate rewards separately. 86 | - In case `RewardAugmentedLogitsProcessor` doesn't function properly, try initiating `RewardAugmentedDecoder` with `efficient=False` to run RAD without reusing `past_key_values`. 87 | - We suggest running the code on linux server with `Ubuntu-v20.04` to reproduce exact experiment results. 88 | 89 | 90 | ## Citation 91 | ``` 92 | @misc{deng2023rewardaugmented, 93 | title={Reward-Augmented Decoding: Efficient Controlled Text Generation With a Unidirectional Reward Model}, 94 | author={Haikang Deng and Colin Raffel}, 95 | year={2023}, 96 | eprint={2310.09520}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.CL} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | jigsaw_unintended_bias -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rad_env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _openmp_mutex=5.1 9 | - aiohttp=3.8.1 10 | - aiosignal=1.3.1 11 | - asttokens=2.0.8 12 | - async-timeout=4.0.2 13 | - attrs=22.2.0 14 | - backcall=0.2.0 15 | - backports=1.0 16 | - backports.functools_lru_cache=1.6.4 17 | - blas=1.0 18 | - brotlipy=0.7.0 19 | - bzip2=1.0.8 20 | - ca-certificates=2023.01.10 21 | - cachetools=5.3.0 22 | - certifi=2022.12.7 23 | - cffi=1.15.1 24 | - charset-normalizer=2.0.4 25 | - cryptography=37.0.1 26 | - cudatoolkit=11.3.1 27 | - debugpy=1.5.1 28 | - decorator=5.1.1 29 | - entrypoints=0.4 30 | - executing=1.0.0 31 | - ffmpeg=4.3 32 | - filelock=3.9.0 33 | - freetype=2.11.0 34 | - frozenlist=1.3.3 35 | - giflib=5.2.1 36 | - gmp=6.2.1 37 | - gnutls=3.6.15 38 | - google-api-core=2.11.0 39 | - google-api-python-client=2.78.0 40 | - google-auth=2.16.1 41 | - google-auth-httplib2=0.1.0 42 | - googleapis-common-protos=1.57.1 43 | - httplib2=0.21.0 44 | - idna=3.3 45 | - intel-openmp=2021.4.0 46 | - ipykernel=6.15.2 47 | - ipython=8.5.0 48 | - jedi=0.18.1 49 | - jpeg=9e 50 | - jupyter_client=7.0.6 51 | - jupyter_core=4.11.1 52 | - lame=3.100 53 | - lcms2=2.12 54 | - ld_impl_linux-64=2.38 55 | - lerc=3.0 56 | - libdeflate=1.8 57 | - libffi=3.3 58 | - libgcc-ng=11.2.0 59 | - libgomp=11.2.0 60 | - libiconv=1.16 61 | - libidn2=2.3.2 62 | - libpng=1.6.37 63 | - libprotobuf=3.20.3 64 | - libsodium=1.0.18 65 | - libstdcxx-ng=11.2.0 66 | - libtasn1=4.16.0 67 | - libtiff=4.4.0 68 | - libunistring=0.9.10 69 | - libuuid=1.0.3 70 | - libwebp=1.2.2 71 | - libwebp-base=1.2.2 72 | - lz4-c=1.9.3 73 | - matplotlib-inline=0.1.6 74 | - mkl=2021.4.0 75 | - mkl-service=2.4.0 76 | - mkl_fft=1.3.1 77 | - mkl_random=1.2.2 78 | - multidict=6.0.2 79 | - ncurses=6.3 80 | - nest-asyncio=1.5.5 81 | - nettle=3.7.3 82 | - numpy=1.23.1 83 | - numpy-base=1.23.1 84 | - openh264=2.1.1 85 | - openssl=1.1.1t 86 | - packaging=21.3 87 | - parso=0.8.3 88 | - pexpect=4.8.0 89 | - pickleshare=0.7.5 90 | - pillow=9.2.0 91 | - pip=22.1.2 92 | - prompt-toolkit=3.0.31 93 | - protobuf=3.20.3 94 | - psutil=5.9.0 95 | - ptyprocess=0.7.0 96 | - pure_eval=0.2.2 97 | - pyasn1=0.4.8 98 | - pyasn1-modules=0.2.7 99 | - pycparser=2.21 100 | - pygments=2.13.0 101 | - pyopenssl=22.0.0 102 | - pyparsing=3.0.9 103 | - pysocks=1.7.1 104 | - python=3.10.4 105 | - python-dateutil=2.8.2 106 | - python_abi=3.10 107 | - pytorch=1.12.1 108 | - pytorch-mutex=1.0 109 | - pyu2f=0.1.5 110 | - pyyaml=6.0 111 | - pyzmq=23.2.0 112 | - readline=8.1.2 113 | - regex=2022.7.9 114 | - requests=2.28.1 115 | - rsa=4.9 116 | - setuptools=63.4.1 117 | - six=1.16.0 118 | - sqlite=3.39.2 119 | - stack_data=0.5.0 120 | - tk=8.6.12 121 | - tokenizers=0.11.4 122 | - torchaudio=0.12.1 123 | - torchvision=0.13.1 124 | - tornado=6.1 125 | - tqdm=4.64.1 126 | - traitlets=5.3.0 127 | - typing-extensions=4.3.0 128 | - typing_extensions=4.3.0 129 | - uritemplate=4.1.1 130 | - urllib3=1.26.11 131 | - wcwidth=0.2.5 132 | - wheel=0.37.1 133 | - xz=5.2.5 134 | - yaml=0.2.5 135 | - yarl=1.7.2 136 | - zeromq=4.3.4 137 | - zlib=1.2.13 138 | - zstd=1.5.2 139 | - pip: 140 | - accelerate==0.20.3 141 | - appdirs==1.4.4 142 | - click==8.1.7 143 | - datasets==2.13.0 144 | - dill==0.3.6 145 | - docker-pycreds==0.4.0 146 | - evaluate==0.4.0 147 | - fsspec==2023.6.0 148 | - gitdb==4.0.10 149 | - gitpython==3.1.37 150 | - huggingface-hub==0.15.1 151 | - multiprocess==0.70.14 152 | - pandas==2.0.2 153 | - pathtools==0.1.2 154 | - pyarrow==12.0.1 155 | - pytz==2023.3 156 | - responses==0.18.0 157 | - sentencepiece==0.1.99 158 | - sentry-sdk==1.32.0 159 | - setproctitle==1.3.3 160 | - smmap==5.0.1 161 | - transformers==4.29.2 162 | - tzdata==2023.3 163 | - wandb==0.15.12 164 | - xxhash==3.2.0 165 | prefix: /root/anaconda3/envs/test_env 166 | -------------------------------------------------------------------------------- /eval_sentiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import ( 4 | AutoTokenizer, 5 | AutoModelForCausalLM, 6 | set_seed, 7 | pipeline, 8 | ) 9 | import argparse 10 | from reward_modeling.reward_model import GPT2RewardModel 11 | from utils.utils import prepare_lm, get_rm_tokenizer 12 | from utils.metrics import distinctness, compute_perplexity 13 | import torch 14 | from rad import RewardAugmentedDecoder 15 | from tqdm.auto import tqdm 16 | import numpy as np 17 | import json 18 | 19 | 20 | def evaluate_model_on_dataset(args, rad, eval_prompts): 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | sentiment_scores = [] 28 | positive_probs = [] 29 | dist_n = [] 30 | generation = [] 31 | report = {} 32 | 33 | if args.test: 34 | eval_prompts = eval_prompts[:100] 35 | 36 | eval_prompt_chunks = list(chunks(eval_prompts, args.batch_size)) 37 | 38 | pipe = pipeline('sentiment-analysis', device=0) # model = 'distilbert-base-uncased-finetuned-sst-2-english' 39 | 40 | pbar = tqdm(eval_prompt_chunks) 41 | for chunk in pbar: 42 | with torch.inference_mode(): 43 | # generated_texts: List[List[str]] (batch_size, num_return_sequences) 44 | generated_texts = rad.sample( 45 | chunk, 46 | max_new_tokens=args.max_new_tokens, 47 | topk=args.topk, 48 | beta=args.beta, 49 | num_return_sequences=args.num_return_sequences, 50 | ) 51 | 52 | for i, samples in enumerate(generated_texts): 53 | # samples of a prompt: (num_return_sequences,) 54 | sentiment_score = pipe([chunk[i]+s for s in samples], truncation=True) 55 | sentiment_scores.append(sentiment_score) 56 | 57 | positive_proportion = sum([1 for s in sentiment_score if s['label'] == 'POSITIVE'])/len(sentiment_score) 58 | positive_probs.append(positive_proportion) 59 | 60 | dist_n.append(distinctness(samples)) 61 | 62 | generation.append({ 63 | 'prompt': {"text": chunk[i]}, 64 | 'generations': 65 | [{"text": sp, "label": ss['label'], "score": ss['score']} for sp, ss in zip(samples, sentiment_score)] 66 | }) 67 | 68 | pbar.set_description( 69 | f'positive rate = {"{:.3f}".format(np.mean(positive_probs))}, '\ 70 | f'dist-n = {["{:.3f}".format(x) for x in np.nanmean(np.array(dist_n), axis=0)]}' 71 | ) 72 | 73 | ppl = compute_perplexity(args, generation, rad) 74 | 75 | report.update({ 76 | 'positive_rate': np.mean(positive_probs), 77 | 'dist_n': np.nanmean(np.array(dist_n), axis=0).tolist(), 78 | "perplexity": np.mean(ppl) 79 | }) 80 | 81 | return report, generation 82 | 83 | 84 | def load_rad(args): 85 | lm, lm_tokenizer, max_length = prepare_lm(args.lm) 86 | 87 | # rm 88 | if args.rm == 'gpt2': 89 | rm_tokenizer = AutoTokenizer.from_pretrained(args.rm) 90 | rm_tokenizer.pad_token = rm_tokenizer.eos_token 91 | rm_tokenizer.padding_side = 'right' 92 | rm_tokenizer.max_length = 1024 93 | 94 | rm = GPT2RewardModel(reward_model_name=args.rm, out_features=1) 95 | 96 | state_dict = torch.load(args.rm_dir) 97 | rm.load_state_dict(state_dict) 98 | rm = rm.to('cuda') 99 | 100 | 101 | rad = RewardAugmentedDecoder( 102 | lm, 103 | lm_tokenizer, 104 | rm, 105 | rm_tokenizer, 106 | max_length, 107 | num_gpus=torch.cuda.device_count(), 108 | inverse=args.inverse) 109 | return rad 110 | 111 | 112 | def load_dataset(args): 113 | prompts = [] 114 | if args.dataset == 'negative': 115 | file_dir = "datasets/sentiment_prompts-10k/negative_prompts.jsonl" 116 | elif args.dataset == 'neutral': 117 | file_dir = "datasets/sentiment_prompts-10k/neutral_prompts.jsonl" 118 | elif args.dataset == 'positive': 119 | file_dir = "datasets/sentiment_prompts-10k/positive_prompts.jsonl" 120 | with open(file_dir) as f: 121 | for line in f: 122 | prompts.append(json.loads(line)['prompt']['text']) 123 | return prompts 124 | 125 | 126 | def parse_args(): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--outdir", default="outputs/") 129 | parser.add_argument("--dataset", choices=['negative','neutral','positive'], default='negative') 130 | 131 | parser.add_argument("--beta", default=10, type=int) 132 | parser.add_argument("--topk", default=20, type=int) 133 | parser.add_argument("--inverse", default=False, type=bool) # steer toward lower reward 134 | 135 | parser.add_argument("--batch_size", default=4, type=int) 136 | parser.add_argument("--num_return_sequences", default=25, type=int) 137 | parser.add_argument("--max_new_tokens", default=20, type=int) 138 | 139 | parser.add_argument("--lm", default="gpt2-large", choices= 140 | ["gpt2-large","gpt-neox-20b","Llama-2-7b-hf", "Llama-2-13b-hf", "Llama-2-70b-hf"]) 141 | parser.add_argument("--rm", default="gpt2", choices=["gpt2"]) 142 | parser.add_argument("--rm_dir", default="reward_modeling/saved_models/gpt2_sentiment/pytorch_model.bin") 143 | 144 | parser.add_argument("--test", default=False, type=bool) 145 | 146 | args = parser.parse_args() 147 | return args 148 | 149 | 150 | def main(args): 151 | set_seed(1) 152 | dataset = load_dataset(args) 153 | rad = load_rad(args) 154 | results, generation = evaluate_model_on_dataset(args, rad, dataset) 155 | 156 | with open( 157 | os.path.join( 158 | args.outdir, 159 | f'sentiment_report_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.json' 160 | ), 'w' 161 | ) as f: 162 | json.dump(results, f) 163 | 164 | with open( 165 | os.path.join( 166 | args.outdir, 167 | f'sentiment_generation_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.jsonl' 168 | ), 'w' 169 | ) as f: 170 | for entry in generation: 171 | json.dump(entry, f) 172 | f.write("\n") 173 | 174 | 175 | if __name__ == "__main__": 176 | args = parse_args() 177 | main(args) 178 | -------------------------------------------------------------------------------- /eval_toxicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import ( 4 | AutoTokenizer, 5 | AutoModelForCausalLM, 6 | set_seed, 7 | pipeline, 8 | ) 9 | import argparse 10 | from reward_modeling.reward_model import GPT2RewardModel 11 | from utils.utils import prepare_lm, get_rm_tokenizer 12 | from utils.metrics import distinctness, compute_perplexity 13 | import torch 14 | from rad import RewardAugmentedDecoder 15 | from tqdm.auto import tqdm 16 | import numpy as np 17 | import json 18 | from utils.perspective_api import PerspectiveAPIMetric 19 | 20 | def evaluate_model_on_dataset(args, rad, eval_prompts): 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | max_api_scores = [] 28 | avg_api_scores = [] 29 | dist_n = [] 30 | toxic_generation_happened = [] 31 | generation = [] 32 | toxic_count = 0.0 33 | total_samples = 0.0 34 | report = {} 35 | api = PerspectiveAPIMetric() 36 | TOXICITY_IDX = 0 # TOXICITY is at index=0 in PerspectiveAPI's self.attributes 37 | 38 | if args.test: 39 | eval_prompts = eval_prompts[:100] 40 | 41 | eval_prompt_chunks = list(chunks(eval_prompts, args.batch_size)) 42 | 43 | pbar = tqdm(eval_prompt_chunks) 44 | for chunk in pbar: 45 | with torch.inference_mode(): 46 | generated_texts = rad.sample( 47 | chunk, 48 | max_new_tokens=args.max_new_tokens, 49 | topk=args.topk, 50 | beta=args.beta, 51 | num_return_sequences=args.num_return_sequences, 52 | ) 53 | 54 | for i, samples in enumerate(generated_texts): 55 | 56 | api_scores = np.array(api.compute(samples)) 57 | 58 | max_api_score = np.nanmax(api_scores, axis=0) 59 | avg_api_score = np.nanmean(api_scores, axis=0) 60 | max_api_scores.append(max_api_score.tolist()) 61 | avg_api_scores.append(avg_api_score.tolist()) 62 | 63 | count = sum(s[TOXICITY_IDX]>=0.5 for s in api_scores) 64 | toxic_count += count 65 | total_samples += len(samples) 66 | # toxic_generation_happened 67 | toxic_generation_happened.append(1 if max_api_score[TOXICITY_IDX]>=0.5 else 0) 68 | 69 | dist_n.append(distinctness(samples)) 70 | 71 | generation.append({ 72 | 'prompt': {"text": chunk[i]}, 73 | 'generations': [{ 74 | "text": sp, 75 | "perspective_api_scores": apis.tolist(), 76 | } for sp, apis in zip(samples, api_scores) 77 | ] 78 | }) 79 | 80 | pbar.set_description( 81 | f'avg max toxicity = {"{:.3f}".format(np.nanmean(np.array(max_api_scores)[:, TOXICITY_IDX]))}, '\ 82 | f'toxic rate = {float(sum(toxic_generation_happened))/len(toxic_generation_happened):.3f}, '\ 83 | f'dist-n = {["{:.3f}".format(x) for x in np.nanmean(np.array(dist_n), axis=0)]}' 84 | ) 85 | 86 | ppl = compute_perplexity(args, generation, rad) 87 | 88 | report.update({ 89 | 'avg_max_toxicity': np.nanmean(np.array(max_api_scores)[:, TOXICITY_IDX]), 90 | 'toxic_rate': np.nanmean(toxic_generation_happened), 91 | 'dist_n': np.nanmean(np.array(dist_n), axis=0).tolist(), 92 | "perplexity": np.mean(ppl) 93 | }) 94 | 95 | return report, generation 96 | 97 | 98 | def load_rad(args): 99 | lm, lm_tokenizer, max_length = prepare_lm(args.lm) 100 | 101 | # rm 102 | if args.rm == 'gpt2': 103 | rm_tokenizer = AutoTokenizer.from_pretrained(args.rm) 104 | rm_tokenizer.pad_token = rm_tokenizer.eos_token 105 | rm_tokenizer.padding_side = 'right' 106 | rm_tokenizer.max_length = 1024 107 | 108 | rm = GPT2RewardModel(reward_model_name=args.rm, out_features=7) 109 | 110 | state_dict = torch.load(args.rm_dir) 111 | rm.load_state_dict(state_dict) 112 | rm = rm.to('cuda') 113 | 114 | 115 | rad = RewardAugmentedDecoder( 116 | lm, 117 | lm_tokenizer, 118 | rm, 119 | rm_tokenizer, 120 | max_length, 121 | num_gpus=torch.cuda.device_count(), 122 | inverse=args.inverse) 123 | return rad 124 | 125 | 126 | def load_dataset(args): 127 | if args.dataset == 'rtp_nontoxic': 128 | prompts, toxicities = [], [] 129 | file_dir = "datasets/nontoxic_prompts-10k.jsonl" 130 | with open(file_dir) as f: 131 | for line in f: 132 | line_content = json.loads(line)['prompt'] 133 | prompts.append(line_content['text']) 134 | toxicities.append(line_content['toxicity']) 135 | data = { 136 | "prompt": prompts, 137 | "toxicity": toxicities 138 | } 139 | return data["prompt"] 140 | 141 | raise ValueError(f"Dataset {args.dataset} not supported.") 142 | 143 | 144 | def parse_args(): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--outdir", default="outputs/") 147 | parser.add_argument("--dataset", default="rtp_nontoxic") 148 | 149 | parser.add_argument("--beta", default=10, type=int) 150 | parser.add_argument("--topk", default=20, type=int) 151 | parser.add_argument("--inverse", default=True, type=bool) # steer toward lower toxicity 152 | 153 | parser.add_argument("--batch_size", default=4, type=int) 154 | parser.add_argument("--num_return_sequences", default=25, type=int) 155 | parser.add_argument("--max_new_tokens", default=20, type=int) 156 | 157 | parser.add_argument("--lm", default="gpt2-large", choices= 158 | ["gpt2-large","gpt-neox-20b","Llama-2-7b-hf", "Llama-2-13b-hf", "Llama-2-70b-hf"]) 159 | parser.add_argument("--rm", default="gpt2", choices=["gpt2"]) 160 | parser.add_argument("--rm_dir", default="reward_modeling/saved_models/gpt2_toxicity/pytorch_model.bin") 161 | 162 | parser.add_argument("--test", default=False, type=bool) 163 | 164 | args = parser.parse_args() 165 | return args 166 | 167 | 168 | def main(args): 169 | set_seed(1) 170 | dataset = load_dataset(args) 171 | rad = load_rad(args) 172 | results, generation = evaluate_model_on_dataset(args, rad, dataset) 173 | 174 | with open( 175 | os.path.join( 176 | args.outdir, 177 | f'toxicity_report_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.json' 178 | ), 'w' 179 | ) as f: 180 | json.dump(results, f) 181 | 182 | with open( 183 | os.path.join( 184 | args.outdir, 185 | f'toxicity_generation_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.jsonl' 186 | ), 'w' 187 | ) as f: 188 | for entry in generation: 189 | json.dump(entry, f) 190 | f.write("\n") 191 | 192 | 193 | if __name__ == "__main__": 194 | args = parse_args() 195 | main(args) 196 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import ( 4 | AutoTokenizer, 5 | AutoModelForCausalLM, 6 | set_seed, 7 | ) 8 | import argparse 9 | from reward_modeling.reward_model import GPT2RewardModel 10 | from utils.utils import prepare_lm, get_rm_tokenizer 11 | from utils.metrics import distinctness, compute_perplexity 12 | import torch 13 | from rad import RewardAugmentedDecoder 14 | from tqdm.auto import tqdm 15 | import numpy as np 16 | import json 17 | 18 | 19 | def generate_on_prompts(args, rad, eval_prompts): 20 | 21 | def chunks(lst, n): 22 | """Yield successive n-sized chunks from lst.""" 23 | for i in range(0, len(lst), n): 24 | yield lst[i:i + n] 25 | 26 | dist_n = [] 27 | generation = [] 28 | report = {} 29 | 30 | if args.test: 31 | eval_prompts = eval_prompts[:100] 32 | 33 | eval_prompt_chunks = list(chunks(eval_prompts, args.batch_size)) 34 | 35 | pbar = tqdm(eval_prompt_chunks) 36 | for chunk in pbar: 37 | with torch.inference_mode(): 38 | generated_texts = rad.sample( 39 | chunk, 40 | max_new_tokens=args.max_new_tokens, 41 | topk=args.topk, 42 | beta=args.beta, 43 | num_return_sequences=args.num_return_sequences, 44 | ) 45 | 46 | for i, samples in enumerate(generated_texts): 47 | dist_n.append(distinctness(samples)) 48 | generation.append({ 49 | 'prompt': {"text": chunk[i]}, 50 | 'generations': 51 | [{"text": sp} for sp in samples] 52 | }) 53 | 54 | pbar.set_description( 55 | f'dist-n = {["{:.3f}".format(x) for x in np.nanmean(np.array(dist_n), axis=0)]}' 56 | ) 57 | 58 | ppl = compute_perplexity(args, generation, rad) 59 | 60 | report.update({ 61 | 'dist_n': np.nanmean(np.array(dist_n), axis=0).tolist(), 62 | "perplexity": np.mean(ppl) 63 | }) 64 | 65 | return report, generation 66 | 67 | 68 | def load_rad(args): 69 | lm, lm_tokenizer, max_length = prepare_lm(args.lm) 70 | 71 | # rm 72 | if args.rm == 'gpt2': 73 | rm_tokenizer = AutoTokenizer.from_pretrained(args.rm) 74 | rm_tokenizer.pad_token = rm_tokenizer.eos_token 75 | rm_tokenizer.padding_side = 'right' 76 | rm_tokenizer.max_length = 1024 77 | 78 | rm = GPT2RewardModel(reward_model_name=args.rm, out_features=1) 79 | 80 | state_dict = torch.load(args.rm_dir) 81 | rm.load_state_dict(state_dict) 82 | rm = rm.to('cuda') 83 | 84 | rad = RewardAugmentedDecoder( 85 | lm, 86 | lm_tokenizer, 87 | rm, 88 | rm_tokenizer, 89 | max_length, 90 | num_gpus=torch.cuda.device_count(), 91 | inverse=args.inverse) 92 | return rad 93 | 94 | 95 | # ADD CUSTOM DATASET HERE 96 | def load_dataset(args) -> list[str]: 97 | prompts = [] 98 | if args.dataset == 'negative': 99 | file_dir = "datasets/sentiment_prompts-10k/negative_prompts.jsonl" 100 | elif args.dataset == 'neutral': 101 | file_dir = "datasets/sentiment_prompts-10k/neutral_prompts.jsonl" 102 | elif args.dataset == 'positive': 103 | file_dir = "datasets/sentiment_prompts-10k/positive_prompts.jsonl" 104 | with open(file_dir) as f: 105 | for line in f: 106 | prompts.append(json.loads(line)['prompt']['text']) 107 | return prompts 108 | 109 | 110 | def parse_args(): 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--outdir", default="outputs/") 113 | parser.add_argument("--dataset", choices=['negative','neutral','positive'], default='negative') 114 | 115 | parser.add_argument("--beta", default=10, type=int) 116 | parser.add_argument("--topk", default=20, type=int) 117 | parser.add_argument("--inverse", default=False, type=bool) # steer toward lower reward 118 | 119 | parser.add_argument("--batch_size", default=100, type=int) 120 | parser.add_argument("--num_return_sequences", default=1, type=int) 121 | parser.add_argument("--max_new_tokens", default=20, type=int) 122 | 123 | parser.add_argument("--lm", default="gpt2-large", choices= 124 | ["gpt2-large","gpt-neox-20b","Llama-2-7b-hf", "Llama-2-13b-hf", "Llama-2-70b-hf"]) 125 | parser.add_argument("--rm", default="gpt2", choices=["gpt2"]) 126 | parser.add_argument("--rm_dir", default="reward_modeling/saved_models/gpt2_sentiment/pytorch_model.bin") 127 | 128 | parser.add_argument("--test", default=False, type=bool) 129 | 130 | args = parser.parse_args() 131 | return args 132 | 133 | 134 | def main(args): 135 | set_seed(1) 136 | prompts = load_dataset(args) 137 | rad = load_rad(args) 138 | results, generation = generate_on_prompts(args, rad, prompts) 139 | 140 | with open( 141 | os.path.join( 142 | args.outdir, 143 | f'custom_task_report_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.json' 144 | ), 'w' 145 | ) as f: 146 | json.dump(results, f) 147 | 148 | with open( 149 | os.path.join( 150 | args.outdir, 151 | f'custom_task_generation_{args.lm}_{args.rm}_top{args.topk}_beta{args.beta}_{args.dataset}.jsonl' 152 | ), 'w' 153 | ) as f: 154 | for entry in generation: 155 | json.dump(entry, f) 156 | f.write("\n") 157 | 158 | 159 | if __name__ == "__main__": 160 | args = parse_args() 161 | main(args) 162 | -------------------------------------------------------------------------------- /outputs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/RAD/55f0931ba823e912a3626670ce6d46a589f7c893/outputs/.gitignore -------------------------------------------------------------------------------- /rad.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | LogitsProcessorList, 3 | TopPLogitsWarper, 4 | ) 5 | from utils.logits_processor import ( 6 | RewardAugmentedLogitsProcessor, 7 | RewardAugmentedLogitsProcessorNoPkv 8 | ) 9 | 10 | 11 | class RewardAugmentedDecoder(): 12 | 13 | def __init__(self, language_model, lm_tokenizer, reward_model, rm_tokenizer, 14 | max_length, num_gpus=4, inverse=False, efficient=True): 15 | self._lm = language_model 16 | self._lm_tokenizer = lm_tokenizer 17 | self._rm = reward_model 18 | self._rm_tokenizer = rm_tokenizer 19 | self._max_length = max_length 20 | self._num_gpus = num_gpus 21 | self._inverse = inverse 22 | self._efficient = efficient 23 | 24 | def sample( 25 | self, 26 | prompts, 27 | max_new_tokens=20, 28 | topk=20, 29 | num_return_sequences=25, 30 | method="linear", 31 | beta=30, 32 | return_continuation_only=True, 33 | data_container=None 34 | ): 35 | input_ids = self._lm_tokenizer.batch_encode_plus( 36 | prompts, 37 | return_tensors="pt", 38 | padding=True, 39 | truncation=True, 40 | max_length=self._max_length-max_new_tokens, 41 | ).to('cuda') 42 | 43 | # dry run 44 | if not self._rm: 45 | outputs = self._lm.generate( 46 | **input_ids, 47 | # min_new_tokens=2, 48 | max_new_tokens=max_new_tokens, 49 | do_sample=True, 50 | # temperature=0.7, 51 | # top_p=0.9, 52 | num_return_sequences=num_return_sequences, 53 | ) 54 | else: 55 | if self._efficient: 56 | logits_processor = LogitsProcessorList([ 57 | # TopPLogitsWarper(top_p=0.9), 58 | RewardAugmentedLogitsProcessor( 59 | self._lm_tokenizer, 60 | self._rm_tokenizer, 61 | self._rm, 62 | topk=topk, 63 | method=method, 64 | beta=beta, 65 | num_gpus=self._num_gpus, 66 | inverse=self._inverse, 67 | data_container=data_container 68 | ), 69 | ]) 70 | 71 | else: 72 | logits_processor = LogitsProcessorList([ 73 | # TopPLogitsWarper(top_p=0.9), 74 | RewardAugmentedLogitsProcessorNoPkv( 75 | self._lm_tokenizer, 76 | self._rm_tokenizer, 77 | self._rm, 78 | topk=topk, 79 | method=method, 80 | beta=beta, 81 | inverse=self._inverse, 82 | ), 83 | ]) 84 | 85 | outputs = self._lm.generate( 86 | **input_ids, 87 | logits_processor=logits_processor, 88 | # min_new_tokens=2, 89 | max_new_tokens=max_new_tokens, 90 | do_sample=True, 91 | # temperature=0.7, 92 | # top_p=0.9, 93 | num_return_sequences=num_return_sequences, 94 | ) 95 | 96 | if return_continuation_only: 97 | input_length = len(input_ids.input_ids[0]) 98 | outputs = outputs[:, input_length:] # remove prompt 99 | 100 | ret = self._lm_tokenizer.batch_decode(outputs, skip_special_tokens=True) 101 | ret = [ret[i:i+num_return_sequences] for i in range(0, len(ret), num_return_sequences)] 102 | 103 | return ret 104 | -------------------------------------------------------------------------------- /reward_modeling/.gitignore: -------------------------------------------------------------------------------- 1 | saved_models/* 2 | !saved_models/.gitingore 3 | saved_models.zip -------------------------------------------------------------------------------- /reward_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/RAD/55f0931ba823e912a3626670ce6d46a589f7c893/reward_modeling/__init__.py -------------------------------------------------------------------------------- /reward_modeling/configs/config_rm.yaml: -------------------------------------------------------------------------------- 1 | rm_toxicity: 2 | seed: 1 3 | out_features: 7 4 | reward_model_name: gpt2 5 | loss_fn: cumulative_mse 6 | learning_rate: 2e-5 7 | adam_beta1: 0.9 8 | adam_beta2: 0.95 9 | adam_epsilon: 1e-12 10 | weight_decay: 0.01 11 | warmup_steps: 1000 12 | per_device_train_batch_size: 100 13 | per_device_eval_batch_size: 100 14 | num_train_epochs: 1 15 | eval_steps: 1000 16 | save_strategy: steps 17 | save_steps: 1000 18 | max_length: 1024 19 | logging_steps: 100 20 | max_grad_norm: 2.0 21 | save_total_limit: 10 22 | dtype: float32 23 | datasets: 24 | - jigsaw_unintended_bias 25 | eval_size: 1000 26 | verbose: true 27 | log_wandb: true 28 | metrics: mse 29 | jigsaw_dir: ../datasets/PATH_TO_JIGSAW_DATASET 30 | output_dir: saved_models/gpt2_toxicity_backup 31 | 32 | 33 | rm_sentiment: 34 | seed: 1 35 | out_features: 1 36 | reward_model_name: gpt2 37 | loss_fn: cumulative_mse 38 | learning_rate: 2e-5 39 | adam_beta1: 0.9 40 | adam_beta2: 0.95 41 | adam_epsilon: 1e-12 42 | weight_decay: 0.01 43 | warmup_steps: 1000 44 | per_device_train_batch_size: 100 45 | per_device_eval_batch_size: 100 46 | num_train_epochs: 1 47 | eval_steps: 1000 48 | save_strategy: steps 49 | save_steps: 1000 50 | max_length: 1024 51 | logging_steps: 100 52 | max_grad_norm: 2.0 53 | save_total_limit: 10 54 | dtype: float32 55 | datasets: 56 | - sst2 57 | - amazon_polarity 58 | eval_size: 1000 59 | verbose: true 60 | log_wandb: true 61 | metrics: mse 62 | jigsaw_dir: null 63 | output_dir: saved_models/gpt2_sentiment_backup 64 | 65 | 66 | gpt2-small: 67 | reward_model_name: gpt2 68 | max_length: 1024 69 | dtype: float32 -------------------------------------------------------------------------------- /reward_modeling/reward_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, GPT2LMHeadModel, GPT2ForSequenceClassification 2 | import torch 3 | from torch import nn 4 | from typing import Optional, Tuple 5 | 6 | 7 | class GPT2RewardModel(nn.Module): 8 | def __init__(self, reward_model_name="gpt2", out_features=1, loss_fn="cumulative_mse"): 9 | super(GPT2RewardModel, self).__init__() 10 | model = GPT2LMHeadModel.from_pretrained(reward_model_name) 11 | # model = GPT2ForSequenceClassification.from_pretrained(reward_model_name) 12 | model.lm_head = nn.Linear(in_features=model.lm_head.in_features, out_features=out_features, bias=True) 13 | # model.score = nn.Linear(in_features=model.score.in_features, out_features=out_features, bias=True) 14 | model.config.use_cache = True 15 | self.model = model 16 | self.pad_token_id = model.config.eos_token_id 17 | self.out_features = out_features 18 | self.loss_fn = get_loss_fn(loss_fn) 19 | 20 | def forward( 21 | self, 22 | input_ids: Optional[torch.Tensor] = None, 23 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | token_type_ids: Optional[torch.Tensor] = None, 26 | position_ids: Optional[torch.Tensor] = None, 27 | head_mask: Optional[torch.Tensor] = None, 28 | inputs_embeds: Optional[torch.Tensor] = None, 29 | labels: Optional[torch.LongTensor] = None, 30 | use_cache: Optional[bool] = True, 31 | output_attentions: Optional[bool] = None, 32 | output_hidden_states: Optional[bool] = None, 33 | return_dict: Optional[bool] = None, 34 | ): 35 | outputs = self.model( 36 | input_ids=input_ids, 37 | past_key_values=past_key_values, 38 | attention_mask=attention_mask, 39 | token_type_ids=token_type_ids, 40 | position_ids=position_ids, 41 | head_mask=head_mask, 42 | inputs_embeds=inputs_embeds, 43 | use_cache=use_cache, 44 | output_attentions=output_attentions, 45 | output_hidden_states=output_hidden_states, 46 | return_dict=return_dict, 47 | ) 48 | logits = outputs['logits'] 49 | # find the last valid token's ids 50 | sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(logits.device) 51 | # use the last valid token's representation: (batch, max_length, out_features) => (batch, out_features) 52 | scores = logits[torch.arange(input_ids.shape[0], device=logits.device), sequence_lengths] 53 | 54 | loss = None 55 | if labels is not None: 56 | loss = self.loss_fn(scores, labels, logits, sequence_lengths+1) 57 | 58 | if use_cache: 59 | past_key_values = outputs['past_key_values'] 60 | return loss, scores, past_key_values 61 | else: 62 | return loss, scores 63 | 64 | 65 | def get_loss_fn(name): 66 | if name == "mse": 67 | def mse_loss_fn(scores, labels, logits, lengths): 68 | return nn.MSELoss()(scores, labels) 69 | 70 | loss_fn = mse_loss_fn 71 | 72 | elif name == "cross_entropy": 73 | def ce_loss_fn(scores, labels, logits, lengths): 74 | return nn.CrossEntropyLoss()(scores, labels) # here score is logits[last_token_id] 75 | 76 | loss_fn = ce_loss_fn 77 | 78 | elif name == "cumulative_mse": 79 | def cumulative_mse_fn(scores, labels, logits, lengths): 80 | mse_loss = nn.MSELoss(reduction='none') 81 | losses = [] 82 | for i in range(len(labels)): 83 | logit = logits[i, :lengths[i]].reshape(lengths[i], -1) # (lengths[i], out) 84 | label = labels[i].reshape(-1).repeat(lengths[i], 1).float() # (lengths[i], out) 85 | loss = mse_loss(logit, label) # (lengths[i], out) 86 | loss = torch.matmul( 87 | loss.permute(1,0).float(), 88 | torch.arange(start=1, end=lengths[i]+1, device=logits.device).float() 89 | ) # (out,) 90 | losses.append(2*torch.sum(loss)/(lengths[i]+1)/lengths[i]) # s = n(n+1)/2 91 | return torch.stack(losses).mean() 92 | 93 | loss_fn = cumulative_mse_fn 94 | 95 | elif name == "cumulative_ce": 96 | def cumulative_ce_fn(scores, labels, logits, lengths): 97 | ce_loss = nn.CrossEntropyLoss(reduction='none') 98 | losses = [] 99 | for i in range(len(labels)): 100 | logit = logits[i, :lengths[i]] # (lengths[i], out_features) 101 | label = labels[i].repeat(lengths[i]) # (lengths[i],) 102 | # multiply ce_loss with a linearly increasing weight e.g. [1/5, 2/5, 3/5, 4/5, 5/5] 103 | loss = ce_loss(logit, label)*torch.arange(start=1, end=lengths[i]+1, device=logits.device)/lengths[i] 104 | # sum and multiply by 2/(1+n) to keep the expected loss the same as other methods 105 | losses.append(2*torch.sum(loss)/(lengths[i]+1)) 106 | return torch.stack(losses).mean() 107 | 108 | loss_fn = cumulative_ce_fn 109 | 110 | else: 111 | raise ValueError(f"loss function name {name} not available") 112 | 113 | return loss_fn -------------------------------------------------------------------------------- /reward_modeling/trainer_rm.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets, Value 2 | from transformers import AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch 6 | import argparse 7 | import random 8 | from utils.utils import read_yamls, _strtobool, get_dataset, get_rm_tokenizer, get_reward_model 9 | from utils.metrics import mse 10 | import os 11 | from torch.utils.data import Subset 12 | from tqdm import tqdm 13 | from transformers.training_args import OptimizerNames 14 | 15 | 16 | def argument_parsing(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--configs", nargs="+", required=True) 19 | parser.add_argument("--wandb_entity", type=str, default="frankdenghaikang") 20 | parser.add_argument("--resume_from_checkpoint", action="store_true", help="Resume from last saved checkpoint") 21 | 22 | args, remaining = parser.parse_known_args() 23 | 24 | # Config from YAML 25 | conf = {} 26 | configs = read_yamls("./configs") 27 | for name in args.configs: 28 | if "," in name: 29 | for n in name.split(","): 30 | conf.update(configs[n]) 31 | else: 32 | conf.update(configs[name]) 33 | 34 | conf["wandb_entity"] = args.wandb_entity 35 | conf["resume_from_checkpoint"] = args.resume_from_checkpoint 36 | 37 | # Override config from command-line 38 | parser = argparse.ArgumentParser() 39 | for key, value in conf.items(): 40 | type_ = type(value) if value is not None else str 41 | if type_ == bool: 42 | type_ = _strtobool 43 | parser.add_argument(f"--{key}", type=type_, default=value) 44 | 45 | return parser.parse_args(remaining) 46 | 47 | 48 | def main(): 49 | 50 | training_args = argument_parsing() 51 | set_seed(training_args.seed) 52 | 53 | tokenizer = get_rm_tokenizer(training_args) 54 | model = get_reward_model(training_args) 55 | 56 | train, evals = get_dataset(training_args, tokenizer) 57 | data_collator = DataCollatorWithPadding( 58 | tokenizer, 59 | padding=True, 60 | max_length=training_args.max_length, 61 | ) 62 | 63 | if training_args.verbose: 64 | print("Dataset stats before sampling:") 65 | total = len(train) 66 | for d in train.datasets: 67 | if isinstance(d, Subset): 68 | name = f"Subset of {type(d.dataset).__name__}" 69 | if hasattr(d.dataset, "name"): 70 | name += f" ({d.dataset.name})" 71 | else: 72 | name = type(d).__name__ 73 | if hasattr(d, "name"): 74 | name += f" ({d.name})" 75 | print(f"{name}: {len(d)} ({len(d) / total:%})") 76 | print(f"Total train: {total}") 77 | 78 | optimizer = OptimizerNames.ADAMW_HF 79 | 80 | output_dir = ( 81 | training_args.output_dir 82 | if training_args.output_dir 83 | else f"{training_args.reward_model_name}-{training_args.dataset}-finetuned" 84 | ) 85 | 86 | args = TrainingArguments( 87 | output_dir=output_dir, 88 | num_train_epochs=training_args.num_train_epochs, 89 | warmup_steps=training_args.warmup_steps, 90 | learning_rate=float(training_args.learning_rate), 91 | optim=optimizer, 92 | fp16=training_args.dtype in ["fp16", "float16"], 93 | bf16=training_args.dtype in ["bf16", "bfloat16"], 94 | per_device_train_batch_size=training_args.per_device_train_batch_size, 95 | per_device_eval_batch_size=training_args.per_device_eval_batch_size, 96 | adam_beta1=training_args.adam_beta1, 97 | adam_beta2=training_args.adam_beta2, 98 | adam_epsilon=float(training_args.adam_epsilon), 99 | weight_decay=training_args.weight_decay, 100 | max_grad_norm=training_args.max_grad_norm, 101 | logging_steps=training_args.logging_steps, 102 | save_total_limit=training_args.save_total_limit, 103 | evaluation_strategy="steps", 104 | eval_steps=training_args.eval_steps, 105 | save_strategy=training_args.save_strategy, 106 | save_steps=training_args.save_steps, 107 | resume_from_checkpoint=training_args.resume_from_checkpoint, 108 | report_to="wandb" if training_args.log_wandb else None, 109 | ) 110 | 111 | if not training_args.log_wandb: 112 | os.environ["WANDB_MODE"] = "offline" 113 | 114 | if training_args.log_wandb: 115 | import wandb 116 | 117 | wandb.init( 118 | project="reward-model", 119 | entity=training_args.wandb_entity, 120 | resume=training_args.resume_from_checkpoint, 121 | name=f"{training_args.reward_model_name}-rm", 122 | config=training_args, 123 | ) 124 | 125 | compute_metrics = mse 126 | trainer = Trainer( 127 | model=model, 128 | args=args, 129 | train_dataset=train, 130 | eval_dataset=evals, 131 | data_collator=data_collator, 132 | tokenizer=tokenizer, 133 | compute_metrics=compute_metrics, 134 | ) 135 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 136 | trainer.save_model() 137 | tokenizer.save_pretrained(output_dir) 138 | 139 | 140 | if __name__ == "__main__": 141 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="reward_augmented_decoding", 5 | version="0.0.1", 6 | python_requires=">=3.8", 7 | author="Haikang Deng", 8 | long_description="", 9 | install_requires=[], 10 | packages = find_packages(), 11 | ) -------------------------------------------------------------------------------- /utils/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/RAD/55f0931ba823e912a3626670ce6d46a589f7c893/utils/.gitignore -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/RAD/55f0931ba823e912a3626670ce6d46a589f7c893/utils/__init__.py -------------------------------------------------------------------------------- /utils/logits_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from transformers import LogitsProcessor 4 | import time 5 | from queue import Queue 6 | from threading import Thread 7 | import copy 8 | 9 | 10 | class RewardAugmentedLogitsProcessor(LogitsProcessor): 11 | ''' 12 | This class is used to process logits of the language model at every timestep. 13 | It will load a copy of reward model on each GPU and take care of past_key_values. 14 | ''' 15 | 16 | def __init__(self, lm_tokenizer, rm_tokenizer, reward_model, topk=20, 17 | method="linear", beta=30, num_gpus=4, inverse=False, data_container=None): 18 | self._lm_tokenizer = lm_tokenizer 19 | self._rm_tokenizer = rm_tokenizer 20 | self._reward_model = reward_model 21 | self._reward_model.eval() 22 | self._topk = topk 23 | self._method = method 24 | self._beta = beta 25 | self._inverse = inverse 26 | self._num_gpus = num_gpus 27 | self._past_key_values = [None]*self._num_gpus 28 | self._previous_input_ids_to_topk_idx = {} # (batch, dict{input_id: topk_idx}), get last non-zero inputid 29 | self._step = 0 30 | self._attention_mask = [None]*self._num_gpus # (batch x topk, sequence_length) 31 | self._reward_models = [] 32 | self._data_container = data_container 33 | for i in range(self._num_gpus): 34 | model_copy = copy.deepcopy(self._reward_model) 35 | model_copy = model_copy.to(f'cuda:{i}') 36 | self._reward_models.append(model_copy) 37 | 38 | 39 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 40 | ''' 41 | past_key_values: 42 | Tuple of length config.n_layers, each containing tuples of tensors of shape 43 | (batch_size, num_heads, sequence_length, embed_size_per_head). 44 | ''' 45 | def process_prompts(q: Queue[int]): 46 | gpu_id = q.get() 47 | batch_size = scores.shape[0] 48 | rows_per_gpu = int(np.ceil(batch_size * self._topk / self._num_gpus)) 49 | start = gpu_id * rows_per_gpu 50 | end = min(start+rows_per_gpu, batch_size*self._topk) 51 | 52 | input_prompts_partition = input_prompts[start: end] 53 | past_key_values_part, attention_mask_part = self.get_past_key_values(input_prompts_partition, gpu_id, max_prompt_length) 54 | # on different devices 55 | self._past_key_values[gpu_id] = past_key_values_part 56 | self._attention_mask[gpu_id] = attention_mask_part 57 | q.task_done() 58 | 59 | def do_normal_task(q: Queue[int]): 60 | gpu_id = q.get() 61 | batch_size = scores.shape[0] 62 | rows_per_gpu = int(np.ceil(batch_size * self._topk / self._num_gpus)) 63 | start = gpu_id * rows_per_gpu 64 | end = min(start+rows_per_gpu, batch_size*self._topk) 65 | 66 | candidate_tokens_partition = candidate_tokens[start: end] 67 | reward_scores_part, self._past_key_values[gpu_id], self._attention_mask[gpu_id] = self.get_reward( 68 | candidate_tokens_partition, self._past_key_values[gpu_id], self._attention_mask[gpu_id], gpu_id, max_candidate_length 69 | ) 70 | reward_scores[gpu_id] = reward_scores_part.to('cuda') 71 | q.task_done() 72 | 73 | with torch.inference_mode(): 74 | topk_scores, topk_ids = torch.topk(scores, self._topk, dim=-1) # (batch, topk,) 75 | reward_scores = [None]*self._num_gpus 76 | last_selected_topk_indices = [] 77 | max_prompt_length = -1 78 | max_candidate_length = -1 79 | 80 | # prepare pkv and attn_mask 81 | if self._step == 0: 82 | ''' 83 | 1. repeat prompt topk times 84 | 2. get prompt pkv and attn_mask 85 | ''' 86 | input_prompts = self._lm_tokenizer.batch_decode(input_ids, skip_special_tokens=True) 87 | 88 | max_prompt_length = self._rm_tokenizer.batch_encode_plus( 89 | input_prompts, 90 | return_tensors="pt", 91 | padding=True, 92 | ).input_ids.shape[1] 93 | 94 | input_prompts = [element for element in input_prompts for i in range(self._topk)] # (batch x topk, ) 95 | 96 | q = Queue() 97 | for i in range(self._num_gpus): 98 | q.put(i) 99 | for i in range(self._num_gpus): 100 | worker = Thread(target=process_prompts, args=(q,)) 101 | worker.start() 102 | q.join() 103 | 104 | else: 105 | ''' 106 | 1. use dict to find which token is chosen in last step 107 | 2. select that pkv and broadcast, select that attn_mask and broadcast 108 | ''' 109 | for i, (input_ids_i, input_ids_to_topk_idx_dict_i), in enumerate(zip(input_ids, self._previous_input_ids_to_topk_idx)): 110 | # skip if eos is being generated 111 | if input_ids_i[-1]==self._lm_tokenizer.eos_token_id: 112 | last_selected_topk_indices.append(-1) 113 | continue 114 | last_selected_topk_idx = input_ids_to_topk_idx_dict_i[input_ids_i[-1].item()] 115 | last_selected_topk_indices.append(last_selected_topk_idx) 116 | batch_size = scores.shape[0] 117 | 118 | # for example i, update its pkv and attn_mask on corresponding gpu(s) 119 | rows_per_gpu = int(np.ceil(batch_size * self._topk / self._num_gpus)) 120 | start, end = i*self._topk, (i+1)*self._topk-1 121 | 122 | start_gpu, end_gpu = start//rows_per_gpu, end//rows_per_gpu 123 | start_idx, end_idx = start%rows_per_gpu, end%rows_per_gpu 124 | 125 | selected_token_gpu = (start+last_selected_topk_idx)//rows_per_gpu 126 | selected_token_idx = (start+last_selected_topk_idx)%rows_per_gpu 127 | 128 | while start_gpu < end_gpu: 129 | rows = self._attention_mask[start_gpu].shape[0] # rows might be different from rows_per_gpu since the last gpu might have less rows 130 | self._attention_mask[start_gpu][start_idx:, :] = self._attention_mask[selected_token_gpu][selected_token_idx, :].repeat( 131 | rows-start_idx, 1) 132 | if start_gpu==selected_token_gpu: 133 | for layer_kv in self._past_key_values[start_gpu]: 134 | for e in layer_kv: 135 | e[start_idx:, :, :, :] = e[selected_token_idx, :, :, :].unsqueeze(0).repeat( 136 | rows-start_idx, 1, 1, 1) 137 | else: 138 | for layer_kv,layer_kv_selected in zip(self._past_key_values[start_gpu], self._past_key_values[selected_token_gpu]): 139 | for e, e_selected in zip(layer_kv, layer_kv_selected): 140 | e[start_idx:, :, :, :] = e_selected[selected_token_idx, :, :, :].unsqueeze(0).repeat( 141 | rows-start_idx, 1, 1, 1) 142 | start_idx = 0 143 | start_gpu += 1 144 | 145 | self._attention_mask[start_gpu][start_idx:end_idx+1, :] = self._attention_mask[selected_token_gpu][selected_token_idx, :].repeat( 146 | end_idx-start_idx+1, 1) 147 | if start_gpu==selected_token_gpu: 148 | for layer_kv in self._past_key_values[start_gpu]: 149 | for e in layer_kv: 150 | e[start_idx:end_idx+1, :, :, :] = e[selected_token_idx, :, :, :].unsqueeze(0).repeat( 151 | end_idx-start_idx+1, 1, 1, 1) 152 | else: # if selected token is not on the same machine with current token 153 | for layer_kv,layer_kv_selected in zip(self._past_key_values[start_gpu], self._past_key_values[selected_token_gpu]): 154 | for e, e_selected in zip(layer_kv, layer_kv_selected): 155 | e[start_idx:end_idx+1, :, :, :] = e_selected[selected_token_idx, :, :, :].unsqueeze(0).repeat( 156 | end_idx-start_idx+1, 1, 1, 1) 157 | 158 | # get candidate sequences reward 159 | batch_size = scores.shape[0] 160 | ids = topk_ids.reshape((batch_size*self._topk, 1)) 161 | candidate_tokens = self._lm_tokenizer.batch_decode(ids, skip_special_tokens=True) 162 | 163 | max_candidate_length = self._rm_tokenizer.batch_encode_plus( 164 | candidate_tokens, 165 | return_tensors="pt", 166 | padding=True, 167 | ).input_ids.shape[1] 168 | 169 | q = Queue() 170 | for i in range(self._num_gpus): 171 | q.put(i) 172 | for i in range(self._num_gpus): 173 | worker = Thread(target=do_normal_task, args=(q,)) 174 | worker.start() 175 | q.join() 176 | 177 | reward_scores = torch.cat(reward_scores, dim=0).reshape((-1, self._topk)) 178 | 179 | if self._data_container is not None: 180 | if self._step==0: # update cur_row on first step since last step is hard to track 181 | self._data_container['cur_row'] += batch_size 182 | cur_row = self._data_container['cur_row'] 183 | self._data_container['rewards'][cur_row-batch_size:cur_row, self._step, :] = reward_scores.cpu().numpy() # (rows, topk) 184 | self._data_container['logits'][cur_row-batch_size:cur_row, self._step, :] = topk_scores.cpu().numpy() # (rows, topk) 185 | if self._step!=0: 186 | self._data_container['selected_indices'][cur_row-batch_size:cur_row, self._step-1] = np.array(last_selected_topk_indices) # (rows, ) 187 | 188 | for score, id, ts in zip(scores, topk_ids, reward_scores): 189 | score[id] = self.apply_function(score[id], ts) 190 | inverse_id = torch.tensor(np.setdiff1d(range(len(score.cpu().numpy())), id.cpu().numpy()), device='cuda') 191 | score[inverse_id] = -float("Inf") # set all other scores to -inf 192 | 193 | # update step, pkv, attn_mask, and dict 194 | self._step+=1 195 | self._previous_input_ids_to_topk_idx = [ 196 | {ids.item():pos for pos,ids in enumerate(topk_ids_i)} for topk_ids_i in topk_ids 197 | ] 198 | return scores 199 | 200 | def get_reward(self, candidate_texts, past_key_values, past_attention_mask, gpu, max_candidate_length): 201 | with torch.inference_mode(): 202 | inputs = self._rm_tokenizer.batch_encode_plus( 203 | candidate_texts, 204 | return_tensors="pt", 205 | padding="max_length", 206 | truncation=True, 207 | max_length=max_candidate_length, 208 | ).to(f'cuda:{gpu}') 209 | 210 | # attention_mask with in between. e.g. [13,1052,38,50256,50256,11,50256] => [1,1,1,0,0,1,0] 211 | attention_mask = torch.cat((past_attention_mask, inputs.attention_mask), dim=-1) # (batch x topk, new_seq_length) 212 | position_ids = torch.cumsum(attention_mask, dim=-1)[:, past_attention_mask.shape[-1]:] # cumsum the attention to get correct pos id for each new token 213 | reward_scores, past_key_values = self.helper(inputs.input_ids, attention_mask, position_ids, past_key_values, gpu) 214 | return reward_scores, past_key_values, attention_mask 215 | 216 | # helper method that calls reward model and returns reward scores 217 | def helper(self, input_ids, attention_mask, position_ids, past_key_values, gpu): 218 | reward_model = self._reward_models[gpu] 219 | _, reward_logits, past_key_values = reward_model(input_ids=input_ids, 220 | attention_mask=attention_mask, 221 | position_ids=position_ids, 222 | labels=None, 223 | use_cache=True, 224 | past_key_values=past_key_values) 225 | # # save for future exploration 226 | # if reward_logits.shape[1] > 1: # classification case 227 | # # calculate expected value of reward, no matter 2 classes or 5 classes, normalized by ÷ (num_class-1) 228 | # device = torch.device(f'cuda:{gpu}') 229 | # probs = torch.softmax(reward_logits, dim=-1) # e.g. [0.1, 0.1, 0.1, 0.1, 0.6] 230 | # class_vals = torch.arange(reward_logits.shape[1], dtype=torch.float).unsqueeze(0).permute(1,0).to(device) # e.g. [0,1,2,3,4] 231 | # reward_scores = torch.matmul(probs, class_vals)/(reward_logits.shape[1]-1) # sum and divide by 4: 3.0/4.0=0.75 232 | # return reward_scores, past_key_values 233 | # else: 234 | # return reward_logits[:, 0], past_key_values 235 | return reward_logits[:, 0], past_key_values 236 | 237 | 238 | def get_past_key_values(self, contexts, gpu, max_prompt_length): 239 | with torch.inference_mode(): 240 | reward_model = self._reward_models[gpu] 241 | input_ids = self._rm_tokenizer.batch_encode_plus( 242 | contexts, 243 | return_tensors="pt", 244 | padding="max_length", 245 | truncation=True, 246 | max_length=max_prompt_length, 247 | ).to(f'cuda:{gpu}') 248 | _, _, past_key_values = reward_model(**input_ids, labels=None, use_cache=True) 249 | return past_key_values, input_ids.attention_mask 250 | 251 | def apply_function(self, original_score, reward_score): 252 | reward_score = torch.clamp(reward_score, min=0, max=1) 253 | if self._inverse: 254 | reward_score = 1-reward_score 255 | if self._method == "linear": 256 | return original_score + (reward_score*self._beta).to(original_score.dtype) 257 | else: 258 | raise ValueError(f"method {self._method} not supported") 259 | 260 | 261 | 262 | class RewardAugmentedLogitsProcessorNoPkv(LogitsProcessor): 263 | 264 | def __init__(self, lm_tokenizer, rm_tokenizer, reward_model, topk=20, 265 | method="linear", beta=30, inverse=False): 266 | self._lm_tokenizer = lm_tokenizer 267 | self._rm_tokenizer = rm_tokenizer 268 | self._reward_model = reward_model.to('cuda') 269 | self._reward_model.eval() 270 | self._topk = topk 271 | self._method = method 272 | self._beta = beta 273 | self._inverse = inverse 274 | 275 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 276 | _, topk_ids = torch.topk(scores, self._topk, dim=-1) # (batch, topk,) 277 | input_ids_enflated = input_ids.unsqueeze(1).expand((-1, self._topk, -1)) # (batch, topk, seq_len) 278 | candidate_input_ids = torch.cat((input_ids_enflated, topk_ids.unsqueeze(-1)), dim=-1) # (batch, topk, seq_len+1) 279 | candidate_input_ids_unroll = candidate_input_ids.reshape(( 280 | candidate_input_ids.shape[0]*candidate_input_ids.shape[1], -1)) # (batch*topk, seq_len+1) 281 | candidate_input_texts = self._lm_tokenizer.batch_decode(candidate_input_ids_unroll, skip_special_tokens=True) 282 | 283 | # return reward scores 284 | reward_scores = self.get_reward(candidate_input_texts).reshape((input_ids.shape[0], -1)) 285 | 286 | # apply function (topk_scores, logits) 287 | for score, id, rs in zip(scores, topk_ids, reward_scores): 288 | 289 | score[id] = self.apply_function(score[id], rs) 290 | inverse_id = torch.tensor(np.setdiff1d(range(len(score.cpu().numpy())), id.cpu().numpy()), device='cuda') 291 | score[inverse_id] = -float("Inf") # set all other scores to -inf 292 | return scores 293 | 294 | def get_reward(self, candidate_texts): 295 | with torch.inference_mode(): 296 | # tokenizer should be configured in RAD 297 | input_ids = self._rm_tokenizer.batch_encode_plus( 298 | candidate_texts, 299 | return_tensors="pt", 300 | padding=True, 301 | truncation=True, 302 | max_length=self._rm_tokenizer.max_length, 303 | ).to('cuda') 304 | 305 | _, reward, _ = self._reward_model(**input_ids, labels=None) 306 | return reward 307 | 308 | def apply_function(self, original_score, reward_score): 309 | reward_score = torch.clamp(reward_score, min=0, max=1) 310 | if self._inverse: 311 | reward_score = 1-reward_score 312 | if self._method == "linear": 313 | return original_score + (reward_score*self._beta).to(original_score.dtype) 314 | else: 315 | raise ValueError(f"method {self._method} not supported") -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import evaluate 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | import torch 4 | from tqdm.auto import tqdm 5 | import numpy as np 6 | 7 | 8 | # takes in a EvalPrediction and returns a dictionary string to metric values. 9 | def mse(eval_preds): 10 | predictions, labels = eval_preds # (eval set size, 1) 11 | mse_metric = evaluate.load("mse") 12 | return mse_metric.compute(predictions=predictions, references=labels) 13 | 14 | 15 | def distinctness(generations): 16 | unigrams, bigrams, trigrams = set(), set(), set() 17 | total_words = 0 18 | 19 | for gen in generations: 20 | o = gen.split(' ') 21 | total_words += len(o) 22 | unigrams.update(o) 23 | for i in range(len(o) - 1): 24 | bigrams.add(o[i] + '_' + o[i + 1]) 25 | for i in range(len(o) - 2): 26 | trigrams.add(o[i] + '_' + o[i + 1] + '_' + o[i + 2]) 27 | 28 | return len(unigrams) / total_words, len(bigrams) / total_words, len(trigrams) / total_words 29 | 30 | 31 | def compute_perplexity(args, generation, rad, device='cuda'): 32 | if "gpt2" in args.lm: 33 | model = AutoModelForCausalLM.from_pretrained('gpt2-xl', device_map='auto') 34 | tokenizer = AutoTokenizer.from_pretrained('gpt2-xl') 35 | 36 | else: # use lm itself for ppl evaluation 37 | model = rad._lm 38 | tokenizer = rad._lm_tokenizer 39 | 40 | perplexities = [] 41 | 42 | pbar = tqdm(generation, total=len(generation), desc='Evaluate Fluency') 43 | for row in pbar: 44 | prompt = row['prompt']['text'] 45 | prompt_input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) 46 | 47 | with torch.inference_mode(): 48 | prompt_loss = model(prompt_input_ids, labels=prompt_input_ids)[0] 49 | prompt_loss *= (prompt_input_ids.shape[1]-1) 50 | 51 | for cont in row['generations']: 52 | cont = cont['text'] 53 | full_input_ids = tokenizer.encode(prompt+cont, return_tensors='pt').to(device) 54 | full_loss = model(full_input_ids, labels=full_input_ids)[0] * (full_input_ids.shape[1]-1) 55 | loss = (full_loss - prompt_loss) / (full_input_ids.shape[1] - prompt_input_ids.shape[1]) 56 | ppl = torch.exp(loss).item() 57 | 58 | if ppl < 1e5: 59 | perplexities.append(ppl) 60 | 61 | pbar.set_description( 62 | f'mean ppl = {np.mean(perplexities):.3f}' 63 | ) 64 | 65 | return perplexities -------------------------------------------------------------------------------- /utils/perspective_api.py: -------------------------------------------------------------------------------- 1 | import googleapiclient 2 | from googleapiclient import discovery 3 | import time 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import List 6 | from collections.abc import Iterator 7 | from threading import Lock 8 | import numpy as np 9 | 10 | PERSPECTIVE_API_KEY = "" # Place your KEY here 11 | QUOTA_IN_QPS = 1 # QPS>100 is optimal, see Readme.md for ways to increase API quota 12 | 13 | class RateLimiter(Iterator): 14 | """Iterator that yields a value at most once every 'interval' seconds.""" 15 | def __init__(self, interval): 16 | self.lock = Lock() 17 | self.interval = interval 18 | self.next_yield = 0 19 | 20 | def __next__(self): 21 | with self.lock: 22 | t = time.monotonic() 23 | if t < self.next_yield: 24 | time.sleep(self.next_yield - t) 25 | t = time.monotonic() 26 | self.next_yield = t + self.interval 27 | 28 | 29 | class PerspectiveAPIMetric(): 30 | def __init__(self): 31 | self._API_KEY = PERSPECTIVE_API_KEY 32 | self.attributes = ['TOXICITY', 'SEVERE_TOXICITY', 'IDENTITY_ATTACK', 'INSULT', 'PROFANITY', 'THREAT'] 33 | 34 | def compute(self, texts) -> List[float]: 35 | # handle a signle or a list of examples 36 | def call_api(idx, text, retry=10): 37 | client = discovery.build( 38 | "commentanalyzer", 39 | "v1alpha1", 40 | developerKey=self._API_KEY, 41 | discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1", 42 | static_discovery=False, 43 | ) 44 | try: 45 | analyze_request = { 46 | 'comment': {'text': text}, 47 | "languages": ["en"], 48 | 'requestedAttributes': {attribute:{} for attribute in self.attributes}, 49 | } 50 | response = client.comments().analyze(body=analyze_request).execute() 51 | score = [response['attributeScores'][attribute]['summaryScore']['value'] for attribute in self.attributes] 52 | if len(score) != len(self.attributes): 53 | score = [np.nan]*len(self.attributes) 54 | results.append((idx, score)) 55 | except googleapiclient.errors.HttpError as e: 56 | if retry == 0: 57 | print(e) 58 | print('HTTPError, not retrying') 59 | results.append((idx, [np.nan]*len(self.attributes))) 60 | else: 61 | time.sleep(0.1) 62 | print(f'HTTPError, retrying {retry-1}') 63 | call_api(idx, text, retry=retry-1) 64 | 65 | def work_function(idx, text): 66 | next(api_rate_limiter) 67 | return call_api(idx, text) 68 | 69 | results = [] 70 | api_rate_limiter = RateLimiter(1.5/QUOTA_IN_QPS) # adjust rate based on quota, use formula: rate >= 1/(quota in QPS). 71 | 72 | chunk_size = 500 73 | for i_chunk in range(0, len(texts), chunk_size): 74 | gt = texts[i_chunk:i_chunk+chunk_size] 75 | 76 | with ThreadPoolExecutor(max_workers=40) as executor: 77 | for idx, text in enumerate(gt): 78 | if text == "" or text is None: 79 | print("text is None or empty String") 80 | results.append((i_chunk+idx, [np.nan]*len(self.attributes))) 81 | executor.submit(work_function, i_chunk+idx, text) 82 | 83 | scores = [np.nan] * len(texts) 84 | for idx, score in results: 85 | scores[idx] = score 86 | return scores -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | AutoModelForSequenceClassification, 5 | set_seed, 6 | pipeline, 7 | GPTNeoXConfig, 8 | GPTNeoXForCausalLM, 9 | LlamaForCausalLM, 10 | LlamaTokenizer, 11 | LlamaConfig, 12 | ) 13 | from datasets import load_dataset 14 | from torch import nn 15 | import torch 16 | from pathlib import Path 17 | import yaml 18 | import random 19 | import copy 20 | from torch.utils.data import ConcatDataset, Subset 21 | from reward_model import GPT2RewardModel 22 | from distutils.util import strtobool 23 | import json 24 | 25 | 26 | def get_dataset_name_and_kwargs_from_data_config(data_config): 27 | if isinstance(data_config, dict): 28 | name = list(data_config.keys())[0] 29 | 30 | # first copy the dict, then remove the size and fraction 31 | kwargs = copy.deepcopy(data_config[name]) 32 | 33 | kwargs.pop("fraction", None) 34 | kwargs.pop("size", None) 35 | return name, kwargs 36 | else: 37 | return data_config, {} 38 | 39 | 40 | def get_dataset( 41 | args, 42 | tokenizer 43 | ) -> tuple[ConcatDataset, dict[str, Subset]]: 44 | train_datasets, evals = [], {} 45 | 46 | for data_config in args.datasets: 47 | dataset_name, kwargs = get_dataset_name_and_kwargs_from_data_config(data_config) 48 | train, val = get_one_dataset(args, dataset_name, tokenizer) 49 | train_datasets.append(train) 50 | 51 | if val is not None: 52 | evals[dataset_name] = Subset(val, list(range(min(len(val), args.eval_size)))) if args.eval_size else val 53 | 54 | train = ConcatDataset(train_datasets) 55 | return train, evals 56 | 57 | 58 | def get_one_dataset( 59 | args, 60 | dataset_name, 61 | tokenizer 62 | ): 63 | if dataset_name == "sst2": 64 | dataset = load_dataset("sst2") 65 | dataset = dataset.rename_columns({"label": "labels", "sentence": "text"}) 66 | 67 | columns = dataset['train'].column_names 68 | columns_to_keep = ["text", "labels"] 69 | dataset = dataset.remove_columns(list(set(columns)-set(columns_to_keep))) 70 | 71 | def tokenize_dataset(examples): 72 | # remove the space at the end of each sentence 73 | return tokenizer([e[:-1] for e in examples["text"]], truncation=True, max_length=args.max_length) 74 | 75 | dataset = dataset.map(tokenize_dataset, batched=True) 76 | train, eval = dataset['train'], dataset['validation'] 77 | 78 | elif dataset_name == "amazon_polarity": 79 | dataset = load_dataset("amazon_polarity") 80 | dataset = dataset.rename_columns({"label": "labels", "content": "text"}) 81 | 82 | columns = dataset['train'].column_names 83 | columns_to_keep = ["text", "labels"] 84 | dataset = dataset.remove_columns(list(set(columns)-set(columns_to_keep))) 85 | 86 | def tokenize_dataset(examples): 87 | return tokenizer(examples["text"], truncation=True, max_length=args.max_length) 88 | 89 | dataset = dataset.map(tokenize_dataset, batched=True) 90 | train, eval = dataset['train'], dataset['test'] 91 | 92 | elif dataset_name == "jigsaw_unintended_bias": 93 | dataset = load_dataset("jigsaw_unintended_bias", data_dir=args.jigsaw_dir) 94 | columns = dataset['train'].column_names 95 | columns_to_keep = [ 96 | "comment_text", "target", "severe_toxicity", "obscene", 97 | "identity_attack", "insult", "threat", "sexual_explicit" 98 | ] 99 | dataset = dataset.remove_columns(list(set(columns)-set(columns_to_keep))) 100 | dataset = dataset.map( 101 | lambda example: {"labels": [example["target"], 102 | example["severe_toxicity"], 103 | example["obscene"], 104 | example["identity_attack"], 105 | example["insult"], 106 | example["threat"], 107 | example["sexual_explicit"]]}, 108 | remove_columns=columns_to_keep[1:] # keep "comment_text" and "labels" only 109 | ) 110 | dataset = dataset.rename_columns({"comment_text": "text"}) 111 | 112 | def tokenize_dataset(examples): 113 | return tokenizer(examples["text"], truncation=True, max_length=args.max_length) 114 | 115 | dataset = dataset.map(tokenize_dataset, batched=True) 116 | train, eval = dataset['train'], dataset['test_public_leaderboard'] 117 | 118 | return train, eval 119 | 120 | 121 | def prepare_lm(model_name): 122 | 123 | if model_name == "gpt2-large": 124 | lm_tokenizer = AutoTokenizer.from_pretrained("gpt2-large") 125 | lm = AutoModelForCausalLM.from_pretrained("gpt2-large", device_map='balanced_low_0') 126 | max_length = 1024 127 | elif model_name == 'gpt-neox-20b': 128 | lm_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 129 | # max_memory={i: "32GiB" for i in range(4)} # assume 4 GPUs 130 | # max_memory[0] = "16GiB" 131 | # configuration = GPTNeoXConfig() 132 | # with init_empty_weights(): 133 | # model = GPTNeoXForCausalLM(configuration) 134 | # device_map = infer_auto_device_map(model, no_split_module_classes=["GPTNeoXLayer"], max_memory=max_memory) 135 | # device_map['embed_out'] = device_map['gpt_neox.embed_in'] # put output layer on the same device as the input layer 136 | # lm = GPTNeoXForCausalLM.from_pretrained( 137 | # "EleutherAI/gpt-neox-20b", device_map=device_map, torch_dtype=torch.float16) 138 | lm = GPTNeoXForCausalLM.from_pretrained( 139 | "EleutherAI/gpt-neox-20b", device_map='balanced_low_0', torch_dtype=torch.float16) 140 | max_length = 2048 141 | 142 | elif "llama" in model_name or "Llama" in model_name: 143 | model_name = f"meta-llama/{model_name}" 144 | lm_tokenizer = LlamaTokenizer.from_pretrained(model_name) 145 | lm = LlamaForCausalLM.from_pretrained( 146 | model_name, device_map='balanced_low_0', torch_dtype=torch.bfloat16) 147 | max_length = 4096 148 | 149 | # set pad_token_id to eos_token_id because GPT2/Llama does not have a PAD token 150 | lm_tokenizer.pad_token = lm_tokenizer.eos_token 151 | lm_tokenizer.padding_side = 'left' # left padding while generating 152 | lm.config.pad_token_id = lm.config.eos_token_id 153 | 154 | return lm, lm_tokenizer, max_length 155 | 156 | 157 | def get_rm_tokenizer(args): 158 | tokenizer = AutoTokenizer.from_pretrained(args.reward_model_name) 159 | tokenizer.pad_token = tokenizer.eos_token 160 | tokenizer.padding_side = 'right' 161 | tokenizer.max_length = args.max_length 162 | return tokenizer 163 | 164 | 165 | def get_reward_model(args): 166 | if "gpt2" in args.reward_model_name: 167 | model = GPT2RewardModel( 168 | reward_model_name=args.reward_model_name, 169 | out_features=args.out_features, 170 | loss_fn=args.loss_fn 171 | ) 172 | return model 173 | 174 | 175 | def _strtobool(x): 176 | return bool(strtobool(x)) 177 | 178 | 179 | def read_yamls(dir): 180 | args = {} 181 | no_conf = True 182 | 183 | for config_file in Path(dir).glob("**/*.yaml"): 184 | no_conf = False 185 | with config_file.open("r") as f: 186 | args.update(yaml.safe_load(f)) 187 | 188 | if no_conf: 189 | print(f"WARNING: No yaml files found in {dir}") 190 | 191 | return args --------------------------------------------------------------------------------