├── src ├── trl │ ├── __init__.py │ ├── __pycache__ │ │ ├── core.cpython-39.pyc │ │ ├── gpt2.cpython-39.pyc │ │ ├── ppo.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ ├── _nbdev.py │ ├── core.py │ ├── gpt2.py │ └── ppo.py ├── topic │ ├── topic.txt │ └── prefix.txt ├── topic_test.py ├── sentiment_test.py ├── topic_train.py └── sentiment_train.py ├── Overview.png ├── ACL2023-CriticControl-Slides.pdf ├── requirements.txt └── README.md /src/trl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.1" 2 | -------------------------------------------------------------------------------- /Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/Overview.png -------------------------------------------------------------------------------- /src/topic/topic.txt: -------------------------------------------------------------------------------- 1 | Space 2 | Politics 3 | Military 4 | Legal 5 | Science 6 | Religion 7 | Computers -------------------------------------------------------------------------------- /ACL2023-CriticControl-Slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/ACL2023-CriticControl-Slides.pdf -------------------------------------------------------------------------------- /src/trl/__pycache__/core.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/src/trl/__pycache__/core.cpython-39.pyc -------------------------------------------------------------------------------- /src/trl/__pycache__/gpt2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/src/trl/__pycache__/gpt2.cpython-39.pyc -------------------------------------------------------------------------------- /src/trl/__pycache__/ppo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/src/trl/__pycache__/ppo.cpython-39.pyc -------------------------------------------------------------------------------- /src/trl/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minbeomkim/CriticControl/HEAD/src/trl/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trl 2 | datasets==1.9.0 3 | nltk==3.6.2 4 | numpy==1.19.5 5 | pandas==1.1.5 6 | scikit-learn==0.24.2 7 | sentencepiece==0.1.96 8 | torch==1.8.0 9 | tqdm==4.61.2 10 | transformers==4.8.2 11 | wandb==0.10.33 12 | -------------------------------------------------------------------------------- /src/topic/prefix.txt: -------------------------------------------------------------------------------- 1 | In summary 2 | This essay discusses 3 | Views on 4 | The connection 5 | Foundational to this is 6 | To review 7 | In brief, 8 | An illustration of 9 | Furthermore, 10 | The central theme 11 | To conclude, 12 | The key aspect' 13 | Prior to this 14 | Emphasised are 15 | To summarise 16 | The relationship 17 | More importantly 18 | It has been shown 19 | The issue focused on 20 | In this essay 21 | -------------------------------------------------------------------------------- /src/trl/_nbdev.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED BY NBDEV! DO NOT EDIT! 2 | 3 | __all__ = ["index", "modules", "custom_doc_links", "git_url"] 4 | 5 | index = {"WANDB_PADDING": "00-core.ipynb", 6 | "flatten_dict": "00-core.ipynb", 7 | "stack_dicts": "00-core.ipynb", 8 | "add_suffix": "00-core.ipynb", 9 | "pad_to_size": "00-core.ipynb", 10 | "logprobs_from_logits": "00-core.ipynb", 11 | "whiten": "00-core.ipynb", 12 | "clip_by_value": "00-core.ipynb", 13 | "entropy_from_logits": "00-core.ipynb", 14 | "average_torch_dicts": "00-core.ipynb", 15 | "stats_to_np": "00-core.ipynb", 16 | "listify_batch": "00-core.ipynb", 17 | "build_bert_batch_from_txt": "00-core.ipynb", 18 | "CausalLMOutputWithCrossAttentions": "01-gpt2-with-value-head.ipynb", 19 | "ValueHead": "01-gpt2-with-value-head.ipynb", 20 | "GPT2HeadWithValueModel": "01-gpt2-with-value-head.ipynb", 21 | "respond_to_batch": "01-gpt2-with-value-head.ipynb", 22 | "AdaptiveKLController": "02-ppo.ipynb", 23 | "FixedKLController": "02-ppo.ipynb", 24 | "PPOTrainer": "02-ppo.ipynb"} 25 | 26 | modules = ["core.py", 27 | "gpt2.py", 28 | "ppo.py"] 29 | 30 | doc_url = "https://lvwerra.github.io/trl/" 31 | 32 | git_url = "https://github.com/lvwerra/trl/tree/master/" 33 | 34 | def custom_doc_links(name): return None 35 | -------------------------------------------------------------------------------- /src/topic_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import time 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | from collections import Counter 9 | tqdm.pandas() 10 | 11 | from datasets import load_dataset 12 | import datasets 13 | 14 | from transformers import GPT2Tokenizer, T5Tokenizer, BertTokenizer, BartForConditionalGeneration, PreTrainedTokenizerFast, BartTokenizer 15 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5ForConditionalGeneration, BertForSequenceClassification, GPT2LMHeadModel 16 | from transformers import AutoTokenizer, pipeline, top_k_top_p_filtering 17 | import torch.nn.functional as F 18 | import torch 19 | 20 | from trl.gpt2 import GPT2HeadWithValueModel, topic_generation 21 | from trl.ppo import PPOTrainer, ppo_initialize 22 | from trl.core import build_bert_batch_from_txt, listify_batch 23 | 24 | from evaluate import load 25 | from rouge_score import rouge_scorer, scoring 26 | from nltk.tokenize import sent_tokenize 27 | from distinct import distinct 28 | 29 | config = { 30 | "model_name": "model/gpt2-xl-critic", 31 | "batch_size": 1, 32 | "forward_batch_size": 1, 33 | "ppo_epochs": 4, 34 | "lr": 1.41e-5, 35 | "init_kl_coef":0.2, 36 | "target": 6, 37 | "horizon":10000, 38 | "gamma":1, 39 | "lam":0.95, 40 | "cliprange": .2, 41 | "cliprange_value":.2, 42 | "vf_coef":.1, 43 | "topic": 'topic/topic.txt', 44 | "prompt": 'topic/prompt.txt', 45 | } 46 | 47 | device0 = torch.device("cuda:0") 48 | device1 = torch.device("cuda:1") 49 | 50 | sent_kwargs = { 51 | "return_all_scores": True, 52 | "function_to_apply": "none", 53 | "batch_size": config["forward_batch_size"] 54 | } 55 | 56 | gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 57 | gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name']) 58 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token 59 | prompt = ppo_initialize(config['topic'], config['prompt']) 60 | 61 | gpt2_model.to(device0) 62 | 63 | 64 | ###################### Experiment Setting ################### 65 | 66 | result_data = dict() 67 | 68 | #### get response from gpt2 and gpt2_ref 69 | with torch.no_grad(): 70 | 71 | #### Get response from gpt2 72 | response_list = [] 73 | for i in range(len(prompt)): 74 | query = gpt2_tokenizer.encode(prompt[i][1], return_tensors="pt").to(device0) 75 | response = topic_generation(gpt2_model, query) 76 | response_result = ':'.join(gpt2_tokenizer.decode(response.squeeze(), skip_special_tokens=True).split(':')[1:])[1:] 77 | response_list.append([prompt[i][0], response_result.replace("\n", " ")]) 78 | 79 | sentences = [] 80 | topics = [] 81 | for i in range(len(prompt)): 82 | sentences.append(response_list[i][1]) 83 | topics.append(response_list[i][0]) 84 | 85 | result_data['sentences'] = sentences 86 | result_data['topics'] = topics 87 | 88 | df_results = pd.DataFrame(result_data) 89 | save = df_results.to_json('json/topic.json', orient='table') 90 | -------------------------------------------------------------------------------- /src/sentiment_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import time 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | from collections import Counter 9 | tqdm.pandas() 10 | 11 | from datasets import load_dataset 12 | import datasets 13 | 14 | from transformers import GPT2Tokenizer, T5Tokenizer, BertTokenizer, BartForConditionalGeneration, PreTrainedTokenizerFast, BartTokenizer 15 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5ForConditionalGeneration, BertForSequenceClassification, GPT2LMHeadModel 16 | from transformers import AutoTokenizer, pipeline, top_k_top_p_filtering 17 | import torch.nn.functional as F 18 | import torch 19 | 20 | # from trl.gpt2 import GPT2HeadWithValueModel 21 | from trl.gpt2 import GPT2HeadWithValueModel, sentiment_generation 22 | from trl.ppo import PPOTrainer 23 | from trl.core import build_bert_batch_from_txt, listify_batch 24 | 25 | from evaluate import load 26 | from rouge_score import rouge_scorer, scoring 27 | from nltk.tokenize import sent_tokenize 28 | from distinct import distinct 29 | 30 | config = { 31 | "model_name": "model/gpt2-xl-critic", 32 | "cls_model_name": "model/distilbert-imdb", 33 | "lr": 1.41e-5, 34 | "init_kl_coef":0.2, 35 | "target": 6, 36 | "horizon":10000, 37 | "gamma":1, 38 | "lam":0.95, 39 | "cliprange": .2, 40 | "cliprange_value":.2, 41 | "vf_coef":.1, 42 | } 43 | 44 | # load imdb with datasets 45 | ds = load_dataset('imdb', split='test') 46 | ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'}) 47 | 48 | device0 = torch.device("cuda:0") 49 | device1 = torch.device("cuda:1") 50 | 51 | sent_kwargs = { 52 | "return_all_scores": True, 53 | "function_to_apply": "none", 54 | "batch_size": config["forward_batch_size"] 55 | } 56 | 57 | gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 58 | 59 | gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name']) 60 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token 61 | gpt2_model.to(device1) 62 | input_size = 8 63 | 64 | def tokenize(sample): 65 | sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()] 66 | sample["query"] = gpt2_tokenizer.decode(sample["tokens"]) 67 | return sample 68 | 69 | ds = ds.map(tokenize, batched=False) 70 | 71 | bs = 25000 72 | result_data = dict() 73 | ds.set_format("pandas") 74 | df_batch = ds[:] 75 | result_data['query'] = df_batch['query'].tolist() 76 | query_tensors = df_batch['tokens'].tolist() 77 | response_tensors = [] 78 | 79 | #### get response from gpt2 and gpt2_ref 80 | with torch.no_grad(): 81 | gpt2_model.eval() 82 | 83 | for i in range(bs): 84 | response = sentiment_generation(gpt2_model, torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device1)) 85 | response_tensors.append(response[0]) 86 | 87 | #### decode responses 88 | result_data['texts'] = [gpt2_tokenizer.decode(response_tensors[i]) for i in range(bs)] 89 | 90 | #### sentiment analysis of query/response pairs before/after 91 | texts = [q + r for q,r in zip(result_data['query'], result_data['texts'])] 92 | 93 | df_results = pd.DataFrame(result_data) 94 | save = df_results.to_json('json/sentiment.json', orient='table') 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CriticControl 2 | 3 | ![alt text](Overview.png "Main Figure") 4 | 5 | This is the GitHub repository of **"[Critic-Guided Decoding for Controlled Text Generation](https://aclanthology.org/2023.findings-acl.281/)"**, accepted at *ACL 2023 Findings*. Also, you can download ***a summarized PDF*** on this repository. (ACL2023-CriticControl-Slides) 6 | 7 | Use the following to cite our paper: 8 | 9 | ``` 10 | @inproceedings{kim-etal-2023-critic, 11 | title = "Critic-Guided Decoding for Controlled Text Generation", 12 | author = "Kim, Minbeom and 13 | Lee, Hwanhee and 14 | Yoo, Kang Min and 15 | Park, Joonsuk and 16 | Lee, Hwaran and 17 | Jung, Kyomin", 18 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2023", 19 | month = jul, 20 | year = "2023", 21 | address = "Toronto, Canada", 22 | publisher = "Association for Computational Linguistics", 23 | url = "https://aclanthology.org/2023.findings-acl.281", 24 | pages = "4598--4612", 25 | abstract = "Steering language generation towards objectives or away from undesired content has been a long-standing goal in utilizing language models (LM). Recent work has demonstrated reinforcement learning and weighted decoding as effective approaches to achieve a higher level of language control and quality with pros and cons. In this work, we propose a novel critic decoding method for controlled language generation (CriticControl) that combines the strengths of reinforcement learning and weighted decoding. Specifically, we adopt the actor-critic framework and train an LM-steering critic from reward models. Similar to weighted decoding, our method freezes the language model and manipulates the output token distribution using a critic to improve training efficiency and stability. Evaluation of our method on three controlled generation tasks, topic control, sentiment control, and detoxification, shows that our approach generates more coherent and well-controlled texts than previous methods. In addition, CriticControl demonstrates superior generalization ability in zero-shot settings. Human evaluation studies also corroborate our findings.", 26 | } 27 | ``` 28 | 29 | ## Create conda environment and install requirements 30 | ``` 31 | conda create -n CriticControl python=3.8 && conda activate CriticControl 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Train your own Control Codes 36 | Both Topic and Sentiment Control require your own control code. For Topic Control Task, the default settings are in 'src/topic/*.txt'. In Sentiment Control, the default setting is not binary but only for 'Positive' steering. If you want to make it binary control, make control codes also for Sentiment Control task as Topic Control. 37 | 38 | ``` 39 | python3 topic_train.py --model_name gpt2-xl --steps 40000 --batch_size 32 --topic topic/topic.txt --prompt topic/prompt.txt 40 | ``` 41 | 42 | You can set any topic for your own Control Codes. In my paper, you can find that CriticControl can steer diverse themes anything! 43 | 44 | ## Inference with CriticControl 45 | You can find amazing zero-shot control power in my **[paper]((https://aclanthology.org/2023.findings-acl.281/))**. You can choose other topics rather than using only trained topic codes (such as Donald Trump, New York Travel). 46 | 47 | ``` 48 | python3 topic_text.py --model_name [your directory] --topic topic/inference_topic.txt --prompt topic/prompt.txt 49 | ``` 50 | 51 | If you want to leverage other decoding methods, change and adapt your own decoder in ##CriticControl Decoding of 'src/trl/gpt2' 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/topic_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import time 4 | import os 5 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pandas as pd 9 | tqdm.pandas() 10 | 11 | from datasets import load_dataset 12 | 13 | import torch.nn.functional as F 14 | import torch 15 | 16 | from transformers import AutoTokenizer, pipeline 17 | 18 | from trl.gpt2 import GPT2HeadWithValueModel 19 | from trl.ppo import PPOTrainer, ppo_initialize 20 | from trl.core import build_bert_batch_from_txt, listify_batch 21 | 22 | import random 23 | 24 | config = { 25 | "model_name": "gpt2-xl", 26 | "steps": 40000, 27 | "batch_size": 32, 28 | "forward_batch_size": 32, 29 | "ppo_epochs": 1, 30 | "lr": 2.82e-6, 31 | "init_kl_coef":0.2, 32 | "target": 6, 33 | "horizon":10000, 34 | "gamma":0.99, 35 | "lam":0.95, 36 | "cliprange": .2, 37 | "cliprange_value":.2, 38 | "vf_coef":.1, 39 | "topic": 'topic/topic.txt', 40 | "prompt": 'topic/prompt.txt', 41 | } 42 | 43 | device0 = torch.device("cuda:0") 44 | device1 = torch.device("cuda:1") 45 | print(device0) 46 | print(device1) 47 | 48 | sent_kwargs = { 49 | "return_all_scores": True, 50 | "function_to_apply": "none", 51 | "batch_size": config["forward_batch_size"] 52 | } 53 | 54 | gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 55 | gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 56 | 57 | gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name']) 58 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token 59 | 60 | gpt2_model.to(device0) 61 | gpt2_model_ref.to(device0) 62 | topic_pipe = pipeline("zero-shot-classification","facebook/bart-large-mnli", tokenizer='facebook/bart-large-mnli', device=0) # reward 63 | 64 | ############# for policy freezing ############ 65 | for module in [gpt2_model.transformer, gpt2_model.lm_head]: 66 | for param in module.parameters(): 67 | param.requires_grad = False 68 | 69 | gen_kwargs = { 70 | "top_k": 0.0, 71 | "top_p": 1.0, 72 | "do_sample": True, 73 | "pad_token_id": gpt2_tokenizer.eos_token_id, 74 | "max_new_tokens": 80, 75 | "temperature": 2.8 76 | } 77 | 78 | ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config) 79 | prompt = ppo_initialize(config['topic'], config['prompt']) 80 | batch_size = config['batch_size'] 81 | 82 | for epoch in tqdm(range(500)): 83 | 84 | print(epoch) 85 | 86 | torch.cuda.empty_cache() 87 | logs = dict() 88 | game_data = dict() 89 | timing = dict() 90 | t0 = time.time() 91 | 92 | #### Get response from gpt2 93 | t = time.time() 94 | query_tensors = [] 95 | response_tensors = [] 96 | response_list = [] 97 | prompt_list = random.sample(prompt, batch_size) 98 | for i in range(batch_size): 99 | query = gpt2_tokenizer.encode(prompt_list[i][1], return_tensors="pt").to(device0) 100 | response = gpt2_model.generate(query, **gen_kwargs) 101 | query_tensors.append(query.squeeze()) 102 | response_tensors.append(response.squeeze()) 103 | response_list.append([prompt_list[i][0], ':'.join(gpt2_tokenizer.decode(response.squeeze(), skip_special_tokens=True, cleaned_up_tokenization_spaces=False).split(':')[1:])[1:]]) 104 | timing['time/get_response'] = time.time()-t 105 | 106 | # MNLI Score 107 | rewards = [] 108 | for i in range(batch_size): 109 | prob = topic_pipe(response_list[i][1], response_list[i][0], multi_label = False)["scores"][0] 110 | logit = -np.log(1/prob -1)+4 111 | rewards.append(logit) 112 | rewards = torch.tensor(rewards).to(device0) 113 | timing['time/get_sentiment_preds'] = time.time()-t 114 | print('Total Rewards: ', torch.mean(rewards)) 115 | 116 | # Run PPO training 117 | t = time.time() 118 | rewards = rewards.cpu() 119 | rewards = rewards.to(device0) 120 | stats = ppo_trainer.step(query_tensors, response_tensors, rewards) 121 | timing['time/optimization'] = time.time()-t 122 | 123 | # Log everything 124 | timing['time/epoch'] = time.time()-t0 125 | logs.update(timing) 126 | logs.update(stats) 127 | logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy() 128 | logs['env/reward_std'] = torch.std(rewards).cpu().numpy() 129 | logs['env/reward_dist'] = rewards.cpu().numpy() 130 | wandb.log(logs) 131 | 132 | num = num+1 133 | 134 | os.makedirs(experiment_name) 135 | gpt2_model.save_pretrained(experiment_name) 136 | gpt2_tokenizer.save_pretrained(experiment_name) 137 | -------------------------------------------------------------------------------- /src/sentiment_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import time 4 | import os 5 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pandas as pd 9 | tqdm.pandas() 10 | 11 | from datasets import load_dataset 12 | 13 | from transformers import AutoTokenizer, pipeline 14 | 15 | from trl.gpt2 import GPT2HeadWithValueModel 16 | from trl.ppo import PPOTrainer 17 | from trl.core import build_bert_batch_from_txt, listify_batch 18 | 19 | config = { 20 | "model_name": "gpt2-xl", 21 | "cls_model_name": "model/distilbert-imdb", 22 | "steps": 40000, 23 | "batch_size": 128, 24 | "forward_batch_size": 32, 25 | "ppo_epochs": 4, 26 | "lr": 1.41e-5, 27 | "init_kl_coef":0.2, 28 | "target": 6, 29 | "horizon":10000, 30 | "gamma":0.99, 31 | "lam":0.95, 32 | "cliprange": .2, 33 | "cliprange_value":.2, 34 | "vf_coef":.1, 35 | } 36 | 37 | experiment_name = 'model/gpt2-xl-critic' 38 | 39 | # load imdb with datasets 40 | ds = load_dataset('imdb', split='train') 41 | ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'}) 42 | 43 | device0 = torch.device("cuda:0") 44 | device1 = torch.device("cuda:1") 45 | 46 | sent_kwargs = { 47 | "return_all_scores": True, 48 | "function_to_apply": "none", 49 | "batch_size": config["forward_batch_size"] 50 | } 51 | 52 | gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 53 | gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name']) 54 | 55 | gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name']) 56 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token 57 | 58 | gpt2_model.to(device0) 59 | gpt2_model_ref.to(device0) 60 | sentiment_pipe = pipeline("sentiment-analysis",config['cls_model_name'], device=0) # reward 61 | 62 | # Freezing LM 63 | for module in [gpt2_model.transformer, gpt2_model.lm_head]: 64 | for param in module.parameters(): 65 | param.requires_grad = False 66 | 67 | input_size = 32 68 | 69 | def tokenize(sample): 70 | sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()] 71 | sample["query"] = gpt2_tokenizer.decode(sample["tokens"]) 72 | return sample 73 | 74 | ds = ds.map(tokenize, batched=False) 75 | 76 | 77 | gen_kwargs = { 78 | "top_k": 0.0, 79 | "top_p": 1.0, 80 | "do_sample": True, 81 | "pad_token_id": gpt2_tokenizer.eos_token_id, 82 | "max_new_tokens": 25, 83 | "temperature": 2.0 84 | } 85 | 86 | def collater(data): 87 | return dict((key, [d[key] for d in data]) for key in data[0]) 88 | 89 | dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater, shuffle=True) 90 | ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config) 91 | total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size'])) 92 | 93 | for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))): 94 | logs, timing = dict(), dict() 95 | t0 = time.time() 96 | query_tensors = [torch.tensor(t).long().to(device0) for t in batch["tokens"]] 97 | 98 | #### Get response from gpt2 99 | t = time.time() 100 | response_tensors = [] 101 | for i in range(config['batch_size']): 102 | response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0), **gen_kwargs) 103 | response_tensors.append(response.squeeze()) 104 | batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors] 105 | timing['time/get_response'] = time.time()-t 106 | 107 | #### Compute sentiment score 108 | t = time.time() 109 | texts = [q + r for q,r in zip(batch['query'], batch['response'])] 110 | 111 | print(texts[0]) 112 | print(' ') 113 | 114 | # print(texts) 115 | 116 | pipe_outputs = sentiment_pipe(texts, **sent_kwargs) 117 | rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device0) 118 | timing['time/get_sentiment_preds'] = time.time()-t 119 | 120 | print(torch.mean(rewards)) 121 | print(' ') 122 | 123 | #### Run PPO step 124 | t = time.time() 125 | stats = ppo_trainer.step(query_tensors, response_tensors, rewards) 126 | timing['time/optimization'] = time.time()-t 127 | 128 | #### Log everything 129 | timing['time/epoch'] = time.time()-t0 130 | logs.update(timing) 131 | logs.update(stats) 132 | logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy() 133 | logs['env/reward_std'] = torch.std(rewards).cpu().numpy() 134 | logs['env/reward_dist'] = rewards.cpu().numpy() 135 | 136 | 137 | os.makedirs(experiment_name) 138 | gpt2_model.save_pretrained(experiment_name) 139 | gpt2_tokenizer.save_pretrained(experiment_name) 140 | -------------------------------------------------------------------------------- /src/trl/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00-core.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['WANDB_PADDING', 'flatten_dict', 'stack_dicts', 'add_suffix', 'pad_to_size', 'logprobs_from_logits', 4 | 'whiten', 'clip_by_value', 'entropy_from_logits', 'average_torch_dicts', 'stats_to_np', 'listify_batch', 5 | 'build_bert_batch_from_txt'] 6 | 7 | # Cell 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | import collections 13 | import numpy as np 14 | 15 | # Cell 16 | WANDB_PADDING = -1 17 | 18 | # Cell 19 | 20 | def flatten_dict(nested, sep='/'): 21 | """Flatten dictionary and concatenate nested keys with separator.""" 22 | def rec(nest, prefix, into): 23 | for k, v in nest.items(): 24 | if sep in k: 25 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 26 | if isinstance(v, collections.Mapping): 27 | rec(v, prefix + k + sep, into) 28 | else: 29 | into[prefix + k] = v 30 | flat = {} 31 | rec(nested, '', flat) 32 | return flat 33 | 34 | def stack_dicts(stats_dicts): 35 | """Stack the values of a dict.""" 36 | results = dict() 37 | for k in stats_dicts[0]: 38 | stats_list = [torch.flatten(d[k]) for d in stats_dicts] 39 | results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) 40 | return results 41 | 42 | def add_suffix(input_dict, suffix): 43 | """Add suffix to dict keys.""" 44 | return dict((k + suffix, v) for k,v in input_dict.items()) 45 | 46 | # Cell 47 | 48 | def pad_to_size(tensor, size, dim=1, padding=50256): 49 | """Pad tensor to size.""" 50 | t_size = tensor.size()[dim] 51 | if t_size==size: 52 | return tensor 53 | else: 54 | return torch.nn.functional.pad(tensor, (0,size-t_size), 'constant', padding) 55 | 56 | def logprobs_from_logits(logits, labels): 57 | """ 58 | See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 59 | """ 60 | logp = F.log_softmax(logits, dim=2) 61 | logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) 62 | return logpy 63 | 64 | 65 | def whiten(values, shift_mean=True): 66 | """Whiten values.""" 67 | mean, var = torch.mean(values), torch.var(values) 68 | whitened = (values - mean) * torch.rsqrt(var + 1e-8) 69 | if not shift_mean: 70 | whitened += mean 71 | return whitened 72 | 73 | def clip_by_value(x, tensor_min, tensor_max): 74 | """ 75 | Tensor extenstion to torch.clamp 76 | https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 77 | """ 78 | clipped = torch.max(torch.min(x, tensor_max), tensor_min) 79 | return clipped 80 | 81 | def entropy_from_logits(logits): 82 | """Calculate entropy from logits.""" 83 | pd = torch.nn.functional.softmax(logits, dim=-1) 84 | entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd*logits, axis=-1) 85 | return entropy 86 | 87 | 88 | def average_torch_dicts(list_of_dicts): 89 | """Average values of a list of dicts wiht torch tensors.""" 90 | average_dict = dict() 91 | for key in list_of_dicts[0].keys(): 92 | average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) 93 | return average_dict 94 | 95 | def stats_to_np(stats_dict): 96 | """Cast all torch.tensors in dict to numpy arrays.""" 97 | new_dict = dict() 98 | for k, v in stats_dict.items(): 99 | if isinstance(v, torch.Tensor): 100 | new_dict[k] = v.detach().cpu().numpy() 101 | else: 102 | new_dict[k] = v 103 | if np.isscalar(new_dict[k]): 104 | new_dict[k] = float(new_dict[k]) 105 | return new_dict 106 | 107 | def listify_batch(tensor): 108 | """Turns the first dimension of a tensor into a list.""" 109 | return [tensor[i] for i in range(tensor.shape[0])] 110 | 111 | # Cell 112 | 113 | def build_bert_batch_from_txt(text_list, tokenizer, device): 114 | """Create token id and attention mask tensors from text list for BERT classification.""" 115 | 116 | # tokenize 117 | tensors = [tokenizer.encode(txt, return_tensors="pt").to(device) for txt in text_list] 118 | 119 | # find max length to pad to 120 | max_len = max([t.size()[1] for t in tensors]) 121 | 122 | # get padded tensors and attention masks 123 | # (attention masks make bert ignore padding) 124 | padded_tensors = [] 125 | attention_masks = [] 126 | for tensor in tensors: 127 | attention_mask = torch.ones(tensor.size(), device=device) 128 | padded_tensors.append(pad_to_size(tensor, max_len, padding=0)) 129 | attention_masks.append(pad_to_size(attention_mask, max_len, padding=0)) 130 | 131 | # stack all tensors 132 | padded_tensors = torch.cat(padded_tensors) 133 | attention_masks = torch.cat(attention_masks) 134 | 135 | return padded_tensors, attention_masks 136 | -------------------------------------------------------------------------------- /src/trl/gpt2.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01-gpt2-with-value-head.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['CausalLMOutputWithCrossAttentions', 'ValueHead', 'GPT2HeadWithValueModel', 'respond_to_batch'] 4 | 5 | # Cell 6 | 7 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel 8 | from transformers import top_k_top_p_filtering 9 | from transformers.modeling_outputs import ModelOutput 10 | from torch import nn 11 | from torch.nn import Identity 12 | import torch.nn.functional as F 13 | import torch 14 | from dataclasses import dataclass 15 | from typing import Optional, Tuple 16 | 17 | # Cell 18 | @dataclass 19 | class CausalLMOutputWithCrossAttentions(ModelOutput): 20 | loss: Optional[torch.FloatTensor] = None 21 | logits: torch.FloatTensor = None 22 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 23 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 24 | attentions: Optional[Tuple[torch.FloatTensor]] = None 25 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None 26 | value: Optional[torch.FloatTensor] = None 27 | 28 | # Cell 29 | 30 | class ValueHead(nn.Module): 31 | """The ValueHead class implements a head for GPT2 that returns a scalar for each output token.""" 32 | def __init__(self, config): 33 | super().__init__() 34 | self.detach_head = False 35 | self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" 36 | if self.summary_type == "attn": 37 | raise NotImplementedError 38 | 39 | self.summary = Identity() 40 | if hasattr(config, "summary_use_proj") and config.summary_use_proj: 41 | if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: 42 | num_classes = config.num_labels 43 | else: 44 | num_classes = config.hidden_size 45 | self.summary = nn.Linear(config.hidden_size, num_classes) 46 | 47 | self.activation = Identity() 48 | if hasattr(config, "summary_activation") and config.summary_activation == "tanh": 49 | self.activation = nn.Tanh() 50 | 51 | self.first_dropout = Identity() 52 | if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: 53 | self.first_dropout = nn.Dropout(config.summary_first_dropout) 54 | 55 | self.last_dropout = Identity() 56 | if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: 57 | self.last_dropout = nn.Dropout(config.summary_last_dropout) 58 | 59 | self.flatten = nn.Flatten() 60 | 61 | def forward(self, hidden_states, cls_index=None): 62 | if self.detach_head: 63 | output = hidden_states.detach() 64 | else: 65 | output = hidden_states 66 | output = self.first_dropout(output) 67 | output = self.summary(output) 68 | output = self.activation(output) 69 | output = self.last_dropout(output) 70 | 71 | return output 72 | 73 | # Cell 74 | 75 | class GPT2HeadWithValueModel(GPT2LMHeadModel): 76 | """The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.""" 77 | def __init__(self, config): 78 | super().__init__(config) 79 | config.num_labels = 1 80 | self.transformer = GPT2Model(config) 81 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 82 | self.v_head = ValueHead(config) 83 | 84 | self.init_weights() 85 | 86 | def get_output_embeddings(self): 87 | return self.lm_head 88 | 89 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 90 | # only last token for inputs_ids if past is defined in kwargs 91 | if past: 92 | input_ids = input_ids[:, -1].unsqueeze(-1) 93 | 94 | attention_mask = kwargs.get("attention_mask", None) 95 | position_ids = kwargs.get("position_ids", None) 96 | 97 | if attention_mask is not None and position_ids is None: 98 | # create position_ids on the fly for batch generation 99 | position_ids = attention_mask.long().cumsum(-1) - 1 100 | position_ids.masked_fill_(attention_mask == 0, 1) 101 | if past: 102 | position_ids = position_ids[:, -1].unsqueeze(-1) 103 | else: 104 | position_ids = None 105 | return { 106 | "input_ids": input_ids, 107 | "past_key_values": past, 108 | "use_cache": kwargs.get("use_cache"), 109 | "position_ids": position_ids, 110 | "attention_mask": attention_mask, 111 | } 112 | 113 | def detach_value_head(self): 114 | self.v_head.detach_head = True 115 | 116 | def forward( 117 | self, 118 | input_ids=None, 119 | past_key_values=None, 120 | attention_mask=None, 121 | token_type_ids=None, 122 | position_ids=None, 123 | head_mask=None, 124 | inputs_embeds=None, 125 | mc_token_ids=None, 126 | lm_labels=None, 127 | mc_labels=None, 128 | use_cache=None, 129 | return_dict=False, 130 | output_attentions=False, 131 | output_hidden_states=False, 132 | ): 133 | loss=None 134 | transformer_outputs = self.transformer( 135 | input_ids, 136 | past_key_values=past_key_values, 137 | attention_mask=attention_mask, 138 | token_type_ids=token_type_ids, 139 | position_ids=position_ids, 140 | head_mask=head_mask, 141 | use_cache=None, 142 | inputs_embeds=inputs_embeds, 143 | ) 144 | 145 | hidden_states = transformer_outputs[0] 146 | 147 | lm_logits = self.lm_head(hidden_states) 148 | value = self.v_head(hidden_states).squeeze(-1) 149 | 150 | if not return_dict: 151 | outputs = (lm_logits,) + transformer_outputs[1:] + (value,) 152 | return outputs 153 | 154 | return CausalLMOutputWithCrossAttentions( 155 | loss=loss, 156 | logits=lm_logits, 157 | past_key_values=transformer_outputs.past_key_values, 158 | hidden_states=transformer_outputs.hidden_states, 159 | attentions=transformer_outputs.attentions, 160 | cross_attentions=transformer_outputs.cross_attentions, 161 | value=value, 162 | ) 163 | return outputs 164 | 165 | ## CriticControl 166 | def sentiment_generation(model, queries, txt_len=25, top_vocab=10, top_p=0.9, no_repeat_ngram=4): 167 | """Sample text from language model.""" 168 | input_ids = queries 169 | ngram_list = dict() 170 | next_token_id = 0 171 | for i in range(txt_len): 172 | # Get Logits 173 | outputs = model(input_ids) 174 | next_token_logits = outputs[0][:, -1, :] 175 | probs = F.softmax(next_token_logits, dim=-1) 176 | V_value = outputs[2].unsqueeze(-1)[:, -1, :] 177 | # Sample 178 | _, candidate_tokens = torch.topk(probs, top_vocab, dim=-1) 179 | for _, Q_token in enumerate(candidate_tokens[0]): 180 | Q_value = model(torch.cat([input_ids, Q_token.view([1,1])], dim=-1))[2].unsqueeze(-1)[:, -1, :] 181 | probs[0][Q_token.item()] = probs[0][Q_token.item()] * (torch.nn.Sigmoid()(Q_value) / torch.nn.Sigmoid()(V_value)) 182 | if tuple(input_ids[0][-no_repeat_ngram+1:].tolist()) in ngram_list.keys(): 183 | banned_token_list = ngram_list[tuple(input_ids[0][-no_repeat_ngram+1:].tolist())] 184 | for _, banned_token in enumerate(banned_token_list): 185 | probs[0][banned_token] = 0 186 | probs = F.softmax(probs, dim=-1) 187 | probs = top_k_top_p_filtering(probs, top_p=top_p) 188 | probs = F.softmax(probs, dim=-1) 189 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) 190 | input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) 191 | if tuple(input_ids[0][-no_repeat_ngram:-1].tolist()) in ngram_list.keys(): 192 | ngram_list[tuple(input_ids[0][-no_repeat_ngram:-1].tolist())].append(next_token.item()) 193 | else: 194 | ngram_list[tuple(input_ids[0][-no_repeat_ngram:-1].tolist())] = [next_token.item()] 195 | return input_ids 196 | 197 | ## CriticControl 198 | def topic_generation(model, queries, txt_len=80, top_k=10, top_p=1.0, no_repeat_ngram=4): 199 | """Sample text from language model.""" 200 | input_ids = queries 201 | ngram_list = dict() 202 | next_token_id = 0 203 | for i in range(txt_len): 204 | # Get Logits 205 | outputs = model(input_ids) 206 | next_token_logits = outputs[0][:, -1, :] 207 | V_value = outputs[2].unsqueeze(-1)[:, -1, :] 208 | probs = F.softmax(next_token_logits, dim=-1) 209 | _, candidate_tokens = torch.topk(probs, top_k, dim=-1) 210 | # Distribution Shift! 211 | for _, Q_token in enumerate(candidate_tokens[0]): 212 | Q_value = model(torch.cat([input_ids, Q_token.view([1,1])], dim=-1))[2].unsqueeze(-1)[:, -1, :] 213 | probs[0][Q_token.item()] = probs[0][Q_token.item()] * (torch.nn.Sigmoid()(Q_value) / torch.nn.Sigmoid()(V_value))**(1/1) 214 | if tuple(input_ids[0][-no_repeat_ngram+1:].tolist()) in ngram_list.keys(): 215 | banned_token_list = ngram_list[tuple(input_ids[0][-no_repeat_ngram+1:].tolist())] 216 | for _, banned_token in enumerate(banned_token_list): 217 | probs[0][banned_token] = 0 218 | _, next_token = torch.topk(probs, 1, dim=-1) 219 | input_ids = torch.cat([input_ids, next_token], dim=-1) 220 | if tuple(input_ids[0][-no_repeat_ngram:-1].tolist()) in ngram_list.keys(): 221 | ngram_list[tuple(input_ids[0][-no_repeat_ngram:-1].tolist())].append(next_token.item()) 222 | else: 223 | ngram_list[tuple(input_ids[0][-no_repeat_ngram:-1].tolist())] = [next_token.item()] 224 | return input_ids -------------------------------------------------------------------------------- /src/trl/ppo.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02-ppo.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['AdaptiveKLController', 'FixedKLController', 'PPOTrainer'] 4 | 5 | # Cell 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.optim import Adam 9 | import torch 10 | import collections 11 | import time 12 | import random 13 | 14 | from transformers import DataCollatorForLanguageModeling 15 | 16 | from .core import (logprobs_from_logits, 17 | whiten, 18 | clip_by_value, 19 | entropy_from_logits, 20 | flatten_dict, 21 | average_torch_dicts, 22 | stats_to_np, 23 | stack_dicts, 24 | add_suffix, 25 | WANDB_PADDING) 26 | 27 | # Cell 28 | 29 | class AdaptiveKLController: 30 | """ 31 | Adaptive KL controller described in the paper: 32 | https://arxiv.org/pdf/1909.08593.pdf 33 | """ 34 | def __init__(self, init_kl_coef, target, horizon): 35 | self.value = init_kl_coef 36 | self.target = target 37 | self.horizon = horizon 38 | 39 | def update(self, current, n_steps): 40 | target = self.target 41 | proportional_error = np.clip(current / target - 1, -0.2, 0.2) 42 | mult = 1 + proportional_error * n_steps / self.horizon 43 | self.value *= mult 44 | 45 | # Cell 46 | 47 | class FixedKLController: 48 | """Fixed KL controller.""" 49 | def __init__(self, kl_coef): 50 | self.value = kl_coef 51 | 52 | def update(self, current, n_steps): 53 | pass 54 | 55 | # Cell 56 | 57 | def ppo_initialize(self, topic_link, prefix_link): 58 | topic_label = open(topic_link,'r').read().split('\n') 59 | prefix_label = open(prefix_link,'r').read().split('\n') 60 | prompt = [] 61 | for i in range(140): 62 | topic, prefix = divmod(i, 20) 63 | text = topic_label[topic]+': '+ prefix_label[prefix] 64 | prompt.append([topic_label[topic], text]) 65 | return prompt 66 | 67 | 68 | class PPOTrainer: 69 | """ 70 | The PPO_trainer uses Proximal Policy Optimization to optimise language models. 71 | """ 72 | 73 | default_params = { 74 | "lr": 1.41e-5, 75 | "adap_kl_ctrl": True, 76 | "init_kl_coef":0.2, 77 | "target": 6, 78 | "horizon":10000, 79 | "gamma":1, 80 | "lam":0.95, 81 | "cliprange": .2, 82 | "cliprange_value":.2, 83 | "vf_coef":.1, 84 | "batch_size": 256, 85 | "forward_batch_size": 16, 86 | "ppo_epochs": 4, 87 | } 88 | 89 | def __init__(self, model, ref_model, tokenizer, **ppo_params): 90 | """ 91 | Initialize PPOTrainer. 92 | 93 | Args: 94 | model (torch.model): Hugging Face transformer GPT2 model with value head 95 | ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty 96 | tokenizer (tokenizer): Hugging Face tokenizer 97 | ppo_params (dict or None): PPO parameters for training. Can include following keys: 98 | 'lr' (float): Adam learning rate, default: 1.41e-5 99 | 'batch_size' (int): Number of samples per optimisation step, default: 256 100 | 'forward_batch_size' (int): Number of samples forward passed through model at a time, default: 16 101 | 'ppo_epochs' (int): Number of optimisation epochs per batch of samples, default: 4 102 | 'gamma' (float)): Gamma parameter for advantage calculation, default: 1. 103 | 'lam' (float): Lambda parameter for advantage calcualation, default: 0.95 104 | 'cliprange_value' (float): Range for clipping values in loss calculation, default: 0.2 105 | 'cliprange' (float): Range for clipping in PPO policy gradient loss, default: 0.2 106 | 'vf_coef' (float): Scaling factor for value loss, default: 0.1 107 | 'adap_kl_ctrl' (bool): Use adaptive KL control, otherwise linear, default: True 108 | 'init_kl_coef' (float): Initial KL penalty coefficient (used for adaptive and linear control), default: 0.2 109 | 'target' (float): Target KL value for adaptive KL control, default: 6.0 110 | 'horizon' (float): Horizon for adaptive KL control, default: 10000 111 | 112 | """ 113 | self.ppo_params = self.default_params 114 | self.ppo_params.update(ppo_params) 115 | 116 | self.ref_model = ref_model 117 | self.model = model 118 | self.tokenizer = tokenizer 119 | self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) 120 | 121 | self.optimizer = Adam(model.parameters(), lr=self.ppo_params['lr']) 122 | 123 | if self.ppo_params['adap_kl_ctrl']: 124 | self.kl_ctl = AdaptiveKLController(self.ppo_params['init_kl_coef'], 125 | self.ppo_params['target'], 126 | self.ppo_params['horizon']) 127 | else: 128 | self.kl_ctl = FixedKLController(self.ppo_params['init_kl_coef']) 129 | 130 | 131 | def step(self, queries, responses, scores): 132 | """ 133 | Run a PPO optimisation step. 134 | 135 | args: 136 | queries (List): List of tensors containing the encoded queries, shape [query_length] 137 | responses (List): List of tensors containing the encoded responses, shape [response_length] 138 | scores (List): tensor containing the scores, shape [batch_size] 139 | 140 | returns: 141 | train_stats (dict): a summary of the training statistics 142 | """ 143 | 144 | bs = self.ppo_params['batch_size'] 145 | assert bs == len(queries), f"Batch size ({bs}) does not match number of examples ({len(queries)})" 146 | 147 | timing = dict() 148 | t0 = time.time() 149 | 150 | response_lengths = [len(r) for r in responses] 151 | 152 | t = time.time() 153 | logprobs, ref_logprobs, values = self.batched_forward_pass(queries, responses) 154 | timing['time/ppo/forward_pass'] = time.time()-t 155 | 156 | t = time.time() 157 | rewards, non_score_reward = self.compute_rewards(scores, logprobs, ref_logprobs) 158 | timing['time/ppo/compute_rewards'] = time.time()-t 159 | 160 | t = time.time() 161 | all_stats = [] 162 | idxs = list(range(bs)) 163 | for _ in range(self.ppo_params['ppo_epochs']): 164 | random.shuffle(idxs) 165 | for i in range(bs): 166 | idx = idxs[i] 167 | train_stats = self.train_minibatch(logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0), 168 | rewards[idx].unsqueeze(0), queries[idx].unsqueeze(0), 169 | responses[idx].unsqueeze(0), 170 | torch.cat([queries[idx],responses[idx]]).unsqueeze(0)) 171 | all_stats.append(train_stats) 172 | timing['time/ppo/optimize_step'] = time.time()-t 173 | 174 | t = time.time() 175 | train_stats = stack_dicts(all_stats) 176 | 177 | # reshape advantages/ratios such that they are not averaged. 178 | train_stats['policy/advantages'] = torch.flatten(train_stats['policy/advantages']).unsqueeze(0) 179 | train_stats['policy/advantages'] = torch.nan_to_num(train_stats['policy/advantages'], WANDB_PADDING) 180 | train_stats['policy/ratio'] = torch.flatten(train_stats['policy/ratio']).unsqueeze(0) 181 | 182 | stats = self.record_step_stats(scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs, 183 | non_score_reward=non_score_reward, train_stats=train_stats, 184 | kl_coef=self.kl_ctl.value) 185 | stats = stats_to_np(stats) 186 | timing['time/ppo/calc_stats'] = time.time()-t 187 | 188 | self.kl_ctl.update(stats['objective/kl'], self.ppo_params['batch_size']) 189 | 190 | timing['time/ppo/total'] = time.time()-t0 191 | stats.update(timing) 192 | return stats 193 | 194 | def batched_forward_pass(self, queries, responses): 195 | """Calculate model outputs in multiple batches.""" 196 | bs = self.ppo_params['batch_size'] 197 | fbs = self.ppo_params['forward_batch_size'] 198 | all_logprobs = [] 199 | all_ref_logprobs = [] 200 | all_values = [] 201 | 202 | for i in range(int(bs/fbs)): 203 | query_batch = queries[i*fbs:(i+1)*fbs] 204 | response_batch = responses[i*fbs:(i+1)*fbs] 205 | input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"] 206 | with torch.no_grad(): 207 | logits, _, v = self.model(input_ids) 208 | ref_logits, _, _ = self.ref_model(input_ids) 209 | logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:]) 210 | ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:]) 211 | for j in range(fbs): 212 | start = len(query_batch[j])-1 213 | end = len(query_batch[j]) + len(response_batch[j])-1 214 | all_values.append(v[j, start-1:end-1]) 215 | all_logprobs.append(logprobs[j, start:end]) 216 | all_ref_logprobs.append(ref_logprobs[j, start:end]) 217 | return all_logprobs, all_ref_logprobs, all_values 218 | 219 | def train_minibatch(self, logprobs, values, rewards, query, response, model_input): 220 | """Train one PPO minibatch""" 221 | loss_p, loss_v, train_stats = self.loss(logprobs, values, rewards, query, response, model_input) 222 | loss = loss_p + loss_v 223 | self.optimizer.zero_grad() 224 | loss.backward() 225 | self.optimizer.step() 226 | return train_stats 227 | 228 | def compute_rewards(self, scores, logprobs, ref_logprobs): 229 | """Compute per token rewards from scores and KL-penalty.""" 230 | rewards, non_score_rewards = [], [] 231 | for score, logprob, ref_logprob in zip(scores, logprobs, ref_logprobs): 232 | kl = logprob - ref_logprob 233 | non_score_reward = -self.kl_ctl.value * kl 234 | non_score_rewards.append(non_score_reward) 235 | reward = non_score_reward.clone() 236 | reward[-1] += score 237 | rewards.append(reward) 238 | return rewards, non_score_rewards 239 | 240 | def loss(self, old_logprobs, values, rewards, query, response, model_input): 241 | """Calculate policy and value losses.""" 242 | lastgaelam = 0 243 | advantages_reversed = [] 244 | gen_len = response.shape[1] 245 | 246 | for t in reversed(range(gen_len)): 247 | nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 248 | delta = rewards[:, t] + self.ppo_params['gamma'] * nextvalues - values[:, t] 249 | lastgaelam = delta + self.ppo_params['gamma'] * self.ppo_params['lam'] * lastgaelam 250 | advantages_reversed.append(lastgaelam) 251 | advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) 252 | 253 | returns = advantages + values 254 | advantages = whiten(advantages) 255 | advantages = advantages.detach() 256 | 257 | logits, _, vpred = self.model(model_input) 258 | logprob = logprobs_from_logits(logits[:,:-1,:], model_input[:, 1:]) 259 | 260 | #only the generation part of the values/logprobs is needed 261 | logprob, vpred = logprob[:, -gen_len:], vpred[:,-gen_len-1:-1] 262 | 263 | vpredclipped = clip_by_value(vpred, 264 | values - self.ppo_params["cliprange_value"], 265 | values + self.ppo_params["cliprange_value"]) 266 | 267 | vf_losses1 = (vpred - returns)**2 268 | vf_losses2 = (vpredclipped - returns)**2 269 | vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2)) 270 | vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double()) 271 | 272 | ratio = torch.exp(logprob - old_logprobs) 273 | 274 | pg_losses = -advantages * ratio 275 | pg_losses2 = -advantages * torch.clamp(ratio, 276 | 1.0 - self.ppo_params['cliprange'], 277 | 1.0 + self.ppo_params['cliprange']) 278 | 279 | pg_loss = torch.mean(torch.max(pg_losses, pg_losses2)) 280 | pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double()) 281 | 282 | loss = pg_loss + self.ppo_params['vf_coef'] * vf_loss 283 | 284 | entropy = torch.mean(entropy_from_logits(logits)) 285 | approxkl = .5 * torch.mean((logprob - old_logprobs)**2) 286 | policykl = torch.mean(logprob - old_logprobs) 287 | return_mean, return_var = torch.mean(returns), torch.var(returns) 288 | value_mean, value_var = torch.mean(values), torch.var(values) 289 | 290 | stats = dict( 291 | loss=dict(policy=pg_loss, value=vf_loss, total=loss), 292 | policy=dict(entropy=entropy, approxkl=approxkl,policykl=policykl, clipfrac=pg_clipfrac, 293 | advantages=advantages, advantages_mean=torch.mean(advantages), ratio=ratio), 294 | returns=dict(mean=return_mean, var=return_var), 295 | val=dict(vpred=torch.mean(vpred), error=torch.mean((vpred - returns) ** 2), 296 | clipfrac=vf_clipfrac, mean=value_mean, var=value_var), 297 | ) 298 | return pg_loss, self.ppo_params['vf_coef'] * vf_loss, flatten_dict(stats) 299 | 300 | 301 | def record_step_stats(self, kl_coef, **data): 302 | """Record training step statistics.""" 303 | kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])] 304 | mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list])) 305 | mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']])) 306 | mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']])) 307 | stats = { 308 | 'objective/kl': mean_kl, 309 | 'objective/kl_dist': kl_list, 310 | 'objective/logprobs': data['logprobs'], 311 | 'objective/ref_logprobs': data['ref_logprobs'], 312 | 'objective/kl_coef': kl_coef, 313 | 'objective/entropy': mean_entropy, 314 | 'ppo/mean_non_score_reward': mean_non_score_reward, 315 | } 316 | 317 | for k, v in data['train_stats'].items(): 318 | stats[f'ppo/{k}'] = torch.mean(v, axis=0) 319 | stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var'] 320 | return stats 321 | --------------------------------------------------------------------------------