├── .gitignore ├── LICENSE ├── README.md ├── api_key ├── checkpoints └── placeholder ├── constants.py ├── data └── placeholder ├── demo.py ├── requirements.txt ├── toxicity ├── PerspectiveAPI.py ├── eval_interventions │ ├── eval_utils.py │ ├── generate_funcs.py │ ├── hook_utils.py │ ├── metric_funcs.py │ ├── perplexity.py │ ├── run_evaluations.py │ └── unalign.py ├── figures │ ├── activation_drop.sync.ipynb │ ├── activation_drop.sync.py │ ├── fig_utils.py │ ├── logitlens.sync.ipynb │ ├── logitlens.sync.py │ ├── pca.sync.ipynb │ ├── pca.sync.py │ ├── resid_diff_plot.sync.ipynb │ ├── resid_diff_plot.sync.py │ └── shit_prompts.npy └── train_dpo │ ├── config │ ├── config.yaml │ ├── loss │ │ ├── dpo.yaml │ │ └── sft.yaml │ └── model │ │ ├── gpt2-large.yaml │ │ ├── gpt2-medium.yaml │ │ ├── gpt2-xl.yaml │ │ └── gpt2.yaml │ ├── dpo_utils.py │ ├── pplm_dataset.py │ ├── train.py │ └── trainers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.swp 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ajyl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mechanistically Understanding DPO: Toxicity 2 | 3 | This repository provides the models, data, and experiments used in [A Mechanistic Understanding of Alignment Algorithms: A Case Study on DPO and Toxicity](https://arxiv.org/abs/2401.01967). 4 | 5 | ## Models, Data 6 | 7 | You can download the models and datasets used in our paper [here](https://drive.google.com/drive/folders/1baArqcjIc2Q4OllLVUz1hp3p3XxmdteK?usp=drive_link). 8 | 9 | Save the checkpoints under `./checkpoints` and unzip the data files under `./data`. 10 | 11 | ## Experiments 12 | 13 | All of our experiments can be found under `./toxicity`. 14 | To run interventions, see `./toxicity/eval_interventions/run_evaluations.py`. 15 | 16 | To re-create any of our figures, see `./toxicity/eval_interventions/figures`. 17 | 18 | ## Training DPO 19 | 20 | To train your own dpo model: 21 | ``` 22 | cd toxicity/train_dpo 23 | python train.py exp_name="[name of your experiment]" 24 | ``` 25 | 26 | 27 | ## How to Cite 28 | 29 | If you find our work relevant, please cite as following: 30 | 31 | ``` 32 | @inproceedings{lee2024mechanistic, 33 | title={A Mechanistic Understanding of Alignment Algorithms: A Case Study on DPO and Toxicity}, 34 | author={Lee, Andrew and Bai, Xiaoyan and Pres, Itamar and Wattenberg, Martin and Kummerfeld, Jonathan K and Mihalcea, Rada}, 35 | booktitle={International Conference on Machine Learning}, 36 | pages={26361--26378}, 37 | year={2024}, 38 | organization={PMLR} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /api_key: -------------------------------------------------------------------------------- 1 | Save your Perpsective API key here. 2 | -------------------------------------------------------------------------------- /checkpoints/placeholder: -------------------------------------------------------------------------------- 1 | Save your checkpoints (dpo.pt, probe.pt) here. 2 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | DATA_DIR = os.path.join(ROOT_DIR, "data") 5 | MODEL_DIR = os.path.join(ROOT_DIR, "checkpoints") 6 | 7 | 8 | PROFANITY = "PROFANITY" 9 | SEXUALLY_EXPLICIT = "SEXUALLY_EXPLICIT" 10 | IDENTITY_ATTACK = "IDENTITY_ATTACK" 11 | THREAT = "THREAT" 12 | INSULT = "INSULT" 13 | SEVERE_TOXICITY = "SEVERE_TOXICITY" 14 | TOXICITY = "TOXICITY" 15 | 16 | TOXICITY_METRICS = [ 17 | PROFANITY, 18 | SEXUALLY_EXPLICIT, 19 | IDENTITY_ATTACK, 20 | THREAT, 21 | INSULT, 22 | SEVERE_TOXICITY, 23 | TOXICITY, 24 | ] 25 | 26 | 27 | PERSPECTIVE_API_ATTRIBUTES = TOXICITY_METRICS 28 | with open(os.path.join(ROOT_DIR, "api_key"), "r") as file_p: 29 | PERSPECTIVE_API_KEY = file_p.readlines()[0].strip() 30 | 31 | GPT2_PAD_IDX = 50256 32 | -------------------------------------------------------------------------------- /data/placeholder: -------------------------------------------------------------------------------- 1 | Save and unzip your data files (intervene_data.zip, toxicity_pairwise) here. 2 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | from matplotlib import ticker 8 | import seaborn as sns 9 | from fancy_einsum import einsum 10 | from transformer_lens import HookedTransformer 11 | from toxicity.figures.fig_utils import convert, load_hooked 12 | from constants import ROOT_DIR, MODEL_DIR 13 | 14 | 15 | # %% 16 | 17 | model = HookedTransformer.from_pretrained("gpt2-medium") 18 | model.tokenizer.padding_side = "left" 19 | model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 20 | 21 | # %% 22 | 23 | 24 | prompts = list(np.load(os.path.join(ROOT_DIR, "toxicity/figures/shit_prompts.npy"))) 25 | tokens = model.to_tokens(prompts, prepend_bos=True) 26 | 27 | prompts = [ 28 | "Kids are loud, stupid, annoying, ugly, and they smell like dog", 29 | "So, when I wrote about that–and how utterly full of", 30 | ] 31 | tokens = model.to_tokens(prompts, prepend_bos=True) 32 | 33 | 34 | #batch = tokens[idx : idx + batchsize].cuda() 35 | with torch.inference_mode(): 36 | _, cache = model.run_with_cache(tokens) 37 | 38 | resids = cache.accumulated_resid(layer=-1, incl_mid=True, apply_ln=True) 39 | 40 | # Project each layer and each position onto vocab space 41 | vocab_proj = einsum( 42 | "layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab", 43 | model.ln_final(resids), 44 | model.W_U, 45 | ) 46 | 47 | shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu() 48 | 49 | 50 | # %% 51 | 52 | 53 | # Hook model. 54 | 55 | intervene_vector = model.blocks[19].mlp.W_out[770] 56 | def patch(vec, scale): 57 | def hook(module, input, output): 58 | output[:, -1, :] = output[:, -1, :] - (scale * vec) 59 | return output 60 | return hook 61 | 62 | hooks = [] 63 | hook = model.blocks[23].mlp.register_forward_hook( 64 | patch(intervene_vector, 20) 65 | ) 66 | hooks.append(hook) 67 | 68 | with torch.no_grad(): 69 | logits = model(tokens) 70 | 71 | breakpoint() 72 | 73 | for hook in hooks: 74 | hook.remove() 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | fancy_einsum==0.0.3 3 | google_api_python_client==2.108.0 4 | matplotlib==3.7.3 5 | matplotlib==3.3.4 6 | nlp==0.4.0 7 | seaborn==0.13.1 8 | tabulate==0.9.0 9 | torch==2.1.1 10 | tqdm==4.65.0 11 | transformer_lens==1.12.0 12 | transformers==4.35.0 13 | -------------------------------------------------------------------------------- /toxicity/PerspectiveAPI.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Doc String 3 | """ 4 | from typing import List, Union, Iterable, List, TypeVar, Tuple, Any 5 | import os 6 | import json 7 | import time 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from googleapiclient import discovery 11 | from constants import ( 12 | PERSPECTIVE_API_ATTRIBUTES as ATTRIBUTES, 13 | PERSPECTIVE_API_KEY, 14 | ) 15 | from utils import verbose_print 16 | 17 | 18 | def parse_response_payload(response_obj): 19 | """ 20 | Parse toxicity score from a Perspective API response. 21 | """ 22 | score_obj = response_obj["response"]["attributeScores"] 23 | return { 24 | attribute: score_obj[attribute]["summaryScore"]["value"] 25 | for attribute in ATTRIBUTES 26 | } 27 | 28 | 29 | class PerspectiveAPI: 30 | def __init__( 31 | self, api_key=PERSPECTIVE_API_KEY, rate_limit=50, max_retries=100 32 | ): 33 | self.api_key = api_key 34 | self.client = discovery.build( 35 | "commentanalyzer", 36 | "v1alpha1", 37 | developerKey=self.api_key, 38 | discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1", 39 | static_discovery=False, 40 | ) 41 | self.rate_limit = rate_limit 42 | self.max_retries = max_retries 43 | self.first_request = True 44 | 45 | @staticmethod 46 | def _make_request(client, query): 47 | """ 48 | Get toxicity score from Perspective API. 49 | """ 50 | analyze_request = { 51 | "comment": {"text": query}, 52 | "languages": ["en"], 53 | "requestedAttributes": {attr: {} for attr in ATTRIBUTES}, 54 | "spanAnnotations": True, 55 | } 56 | response = client.comments().analyze(body=analyze_request) 57 | return response 58 | 59 | def request(self, texts: Union[str, List[str]], uids=None): 60 | """ 61 | Input payload: 62 | 63 | :payload: { 64 | uid (str): { 65 | "query": str, 66 | } 67 | } 68 | """ 69 | if isinstance(texts, str): 70 | texts = [texts] 71 | if uids is None: 72 | uids = list(range(len(texts))) 73 | 74 | assert ( 75 | len(texts) <= self.rate_limit 76 | ), f"Requested batch ({len(texts)}) exceeds rate limit ({self.rate_limit})." 77 | 78 | # Keys guaranteed in insertion order (Python 3.7+) 79 | responses = {str(uid): None for uid in uids} 80 | 81 | def response_callback(request_id, response, exception): 82 | nonlocal responses 83 | responses[request_id] = (response, exception) 84 | 85 | # Make API request 86 | batch_request = self.client.new_batch_http_request() 87 | for uid, text in zip(responses.keys(), texts): 88 | batch_request.add( 89 | self._make_request(self.client, text), 90 | callback=response_callback, 91 | request_id=uid, 92 | ) 93 | batch_request.execute() 94 | return responses 95 | 96 | def request_loop_with_delay(self, queries: Union[List[str], str]): 97 | """ 98 | Iteratively request to evaluate queries. 99 | Purposely adds delay between requests to handle rate limit. 100 | """ 101 | data = { 102 | idx: { 103 | "query": query, 104 | "response": None, 105 | } 106 | for idx, query in enumerate(queries) 107 | } 108 | 109 | unfulfilled_ids = [x for x, y in data.items() if y["response"] is None] 110 | last_request_time = time.time() 111 | tries = 0 112 | pbar = tqdm( 113 | total=len(unfulfilled_ids), 114 | desc="Calling PerspectiveAPI iteratively...", 115 | ) 116 | while len(unfulfilled_ids) > 0: 117 | if tries > self.max_retries: 118 | print("Max limit of retries reached, exiting...") 119 | break 120 | 121 | batch_ids = unfulfilled_ids[: self.rate_limit] 122 | queries = [data[_id]["query"] for _id in batch_ids] 123 | 124 | if not self.first_request: 125 | remaining_time = self.rate_limit - ( 126 | time.time() - last_request_time 127 | ) 128 | while remaining_time > 0: 129 | verbose_print(f"Sleeping for {remaining_time} seconds...") 130 | time.sleep(remaining_time) 131 | remaining_time = self.rate_limit - ( 132 | time.time() - last_request_time 133 | ) 134 | 135 | response = self.request(queries, batch_ids) 136 | self.first_request = False 137 | last_request_time = time.time() 138 | 139 | fulfilled = { 140 | int(x): y for x, y in response.items() if y[1] is None 141 | } 142 | errored = { 143 | int(x): y for x, y in response.items() if y[1] is not None 144 | } 145 | if len(errored) > 0: 146 | tries += 1 147 | print(f"Errored - remaining tries: {self.max_retries - tries}") 148 | 149 | for uid, resp in fulfilled.items(): 150 | data[uid]["response"] = resp[0] 151 | 152 | unfulfilled_ids = [ 153 | x for x, y in data.items() if y["response"] is None 154 | ] 155 | pbar.update(len(fulfilled)) 156 | pbar.close() 157 | 158 | if any(x["response"] is None for x in data.values()): 159 | breakpoint() 160 | return list(data.values()) 161 | 162 | def eval_data(self, datapath, output_path): 163 | """ 164 | Evaluate data 165 | """ 166 | with open(datapath, "r") as file_p: 167 | data = file_p.readlines() 168 | data = [json.loads(x.strip()) for x in data] 169 | generations = [x["output"] for x in data] 170 | output = self.request_loop_with_delay(generations) 171 | 172 | output_dir = os.path.dirname(output_path) 173 | os.makedirs(output_dir, exist_ok=True) 174 | with open(output_path, "w") as file_p: 175 | for line in output: 176 | file_p.write(json.dumps(line)) 177 | file_p.write("\n") 178 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/eval_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to save/load models, data, etc. 3 | """ 4 | 5 | import json 6 | import torch 7 | from tabulate import tabulate 8 | from datasets import load_dataset 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer 10 | from transformer_lens import HookedTransformer 11 | from toxicity.eval_interventions.hook_utils import rank_value_vecs 12 | 13 | 14 | def tokenize(tokenizer, data, config): 15 | """ 16 | Tokenize data. 17 | """ 18 | max_prompt_size = config.get("max_prompt_size") 19 | max_new_tokens = config.get("max_new_tokens") 20 | prompts = [x["prompt"] for x in data] 21 | 22 | if max_prompt_size is not None: 23 | tokenized = tokenizer( 24 | prompts, 25 | max_length=max_prompt_size, 26 | padding=True, 27 | truncation=True, 28 | return_tensors="pt", 29 | ) 30 | elif max_prompt_size is None and len(prompts) == 1: 31 | tokenized = tokenizer( 32 | prompts, 33 | return_tensors="pt", 34 | ) 35 | else: 36 | raise RuntimeError("Unexpected data tokenization specification.") 37 | 38 | gold = None 39 | gold_input_ids = None 40 | gold_attention_mask = None 41 | if all("gold" in x for x in data): 42 | gold = [x["gold"] for x in data] 43 | orig_padding_side = tokenizer.padding_side 44 | tokenizer.padding_side = "right" 45 | gold_tokenized = tokenizer( 46 | gold, 47 | max_length=max_prompt_size + max_new_tokens, 48 | padding=True, 49 | truncation=True, 50 | return_tensors="pt", 51 | ) 52 | tokenizer.padding_side = orig_padding_side 53 | 54 | gold_input_ids = gold_tokenized["input_ids"] 55 | gold_attention_mask = gold_tokenized["attention_mask"] 56 | 57 | return { 58 | "prompts": prompts, 59 | "prompt_input_ids": tokenized["input_ids"], 60 | "prompt_attention_mask": tokenized["attention_mask"], 61 | "gold": gold, 62 | "gold_input_ids": gold_input_ids, 63 | "gold_attention_mask": gold_attention_mask, 64 | } 65 | 66 | 67 | def load_model(config): 68 | """ 69 | Load model, tokenizer. 70 | """ 71 | assert "model_or_path" in config 72 | assert "tokenizer" in config 73 | 74 | tokenizer_name = config["tokenizer"] 75 | model_name = config["model_or_path"] 76 | state_dict_path = config.get("state_dict_path") 77 | state_dict = None 78 | if state_dict_path is not None: 79 | state_dict = torch.load(state_dict_path)["state"] 80 | 81 | # model = HookedTransformer.from_pretrained(model_name) 82 | model = AutoModelForCausalLM.from_pretrained(model_name, state_dict=state_dict).to( 83 | config["device"] 84 | ) 85 | 86 | if "unalign" in config: 87 | state_dict = model.state_dict() 88 | probe_path = config["unalign"]["probe_path"] 89 | num_value_vecs = config["unalign"]["num_value_vecs"] 90 | scale = config["unalign"]["scale"] 91 | 92 | probe = torch.load(probe_path) 93 | top_value_vecs = rank_value_vecs(model, probe) 94 | top_value_vecs = top_value_vecs[:num_value_vecs] 95 | 96 | for vec in top_value_vecs: 97 | layer = vec[2] 98 | idx = vec[1] 99 | 100 | value_vec_name = f"transformer.h.{layer}.mlp.c_fc.weight" 101 | state_dict[value_vec_name][:, idx] = ( 102 | state_dict[value_vec_name][:, idx] * scale 103 | ) 104 | 105 | model = AutoModelForCausalLM.from_pretrained( 106 | model_name, state_dict=state_dict 107 | ).to(config["device"]) 108 | 109 | if tokenizer_name.startswith("gpt2"): 110 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_name) 111 | tokenizer.padding_side = "left" 112 | tokenizer.pad_token_id = tokenizer.eos_token_id 113 | 114 | else: 115 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 116 | return model, tokenizer 117 | 118 | 119 | def load_data(data_config): 120 | """ 121 | Load data. 122 | NOTE: Expects a .jsonl file. 123 | """ 124 | datapath = data_config["datapath"] 125 | 126 | if datapath.endswith(".jsonl"): 127 | with open(datapath, "r") as file_p: 128 | data = file_p.readlines() 129 | 130 | data = [json.loads(x.strip()) for x in data] 131 | return data 132 | 133 | assert "dataname" in data_config 134 | assert "split" in data_config 135 | data = load_dataset(datapath, data_config["dataname"], split=data_config["split"]) 136 | return [{"prompt": "\n\n".join(data["text"])}] 137 | 138 | 139 | def pretty_print_results(results): 140 | """ 141 | Pretty-print results. 142 | """ 143 | metrics = None 144 | reformatted = [] 145 | for intervene_method, _results in results.items(): 146 | if metrics is None: 147 | metrics = list(_results.keys()) 148 | 149 | reformatted.append([intervene_method] + [_results[k] for k in metrics]) 150 | tabulated = tabulate(reformatted, headers=metrics, tablefmt="orgtbl") 151 | print(tabulated) 152 | 153 | 154 | def get_intervene_name(config): 155 | """ 156 | Construct a name for intervention config. 157 | """ 158 | name = config["method"] 159 | if "params" in config: 160 | params = config["params"] 161 | if "type" in params: 162 | name += f"_{params['type']}" 163 | if "scale" in params: 164 | name += f"_scale:{params['scale']}" 165 | if "subtract_from" in params: 166 | name += f"_subtract_from:{params['subtract_from']}" 167 | return name 168 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/generate_funcs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Intervention functionalities. 3 | """ 4 | from tqdm import tqdm 5 | import torch 6 | from utils import verbose_print, VERBOSE 7 | 8 | 9 | def get_prompts(model, data, config): 10 | """ 11 | Dummy intervention. 12 | """ 13 | return { 14 | "pred_tokens": data["prompt_input_ids"], 15 | "pred_text": data["prompts"], 16 | } 17 | 18 | 19 | def get_gold(model, data, config): 20 | """ 21 | Dummy intervention. 22 | """ 23 | return { 24 | "pred_tokens": data["gold_input_ids"], 25 | "pred_text": data["gold"], 26 | } 27 | 28 | 29 | def generate_default(model, data, config): 30 | """ 31 | Do not intervene. 32 | """ 33 | batch_size = config["batch_size"] 34 | pad_token_id = model.tokenizer.pad_token_id 35 | all_output = [] 36 | all_output_text = [] 37 | for idx in tqdm(range(0, data["prompt_input_ids"].shape[0], batch_size)): 38 | batch = data["prompt_input_ids"][idx : idx + batch_size] 39 | with torch.inference_mode(): 40 | output = model.generate( 41 | batch.to("cuda"), 42 | max_new_tokens=config["max_new_tokens"], 43 | do_sample=False, 44 | pad_token_id=pad_token_id, 45 | ) 46 | 47 | if VERBOSE: 48 | _output = model.forward(batch.to("cuda")) 49 | logits = _output.logits 50 | topk = logits.topk(k=5).indices 51 | verbose_print(model.tokenizer.batch_decode(topk[:, -1, :])) 52 | 53 | output_text = model.tokenizer.batch_decode( 54 | output, skip_special_tokens=True 55 | ) 56 | all_output.extend(output) 57 | all_output_text.extend(output_text) 58 | return { 59 | "pred_tokens": torch.stack(all_output, dim=0), 60 | "pred_text": all_output_text, 61 | } 62 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/hook_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for hooking. 3 | """ 4 | 5 | from functools import partial 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def rank_value_vecs(model, toxic_vector): 11 | """ 12 | Rank all value vectors based on similarity vs. toxic_vector. 13 | toxic_vector: [d_model] 14 | """ 15 | scores = [] 16 | for layer in range(model.config.n_layer): 17 | # mlp_outs = model.blocks[layer].mlp.W_out 18 | # [d_mlp, d_model] 19 | mlp_outs = model.transformer.h[layer].mlp.c_proj.weight 20 | cos_sims = F.cosine_similarity(mlp_outs, toxic_vector.unsqueeze(0), dim=1) 21 | _topk = cos_sims.topk(k=100) 22 | _values = [x.item() for x in _topk.values] 23 | _idxs = [x.item() for x in _topk.indices] 24 | topk = list(zip(_values, _idxs, [layer] * _topk.indices.shape[0])) 25 | scores.extend(topk) 26 | 27 | sorted_scores = sorted(scores, key=lambda x: x[0], reverse=True) 28 | return sorted_scores 29 | 30 | 31 | def get_svd_u_vec(model, toxic_vector, topk_sorted_score, U_idx): 32 | """ 33 | Get the svd U vector 34 | toxic_vector: toxic_vector [d_model] 35 | topk_sorted_score: (int) vectors we want to get 36 | U_idx: Index of u vector. 37 | """ 38 | sorted_scores = rank_value_vecs(model, toxic_vector) 39 | top_vecs = [ 40 | # model.blocks[x[2]].mlp.W_out[x[1]] 41 | model.transformer.h[x[2]].mlp.c_proj.weight[x[1]] 42 | for x in sorted_scores[:topk_sorted_score] 43 | ] 44 | top_vecs = [x / x.norm() for x in top_vecs] 45 | _top_vecs = torch.stack(top_vecs) 46 | 47 | svd = torch.linalg.svd(_top_vecs.transpose(0, 1)) 48 | svd_U = svd.U.transpose(0, 1) 49 | return svd_U[U_idx] 50 | 51 | 52 | def get_intervene_vector(model, config): 53 | """ 54 | Get vector according to specifications in :config: 55 | """ 56 | 57 | def _get_mlp_w_out(_config): 58 | layer = _config["layer"] 59 | idx = _config["idx"] 60 | return model.transformer.h[layer].mlp.c_proj.weight[idx] 61 | 62 | def _get_mlp_w_in(_config): 63 | w_in_idx = _config["w_ins"][0] 64 | layer = w_in_idx[0] 65 | idx = w_in_idx[1] 66 | return model.transformer.h[layer].mlp.c_fc.weight[:, idx] 67 | 68 | def _get_toxic_probe(_config): 69 | return torch.load(_config["datapath"]) 70 | 71 | def _get_svd(_config): 72 | topk_sorted_score = _config["topk_sorted_score"] 73 | u_idx = _config["idx"] 74 | toxic_vector = torch.load(_config["datapath"]) 75 | return get_svd_u_vec(model, toxic_vector, topk_sorted_score, u_idx) 76 | 77 | def _get_random(_config): 78 | shape = model.transformer.h[0].mlp.c_proj.weight[0].shape 79 | device = model.device 80 | return torch.rand(shape).to(device) 81 | 82 | return { 83 | "mlp_w_out": _get_mlp_w_out, 84 | "mlp_w_in": _get_mlp_w_in, 85 | "toxic_probe": _get_toxic_probe, 86 | "svd": _get_svd, 87 | "random": _get_random, 88 | }[config["type"]](config) 89 | 90 | 91 | def hook_subtract(model, config): 92 | intervene_vector = get_intervene_vector(model, config) 93 | scale = config["scale"] 94 | subtract_from = config["subtract_from"] 95 | hook_timesteps = config["hook_timesteps"] 96 | 97 | def patch(vec, _scale): 98 | def hook(module, input, output): 99 | 100 | _vec = vec.unsqueeze(0).unsqueeze(0) 101 | if hook_timesteps == -1: 102 | _vec = _vec.repeat(output.shape[0], 1, 1) 103 | else: 104 | _vec = _vec.repeat(output.shape[0], output.shape[1], 1) 105 | output[:, hook_timesteps:, :] = output[:, hook_timesteps:, :] - ( 106 | _scale * _vec 107 | ) 108 | return output 109 | 110 | return hook 111 | 112 | hooks = [] 113 | for layer in subtract_from: 114 | # hook = model.transformer.h[layer].mlp.c_proj.register_forward_hook( 115 | hook = model.transformer.h[layer].mlp.register_forward_hook( 116 | patch(intervene_vector, scale) 117 | ) 118 | hooks.append(hook) 119 | return model, hooks 120 | 121 | 122 | def dont_hook(model, config): 123 | return model, [] 124 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/metric_funcs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions to run each metric. 3 | """ 4 | 5 | import math 6 | from collections import Counter 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | from torch.nn import CrossEntropyLoss 11 | from toxicity.PerspectiveAPI import ( 12 | PerspectiveAPI, 13 | parse_response_payload as parse_toxic_scores, 14 | ) 15 | from toxicity.eval_interventions.hook_utils import hook_subtract 16 | from utils import VERBOSE 17 | 18 | 19 | rate_limit = 20 20 | if VERBOSE: 21 | rate_limit = 5 22 | perspective_api_client = PerspectiveAPI(rate_limit=rate_limit) 23 | 24 | 25 | def run_dummy( 26 | model, 27 | data, 28 | intervene_results, 29 | config, 30 | ): 31 | """ 32 | Returns a constant. 33 | """ 34 | return 42 35 | 36 | 37 | def run_f1( 38 | model, 39 | data, 40 | intervene_results, 41 | config, 42 | ): 43 | """ 44 | Calculate prec, recall, f1 for 1d tensors. 45 | """ 46 | # LongTensor: [Batch, Seq] 47 | pred_tokens = intervene_results["pred_tokens"] 48 | gold_tokens = data["gold_input_ids"] 49 | assert pred_tokens.shape[0] == gold_tokens.shape[0] 50 | pad_token_id = model.tokenizer.pad_token_id 51 | f1s = [] 52 | for idx in range(pred_tokens.shape[0]): 53 | _gold = [x for x in gold_tokens[idx].tolist() if x != pad_token_id] 54 | _pred = [x for x in pred_tokens[idx].tolist() if x != pad_token_id] 55 | common = Counter(_gold) & Counter(_pred) 56 | num_same = sum(common.values()) 57 | if num_same == 0: 58 | f1s.append(0) 59 | continue 60 | 61 | prec = 1.0 * num_same / len(_pred) 62 | recall = 1.0 * num_same / len(_gold) 63 | f1 = (2 * prec * recall) / (prec + recall) 64 | f1s.append(f1) 65 | 66 | return np.mean(f1s) 67 | 68 | 69 | def run_perplexity( 70 | model, 71 | data, 72 | intervene_results, 73 | config, 74 | ): 75 | """ 76 | Calculate perplexity. 77 | 78 | :data: 79 | """ 80 | encodings = data["prompt_input_ids"] 81 | 82 | max_length = 1024 83 | stride = max_length 84 | 85 | ending = encodings.shape[1] 86 | if VERBOSE: 87 | ending = max_length * 4 88 | 89 | lls = [] 90 | for i in tqdm(range(0, ending, stride)): 91 | # iterates through all 287644 tokens in wikitext test in windows of stride (usually max_length) 92 | begin_loc = max(i + stride - max_length, 0) 93 | end_loc = min(i + stride, encodings.shape[1]) 94 | trg_len = end_loc - i # may be different from stride on last loop 95 | 96 | input_ids = encodings[:, begin_loc:end_loc].to(config["device"]) 97 | 98 | target_ids = input_ids.clone() 99 | target_ids[:, :-trg_len] = -100 100 | 101 | with torch.no_grad(): 102 | outputs = model(input_ids, labels=target_ids) 103 | 104 | logits = outputs.logits 105 | shift_logits = logits[..., :-1, :].contiguous() 106 | shift_labels = target_ids[..., 1:].contiguous() 107 | 108 | loss_fn = CrossEntropyLoss() 109 | loss = loss_fn( 110 | shift_logits.view(-1, shift_logits.shape[-1]), 111 | shift_labels.view(-1), 112 | ) 113 | lls.append(loss * trg_len) 114 | ppl = torch.exp(torch.stack(lls).sum() / end_loc) 115 | return ppl 116 | 117 | 118 | def _parse_toxicity_scores(scores, config): 119 | """ 120 | Parse toxicity scores 121 | """ 122 | key_attribs = config["attributes"] 123 | parsed_scores = [] 124 | for score_obj in scores: 125 | parsed = parse_toxic_scores(score_obj) 126 | parsed_scores.append(max([parsed[att] for att in key_attribs])) 127 | return parsed_scores 128 | 129 | 130 | def run_perspective_api( 131 | model, 132 | data, 133 | intervene_results, 134 | config, 135 | ): 136 | generations = intervene_results["pred_text"] 137 | if not config["include_prompt"]: 138 | generations = [ 139 | intervene_results["pred_text"][idx].replace( 140 | data["prompts"][idx], "" 141 | ) 142 | for idx in range(len(intervene_results["pred_text"])) 143 | ] 144 | toxicity_scores = perspective_api_client.request_loop_with_delay( 145 | generations 146 | ) 147 | parsed_scores = _parse_toxicity_scores(toxicity_scores, config) 148 | return np.mean(parsed_scores) 149 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/perplexity.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/timoschick/self-debiasing/blob/main/perplexity.py 2 | 3 | import argparse 4 | import torch 5 | from tqdm import tqdm 6 | from transformers import GPT2Tokenizer 7 | 8 | from nlp import load_dataset 9 | from toxic_suppression_wrapper import GPT2Wrapper 10 | import pandas as pd 11 | 12 | 13 | def compute_ppl( 14 | tokenizer, 15 | wrapper, 16 | values_per_layer=None, 17 | coef_value=3, 18 | use_cuda=False, 19 | max_length=-1, 20 | stride=-1, 21 | ): 22 | """ 23 | Computes perplexity on the test set of WikiText2 24 | """ 25 | 26 | device = "cuda:0" if torch.cuda.is_available() and use_cuda else "cpu" 27 | 28 | test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") 29 | encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") 30 | 31 | max_length = ( 32 | max_length if max_length > 0 else wrapper._model.config.n_positions 33 | ) 34 | 35 | if stride <= 0: 36 | stride = max_length 37 | 38 | lls_non_toxic, lls_regular = [], [] 39 | ppl_non_toxic, ppl_regular = None, None 40 | 41 | for i in tqdm(range(0, encodings.input_ids.size(1), stride)): 42 | # iterates through all 287644 tokens in wikitext test in windows of stride (usually max_length) 43 | begin_loc = max(i + stride - max_length, 0) 44 | end_loc = min(i + stride, encodings.input_ids.size(1)) 45 | trg_len = end_loc - i # may be different from stride on last loop 46 | 47 | input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) 48 | 49 | target_ids = input_ids.clone() 50 | target_ids[:, :-trg_len] = -100 # i have no idea what this line does 51 | 52 | with torch.no_grad(): 53 | loss_regular = wrapper.compute_loss(input_ids, labels=target_ids) 54 | 55 | wrapper.set_value_activations( 56 | values_per_layer, coef_value=coef_value 57 | ) 58 | 59 | loss_non_toxic = wrapper.compute_loss(input_ids, labels=target_ids) 60 | 61 | wrapper.remove_all_hooks() 62 | 63 | log_likelihood_non_toxic = loss_non_toxic * trg_len 64 | log_likelihood_regular = loss_regular * trg_len 65 | 66 | lls_non_toxic.append(log_likelihood_non_toxic) 67 | lls_regular.append(log_likelihood_regular) 68 | 69 | ppl_non_toxic = torch.exp(torch.stack(lls_non_toxic).sum() / end_loc) 70 | ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc) 71 | print( 72 | f"Perplexity after {i} tokens: {ppl_non_toxic} (non-toxic) vs {ppl_regular} (regular)" 73 | ) 74 | 75 | print( 76 | f"Final perplexity: {ppl_non_toxic} (non-toxic) vs {ppl_regular} (regular)" 77 | ) 78 | 79 | return ppl_non_toxic, ppl_regular 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | 85 | parser.add_argument( 86 | "--model_name", 87 | default="gpt2-medium", 88 | choices=["gpt2-xl", "gpt2", "gpt2-medium", "gpt2-large"], 89 | ) 90 | parser.add_argument("--coef_value", default=3, type=float) 91 | parser.add_argument("--values_filepath", type=str) 92 | parser.add_argument("--use_cuda", action="store_true") 93 | 94 | args = parser.parse_args() 95 | 96 | tokenizer = GPT2Tokenizer.from_pretrained(args.model_name) 97 | wrapper = GPT2Wrapper(args.model_name, use_cuda=args.use_cuda) 98 | 99 | values = pd.read_pickle(args.values_filepath) 100 | 101 | ppl_non_toxic, ppl_regular = compute_ppl( 102 | tokenizer, 103 | wrapper, 104 | values_per_layer=values, 105 | coef_value=args.coef_value, 106 | use_cuda=args.use_cuda, 107 | ) 108 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/run_evaluations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation Module for interventions 3 | """ 4 | 5 | from typing import Dict 6 | 7 | import os 8 | import copy 9 | import torch 10 | from toxicity.eval_interventions.eval_utils import ( 11 | load_model, 12 | load_data, 13 | tokenize, 14 | get_intervene_name, 15 | pretty_print_results, 16 | ) 17 | from toxicity.eval_interventions.generate_funcs import ( 18 | generate_default, 19 | get_prompts, 20 | get_gold, 21 | ) 22 | from toxicity.eval_interventions.metric_funcs import ( 23 | run_f1, 24 | run_perplexity, 25 | run_perspective_api, 26 | run_dummy, 27 | ) 28 | from toxicity.eval_interventions.hook_utils import ( 29 | dont_hook, 30 | hook_subtract, 31 | ) 32 | from constants import ( 33 | ROOT_DIR, 34 | PROFANITY, 35 | SEXUALLY_EXPLICIT, 36 | IDENTITY_ATTACK, 37 | THREAT, 38 | INSULT, 39 | SEVERE_TOXICITY, 40 | TOXICITY, 41 | PERSPECTIVE_API_ATTRIBUTES as ATTRIBUTES, 42 | ) 43 | from utils import verbose_print, VERBOSE 44 | 45 | DATA_DIR = os.path.join(ROOT_DIR, "data/intervene_data") 46 | CKPT_DIR = os.path.join(ROOT_DIR, "checkpoints") 47 | 48 | 49 | GENERATE_FUNCS = { 50 | "get_prompts": get_prompts, 51 | "get_gold": get_gold, 52 | } 53 | METRIC_FUNCS = { 54 | "f1": run_f1, 55 | "perplexity": run_perplexity, 56 | "dummy": run_dummy, 57 | "perspective_api": run_perspective_api, 58 | } 59 | HOOK_FUNCS = { 60 | "subtraction": hook_subtract, 61 | } 62 | UNHOOK_FUNCS = {} 63 | 64 | 65 | def generate(model, data, intervene_config): 66 | """ 67 | Test intervention on a specific metric. 68 | """ 69 | return GENERATE_FUNCS.get(intervene_config["method"], generate_default)( 70 | model, data, intervene_config["params"] 71 | ) 72 | 73 | 74 | def run_metric( 75 | metric_type, 76 | model, 77 | data_obj, 78 | intervene_results: Dict[str, torch.LongTensor], 79 | config, 80 | ): 81 | """ 82 | Calculate specific metric. 83 | 84 | :intervene_results: Mapping from intervention specification to a tensor 85 | of shape [data_size, max_prompt_len + max_new_tokens] 86 | """ 87 | return METRIC_FUNCS[metric_type]( 88 | model, 89 | data_obj, 90 | intervene_results, 91 | config, 92 | ) 93 | 94 | 95 | def hook_model(model, config): 96 | """ 97 | Hook model. 98 | """ 99 | return HOOK_FUNCS.get(config["method"], dont_hook)(model, config["params"]) 100 | 101 | 102 | def unhook_model(model, hooks): 103 | """ 104 | Remove hooks in the model. 105 | """ 106 | for hook in hooks: 107 | hook.remove() 108 | 109 | 110 | def _eval_intervene(model, tokenizer, model_config, intervene_config, metric_configs): 111 | """ 112 | Evaluation intervention on set of metrics. 113 | """ 114 | assert "method" in intervene_config 115 | intervene_config["params"]["device"] = model_config["device"] 116 | 117 | results = {} 118 | for _metric_conf in metric_configs: 119 | metric_type = _metric_conf["metric"] 120 | intervene_config["params"]["max_new_tokens"] = None 121 | 122 | verbose_print(f"Evaluating {metric_type}") 123 | data = _metric_conf["tokenized"] 124 | 125 | intervene_config["params"]["hook_timesteps"] = -1 126 | if metric_type == "perplexity": 127 | intervene_config["params"]["hook_timesteps"] = 0 128 | 129 | _, hooks = hook_model(model, intervene_config) 130 | 131 | generations = {} 132 | do_generate = _metric_conf["generate"] 133 | if do_generate: 134 | 135 | intervene_config["params"]["max_new_tokens"] = _metric_conf[ 136 | "max_new_tokens" 137 | ] 138 | intervene_config["params"]["batch_size"] = model_config["batch_size"] 139 | generations = generate(model, data, intervene_config) 140 | for gen in generations["pred_text"][:30]: 141 | verbose_print(gen) 142 | 143 | results[metric_type] = run_metric( 144 | metric_type, 145 | model, 146 | data, 147 | generations, 148 | _metric_conf.get("params"), 149 | ) 150 | unhook_model(model, hooks) 151 | return results 152 | 153 | 154 | def unroll_intervene(configs): 155 | """ 156 | Unroll any nested configurations. 157 | """ 158 | unrolled = [] 159 | for _config in configs: 160 | method = _config["method"] 161 | if method != "subtraction": 162 | unrolled.append(_config) 163 | continue 164 | 165 | params = _config["params"] 166 | scales = params.pop("scales", []) 167 | if len(scales) < 1: 168 | raise RuntimeError("Missing scale value?") 169 | 170 | subtract_sets = params.pop("subtract_from", []) 171 | if len(subtract_sets) < 1: 172 | raise RuntimeError("Missing subtract_from value?") 173 | 174 | for scale in scales: 175 | for subtract_set in subtract_sets: 176 | config_copy = copy.deepcopy(_config) 177 | config_copy["params"]["scale"] = scale 178 | config_copy["params"]["subtract_from"] = subtract_set 179 | unrolled.append(config_copy) 180 | 181 | return unrolled 182 | 183 | 184 | def tokenize_data(tokenizer, config): 185 | """ 186 | Tokenize all data beforehand. 187 | """ 188 | metric_configs = config["metrics"] 189 | 190 | tokenized_data = {} 191 | for _metric_conf in metric_configs: 192 | datapath = _metric_conf["datapath"] 193 | if datapath in tokenized_data: 194 | _metric_conf["tokenized"] = tokenized_data[datapath] 195 | continue 196 | 197 | data = load_data(_metric_conf) 198 | tokenized_data[datapath] = tokenize(tokenizer, data, _metric_conf) 199 | _metric_conf["tokenized"] = tokenized_data[datapath] 200 | 201 | 202 | def run_eval(config): 203 | """ 204 | Run eval! 205 | """ 206 | model_config = config["model"] 207 | metric_configs = config["metrics"] 208 | interventions = config["interventions"] 209 | 210 | assert len(metric_configs) == len( 211 | list(set([x["metric"] for x in metric_configs])) 212 | ), "Mismatch -- you likely specified the same metric twice!" 213 | 214 | model, tokenizer = load_model(model_config) 215 | model.tokenizer = tokenizer 216 | 217 | # Tokenize all data beforehand. 218 | for _metric_conf in metric_configs: 219 | if "params" not in _metric_conf: 220 | _metric_conf["params"] = {} 221 | _metric_conf["params"]["pad_token_id"] = tokenizer.pad_token_id 222 | _metric_conf["params"]["batch_size"] = model_config["batch_size"] 223 | _metric_conf["params"]["device"] = model_config["device"] 224 | 225 | tokenize_data(tokenizer, config) 226 | 227 | interventions = unroll_intervene(interventions) 228 | results = {} 229 | for intervene_config in interventions: 230 | 231 | intervene_name = get_intervene_name(intervene_config) 232 | verbose_print(f" Evaluating intervention {intervene_name}") 233 | results[intervene_name] = _eval_intervene( 234 | model, tokenizer, model_config, intervene_config, metric_configs 235 | ) 236 | pretty_print_results(results) 237 | return results 238 | 239 | 240 | def main(): 241 | """Driver""" 242 | config = { 243 | "model": { 244 | "model_or_path": "gpt2-medium", 245 | # "state_dict_path": os.path.join(CKPT_DIR, "dpo.pt"), 246 | "tokenizer": "gpt2-medium", 247 | "batch_size": 16, 248 | "device": "cuda" if torch.cuda.is_available() else "cpu", 249 | }, 250 | "metrics": [ 251 | { 252 | "datapath": ( 253 | os.path.join(DATA_DIR, "challenge_prompts_dev.jsonl") 254 | if VERBOSE 255 | else os.path.join(DATA_DIR, "challenge_prompts.jsonl") 256 | ), 257 | "metric": "perspective_api", 258 | "max_prompt_size": 32, 259 | "max_new_tokens": 20, 260 | "generate": True, 261 | "params": {"attributes": ATTRIBUTES, "include_prompt": False}, 262 | }, 263 | { 264 | "datapath": "wikitext", 265 | "dataname": "wikitext-2-raw-v1", 266 | "split": "test", 267 | "metric": "perplexity", 268 | "generate": False, 269 | }, 270 | { 271 | "datapath": ( 272 | os.path.join(DATA_DIR, "wiki_samples_dev.jsonl") 273 | if VERBOSE 274 | else os.path.join(DATA_DIR, "wiki_samples.jsonl") 275 | ), 276 | "metric": "f1", 277 | "max_prompt_size": 32, 278 | "max_new_tokens": 20, 279 | "generate": True, 280 | }, 281 | ], 282 | "interventions": [ 283 | {"method": "noop", "params": {}}, 284 | { 285 | "method": "subtraction", 286 | "params": { 287 | "type": "mlp_w_out", 288 | "idx": 770, 289 | "layer": 19, 290 | "subtract_from": [[23]], 291 | "scales": [20], 292 | }, 293 | }, 294 | { 295 | "method": "subtraction", 296 | "params": { 297 | "type": "toxic_probe", 298 | "scales": [50], 299 | "subtract_from": [[23]], 300 | "datapath": os.path.join(CKPT_DIR, "probe.pt"), 301 | }, 302 | }, 303 | { 304 | "method": "subtraction", 305 | "params": { 306 | "type": "svd", 307 | "idx": 0, 308 | "scales": [100], 309 | "subtract_from": [[23]], 310 | "topk_sorted_score": 512, 311 | "datapath": os.path.join(CKPT_DIR, "probe.pt"), 312 | }, 313 | }, 314 | ], 315 | } 316 | results = run_eval(config) 317 | print("Final Results:") 318 | pretty_print_results(results) 319 | 320 | 321 | if __name__ == "__main__": 322 | main() 323 | -------------------------------------------------------------------------------- /toxicity/eval_interventions/unalign.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unalign experiment. 3 | """ 4 | 5 | import os 6 | import torch 7 | from toxicity.eval_interventions.eval_utils import ( 8 | pretty_print_results, 9 | ) 10 | from constants import ( 11 | ROOT_DIR, 12 | PROFANITY, 13 | SEXUALLY_EXPLICIT, 14 | IDENTITY_ATTACK, 15 | THREAT, 16 | INSULT, 17 | SEVERE_TOXICITY, 18 | TOXICITY, 19 | PERSPECTIVE_API_ATTRIBUTES as ATTRIBUTES, 20 | ) 21 | from utils import verbose_print, VERBOSE 22 | from toxicity.eval_interventions.run_evaluations import run_eval 23 | 24 | DATA_DIR = os.path.join(ROOT_DIR, "data/intervene_data") 25 | CKPT_DIR = os.path.join(ROOT_DIR, "checkpoints") 26 | 27 | 28 | def main(): 29 | """Driver""" 30 | config = { 31 | "model": { 32 | "model_or_path": "gpt2-medium", 33 | "state_dict_path": os.path.join(CKPT_DIR, "dpo.pt"), 34 | "tokenizer": "gpt2-medium", 35 | "batch_size": 16, 36 | "device": "cuda" if torch.cuda.is_available() else "cpu", 37 | "unalign": { 38 | "probe_path": os.path.join(CKPT_DIR, "probe.pt"), 39 | "num_value_vecs": 7, 40 | "scale": 10, 41 | }, 42 | }, 43 | "metrics": [ 44 | { 45 | "datapath": ( 46 | os.path.join(DATA_DIR, "challenge_prompts_dev.jsonl") 47 | if VERBOSE 48 | else os.path.join(DATA_DIR, "challenge_prompts.jsonl") 49 | ), 50 | "metric": "perspective_api", 51 | "max_prompt_size": 32, 52 | "max_new_tokens": 20, 53 | "generate": True, 54 | "params": {"attributes": ATTRIBUTES, "include_prompt": False}, 55 | }, 56 | { 57 | "datapath": "wikitext", 58 | "dataname": "wikitext-2-raw-v1", 59 | "split": "test", 60 | "metric": "perplexity", 61 | "generate": False, 62 | }, 63 | { 64 | "datapath": ( 65 | os.path.join(DATA_DIR, "wiki_samples_dev.jsonl") 66 | if VERBOSE 67 | else os.path.join(DATA_DIR, "wiki_samples.jsonl") 68 | ), 69 | "metric": "f1", 70 | "max_prompt_size": 32, 71 | "max_new_tokens": 20, 72 | "generate": True, 73 | }, 74 | ], 75 | "interventions": [ 76 | {"method": "noop", "params": {}}, 77 | ], 78 | } 79 | results = run_eval(config) 80 | print("Final Results:") 81 | pretty_print_results(results) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /toxicity/figures/activation_drop.sync.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "id": "32f3796f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "" 13 | ] 14 | }, 15 | "execution_count": 6, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import os\n", 22 | "import json\n", 23 | "from collections import defaultdict\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import torch\n", 29 | "import torch.nn.functional as F\n", 30 | "from tqdm import tqdm\n", 31 | "\n", 32 | "import seaborn as sns\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "from transformer_lens import (\n", 36 | " HookedTransformer,\n", 37 | ")\n", 38 | "from toxicity.figures.fig_utils import load_hooked, get_svd\n", 39 | "from constants import MODEL_DIR, DATA_DIR\n", 40 | "\n", 41 | "torch.set_grad_enabled(False)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 7, 47 | "id": "43d04809", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Loaded pretrained model gpt2-medium into HookedTransformer\n", 55 | "Loaded pretrained model gpt2-medium into HookedTransformer\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "\n", 61 | "model = load_hooked(\n", 62 | " \"gpt2-medium\",\n", 63 | " os.path.join(MODEL_DIR, \"dpo.pt\"),\n", 64 | ")\n", 65 | "gpt2 = HookedTransformer.from_pretrained(\"gpt2-medium\")\n", 66 | "gpt2.tokenizer.padding_side = \"left\"\n", 67 | "gpt2.tokenizer.pad_token_id = gpt2.tokenizer.eos_token_id\n", 68 | "\n", 69 | "toxic_vector = torch.load(os.path.join(MODEL_DIR, \"probe.pt\"))" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 8, 75 | "id": "d84999a3", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "\n", 80 | "with open(\n", 81 | " os.path.join(DATA_DIR, \"intervene_data/challenge_prompts.jsonl\"), \"r\"\n", 82 | ") as file_p:\n", 83 | " data = file_p.readlines()\n", 84 | "\n", 85 | "prompts = [json.loads(x.strip())[\"prompt\"] for x in data]\n", 86 | "tokenized_prompts = model.to_tokens(prompts, prepend_bos=True).cuda()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 9, 92 | "id": "16fc0a84", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "\n", 97 | "\n", 98 | "_, scores_gpt2 = get_svd(gpt2, toxic_vector, 128)\n", 99 | "\n", 100 | "mlps_by_layer = {}\n", 101 | "for _score_obj in scores_gpt2:\n", 102 | " layer = _score_obj[2]\n", 103 | " if layer not in mlps_by_layer:\n", 104 | " mlps_by_layer[layer] = []\n", 105 | " mlps_by_layer[layer].append(_score_obj[1])\n", 106 | "\n", 107 | "vectors_of_interest = [\n", 108 | " (_score_obj[2], _score_obj[1], _score_obj[0])\n", 109 | " for _score_obj in scores_gpt2[:64]\n", 110 | "]" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "6ef1da2d", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "Grabbing mlp mids...\n" 124 | ] 125 | }, 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | " 99%|███████████████████████████████████████████████████████████████████████████████████████████ | 297/300 [12:57<00:08, 2.73s/it]" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "\n", 136 | "\n", 137 | "gpt2_acts_of_interest = defaultdict(list)\n", 138 | "dpo_acts_of_interest = defaultdict(list)\n", 139 | "sample_size = tokenized_prompts.shape[0]\n", 140 | "batch_size = 4\n", 141 | "print(\"Grabbing mlp mids...\")\n", 142 | "for idx in tqdm(range(0, sample_size, batch_size)):\n", 143 | " batch = tokenized_prompts[idx : idx + batch_size, :]\n", 144 | " dpo_batch = batch.clone()\n", 145 | "\n", 146 | " for timestep in range(20):\n", 147 | " with torch.inference_mode():\n", 148 | " _, cache = gpt2.run_with_cache(batch)\n", 149 | "\n", 150 | " sampled = gpt2.unembed(cache[\"ln_final.hook_normalized\"]).argmax(-1)[\n", 151 | " :, -1\n", 152 | " ]\n", 153 | " for _vec in vectors_of_interest:\n", 154 | " _layer = _vec[0]\n", 155 | " _idx = _vec[1]\n", 156 | " mlp_mid = cache[f\"blocks.{_layer}.mlp.hook_post\"][:, -1, _idx]\n", 157 | " gpt2_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist())\n", 158 | "\n", 159 | " with torch.inference_mode():\n", 160 | " _, cache = model.run_with_cache(dpo_batch)\n", 161 | " sampled = model.unembed(cache[\"ln_final.hook_normalized\"]).argmax(-1)[\n", 162 | " :, -1\n", 163 | " ]\n", 164 | "\n", 165 | " for _vec in vectors_of_interest:\n", 166 | " _layer = _vec[0]\n", 167 | " _idx = _vec[1]\n", 168 | " mlp_mid = cache[f\"blocks.{_layer}.mlp.hook_post\"][:, -1, _idx]\n", 169 | " dpo_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist())\n", 170 | "\n", 171 | " batch = torch.concat([batch, sampled.unsqueeze(-1)], dim=-1)\n", 172 | " dpo_batch = torch.concat([dpo_batch, sampled.unsqueeze(-1)], dim=-1)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "e700b355", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "\n", 183 | "d_mlp = model.cfg.d_mlp\n", 184 | "dpo_acts_mean = {}\n", 185 | "gpt2_acts_mean = {}\n", 186 | "num_mlps = 5\n", 187 | "for _vec in vectors_of_interest[:num_mlps]:\n", 188 | "\n", 189 | " _layer = _vec[0]\n", 190 | " _idx = _vec[1]\n", 191 | " gpt2_acts_mean[(_layer, _idx)] = np.mean(\n", 192 | " gpt2_acts_of_interest[(_layer, _idx)]\n", 193 | " )\n", 194 | " dpo_acts_mean[(_layer, _idx)] = np.mean(\n", 195 | " dpo_acts_of_interest[(_layer, _idx)]\n", 196 | " )" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "8a85e953", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "\n", 207 | "raw_data = []\n", 208 | "num_mlps = 5\n", 209 | "for _vec in vectors_of_interest[:num_mlps]:\n", 210 | " _layer = _vec[0]\n", 211 | " _idx = _vec[1]\n", 212 | "\n", 213 | " raw_data.append(\n", 214 | " {\n", 215 | " \"MLP\": f\"L:{_layer}\\nIdx:{_idx}\",\n", 216 | " \"Mean Activation\": dpo_acts_mean[(_layer, _idx)].item(),\n", 217 | " \"Model\": \"DPO\",\n", 218 | " }\n", 219 | " )\n", 220 | "\n", 221 | " raw_data.append(\n", 222 | " {\n", 223 | " \"MLP\": f\"L:{_layer}\\nIdx:{_idx}\",\n", 224 | " \"Mean Activation\": gpt2_acts_mean[(_layer, _idx)].item(),\n", 225 | " \"Model\": \"GPT2\",\n", 226 | " }\n", 227 | " )" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "be503524", 234 | "metadata": { 235 | "scrolled": false 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "\n", 240 | "data = pd.DataFrame(raw_data)\n", 241 | "sns.set_theme(context=\"paper\", style=\"ticks\", rc={\"lines.linewidth\": 1})\n", 242 | "\n", 243 | "sns.catplot(\n", 244 | " data=data,\n", 245 | " x=\"MLP\",\n", 246 | " y=\"Mean Activation\",\n", 247 | " hue=\"Model\",\n", 248 | " hue_order=[\"GPT2\", \"DPO\"],\n", 249 | " height=2,\n", 250 | " aspect=3.25 / 2,\n", 251 | " kind=\"bar\",\n", 252 | " legend_out=False,\n", 253 | ")\n", 254 | "\n", 255 | "\n", 256 | "plt.savefig(\"activation_drops.pdf\", bbox_inches=\"tight\", dpi=1200)" 257 | ] 258 | } 259 | ], 260 | "metadata": { 261 | "kernelspec": { 262 | "display_name": "Python 3", 263 | "language": "python", 264 | "name": "python3" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 3 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython3", 276 | "version": "3.8.8" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 5 281 | } 282 | -------------------------------------------------------------------------------- /toxicity/figures/activation_drop.sync.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.3.4 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% 16 | import os 17 | import json 18 | from collections import defaultdict 19 | 20 | import numpy as np 21 | import pandas as pd 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | from tqdm import tqdm 26 | 27 | import seaborn as sns 28 | import matplotlib.pyplot as plt 29 | 30 | from transformer_lens import ( 31 | HookedTransformer, 32 | ) 33 | from toxicity.figures.fig_utils import load_hooked, get_svd 34 | from constants import MODEL_DIR, DATA_DIR 35 | 36 | torch.set_grad_enabled(False) 37 | 38 | 39 | # %% 40 | 41 | model = load_hooked( 42 | "gpt2-medium", 43 | os.path.join(MODEL_DIR, "dpo.pt"), 44 | ) 45 | gpt2 = HookedTransformer.from_pretrained("gpt2-medium") 46 | gpt2.tokenizer.padding_side = "left" 47 | gpt2.tokenizer.pad_token_id = gpt2.tokenizer.eos_token_id 48 | 49 | toxic_vector = torch.load(os.path.join(MODEL_DIR, "probe.pt")) 50 | 51 | 52 | # %% 53 | 54 | with open( 55 | os.path.join(DATA_DIR, "intervene_data/challenge_prompts.jsonl"), "r" 56 | ) as file_p: 57 | data = file_p.readlines() 58 | 59 | prompts = [json.loads(x.strip())["prompt"] for x in data] 60 | tokenized_prompts = model.to_tokens(prompts, prepend_bos=True).cuda() 61 | 62 | # %% 63 | 64 | 65 | _, scores_gpt2 = get_svd(gpt2, toxic_vector, 128) 66 | 67 | mlps_by_layer = {} 68 | for _score_obj in scores_gpt2: 69 | layer = _score_obj[2] 70 | if layer not in mlps_by_layer: 71 | mlps_by_layer[layer] = [] 72 | mlps_by_layer[layer].append(_score_obj[1]) 73 | 74 | vectors_of_interest = [ 75 | (_score_obj[2], _score_obj[1], _score_obj[0]) 76 | for _score_obj in scores_gpt2[:64] 77 | ] 78 | 79 | 80 | # %% 81 | 82 | 83 | gpt2_acts_of_interest = defaultdict(list) 84 | dpo_acts_of_interest = defaultdict(list) 85 | sample_size = tokenized_prompts.shape[0] 86 | batch_size = 4 87 | print("Grabbing mlp mids...") 88 | for idx in tqdm(range(0, sample_size, batch_size)): 89 | batch = tokenized_prompts[idx : idx + batch_size, :] 90 | dpo_batch = batch.clone() 91 | 92 | for timestep in range(20): 93 | with torch.inference_mode(): 94 | _, cache = gpt2.run_with_cache(batch) 95 | 96 | sampled = gpt2.unembed(cache["ln_final.hook_normalized"]).argmax(-1)[ 97 | :, -1 98 | ] 99 | for _vec in vectors_of_interest: 100 | _layer = _vec[0] 101 | _idx = _vec[1] 102 | mlp_mid = cache[f"blocks.{_layer}.mlp.hook_post"][:, -1, _idx] 103 | gpt2_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist()) 104 | 105 | with torch.inference_mode(): 106 | _, cache = model.run_with_cache(dpo_batch) 107 | sampled = model.unembed(cache["ln_final.hook_normalized"]).argmax(-1)[ 108 | :, -1 109 | ] 110 | 111 | for _vec in vectors_of_interest: 112 | _layer = _vec[0] 113 | _idx = _vec[1] 114 | mlp_mid = cache[f"blocks.{_layer}.mlp.hook_post"][:, -1, _idx] 115 | dpo_acts_of_interest[(_layer, _idx)].extend(mlp_mid.tolist()) 116 | 117 | batch = torch.concat([batch, sampled.unsqueeze(-1)], dim=-1) 118 | dpo_batch = torch.concat([dpo_batch, sampled.unsqueeze(-1)], dim=-1) 119 | 120 | # %% 121 | 122 | d_mlp = model.cfg.d_mlp 123 | dpo_acts_mean = {} 124 | gpt2_acts_mean = {} 125 | num_mlps = 5 126 | for _vec in vectors_of_interest[:num_mlps]: 127 | 128 | _layer = _vec[0] 129 | _idx = _vec[1] 130 | gpt2_acts_mean[(_layer, _idx)] = np.mean( 131 | gpt2_acts_of_interest[(_layer, _idx)] 132 | ) 133 | dpo_acts_mean[(_layer, _idx)] = np.mean( 134 | dpo_acts_of_interest[(_layer, _idx)] 135 | ) 136 | 137 | 138 | # %% 139 | 140 | raw_data = [] 141 | num_mlps = 5 142 | for _vec in vectors_of_interest[:num_mlps]: 143 | _layer = _vec[0] 144 | _idx = _vec[1] 145 | 146 | raw_data.append( 147 | { 148 | "MLP": f"L:{_layer}\nIdx:{_idx}", 149 | "Mean Activation": dpo_acts_mean[(_layer, _idx)].item(), 150 | "Model": "DPO", 151 | } 152 | ) 153 | 154 | raw_data.append( 155 | { 156 | "MLP": f"L:{_layer}\nIdx:{_idx}", 157 | "Mean Activation": gpt2_acts_mean[(_layer, _idx)].item(), 158 | "Model": "GPT2", 159 | } 160 | ) 161 | 162 | 163 | # %% 164 | 165 | data = pd.DataFrame(raw_data) 166 | sns.set_theme(context="paper", style="ticks", rc={"lines.linewidth": 1}) 167 | 168 | sns.catplot( 169 | data=data, 170 | x="MLP", 171 | y="Mean Activation", 172 | hue="Model", 173 | hue_order=["GPT2", "DPO"], 174 | height=2, 175 | aspect=3.25 / 2, 176 | kind="bar", 177 | legend_out=False, 178 | ) 179 | 180 | 181 | plt.savefig("activation_drops.pdf", bbox_inches="tight", dpi=1200) 182 | -------------------------------------------------------------------------------- /toxicity/figures/fig_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for figures. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import einops 8 | from transformer_lens import ( 9 | HookedTransformer, 10 | ) 11 | 12 | 13 | def convert(orig_state_dict, cfg): 14 | state_dict = {} 15 | 16 | state_dict["embed.W_E"] = orig_state_dict["transformer.wte.weight"] 17 | state_dict["pos_embed.W_pos"] = orig_state_dict["transformer.wpe.weight"] 18 | 19 | for l in range(cfg.n_layers): 20 | state_dict[f"blocks.{l}.ln1.w"] = orig_state_dict[ 21 | f"transformer.h.{l}.ln_1.weight" 22 | ] 23 | state_dict[f"blocks.{l}.ln1.b"] = orig_state_dict[ 24 | f"transformer.h.{l}.ln_1.bias" 25 | ] 26 | 27 | # In GPT-2, q,k,v are produced by one big linear map, whose output is 28 | # concat([q, k, v]) 29 | W = orig_state_dict[f"transformer.h.{l}.attn.c_attn.weight"] 30 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1) 31 | W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads) 32 | W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads) 33 | W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads) 34 | 35 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 36 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 37 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 38 | 39 | qkv_bias = orig_state_dict[f"transformer.h.{l}.attn.c_attn.bias"] 40 | qkv_bias = einops.rearrange( 41 | qkv_bias, 42 | "(qkv index head)->qkv index head", 43 | qkv=3, 44 | index=cfg.n_heads, 45 | head=cfg.d_head, 46 | ) 47 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0] 48 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1] 49 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2] 50 | 51 | W_O = orig_state_dict[f"transformer.h.{l}.attn.c_proj.weight"] 52 | W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads) 53 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 54 | state_dict[f"blocks.{l}.attn.b_O"] = orig_state_dict[ 55 | f"transformer.h.{l}.attn.c_proj.bias" 56 | ] 57 | 58 | state_dict[f"blocks.{l}.ln2.w"] = orig_state_dict[ 59 | f"transformer.h.{l}.ln_2.weight" 60 | ] 61 | state_dict[f"blocks.{l}.ln2.b"] = orig_state_dict[ 62 | f"transformer.h.{l}.ln_2.bias" 63 | ] 64 | 65 | W_in = orig_state_dict[f"transformer.h.{l}.mlp.c_fc.weight"] 66 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in 67 | state_dict[f"blocks.{l}.mlp.b_in"] = orig_state_dict[ 68 | f"transformer.h.{l}.mlp.c_fc.bias" 69 | ] 70 | 71 | W_out = orig_state_dict[f"transformer.h.{l}.mlp.c_proj.weight"] 72 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out 73 | state_dict[f"blocks.{l}.mlp.b_out"] = orig_state_dict[ 74 | f"transformer.h.{l}.mlp.c_proj.bias" 75 | ] 76 | state_dict["unembed.W_U"] = orig_state_dict["lm_head.weight"].T 77 | 78 | state_dict["ln_final.w"] = orig_state_dict["transformer.ln_f.weight"] 79 | state_dict["ln_final.b"] = orig_state_dict["transformer.ln_f.bias"] 80 | return state_dict 81 | 82 | 83 | def load_hooked(model_name, weights_path): 84 | _model = HookedTransformer.from_pretrained(model_name) 85 | cfg = _model.cfg 86 | 87 | _weights = torch.load(weights_path, map_location=torch.device("cuda"))[ 88 | "state" 89 | ] 90 | weights = convert(_weights, cfg) 91 | model = HookedTransformer(cfg) 92 | model.load_and_process_state_dict(weights) 93 | model.tokenizer.padding_side = "left" 94 | model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 95 | return model 96 | 97 | 98 | def get_svd(_model, toxic_vector, num_mlp_vecs): 99 | scores = [] 100 | for layer in range(_model.cfg.n_layers): 101 | mlp_outs = _model.blocks[layer].mlp.W_out 102 | cos_sims = F.cosine_similarity( 103 | mlp_outs, toxic_vector.unsqueeze(0), dim=1 104 | ) 105 | _topk = cos_sims.topk(k=300) 106 | _values = [x.item() for x in _topk.values] 107 | _idxs = [x.item() for x in _topk.indices] 108 | topk = list(zip(_values, _idxs, [layer] * _topk.indices.shape[0])) 109 | scores.extend(topk) 110 | 111 | sorted_scores = sorted(scores, key=lambda x: x[0], reverse=True) 112 | top_vecs = [ 113 | _model.blocks[x[2]].mlp.W_out[x[1]] 114 | for x in sorted_scores[:num_mlp_vecs] 115 | ] 116 | top_vecs = [x / x.norm() for x in top_vecs] 117 | _top_vecs = torch.stack(top_vecs) 118 | 119 | svd = torch.linalg.svd(_top_vecs.transpose(0, 1)) 120 | return svd, sorted_scores 121 | -------------------------------------------------------------------------------- /toxicity/figures/logitlens.sync.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "id": "d71afc97", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "\"\"\"\n", 11 | "Module Doc String\n", 12 | "\"\"\"\n", 13 | "\n", 14 | "import os\n", 15 | "import numpy as np\n", 16 | "import pandas as pd\n", 17 | "import torch\n", 18 | "from tqdm import tqdm\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from matplotlib import ticker\n", 21 | "import seaborn as sns\n", 22 | "from fancy_einsum import einsum\n", 23 | "from transformer_lens import HookedTransformer\n", 24 | "from toxicity.figures.fig_utils import convert, load_hooked\n", 25 | "from constants import ROOT_DIR, MODEL_DIR" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 8, 31 | "id": "e0137eae", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "Loaded pretrained model gpt2-medium into HookedTransformer\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "\n", 44 | "model = load_hooked(\"gpt2-medium\", os.path.join(MODEL_DIR, \"dpo.pt\"))\n", 45 | "model.tokenizer.padding_side = \"left\"\n", 46 | "model.tokenizer.pad_token_id = model.tokenizer.eos_token_id" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 9, 52 | "id": "f744b3b0", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "\n", 57 | "prompts = list(\n", 58 | " np.load(os.path.join(ROOT_DIR, \"toxicity/figures/shit_prompts.npy\"))\n", 59 | ")\n", 60 | "tokens = model.to_tokens(prompts, prepend_bos=True)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 10, 66 | "id": "ed241c56", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | "100%|█████████████████████████████████████| 73/73 [00:17<00:00, 4.29it/s]\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "\n", 79 | "batchsize = 4\n", 80 | "all_dpo_prob = None\n", 81 | "all_gpt2_prob = None\n", 82 | "for idx in tqdm(range(0, tokens.shape[0], batchsize)):\n", 83 | " batch = tokens[idx : idx + batchsize].cuda()\n", 84 | " with torch.inference_mode():\n", 85 | " _, cache = model.run_with_cache(batch)\n", 86 | "\n", 87 | " accum = cache.accumulated_resid(layer=-1, incl_mid=True, apply_ln=True)\n", 88 | "\n", 89 | " # Project each layer and each position onto vocab space\n", 90 | " vocab_proj = einsum(\n", 91 | " \"layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab\",\n", 92 | " accum,\n", 93 | " model.W_U,\n", 94 | " )\n", 95 | "\n", 96 | " shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu()\n", 97 | " if all_dpo_prob is None:\n", 98 | " all_dpo_prob = shit_probs\n", 99 | " else:\n", 100 | " all_dpo_prob = torch.concat([all_dpo_prob, shit_probs], dim=1)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 11, 106 | "id": "61a1f967", 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "Loaded pretrained model gpt2-medium into HookedTransformer\n" 114 | ] 115 | }, 116 | { 117 | "name": "stderr", 118 | "output_type": "stream", 119 | "text": [ 120 | "100%|█████████████████████████████████████| 73/73 [00:16<00:00, 4.37it/s]\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "\n", 126 | "model = HookedTransformer.from_pretrained(\"gpt2-medium\")\n", 127 | "model.tokenizer.padding_side = \"left\"\n", 128 | "model.tokenizer.pad_token_id = model.tokenizer.eos_token_id\n", 129 | "\n", 130 | "for idx in tqdm(range(0, tokens.shape[0], batchsize)):\n", 131 | " batch = tokens[idx : idx + batchsize].cuda()\n", 132 | " with torch.inference_mode():\n", 133 | " _, cache = model.run_with_cache(batch)\n", 134 | "\n", 135 | " accum, accum_labels = cache.accumulated_resid(\n", 136 | " layer=-1, incl_mid=True, apply_ln=True, return_labels=True\n", 137 | " )\n", 138 | " vocab_proj = einsum(\n", 139 | " \"layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab\",\n", 140 | " accum,\n", 141 | " model.W_U,\n", 142 | " )\n", 143 | "\n", 144 | " shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu()\n", 145 | " if all_gpt2_prob is None:\n", 146 | " all_gpt2_prob = shit_probs\n", 147 | " else:\n", 148 | " all_gpt2_prob = torch.concat([all_gpt2_prob, shit_probs], dim=1)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 12, 154 | "id": "ae18867b", 155 | "metadata": { 156 | "scrolled": false 157 | }, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAACcCAYAAADrhbcmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwD0lEQVR4nO3deXgUVbrA4V/1mqSzdvbFhD0sssgFWZVV9CIYBkVARhxkVFREdATEdUBAjaIoXBFnYBQdkTjCgAsMIIuEUZRNhQCBQAxkJ+nsnd6q7h8NLZEkdIckhPa8z8NjuvrUqa/azpc6VWeRFEVREARB8EKqqx2AIAhCUxEJThAEryUSnCAIXkskOEEQvJZIcIIgeC2R4ARB8FoiwQmC4LVEghMEwWuJBCcIgtcSCU4QBK8lEpwgCF5LJDhBELyWSHCCIHgtkeAEQfBaIsG1AImJiXTp0oXi4uJL3rv33ntJTEzk7NmzHtdbXFzs9r5Lly7loYce8vgYgtCSiQTXQgQEBPDll1/W2JaTk8PRo0evUkSCcO3z+gRnNps5cuQIZrP5aodSr5EjR7Jx48Ya2zZs2MCIESNqbDt79izTp0+nT58+DBo0iJdffpnq6moAZFnmzTffpG/fvvTr1481a9bU2DcvL4/HHnuMvn37MnToUN555x0cDkfTnpggXEVen+BOnTrF2LFjOXXq1NUOpV4jRozgxIkTZGZmurZt2LCBP/zhD67XVquVKVOmEBYWxs6dO0lJSeHQoUMsWrQIgLVr17Jx40bWrFnD1q1ba1z9ORwOpk2bRmRkJDt37mT16tVs2rSJjz76qNnOUWh81dXVlJaW1vh34Q9efeXcKVNXuWuJ1ye4a4XBYGDYsGGuq7gff/wRg8FA27ZtXWX2799PYWEhzzzzDL6+vkRGRjJr1iz+/e9/I8syX375JRMnTqR169b4+/sza9Ys176HDx/m9OnTzJkzBx8fH+Li4pg2bRopKSnNfq5C47FYLJw5c4asrCyysrI4c+YMFoul3nLulKmvXGOx2eUmq/sCTZMfQXDbHXfcwcKFC5kxYwb//ve/a1y9ARQVFREeHo5Op3Nti4uLw2KxUFRURGFhIVFRUa73YmNjXT9nZ2djtVrp16+fa5uiKEiS1IRnJDSHi5dVqW+JlQvvuVPmcuWulKm8mrP5FSQmhKDTqpvsOCLBtSADBw6ksrKSffv2sXXrVmbMmFHjSxYdHU1hYSFWq9WV5LKystBqtQQFBREREUFOTo6rfEFBgevnyMhIgoOD+fbbb13bSktLKS8vb4YzE4RfVZptnDpbis3uwFRuIdLo12THEk3UFkStVnP77bfz17/+lR49ehASElLj/W7duhEbG8uiRYswm83k5+fz+uuvc/vtt6PT6Rg7diwfffQR6enpmM1mFi9eXGPf0NBQ3nrrLSwWCyUlJcycOZOFCxc292kKv2MWm4OMsyVoNSoCDHryi6uQ5aa7UhQJroVJSkrixIkTjBkz5pL3tFot7777Lvn5+QwePJikpCS6du3Kiy++CMAf/vAHJk+ezJQpUxg0aFCN+3darZYVK1Zw/PhxBg0axK233orRaOSVV15prlMTfuccDpnT2SXYHDIWq4Nqq51qq53yKmuTHVPy9nVRjxw5wtixY1m3bh1dunS52uEIQqMqLS0lKyurxrb4+HiCgoLqLedOmbrKNYQsK/ySV0ZRqZmyCivL/vUjPjo1Myf0JCRAT/v4kMtX0gDiCk4QhCYlywq5RZUUmKoAWPX5EWx2mfIqG2mnizBVWDBb7E1ybJHgBEFoMnaHTGZuGWcLyjHotbz/RRqllVa0Gmfq2X0oGwkoLm2ajvgiwQmC0CSqrXZOnDFRVGom2F/P+l0n+SWvHD+9hhl398BHp6bAZCa7oIICkxm7o/H7xV2VBLd7926SkpK47bbbmDJlSo3uDL9ltVoZN24c77zzTjNGKAjClSivsnI0s5hqi4OQAB9SD+XwfVo+kgSTR3YiLiKAPl2cfTb3/JSD3SFTUt74nYqbPcEVFxfz1FNPkZyczObNmxkyZAhz586ts/z8+fM5c+ZMM0YoCEJDybJCfnEVxzJNaFQqAvx0HDlVxIbdGQDccVNbEhOMAPTvGoMkwbFfTJRWWMk3VTV652KPO/p+++23NXrDeyo1NZXExEQSExMBmDBhAsnJyRQWFhIeHl6jbEpKClarlcGDB1+23oKCAgoLCy/ZnpGR0eBYBUFwX1W1jay8MsoqbQT66XAoCut2nmT3oWwAeneKZNANztE1ZVVWVCqJzq1DOXKqiO/T8hh+YzyVZhv+frr6DuMRjxPc7NmzUavVjB49mqSkJNq1a+fR/nl5eURHR7te63Q6QkJCyM3NrZHgfvrpJ1JSUvjwww+ZN2/eZetdu3Yty5Yt8ygWQRCunENWKDBVcbagHK1KhTHQh4zsEj7Zcpxzpc7B+v2uj+YPg9shSRJV1TbUKomYCH/+p2MER04Vse9oPjf1iKWs0np1E9yuXbv49ttv+fzzzxk/fjwJCQmMGTOGUaNGYTQaL7t/XeMfVapfW8vFxcU8++yzLFu2DF9fX7fiGj9+PEOHDr1ke0ZGRo1B54IgNA6HrFBWaSH3XCUVZudVmywrrN91kt0Hs1GA4AA944d3oOP5ZqnF5sBml+nYyohWo6J1TCAxYQZyzlVy8HgBraIDGzVGjxOcSqViwIABDBgwgHnz5rFr1y6WL19OcnIyAwcOZOLEiQwaNKjO/WNiYvjuu+9cr61WKyaTiZiYGNe2r7/+msrKSmbMmAFAbm4uOp2O8vJy5syZU2u9ERERREREeHo6giB4yGZ3UFxmIa+oEovVjo9OgzHAhxNnTKzdlk7R+au2vtdHccdNbfHVO9OM3S5TabbRIT4Eg68WgMhQA707R7Hhmwz2HsnjjpvaNGqsDR5s/+OPP/LFF1+wefNmNBoN999/P7Gxsbz66qts3769zmblgAEDWLBgAenp6XTo0IFPP/2U7t2717j6GzduHOPGjXO9fvrpp4mPj+eRRx5paLi/Oxs2bGDNmjWYTCYAjEYj06ZNY9CgQSxdupSPPvqIqKgoJElClmW0Wi0zZ86kf//+jB07FgCbzUZGRgYdO3YEICwsjJUrV7JlyxaWL1+OLMuoVCpmzpxZ7x81wXvkFlWSU1iBoij46rUYAn0xV9v5ZOtx9h7JA5xXbXcPbU+n1qGu/WRZobTKQqvoIIID9K7tYUG+XN8mlK9/yKKs0sqhE4W0irnykRMXeJzgFi9ezKZNmzCZTIwYMYLk5GT69u3ranZef/31TJo0qc4EZzQaefPNN5kzZw4Wi4XQ0FCSk5MB5zjMBQsW0LVr1ys4pealKAoWa9POiqvXqT2a1mjp0qVs2rSJJUuW0KFDBwCOHTvG1KlTXfcphw8fXmOg/ddff8306dPZsWMHGzZsAJyzB48YMcL1GpzTLj3//POkpKSQkJDA0aNHmTRpEtu3byc4OLgRzlZoqfKLq8jKLSPIX49GrUJRFA6dKGT9zpOUVTrHkw7sHsPtA1rjo/s1tdjsMmVVFmLC/IkIqXnLyVevITLUj96dItlx4CwZZ0sbNWaPE1xaWhozZsxgxIgR+Pj4XPJ+TEwMb7zxRr119O/fn/Xr11+y/eJfpIu11AHhiqIwZ1kqRzMvXSymMXVqZeTV6QPdSnJFRUWsWLGCNWvWuJIbQMeOHZk/f36dU5T379+f6upqsrOz672XqlKpeOmll0hISACgffv2SJJEUVGRSHBezFRezS+5ZQQZ9KhVEkczi9n0bSZn8p3TbUWE+DJ+eCJtYn+9+lIUhYoqG7Ki0Co6iPBg31q/w+EhfvTrHkWAQUe/rtGXvH8lPE5wMTEx3HHHHZdsnzlzJkuWLCEkJKTWm/1C8zh48CAGg6HWq+Bhw4YB1JgTDpxfxDVr1hAeHk779u3rrT86OrrGU/C33nqL6667jjZtGvfeidByVFRZOXm2BIOvltO5pXz130wyc8sA0GlVDO55HcN7x7uGX8GvV23B/j7ERwW47sPVxt9XS2igH51aQZC/vs5yDeFWgsvNzWXz5s2A8yrrt1/m8vJyUlNTGzWwa4EkSbw6fWCLaqLW1lHynnvuobKykurqajp27Ei7du3Ytm0bhw8fBpwPehISElixYkWtV+W1sVqtLFiwgL179/KPf/xDzAzspWx2mVN5JejUatbvPOm6z6ZVqxjQPYZhva6r0a1DURTKzTYUWaF1dBBhwb6oVJf/bkQa/Sg8Pxi/MbmV4CIjIzl48CAmkwm73c727dtrvK/T6XjhhRcaPbhrgSRJ+NTz16m5devWjfLyco4dO+Z6OPDxxx8DsG7dOteaD7+9B+eJc+fOMX36dAwGAykpKY0ynY7Q8jhkmXOlZhRFzcbU03yflodKgv7dYhjeO/6Sqy2r3UF5lRVjgA/XRQZ49HsRaNARHOjeH1dPuBWBSqXi7bffBuCll17i+eefb/RAhMYRGRnJgw8+yKxZs1i8eLHrPlxxcTF79uxBrb6y+e8rKir44x//yMCBA3nmmWdq9F8UvEul2Y7Kx87mb8/yfVoekgR//N9O3NChZncsRVEoq7IiAW1jgzAGunfVdjFJkogN88dia9zWkNsp9sIVwdixYzly5EitZcSEki3DzJkzad++PfPnz6e01PlUSqVSMXToUJ599ln++c9/NrjulJQUTp8+jU6nq7EozrX29Fuon6JASYWFb45m8+3hAiQJJt3asUZysztkzBY7NrtMSKDeedWma3hr5uLuI43F7Rl9e/bsyYEDB1zNnksqkqQWuQq7mNFX8GZNNaOv2WJj/Z5CDmdVIwETRyTSu3MUNrtMlcWGLCto1CpCAvQEB/gQ5K9rkfdh3U63Bw4cAJxXcoIgeLcDJ8s5nOUckTD+lg707hyF1e6g0mwnJsyPQH89fj5a1B42RZub2wmurmbpBZIk0blz5ysOSBCEq0uWFQ6crARgeK8Y+nSJxu6Qqaiy0jYumNAg98aHtwRuJ7g777yz3vdbahNVEATP5BRVk2uyATCga4RzmFWFhfjowGsquYGHDxkEQfB+3x93duJNiNAT4q+juNJCdJiBqCZcoLmpePwUta6mqmiiCsK1r9rqIO2Ms8Pt9fG+mK12jEEBxEYEtMiHCJfjdoK75557OHDgQJ1NVdFEFYRrX1pWBVUWBV+dRIxRi16nJiEqsMU/TKiLeIoqCALg7LB74EQFAB3jfFEUhUA/58wh16oG9crLzMxk06ZNFBYWEhcXx+23305kZGRjxyY0QGJiIu3atUOj0SDLMg6Hg1tuuYXHHnsMjUbD2bNnGTZsWI3+jFarlWHDhvGXv/zF1QzZunUrK1euxGQyodPpiIyM5JFHHqFnz55X69SEJlZYWs3pAufKVokxPmi1KnTaKxv5crV5nOC2bdvGE088QZ8+fYiKimL37t0sW7aMFStW0Lt376aIsUVTFAXF1vjLnV1M0uo9uv+xcuVKoqKcS7KZTCamTZtGZWUlzz33HABqtbrG1FTl5eWMGTOG6OhoJk2axMcff8z777/P4sWLXaMTUlNTefjhh1m8eDEDBw5sxLMTWopvDuWiKBAVrMXfTyLEX+fxkKuWxuME9+qrr7J06dIaK1199dVXLFq0qNY53ryZoijkrH4Wy9njTXocfVxHYiYvaNBN3pCQEObMmcPkyZOZOXNmrWUCAgK4/vrrycjIwGq1snjxYt56660aQ68GDhzIww8/zKuvvioSnBeqttrZ85NzfeJO1/kgAb4+LWcSiYbyuHFdVlbGTTfdVGPbLbfcQmZmZmPFdI1p+X/hOnbsiM1m49SpU7W+n5GRwd69e+nfvz/p6elUVFTU2hTt168f6enplJWVNXXIQjNyOGR27j9LYUk1GrVEfLiGAIMWtRdMpOBxiv7f//1f3n//faZOneratnbtWoYMGdKogV0LJEkiZvKCFtdEvWT/8/teWKHM4XCQlJQEgCzL6PV6pk+fzvDhw13dgOx2+yX1WK3WBscgtEyKonC2sII9P+YA0D7GB61ahb9v4w98vxrcTnCjR48GnF/yTz75hI8//piYmBgKCgr45Zdf6NGjR1PF2KJJkoSka/x5rBrTzz//jK+vL/Hx8RQWFl5yD+5i7dq1IygoiO+//57hw4fXeO+HH36gffv2BAY27tJuwtVTaDLz88lzpJ0uAqBjjB4fnRof3bX9cOECtxPc/fff35RxCE0kPz+f119/nfvuuw+9/vJ/lfV6PbNmzWLhwoVERka67sN98803vPvuu7z22mtNHbLQTMoqrRw5XUTKtnSsdpm2sQGEBqobfdrwq8ntBHfx3F+1qWsxE6H5TZ06FY1G41oScNSoUTzwwANu7z9u3DjCwsJ45ZVXMJlMOBwO4uLieOedd+jVq1cTRi40F7tdJuNsMf/6+gSmcgthQT5Mvq0dOdln610/4Vrj8ZlkZmayfPly8vPzkWUZcK6fmZmZecliJkLzO368/ie6cXFxpKWlXbaeIUOG/C7vq/5elFRaWL/rJFn55fjqNTyQ1BW9RibIC7qGXMzjxyTPPvssxcXFhIeHoygKN9xwA1lZWUyaNKkp4hMEoZHJCnyemsXhjCJUKokpozoTGuSDAvj5aK92eI3K4wR35MgRlixZwgMPPIBWq+WJJ57grbfe4ptvvmmK+ARBaGSHTpbx9f5cAMYNbU/buGBKKy0E+evQaa79riEX8/hsAgMDMRgMJCQkkJ6eDkCvXr345ZdfGj04QRAa3+4jznU6BveMo0+XKErKq4k0Ggi4aPk/b+Fxgmvfvj1///vfUavVBAUFsX//fo4cOeLxak27d+8mKSmJ2267jSlTplBQUHBJmfT0dO69916SkpIYOXIkf//73z0NVxCEi1SYHeSXOCezHPo/11FaacUY5Etc5LU5HdLleJzgZs2axWeffUZeXh6PPvookydPZty4cfzpT39yu47i4mKeeuopkpOT2bx5M0OGDGHu3LmXlHvssceYOHEiGzZsYM2aNaSkpLBr1y5PQxYE4bz0bOdcb7Fhfsgo+PtqrunpkC7H46eoHTt2ZNOmTYDziVyvXr2oqKi4ZLX7+qSmppKYmEhiYiIAEyZMIDk5mcLCQsLDwwHnk9mpU6cyYsQIAIKCgkhISCA7O7vWOgsKCigsLLxke0ZGhkfnJwje7EKCax8XiFajpnVsMFovu+92sQZ1ePnuu+/44osvKCwsJDY2lrvuusuj/fPy8oiOjna91ul0hISEkJub60pwWq2Wu+++21Vm165dHDhwgHnz5tVa59q1a1m2bFkDzkYQfh8URSEj1wxAu7gA2sQEor/Gp0O6HI8T3Nq1a0lOTmbUqFH06NGD7Oxs/vjHP/Lyyy9z6623ulWHoii1tvfrWiU9JSWFxYsXs3TpUmJiYmotM378eIYOHXrJ9oyMDGbNmuVWXILgzfJMViqrZTQqaBMb4HVdQmrjcYJ75513WLVqFd27d3dtu+OOO3jxxRfdTnAxMTF89913rtdWqxWTyXRJ8rLb7cyfP5///ve/rF692tWkrU1ERAQRERF1vi8Iv3fp2c6rt2ijlgBf7+rQWxePG98Oh4NOnTrV2NajR49an4LWZcCAAaSlpbm6mXz66ad0794do9FYo9zs2bM5efIkn332Wb3JTRCEy3M9YAjVedVwrPp4fJZ33303r732GrNnz0ar1WK323n77bcvO1b1YkajkTfffJM5c+ZgsVgIDQ0lOTkZgKSkJBYscE7u+OWXXxIfH8/kyZNd+06YMIGJEyd6GrYg/K5ZbA6yCpwr1ceFadFqvffBwsXcTnA33HADkiShKApms5mUlBRCQ0MxmUyYzWZiYmJcU2K7o3///rXOAHzxND6XG1cpCIJ7jv9SikMGg4+K6BC9V0xm6Q63E9yKFSuaMg5BEJrQjyed873FGrUE+DXuwwVFkZGklpkw3U5wN954o+tnh8PBzz//TE5ODuHh4fTs2dPjkQyCIDSfn0+ZAGfztLFWylLsNmylhdhLClDpfVH5B6PWG1DpfJDULeMen8dR5Obm8tBDD3HmzBnCw8MpKCggMjKSVatWERsb2xQxCoJwBc6VmMk553zAEB+uv+IEp8gO7OVF2M7losh21D4GFIcN+7lsbChIkgqV3g+V3g9J54NKo0PSaJF0Ps1+pedxgnv55Ze54YYb+Ne//oVOp6O6uppFixaxYMECli9f3hQxCoJwBQ4ed/ZwCAvUEBbkwWy9ioxsrcZRCYosO1/LDhwl+cg2CyqtHktOBrLVgi6yFdqQKCSVCkV2oNht2MuLURwOJElBAdR6A9rQGFS+zTfu1eME98MPP7Bz5050OufMAz4+PsydO5ebb7650YMTBOHK7T/mTHDXhWk96B6iYCsvxu6nxlJqQUECRUGSQJFUmE//RMXhb5DN5a49JJ0PuogE9JGtUQcYUfv4o/L1R/LxR633Q7FVU52djto3EK0x2vleEyc6jxOcVqulrKzMNaQKnAsHX1ixSRCElqPSbONg+oUEp3N7aJa9shSpJBfZV8EuOUCWUWQH1sIsKtL2oFicTV61IQhNUATWgl9QrNVYzh6vdZ1gSavHv8tNGK6/GdlWTfXZ46gDQtCHxyNpmm5EhccJ7rbbbuPxxx9n1qxZxMTEcPbsWRYvXuz2KAZBEJqeoigUlZr56r+ZVFXb8dGpSYjwcWv0gmypQnMiFb+sfVQegspayqgDwwjoNgS/tjcgqTUosgNbcS7W/EysBb/gMJcjV1cgmyuQLVUoNgvlh7ZRefw7Am8YgW/7XjgqS7HYT6KLaoNK2zQL3Xic4J588kleeOEF7r33XhwOBzqdjqSkJJ588smmiE8QBA/ZHTL7jxXw6dfppJ0uBqBTQhCBhstfKSl2K1LGXgxZ+wCQDCGoNRqQVEgqNSq9H36JffBt1RXpfF86xW4DtRpdWBy6sDjoMrBmnbKD6l+OULpvE47yIkr+u46KI6kE3TgKKTQWS84J9NFtUekavxXocYLbtWsXL730EgsWLKC0tJSwsDCvnChPEK5FFpvMp9tPseWHHGx2GUmCAd1iGNYzApWtpN59FdmBkvUz/hm7nXXFdiV86L0E1JIXFUVBtlQhW6uRtHqw2FAUkNQaVFp9jWanpFLj27obPvGdqTy+l/JD27CXFlC0dRX+3YZg6DQAS/YJ9DHtUOn9GvPj8DzBPf/88+zZswetVlvjPpwgCFeX3aHw3qYccoutALSKDuSuIe0xBvqgV1uxVF7aRUNRZBSbFcVuxXrmGIajm5EUGVtYG8xt+oPsQLY7nN07JBVIIFvMKHYrat9AfCISUPkGoNityBYzclUZjsoSHJZKOP9gAklCQgFUGDr2w69dT8oObKEybQ8VP+3AVpRNcL8xVJ89gT62HWofQ6N9Jh4nuD59+pCSkkJSUhL+/v6NFoggCFdm77FScout+OjU/GFQO3p0CKey2oZDUQg06Cm86Gaao7qS6vI85/0xRcZRWUbl16tROWzYA6OoTByKYnWOXUWWkR1WUGRQFFR+AWgjW9d4Cipp9c77aP7BKEocisPuTG7n/ymKjKPsHLbSQiSVhqA+o9GFx1OS+i8s2emc2/w3ggfciSY49OomuGPHjrF161YWLFiAj49PjebpgQMHGi0wQRDcZ7Y42P5jCQCjB1xHYkIw1VY7cZEBhAf7Ulnxa3cOxWHDcvoQclkejopi7CWFWIuyUaxmHH4hVHa+DdlqRR0QjCY4HN/gEOd+ssOZsFTqem9LSZIKSXPpAjZqHwOaoHCspjwc5cXoYzsQPno6RV9/gKO8mKJt7xM2chq60LhG+1w8TnALFy5stIMLgtA4dvxkwmyVCfFX071dCMYQAxFGv1q7hegOrMOcexTzb7ZLhhAqutyOwyGj9g9E4x9SY+SBpLryIV4qvR8+UW1wBEVgK8pB0VoJv/0RTLtTsGSnU52VRmC3xltw3KMEV15ejsFgoF27duj1TfNYVxAEz5gqbHx7tAyA3u0NGAN9uS4yoNayUkEGutyjIEno4xLRBkWgCY5EGxyJJSia4hPpqP0C0ASEAk338FDt648qth320kJs584SPPBubMXZ+La6vlGP43aC279/Pw8++CCVlZWEh4fz7rvv0qVLl0YNRhAEz/1nXzEOGWKMWtpG6fH3rePXWpHRH3YuGKVr35fQ/kkgO1AUZ0fearsDyS8ATYCRpkxuF0iSCm1wJGrfAKyFZ1ACjI0+SN/tka9LlixhxowZHDx4kLvuuos33nijUQMRBMFzp3PK+SnT+fSgdzs/woJ961zbRJ25H3XFORS1Dm3nm5GrK5z93lQaVD7+qAyBaAOMzqelzUil90Mf0w5dVBskrU+j1u12ujx69CgffvghAH/+85+55ZZbGjUQQRA8oygK/9xyEoC20TpaR/vWOdZUtprRHdsOgDnuBoIDjPjGta5xX626tBSkkiaPuzaSSo0uJKrR63U7VSuK4vrZYDBgt9sbPRhBENy35btfOJZViloFfdr5ERJQ99VP1fefo7KZcfgEYYnogMoQ1CgPDVo6t6/gLk5wgiBcPUczi3j/8zTSMp3DsLrE+3JdlH+dCzjbTHlYftoKQFX8/6AKDEVVSzcOb+RRgktLS3MlOofDUeM1IB46CEITOlNQzj8+P8IPafkAqFQSfTuH0fU6iQC/8wlLkXFUmLDaKpyjFBQZ07YPQHZgC4rFbmyNzjfwKp5F83I7wZnNZsaOHVtj28WvJUni6NGjjReZIAgAVFvtfLTpGF+knsIhOy8ourcP46YescSHa6ksKUQlgercaXQH/02JuZSS31YiSVTF/w+aoFD4nSw4Ax4kuGPHjjVlHIIg1OL7I3ks/+xHzpU6h021vy6YQT1jaR0TTEyYAbVSzdkK0B7Zgjbj2/NjPi+l6zgQJTS+0Qezt3QtY2UIQRBqyC+u4u8bfua7w3kABPnruLVPAj07RRITZiDQoEetkijO/AWfb/6GusxZzhLRHuOAsRhUChLn751LEmZdIJqy6qt4RleHSHCC0AJYbA7SThWx72g+B48XcKagAgBJgr7XR3FL7+to5VOOT9VR7D/lU1xaiK20gOqsNNR2G7JGT2WrPtgjEokKuw5ff4NrHQUUGZvVgVRZcJXPsvmJBCcIV0lVtY0f0vL55tBZDh4vxGaXa7zfJtKPsZ3stLZ9j7z1PcorSyivpR5bUAwVrfujMsai8wtCpfdD7VfzQYKqtLQJz6TluioJbvfu3bz++utYLBaio6N59dVXiYiIqFHGZDIxd+5csrKycDgczJo1i+HDh1+NcAXhilVb7BSXV2Mqs5BfXMXew7n8cDQfg1xOK3Uhw7UlBPrZCfdTCPGRCdLa0ZZlw0/VXOhxKqm1aIIjUPsFovILRG0Iwh4YRYndF22A0TnxpFBDsye44uJinnrqKVavXk1iYiKrV69m7ty5rFy5ska5efPm0aFDB959913Onj3L+PHj6dKlC9HR0c0dsiDUyeGQKau0UlpppbTCgqm0ClPhOcqKi6gymbCUm7BWVSE77KiQUUsKWux00Ji41b+QYFVVzQqt5/+dp9L7oY/tgE9cIj7XdULtH4ykdS6sLKk1lFVZ0Obk0BxjR69FzZ7gUlNTSUxMJDExEYAJEyaQnJxMYWGha4Zgu93Ojh072LTJOTA4Li6OgQMH8vnnn/Pggw82WWynDh/mXObJyxe8Vjo9/ybOWqN251zcPN/fdga/+Iner29dvq6a9Siu/yiuH5SLCzsnVDz/84Uy0oXy5ydbdP73wgSMsvM9Wa5Z5vzKUSjy+UHozu2KfH5fhw3JXo3aYUEtW9DIVjSKHa1kR4sDveQgXnIQ/9sTqm94pSShDgpHZ4xB5WNApfNF0vuh0vuiCYrA57pE5/J7Oh8k9aVzh6usMiK51a3ZE1xeXl6NqzCdTkdISAi5ubmuBGcymaiuriYq6texaVFRUeTm5tZZb0FBAYWFhZdsv9A3LyMjo964qquqyP/4r2ikayR5CVeFAtjP/7PUU8Ym6ZDVPqDzQaPTodGo0ajVSCoVkkqF2j8EfVRr9BGtnDPjqrWuRVxqyCkGiuuMp6Kigvz8fNcfBUmSqKqqumS27YvLuVOmvrpakjZt2tS7ZGmzJ7gLH/BvXTwDwsUfcF1lfmvt2rUsW7aszvdnzZrlaaiCILRw69atq3cEVbMnuJiYGL777jvXa6vVislkIiYmxrUtNDQUvV5PQUEBkZGRAOTn59OuXbs66x0/fjxDhw69ZHtZWRkZGRl07tz5spN0njt3jnXr1jF27FjCwsJc8b3//vv86U9/QqfTuVWmtrrcKdPcx2uJMYnjieNdXO5y2rRpU+/7zZ7gBgwYwIIFC0hPT6dDhw58+umndO/eHaPR6CqjVqsZNmwYH3/8MU888QTZ2dns3r2badOm1VlvRETEJU9iL+jXr59bseXm5mI0GunQoYOrGW2xWDAaja4E6U6Z2upyp0xzH68lxiSOJ453cbkr1eyD0oxGI2+++SZz5sxh5MiRbN68meTkZACSkpL4+eefAefyhBkZGYwaNYqpU6fy9NNP06pVq+YOVxCEa9hV6QfXv39/1q9ff8n2DRs2uH42Go313lNrLmq1mkGDBqFW1z13VmOVae7jtcSYxPHE8RqTGMlwGRqNhsGDBzdLmeY+XkuMSRxPHK8x/X7mTREE4XdHJDhBELyWSHCCIHgtkeAEQfBaIsEJguC1RIITBMFriQQnCILXEglOEASvJRLcRfz9/Rk0aFC908O4U6Yx62qJx2uJMYnj/T6O5ylJEUvWC4LgpcQVnCAIXkskOEEQvJZIcIIgeC0xm8h57ixl6I41a9bw8ccfI0kSvr6+PPvss3Tr1u2KYvvxxx+ZNGkS27Ztq7FOhSdOnDjB/PnzKS8vR6VS8cILL9CjR48G1bVt2zbeeustVCoV/v7+zJ8/n7Zt23pUx5IlSygsLGThwoWAc6qs9957D7vdTqdOnViwYIFbN5t/W8+yZcvYtGkTKpWK0NBQXnzxRVq3bt2gmC7YunUrTzzxBIcPH27w+e3bt4/k5GSqq6sxGAwsWrTIrbh+W88nn3zC6tWrUavVREVFsXDhwst+T+v6Tq5cuZJPP/0Uh8PBwIEDeeaZZ9BqL13Y5nJ1derUiZdffpm9e/ciSRIJCQnMmzfP7Vl5ExMTadeuHRrNr+koOjqad999163966UISlFRkXLjjTcqx44dUxRFUT744APl/vvv97ie/fv3K4MHD1aKiooURVGU7du3KwMGDFBkWW5wbOfOnVOSkpKUDh06KLm5uQ2qw2w2KwMHDlT+85//KIqiKDt27FAGDx7coLjMZrPStWtX5cSJE4qiKMrq1auVSZMmub3/mTNnlEceeUTp1q2b8swzzyiKoijp6elKv379lLy8PEVRFOXll19Wnn/+eY/r2bhxozJ27FilsrJSURRF+fDDD5W77rqrQTFdcPLkSWXo0KFKp06dGnx+eXl5Su/evZVDhw4piqIoH330kXLPPfd4XE9WVpbSs2dPpbCwUFEU5+c0e/bseuup6zu5c+dO5bbbblPKysoUu92uzJgxQ1mxYkWD6nrnnXeUhx56SLFarYqiKMorr7yiPP744/XWdbEr+W5fjmiiUvtShnv37q11la76BAUF8dJLL7mmX+/WrRtFRUWYzeYGxWW323nyySeveMGc1NRUwsPDGTFiBACDBg1i+fLllyzz5w6Hw4EkSZSeXym9qqoKH5/61sWrae3atfTv358pU6a4tm3bto1Bgwa51t+YNGkSn3/+ObIs11VNrfUkJCTw3HPP4efnBzg//+zs7AbFBM5VpmbNmsWzzz57Ree3efNm+vbtS/fu3QEYN24c8+bN87geWZZxOBxUVVWhKIpbn31d38mtW7dy++23ExAQgFqtZuLEibVOQutOXZ06deLJJ590Xf117drVrc+9OYgmKu4tZeiOtm3buppqsiyzaNEiBg8e7PqF81RycjJ9+vRhwIABDdr/gtOnTxMREcFzzz1HWloa/v7+PPXUU/WuUlYXg8HAvHnzuO+++zAajVgsFlavXu32/n/5y18AWLp0qWtbbm5ujc8/KiqKqqoqSkpKaqzVcbl6Lr4VYLFYeO211xg5cmSDYgKYO3cu9913Hx06dLhsHfXVdfr0aQwGA08++aTr/8XTTz/tcT0JCQlMmzaNkSNHEhQUhF6vZ82aNfXWU9d3Mjc3lxtuuMFV7nLLctZX18UTV5aUlPB///d/jBs3rt66fmvq1Kk1mqjJycmuC44rIa7gcG8pQ09UVFQwffp0srOzefXVVxtUxxdffEFWVhYPP/xwg/a/mN1uZ8+ePYwZM4Z169YxdepUHnzwQSoqKjyu6/jx47z99tts3LiRb775hueee44HHniAqqqqy+9cj9o+/9q2uaOgoID77ruPgIAAZs+e3aA6VqxYQVhYGElJSQ3a/2IXFjJ/9NFHWb9+PUOGDKl3AaW6pKam8sUXX/D111+TmprKhAkTmDZtmltX4rV9J3/7+br7edf1/T516hT33HMPvXv35r777vPgzGDlypVs2LDB9a8xkhuIBAc4lzLMz893va5tKUN3nT59mrvuugt/f38++OADAgMDGxTTZ599RlZWFmPGjHH9kk2dOpV9+/Z5XFdkZCStWrWiV69egLOJqtFoOHXqlMd1paam0rVrV9dybaNHj8bhcFx2Ye36/Pbzz8/Px2AwEBQU5HFdP/74I3feeSe9evVi2bJlriXpPLV+/Xr27dtHUlISDz74IA6Hg6SkJLKysjyuKzIykh49eriufsaOHUtmZibFxXUv6Fyb7du3c/PNNxMZGYkkSUyePJm0tDRMJlO9+9X2naztM3fn+17X93vHjh1MnDiRCRMm8Ne//rXBf5wam0hwOJcyTEtLIz09HaDWpQzdkZOTw6RJkxg3bhzJycmXXYe1Pv/4xz/46quvXH/RwPlX7kKS8sTNN99Mbm4uhw4dAmD//v1YrdbLrilZmy5durB//37y8vIA+OGHH7Db7W4/qazNsGHD2LVrl+sX7p///CfDhw/3+Ar6yJEj3H///Tz99NMNboJfsHnzZj7//HPX0121Ws2GDRuIj4/3uK5bbrmFgwcPkpmZCcCWLVuIj48nODjYo3q6dOnC7t27KS8vB+A///kPCQkJ9X5P6/pO3nLLLXz55ZeUlZUhyzKffPKJ6x6tp3Xt3LmT2bNn8/bbbzN58mSPzqmpiXtw1FzK0GKxEBoa6lrK0BMrV66krKyMjRs3snHjRtf29957z3UD/WoICwtjxYoVLFq0iKqqKtRqNUuXLm3QmL++ffvy6KOPMmXKFLRaLX5+fixfvvyKxg+2b9+e2bNn8+c//xmbzUbr1q155ZVXPK5n6dKlyLLMe++9x3vvvefafvFqbVdDx44dWbhwIY8//jh2ux1/f3+WLl3qcQIeO3Ysubm53HXXXej1eoxGI8uXL693n/q+k3feeScTJ07EbrfTs2fPy94Oqasus9mMJEksWrTItS0iIoK//e1vHp1fUxBjUQVB8FqiiSoIgtcSCU4QBK8lEpwgCF5LJDhBELyWSHCCIHgtkeAEQfBaIsEJguC1RIITWpTExER+/vnnqx2G4CVEghMEwWuJBCdcM3Jycnj00UcZPHgw3bp1Y8yYMRw4cABwTkTw+uuvu8pemKV2z549gHN88W233UavXr249957OXnypKtsYmIiL730EjfeeGON4UbCtU8kOOGa8fzzzxMZGcmWLVv44Ycf6Ny5syup3XHHHXz11Veust999x0qlYp+/fqxZcsWlixZwuLFi/n2228ZMWIE999/f42JSE0mE6mpqTz22GPNfl5C0xEJTrhmLFy4kKeeegqA7OxsAgMDXTOQjBgxApPJxMGDBwHYuHEjo0ePRqVSkZKSwqRJk+jSpQtarZZ7770XPz8/du7c6ap75MiR6HQ6AgICmv28hKYjZhMRrhmnT5/mtddeIycnh7Zt22IwGFyTPfr6+nLrrbfy5Zdf0qlTJ7Zu3conn3wCOJu2+/btY9WqVa667HY7OTk5rtcNWWBIaPlEghOuCTabjenTp/PCCy+4JgBdu3ZtjXtpSUlJzJo1i969e5OQkOCaajwyMpJJkyYxadIkV9nMzMwa09G3lAkahcYlmqhCi1NcXExeXp7rX0FBAVarlerqatciK8ePH2fVqlVYrVbXfn369EGj0bB8+fIaU43feeedrFq1ivT0dBRFYdu2bYwaNYrTp083+7kJzUtcwQktzoMPPljjdXBwMHv37mXevHksWrSIuXPnEhsby7hx43jjjTcoLi7GaDSiUqkYPXo0q1atYtSoUa79R40aRXl5OTNmzHBNzZ2cnMz111/f3KcmNDMx4aXgVT755BN27NjBihUrrnYoQgsgmqiCVyguLubIkSN88MEHTJgw4WqHI7QQIsEJXuHgwYPcc8899OrViyFDhlztcIQWQjRRBUHwWuIKThAEryUSnCAIXkskOEEQvJZIcIIgeC2R4ARB8FoiwQmC4LVEghMEwWuJBCcIgtf6f36ruv6XBwyjAAAAAElFTkSuQmCC", 162 | "text/plain": [ 163 | "
" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "\n", 172 | "data = []\n", 173 | "for layer_idx in range(all_gpt2_prob.shape[0]):\n", 174 | " for prob in all_gpt2_prob[layer_idx]:\n", 175 | " data.append(\n", 176 | " {\n", 177 | " \"Layer\": layer_idx,\n", 178 | " \"Model\": \"GPT2\",\n", 179 | " \"Probability\": prob.item(),\n", 180 | " }\n", 181 | " )\n", 182 | "\n", 183 | "for layer_idx in range(all_dpo_prob.shape[0]):\n", 184 | " for prob in all_dpo_prob[layer_idx]:\n", 185 | " data.append(\n", 186 | " {\n", 187 | " \"Layer\": layer_idx,\n", 188 | " \"Model\": \"DPO\",\n", 189 | " \"Probability\": prob.item(),\n", 190 | " }\n", 191 | " )\n", 192 | "\n", 193 | "data = pd.DataFrame(data)\n", 194 | "\n", 195 | "sns.set_theme(context=\"paper\", style=\"ticks\", rc={\"lines.linewidth\": 1.5})\n", 196 | "fig = sns.relplot(\n", 197 | " data=data,\n", 198 | " x=\"Layer\",\n", 199 | " y=\"Probability\",\n", 200 | " hue=\"Model\",\n", 201 | " hue_order=[\"GPT2\", \"DPO\"],\n", 202 | " kind=\"line\",\n", 203 | " height=1.5,\n", 204 | " aspect=3.25 / 1.5,\n", 205 | ")\n", 206 | "\n", 207 | "\n", 208 | "major_tick_locs, major_labels = plt.xticks()\n", 209 | "minor_tick_locs, minor_labels = plt.xticks(minor=True)\n", 210 | "\n", 211 | "fig.ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", 212 | "fig.ax.xaxis.set_major_formatter(ticker.ScalarFormatter())\n", 213 | "major_tick_locs, major_labels = plt.xticks()\n", 214 | "new_tick_locs = [\n", 215 | " x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 == 0\n", 216 | "]\n", 217 | "new_minor_tick_locs = [\n", 218 | " x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 != 0\n", 219 | "]\n", 220 | "\n", 221 | "\n", 222 | "major_labels = [x if x % 2 == 0 else \"\" for x in range(24)] + [\"F\"]\n", 223 | "\n", 224 | "plt.xticks(ticks=new_tick_locs, labels=major_labels)\n", 225 | "plt.xticks(\n", 226 | " ticks=new_minor_tick_locs,\n", 227 | " labels=[\"\" for _ in new_minor_tick_locs],\n", 228 | " minor=True,\n", 229 | ")\n", 230 | "\n", 231 | "fig.ax.tick_params(axis=\"x\", which=\"major\", length=10)\n", 232 | "fig.ax.tick_params(axis=\"x\", which=\"both\", color=\"grey\")\n", 233 | "fig.ax.set_ylim(ymin=0)\n", 234 | "fig.ax.set_ylim(ymax=0.48)\n", 235 | "fig.ax.fill_betweenx([0, 0.48], 37, 38, alpha=0.35, facecolor=\"grey\")\n", 236 | "fig.ax.fill_betweenx([0, 0.48], 39, 40, alpha=0.35, facecolor=\"grey\")\n", 237 | "fig.ax.fill_betweenx([0, 0.48], 41, 42, alpha=0.35, facecolor=\"grey\")\n", 238 | "sns.move_legend(fig, \"upper left\", bbox_to_anchor=(0.22, 1))\n", 239 | "\n", 240 | "plt.savefig(\"logitlens.pdf\", bbox_inches=\"tight\", dpi=1200)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "77b8e31f", 247 | "metadata": { 248 | "scrolled": false 249 | }, 250 | "outputs": [], 251 | "source": [] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.8.8" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 5 275 | } 276 | -------------------------------------------------------------------------------- /toxicity/figures/logitlens.sync.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.3.4 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% 16 | """ 17 | Module Doc String 18 | """ 19 | 20 | import os 21 | import numpy as np 22 | import pandas as pd 23 | import torch 24 | from tqdm import tqdm 25 | import matplotlib.pyplot as plt 26 | from matplotlib import ticker 27 | import seaborn as sns 28 | from fancy_einsum import einsum 29 | from transformer_lens import HookedTransformer 30 | from toxicity.figures.fig_utils import convert, load_hooked 31 | from constants import ROOT_DIR, MODEL_DIR 32 | 33 | 34 | # %% 35 | 36 | model = load_hooked("gpt2-medium", os.path.join(MODEL_DIR, "dpo.pt")) 37 | model.tokenizer.padding_side = "left" 38 | model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 39 | 40 | # %% 41 | 42 | prompts = list( 43 | np.load(os.path.join(ROOT_DIR, "toxicity/figures/shit_prompts.npy")) 44 | ) 45 | tokens = model.to_tokens(prompts, prepend_bos=True) 46 | 47 | # %% 48 | 49 | batchsize = 4 50 | all_dpo_prob = None 51 | all_gpt2_prob = None 52 | for idx in tqdm(range(0, tokens.shape[0], batchsize)): 53 | batch = tokens[idx : idx + batchsize].cuda() 54 | with torch.inference_mode(): 55 | _, cache = model.run_with_cache(batch) 56 | 57 | accum = cache.accumulated_resid(layer=-1, incl_mid=True, apply_ln=True) 58 | 59 | # Project each layer and each position onto vocab space 60 | vocab_proj = einsum( 61 | "layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab", 62 | accum, 63 | model.W_U, 64 | ) 65 | 66 | shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu() 67 | if all_dpo_prob is None: 68 | all_dpo_prob = shit_probs 69 | else: 70 | all_dpo_prob = torch.concat([all_dpo_prob, shit_probs], dim=1) 71 | 72 | # %% 73 | 74 | model = HookedTransformer.from_pretrained("gpt2-medium") 75 | model.tokenizer.padding_side = "left" 76 | model.tokenizer.pad_token_id = model.tokenizer.eos_token_id 77 | 78 | for idx in tqdm(range(0, tokens.shape[0], batchsize)): 79 | batch = tokens[idx : idx + batchsize].cuda() 80 | with torch.inference_mode(): 81 | _, cache = model.run_with_cache(batch) 82 | 83 | accum, accum_labels = cache.accumulated_resid( 84 | layer=-1, incl_mid=True, apply_ln=True, return_labels=True 85 | ) 86 | vocab_proj = einsum( 87 | "layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab", 88 | accum, 89 | model.W_U, 90 | ) 91 | 92 | shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu() 93 | if all_gpt2_prob is None: 94 | all_gpt2_prob = shit_probs 95 | else: 96 | all_gpt2_prob = torch.concat([all_gpt2_prob, shit_probs], dim=1) 97 | 98 | 99 | # %% 100 | 101 | data = [] 102 | for layer_idx in range(all_gpt2_prob.shape[0]): 103 | for prob in all_gpt2_prob[layer_idx]: 104 | data.append( 105 | { 106 | "Layer": layer_idx, 107 | "Model": "GPT2", 108 | "Probability": prob.item(), 109 | } 110 | ) 111 | 112 | for layer_idx in range(all_dpo_prob.shape[0]): 113 | for prob in all_dpo_prob[layer_idx]: 114 | data.append( 115 | { 116 | "Layer": layer_idx, 117 | "Model": "DPO", 118 | "Probability": prob.item(), 119 | } 120 | ) 121 | 122 | data = pd.DataFrame(data) 123 | 124 | sns.set_theme(context="paper", style="ticks", rc={"lines.linewidth": 1.5}) 125 | fig = sns.relplot( 126 | data=data, 127 | x="Layer", 128 | y="Probability", 129 | hue="Model", 130 | hue_order=["GPT2", "DPO"], 131 | kind="line", 132 | height=1.5, 133 | aspect=3.25 / 1.5, 134 | ) 135 | 136 | 137 | major_tick_locs, major_labels = plt.xticks() 138 | minor_tick_locs, minor_labels = plt.xticks(minor=True) 139 | 140 | fig.ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 141 | fig.ax.xaxis.set_major_formatter(ticker.ScalarFormatter()) 142 | major_tick_locs, major_labels = plt.xticks() 143 | new_tick_locs = [ 144 | x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 == 0 145 | ] 146 | new_minor_tick_locs = [ 147 | x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 != 0 148 | ] 149 | 150 | 151 | major_labels = [x if x % 2 == 0 else "" for x in range(24)] + ["F"] 152 | 153 | plt.xticks(ticks=new_tick_locs, labels=major_labels) 154 | plt.xticks( 155 | ticks=new_minor_tick_locs, 156 | labels=["" for _ in new_minor_tick_locs], 157 | minor=True, 158 | ) 159 | 160 | fig.ax.tick_params(axis="x", which="major", length=10) 161 | fig.ax.tick_params(axis="x", which="both", color="grey") 162 | fig.ax.set_ylim(ymin=0) 163 | fig.ax.set_ylim(ymax=0.48) 164 | fig.ax.fill_betweenx([0, 0.48], 37, 38, alpha=0.35, facecolor="grey") 165 | fig.ax.fill_betweenx([0, 0.48], 39, 40, alpha=0.35, facecolor="grey") 166 | fig.ax.fill_betweenx([0, 0.48], 41, 42, alpha=0.35, facecolor="grey") 167 | sns.move_legend(fig, "upper left", bbox_to_anchor=(0.22, 1)) 168 | 169 | plt.savefig("logitlens.pdf", bbox_inches="tight", dpi=1200) 170 | 171 | # %% 172 | -------------------------------------------------------------------------------- /toxicity/figures/pca.sync.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.3.4 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% 16 | import os 17 | import json 18 | 19 | import pandas as pd 20 | import einops 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from fancy_einsum import einsum 25 | from tqdm import tqdm 26 | 27 | import seaborn as sns 28 | import matplotlib.pyplot as plt 29 | 30 | from matplotlib.gridspec import GridSpec 31 | from transformer_lens import ( 32 | HookedTransformer, 33 | ) 34 | from toxicity.figures.fig_utils import convert, load_hooked, get_svd 35 | from constants import MODEL_DIR, DATA_DIR 36 | 37 | torch.set_grad_enabled(False) 38 | 39 | # %% 40 | 41 | model = load_hooked( 42 | "gpt2-medium", 43 | os.path.join(MODEL_DIR, "dpo.pt"), 44 | ) 45 | gpt2 = HookedTransformer.from_pretrained("gpt2-medium") 46 | gpt2.tokenizer.padding_side = "left" 47 | gpt2.tokenizer.pad_token_id = gpt2.tokenizer.eos_token_id 48 | 49 | toxic_vector = torch.load(os.path.join(MODEL_DIR, "probe.pt")) 50 | 51 | # %% 52 | 53 | with open( 54 | os.path.join(DATA_DIR, "intervene_data/challenge_prompts.jsonl"), "r" 55 | ) as file_p: 56 | data = file_p.readlines() 57 | 58 | prompts = [json.loads(x.strip())["prompt"] for x in data] 59 | tokenized_prompts = model.to_tokens(prompts, prepend_bos=True).cuda() 60 | 61 | # %% 62 | 63 | 64 | _, scores_gpt2 = get_svd(gpt2, toxic_vector, 128) 65 | vectors_of_interest = [ 66 | (_score_obj[2], _score_obj[1], _score_obj[0]) 67 | for _score_obj in scores_gpt2[:64] 68 | ] 69 | 70 | 71 | # %% 72 | 73 | gpt2_resid = [] 74 | dpo_resid = [] 75 | sample_size = 50 76 | batch_size = 4 77 | print("Grabbing mlp mids...") 78 | _vec = vectors_of_interest[0] 79 | _layer = _vec[0] 80 | _idx = _vec[1] 81 | 82 | for idx in tqdm(range(0, sample_size, batch_size)): 83 | batch = tokenized_prompts[idx : idx + batch_size, :] 84 | dpo_batch = batch.clone() 85 | 86 | with torch.inference_mode(): 87 | _, cache = gpt2.run_with_cache(batch) 88 | resid = cache[f"blocks.{_layer}.hook_resid_mid"][:, -1, :] 89 | 90 | gpt2_resid.extend(resid.cpu().tolist()) 91 | 92 | with torch.inference_mode(): 93 | _, cache = model.run_with_cache(dpo_batch) 94 | resid = cache[f"blocks.{_layer}.hook_resid_mid"][:, -1, :] 95 | dpo_resid.extend(resid.cpu().tolist()) 96 | 97 | 98 | w_ins = [ 99 | gpt2.blocks[_layer].mlp.W_in[:, _idx].cpu(), 100 | model.blocks[_layer].mlp.W_in[:, _idx].cpu(), 101 | ] 102 | 103 | gpt2_stacked = torch.stack([torch.Tensor(x) for x in gpt2_resid], dim=0) 104 | dpo_stacked = torch.stack([torch.Tensor(x) for x in dpo_resid], dim=0) 105 | gpt2_dots = einsum("sample d_model, d_model", gpt2_stacked, w_ins[0]) 106 | dpo_dots = einsum("sample d_model, d_model", dpo_stacked, w_ins[1]) 107 | gpt_acts = model.blocks[0].mlp.act_fn(gpt2_dots) 108 | dpo_acts = model.blocks[0].mlp.act_fn(dpo_dots) 109 | 110 | 111 | # %% 112 | 113 | all_data = torch.concat([gpt2_stacked, dpo_stacked], dim=0) 114 | mean = all_data.mean(dim=0) 115 | stddev = all_data.std(dim=0) 116 | normalized = (all_data - mean) / stddev 117 | 118 | U, S, V = torch.pca_lowrank(normalized) 119 | 120 | diff = dpo_stacked - gpt2_stacked 121 | diff_mean = diff.mean(dim=0) 122 | print(diff_mean.shape) 123 | 124 | comps = torch.concat([diff_mean.unsqueeze(-1), V], dim=1) 125 | comps = comps[:, :2] 126 | projected = torch.mm(normalized, comps) 127 | 128 | pca_raw = [] 129 | num_samples = 30 130 | for idx in range(num_samples): 131 | 132 | _activation = gpt_acts[idx].item() 133 | if _activation > 15: 134 | act = "High (> 15)" 135 | elif _activation > 0: 136 | act = "Low (> 0)" 137 | else: 138 | act = "None" 139 | 140 | print(_activation) 141 | pca_raw.append( 142 | { 143 | "Model": "GPT2", 144 | "x": projected[idx, 0].item(), 145 | "y": projected[idx, 1].item(), 146 | "Activated": act, 147 | } 148 | ) 149 | 150 | _offset = len(gpt2_resid) 151 | print("____") 152 | for idx in range(num_samples): 153 | _activation = dpo_acts[idx].item() 154 | if _activation > 15: 155 | act = "High (> 15)" 156 | elif _activation > 0: 157 | act = "Low (> 0)" 158 | else: 159 | act = "None" 160 | print(_activation) 161 | pca_raw.append( 162 | { 163 | "Model": "DPO", 164 | "x": projected[_offset + idx, 0].item(), 165 | "y": projected[_offset + idx, 1].item(), 166 | "Activated": act, 167 | } 168 | ) 169 | 170 | pca_data = pd.DataFrame(pca_raw) 171 | sns.set_theme(context="paper", style="ticks", rc={"lines.linewidth": 1}) 172 | 173 | fig = sns.relplot( 174 | pca_data, 175 | x="x", 176 | y="y", 177 | hue="Activated", 178 | palette={"High (> 15)": "red", "Low (> 0)": "orange", "None": "green"}, 179 | hue_order=["High (> 15)", "Low (> 0)", "None"], 180 | style="Model", 181 | markers={"GPT2": "o", "DPO": "^"}, 182 | height=2.5, 183 | aspect=3.25 / 2.5, 184 | s=60, 185 | legend="full", 186 | ) 187 | 188 | fig.ax.set_xticks([]) 189 | fig.ax.set_yticks([]) 190 | fig.ax.xaxis.label.set_text("Shift Component") 191 | fig.ax.yaxis.label.set_text("Principle Component") 192 | fig.ax.xaxis.label.set_visible(True) 193 | fig.ax.yaxis.label.set_visible(True) 194 | 195 | _offset = len(gpt2_resid) 196 | for idx in range(num_samples): 197 | gpt2_x = projected[idx, 0].item() 198 | gpt2_y = projected[idx, 1].item() 199 | dpo_x = projected[_offset + idx, 0].item() 200 | dpo_y = projected[_offset + idx, 1].item() 201 | fig.ax.plot( 202 | [gpt2_x, dpo_x], [gpt2_y, dpo_y], color="black", ls=":", zorder=0 203 | ) 204 | 205 | plt.savefig(f"pca_layer{_layer}.pdf", bbox_inches="tight", dpi=1200) 206 | 207 | 208 | # %% 209 | 210 | 211 | fig = plt.figure(figsize=(6.75, 3.5)) 212 | gs = GridSpec(1, 3) 213 | boundaries = [ 214 | (0.1, 10), 215 | (0.1, 10), 216 | (0.1, 10), 217 | ] 218 | 219 | for vec_idx in [1, 2, 3]: 220 | 221 | gpt2_resid = [] 222 | dpo_resid = [] 223 | sample_size = 50 224 | batch_size = 4 225 | print("Grabbing mlp mids...") 226 | _vec = vectors_of_interest[vec_idx] 227 | _layer = _vec[0] 228 | _idx = _vec[1] 229 | 230 | for idx in tqdm(range(0, sample_size, batch_size)): 231 | batch = tokenized_prompts[idx : idx + batch_size, :] 232 | dpo_batch = batch.clone() 233 | 234 | with torch.inference_mode(): 235 | _, cache = gpt2.run_with_cache(batch) 236 | resid = cache[f"blocks.{_layer}.hook_resid_mid"][:, -1, :] 237 | 238 | gpt2_resid.extend(resid.cpu().tolist()) 239 | 240 | with torch.inference_mode(): 241 | _, cache = model.run_with_cache(dpo_batch) 242 | resid = cache[f"blocks.{_layer}.hook_resid_mid"][:, -1, :] 243 | dpo_resid.extend(resid.cpu().tolist()) 244 | 245 | w_ins = [ 246 | gpt2.blocks[_layer].mlp.W_in[:, _idx].cpu(), 247 | model.blocks[_layer].mlp.W_in[:, _idx].cpu(), 248 | ] 249 | 250 | gpt2_stacked = torch.stack([torch.Tensor(x) for x in gpt2_resid], dim=0) 251 | dpo_stacked = torch.stack([torch.Tensor(x) for x in dpo_resid], dim=0) 252 | gpt2_dots = einsum("sample d_model, d_model", gpt2_stacked, w_ins[0]) 253 | dpo_dots = einsum("sample d_model, d_model", dpo_stacked, w_ins[1]) 254 | gpt_acts = model.blocks[0].mlp.act_fn(gpt2_dots) 255 | dpo_acts = model.blocks[0].mlp.act_fn(dpo_dots) 256 | 257 | all_data = torch.concat([gpt2_stacked, dpo_stacked], dim=0) 258 | mean = all_data.mean(dim=0) 259 | stddev = all_data.std(dim=0) 260 | normalized = (all_data - mean) / stddev 261 | 262 | U, S, V = torch.pca_lowrank(normalized) 263 | 264 | diff = dpo_stacked - gpt2_stacked 265 | diff_mean = diff.mean(dim=0) 266 | 267 | comps = torch.concat([diff_mean.unsqueeze(-1), V], dim=1) 268 | comps = comps[:, :2] 269 | projected = torch.mm(normalized, comps) 270 | 271 | pca_raw = [] 272 | num_samples = 30 273 | 274 | _boundary = boundaries[vec_idx - 1] 275 | for idx in range(num_samples): 276 | 277 | _activation = gpt_acts[idx].item() 278 | if _activation > _boundary[1]: 279 | act = f"High (> {_boundary[1]})" 280 | elif _activation > _boundary[0]: 281 | act = f"Low (> {_boundary[0]})" 282 | else: 283 | act = "None" 284 | 285 | print(_activation) 286 | pca_raw.append( 287 | { 288 | "Model": "GPT2", 289 | "x": projected[idx, 0].item(), 290 | "y": projected[idx, 1].item(), 291 | "Activated": act, 292 | } 293 | ) 294 | 295 | _offset = len(gpt2_resid) 296 | print("____") 297 | for idx in range(num_samples): 298 | _activation = dpo_acts[idx].item() 299 | if _activation > _boundary[1]: 300 | act = f"High (> {_boundary[1]})" 301 | elif _activation > _boundary[0]: 302 | act = f"Low (> {_boundary[0]})" 303 | else: 304 | act = "None" 305 | print(_activation) 306 | pca_raw.append( 307 | { 308 | "Model": "DPO", 309 | "x": projected[_offset + idx, 0].item(), 310 | "y": projected[_offset + idx, 1].item(), 311 | "Activated": act, 312 | } 313 | ) 314 | 315 | pca_data = pd.DataFrame(pca_raw) 316 | sns.set_theme(context="paper", style="ticks", rc={"lines.linewidth": 1}) 317 | 318 | ax = fig.add_subplot(gs[0, vec_idx - 1]) 319 | 320 | legend = None 321 | if vec_idx == 1: 322 | legend = "full" 323 | 324 | sns.scatterplot( 325 | pca_data, 326 | x="x", 327 | y="y", 328 | hue="Activated", 329 | palette={ 330 | f"High (> {_boundary[1]})": "red", 331 | f"Low (> {_boundary[0]})": "orange", 332 | "None": "green", 333 | }, 334 | hue_order=[ 335 | f"High (> {_boundary[1]})", 336 | f"Low (> {_boundary[0]})", 337 | "None", 338 | ], 339 | style="Model", 340 | markers={"GPT2": "o", "DPO": "^"}, 341 | s=60, 342 | legend=legend, 343 | ax=ax, 344 | ) 345 | 346 | ax.set_xticks([]) 347 | ax.set_yticks([]) 348 | ax.xaxis.label.set_text("Shift Component") 349 | ax.yaxis.label.set_text("Principle Component") 350 | ax.xaxis.label.set_visible(True) 351 | ax.yaxis.label.set_visible(True) 352 | 353 | _offset = len(gpt2_resid) 354 | for idx in range(num_samples): 355 | gpt2_x = projected[idx, 0].item() 356 | gpt2_y = projected[idx, 1].item() 357 | dpo_x = projected[_offset + idx, 0].item() 358 | dpo_y = projected[_offset + idx, 1].item() 359 | ax.plot( 360 | [gpt2_x, dpo_x], [gpt2_y, dpo_y], color="black", ls=":", zorder=0 361 | ) 362 | 363 | plt.savefig(f"pca_layer_appx.pdf", bbox_inches="tight", dpi=1200) 364 | 365 | -------------------------------------------------------------------------------- /toxicity/figures/resid_diff_plot.sync.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: percent 7 | # format_version: '1.3' 8 | # jupytext_version: 1.3.4 9 | # kernelspec: 10 | # display_name: Python 3 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # %% 16 | import os 17 | import json 18 | from collections import defaultdict 19 | 20 | import pandas as pd 21 | import numpy as np 22 | import einops 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from fancy_einsum import einsum 27 | from tqdm import tqdm 28 | 29 | import seaborn as sns 30 | import matplotlib.pyplot as plt 31 | from matplotlib.gridspec import GridSpec 32 | from transformer_lens import HookedTransformer 33 | from toxicity.figures.fig_utils import convert, load_hooked, get_svd 34 | from constants import MODEL_DIR, DATA_DIR 35 | 36 | torch.set_grad_enabled(False) 37 | 38 | # %% 39 | 40 | model = load_hooked("gpt2-medium", os.path.join(MODEL_DIR, "dpo.pt")) 41 | gpt2 = HookedTransformer.from_pretrained("gpt2-medium") 42 | gpt2.tokenizer.padding_side = "left" 43 | gpt2.tokenizer.pad_token_id = gpt2.tokenizer.eos_token_id 44 | 45 | toxic_vector = torch.load(os.path.join(MODEL_DIR, "probe.pt")) 46 | 47 | cos = F.cosine_similarity 48 | 49 | # %% 50 | 51 | with open( 52 | os.path.join(DATA_DIR, "intervene_data/challenge_prompts.jsonl"), "r" 53 | ) as file_p: 54 | data = file_p.readlines() 55 | 56 | prompts = [json.loads(x.strip())["prompt"] for x in data] 57 | tokenized_prompts = model.to_tokens(prompts, prepend_bos=True).cuda() 58 | 59 | # %% 60 | 61 | 62 | svd_gpt2, scores_gpt2 = get_svd(gpt2, toxic_vector, 128) 63 | 64 | mlps_by_layer = {} 65 | for _score_obj in scores_gpt2: 66 | layer = _score_obj[2] 67 | if layer not in mlps_by_layer: 68 | mlps_by_layer[layer] = [] 69 | mlps_by_layer[layer].append(_score_obj[1]) 70 | 71 | vectors_of_interest = [ 72 | (_score_obj[2], _score_obj[1], _score_obj[0]) 73 | for _score_obj in scores_gpt2[:64] 74 | ] 75 | 76 | 77 | # %% 78 | 79 | 80 | sample_size = tokenized_prompts.shape[0] 81 | layer_of_interest = 18 82 | 83 | sublayers = [0, 2, 4, 6, 8, 9, 11, 13, 15, 17] 84 | batch_size = 4 85 | 86 | all_diffs = [] 87 | for idx in tqdm(range(0, sample_size, batch_size)): 88 | batch = tokenized_prompts[idx : idx + batch_size, :] 89 | with torch.inference_mode(): 90 | _, cache = gpt2.run_with_cache(batch) 91 | 92 | gpt2_resids = {} 93 | # [batch, d_model] 94 | gpt2_resids[layer_of_interest] = cache[ 95 | f"blocks.{layer_of_interest}.hook_resid_mid" 96 | ] 97 | 98 | with torch.inference_mode(): 99 | _, cache = model.run_with_cache(batch) 100 | 101 | # [batch, d_model] 102 | all_diffs.append( 103 | cache[f"blocks.{layer_of_interest}.hook_resid_mid"] 104 | - gpt2_resids[layer_of_interest] 105 | ) 106 | 107 | all_diffs = torch.concat(all_diffs, dim=0) 108 | 109 | # %% 110 | 111 | # [4096, 1024] 112 | mlp_diffs = {} 113 | for _layer in sublayers: 114 | dpo_w_out = model.blocks[_layer].mlp.W_out 115 | gpt2_w_out = gpt2.blocks[_layer].mlp.W_out 116 | mlp_diffs[_layer] = dpo_w_out - gpt2_w_out 117 | 118 | 119 | diff_cosines = {} 120 | 121 | for mlp_layer in sublayers: 122 | 123 | diff_cosines[mlp_layer] = cos( 124 | mlp_diffs[mlp_layer], 125 | all_diffs.view(-1, 1024).mean(dim=0).unsqueeze(0), 126 | dim=1, 127 | ).tolist() 128 | 129 | 130 | # %% 131 | 132 | all_acts_dpo_pt = defaultdict(list) 133 | # Decompose layer 19 MLP out 134 | print("Grabbing mlp mids...") 135 | for idx in tqdm(range(0, sample_size, batch_size)): 136 | batch = tokenized_prompts[idx : idx + batch_size, :] 137 | with torch.inference_mode(): 138 | _, cache = model.run_with_cache(batch) 139 | 140 | for _layer in sublayers: 141 | mlp_mid = cache[f"blocks.{_layer}.mlp.hook_post"] 142 | all_acts_dpo_pt[_layer].append(mlp_mid.cpu()) 143 | 144 | # %% 145 | 146 | d_mlp = model.cfg.d_mlp 147 | dpo_acts_mean = {} 148 | gpt2_acts_mean = {} 149 | for _layer in sublayers: 150 | concat = torch.concat([x.cuda() for x in all_acts_dpo_pt[_layer]], dim=0) 151 | dpo_acts_mean[_layer] = ( 152 | concat.view(-1, d_mlp).to("cuda:1").mean(dim=0).cpu() 153 | ) 154 | 155 | # %% 156 | 157 | print("Building dataframes.") 158 | 159 | raw_cosine_data = [] 160 | layers_to_plot = list(diff_cosines.keys()) 161 | for _layer in layers_to_plot: 162 | for _idx in range(d_mlp): 163 | raw_cosine_data.append( 164 | { 165 | "layer": _layer, 166 | "cos_sim": round(diff_cosines[_layer][_idx], 2), 167 | } 168 | ) 169 | 170 | raw_mean_acts = [] 171 | for _layer in layers_to_plot: 172 | for _idx in range(d_mlp): 173 | raw_mean_acts.append( 174 | {"layer": _layer, "mean_acts": dpo_acts_mean[_layer][_idx].item()} 175 | ) 176 | 177 | 178 | cos_sim_df = pd.DataFrame(raw_cosine_data) 179 | mean_acts_df = pd.DataFrame(raw_mean_acts) 180 | 181 | # %% 182 | 183 | sns.set_theme(context="paper", style="ticks", rc={"lines.linewidth": 1}) 184 | 185 | colors = sns.color_palette() 186 | 187 | num_cols = int(len(sublayers) / 2) 188 | 189 | fig = plt.figure(figsize=(6.75, 2.4)) 190 | gs = GridSpec(2, num_cols) 191 | 192 | yticks = [0, 0.04, 0.08, 0.12, 0.16, 0.20, 0.24] 193 | edge_col = 0 194 | 195 | for idx in range(num_cols * 2): 196 | curr_row = idx // num_cols 197 | curr_col = idx % num_cols 198 | 199 | _layer = sublayers[idx] 200 | ax = fig.add_subplot(gs[curr_row, curr_col]) 201 | 202 | ax.set(yticks=yticks) 203 | ax2 = ax.twiny() 204 | sns.histplot( 205 | data=cos_sim_df[cos_sim_df.layer == _layer], 206 | x="cos_sim", 207 | ax=ax, 208 | stat="probability", 209 | color=colors[0], 210 | alpha=1, 211 | element="poly", 212 | ) 213 | sns.histplot( 214 | data=mean_acts_df[mean_acts_df.layer == _layer], 215 | x="mean_acts", 216 | ax=ax2, 217 | stat="probability", 218 | color=colors[1], 219 | alpha=0.65, 220 | element="poly", 221 | ) 222 | 223 | ax.set_title(f"Layer {_layer}", pad=-10) 224 | 225 | ax.set(yticks=yticks) 226 | ax.set(xticks=[-1, 0, 1]) 227 | ax.yaxis.tick_left() 228 | if curr_col != edge_col: 229 | ax.yaxis.set_visible(True) 230 | ax.yaxis.set_ticklabels([]) 231 | ax.yaxis.label.set_visible(False) 232 | 233 | else: 234 | ax.yaxis.label.set_visible(True) 235 | ax.yaxis.label.set_text("Proportion") 236 | 237 | if curr_row == 0: 238 | ax.xaxis.label.set_visible(False) 239 | ax.xaxis.set_ticklabels([]) 240 | ax2.xaxis.set_visible(False) 241 | 242 | ax2.set_xlim(left=-0.21, right=0.21) 243 | 244 | ax2.xaxis.set_ticks_position("bottom") 245 | ax2.xaxis.set_label_position("bottom") 246 | ax2.spines["bottom"].set_position(("outward", 28)) 247 | ax2.set_frame_on(True) 248 | ax2.patch.set_visible(False) 249 | for sp in ax2.spines.values(): 250 | sp.set_visible(False) 251 | ax2.spines["bottom"].set_visible(True) 252 | 253 | ax.spines["bottom"].set_color(colors[0]) 254 | ax2.spines["bottom"].set_color(colors[1]) 255 | 256 | ax.xaxis.label.set_color(colors[0]) 257 | ax.tick_params(axis="x", colors=colors[0]) 258 | 259 | ax2.xaxis.label.set_color(colors[1]) 260 | ax2.tick_params(axis="x", colors=colors[1]) 261 | 262 | ax2.set(xticks=[-0.2, 0, 0.2]) 263 | ax.xaxis.label.set_text("Cos Sim") 264 | ax2.xaxis.label.set_text("Mean Act.") 265 | ax.xaxis.labelpad = -1 266 | ax2.xaxis.labelpad = -1 267 | 268 | fig.savefig( 269 | f"resid_diff_subplots_layer{layer_of_interest}.pdf", 270 | bbox_inches="tight", 271 | dpi=1200, 272 | ) 273 | 274 | -------------------------------------------------------------------------------- /toxicity/figures/shit_prompts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajyl/dpo_toxic/34319c9bad0d22608e30460807a54869e84f474f/toxicity/figures/shit_prompts.npy -------------------------------------------------------------------------------- /toxicity/train_dpo/config/config.yaml: -------------------------------------------------------------------------------- 1 | # random seed for batch sampling 2 | seed: 42 3 | 4 | # name for this experiment in the local run directory and on wandb 5 | exp_name: ??? 6 | 7 | valid_size: 64 8 | # the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) 9 | batch_size: 4 10 | # the batch size during evaluation and sampling, if enabled 11 | eval_batch_size: 8 12 | 13 | # debug mode (disables wandb, model checkpointing, etc.) 14 | debug: false 15 | 16 | # the port to use for FSDP 17 | fsdp_port: null 18 | 19 | # wandb configuration 20 | wandb: 21 | enabled: true 22 | entity: null 23 | project: "dpo-toxicity-pplm" 24 | 25 | # to create the local run directory and cache models/datasets, 26 | # we will try each of these directories in order; if none exist, 27 | # we will create the last one and use it 28 | local_dirs: 29 | - .cache 30 | 31 | # whether or not to generate samples during evaluation; disable for FSDP/TensorParallel 32 | # is recommended, because they are slow 33 | 34 | # how many model samples to generate during evaluation 35 | n_eval_model_samples: 16 36 | 37 | # whether to eval at the very beginning of training 38 | do_first_eval: false 39 | 40 | # an OmegaConf resolver that returns the local run directory, calling a function in utils.py 41 | local_run_dir: ${get_local_run_dir:${exp_name},${local_dirs}} 42 | 43 | # the learning rate 44 | lr: 1e-6 45 | 46 | # number of steps to accumulate over for each batch 47 | # (e.g. if batch_size=4 and gradient_accumulation_steps=2, then we will 48 | # accumulate gradients over 2 microbatches of size 2) 49 | gradient_accumulation_steps: 1 50 | 51 | # the maximum gradient norm to clip to 52 | max_grad_norm: 10.0 53 | 54 | # the maximum allowed length for an input (prompt + response) 55 | max_length: 256 56 | max_new_tokens: 64 57 | 58 | # the maximum allowed length for a prompt 59 | max_prompt_length: 64 60 | 61 | # the number of epochs to train for; if null, must specify n_examples 62 | n_epochs: 5 63 | 64 | # the trainer class to use (e.g. BasicTrainer, FSDPTrainer, TensorParallelTrainer) 65 | trainer: BasicTrainer 66 | 67 | # The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient 68 | optimizer: RMSprop 69 | 70 | # number of linear warmup steps for the learning rate 71 | warmup_steps: 150 72 | 73 | # whether or not to use activation/gradient checkpointing 74 | activation_checkpointing: false 75 | 76 | # evaluate and save model every eval_every steps 77 | eval_every: 100 78 | save_every: 100 79 | validation_metric: "loss/valid" 80 | validation_direction: "min" 81 | validation_patience: 30 82 | 83 | sample_during_eval: false 84 | sample_every: 2000 85 | 86 | # prevent wandb from logging more than once per minimum_log_interval_secs 87 | minimum_log_interval_secs: 2.0 88 | 89 | defaults: 90 | - _self_ 91 | - model: gpt2-medium 92 | - loss: dpo # which loss function, either sft or dpo (specify loss.beta if using dpo) 93 | -------------------------------------------------------------------------------- /toxicity/train_dpo/config/loss/dpo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: dpo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: 0.1 7 | 8 | # if true, use a uniform (maximum entropy) reference model 9 | reference_free: false 10 | -------------------------------------------------------------------------------- /toxicity/train_dpo/config/loss/sft.yaml: -------------------------------------------------------------------------------- 1 | name: sft -------------------------------------------------------------------------------- /toxicity/train_dpo/config/model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-large 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /toxicity/train_dpo/config/model/gpt2-medium.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-medium 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 9 | -------------------------------------------------------------------------------- /toxicity/train_dpo/config/model/gpt2-xl.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-xl 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /toxicity/train_dpo/config/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 9 | -------------------------------------------------------------------------------- /toxicity/train_dpo/dpo_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import getpass 3 | from datetime import datetime 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.distributed as dist 8 | import inspect 9 | import importlib.util 10 | import socket 11 | import os 12 | from typing import Dict, Union, Type, List 13 | 14 | 15 | def get_open_port(): 16 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 17 | s.bind(("", 0)) # bind to all interfaces and use an OS provided port 18 | return s.getsockname()[1] # return only the port number 19 | 20 | 21 | def get_remote_file(remote_path, local_path=None): 22 | hostname, path = remote_path.split(":") 23 | local_hostname = socket.gethostname() 24 | if ( 25 | hostname == local_hostname 26 | or hostname == local_hostname[: local_hostname.find(".")] 27 | ): 28 | return path 29 | 30 | if local_path is None: 31 | local_path = path 32 | # local_path = local_path.replace('/scr-ssd', '/scr') 33 | if os.path.exists(local_path): 34 | return local_path 35 | local_dir = os.path.dirname(local_path) 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | print(f"Copying {hostname}:{path} to {local_path}") 39 | os.system(f"scp {remote_path} {local_path}") 40 | return local_path 41 | 42 | 43 | def rank0_print(*args, **kwargs): 44 | """Print, but only on rank 0.""" 45 | if not dist.is_initialized() or dist.get_rank() == 0: 46 | print(*args, **kwargs) 47 | 48 | 49 | def get_local_dir(prefixes_to_resolve: List[str]) -> str: 50 | """Return the path to the cache directory for this user.""" 51 | for prefix in prefixes_to_resolve: 52 | if os.path.exists(prefix): 53 | return f"{prefix}/{getpass.getuser()}" 54 | os.makedirs(prefix) 55 | return f"{prefix}/{getpass.getuser()}" 56 | 57 | 58 | def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str: 59 | """Create a local directory to store outputs for this run, and return its path.""" 60 | run_dir = f"{get_local_dir(local_dirs)}/{exp_name}" 61 | os.makedirs(run_dir, exist_ok=True) 62 | return run_dir 63 | 64 | 65 | def slice_and_move_batch_for_device( 66 | batch: Dict, rank: int, world_size: int, device: str 67 | ) -> Dict: 68 | """Slice a batch into chunks, and move each chunk to the specified device.""" 69 | chunk_size = len(list(batch.values())[0]) // world_size 70 | start = chunk_size * rank 71 | end = chunk_size * (rank + 1) 72 | sliced = {k: v[start:end] for k, v in batch.items()} 73 | on_device = { 74 | k: (v.to(device) if isinstance(v, torch.Tensor) else v) 75 | for k, v in sliced.items() 76 | } 77 | return on_device 78 | 79 | 80 | def pad_to_length( 81 | tensor: torch.Tensor, 82 | length: int, 83 | pad_value: Union[int, float], 84 | dim: int = -1, 85 | ) -> torch.Tensor: 86 | if tensor.size(dim) >= length: 87 | return tensor 88 | else: 89 | pad_size = list(tensor.shape) 90 | pad_size[dim] = length - tensor.size(dim) 91 | return torch.cat( 92 | [ 93 | tensor, 94 | pad_value 95 | * torch.ones( 96 | *pad_size, dtype=tensor.dtype, device=tensor.device 97 | ), 98 | ], 99 | dim=dim, 100 | ) 101 | 102 | 103 | def all_gather_if_needed( 104 | values: torch.Tensor, rank: int, world_size: int 105 | ) -> torch.Tensor: 106 | """Gather and stack/cat values from all processes, if there are multiple processes.""" 107 | if world_size == 1: 108 | return values 109 | 110 | all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)] 111 | dist.all_gather(all_values, values) 112 | cat_function = torch.cat if values.dim() > 0 else torch.stack 113 | return cat_function(all_values, dim=0) 114 | 115 | 116 | def formatted_dict(d: Dict) -> Dict: 117 | """Format a dictionary for printing.""" 118 | return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()} 119 | 120 | 121 | def disable_dropout(model: torch.nn.Module): 122 | """Disable dropout in a model.""" 123 | for module in model.modules(): 124 | if isinstance(module, torch.nn.Dropout): 125 | module.p = 0 126 | 127 | 128 | def print_gpu_memory(rank: int = None, message: str = ""): 129 | """Print the amount of GPU memory currently allocated for each GPU.""" 130 | if torch.cuda.is_available(): 131 | device_count = torch.cuda.device_count() 132 | for i in range(device_count): 133 | device = torch.device(f"cuda:{i}") 134 | allocated_bytes = torch.cuda.memory_allocated(device) 135 | if allocated_bytes == 0: 136 | continue 137 | print("*" * 40) 138 | print( 139 | f"[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024**2:.2f} MB" 140 | ) 141 | print("*" * 40) 142 | 143 | 144 | def get_block_class_from_model( 145 | model: torch.nn.Module, block_class_name: str 146 | ) -> torch.nn.Module: 147 | """Get the class of a block from a model, using the block's class name.""" 148 | for module in model.modules(): 149 | if module.__class__.__name__ == block_class_name: 150 | return module.__class__ 151 | raise ValueError( 152 | f"Could not find block class {block_class_name} in model {model}" 153 | ) 154 | 155 | 156 | def get_block_class_from_model_class_and_block_name( 157 | model_class: Type, block_class_name: str 158 | ) -> Type: 159 | filepath = inspect.getfile(model_class) 160 | assert filepath.endswith(".py"), f"Expected a .py file, got {filepath}" 161 | assert os.path.exists(filepath), f"File {filepath} does not exist" 162 | assert ( 163 | "transformers" in filepath 164 | ), f"Expected a transformers model, got {filepath}" 165 | 166 | module_name = filepath[filepath.find("transformers") :].replace("/", ".")[ 167 | :-3 168 | ] 169 | print( 170 | f"Searching in file {filepath}, module {module_name} for class {block_class_name}" 171 | ) 172 | 173 | # Load the module dynamically 174 | spec = importlib.util.spec_from_file_location(module_name, filepath) 175 | module = importlib.util.module_from_spec(spec) 176 | spec.loader.exec_module(module) 177 | 178 | # Get the class dynamically 179 | class_ = getattr(module, block_class_name) 180 | print(f"Found class {class_} in module {module_name}") 181 | return class_ 182 | 183 | 184 | def init_distributed( 185 | rank: int, 186 | world_size: int, 187 | master_addr: str = "localhost", 188 | port: int = 12355, 189 | backend: str = "nccl", 190 | ): 191 | print(rank, "initializing distributed") 192 | os.environ["MASTER_ADDR"] = master_addr 193 | os.environ["MASTER_PORT"] = str(port) 194 | dist.init_process_group(backend, rank=rank, world_size=world_size) 195 | torch.cuda.set_device(rank) 196 | 197 | 198 | class TemporarilySeededRandom: 199 | def __init__(self, seed): 200 | """Temporarily set the random seed, and then restore it when exiting the context.""" 201 | self.seed = seed 202 | self.stored_state = None 203 | self.stored_np_state = None 204 | 205 | def __enter__(self): 206 | # Store the current random state 207 | self.stored_state = random.getstate() 208 | self.stored_np_state = np.random.get_state() 209 | 210 | # Set the random seed 211 | random.seed(self.seed) 212 | np.random.seed(self.seed) 213 | 214 | def __exit__(self, exc_type, exc_value, traceback): 215 | # Restore the random state 216 | random.setstate(self.stored_state) 217 | np.random.set_state(self.stored_np_state) 218 | -------------------------------------------------------------------------------- /toxicity/train_dpo/pplm_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load PPLM dataset 3 | """ 4 | from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple 5 | 6 | import os 7 | import json 8 | from collections import defaultdict 9 | import random 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader, Dataset 15 | from torch.nn.utils.rnn import pad_sequence 16 | from toxicity.train_dpo.dpo_utils import get_local_dir, TemporarilySeededRandom 17 | from constants import DATA_DIR, GPT2_PAD_IDX 18 | 19 | 20 | def get_pplm_batch_iterator( 21 | tokenizer, 22 | config, 23 | split: str = "train", 24 | device: str = "cuda", 25 | ) -> Iterator[Dict]: 26 | """ 27 | Get an iterator over batches of data. 28 | 29 | :params: 30 | 31 | :split: Which split to use. 32 | :batch_size: Batch size. 33 | :valid_size: Validation size. 34 | """ 35 | assert split in ["train", "valid"] 36 | data_dir = os.path.join(DATA_DIR, "toxicity_pairwise") 37 | batch_size = config.batch_size 38 | if split == "valid": 39 | batch_size = config.eval_batch_size 40 | max_prompt_length = config.max_prompt_length 41 | max_new_tokens = config.max_new_tokens 42 | valid_size = config.valid_size 43 | 44 | filenames = [ 45 | os.path.join(data_dir, filename) 46 | for filename in os.listdir(data_dir) 47 | if filename.endswith(".jsonl") 48 | ] 49 | 50 | data = [] 51 | for filename in tqdm(filenames): 52 | with open(filename, "r") as file_p: 53 | file_data = file_p.readlines() 54 | 55 | data.extend(file_data) 56 | 57 | random.shuffle(file_data) 58 | if split == "train": 59 | data = data[:-valid_size] 60 | else: 61 | data = data[-valid_size:] 62 | data_size = len(data) 63 | 64 | for idx in range(0, data_size, batch_size): 65 | batch = data[idx : idx + batch_size] 66 | batch = [json.loads(x.strip()) for x in batch] 67 | 68 | prompt_text = [x["prompt_text"] for x in batch] 69 | gold_text = [x["unpert_gen_text"] for x in batch] 70 | 71 | prompt_tokenized = tokenizer( 72 | prompt_text, 73 | max_length=max_prompt_length, 74 | padding=True, 75 | truncation=True, 76 | return_tensors="pt", 77 | ).to(device) 78 | 79 | prompt_input_ids = prompt_tokenized["input_ids"] 80 | prompt_attention_mask = prompt_tokenized["attention_mask"] 81 | 82 | tokenizer.padding_side = "right" 83 | gold_tokenized = tokenizer( 84 | gold_text, 85 | max_length=max_new_tokens, 86 | padding=True, 87 | truncation=True, 88 | return_tensors="pt", 89 | ).to(device) 90 | 91 | pos_input_id = gold_tokenized["input_ids"].long() 92 | 93 | pplm_text = [x["pert_gen_text"] for x in batch] 94 | pplm_tokenized = tokenizer( 95 | pplm_text, 96 | max_length=max_new_tokens, 97 | padding=True, 98 | truncation=True, 99 | return_tensors="pt", 100 | ).to(device) 101 | tokenizer.padding_side = "left" 102 | 103 | pos_input_ids = torch.concat( 104 | [prompt_input_ids, gold_tokenized["input_ids"]], dim=1 105 | ) 106 | neg_input_ids = torch.concat( 107 | [prompt_input_ids, pplm_tokenized["input_ids"]], dim=1 108 | ) 109 | 110 | prompt_shape = prompt_input_ids.shape[1] 111 | pos_labels = pos_input_ids.detach().clone() 112 | pos_labels[:, :prompt_shape] = -100 113 | neg_labels = neg_input_ids.detach().clone() 114 | neg_labels[:, :prompt_shape] = -100 115 | 116 | yield { 117 | "prompt_input_ids": prompt_input_ids, 118 | "prompt_attention_mask": prompt_attention_mask, 119 | "gold_text": gold_text, 120 | "gold_input_ids": pos_input_id, 121 | "pos_text": gold_text, 122 | "pos_input_ids": pos_input_ids, 123 | "pos_attention_mask": pos_input_ids != tokenizer.pad_token_id, 124 | "pos_labels": pos_labels, 125 | "neg_text": pplm_text, 126 | "neg_input_ids": neg_input_ids, 127 | "neg_attention_mask": neg_input_ids != tokenizer.pad_token_id, 128 | "neg_labels": neg_labels, 129 | } 130 | -------------------------------------------------------------------------------- /toxicity/train_dpo/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train script 3 | """ 4 | from typing import Optional, Set 5 | 6 | import os 7 | import json 8 | import socket 9 | import resource 10 | import torch 11 | import torch.nn as nn 12 | import torch.multiprocessing as mp 13 | import transformers 14 | import hydra 15 | from omegaconf import OmegaConf, DictConfig 16 | import wandb 17 | 18 | import toxicity.train_dpo.trainers as trainers 19 | from toxicity.train_dpo.dpo_utils import ( 20 | get_local_dir, 21 | get_local_run_dir, 22 | disable_dropout, 23 | init_distributed, 24 | get_open_port, 25 | ) 26 | 27 | torch.backends.cuda.matmul.allow_tf32 = True 28 | OmegaConf.register_new_resolver( 29 | "get_local_run_dir", 30 | lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs), 31 | ) 32 | 33 | 34 | def worker_main( 35 | rank: int, 36 | world_size: int, 37 | config: DictConfig, 38 | policy: nn.Module, 39 | reference_model: Optional[nn.Module] = None, 40 | ): 41 | """ 42 | Main function for each worker process 43 | (may be only 1 for BasicTrainer/TensorParallelTrainer). 44 | """ 45 | if "FSDP" in config.trainer: 46 | init_distributed(rank, world_size, port=config.fsdp_port) 47 | 48 | if config.debug: 49 | wandb.init = lambda *args, **kwargs: None 50 | wandb.log = lambda *args, **kwargs: None 51 | 52 | if rank == 0 and config.wandb.enabled: 53 | os.environ["WANDB_CACHE_DIR"] = get_local_dir(config.local_dirs) 54 | wandb.init( 55 | entity=config.wandb.entity, 56 | project=config.wandb.project, 57 | config=OmegaConf.to_container(config), 58 | dir=get_local_dir(config.local_dirs), 59 | name=config.exp_name, 60 | ) 61 | 62 | TrainerClass = getattr(trainers, config.trainer) 63 | print(f"Creating trainer on process {rank} with world size {world_size}") 64 | trainer = TrainerClass( 65 | policy, 66 | config, 67 | config.seed, 68 | config.local_run_dir, 69 | reference_model=reference_model, 70 | rank=rank, 71 | world_size=world_size, 72 | ) 73 | 74 | trainer.train_loop() 75 | 76 | 77 | @hydra.main(version_base=None, config_path="config", config_name="config") 78 | def main(config: DictConfig): 79 | """ 80 | Main entry point for training. 81 | Validates config, creates/initializes model(s), 82 | and kicks off worker process(es). 83 | """ 84 | # Resolve hydra references, e.g. so we don't re-compute the run directory 85 | OmegaConf.resolve(config) 86 | 87 | missing_keys: Set[str] = OmegaConf.missing_keys(config) 88 | if missing_keys: 89 | raise ValueError(f"Got missing keys in config:\n{missing_keys}") 90 | 91 | if config.eval_every % config.batch_size != 0: 92 | print("WARNING: eval_every must be divisible by batch_size") 93 | print( 94 | "Setting eval_every to", 95 | config.eval_every - config.eval_every % config.batch_size, 96 | ) 97 | config.eval_every = ( 98 | config.eval_every - config.eval_every % config.batch_size 99 | ) 100 | 101 | if "FSDP" in config.trainer and config.fsdp_port is None: 102 | free_port = get_open_port() 103 | print("no FSDP port specified; using open port for FSDP:", free_port) 104 | config.fsdp_port = free_port 105 | 106 | print(OmegaConf.to_yaml(config)) 107 | 108 | config_path = os.path.join(config.local_run_dir, "config.yaml") 109 | with open(config_path, "w") as f: 110 | OmegaConf.save(config, f) 111 | 112 | print("=" * 80) 113 | print(f"Writing to {socket.gethostname()}:{config.local_run_dir}") 114 | print("=" * 80) 115 | 116 | os.environ["XDG_CACHE_HOME"] = get_local_dir(config.local_dirs) 117 | print("building policy") 118 | model_kwargs = ( 119 | {"device_map": "balanced"} if config.trainer == "BasicTrainer" else {} 120 | ) 121 | policy_dtype = getattr(torch, config.model.policy_dtype) 122 | policy = transformers.AutoModelForCausalLM.from_pretrained( 123 | config.model.name_or_path, 124 | cache_dir=get_local_dir(config.local_dirs), 125 | low_cpu_mem_usage=True, 126 | torch_dtype=policy_dtype, 127 | **model_kwargs, 128 | ) 129 | disable_dropout(policy) 130 | 131 | if config.loss.name == "dpo": 132 | print("building reference model") 133 | reference_model_dtype = getattr(torch, config.model.reference_dtype) 134 | reference_model = transformers.AutoModelForCausalLM.from_pretrained( 135 | config.model.name_or_path, 136 | cache_dir=get_local_dir(config.local_dirs), 137 | low_cpu_mem_usage=True, 138 | torch_dtype=reference_model_dtype, 139 | **model_kwargs, 140 | ) 141 | disable_dropout(reference_model) 142 | else: 143 | reference_model = None 144 | 145 | if config.model.archive is not None: 146 | state_dict = torch.load(config.model.archive, map_location="cpu") 147 | step, metrics = state_dict["step_idx"], state_dict["metrics"] 148 | print( 149 | f"loading pre-trained weights at step {step} from \ 150 | {config.model.archive} with metrics \ 151 | {json.dumps(metrics, indent=2)}" 152 | ) 153 | policy.load_state_dict(state_dict["state"]) 154 | if config.loss.name == "dpo": 155 | reference_model.load_state_dict(state_dict["state"]) 156 | 157 | print("loaded pre-trained weights") 158 | 159 | if "FSDP" in config.trainer: 160 | world_size = torch.cuda.device_count() 161 | print("starting", world_size, "processes for FSDP training") 162 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 163 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 164 | print(f"setting RLIMIT_NOFILE soft limit to {hard} from {soft}") 165 | mp.spawn( 166 | worker_main, 167 | nprocs=world_size, 168 | args=(world_size, config, policy, reference_model), 169 | join=True, 170 | ) 171 | else: 172 | print("starting single-process worker") 173 | worker_main(0, 1, config, policy, reference_model) 174 | 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /toxicity/train_dpo/trainers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train loop for DPO. 3 | """ 4 | from typing import Optional, Dict, List, Union, Tuple 5 | 6 | import random 7 | import os 8 | from collections import defaultdict 9 | import time 10 | import json 11 | import functools 12 | import contextlib 13 | from collections import Counter 14 | 15 | import numpy as np 16 | import wandb 17 | from tqdm import tqdm 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.nn as nn 22 | from torch.nn import KLDivLoss 23 | import torch.distributed as dist 24 | from torch.distributed.fsdp import ( 25 | FullyShardedDataParallel as FSDP, 26 | MixedPrecision, 27 | StateDictType, 28 | BackwardPrefetch, 29 | ShardingStrategy, 30 | CPUOffload, 31 | ) 32 | from torch.distributed.fsdp.api import ( 33 | FullStateDictConfig, 34 | FullOptimStateDictConfig, 35 | ) 36 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 37 | import transformers 38 | from omegaconf import DictConfig 39 | 40 | from toxicity.train_dpo.pplm_dataset import get_pplm_batch_iterator 41 | from toxicity.train_dpo.dpo_utils import ( 42 | slice_and_move_batch_for_device, 43 | formatted_dict, 44 | all_gather_if_needed, 45 | pad_to_length, 46 | get_block_class_from_model, 47 | rank0_print, 48 | get_local_dir, 49 | ) 50 | from constants import GPT2_PAD_IDX 51 | 52 | torch.backends.cuda.matmul.allow_tf32 = True 53 | 54 | 55 | def generate( 56 | model, 57 | batch, 58 | max_new_tokens, 59 | pad_token_id, 60 | include_ngram_blocked=False, 61 | include_ref=False, 62 | fsdp=False, 63 | ref_model=None, 64 | ): 65 | """ 66 | Return greedy and n-gram blocked generations. 67 | """ 68 | prompt_shape = batch["prompt_input_ids"].shape[1] 69 | with torch.no_grad(): 70 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069 71 | ctx = lambda: ( 72 | FSDP.summon_full_params(model, writeback=False, recurse=False) 73 | if fsdp 74 | else contextlib.nullcontext() 75 | ) 76 | with ctx(): 77 | greedy_resp = model.generate( 78 | input_ids=batch["prompt_input_ids"], 79 | attention_mask=batch["prompt_attention_mask"], 80 | max_new_tokens=max_new_tokens, 81 | do_sample=False, 82 | pad_token_id=pad_token_id, 83 | ) 84 | 85 | greedy_resp_labels = greedy_resp.detach().clone() 86 | greedy_resp_labels[:, :prompt_shape] = -100 87 | output = { 88 | "policy_input_ids": greedy_resp, 89 | "policy_attention_mask": greedy_resp != GPT2_PAD_IDX, 90 | "policy_labels": greedy_resp_labels, 91 | } 92 | 93 | return output 94 | 95 | 96 | def dpo_loss( 97 | policy_pos_logps: torch.FloatTensor, 98 | policy_neg_logps: torch.FloatTensor, 99 | ref_pos_logps: torch.FloatTensor, 100 | ref_neg_logps: torch.FloatTensor, 101 | beta: float, 102 | reference_free: bool = False, 103 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 104 | """ 105 | Compute the DPO loss for a batch of policy and reference model log probabilities. 106 | 107 | :params: 108 | 109 | :policy_pos_logps: logprobs of positive responses from policy model: (batch_size,) 110 | :policy_neg_logps: logprobs of negative responses from policy model: (batch_size,) 111 | :ref_pos_logps: logprobs of positive responses from reference model: (batch_size,) 112 | :ref_neg_logps: logprobs of negative responses from reference model: (batch_size,) 113 | :beta: Temperature parameter for the DPO loss, typically something 114 | in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 115 | :reference_free: If True, we ignore the _provided_ reference model and 116 | implicitly use a reference model that assigns equal probability to all responses. 117 | 118 | :returns: 119 | 120 | A tuple of three tensors: (losses, pos_rewards, neg_rewards). 121 | The losses tensor contains the DPO loss for each example in the batch. 122 | The pos_rewards and neg_rewards tensors contain the rewards for the 123 | positive and neg responses, respectively. 124 | """ 125 | pi_logratios = policy_pos_logps - policy_neg_logps 126 | ref_logratios = ref_pos_logps - ref_neg_logps 127 | 128 | if reference_free: 129 | ref_logratios = 0 130 | 131 | logits = pi_logratios - ref_logratios 132 | 133 | losses = -F.logsigmoid(beta * logits) 134 | pos_rewards = beta * (policy_pos_logps - ref_pos_logps).detach() 135 | neg_rewards = beta * (policy_neg_logps - ref_neg_logps).detach() 136 | 137 | return losses, pos_rewards, neg_rewards 138 | 139 | 140 | def get_kl_div( 141 | kl_criterion: KLDivLoss, 142 | pos_pi_logits: torch.FloatTensor, # [batch, seq, vocab] 143 | neg_pi_logits: torch.FloatTensor, # [batch, seq, vocab] 144 | pos_ref_logits: torch.FloatTensor, # [batch, seq, vocab] 145 | neg_ref_logits: torch.FloatTensor, # [batch, seq, vocab] 146 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 147 | """ 148 | Return KL Loss. 149 | """ 150 | # [batch, seq, vocab] --> [batch] 151 | pos_kl_div = ( 152 | kl_criterion( 153 | F.log_softmax(pos_pi_logits, dim=-1), 154 | F.log_softmax(pos_ref_logits, dim=-1), 155 | ) 156 | .sum(dim=-1) 157 | .mean(dim=-1) 158 | ) 159 | neg_kl_div = ( 160 | kl_criterion( 161 | F.log_softmax(neg_pi_logits, dim=-1), 162 | F.log_softmax(neg_ref_logits, dim=-1), 163 | ) 164 | .sum(dim=-1) 165 | .mean(dim=-1) 166 | ) 167 | return pos_kl_div, neg_kl_div 168 | 169 | 170 | def get_batch_logps( 171 | logits: torch.FloatTensor, 172 | input_ids: torch.FloatTensor, 173 | average_log_prob: bool = False, 174 | ) -> torch.FloatTensor: 175 | """ 176 | Compute the log probabilities of the given labels under the given logits. 177 | 178 | :params: 179 | 180 | :logits: Logits of the model (unnormalized). (batch, seq, vocab) 181 | :labels: Labels for which to compute the log probabilities. 182 | Label tokens with a value of -100 are ignored. (batch, seq) 183 | :average_log_prob: If True, return the average log probability per 184 | (non-masked) token. Otherwise, return the sum of the log probabilities 185 | of the (non-masked) tokens. 186 | 187 | Returns: 188 | A tensor of shape (batch_size,) containing the average/sum log 189 | probabilities of the given labels under the given logits. 190 | """ 191 | # [batch, seq] 192 | labels = input_ids[:, 1:].clone() 193 | logits = logits[:, :-1, :] 194 | loss_mask = labels != GPT2_PAD_IDX 195 | 196 | # dummy token; we'll ignore the losses on these tokens later 197 | labels[labels == GPT2_PAD_IDX] = 0 198 | 199 | per_token_logps = torch.gather( 200 | logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) 201 | ).squeeze(2) 202 | 203 | if average_log_prob: 204 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 205 | else: 206 | return (per_token_logps * loss_mask).sum(-1) 207 | 208 | 209 | def concatenated_inputs( 210 | batch: Dict[str, Union[List, torch.LongTensor]] 211 | ) -> Dict[str, torch.LongTensor]: 212 | """ 213 | Concatenate the positive and negative inputs into a single tensor. 214 | 215 | :params: 216 | 217 | :batch: A batch of data. Must contain the keys 'pos_input_ids' and 218 | 'neg_input_ids', which are tensors of shape (batch, seq). 219 | 220 | :returns: 221 | A dictionary containing the concatenated inputs under the key 222 | 'concatenated_input_ids'. 223 | """ 224 | max_length = max( 225 | batch["pos_input_ids"].shape[1], 226 | batch["neg_input_ids"].shape[1], 227 | ) 228 | concatenated_batch = {} 229 | for k in batch: 230 | if k.startswith("pos_") and isinstance(batch[k], torch.Tensor): 231 | pad_value = -100 if "labels" in k else 0 232 | concatenated_key = k.replace("pos", "concatenated") 233 | concatenated_batch[concatenated_key] = pad_to_length( 234 | batch[k], max_length, pad_value=pad_value 235 | ) 236 | for k in batch: 237 | if k.startswith("neg_") and isinstance(batch[k], torch.Tensor): 238 | pad_value = -100 if "labels" in k else 0 239 | concatenated_key = k.replace("neg", "concatenated") 240 | concatenated_batch[concatenated_key] = torch.cat( 241 | ( 242 | concatenated_batch[concatenated_key], 243 | pad_to_length(batch[k], max_length, pad_value=pad_value), 244 | ), 245 | dim=0, 246 | ) 247 | return concatenated_batch 248 | 249 | 250 | class BasicTrainer(object): 251 | def __init__( 252 | self, 253 | policy: nn.Module, 254 | config: DictConfig, 255 | seed: int, 256 | run_dir: str, 257 | reference_model: Optional[nn.Module] = None, 258 | rank: int = 0, 259 | world_size: int = 1, 260 | ): 261 | """ 262 | A trainer for a language model, supporting either SFT or DPO training. 263 | 264 | If multiple GPUs are present, naively splits the model across them, effectively 265 | offering N times available memory, but without any parallel computation. 266 | """ 267 | self.seed = seed 268 | self.rank = rank 269 | self.world_size = world_size 270 | self.config = config 271 | self.run_dir = run_dir 272 | self.example_counter = 0 273 | self.batch_counter = 0 274 | self.last_log = None 275 | self.patience = 0 276 | self.val_metric_value = -1 277 | if config.validation_direction == "max": 278 | self.val_direction = 1 279 | self.best_val_metric = -1 280 | 281 | else: 282 | self.val_direction = -1 283 | self.best_val_metric = 1e10 284 | 285 | tokenizer_name_or_path = ( 286 | config.model.tokenizer_name_or_path or config.model.name_or_path 287 | ) 288 | rank0_print(f"Loading tokenizer {tokenizer_name_or_path}") 289 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 290 | tokenizer_name_or_path, cache_dir=get_local_dir(config.local_dirs) 291 | ) 292 | if tokenizer_name_or_path.startswith("gpt2"): 293 | self.tokenizer.padding_side = "left" 294 | if self.tokenizer.pad_token_id is None: 295 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 296 | 297 | self.policy = policy 298 | self.reference_model = reference_model 299 | self.kl_criterion = KLDivLoss(reduction="none", log_target=True) 300 | 301 | self.train_iterator = get_pplm_batch_iterator( 302 | self.tokenizer, 303 | self.config, 304 | split="train", 305 | ) 306 | self.eval_iterator = get_pplm_batch_iterator( 307 | self.tokenizer, 308 | self.config, 309 | split="valid", 310 | ) 311 | self.eval_batches = list(self.eval_iterator) 312 | rank0_print( 313 | f"Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}" 314 | ) 315 | 316 | def get_batch_samples( 317 | self, batch: Dict[str, torch.LongTensor] 318 | ) -> Tuple[str, str]: 319 | """ 320 | Generate samples from the policy (and reference model, if doing DPO training) 321 | for the given batch of inputs 322 | """ 323 | 324 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069 325 | ctx = lambda: ( 326 | FSDP.summon_full_params( 327 | self.policy, writeback=False, recurse=False 328 | ) 329 | if "FSDP" in self.config.trainer 330 | else contextlib.nullcontext() 331 | ) 332 | with ctx(): 333 | policy_output = self.policy.generate( 334 | batch["prompt_input_ids"], 335 | attention_mask=batch["prompt_attention_mask"], 336 | max_length=self.config.max_length, 337 | do_sample=False, 338 | pad_token_id=self.tokenizer.pad_token_id, 339 | ) 340 | 341 | if self.config.loss.name == "dpo": 342 | ctx = lambda: ( 343 | FSDP.summon_full_params( 344 | self.reference_model, writeback=False, recurse=False 345 | ) 346 | if "FSDP" in self.config.trainer 347 | else contextlib.nullcontext() 348 | ) 349 | with ctx(): 350 | reference_output = self.reference_model.generate( 351 | batch["prompt_input_ids"], 352 | attention_mask=batch["prompt_attention_mask"], 353 | max_length=self.config.max_length, 354 | do_sample=False, 355 | pad_token_id=self.tokenizer.pad_token_id, 356 | ) 357 | 358 | policy_output = pad_to_length( 359 | policy_output, self.config.max_length, self.tokenizer.pad_token_id 360 | ) 361 | policy_output = all_gather_if_needed( 362 | policy_output, self.rank, self.world_size 363 | ) 364 | policy_output_decoded = self.tokenizer.batch_decode( 365 | policy_output, skip_special_tokens=True 366 | ) 367 | 368 | reference_output_decoded = [] 369 | if self.config.loss.name == "dpo": 370 | reference_output = pad_to_length( 371 | reference_output, 372 | self.config.max_length, 373 | self.tokenizer.pad_token_id, 374 | ) 375 | reference_output = all_gather_if_needed( 376 | reference_output, self.rank, self.world_size 377 | ) 378 | reference_output_decoded = self.tokenizer.batch_decode( 379 | reference_output, skip_special_tokens=True 380 | ) 381 | 382 | return policy_output_decoded, reference_output_decoded 383 | 384 | def concatenated_forward( 385 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] 386 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 387 | """ 388 | Run the given model on the given batch of inputs, 389 | concatenating the positive and negative inputs together. 390 | 391 | We do this to avoid doing two forward passes, because it's faster for FSDP. 392 | 393 | :returns: 394 | :pos_logps: (batch) 395 | :neg_logps: (batch) 396 | :pos_logits: (batch, seq, vocab) 397 | :neg_logits: (batch, seq, vocab) 398 | """ 399 | concatenated_batch = concatenated_inputs(batch) 400 | 401 | # [batch (*2), seq (prompt + response), vocab] 402 | all_logits = model( 403 | concatenated_batch["concatenated_input_ids"], 404 | attention_mask=concatenated_batch["concatenated_attention_mask"], 405 | ).logits.to(torch.float32) 406 | all_logps = get_batch_logps( 407 | all_logits, 408 | concatenated_batch["concatenated_input_ids"], 409 | average_log_prob=False, 410 | ) 411 | 412 | num_pos_samples = batch["pos_input_ids"].shape[0] 413 | pos_logps = all_logps[:num_pos_samples] 414 | neg_logps = all_logps[num_pos_samples:] 415 | pos_logits = all_logits[:num_pos_samples] 416 | neg_logits = all_logits[num_pos_samples:] 417 | return pos_logps, neg_logps, pos_logits, neg_logits 418 | 419 | def get_batch_metrics( 420 | self, 421 | batch: Dict[str, Union[List, torch.LongTensor]], 422 | loss_config: DictConfig, 423 | train=True, 424 | ): 425 | """ 426 | Compute the SFT or DPO loss and other metrics for the given batch of inputs. 427 | """ 428 | 429 | metrics = {} 430 | train_test = "train" if train else "valid" 431 | kl_loss = None 432 | 433 | if loss_config.name == "dpo": 434 | ( 435 | policy_pos_logps, 436 | policy_neg_logps, 437 | policy_pos_logits, 438 | policy_neg_logits, 439 | ) = self.concatenated_forward(self.policy, batch) 440 | with torch.no_grad(): 441 | ( 442 | ref_pos_logps, 443 | ref_neg_logps, 444 | ref_pos_logits, 445 | ref_neg_logits, 446 | ) = self.concatenated_forward(self.reference_model, batch) 447 | losses, pos_rewards, neg_rewards = dpo_loss( 448 | policy_pos_logps, 449 | policy_neg_logps, 450 | ref_pos_logps, 451 | ref_neg_logps, 452 | beta=loss_config.beta, 453 | reference_free=loss_config.reference_free, 454 | ) 455 | 456 | pos_kl_div, neg_kl_div = get_kl_div( 457 | self.kl_criterion, 458 | policy_pos_logits, 459 | policy_neg_logits, 460 | ref_pos_logits, 461 | ref_neg_logits, 462 | ) 463 | 464 | reward_accuracies = (pos_rewards > neg_rewards).float() 465 | 466 | pos_rewards = all_gather_if_needed( 467 | pos_rewards, self.rank, self.world_size 468 | ) 469 | neg_rewards = all_gather_if_needed( 470 | neg_rewards, self.rank, self.world_size 471 | ) 472 | reward_accuracies = all_gather_if_needed( 473 | reward_accuracies, self.rank, self.world_size 474 | ) 475 | 476 | metrics[f"rewards_{train_test}/positive"] = ( 477 | pos_rewards.cpu().numpy().tolist() 478 | ) 479 | metrics[f"rewards_{train_test}/negative"] = ( 480 | neg_rewards.cpu().numpy().tolist() 481 | ) 482 | metrics[f"rewards_{train_test}/accuracies"] = ( 483 | reward_accuracies.cpu().numpy().tolist() 484 | ) 485 | metrics[f"rewards_{train_test}/margins"] = ( 486 | (pos_rewards - neg_rewards).cpu().numpy().tolist() 487 | ) 488 | 489 | policy_neg_logps = all_gather_if_needed( 490 | policy_neg_logps.detach(), self.rank, self.world_size 491 | ) 492 | metrics[f"logps_{train_test}/negative"] = ( 493 | policy_neg_logps.cpu().numpy().tolist() 494 | ) 495 | 496 | metrics[f"kl_div_{train_test}/positive"] = ( 497 | pos_kl_div.detach().cpu().numpy().tolist() 498 | ) 499 | 500 | metrics[f"kl_div_{train_test}/negative"] = ( 501 | neg_kl_div.detach().cpu().numpy().tolist() 502 | ) 503 | 504 | elif loss_config.name == "sft": 505 | policy_pos_logits = self.policy( 506 | batch["pos_input_ids"], 507 | attention_mask=batch["pos_attention_mask"], 508 | ).logits.to(torch.float32) 509 | policy_pos_logps = get_batch_logps( 510 | policy_pos_logits, 511 | batch["pos_labels"], 512 | average_log_prob=False, 513 | ) 514 | 515 | losses = -policy_pos_logps 516 | 517 | policy_pos_logps = all_gather_if_needed( 518 | policy_pos_logps.detach(), self.rank, self.world_size 519 | ) 520 | metrics[f"logps_{train_test}/positive"] = ( 521 | policy_pos_logps.cpu().numpy().tolist() 522 | ) 523 | 524 | all_devices_losses = all_gather_if_needed( 525 | losses.detach(), self.rank, self.world_size 526 | ) 527 | metrics[f"loss/{train_test}"] = ( 528 | all_devices_losses.cpu().numpy().tolist() 529 | ) 530 | 531 | return losses.mean(), metrics 532 | 533 | def train_loop(self): 534 | """Begin either SFT or DPO training, with periodic evaluation.""" 535 | 536 | rank0_print(f"Using {self.config.optimizer} optimizer") 537 | self.optimizer = getattr(torch.optim, self.config.optimizer)( 538 | self.policy.parameters(), lr=self.config.lr 539 | ) 540 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 541 | self.optimizer, 542 | lr_lambda=lambda step: min( 543 | 1.0, (step + 1) / (self.config.warmup_steps + 1) 544 | ), 545 | ) 546 | 547 | torch.manual_seed(self.seed) 548 | np.random.seed(self.seed) 549 | random.seed(self.seed) 550 | 551 | if self.config.loss.name == "dpo": 552 | self.reference_model.eval() 553 | 554 | for batch in self.train_iterator: 555 | if self.example_counter % self.config.eval_every == 0 and ( 556 | self.example_counter > 0 or self.config.do_first_eval 557 | ): 558 | result = self.eval() 559 | if result == -1: 560 | return 561 | 562 | self.train(batch) 563 | 564 | def train(self, batch): 565 | """ 566 | Run single train step. 567 | """ 568 | self.policy.train() 569 | 570 | start_time = time.time() 571 | batch_metrics = defaultdict(list) 572 | for microbatch_idx in range(self.config.gradient_accumulation_steps): 573 | # batch: 574 | # { 575 | # "pos_input_ids": Tensor[batch, seq], 576 | # "pos_attention_mask": Tensor[batch, seq], 577 | # "neg_input_ids": Tensor[batch, seq], 578 | # "neg_attention_mask": Tensor[batch, seq], 579 | # } 580 | self.policy.train() 581 | global_microbatch = slice_and_move_batch_for_device( 582 | batch, 583 | microbatch_idx, 584 | self.config.gradient_accumulation_steps, 585 | self.rank, 586 | ) 587 | local_microbatch = slice_and_move_batch_for_device( 588 | global_microbatch, self.rank, self.world_size, self.rank 589 | ) 590 | loss, metrics = self.get_batch_metrics( 591 | local_microbatch, self.config.loss, train=True 592 | ) 593 | (loss / self.config.gradient_accumulation_steps).backward() 594 | 595 | for k, v in metrics.items(): 596 | batch_metrics[k].extend(v) 597 | 598 | grad_norm = self.clip_gradient() 599 | self.optimizer.step() 600 | self.scheduler.step() 601 | self.optimizer.zero_grad() 602 | 603 | step_time = time.time() - start_time 604 | examples_per_second = self.config.batch_size / step_time 605 | batch_metrics["examples_per_second"].append(examples_per_second) 606 | batch_metrics["grad_norm"].append(grad_norm) 607 | 608 | self.batch_counter += 1 609 | self.example_counter += self.config.batch_size 610 | 611 | if ( 612 | self.last_log is None 613 | or time.time() - self.last_log 614 | > self.config.minimum_log_interval_secs 615 | ): 616 | mean_train_metrics = { 617 | k: sum(v) / len(v) for k, v in batch_metrics.items() 618 | } 619 | mean_train_metrics["counters/examples"] = self.example_counter 620 | mean_train_metrics["counters/updates"] = self.batch_counter 621 | rank0_print( 622 | f"train stats after {self.example_counter} examples: {formatted_dict(mean_train_metrics)}" 623 | ) 624 | 625 | if self.config.wandb.enabled and self.rank == 0: 626 | wandb.log(mean_train_metrics, step=self.example_counter) 627 | 628 | self.last_log = time.time() 629 | 630 | def eval(self): 631 | """ 632 | Run evaluation. 633 | """ 634 | rank0_print( 635 | f"Running evaluation after {self.example_counter} train examples" 636 | ) 637 | self.policy.eval() 638 | 639 | all_eval_metrics = defaultdict(list) 640 | if self.config.sample_during_eval: 641 | all_policy_samples, all_reference_samples = [], [] 642 | 643 | for eval_batch in ( 644 | tqdm(self.eval_batches, desc="Computing eval metrics") 645 | if self.rank == 0 646 | else self.eval_batches 647 | ): 648 | 649 | local_eval_batch = slice_and_move_batch_for_device( 650 | eval_batch, self.rank, self.world_size, self.rank 651 | ) 652 | with torch.no_grad(): 653 | _, eval_metrics = self.get_batch_metrics( 654 | local_eval_batch, self.config.loss, train=False 655 | ) 656 | 657 | for k, v in eval_metrics.items(): 658 | all_eval_metrics[k].extend(v) 659 | 660 | if ( 661 | self.config.sample_during_eval 662 | and self.example_counter % self.config.sample_every == 0 663 | ): 664 | if self.config.n_eval_model_samples < self.config.eval_batch_size: 665 | rank0_print( 666 | f"Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < \ 667 | eval_batch_size ({self.config.eval_batch_size}). \ 668 | Sampling from the first complete eval batch of prompts." 669 | ) 670 | sample_batches = self.eval_batches[:1] 671 | else: 672 | n_sample_batches = ( 673 | self.config.n_eval_model_samples 674 | // self.config.eval_batch_size 675 | ) 676 | sample_batches = self.eval_batches[:n_sample_batches] 677 | 678 | for eval_batch in ( 679 | tqdm(sample_batches, desc="Generating samples...") 680 | if self.rank == 0 681 | else sample_batches 682 | ): 683 | local_eval_batch = slice_and_move_batch_for_device( 684 | eval_batch, self.rank, self.world_size, self.rank 685 | ) 686 | ( 687 | policy_samples, 688 | reference_samples, 689 | ) = self.get_batch_samples(local_eval_batch) 690 | 691 | all_policy_samples.extend(policy_samples) 692 | all_reference_samples.extend(reference_samples) 693 | 694 | rank0_print("Policy samples:") 695 | rank0_print(json.dumps(all_policy_samples[:10], indent=2)) 696 | 697 | mean_eval_metrics = { 698 | k: sum(v) / len(v) for k, v in all_eval_metrics.items() 699 | } 700 | self.val_metric_value = mean_eval_metrics[ 701 | self.config.validation_metric 702 | ] 703 | 704 | rank0_print( 705 | f"eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}" 706 | ) 707 | 708 | if self.config.wandb.enabled and self.rank == 0: 709 | wandb.log(mean_eval_metrics, step=self.example_counter) 710 | 711 | if self.example_counter == 0: 712 | return 0 713 | 714 | if ( 715 | self.val_metric_value is not None 716 | and self.val_metric_value * self.val_direction 717 | > self.val_direction * self.best_val_metric 718 | ): 719 | self.best_val_metric = self.val_metric_value 720 | 721 | rank0_print( 722 | f"\n=====\nNew best for {self.config.validation_metric}: {self.best_val_metric}.\n=====\n" 723 | ) 724 | self.patience = 0 725 | 726 | if self.example_counter % self.config.save_every == 0: 727 | if self.config.debug: 728 | rank0_print("skipping save in debug mode") 729 | else: 730 | output_dir = os.path.join(self.run_dir, "checkpoints") 731 | rank0_print( 732 | f"Creating checkpoint to write to {output_dir}..." 733 | ) 734 | self.save(output_dir, mean_eval_metrics) 735 | else: 736 | self.patience += 1 737 | if self.patience >= self.config.validation_patience: 738 | rank0_print("Ran out of patience, stopping training...") 739 | return -1 740 | 741 | return 0 742 | 743 | def clip_gradient(self): 744 | """Clip the gradient norm of the parameters of a non-FSDP policy.""" 745 | return torch.nn.utils.clip_grad_norm_( 746 | self.policy.parameters(), self.config.max_grad_norm 747 | ).item() 748 | 749 | def write_state_dict( 750 | self, 751 | step: int, 752 | state: Dict[str, torch.Tensor], 753 | metrics: Dict, 754 | filename: str, 755 | dir_name: Optional[str] = None, 756 | ): 757 | """Write a checkpoint to disk.""" 758 | if dir_name is None: 759 | dir_name = os.path.join(self.run_dir, f"LATEST") 760 | 761 | os.makedirs(dir_name, exist_ok=True) 762 | output_path = os.path.join(dir_name, filename) 763 | rank0_print(f"writing checkpoint to {output_path}...") 764 | torch.save( 765 | { 766 | "step_idx": step, 767 | "state": state, 768 | "metrics": metrics if metrics is not None else {}, 769 | }, 770 | output_path, 771 | ) 772 | 773 | def save( 774 | self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None 775 | ): 776 | """Save policy, optimizer, and scheduler state to disk.""" 777 | 778 | policy_state_dict = self.policy.state_dict() 779 | self.write_state_dict( 780 | self.example_counter, 781 | policy_state_dict, 782 | metrics, 783 | "policy.pt", 784 | output_dir, 785 | ) 786 | del policy_state_dict 787 | 788 | optimizer_state_dict = self.optimizer.state_dict() 789 | self.write_state_dict( 790 | self.example_counter, 791 | optimizer_state_dict, 792 | metrics, 793 | "optimizer.pt", 794 | output_dir, 795 | ) 796 | del optimizer_state_dict 797 | 798 | scheduler_state_dict = self.scheduler.state_dict() 799 | self.write_state_dict( 800 | self.example_counter, 801 | scheduler_state_dict, 802 | metrics, 803 | "scheduler.pt", 804 | output_dir, 805 | ) 806 | 807 | 808 | class FSDPTrainer(BasicTrainer): 809 | def __init__( 810 | self, 811 | policy: nn.Module, 812 | config: DictConfig, 813 | seed: int, 814 | run_dir: str, 815 | reference_model: Optional[nn.Module] = None, 816 | rank: int = 0, 817 | world_size: int = 1, 818 | ): 819 | """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs. 820 | 821 | This trainer will shard both the policy and reference model across all available GPUs. 822 | Models are sharded at the block level, where the block class name is provided in the config. 823 | """ 824 | 825 | super().__init__( 826 | policy, config, seed, run_dir, reference_model, rank, world_size 827 | ) 828 | assert ( 829 | config.model.block_name is not None 830 | ), "must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP" 831 | 832 | wrap_class = get_block_class_from_model( 833 | policy, config.model.block_name 834 | ) 835 | model_auto_wrap_policy = functools.partial( 836 | transformer_auto_wrap_policy, 837 | transformer_layer_cls={wrap_class}, 838 | ) 839 | 840 | shared_fsdp_kwargs = dict( 841 | auto_wrap_policy=model_auto_wrap_policy, 842 | sharding_strategy=ShardingStrategy.FULL_SHARD, 843 | cpu_offload=CPUOffload(offload_params=False), 844 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 845 | device_id=rank, 846 | ignored_modules=None, 847 | limit_all_gathers=False, 848 | use_orig_params=False, 849 | sync_module_states=False, 850 | ) 851 | 852 | rank0_print("Sharding policy...") 853 | mp_dtype = ( 854 | getattr(torch, config.model.fsdp_policy_mp) 855 | if config.model.fsdp_policy_mp is not None 856 | else None 857 | ) 858 | policy_mp_policy = MixedPrecision( 859 | param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype 860 | ) 861 | self.policy = FSDP( 862 | policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy 863 | ) 864 | 865 | if config.activation_checkpointing: 866 | rank0_print("Attempting to enable activation checkpointing...") 867 | try: 868 | # use activation checkpointing, according to: 869 | # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/ 870 | # 871 | # first, verify we have FSDP activation support ready by importing: 872 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 873 | checkpoint_wrapper, 874 | apply_activation_checkpointing, 875 | CheckpointImpl, 876 | ) 877 | 878 | non_reentrant_wrapper = functools.partial( 879 | checkpoint_wrapper, 880 | offload_to_cpu=False, 881 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 882 | ) 883 | except Exception as e: 884 | rank0_print("FSDP activation checkpointing not available:", e) 885 | else: 886 | check_fn = lambda submodule: isinstance(submodule, wrap_class) 887 | rank0_print( 888 | "Applying activation checkpointing wrapper to policy..." 889 | ) 890 | apply_activation_checkpointing( 891 | self.policy, 892 | checkpoint_wrapper_fn=non_reentrant_wrapper, 893 | check_fn=check_fn, 894 | ) 895 | rank0_print("FSDP activation checkpointing enabled!") 896 | 897 | if config.loss.name == "dpo": 898 | rank0_print("Sharding reference model...") 899 | self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs) 900 | 901 | print("Loaded model on rank", rank) 902 | dist.barrier() 903 | 904 | def clip_gradient(self): 905 | """ 906 | Clip the gradient norm of the parameters of an FSDP policy, 907 | gathering the gradients across all GPUs. 908 | """ 909 | return self.policy.clip_grad_norm_(self.config.max_grad_norm).item() 910 | 911 | def save(self, output_dir=None, metrics=None): 912 | """ 913 | Save policy, optimizer, and scheduler state to disk, 914 | gathering from all processes and saving only on the rank 0 process. 915 | """ 916 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 917 | with FSDP.state_dict_type( 918 | self.policy, 919 | StateDictType.FULL_STATE_DICT, 920 | state_dict_config=save_policy, 921 | ): 922 | policy_state_dict = self.policy.state_dict() 923 | 924 | if self.rank == 0: 925 | self.write_state_dict( 926 | self.example_counter, 927 | policy_state_dict, 928 | metrics, 929 | "policy.pt", 930 | output_dir, 931 | ) 932 | del policy_state_dict 933 | dist.barrier() 934 | 935 | save_policy = FullOptimStateDictConfig( 936 | offload_to_cpu=True, rank0_only=True 937 | ) 938 | with FSDP.state_dict_type( 939 | self.policy, 940 | StateDictType.FULL_STATE_DICT, 941 | optim_state_dict_config=save_policy, 942 | ): 943 | optimizer_state_dict = FSDP.optim_state_dict( 944 | self.policy, self.optimizer 945 | ) 946 | 947 | if self.rank == 0: 948 | self.write_state_dict( 949 | self.example_counter, 950 | optimizer_state_dict, 951 | metrics, 952 | "optimizer.pt", 953 | output_dir, 954 | ) 955 | del optimizer_state_dict 956 | dist.barrier() 957 | 958 | if self.rank == 0: 959 | scheduler_state_dict = self.scheduler.state_dict() 960 | self.write_state_dict( 961 | self.example_counter, 962 | scheduler_state_dict, 963 | metrics, 964 | "scheduler.pt", 965 | output_dir, 966 | ) 967 | dist.barrier() 968 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions. 3 | """ 4 | 5 | import os 6 | 7 | VERBOSE = os.getenv("VERBOSE", 0) 8 | 9 | 10 | def verbose_print(x): 11 | if VERBOSE: 12 | print(x) 13 | --------------------------------------------------------------------------------