├── requirements.txt ├── pyproject.toml ├── LICENSE ├── .gitignore ├── README.md └── abliterator.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.3.0 2 | einops>=0.8.0 3 | datasets>=2.19.1 4 | scikit-learn>=1.5.0 5 | tqdm>=4.66.4 6 | transformer-lens @ git+https://github.com/TransformerLensOrg/TransformerLens.git@dev 7 | transformers>=4.41.1 8 | jaxtyping>=0.2.28 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "abliterator" 3 | version = "0.0.0" 4 | description = "Python library for transformer activation steering and ablation." 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | license = {file = "LICENSE"} 8 | 9 | keywords = ["transformers", "steering", "ablation", "interpretability", "machine learning"] 10 | 11 | authors = [] 12 | 13 | maintainers = [] 14 | 15 | dependencies = [] 16 | 17 | [build-system] 18 | requires = ["setuptools>=68.0"] 19 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FailSpy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.dll 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # abliterator.py 2 | Simple Python library/structure to ablate features in LLMs which are supported by TransformerLens. 3 | 4 | Most of its advantage in workflow comes from being able to enter temporary contexts, quickly cache activations with N samples, refusal direction calculation built-in, and tokenizer utilities. As well as wrapping around certain quirks of TransformerLens. 5 | 6 | If you're interested in notebooking your own orthgonalized model, this library will help save you a LOT of time in performing and measuring experiments to find your best orthogonalization. 7 | 8 | This is ultimately just bits and pieces to make it so to process and experiment with ablation direction turns into shorter, hopefully clearer code without losing track of where things are at; encapsulating a lot of useful logic that you'll find yourself writing if you're looking to do this more. 9 | 10 | This library is so exceedingly barebones for right now, and documentation is slim at the moment (WYSIWYG!). Right now, it feels like a glorified IPython notebook rather than something more broadly useful. 11 | I want to publish this now to lay out the template, and hopefully bring this up to snuff over time. Ideally with the help of the community! 12 | 13 | Right now, this works very well for my own personal workflow, but I would like to systematize and ideally automate this further, and broaden out from the pure "harmless / harmful" feature ablation, to augmentation, and adding additional features 14 | 15 | ## Loading a model in 16 | ```python 17 | import abliterator 18 | 19 | model = "meta-llama/Meta-Llama-3-70B-Instruct" # the huggingface or path to the model you're interested in loading in 20 | dataset = [abliterator.get_harmful_instructions(), abliterator.get_harmless_instructions()] # datasets to be used for caching and testing, split by harmful/harmless 21 | device = 'cuda' # optional: defaults to cuda 22 | n_devices = None # optional: when set to None, defaults to `device.cuda.device_count` 23 | cache_fname = 'my_cached_point.pth' # optional: if you need to save where you left off, you can use `save_activations(filename)` which will write out a file. This is how you load that back in. 24 | activation_layers = None # optional: defaults to ['resid_pre', 'resid_mid', 'resid_post'] which are the residual streams. Setting to None will cache ALL activation layer types 25 | chat_template = None # optional: defaults to Llama-3 instruction template. You can use a format string e.g. ("{instruction}") or a custom class with format function -- it just needs an '.format(instruction="")` function. See abliterator.ChatTemplate for a very basic structure. 26 | negative_toks = [4250] # optional, but highly recommended: ' cannot' in Llama's tokenizer. Tokens you don't want to be seeing. Defaults to my preset for Llama-3 models 27 | positive_toks = [23371, 40914] # optional, but highly recommended: ' Sure' and 'Sure' in Llama's tokenizer. Tokens you want to be seeing, basically. Defaults to my preset for Llama-3 models 28 | 29 | my_model = abliterator.ModelAbliterator( 30 | model, 31 | dataset, 32 | device='cuda', 33 | n_devices=None, 34 | cache_fname=None, 35 | activation_layers=['resid_pre', 'resid_post', 'attn_out', 'mlp_out'], 36 | chat_template="\n{instruction}", 37 | positive_toks=positive_toks, 38 | negative_toks=negative_toks 39 | ) 40 | ``` 41 | 42 | ## Cache activations/sample dataset 43 | Once loaded in, run the model against N samples of harmful, and N samples of harmless so it has some data to work with: 44 | ```python 45 | my_model.cache_activations(N=512,reset=True,preserve_harmless=True) 46 | ``` 47 | `preserve_harmless=True` is generally useful, as it keeps the "desired behaviour" unaltered from any stacked modifications if you run it after some mods. 48 | 49 | ## Saving state 50 | Most of the advantage of this is a lot of groundwork has been laid to make it so you aren't repeating yourself 1000 times just to try one little experiment. 51 | `save_activations('file.pth')` will save your cached activations, and any currently applied modifications to the model's weights to a file so you can restore them next time you load up with `cache_fname='file.pth'` in your ModelAbliterator initialization. 52 | 53 | ## Getting refusal directions from the cached activations 54 | Speaking of modding, here's a simple representation of how to pick, test, and actually apply a direction from a layer's activations: 55 | ```python 56 | refusal_dirs = my_model.refusal_dirs() 57 | testing_dir = refusal_dirs['blocks.18.hook_resid_pre'] 58 | my_model.test_dir(testing_dir, N=32, use_hooks=True) # I recommend use_hooks=True for large models as it can slow things down otherwise, but use_hooks=False can give you more precise scoring to an actual weights modification 59 | ``` 60 | `test_dir` will apply your refusal_dir to the model temporarily, and run against N samples of test data, and return a composite (negative_score, positive_score) from those runs. Generally, you want negative_score to go down, positive_score to go up. 61 | 62 | ### Testing lots of refusal directions 63 | 64 | This is one of the functions included in the library, but it's also useful for showing how this can be generalized to test a whole bunch of directions. 65 | ```python 66 | def find_best_refusal_dir(N=4, use_hooks=True, invert=False): 67 | dirs = self.refusal_dirs(invert=invert) 68 | scores = [] 69 | for direction in tqdm(dirs.items()): 70 | score = self.test_dir(direction[1],N=N,use_hooks=use_hooks)[0] 71 | scores.append((score,direction)) 72 | return sorted(scores,key=lambda x:x[0]) 73 | 74 | ``` 75 | 76 | ## Applying the weights 77 | 78 | And now, to apply it! 79 | ```python 80 | my_amazing_dir = find_best_refusal_dir()[0] 81 | my_model.apply_refusal_dirs([my_amazing_dir],layers=None) 82 | ``` 83 | Note the `layers=None`. You can supply a list here to specify which layers you want to apply the refusal direction to. None will apply it to all writable layers. 84 | 85 | ### Blacklisting specific layers 86 | Sometimes some layers are troublesome no matter what you do. If you're worried about accidentally replacing it, you can blacklist it to prevent any alteration from occurring: 87 | ```python 88 | my_model.blacklist_layer(27) 89 | my_model.blacklist_layer([i for i in range(27,30)]) # it also accepts lists! 90 | ``` 91 | 92 | #### Whitelisting 93 | And naturally, to undo this and make sure a layer can be overwritten: 94 | ``` 95 | my_model.whitelist_layer(27) 96 | ``` 97 | By default, all layers are whitelisted. I recommend blacklisting the first and last couple layers, as those can and will have dramatic effects on outputs. 98 | 99 | Neither of these will provide success/failure states. They will just assure the desired state in running it at that instant. 100 | 101 | ## Benchmarking 102 | Now to make sure you've not damaged the model dramatically after applying some stuff, you can do a test run: 103 | ```python 104 | with my_model: # loads a temporary context with the model 105 | ortho.apply_refusal_dir([my_new_precious_dir]) # Because this was applied in the 'with my_model:', it will be unapplied after coming out. 106 | print(my_model.mse_harmless(N=128)) # While we've got the dir applied, this tells you the Mean Squared Error using the current cached harmless runs as "ground truth" (loss function, effectively) 107 | ``` 108 | 109 | ### Want to see it run? Test it! 110 | ```python 111 | ortho.test(N=16,batch_size = 4) # runs N samples from the harmful test set and prints them for the user. Good way to check the model hasn't completely derailed. 112 | # Note that by default if a test run produces a negative token, it will stop the whole batch and move on to the next. (it will show lots of '!!!!' in Llama-3's case, as that's token ID 0) 113 | 114 | ortho.generate("How much wood could a woodchuck chuck if a woodchuck could chuck wood?") # runs and prints the prompt! 115 | ``` 116 | 117 | ## Utility functions 118 | Documentation coming soon. 119 | 120 | ## How to Save as a HuggingFace model 121 | Functionality coming soon. For now, use PyTorch's saving method, or see my notebook for an idea of how to do this yourself. 122 | 123 | -------------------------------------------------------------------------------- /abliterator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import functools 4 | import einops 5 | import gc 6 | import re 7 | from itertools import islice 8 | 9 | from datasets import load_dataset 10 | from sklearn.model_selection import train_test_split 11 | from tqdm import tqdm 12 | from torch import Tensor 13 | from typing import Callable, Dict, List, Set, Tuple 14 | from transformer_lens import HookedTransformer, utils, ActivationCache, loading 15 | from transformer_lens.hook_points import HookPoint 16 | from transformers import AutoTokenizer, AutoModelForCausalLM 17 | from jaxtyping import Float, Int 18 | 19 | def batch(iterable, n): 20 | it = iter(iterable) 21 | while True: 22 | chunk = list(islice(it, n)) 23 | if not chunk: 24 | break 25 | yield chunk 26 | 27 | def get_harmful_instructions() -> Tuple[List[str], List[str]]: 28 | hf_path = 'Undi95/orthogonal-activation-steering-TOXIC' 29 | dataset = load_dataset(hf_path) 30 | instructions = [i['goal'] for i in dataset['test']] 31 | 32 | train, test = train_test_split(instructions, test_size=0.2, random_state=42) 33 | return train, test 34 | 35 | 36 | def get_harmless_instructions() -> Tuple[List[str], List[str]]: 37 | hf_path = 'tatsu-lab/alpaca' 38 | dataset = load_dataset(hf_path) 39 | # filter for instructions that do not have inputs 40 | instructions = [] 41 | for i in range(len(dataset['train'])): 42 | if dataset['train'][i]['input'].strip() == '': 43 | instructions.append(dataset['train'][i]['instruction']) 44 | 45 | train, test = train_test_split(instructions, test_size=0.2, random_state=42) 46 | return train, test 47 | 48 | def prepare_dataset(dataset:Tuple[List[str], List[str]]|List[str]) -> Tuple[List[str], List[str]]: 49 | if len(dataset) != 2: 50 | # assumed to not be split into train/test 51 | train, test = train_test_split(dataset, test_size=0.1, random_state=42) 52 | else: 53 | train, test = dataset 54 | 55 | return train, test 56 | 57 | def directional_hook( 58 | activation: Float[Tensor, "... d_model"], 59 | hook: HookPoint, 60 | direction: Float[Tensor, "d_model"] 61 | ) -> Float[Tensor, "... d_model"]: 62 | if activation.device != direction.device: 63 | direction = direction.to(activation.device) 64 | 65 | proj = einops.einsum(activation, direction.view(-1, 1), '... d_model, d_model single -> ... single') * direction 66 | return activation - proj 67 | 68 | def clear_mem(): 69 | gc.collect() 70 | torch.cuda.empty_cache() 71 | 72 | def measure_fn(measure: str, input_tensor: Tensor, *args, **kwargs) -> Float[Tensor, '...']: 73 | avail_measures = { 74 | 'mean': torch.mean, 75 | 'median': torch.median, 76 | 'max': torch.max, 77 | 'stack': torch.stack 78 | } 79 | 80 | try: 81 | return avail_measures[measure](input_tensor, *args, **kwargs) 82 | except KeyError: 83 | raise NotImplementedError(f"Unknown measure function '{measure}'. Available measures:" + ', '.join([f"'{str(fn)}'" for fn in avail_measures.keys()]) ) 84 | 85 | class ChatTemplate: 86 | def __init__(self,model,template): 87 | self.model = model 88 | self.template = template 89 | 90 | def format(self,instruction): 91 | return self.template.format(instruction=instruction) 92 | 93 | def __enter__(self): 94 | self.prev = self.model.chat_template 95 | self.model.chat_template = self 96 | return self 97 | 98 | def __exit__(self,exc,exc_value,exc_tb): 99 | self.model.chat_template = self.prev 100 | del self.prev 101 | 102 | 103 | LLAMA3_CHAT_TEMPLATE = """<|start_header_id|>user<|end_header_id|>\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""" 104 | PHI3_CHAT_TEMPLATE = """<|user|>\n{instruction}<|end|>\n<|assistant|>""" 105 | 106 | class ModelAbliterator: 107 | def __init__( 108 | self, 109 | model: str, 110 | dataset: Tuple[List[str], List[str]]|List[Tuple[List[str], List[str]]], 111 | device: str = 'cuda', 112 | n_devices: int = None, 113 | cache_fname: str = None, 114 | activation_layers: List[str] = ['resid_pre', 'resid_post', 'mlp_out', 'attn_out'], 115 | chat_template: str = None, 116 | positive_toks: List[int]|Tuple[int]|Set[int]|Int[Tensor, '...'] = None, 117 | negative_toks: List[int]|Tuple[int]|Set[int]|Int[Tensor, '...'] = None 118 | ): 119 | self.MODEL_PATH = model 120 | if n_devices is None and torch.cuda.is_available(): 121 | n_devices = torch.cuda.device_count() 122 | elif n_devices is None: 123 | n_devices = 1 124 | 125 | # Save memory 126 | torch.set_grad_enabled(False) 127 | 128 | self.model = HookedTransformer.from_pretrained_no_processing( 129 | model, 130 | n_devices=n_devices, 131 | device=device, 132 | dtype=torch.bfloat16, 133 | default_padding_side='left' 134 | ) 135 | 136 | self.model.requires_grad_(False) 137 | 138 | self.model.tokenizer.padding_side = 'left' 139 | self.model.tokenizer.pad_token = self.model.tokenizer.eos_token 140 | self.chat_template = chat_template or ChatTemplate(self,LLAMA3_CHAT_TEMPLATE) 141 | 142 | self.hidden_size = self.model.cfg.d_model 143 | self.original_state = {k:v.to('cpu') for k,v in self.model.state_dict().items()} 144 | self.harmful = {} 145 | self.harmless = {} 146 | self.modified_layers = {'mlp':{}, 'W_O':{}} 147 | self.checkpoints = [] 148 | 149 | if cache_fname is not None: 150 | outs = torch.load(cache_fname,map_location='cpu') 151 | self.harmful,self.harmless,modified_layers,checkpoints = outs[:4] 152 | self.checkpoints = checkpoints or [] 153 | self.modified_layers = modified_layers 154 | 155 | self.harmful_inst_train,self.harmful_inst_test = prepare_dataset(dataset[0]) 156 | self.harmless_inst_train,self.harmless_inst_test = prepare_dataset(dataset[1]) 157 | 158 | self.fwd_hooks = [] 159 | self.modified = False 160 | self.activation_layers = [activation_layers] if type(activation_layers) == str else activation_layers 161 | if negative_toks == None: 162 | print("WARNING: You've not set 'negative_toks', defaulting to tokens for Llama-3 vocab") 163 | self.negative_toks = {4250, 14931, 89735, 20451, 11660, 11458, 956} # llama-3 refusal tokens e.g. ' cannot', ' unethical', ' sorry' 164 | else: 165 | self.negative_toks = negative_toks 166 | if positive_toks == None: 167 | print("WARNING: You've not set 'positive_toks', defaulting to tokens for Llama-3 vocab") 168 | self.positive_toks = {32,1271,8586,96556,78145} 169 | else: 170 | self.positive_toks = positive_toks 171 | self._blacklisted = set() 172 | 173 | def __enter__(self): 174 | if hasattr(self,"current_state"): 175 | raise Exception("Cannot do multi-contexting") 176 | self.current_state = self.model.state_dict() 177 | self.current_layers = self.modified_layers.copy() 178 | self.was_modified = self.modified 179 | return self 180 | 181 | def __exit__(self,exc,exc_value,exc_tb): 182 | self.model.load_state_dict(self.current_state) 183 | del self.current_state 184 | self.modified_layers = self.current_layers 185 | del self.current_layers 186 | self.modified = self.was_modified 187 | del self.was_modified 188 | 189 | def reset_state(self): 190 | self.modified = False 191 | self.modified_layers = {'mlp':{}, 'W_O':{}} 192 | self.model.load_state_dict(self.original_state) 193 | 194 | def checkpoint(self): 195 | # MAYBE: Offload to disk? That way we're not taking up RAM with this 196 | self.checkpoints.append(self.modified_layers.copy()) 197 | 198 | # Utility functions 199 | 200 | def blacklist_layer(self, layer: int|List[int]): 201 | # Prevents a layer from being modified 202 | if type(layer) is list: 203 | for l in layer: 204 | self._blacklisted.add(l) 205 | else: 206 | self._blacklisted.add(layer) 207 | 208 | def whitelist_layer(self,layer: int|List[int]): 209 | # Removes layer from blacklist to allow modification 210 | if type(layer) is list: 211 | for l in layer: 212 | self._blacklisted.discard(l) 213 | else: 214 | self._blacklisted.discard(layer) 215 | 216 | def save_activations(self, fname: str): 217 | torch.save([self.harmful,self.harmless,self.modified_layers if self.modified_layers['mlp'] or self.modified_layers['W_O'] else None, self.checkpoints if len(self.checkpoints) > 0 else None], fname) 218 | 219 | def get_whitelisted_layers(self) -> List[int]: 220 | return [l for l in range(self.model.cfg.n_layers) if l not in self._blacklisted] 221 | 222 | def get_all_act_names(self, activation_layers: List[str] = None) -> List[Tuple[int,str]]: 223 | return [(i,utils.get_act_name(act_name,i)) for i in self.get_whitelisted_layers() for act_name in (activation_layers or self.activation_layers)] 224 | 225 | def calculate_mean_dirs(self, key: str, include_overall_mean: bool = False) -> Dict[str, Float[Tensor, 'd_model']]: 226 | dirs = { 227 | 'harmful_mean': torch.mean(self.harmful[key], dim=0), 228 | 'harmless_mean': torch.mean(self.harmless[key], dim=0) 229 | } 230 | 231 | if include_overall_mean: 232 | if self.harmful[key].shape != self.harmless[key].shape or self.harmful[key].device.type == 'cuda': 233 | # If the shapes are different, we can't add them together; we'll need to concatenate the tensors first. 234 | # Using 'cpu', this is slower than the alternative below. 235 | # Using 'cuda', this seems to be faster than the alternatives. 236 | # NOTE: Assume both tensors are on the same device. 237 | # 238 | dirs['mean_dir'] = torch.mean(torch.cat((self.harmful[key], self.harmless[key]), dim=0), dim=0) 239 | else: 240 | # If the shapes are the same, we can add them together, take the mean, 241 | # then divide by 2.0 to account for the initial element-wise addition of the tensors. 242 | # 243 | # The result is identical to: 244 | # `torch.sum(self.harmful[key] + self.harmless[key]) / (len(self.harmful[key]) + len(self.harmless[key]))` 245 | # 246 | dirs['mean_dir'] = torch.mean(self.harmful[key] + self.harmless[key], dim=0) / 2.0 247 | 248 | return dirs 249 | 250 | def get_avg_projections(self, key: str, direction: Float[Tensor, 'd_model']) -> Tuple[Float[Tensor, 'd_model'], Float[Tensor, 'd_model']]: 251 | dirs = self.calculate_mean_dirs(self,key) 252 | return (torch.dot(dirs['harmful_mean'], direction), torch.dot(dirs['harmless_mean'], direction)) 253 | 254 | def get_layer_dirs(self, layer, key: str = None, include_overall_mean: bool=False) -> Dict[str, Float[Tensor, 'd_model']]: 255 | act_key = key or self.activation_layers[0] 256 | if len(self.harmfuls[key]) < layer: 257 | raise IndexError("Invalid layer") 258 | return self.calculate_mean_dirs(utils.get_act_name(act_key, layer), include_overall_mean=include_overall_mean) 259 | 260 | def refusal_dirs(self, invert: bool = False) -> Dict[str, Float[Tensor, 'd_model']]: 261 | if not self.harmful: 262 | raise IndexError("No cache") 263 | 264 | refusal_dirs = {key:self.calculate_mean_dirs(key) for key in self.harmful if '.0.' not in key} # don't include layer 0, as it often becomes NaN 265 | if invert: 266 | refusal_dirs = {key:v['harmless_mean']-v['harmful_mean'] for key,v in refusal_dirs.items()} 267 | else: 268 | refusal_dirs = {key:v['harmful_mean']-v['harmless_mean'] for key,v in refusal_dirs.items()} 269 | 270 | return {key:(v/v.norm()).to('cpu') for key,v in refusal_dirs.items()} 271 | 272 | def scored_dirs(self,invert = False) -> List[Tuple[str,Float[Tensor, 'd_model']]]: 273 | refusals = self.refusal_dirs(invert=invert) 274 | return sorted([(ln,refusals[act_name]) for ln,act_name in self.get_all_act_names()],reverse=True, key=lambda x:abs(x[1].mean())) 275 | 276 | def get_layer_of_act_name(self, ref: str) -> str|int: 277 | s = re.search(r"\.(\d+)\.",ref) 278 | return s if s is None else int(s[1]) 279 | 280 | def layer_attn(self, layer: int, replacement: Float[Tensor, "d_model"] = None) -> Float[Tensor, "d_model"]: 281 | if replacement is not None and layer not in self._blacklisted: 282 | # make sure device doesn't change 283 | self.modified = True 284 | self.model.blocks[layer].attn.W_O.data = replacement.to(self.model.blocks[layer].attn.W_O.device) 285 | self.modified_layers['W_O'][layer] = self.modified_layers.get(layer,[])+[(self.model.blocks[layer].attn.W_O.data.to('cpu'),replacement.to('cpu'))] 286 | return self.model.blocks[layer].attn.W_O.data 287 | 288 | def layer_mlp(self, layer: int, replacement: Float[Tensor, "d_model"] = None) -> Float[Tensor, "d_model"]: 289 | if replacement is not None and layer not in self._blacklisted: 290 | # make sure device doesn't change 291 | self.modified = True 292 | self.model.blocks[layer].mlp.W_out.data = replacement.to(self.model.blocks[layer].mlp.W_out.device) 293 | self.modified_layers['mlp'][layer] = self.modified_layers.get(layer,[])+[(self.model.blocks[layer].mlp.W_out.data.to('cpu'),replacement.to('cpu'))] 294 | return self.model.blocks[layer].mlp.W_out.data 295 | 296 | def tokenize_instructions_fn( 297 | self, 298 | instructions: List[str] 299 | ) -> Int[Tensor, 'batch_size seq_len']: 300 | prompts = [self.chat_template.format(instruction=instruction) for instruction in instructions] 301 | return self.model.tokenizer(prompts, padding=True, truncation=False, return_tensors="pt").input_ids 302 | 303 | def generate_logits( 304 | self, 305 | toks: Int[Tensor, 'batch_size seq_len'], 306 | *args, 307 | drop_refusals: bool = True, 308 | stop_at_eos: bool = False, 309 | max_tokens_generated: int = 1, 310 | **kwargs 311 | ) -> Tuple[Float[Tensor, 'batch_size seq_len d_vocab'], Int[Tensor, 'batch_size seq_len']]: 312 | # does most of the model magic 313 | all_toks = torch.zeros((toks.shape[0],toks.shape[1]+max_tokens_generated), dtype=torch.long, device=toks.device) 314 | all_toks[:, :toks.shape[1]] = toks 315 | generating = [i for i in range(toks.shape[0])] 316 | for i in range(max_tokens_generated): 317 | logits = self.model(all_toks[generating, :-max_tokens_generated + i],*args,**kwargs) 318 | next_tokens = logits[:,-1,:].argmax(dim=-1).to('cpu') 319 | all_toks[generating,-max_tokens_generated+i] = next_tokens 320 | if drop_refusals and any(negative_tok in next_tokens for negative_tok in self.negative_toks): 321 | # refusals we handle differently: if it's misbehaving, we stop all batches and move on to the next one 322 | break 323 | if stop_at_eos: 324 | for batch_idx in generating: 325 | generating = [i for i in range(toks.shape[0]) if all_toks[i][-1] != self.model.tokenizer.eos_token_id] 326 | if len(generating) == 0: 327 | break 328 | return logits, all_toks 329 | 330 | def generate( 331 | self, 332 | prompt: List[str]|str, 333 | *model_args, 334 | max_tokens_generated: int = 64, 335 | stop_at_eos: bool = True, 336 | **model_kwargs 337 | ) -> List[str]: 338 | # convenience function to test manual prompts, no caching 339 | if type(prompt) is str: 340 | gen = self.tokenize_instructions_fn([prompt]) 341 | else: 342 | gen = self.tokenize_instructions_fn(prompt) 343 | 344 | logits,all_toks = self.generate_logits(gen, *model_args, stop_at_eos=stop_at_eos, max_tokens_generated=max_tokens_generated, **model_kwargs) 345 | return self.model.tokenizer.batch_decode(all_toks, skip_special_tokens=True) 346 | 347 | def test( 348 | self, 349 | *args, 350 | test_set: List[str] = None, 351 | N: int = 16, 352 | batch_size: int = 4, 353 | **kwargs 354 | ): 355 | if test_set is None: 356 | test_set = self.harmful_inst_test 357 | for prompts in batch(test_set[:min(len(test_set),N)], batch_size): 358 | for i, res in enumerate(self.generate(prompts, *args, **kwargs)): 359 | print(res) 360 | 361 | def run_with_cache( 362 | self, 363 | *model_args, 364 | names_filter: Callable[[str], bool] = None, 365 | incl_bwd: bool = False, 366 | device: str = None, 367 | remove_batch_dim: bool = False, 368 | reset_hooks_end: bool = True, 369 | clear_contexts: bool = False, 370 | fwd_hooks: List[str] = [], 371 | max_new_tokens: int = 1, 372 | **model_kwargs 373 | ) -> Tuple[Float[Tensor, 'batch_size seq_len d_vocab'], Dict[str, Float[Tensor, 'batch_size seq_len d_model']]]: 374 | if names_filter is None and self.activation_layers: 375 | def activation_layering(namefunc: str): 376 | return any(s in namefunc for s in self.activation_layers) 377 | names_filter = activation_layering 378 | 379 | 380 | cache_dict, fwd, bwd = self.model.get_caching_hooks( 381 | names_filter, 382 | incl_bwd, 383 | device, 384 | remove_batch_dim=remove_batch_dim, 385 | pos_slice=utils.Slice(None) 386 | ) 387 | 388 | fwd_hooks = fwd_hooks+fwd+self.fwd_hooks 389 | 390 | if not max_new_tokens: 391 | # must do at least 1 token 392 | max_new_tokens = 1 393 | 394 | with self.model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd, reset_hooks_end=reset_hooks_end, clear_contexts=clear_contexts): 395 | #model_out = self.model(*model_args,**model_kwargs) 396 | model_out,toks = self.generate_logits(*model_args,max_tokens_generated=max_new_tokens, **model_kwargs) 397 | if incl_bwd: 398 | model_out.backward() 399 | 400 | return model_out, cache_dict 401 | 402 | def apply_refusal_dirs( 403 | self, 404 | refusal_dirs: List[Float[Tensor, 'd_model']], 405 | W_O: bool = True, 406 | mlp: bool = True, 407 | layers: List[str] = None 408 | ): 409 | if layers == None: 410 | layers = list(l for l in range(1,self.model.cfg.n_layers)) 411 | for refusal_dir in refusal_dirs: 412 | for layer in layers: 413 | for modifying in [(W_O,self.layer_attn),(mlp,self.layer_mlp)]: 414 | if modifying[0]: 415 | matrix = modifying[1](layer) 416 | if refusal_dir.device != matrix.device: 417 | refusal_dir = refusal_dir.to(matrix.device) 418 | proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir 419 | modifying[1](layer,matrix - proj) 420 | 421 | def induce_refusal_dir( 422 | self, 423 | refusal_dir: Float[Tensor, 'd_model'], 424 | W_O: bool = True, 425 | mlp: bool = True, 426 | layers: List[str] = None 427 | ): 428 | # incomplete, needs work 429 | if layers == None: 430 | layers = list(l for l in range(1,self.model.cfg.n_layers)) 431 | for layer in layers: 432 | for modifying in [(W_O,self.layer_attn),(mlp,self.layer_mlp)]: 433 | if modifying[0]: 434 | matrix = modifying[1](layer) 435 | if refusal_dir.device != matrix.device: 436 | refusal_dir = refusal_dir.to(matrix.device) 437 | proj = einops.einsum(matrix, refusal_dir.view(-1, 1), '... d_model, d_model single -> ... single') * refusal_dir 438 | avg_proj = refusal_dir * self.get_avg_projections(utils.get_act_name(self.activation_layers[0], layer),refusal_dir) 439 | modifying[1](layer,(matrix - proj) + avg_proj) 440 | 441 | def test_dir( 442 | self, 443 | refusal_dir: Float[Tensor, 'd_model'], 444 | activation_layers: List[str] = None, 445 | use_hooks: bool = True, 446 | layers: List[str] = None, 447 | **kwargs 448 | ) -> Dict[str, Float[Tensor, 'd_model']]: 449 | # `use_hooks=True` is better for bigger models as it causes a lot of memory swapping otherwise, but 450 | # `use_hooks=False` is much more representative of the final weights manipulation 451 | 452 | before_hooks = self.fwd_hooks 453 | try: 454 | if layers is None: 455 | layers = self.get_whitelisted_layers() 456 | 457 | if activation_layers is None: 458 | activation_layers = self.activation_layers 459 | 460 | if use_hooks: 461 | hooks = self.fwd_hooks 462 | hook_fn = functools.partial(directional_hook,direction=refusal_dir) 463 | self.fwd_hooks = before_hooks+[(act_name,hook_fn) for ln,act_name in self.get_all_act_names()] 464 | return self.measure_scores(**kwargs) 465 | else: 466 | with self: 467 | self.apply_refusal_dirs([refusal_dir],layers=layers) 468 | return self.measure_scores(**kwargs) 469 | finally: 470 | self.fwd_hooks = before_hooks 471 | 472 | def find_best_refusal_dir( 473 | self, 474 | N: int = 4, 475 | positive: bool = False, 476 | use_hooks: bool = True, 477 | invert: bool = False 478 | ) -> List[Tuple[float,str]]: 479 | dirs = self.refusal_dirs(invert=invert) 480 | if self.modified: 481 | print("WARNING: Modified; will restore model to current modified state each run") 482 | scores = [] 483 | for direction in tqdm(dirs.items()): 484 | score = self.test_dir(direction[1],N=N,use_hooks=use_hooks)[int(positive)] 485 | scores.append((score,direction)) 486 | return sorted(scores,key=lambda x:x[0]) 487 | 488 | def measure_scores( 489 | self, 490 | N: int = 4, 491 | sampled_token_ct: int = 8, 492 | measure: str = 'max', 493 | batch_measure: str = 'max', 494 | positive: bool = False 495 | ) -> Dict[str, Float[Tensor, 'd_model']]: 496 | toks = self.tokenize_instructions_fn(instructions=self.harmful_inst_test[:N]) 497 | logits,cache = self.run_with_cache(toks,max_new_tokens=sampled_token_ct,drop_refusals=False) 498 | 499 | negative_score,positive_score = self.measure_scores_from_logits(logits,sampled_token_ct,measure=batch_measure) 500 | 501 | negative_score = measure_fn(measure,negative_score) 502 | positive_score = measure_fn(measure,positive_score) 503 | return {'negative':negative_score.to('cpu'), 'positive':positive_score.to('cpu')} 504 | 505 | def measure_scores_from_logits( 506 | self, 507 | logits: Float[Tensor, 'batch_size seq_len d_vocab'], 508 | sequence: int, 509 | measure: str = 'max' 510 | ) -> Tuple[Float[Tensor, 'batch_size'], Float[Tensor, 'batch_size']]: 511 | normalized_scores = torch.softmax(logits[:,-sequence:,:].to('cpu'),dim=-1)[:,:,list(self.positive_toks)+list(self.negative_toks)] 512 | 513 | normalized_positive,normalized_negative = torch.split(normalized_scores,[len(self.positive_toks), len(self.negative_toks)], dim=2) 514 | 515 | max_negative_score_per_sequence = torch.max(normalized_negative,dim=-1)[0] 516 | max_positive_score_per_sequence = torch.max(normalized_positive,dim=-1)[0] 517 | 518 | negative_score_per_batch = measure_fn(measure,max_negative_score_per_sequence,dim=-1)[0] 519 | positive_score_per_batch = measure_fn(measure,max_positive_score_per_sequence,dim=-1)[0] 520 | return negative_score_per_batch,positive_score_per_batch 521 | 522 | def do_resid(self, fn_name: str) -> Tuple[Float[Tensor, 'layer batch d_model'], Float[Tensor, 'layer batch d_model'], List[str]]: 523 | if not any("resid" in k for k in self.harmless.keys()): 524 | raise AssertionError("You need residual streams to decompose layers! Run cache_activations with None in `activation_layers`") 525 | resid_harmful,labels = getattr(self.harmful,fn_name)(apply_ln=True,return_labels=True) 526 | resid_harmless = getattr(self.harmless,fn_name)(apply_ln=True) 527 | 528 | return resid_harmful,resid_harmless,labels 529 | 530 | def decomposed_resid(self) -> Tuple[Float[Tensor, 'layer batch d_model'], Float[Tensor, 'layer batch d_model'], List[str]]: 531 | return self.do_resid("decompose_resid") 532 | 533 | def accumulated_resid(self) -> Tuple[Float[Tensor, 'layer batch d_model'], Float[Tensor, 'layer batch d_model'], List[str]]: 534 | return self.do_resid("accumulated_resid") 535 | 536 | def unembed_resid(self, resid: Float[Tensor, "layer batch d_model"], pos: int = -1) -> Float[Tensor, "layer batch d_vocab"]: 537 | W_U = self.model.W_U 538 | if pos == None: 539 | return einops.einsum(resid.to(W_U.device), W_U,"layer batch d_model, d_model d_vocab -> layer batch d_vocab").to('cpu') 540 | else: 541 | return einops.einsum(resid[:,pos,:].to(W_U.device),W_U,"layer d_model, d_model d_vocab -> layer d_vocab").to('cpu') 542 | 543 | def create_layer_rankings( 544 | self, 545 | token_set: List[int]|Set[int]|Int[Tensor, '...'], 546 | decompose: bool = True, 547 | token_set_b: List[int]|Set[int]|Int[Tensor, '...'] = None 548 | ) -> List[Tuple[int,int]]: 549 | decomposer = self.decomposed_resid if decompose else self.accumulated_resid 550 | 551 | decomposed_resid_harmful, decomposed_resid_harmless, labels = decomposer() 552 | 553 | W_U = self.model.W_U.to('cpu') 554 | unembedded_harmful = self.unembed_resid(decomposed_resid_harmful) 555 | unembedded_harmless = self.unembed_resid(decomposed_resid_harmless) 556 | 557 | sorted_harmful_indices = torch.argsort(unembedded_harmful, dim=1, descending=True) 558 | sorted_harmless_indices = torch.argsort(unembedded_harmless, dim=1, descending=True) 559 | 560 | harmful_set = torch.isin(sorted_harmful_indices, torch.tensor(list(token_set))) 561 | harmless_set = torch.isin(sorted_harmless_indices, torch.tensor(list(token_set if token_set_b is None else token_set_b))) 562 | 563 | indices_in_set = zip(harmful_set.nonzero(as_tuple=True)[1],harmless_set.nonzero(as_tuple=True)[1]) 564 | return indices_in_set 565 | 566 | def mse_positive( 567 | self, 568 | N: int = 128, 569 | batch_size: int = 8, 570 | last_indices: int = 1 571 | ) -> Dict[str, Float[Tensor, 'd_model']]: 572 | # Calculate mean squared error against currently loaded negative cached activation 573 | # Idea being to get a general sense of how the "normal" direction has been altered. 574 | # This is to compare ORIGINAL functionality to ABLATED functionality, not for ground truth. 575 | 576 | #load full training set to ensure alignment 577 | toks = self.tokenize_instructions_fn(instructions=self.harmful_inst_train[:N]+self.harmless_inst_train[:N]) 578 | 579 | splitpos = min(N,len(self.harmful_inst_train)) 580 | 581 | # select for just harmless 582 | toks = toks[splitpos:] 583 | self.loss_harmless = {} 584 | 585 | for i in tqdm(range(0,min(N,len(toks)),batch_size)): 586 | logits,cache = self.run_with_cache(toks[i:min(i+batch_size,len(toks))]) 587 | for key in cache: 588 | if any(k in key for k in self.activation_layers): 589 | tensor = torch.mean(cache[key][:, -last_indices:, :],dim=1).to('cpu') 590 | if key not in self.loss_harmless: 591 | self.loss_harmless[key] = tensor 592 | else: 593 | self.loss_harmless[key] = torch.cat((self.loss_harmless[key], tensor),dim=0) 594 | del logits,cache 595 | clear_mem() 596 | 597 | return {k:F.mse_loss(self.loss_harmless[k].float()[:N],self.harmless[k].float()[:N]) for k in self.loss_harmless} 598 | 599 | def create_activation_cache( 600 | self, 601 | toks, 602 | N: int = 128, 603 | batch_size: int = 8, 604 | last_indices: int = 1, 605 | measure_refusal: int = 0, 606 | stop_at_layer: int = None 607 | ) -> Tuple[ActivationCache, List[str]]: 608 | # Base functionality for creating an activation cache with a training set, prefer 'cache_activations' for regular usage 609 | 610 | base = dict() 611 | z_label = [] if measure_refusal > 1 else None 612 | for i in tqdm(range(0,min(N,len(toks)),batch_size)): 613 | logits,cache = self.run_with_cache(toks[i:min(i+batch_size,len(toks))],max_new_tokens=measure_refusal,stop_at_layer=stop_at_layer) 614 | if measure_refusal > 1: 615 | z_label.extend(self.measure_scores_from_logits(logits,measure_refusal)[0]) 616 | for key in cache: 617 | if self.activation_layers is None or any(k in key for k in self.activation_layers): 618 | tensor = torch.mean(cache[key][:,-last_indices:,:].to('cpu'),dim=1) 619 | if key not in base: 620 | base[key] = tensor 621 | else: 622 | base[key] = torch.cat((base[key], tensor), dim=0) 623 | 624 | del logits, cache 625 | clear_mem() 626 | 627 | return ActivationCache(base,self.model), z_label 628 | 629 | def cache_activations( 630 | self, 631 | N: int = 128, 632 | batch_size: int = 8, 633 | measure_refusal: int = 0, 634 | last_indices: int = 1, 635 | reset: bool = True, 636 | activation_layers: int = -1, 637 | preserve_harmless: bool = True, 638 | stop_at_layer: int = None 639 | ): 640 | if hasattr(self,"current_state"): 641 | print("WARNING: Caching activations using a context") 642 | if self.modified: 643 | print("WARNING: Running modified model") 644 | 645 | if activation_layers == -1: 646 | activation_layers = self.activation_layers 647 | 648 | harmless_is_set = len(getattr(self,"harmless",{})) > 0 649 | preserve_harmless = harmless_is_set and preserve_harmless 650 | 651 | if reset == True or getattr(self,"harmless",None) is None: 652 | self.harmful = {} 653 | if not preserve_harmless: 654 | self.harmless = {} 655 | 656 | self.harmful_z_label = [] 657 | self.harmless_z_label = [] 658 | 659 | # load the full training set here to align all the dimensions (even if we're not going to run harmless) 660 | toks = self.tokenize_instructions_fn(instructions=self.harmful_inst_train[:N]+self.harmless_inst_train[:N]) 661 | 662 | splitpos = min(N,len(self.harmful_inst_train)) 663 | harmful_toks = toks[:splitpos] 664 | harmless_toks = toks[splitpos:] 665 | 666 | last_indices = last_indices or 1 667 | 668 | self.harmful,self.harmful_z_label = self.create_activation_cache(harmful_toks,N=N,batch_size=batch_size,last_indices=last_indices,measure_refusal=measure_refusal,stop_at_layer=None) 669 | if not preserve_harmless: 670 | self.harmless, self.harmless_z_label = self.create_activation_cache(harmless_toks,N=N,batch_size=batch_size,last_indices=last_indices,measure_refusal=measure_refusal,stop_at_layer=None) 671 | 672 | --------------------------------------------------------------------------------