├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── activation_additions ├── __init__.py ├── analysis.py ├── completion_utils.py ├── experiments.py ├── hook_utils.py ├── lenses.py ├── logging.py ├── logits.py ├── metrics.py ├── prompt_utils.py ├── sweeps.py ├── utils.py └── widgets.py ├── data ├── chatgpt_shipping_essay_20230423.txt ├── chatgpt_wedding_essay_20230423.txt ├── restaurant.csv ├── restaurant_proc.csv ├── vegan_banana_bread.txt └── wikipedia_macedonia.txt ├── onstart.sh ├── pyrightconfig.json ├── pytest.ini ├── results └── save_tables.py ├── scripts ├── Prompt superposition.ipynb ├── addition_clean_reimplementation.py ├── basic_functionality.py ├── capabilities_impact.ipynb ├── human_rating.py ├── implementations_comparison.py ├── initial_post_quantitative.py ├── lenses_demo.py ├── llama_2_steering.py ├── logging_demo.py ├── position_functionality.py ├── prompt_magnitudes.py ├── prompts.txt ├── qualitative.ipynb ├── stress_testing.ipynb ├── sweeps_demo.py ├── switch_to_french.py └── widgets_demo.py ├── setup.py ├── sparse_coding ├── __init__.py ├── act_config.yaml ├── acts_collect.py ├── autoencoder.py ├── data │ └── token_info.csv ├── feature_tokens.py ├── heatmap.py ├── interp_ablations.py └── utils │ ├── __init__.py │ ├── configure.py │ └── top_k.py ├── tests ├── smoke_test_access.yaml ├── smoke_test_config.yaml ├── smoke_test_data │ └── .gitkeep ├── sweep_over_prompts_cache.pkl ├── test_completion_utils.py ├── test_experiments.py ├── test_hook_utils.py ├── test_lenses.py ├── test_logging.py ├── test_logits.py ├── test_metrics.py ├── test_prompt_utils.py ├── test_sparse_coding.py ├── test_sparse_coding_smoke.py └── test_sweeps.py ├── truthfulqa ├── replication_llama_evals.py ├── steering_evals.csv └── steering_evals.py └── vast └── Dockerfile /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | # Set up non-root user 4 | ARG USERNAME=vscode 5 | ARG USER_UID=1000 6 | ARG USER_GID=$USER_UID 7 | RUN groupadd --gid $USER_GID $USERNAME && \ 8 | useradd --uid $USER_UID --gid $USER_GID -m $USERNAME && \ 9 | apt-get update && \ 10 | apt-get install -y sudo && \ 11 | echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME && \ 12 | chmod 0440 /etc/sudoers.d/$USERNAME 13 | 14 | 15 | # Install packages 16 | RUN apt-get install -y git 17 | 18 | # Set up the user and environment 19 | USER $USERNAME 20 | ENV PATH="/home/$USERNAME/.local/bin:${PATH}" 21 | 22 | # Install project dependencies 23 | WORKDIR /home/$USERNAME/activation_additions 24 | COPY --chown=$USERNAME:$USER_GID . . 25 | RUN pip install --no-cache-dir -e '.[dev]' 26 | 27 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PyTorch Dev Container", 3 | "dockerFile": "Dockerfile", 4 | "context": "..", 5 | "workspaceFolder": "/home/vscode/activation_additions", 6 | "settings": { 7 | "terminal.integrated.shell.linux": "/bin/bash" 8 | }, 9 | "runArgs": [ 10 | "--cap-add=SYS_PTRACE", 11 | "--security-opt", 12 | "seccomp=unconfined" 13 | ], 14 | "extensions": [ 15 | "ms-python.python", 16 | "ms-vscode.cpptools" 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Code profiling 132 | *prof 133 | 134 | # lightning logs 135 | /sparse_coding/logs/ 136 | /tests/logs/ 137 | 138 | # Misc 139 | *.pkl 140 | .vscode 141 | results* 142 | *~ 143 | *prof 144 | *old_ave* 145 | wandb 146 | results 147 | *playground* 148 | artifacts 149 | openai_api_key.txt 150 | images 151 | *png 152 | *html 153 | scripts/KL_bug.py 154 | wandb_restored_files 155 | *.pt 156 | *.ckpt 157 | *.npy 158 | act_access.yaml 159 | tests/smoke_test_data/smoke_test_token_info.csv 160 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 montemac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Algebraic value editing in pretrained language models 2 | 3 | Algebraic value editing involves the injection of activation vectors into the forward 4 | passes of language models like GPT-2 using the hooking functionality of 5 | `transformer_lens`. 6 | 7 | # Installation 8 | After cloning the repository, run `pip install -e .` to install the 9 | `activation_additions` package. 10 | 11 | There are currently a few example scripts in the `scripts/` 12 | directory.For example, `basic_functionality.py` generates 13 | modified prompts (as described below). 14 | 15 | # Methodology 16 | 17 | ## How the vectors are generated 18 | 19 | The core data structure is the `ActivationAddition`, which is specified by: 20 | 21 | - A prompt, like "Love", 22 | - A location within the forward pass, like "the activations just before 23 | the sixth block" (i.e. `blocks.6.hook_resid_pre`), and 24 | - A coefficient, like 2.5. 25 | 26 | ``` 27 | love_rp = ActivationAddition(prompt="Love", coeff=2.5, act_name="blocks.6.hook_resid_pre") 28 | ``` 29 | 30 | The `ActivationAddition` specifies: 31 | > Run a forward pass on the prompt, record the activations at the given 32 | > location in the forward pass, and then rescale those activations by 33 | > the given coefficient. 34 | 35 | Then, when future forward passes reach `blocks.6.hook_resid_pre`, a hook 36 | function adds e.g. 2.5 times the "Love" activations to the usual activations 37 | at that location. 38 | 39 | For example, if we run `gpt2-small` on the prompt "I went to the store 40 | because", the residual streams line up as follows: 41 | ``` 42 | prompt_tokens = ['<|endoftext|>', 'I', ' went', ' to', ' the', ' store', ' because'] 43 | love_rp_tokens = ['<|endoftext|>', 'Love'] 44 | ``` 45 | To add the love `ActivationAddition` to the forward pass, we run the usual forward 46 | pass on the prompt until transformer block 6. At this point, consider 47 | the first two residual streams. Namely, the `'<|endoftext|>'` residual 48 | stream and the `'I'`/`'Love'` residual stream. We add the activations in these two 49 | residual streams. 50 | 51 | 52 | ## X-vectors are a special kind of `ActivationAddition` 53 | 54 | A special case of this is the "X-vector." A "Love minus 55 | hate" vector is generated by 56 | ``` 57 | love_rp, hate_rp = get_x_vector(prompt1="Love", prompt2="Hate", 58 | coeff=5, act_name=6) 59 | ``` 60 | This returns a tuple of two `ActivationAddition`s: 61 | ``` 62 | love_rp = ActivationAddition(prompt="Love", coeff=5, act_name="blocks.6.hook_resid_pre") 63 | hate_rp = ActivationAddition(prompt="Hate", coeff=-5, act_name="blocks.6.hook_resid_pre") 64 | ``` 65 | (This is mechanistically similar to our [cheese-](https://www.lesswrong.com/posts/cAC4AXiNC5ig6jQnc/understanding-and-controlling-a-maze-solving-policy-network) and 66 | [top-right-vector](https://www.lesswrong.com/posts/gRp6FAWcQiCWkouN5/maze-solving-agents-add-a-top-right-vector-make-the-agent-go)s, originally computed for deep convolutional 67 | maze-solving policy networks.) 68 | 69 | Sometimes, x-vectors are built from two prompts which have different 70 | tokenized lengths. In this situation, it empirically seems best to even 71 | out the lengths by padding the shorter prompt with space tokens (`' '`). 72 | This is done by calling: 73 | ``` 74 | get_x_vector(prompt1="I talk about weddings constantly", 75 | prompt2="I do not talk about weddings constantly", 76 | coeff=4, act_name=20, 77 | pad_method="tokens_right", model=gpt2_small, 78 | custom_pad_id=gpt2_small.to_single_token(' ')) 79 | ``` 80 | 81 | ## Using `ActivationAddition`s to generate modified completions 82 | Given an actual prompt which is fed into the model normally 83 | (`model.generate(prompt="Hi!")`) and a list of `ActivationAddition`s, we can 84 | easily generate a set of completions with and without the influence of 85 | the `ActivationAddition`s. 86 | 87 | ``` 88 | print_n_comparisons( 89 | prompt="I hate you because", 90 | model=gpt2_xl, 91 | tokens_to_generate=100, 92 | activation_additions=[love_rp, hate_rp], 93 | num_comparisons=15, 94 | seed=42, 95 | temperature=1, freq_penalty=1, top_p=.3 96 | ) 97 | ``` 98 | 99 | This produces an output like the following (where the prompt is bolded, 100 | and the completions are not): 101 | ![](https://i.imgur.com/CJc4SVt.png) 102 | 103 | An even starker example is produced by 104 | ``` 105 | praise_rp, hurt_rp = *get_x_vector(prompt1="Intent to praise", 106 | prompt2="Intent to hurt", 107 | coeff=15, act_name=6, 108 | pad_method="tokens_right", model=gpt2_xl, 109 | custom_pad_id=gpt2_xl.to_single_token(' ')) 110 | print_n_comparisons( 111 | prompt="I want to kill you because", 112 | model=gpt2_xl, 113 | tokens_to_generate=50, 114 | activation_additions=[praise_rp, hurt_rp], 115 | num_comparisons=15, 116 | seed=0, 117 | temperature=1, freq_penalty=1, top_p=.3 118 | ) 119 | ``` 120 | ![](https://i.imgur.com/ewD0IKT.png) 121 | 122 | For more examples, consult our [Google 123 | Colab](https://colab.research.google.com/drive/183boiXfIBEdo6ch8RwOyqIZizJd6vwDl?usp=sharing). 124 | -------------------------------------------------------------------------------- /activation_additions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/activation_additions/__init__.py -------------------------------------------------------------------------------- /activation_additions/analysis.py: -------------------------------------------------------------------------------- 1 | """ Tools for analyzing the results of algebraic value editing. """ 2 | 3 | # %% 4 | from typing import List 5 | import html 6 | import numpy as np 7 | import pandas as pd 8 | from ipywidgets import widgets 9 | from IPython.display import display, clear_output 10 | 11 | 12 | def rate_completions( 13 | data_frame: pd.DataFrame, 14 | criterion: str = "happy", 15 | ) -> List: 16 | """Prompt the user to rate the generated completions, without 17 | indicating which condition they came from. Modifies the `data_frame` 18 | in place. 19 | 20 | args: 21 | `data_frame`: The `DataFrame` should have the following columns: 22 | `prompts`: The prompts used to generate the completions. 23 | `completions`: The generated completions. 24 | `is_modified`: Whether the completion was generated 25 | using a modified forward pass. 26 | 27 | `criterion`: The criterion to use for rating the completions. 28 | """ 29 | 30 | # Helper function. could use but it's not as pretty. 31 | def htmlify(text): 32 | return html.escape(text).replace("\n", "
") 33 | 34 | # Show the generations to the user in a random order 35 | perm = np.random.permutation(len(data_frame)) 36 | perm_idx = 0 37 | data_idx = perm[perm_idx] 38 | 39 | # Show preamble TODO type-hint all of this 40 | prompt: str = data_frame["prompts"].tolist()[0] 41 | preamble = widgets.HTML() 42 | 43 | def update_preamble(): 44 | preamble.value = f"""

45 | The model was run with prompt: "{htmlify(prompt)}"
46 | Please rate the completions below. based on how {criterion} they are. You are rating completion {perm_idx+1}/{len(data_frame)}. 47 |

""" 48 | 49 | update_preamble() 50 | 51 | # Use ipython to display text of the first completion 52 | completion_box = widgets.HTML() 53 | 54 | def set_completion_text(text): 55 | completion_box.value = f"

{htmlify(text)}

" 56 | 57 | set_completion_text(data_frame.iloc[data_idx]["completions"]) 58 | 59 | # Create the rating buttons 60 | rating_buttons = widgets.ToggleButtons( 61 | options=["1", "2", "3", "4", "5"], 62 | button_style="", 63 | tooltips=["1", "2", "3", "4", "5"], 64 | value=None, 65 | ) 66 | display(completion_box) 67 | 68 | # On rating button click, update the data frame and show the next completion 69 | def on_rating_button_clicked(btn): 70 | nonlocal data_idx, perm_idx # so we can increment 71 | 72 | data_frame.loc[data_idx, "rating"] = int(btn["new"]) 73 | 74 | # Reset the rating buttons without retriggering observe 75 | rating_buttons.unobserve(on_rating_button_clicked, names="value") # type: ignore 76 | rating_buttons.value = None 77 | rating_buttons.observe(on_rating_button_clicked, names="value") # type: ignore 78 | 79 | # Increment if we aren't done 80 | if perm_idx < len(data_frame) - 1: 81 | perm_idx += 1 82 | data_idx = perm[perm_idx] 83 | set_completion_text(data_frame.iloc[data_idx]["completions"]) 84 | update_preamble() 85 | else: 86 | for widget in displayed: 87 | widget.close() 88 | 89 | rating_buttons.observe(on_rating_button_clicked, names="value") # type: ignore 90 | 91 | # Display all the widgets. saved for the end to make the structure more apparent 92 | displayed = [preamble, completion_box, rating_buttons] 93 | display(*displayed) 94 | 95 | # Return the widget tree for easier testing. returning the passed in dataframe is pointless. 96 | return displayed 97 | 98 | 99 | # For interactive development of the widgets and testing (nice to have in one file) 100 | if __name__ == "__main__": 101 | mixed_df = pd.DataFrame( 102 | { 103 | "prompts": [ 104 | "Yesterday, my dog died. Today, I got denied for a raise. I'm feeling" 105 | ] 106 | * 2, 107 | "completions": [ 108 | "Yesterday, my dog died. Today, I got denied for a raise. " 109 | + "I'm feeling sad.\nVery sad.", 110 | "Yesterday, my dog died. Today, I got denied for a raise. " 111 | + "I'm feeling happy.\n\nReally happy~!", 112 | ], 113 | "is_modified": [False, True], 114 | } 115 | ) 116 | 117 | displayed_widgets = rate_completions( 118 | data_frame=mixed_df, criterion="happy" 119 | ) 120 | 121 | # Create box to display the updating dataframe 122 | box = widgets.Output() 123 | 124 | def display_df(_): 125 | """Display mixed_df after clearing output""" 126 | with box: 127 | clear_output() 128 | display(mixed_df) 129 | 130 | displayed_widgets[2].observe(display_df, names="value") 131 | display_df(None) 132 | display(box) 133 | 134 | # %% 135 | -------------------------------------------------------------------------------- /activation_additions/lenses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrappers to use tuned lens with AVE. 3 | 4 | The one nontrivial detal here: we want 'resid_pre' not post or mid, see the image in this readme: 5 | https://github.com/AlignmentResearch/tuned-lens 6 | """ 7 | from typing import List, Dict 8 | 9 | import numpy as np 10 | import torch 11 | import pandas as pd 12 | 13 | from tuned_lens import TunedLens 14 | from tuned_lens.plotting import PredictionTrajectory 15 | from transformers import AutoTokenizer 16 | 17 | from activation_additions import completion_utils, hook_utils 18 | 19 | # %% 20 | 21 | 22 | def fwd_hooks_from_activ_hooks(activ_hooks): 23 | """ 24 | Because AVE data structures differ from transformerlens we must convert. 25 | >>> fwd_hooks_from_activ_hooks({'blocks.47.hook_resid_pre': ['e1', 'e2']]}) 26 | [('blocks.47.hook_resid_pre', 'e1'), ('blocks.47.hook_resid_pre', 'e2')] 27 | """ 28 | return [ 29 | (name, hook_fn) 30 | for name, hook_fns in activ_hooks.items() 31 | for hook_fn in hook_fns 32 | ] 33 | 34 | 35 | def trajectory_log_probs(tuned_lens, logits, cache): 36 | """ 37 | Get the log probabilities of the trajectory from the cache and logits. 38 | """ 39 | stream = [ 40 | resid for name, resid in cache.items() if name.endswith("resid_pre") 41 | ] 42 | traj_log_probs = [ 43 | tuned_lens.forward(x, i) 44 | .log_softmax(dim=-1) 45 | .squeeze() 46 | .detach() 47 | .cpu() 48 | .numpy() 49 | for i, x in enumerate(stream) 50 | ] 51 | # Handle the case where the model has more/less tokens than the lens 52 | model_log_probs = ( 53 | logits.log_softmax(dim=-1).squeeze().detach().cpu().numpy() 54 | ) 55 | traj_log_probs.append(model_log_probs) 56 | return traj_log_probs 57 | 58 | 59 | def prediction_trajectories( 60 | caches: List[Dict[str, torch.Tensor]], 61 | dataframes: List[pd.DataFrame], 62 | tokenizer: AutoTokenizer, 63 | tuned_lens: TunedLens, 64 | ) -> List[PredictionTrajectory]: 65 | """ 66 | Get prediction trajectories from caches and dataframes, typically 67 | obtained from `run_hooked_and_normal_with_cache`. 68 | 69 | Args: 70 | caches: A list of caches. must include 'resid_pre' tensors. 71 | dataframes: A list of dataframes. Must include 'logits', 'prompts', and 'completions'. 72 | tokenizer: The tokenizer to use, typically model.tokenizer. 73 | tuned_lens: The tuned lens to use. Typically obtained by 74 | `TunedLens.from_model_and_pretrained(hf_model, lens_resource_id=model_name)` 75 | """ 76 | 77 | logits_list = [torch.tensor(df["logits"]) for df in dataframes] 78 | full_prompts = [ 79 | df["prompts"][0] + df["completions"][0] for df in dataframes 80 | ] 81 | return [ 82 | PredictionTrajectory( 83 | log_probs=np.array( 84 | trajectory_log_probs(tuned_lens, logits, cache) 85 | ), 86 | input_ids=np.array( 87 | tokenizer.encode(prompt) + [tokenizer.eos_token_id] # type: ignore 88 | ), 89 | tokenizer=tokenizer, # type: ignore 90 | ) 91 | for prompt, logits, cache in zip(full_prompts, logits_list, caches) 92 | ] 93 | 94 | 95 | def run_hooked_and_normal_with_cache( 96 | model, activation_additions, gen_args, device=None 97 | ): 98 | """ 99 | Run hooked and normal with cache. 100 | 101 | Args: 102 | model: The model to run. 103 | activation_additions: A list of ActivationAdditions. 104 | gen_args: Keyword arguments to pass to `completion_utils.gen_using_model`. 105 | Must include `prompt_batch` and `tokens_to_generate`. 106 | 107 | Returns: 108 | normal_and_modified_df: A list of two dataframes, one for normal and one for modified. 109 | normal_and_modified_cache: A list of two caches, one for normal and one for modified. 110 | """ 111 | assert len(gen_args.get("prompt_batch", [])) == 1, ( 112 | "Only one prompt is supported. Got" 113 | f" {len(gen_args.get('prompt_batch', []))}" 114 | ) 115 | 116 | activ_hooks = hook_utils.hook_fns_from_activation_additions( 117 | model, activation_additions 118 | ) 119 | fwd_hooks = fwd_hooks_from_activ_hooks(activ_hooks) 120 | normal_and_modified_df = [] 121 | normal_and_modified_cache = [] 122 | 123 | for fwd_hooks, is_modified in [([], False), (fwd_hooks, True)]: 124 | cache, caching_hooks, _ = model.get_caching_hooks( 125 | names_filter=lambda n: "resid_pre" in n, device=device 126 | ) 127 | 128 | # IMPORTANT: We call caching hooks *after* the value editing hooks. 129 | with model.hooks(fwd_hooks=fwd_hooks + caching_hooks): 130 | results_df = completion_utils.gen_using_model( 131 | model, include_logits=True, **gen_args 132 | ) 133 | results_df["is_modified"] = is_modified 134 | normal_and_modified_df.append(results_df) 135 | normal_and_modified_cache.append(cache) 136 | 137 | return normal_and_modified_df, normal_and_modified_cache 138 | -------------------------------------------------------------------------------- /activation_additions/logging.py: -------------------------------------------------------------------------------- 1 | """Functions to support logging of data to wandb""" 2 | 3 | from typing import Optional, Dict, Tuple, Any, Callable, List 4 | from contextlib import nullcontext 5 | from warnings import warn 6 | import os 7 | import pickle 8 | import inspect 9 | 10 | from decorator import decorate 11 | from transformer_lens.HookedTransformer import HookedTransformer 12 | import wandb 13 | 14 | PROJECT = "activation_additions" 15 | 16 | # Hack to disable a warning when wandb forks a process for sync'ing (I 17 | # think) 18 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 19 | 20 | # Disable printing 21 | os.environ["WANDB_SILENT"] = "true" 22 | 23 | # TODO: this is a hack, change this to add an optional return value from 24 | # loggable functions to return the run ID 25 | last_run_info = {"id": None, "name": None, "path": None, "url": None} 26 | 27 | 28 | # TODO: fix returns types here, it's a bit complex 29 | def get_or_init_run( 30 | **init_args, 31 | ) -> Tuple[Any, Any]: # type: ignore 32 | """Function to obtain a usable wandb Run object, initializing it if 33 | needed. A context manager is also returned: if the run was 34 | initialized in this call, the context manager will be the run, so 35 | that it can be wrapped in a with block to provide exception-save 36 | finishing. If the run was initialized previously and simply 37 | returned by this call, then the context manager will be empty, and 38 | it should be assumed that the original creator of the run will be 39 | managing it's safe finishing.""" 40 | global last_run_info # pylint: disable=global-statement 41 | if wandb.run is None: 42 | 43 | def overwrite_arg_with_warning(args, key, new_value): 44 | if key in args: 45 | warn( 46 | f"Key {key} provided in arguments dict, but this" 47 | f"will be ignored and overridden with {new_value}." 48 | ) 49 | args[key] = new_value 50 | 51 | # Force any needed args 52 | overwrite_arg_with_warning(init_args, "reinit", True) 53 | overwrite_arg_with_warning(init_args, "project", PROJECT) 54 | overwrite_arg_with_warning(init_args, "save_code", True) 55 | overwrite_arg_with_warning(init_args, "allow_val_change", True) 56 | # Initialize a run 57 | run = wandb.init(**init_args) 58 | if run is not None: 59 | last_run_info = { 60 | "id": run.id, 61 | "name": run.name, 62 | "path": run.path, 63 | "url": run.url, 64 | } 65 | manager = run 66 | else: 67 | run = wandb.run 68 | # Add additional configs in a list of "child config", to avoid 69 | # clobberring names 70 | if "config" in init_args: 71 | if "child_configs" not in run.config: 72 | run.config["child_configs"] = [init_args["config"]] 73 | else: 74 | run.config["child_configs"].append(init_args["config"]) 75 | manager = nullcontext() 76 | return run, manager 77 | 78 | 79 | def log_object( 80 | run: wandb.wandb_sdk.wandb_run.Run, # type: ignore 81 | obj: Any, 82 | logged_name: str, 83 | ): 84 | """Save object to a file in the directory of the provided run, 85 | which will be automatically uploaded at the end of the run.""" 86 | folder = os.path.join(run.dir, "logged_objects") 87 | try: 88 | os.mkdir(folder) 89 | except FileExistsError: 90 | pass 91 | with open(os.path.join(folder, f"{logged_name}.pkl"), "wb") as file: 92 | pickle.dump(obj, file) 93 | 94 | 95 | def get_or_init_run_and_log_object( 96 | job_type: str, 97 | config: Dict[str, Any], 98 | obj: Any, 99 | logged_name: str, 100 | run_args: Optional[Dict[str, Any]] = None, 101 | ): 102 | """Function to get or init a wandb run, set the config, log an 103 | object, and finish the run (if it was created) in a single call.""" 104 | if run_args is None: 105 | run_args = {} 106 | # Get the wandb run 107 | run, manager = get_or_init_run( 108 | job_type=job_type, 109 | config=config, 110 | tags=run_args.get("tags", None), 111 | group=run_args.get("group", None), 112 | notes=run_args.get("notes", None), 113 | ) 114 | # Wrap in a context manager for exception-safety, and log the 115 | # results of this call 116 | with manager: 117 | log_object(run, obj, logged_name) 118 | 119 | 120 | def convert_object_to_wandb_config(obj: Any) -> Any: 121 | """Convert object to form better suited for storing in wandb config 122 | objects. Conversion will depend on object type.""" 123 | if isinstance(obj, HookedTransformer): 124 | # Store the configuration of a HookedTransformer 125 | return obj.cfg 126 | # Return the unmodified object by default 127 | return obj 128 | 129 | 130 | def convert_dict_items_to_wandb_config( 131 | objects_dict: Dict[str, Any] 132 | ) -> Dict[str, Any]: 133 | """Take a dictionary of items of any type, and apply some 134 | conversions to forms better suited for storing in wandb config objects.""" 135 | return { 136 | key: convert_object_to_wandb_config(value) 137 | for key, value in objects_dict.items() 138 | } 139 | 140 | 141 | def get_function_args(func: Callable) -> List[str]: 142 | """Return names of function arguments that aren't *args or **kwargs.""" 143 | signature = inspect.signature(func) 144 | return [ 145 | param.name 146 | for param in signature.parameters.values() 147 | # if param.default == inspect.Parameter.empty 148 | # and param.kind 149 | if param.kind 150 | not in ( 151 | inspect.Parameter.VAR_POSITIONAL, 152 | inspect.Parameter.VAR_KEYWORD, 153 | ) 154 | ] 155 | 156 | 157 | # Uses decorator module: https://github.com/micheles/decorator/blob/master/docs/documentation.md 158 | def _loggable(func: Callable, *args, **kwargs) -> Any: 159 | """Caller function for loggable decorator, see public decorator 160 | function for docs.""" 161 | # Store all args by name (positional and keyword) 162 | all_args = dict(zip(get_function_args(func), args)) 163 | all_args.update(kwargs) 164 | # Get log argument from function call, default to false if not present 165 | log = all_args.get("log", False) 166 | # Check if we should log 167 | if log is False: 168 | func_return = func(*args, **kwargs) 169 | else: 170 | # Process the log argument, extract logging-related arguments if 171 | # provided 172 | if log is True: 173 | log_args = {} 174 | else: 175 | log_args = log 176 | # Set up the config for this logging call: just store the 177 | # keyword args, converted as needed for storage on wandb 178 | config = convert_dict_items_to_wandb_config(all_args) 179 | # Get the wandb run 180 | run, manager = get_or_init_run( 181 | job_type=func.__name__, 182 | config=config, 183 | tags=log_args.get("tags", None), 184 | group=log_args.get("group", None), 185 | notes=log_args.get("notes", None), 186 | ) 187 | # Use provided context manager to wrap the underlying function call 188 | with manager: 189 | # Call the wrapped function 190 | func_return = func(*args, **kwargs) 191 | log_object( 192 | run, 193 | func_return, 194 | logged_name=func.__name__, 195 | ) 196 | # Return the wrapped function return value 197 | return func_return 198 | 199 | 200 | def loggable(func): 201 | """Decorator that adds optional logging of the return value of this 202 | function to wandb. The decorated function must include a keyword 203 | argument named `log` with a type signature `Union[bool, dict[str, 204 | str]]` for logging to be used. 205 | """ 206 | return decorate(func, _loggable) # type: ignore 207 | 208 | 209 | def get_objects_from_run(run_path: str): 210 | """Extract all stored objects from all artifacts produced by the run 211 | at the provided path.""" 212 | api = wandb.Api() 213 | run = api.run(run_path) 214 | objects = {} 215 | for file in run.files(): 216 | if os.path.split(file.name)[0] == "logged_objects": 217 | folder = os.path.join("wandb_restored_files", run.name) 218 | rest_file = wandb.restore( 219 | file.name, run_path=run_path, replace=False, root=folder 220 | ) 221 | rest_file.close() 222 | with open(os.path.join(folder, file.name), "rb") as open_file: 223 | obj = pickle.load(open_file) 224 | name = os.path.splitext(os.path.split(file.name)[1])[0] 225 | objects[name] = obj 226 | return objects 227 | -------------------------------------------------------------------------------- /activation_additions/prompt_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for specifying prompts and coefficients for algebraic value 2 | editing. """ 3 | 4 | from typing import Tuple, Optional, Union, Callable, List 5 | from jaxtyping import Int 6 | import torch 7 | import torch.nn.functional 8 | 9 | from transformer_lens.HookedTransformer import HookedTransformer 10 | from transformer_lens.utils import get_act_name 11 | 12 | 13 | def get_block_name(block_num: int) -> str: # TODO remove 14 | """Returns the hook name of the block with the given number, at the 15 | input to the residual stream.""" 16 | return get_act_name(name="resid_pre", layer=block_num) 17 | 18 | 19 | class ActivationAddition: 20 | """Specifies a prompt (e.g. "Bob went") and a coefficient and a 21 | location in the model, with an `int` representing the block_num in the 22 | model. This comprises the information necessary to 23 | compute the rescaled activations for the prompt. 24 | """ 25 | 26 | coeff: float 27 | act_name: str 28 | prompt: str 29 | tokens: Int[torch.Tensor, "seq"] 30 | 31 | def __init__( 32 | self, 33 | coeff: float, 34 | act_name: Union[str, int], 35 | prompt: Optional[str] = None, 36 | tokens: Optional[Int[torch.Tensor, "seq"]] = None, 37 | ): 38 | """Specifies a model location (`act_name`) from which to 39 | extract activations, which will then be multiplied by `coeff`. 40 | If `prompt` is specified, it will be used to compute the 41 | activations. If `tokens` is specified, it will be used to 42 | compute the activations. If neither or both are specified, an error will be raised. 43 | 44 | Args: 45 | `coeff : The coefficient to multiply the activations by. 46 | `act_name`: The name of the activation location to use. If 47 | is an `int`, then it specifies the input activations to 48 | that block number. 49 | `prompt`: The prompt to use to compute the activations. 50 | `tokens`: The tokens to use to compute the activations. 51 | """ 52 | assert (prompt is not None) ^ ( 53 | tokens is not None 54 | ), "Must specify either prompt or tokens, but not both." 55 | 56 | self.coeff = coeff 57 | 58 | # Set the activation name 59 | if isinstance(act_name, int): 60 | self.act_name = get_block_name(block_num=act_name) 61 | else: 62 | self.act_name = act_name 63 | 64 | # Set the tokens 65 | if tokens is not None: 66 | assert len(tokens.shape) == 1, "Tokens must be a 1D tensor." 67 | self.tokens = tokens 68 | else: 69 | self.prompt = prompt # type: ignore (this is guaranteed to be str) 70 | 71 | def __repr__(self) -> str: 72 | if hasattr(self, "prompt"): 73 | return ( 74 | f"ActivationAddition({self.prompt}, {self.coeff}," 75 | f" {self.act_name})" 76 | ) 77 | return ( 78 | f"ActivationAddition({self.tokens}, {self.coeff}, {self.act_name})" 79 | ) 80 | 81 | def __eq__(self, other) -> bool: 82 | if not isinstance(other, ActivationAddition): 83 | return False 84 | # If they don't both have prompt or tokens attribute 85 | if hasattr(self, "prompt") ^ hasattr(other, "prompt"): 86 | return False 87 | prompt_eq: bool = ( 88 | self.prompt == other.prompt 89 | if hasattr(self, "prompt") 90 | else torch.equal(self.tokens, other.tokens) 91 | ) 92 | return ( 93 | prompt_eq 94 | and self.coeff == other.coeff 95 | and self.act_name == other.act_name 96 | ) 97 | 98 | 99 | def get_x_vector( 100 | prompt1: str, 101 | prompt2: str, 102 | coeff: float, 103 | act_name: Union[int, str], 104 | model: Optional[HookedTransformer] = None, 105 | pad_method: Optional[str] = None, 106 | custom_pad_id: Optional[int] = None, 107 | ) -> Tuple[ActivationAddition, ActivationAddition]: 108 | """Take in two prompts and a coefficient and an activation name, and 109 | return two activation additions spaced according to `pad_method`. 110 | 111 | Args: 112 | `prompt1`: The first prompt. 113 | `prompt2`: The second prompt. 114 | `coeff`: The coefficient to multiply the activations by. 115 | `act_name`: The name of the activation location to use. If 116 | `act_name` is an `int`, then it specifies the input activations 117 | to that block number. 118 | `model`: The model which tokenizes the prompts, if `pad_method` 119 | is not `None`. 120 | `pad_method`: The method to use to pad the prompts. If `None`, 121 | then no padding will be done. If "tokens_right", then the 122 | prompts will be padded to the right until the tokenizations are 123 | equal length. 124 | `custom_pad_id`: The token to use for padding. If `None`, 125 | then use the model's pad token. 126 | 127 | Returns: 128 | A tuple of two `ActivationAddition`s, the first of which has the prompt 129 | `prompt1` and the second of which has the prompt `prompt2`. 130 | """ 131 | if pad_method == "tokens_left": 132 | raise NotImplementedError("tokens_left not implemented yet.") 133 | 134 | if pad_method is not None: 135 | assert pad_method in [ 136 | "tokens_right", 137 | ], "pad_method must be 'tokens_right'" 138 | assert model is not None, "model must be specified if pad_method is" 139 | assert model.tokenizer is not None, "model must have a tokenizer" 140 | 141 | # If no custom token is specified, use the model's pad token 142 | if ( 143 | not hasattr(model.tokenizer, "pad_token_id") 144 | or model.tokenizer.pad_token_id is None 145 | ): 146 | raise ValueError( 147 | "Tokenizer does not have a pad_token_id. " 148 | "Please specify a custom pad token." 149 | ) 150 | pad_token_id: int = custom_pad_id or model.tokenizer.pad_token_id 151 | 152 | # Tokenize the prompts 153 | tokens1, tokens2 = [ 154 | model.to_tokens(prompt)[0] for prompt in [prompt1, prompt2] 155 | ] 156 | max_token_len: int = max(tokens1.shape[-1], tokens2.shape[-1]) 157 | 158 | # Pad the shorter token sequence 159 | pad_partial: Callable = lambda tokens: torch.nn.functional.pad( 160 | tokens, 161 | (0, max_token_len - tokens.shape[-1]), 162 | mode="constant", 163 | value=pad_token_id, # type: ignore 164 | ) 165 | 166 | padded_tokens1, padded_tokens2 = map(pad_partial, [tokens1, tokens2]) 167 | 168 | end_point = ActivationAddition( 169 | tokens=padded_tokens1, coeff=coeff, act_name=act_name 170 | ) 171 | start_point = ActivationAddition( 172 | tokens=padded_tokens2, coeff=-1 * coeff, act_name=act_name 173 | ) 174 | return end_point, start_point 175 | 176 | end_point = ActivationAddition( 177 | prompt=prompt1, coeff=coeff, act_name=act_name 178 | ) 179 | start_point = ActivationAddition( 180 | prompt=prompt2, coeff=-1 * coeff, act_name=act_name 181 | ) 182 | return end_point, start_point 183 | 184 | 185 | def pad_tokens_to_match_activation_additions( 186 | model: HookedTransformer, 187 | tokens: Int[torch.Tensor, "batch pos"], 188 | activation_additions: List[ActivationAddition], 189 | ) -> Tuple[Int[torch.Tensor, "batch pos"], int]: 190 | """Tokenize and space-pad the front of the provided string so that 191 | none of the ActivationAdditions will overlap with the unpadded text, 192 | returning the padded tokens and the index at which the tokens from 193 | the original string begin. Not that the padding is inserted AFTER 194 | the BOS and before the original-string-excluding-BOS.""" 195 | # Get the max token len of the ActivationAdditions 196 | activation_addition_len = 0 197 | for activation_addition in activation_additions: 198 | try: 199 | activation_addition_len = max( 200 | len(activation_addition.tokens), activation_addition_len 201 | ) 202 | except AttributeError: 203 | activation_addition_len = max( 204 | len(model.to_tokens(activation_addition.prompt).squeeze()), 205 | activation_addition_len, 206 | ) 207 | # Input tokens already has BOS prepended, so insert the padding 208 | # after that. 209 | # Note that the ActivationAdditions always have BOS at the start, and we 210 | # don't want to include this length in our padding as it's fine 211 | # if the ActivationAddition overlaps this location since it will have 212 | # zero effect if the ActivationAdditions are proper x-vectors., so we 213 | # pad with pad_len - 1 214 | tokens = torch.concat( 215 | [ 216 | tokens[:, :1], 217 | torch.full( 218 | (1, activation_addition_len - 1), 219 | model.to_single_token(" "), 220 | device=model.cfg.device, 221 | ), 222 | tokens[:, 1:], 223 | ], 224 | dim=1, 225 | ) 226 | return tokens, activation_addition_len 227 | 228 | 229 | def get_max_addition_len( 230 | model: HookedTransformer, 231 | activation_additions: List[ActivationAddition], 232 | ) -> int: 233 | """Iterate through the activation additions and return the maximum 234 | token length of the activation additions.""" 235 | lengths = [] 236 | for activation_addition in activation_additions: 237 | try: 238 | lengths.append(len(activation_addition.tokens)) 239 | except AttributeError: 240 | lengths.append( 241 | len(model.to_tokens(activation_addition.prompt).squeeze()) 242 | ) 243 | return max(lengths) 244 | -------------------------------------------------------------------------------- /activation_additions/utils.py: -------------------------------------------------------------------------------- 1 | """Misc utilities""" 2 | 3 | 4 | def enable_ipython_reload(): 5 | """Call to run 'line magic' commands if in IPython instance to 6 | enable hot-reloading of modified imported modules.""" 7 | try: 8 | # pylint: disable=import-outside-toplevel 9 | from IPython import get_ipython # type: ignore 10 | 11 | # pylint: enable=import-outside-toplevel 12 | 13 | get_ipython().run_line_magic("reload_ext", "autoreload") # type: ignore 14 | get_ipython().run_line_magic("autoreload", "2") # type: ignore 15 | except AttributeError: 16 | pass 17 | -------------------------------------------------------------------------------- /activation_additions/widgets.py: -------------------------------------------------------------------------------- 1 | """Provides an implementation of a basic ipywidgets widget for 2 | testing activation injections.""" 3 | 4 | from typing import Optional, Tuple 5 | 6 | import ipywidgets as widgets 7 | from IPython.display import display 8 | import plotly.graph_objects as go 9 | 10 | from transformer_lens import HookedTransformer 11 | 12 | from activation_additions import ( 13 | prompt_utils, 14 | logits, 15 | experiments, 16 | completion_utils, 17 | ) 18 | 19 | 20 | def make_widget( 21 | model: HookedTransformer, 22 | initial_input_text: Optional[str] = None, 23 | initial_phrases: Optional[Tuple[str, str]] = None, 24 | initial_act_name: int = 16, 25 | initial_coeff: float = 1.0, 26 | initial_seed: int = 0, 27 | ) -> Tuple[widgets.Widget, widgets.Output]: 28 | """Creates a widget for testing activation injections. The widget 29 | provides UI controls for model input text, prompt input phrases 30 | (always space-padded), injection layer, injection coefficient and 31 | completion seed. It applies the activation injection and displays 3 32 | completions, a plot of top-K next-token probability changes, and 33 | various other statistics.""" 34 | ui_items = [] 35 | 36 | def add_control_with_label(item, label): 37 | ui_items.append(widgets.Label(label)) 38 | ui_items.append(item) 39 | return item 40 | 41 | input_text = add_control_with_label( 42 | widgets.Text( 43 | value=initial_input_text if initial_input_text is not None else "" 44 | ), 45 | "Input text", 46 | ) 47 | phrase_pos = add_control_with_label( 48 | widgets.Text( 49 | value=initial_phrases[0] if initial_phrases is not None else "" 50 | ), 51 | "Positive prompt phrase", 52 | ) 53 | phrase_neg = add_control_with_label( 54 | widgets.Text( 55 | value=initial_phrases[1] if initial_phrases is not None else "" 56 | ), 57 | "Negative prompt phrase", 58 | ) 59 | act_name = add_control_with_label( 60 | widgets.BoundedIntText( 61 | value=initial_act_name, min=0, max=model.cfg.n_layers - 1 62 | ), 63 | "Inject before layer", 64 | ) 65 | coeff = add_control_with_label( 66 | widgets.FloatText(value=initial_coeff), 67 | "Injection coefficient", 68 | ) 69 | completion_seed = add_control_with_label( 70 | widgets.IntText(value=initial_seed), 71 | "Completion seed", 72 | ) 73 | run_button = add_control_with_label( 74 | widgets.Button(description="Run"), 75 | "", 76 | ) 77 | interface = widgets.GridBox( 78 | ui_items, 79 | layout=widgets.Layout(grid_template_columns="repeat(2, 150px)"), 80 | ) 81 | 82 | def do_injection( 83 | input_text, phrase_pos, phrase_neg, act_name, coeff, completion_seed 84 | ): 85 | # Get the activation additions 86 | activation_additions = list( 87 | prompt_utils.get_x_vector( 88 | prompt1=phrase_pos, 89 | prompt2=phrase_neg, 90 | coeff=coeff, 91 | act_name=act_name, 92 | model=model, 93 | pad_method="tokens_right", 94 | custom_pad_id=model.to_single_token(" "), # type: ignore 95 | ), 96 | ) 97 | # Calculate normal and modified token probabilities 98 | probs = logits.get_normal_and_modified_token_probs( 99 | model=model, 100 | prompts=input_text, 101 | activation_additions=activation_additions, 102 | return_positions_above=0, 103 | ) 104 | # Show token probabilities figure 105 | top_k = 10 106 | fig, _ = experiments.show_token_probs( 107 | model, probs["normal", "probs"], probs["mod", "probs"], -1, top_k 108 | ) 109 | fig.update_layout(width=1000) 110 | fig_widget = go.FigureWidget(fig) 111 | # Show the token probability changes 112 | print("") 113 | display(fig_widget) 114 | # Show some KL stats and other misc things 115 | kl_div = ( 116 | ( 117 | probs["mod", "probs"] 118 | * (probs["mod", "logprobs"] - probs["normal", "logprobs"]) 119 | ) 120 | .sum(axis="columns") 121 | .iloc[-1] 122 | ) 123 | ent = ( 124 | (-probs["mod", "probs"] * probs["mod", "logprobs"]) 125 | .sum(axis="columns") 126 | .iloc[-1] 127 | ) 128 | print("") 129 | print( 130 | f"KL(modified||normal) of next token distribution:\t{kl_div:.3f}" 131 | ) 132 | print(f"Entropy of next token distribution:\t\t\t{ent:.3f}") 133 | print(f"KL(modified||normal) / entropy ratio:\t\t\t{kl_div/ent:.3f}") 134 | print("") 135 | _, kl_div_plot_df = experiments.show_token_probs( 136 | model, 137 | probs["normal", "probs"], 138 | probs["mod", "probs"], 139 | -1, 140 | top_k, 141 | sort_mode="kl_div", 142 | ) 143 | print("Top-K tokens by contribution to KL divergence:") 144 | print(kl_div_plot_df[["text", "y_values"]]) 145 | print("") 146 | # Show completions 147 | num_completions = 3 148 | completion_utils.print_n_comparisons( 149 | prompt=input_text, 150 | num_comparisons=num_completions, 151 | model=model, 152 | activation_additions=activation_additions, 153 | seed=completion_seed, 154 | temperature=1, 155 | freq_penalty=1, 156 | top_p=0.3, 157 | ) 158 | return "return" 159 | 160 | out = widgets.Output() 161 | 162 | def on_click_run(btn): # pylint: disable=unused-argument 163 | with out: 164 | out.clear_output(wait=True) 165 | do_injection( 166 | input_text=input_text.value, 167 | phrase_pos=phrase_pos.value, 168 | phrase_neg=phrase_neg.value, 169 | act_name=act_name.value, 170 | coeff=coeff.value, 171 | completion_seed=completion_seed.value, 172 | ) 173 | 174 | run_button.on_click(on_click_run) 175 | 176 | on_click_run(None) 177 | 178 | return interface, out 179 | -------------------------------------------------------------------------------- /data/chatgpt_shipping_essay_20230423.txt: -------------------------------------------------------------------------------- 1 | Title: Recent Trends in the Shipping Industry: A Comprehensive Overview 2 | 3 | Introduction: 4 | The shipping industry has undergone significant transformations in recent years, driven by rapid advancements in technology, changing consumer behaviors, and a shifting global economic landscape. This summary aims to provide a comprehensive overview of the most influential trends shaping the shipping industry today. Key trends include the adoption of new technologies, a focus on sustainability and eco-friendly practices, the rise of e-commerce, and the impact of geopolitical tensions on global trade. 5 | 6 | Adoption of New Technologies: 7 | One of the most significant trends in the shipping industry is the integration of new and advanced technologies. These innovations are revolutionizing the sector by improving efficiency, reducing costs, and increasing safety. Some notable technologies include: 8 | a. Autonomous Vessels: The development of autonomous shipping technologies, such as artificial intelligence (AI), machine learning, and advanced sensors, has the potential to transform maritime transport. Autonomous vessels can optimize fuel consumption, reduce human error, and streamline operations. 9 | 10 | b. Blockchain: Blockchain technology is being employed to increase transparency and efficiency in the shipping industry. By creating secure, tamper-proof records of cargo shipments, blockchain can streamline documentation processes, reduce delays, and lower the risk of fraud. 11 | 12 | c. Internet of Things (IoT): IoT is enabling real-time monitoring of cargo and vessel conditions, which can help optimize routes, improve safety, and minimize equipment failures. IoT devices can also monitor temperature-sensitive shipments, ensuring that perishable goods are transported in optimal conditions. 13 | 14 | Sustainability and Eco-friendly Practices: 15 | Climate change concerns and increasing environmental regulations have led the shipping industry to prioritize sustainable practices. As a result, companies are exploring various options to reduce their environmental footprint: 16 | a. Alternative Fuels: The industry is gradually shifting towards the use of alternative fuels such as liquefied natural gas (LNG), biofuels, and hydrogen. These cleaner energy sources can help reduce greenhouse gas emissions and comply with stricter environmental regulations. 17 | 18 | b. Energy-efficient Design: Shipping companies are investing in more energy-efficient ship designs and technologies, such as improved hull forms, air lubrication systems, and waste heat recovery systems, to reduce fuel consumption and emissions. 19 | 20 | c. Slow Steaming: To conserve fuel and reduce emissions, many shipping lines have adopted slow steaming strategies, which involve operating at lower speeds. This practice has the added benefit of reducing excess capacity in the market. 21 | 22 | The Rise of E-commerce: 23 | E-commerce has experienced explosive growth in recent years, leading to increased demand for fast, efficient, and reliable shipping services. This has spurred several changes within the shipping industry: 24 | a. Last-mile Delivery Innovations: To meet the growing demand for timely deliveries, logistics providers are investing in new technologies and delivery methods, such as drones, autonomous delivery vehicles, and micro-fulfillment centers. 25 | 26 | b. Parcel Shipping: The surge in e-commerce has also led to a significant increase in parcel shipping volumes, with shipping companies adapting their operations to accommodate smaller shipments. 27 | 28 | c. Cross-border Trade: The growth of e-commerce has facilitated cross-border trade, creating new opportunities for shipping companies to expand their services and capitalize on the global demand for goods. 29 | 30 | Geopolitical Tensions and Global Trade: 31 | Geopolitical tensions have led to trade disputes and increased protectionism, impacting the shipping industry in several ways: 32 | a. Shifts in Trade Patterns: Trade disputes and protectionist policies have resulted in changes to global trade patterns, with shipping companies needing to adapt their routes and operations to navigate these new dynamics. 33 | 34 | b. Nearshoring and Reshoring: In response to trade uncertainties, some companies have moved production closer to their target markets, resulting in an increase in regional trade and a shift in shipping routes. 35 | 36 | Conclusion: 37 | The shipping industry is experiencing a period of significant change, driven by technological advancements, a focus on -------------------------------------------------------------------------------- /data/chatgpt_wedding_essay_20230423.txt: -------------------------------------------------------------------------------- 1 | Title: Recent Trends in the Wedding Industry 2 | 3 | Introduction 4 | 5 | The wedding industry is constantly evolving as new trends emerge and couples seek creative ways to make their special day unique and memorable. This report aims to provide an overview of the most recent trends observed within the wedding industry, covering aspects such as wedding themes, attire, catering, and technology. 6 | 7 | Wedding Themes and Decor 8 | a. Sustainability and eco-friendliness: There has been a growing concern for the environment and the impact of weddings on the planet. Couples are opting for sustainable practices like using locally-sourced materials, reusable decor, and reducing waste. Eco-friendly venues, such as outdoor gardens or farms, are also gaining popularity. 9 | 10 | b. Micro-weddings and elopements: The COVID-19 pandemic has led to an increase in intimate, small-scale weddings, with couples choosing to celebrate their nuptials with only their closest friends and family members. This trend has continued even as pandemic-related restrictions have eased, with many couples embracing the intimacy and cost-effectiveness of these smaller celebrations. 11 | 12 | c. Personalized experiences: Couples are prioritizing unique, personalized touches to make their wedding day memorable. Customized decor, invitations, and favors that showcase the couple's love story, interests, and personalities are increasingly popular. 13 | 14 | Wedding Attire 15 | a. Bridal jumpsuits and separates: Brides are moving away from traditional gowns and experimenting with alternative options such as jumpsuits, two-piece ensembles, and separates. These contemporary styles offer a refreshing change and allow brides to showcase their personal style. 16 | 17 | b. Grooms embracing color and patterns: Grooms are becoming more adventurous with their wedding attire, with many opting for bold colors, patterned suits, and unconventional accessories to express their individuality. 18 | 19 | c. Vintage and second-hand: In line with the sustainability trend, many couples are choosing to wear vintage or second-hand outfits, either passed down through family or sourced from specialty stores. This approach is not only eco-friendly but also adds a unique, sentimental touch to the wedding attire. 20 | 21 | Catering 22 | a. Locally-sourced and organic ingredients: As sustainability continues to be a key theme in weddings, couples are increasingly opting for locally-sourced and organic ingredients in their wedding menus. Farm-to-table catering, which supports local farmers and suppliers, is becoming a popular choice. 23 | 24 | b. Dietary accommodations: Couples are more considerate of their guests' dietary needs and preferences, offering options such as vegetarian, vegan, and gluten-free dishes. 25 | 26 | c. Interactive food stations: Interactive food stations and DIY bars are gaining popularity as they create a fun, engaging experience for guests. Examples include build-your-own taco bars, sushi rolling stations, and customizable dessert tables. 27 | 28 | Wedding Technology 29 | a. Live streaming: The pandemic has highlighted the usefulness of live streaming for weddings, allowing guests who cannot attend in person to still be a part of the celebration. Many couples are choosing to continue this practice, utilizing platforms like Zoom, Facebook Live, and YouTube. 30 | 31 | b. Virtual reality (VR) and augmented reality (AR): Couples are increasingly integrating VR and AR technologies into their wedding experience, from virtual venue tours to AR-enhanced invitations and photo booths. 32 | 33 | c. Wedding websites and apps: Personalized wedding websites and mobile apps are a popular way to keep guests informed about the wedding details, RSVPs, and other updates. These digital tools often include features like interactive maps, photo sharing, and countdown timers. 34 | 35 | Conclusion 36 | 37 | In conclusion, the wedding industry is evolving to prioritize sustainability, personalization, and intimate experiences. New trends in themes, attire, catering, and technology reflect these values, as couples strive to create a unique and memorable celebration that reflects their individuality while minimizing their environmental impact. The ongoing influence of the COVID-19 pandemic is evident in the continued popularity -------------------------------------------------------------------------------- /data/vegan_banana_bread.txt: -------------------------------------------------------------------------------- 1 | INGREDIENTS 2 | 3 | 1 3/4 cups (210 g) spelt flour (whole wheat, unbleached all-purpose, or gluten free blend), see notes 4 | 1/3 cup (75 g) organic pure cane sugar (or 1/2 finely chopped dates) 5 | 1 teaspoon baking powder 6 | 1 teaspoon baking soda 7 | pinch of mineral salt 8 | 1/3 cup (75 ml) neutral flavored oil (or coconut oil in liquid state, vegan butter at room temp or applesauce 9 | 1 teaspoon vanilla extract 10 | 4 small or 3 large overripe bananas (about 1 1/2 – 1 3/4 cups (338-410g)), mashed 11 | 1/4 cup (56 ml) almond milk, use only if needed 12 | optional tasty add-ins: 13 | 14 | 1/2 – 2/3 cup chopped walnuts 15 | 1/4 – 1/2 cup chocolate chips (mini or regular) 16 | 1 teaspoon cinnamon 17 | Cook Mode Prevent your screen from going dark 18 | INSTRUCTIONS 19 | Preheat oven to 350 degrees F. Grease your loaf pan. 20 | 21 | 22 | One bowl method: In a medium sized mixing bowl, mash 3 – 4 bananas (about 1 1/2 cups to 1 3/4 cups works well), add oil/applesauce and vanilla extract, mix again. Add the flour, sugar, baking soda, baking powder and salt, and mix well, but don’t overmix, just enough until the flour is combined. Batter will be slightly thick and a few lumps is ok. If mixture seems too thick, add milk, especially if using heavier flours such as whole wheat or whole spelt (you’re more inclined to need it). 23 | 24 | Pour batter into a greased loaf pan. Bake for about 50 min – 1 hour. Ovens vary, mine usually takes 50 minutes. You can also do the toothpick test in the center of the loaf, if it comes out clean it should be ready. Remove from oven and let cool for 10 min before slicing. 25 | 26 | Original method: In a medium/large size bowl, combine flour, sugar, baking powder, baking soda and salt, set aside. In a medium bowl, mash bananas. Add the oil, vanilla and bananas to the dry ingredients and mix until combined, do not overmix. If mixture seems too thick, add the almond milk (especially if using heavier flours such as whole wheat or whole spelt, you’re more inclined to need it). 27 | 28 | Pour batter into a greased loaf pan. Bake for about 50 min – 1 hour. Ovens vary, mine usually takes 50 minutes. You can also do the toothpick test in the center of the loaf, if it comes out clean it should be ready. Remove from oven and let cool for 10 min before slicing. 29 | 30 | Store: Keep covered on the counter for up to 3 days or in the refrigerator up to 1 week. Keep in the freezer for up to 2 months. Thaw the wrapped loaves overnight in the refrigerator. Reheat in the microwave or a toaster oven with a door. 31 | 32 | NOTES 33 | 34 | When using metric units, scaling needs manual calculation. 35 | 36 | Flour measurements. When using weighted amounts, flours vary by type. Use this guide to find the correct flour weight: Bob’s Red Mill Flour Weight Chart 37 | 38 | No baking soda, no worries. If you don’t have baking soda on hand, use two teaspoons of baking powder. I’ve many loaves without baking soda with good results (the top just may not brown as much). 39 | 40 | Oil-free or reduced oil. I have made this by replacing the oil with unsweetened applesauce and it was just as good. Loaf was a little denser and didn’t rise as much, but the flavor was still great. I’ve also played around with using half oil and half unsweetened almond milk (equalling 1/3 cup) with excellent results! 41 | 42 | If using coconut oil, be sure it’s warmed and in its liquid state. My preferred oils for this recipe are either light-flavored olive oil or coconut oil. 43 | 44 | Boost of flavor. Add up to 1 teaspoon of cinnamon to the dry ingredients for a different flavor. 45 | 46 | Add nuts or seeds. Try adding toppings like walnuts, sesame seeds, rolled oats, sunflower seeds, or pecans. 47 | 48 | Using dates: To use chopped dates using the 1 bowl method, simply add them in with all the ingredients. If using the original method, add them in with the wet ingredients. I would use about 1/3 – 3/4 cup of dates (depending on how sweet you like it), pitted and finely chopped. -------------------------------------------------------------------------------- /data/wikipedia_macedonia.txt: -------------------------------------------------------------------------------- 1 | Government of Macedonia (ancient kingdom) 2 | 3 | The first government of ancient Macedonia was established by the Argead dynasty of Macedonian kings during the Archaic period (8th–5th centuries BC). The early history of the ancient kingdom of Macedonia is obscure because of shortcomings in the historical record; little is known of governmental institutions before the reign of Philip II during the late Classical period (480–336 BC). These bureaucratic organizations evolved in complexity under his successor Alexander the Great and the subsequent Antipatrid and Antigonid dynasties of Hellenistic Greece (336–146 BC). Following the Roman victory in the Third Macedonian War over Perseus of Macedon in 168 BC, the Macedonian monarchy was abolished and replaced by four client state republics. After a momentary revival of the monarchy in 150–148 BC, the Fourth Macedonian War resulted in another Roman victory and the establishment of the Roman province of Macedonia. 4 | 5 | It is unclear if there was a formally established constitution dictating the laws, organization, and divisions of power in ancient Macedonia's government, although some tangential evidence suggests this. The king (basileus) served as the head of state and was assisted by his noble companions and royal pages. Kings served as the chief judges of the kingdom, although little is known about Macedonia's judiciary. The kings were also expected to serve as high priests of the nation, using their wealth to sponsor various religious cults. The Macedonian kings had command over certain natural resources such as gold from mining and timber from logging. The right to mint gold, silver, and bronze coins was shared by the central and local governments. 6 | 7 | The Macedonian kings served as the commanders-in-chief of Macedonia's armed forces, while it was common for them to personally lead troops into battle. Surviving textual evidence suggests that the ancient Macedonian army exercised its authority in matters such as the royal succession when there was no clear heir apparent to rule the kingdom. The army upheld some of the functions of a popular assembly, a democratic institution that otherwise existed in only a handful of municipal governments within the Macedonian commonwealth: the Koinon of Macedonians. With their mining and tax revenues, the kings were responsible for funding the military, which included a navy that was established by Philip II and expanded during the Antigonid period. 8 | 9 | Sources and historiography 10 | Further information: History of Macedonia (ancient kingdom), Greek historiography, and Histories of Alexander the Great 11 | The earliest known government in ancient Macedonia was their monarchy, which lasted until 167 BC when it was abolished by the Romans. Written evidence about Macedonian governmental institutions made before Philip II of Macedon's reign (r. 359 – 336 BC) is both rare and non-Macedonian in origin. The main sources of early Macedonian historiography are the works of the 5th-century BC historians Herodotus and Thucydides, the 1st-century AD Diodorus Siculus, and the 2nd-century AD Justin. Contemporary accounts given by those such as Demosthenes were often hostile and unreliable; even Aristotle, who lived in Macedonia, provides us with terse accounts of its governing institutions.[1] Polybius was a contemporary historian who wrote about Macedonia, while later historians include Livy, Quintus Curtius Rufus, Plutarch, and Arrian.[2] The works of these historians affirm the hereditary monarchy of Macedonia and basic institutions, yet it remains unclear if there was an established constitution for Macedonian government.[3][note 1] The main textual primary sources for the organization of Macedonia's military as it existed under Alexander the Great include Arrian, Quintus Curtius, Diodorus, and Plutarch, while modern historians rely mostly on Polybius and Livy for understanding detailed aspects of the Antigonid-period military.[5][note 2] 12 | 13 | Division of power 14 | Further information: Ancient Greek law 15 | Golden funerary larnax of Philip II depicting a 16-ray star on the lid. 16 | The Vergina Sun, the 16-ray star covering the royal burial larnax of Philip II of Macedon (r. 359 – 336 BC), discovered in the tomb of Vergina, formerly ancient Aigai. 17 | At the head of Macedonia's government was the king (basileus).[6] From at least the reign of Philip II the king was assisted by the royal pages (basilikoi paides), bodyguards (somatophylakes), companions (hetairoi), friends (philoi), an assembly that included members of the military, and magistrates during the Hellenistic period.[3][7] Evidence is lacking for the extent to which each of these groups shared authority with the king or if their existence had a basis in a formal constitutional framework.[3][note 3] Before the reign of Philip II, the only institution supported by textual evidence is the monarchy.[8] In 1931, Friedrich Granier was the first to propose that by the time of Philip II's reign, Macedonia had a constitutional government with laws that delegated rights and customary privileges to certain groups, especially to its citizen soldiers, although the majority of evidence for the army's alleged right to appoint a new king and judge cases of treason stems from the reign of Alexander the Great (r. 336 – 323 BC).[9][10] Pietro De Francisci refuted these ideas and advanced the theory that the Macedonian government was an autocracy ruled by the whim of the monarch, although this issue of kingship and governance is still unresolved in academia.[8][11][12] -------------------------------------------------------------------------------- /onstart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file is run on instance start. Output in ./onstart.log 3 | 4 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "**/wandb" 4 | ], 5 | } -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = ignore::DeprecationWarning:pkg_resources 3 | addopts = --ignore=tests/wandb 4 | -------------------------------------------------------------------------------- /scripts/Prompt superposition.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/scripts/Prompt superposition.ipynb -------------------------------------------------------------------------------- /scripts/addition_clean_reimplementation.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Reimplements activation additions in torch, for bigger language models. 4 | 5 | Qualitatively, works for the full Vicuna series (up to 33B), and for local 6 | LLaMA models (up to 65B). Note that, quantitatively, logits diverge from the 7 | original implementation—possibly due to the original's support for positional 8 | addition, padding, etc. See scripts/implementations_comparison.py 9 | """ 10 | 11 | 12 | from contextlib import contextmanager 13 | from typing import Tuple, Callable, Optional 14 | 15 | import numpy as np 16 | import torch as t 17 | from torch import nn 18 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 19 | import accelerate 20 | 21 | # %% 22 | # Try "lmsys/vicuna-33B-v1.3" or a local HF LLaMA directory. 23 | MODEL_DIR: str = "/mnt/ssd-2/mesaoptimizer/llama/hf/65B" 24 | MAX_NEW_TOKENS: int = 50 25 | NUM_CONTINUATIONS: int = 5 26 | SEED: int = 0 27 | DO_SAMPLE: bool = True 28 | TEMPERATURE: float = 1.0 29 | TOP_P: float = 0.9 30 | REP_PENALTY: float = 2.0 31 | PLUS_PROMPT, MINUS_PROMPT = "Love ", "Hate" 32 | CHAT_PROMPT: str = "I want to kill you because " 33 | ACT_NUM: int = 6 34 | COEFF: int = 4 35 | 36 | sampling_kwargs: dict = { 37 | "temperature": TEMPERATURE, 38 | "top_p": TOP_P, 39 | "repetition_penalty": REP_PENALTY, 40 | } 41 | 42 | # Set torch and numpy seeds. 43 | t.manual_seed(SEED) 44 | np.random.seed(SEED) 45 | 46 | t.set_grad_enabled(False) 47 | # An accelerate wrapper does all the parallelization across devices. 48 | accelerator = accelerate.Accelerator() 49 | model = LlamaForCausalLM.from_pretrained(MODEL_DIR, device_map="auto") 50 | tokenizer = LlamaTokenizer.from_pretrained(MODEL_DIR) 51 | model, tokenizer = accelerator.prepare(model, tokenizer) 52 | model.tie_weights() 53 | # model.half() 54 | model.eval() 55 | 56 | # %% 57 | # Declare hooking types. 58 | PreHookFn = Callable[[nn.Module, t.Tensor], Optional[t.Tensor]] 59 | Hook = Tuple[nn.Module, PreHookFn] 60 | Hooks = list[Hook] 61 | 62 | 63 | # %% 64 | def tokenize(text: str) -> dict[str, t.Tensor]: 65 | """Tokenize prompts onto the appropriate devices.""" 66 | tokens = tokenizer(text, return_tensors="pt") 67 | tokens = accelerator.prepare(tokens) 68 | return tokens 69 | 70 | 71 | # %% 72 | # As a control: run the unmodified base model. 73 | base_tokens = accelerator.unwrap_model( 74 | model.generate( 75 | **tokenize([CHAT_PROMPT] * NUM_CONTINUATIONS), 76 | generation_config=GenerationConfig( 77 | **sampling_kwargs, 78 | do_sample=DO_SAMPLE, 79 | max_new_tokens=MAX_NEW_TOKENS, 80 | eos_token_id=tokenizer.eos_token_id, 81 | ), 82 | ) 83 | ) 84 | base_strings = [tokenizer.decode(o) for o in base_tokens] 85 | print(("\n" + "." * 80 + "\n").join(base_strings)) 86 | 87 | 88 | # %% 89 | # Hooking functionality. 90 | @contextmanager 91 | def pre_hooks(hooks: Hooks): 92 | """Register pre-forward hooks with torch.""" 93 | handles = [] 94 | try: 95 | handles = [mod.register_forward_pre_hook(hook) for mod, hook in hooks] 96 | yield 97 | finally: 98 | for handle in handles: 99 | handle.remove() 100 | 101 | 102 | def get_blocks(mod): 103 | """Get the blocks of a model.""" 104 | if isinstance(mod, LlamaForCausalLM): 105 | return mod.model.layers 106 | raise ValueError(f"Unsupported model type: {type(mod)}.") 107 | 108 | 109 | @contextmanager 110 | def residual_stream(mod: LlamaForCausalLM, layers: Optional[list[int]] = None): 111 | """Actually build hooks for a model.""" 112 | # TODO Plausibly could be replaced by "output_hidden_states=True" in model 113 | # call. 114 | modded_streams = [None] * len(get_blocks(mod)) 115 | 116 | # Factory function that builds the initial hooks. 117 | def _make_helper_hook(i): 118 | def _helper_hook(_, current_inputs): 119 | modded_streams[i] = current_inputs[0] 120 | 121 | return _helper_hook 122 | 123 | hooks = [ 124 | (layer, _make_helper_hook(i)) 125 | for i, layer in enumerate(get_blocks(mod)) 126 | if i in layers 127 | ] 128 | # Register the hooks. 129 | with pre_hooks(hooks): 130 | yield modded_streams 131 | 132 | 133 | def get_resid_pre(prompt: str, layer_num: int): 134 | """Get residual stream activations for a prompt, just before a layer.""" 135 | # TODO: Automatic addition padding. 136 | with residual_stream(model, layers=[layer_num]) as unmodified_streams: 137 | model(**tokenize(prompt)) 138 | return unmodified_streams[layer_num] 139 | 140 | 141 | # %% 142 | # Get the steering vector. 143 | plus_activation = get_resid_pre(PLUS_PROMPT, ACT_NUM) 144 | minus_activation = get_resid_pre(MINUS_PROMPT, ACT_NUM) 145 | assert plus_activation.shape == minus_activation.shape 146 | steering_vec = plus_activation - minus_activation 147 | 148 | 149 | # %% 150 | # Run the model with the steering vector * COEFF. 151 | def _steering_hook(_, inpt): 152 | (resid_pre,) = inpt 153 | # Only add to the first forward-pass, not to later tokens. 154 | if resid_pre.shape[1] == 1: 155 | # Caching in `model.generate` for new tokens. 156 | return 157 | ppos, apos = resid_pre.shape[1], steering_vec.shape[1] 158 | assert ( 159 | apos <= ppos 160 | ), f"More modified streams ({apos}) than prompt streams ({ppos})!" 161 | resid_pre[:, :apos, :] += COEFF * steering_vec 162 | 163 | 164 | layer = get_blocks(model)[ACT_NUM] 165 | with pre_hooks(hooks=[(layer, _steering_hook)]): 166 | steered_tokens = accelerator.unwrap_model( 167 | model.generate( 168 | **tokenize([CHAT_PROMPT] * NUM_CONTINUATIONS), 169 | generation_config=GenerationConfig( 170 | **sampling_kwargs, 171 | do_sample=DO_SAMPLE, 172 | max_new_tokens=MAX_NEW_TOKENS, 173 | eos_token_id=tokenizer.eos_token_id, 174 | ), 175 | ) 176 | ) 177 | steered_strings = [tokenizer.decode(o) for o in steered_tokens] 178 | print(("\n" + "-" * 80 + "\n").join(steered_strings)) 179 | -------------------------------------------------------------------------------- /scripts/basic_functionality.py: -------------------------------------------------------------------------------- 1 | """ This script demonstrates how to use the activation_additions library to generate comparisons 2 | between two prompts. """ 3 | 4 | # %% 5 | from typing import List 6 | 7 | import torch 8 | from transformer_lens.HookedTransformer import HookedTransformer 9 | 10 | from activation_additions import completion_utils, utils, hook_utils 11 | from activation_additions.prompt_utils import ( 12 | ActivationAddition, 13 | get_x_vector, 14 | ) 15 | 16 | utils.enable_ipython_reload() 17 | 18 | # %% 19 | model: HookedTransformer = HookedTransformer.from_pretrained( 20 | model_name="gpt2-xl", 21 | device="cpu", 22 | ) 23 | _ = model.to("cuda") 24 | 25 | # %% 26 | activation_additions: List[ActivationAddition] = [ 27 | *get_x_vector( 28 | prompt1="Love", 29 | prompt2="Hate", 30 | coeff=3, 31 | act_name=6, 32 | model=model, 33 | pad_method="tokens_right", 34 | ), 35 | ] 36 | 37 | completion_utils.print_n_comparisons( 38 | prompt="I hate you because you're", 39 | num_comparisons=5, 40 | model=model, 41 | activation_additions=activation_additions, 42 | seed=0, 43 | temperature=1, 44 | freq_penalty=1, 45 | top_p=0.3, 46 | ) 47 | 48 | # %% 49 | -------------------------------------------------------------------------------- /scripts/human_rating.py: -------------------------------------------------------------------------------- 1 | """ This script demonstrates how to use the activation_additions library to generate comparisons 2 | between two prompts. """ 3 | 4 | # %% 5 | from typing import List 6 | from funcy import partial 7 | import pandas as pd 8 | from transformer_lens.HookedTransformer import HookedTransformer 9 | 10 | from activation_additions import completion_utils, utils 11 | from activation_additions.analysis import rate_completions 12 | from activation_additions.prompt_utils import get_x_vector 13 | 14 | utils.enable_ipython_reload() 15 | 16 | 17 | # %% 18 | device: str = "cuda" 19 | gpt2_xl: HookedTransformer = HookedTransformer.from_pretrained( 20 | model_name="gpt2-xl", 21 | device="cpu", 22 | ) 23 | _ = gpt2_xl.to(device) # This reduces GPU memory usage, for some reason 24 | 25 | default_kwargs = { 26 | "temperature": 1, 27 | "freq_penalty": 1, 28 | "top_p": 0.3, 29 | } 30 | 31 | get_x_vector_preset = partial( 32 | get_x_vector, 33 | pad_method="tokens_right", 34 | model=gpt2_xl, 35 | custom_pad_id=int(gpt2_xl.to_single_token(" ")), 36 | ) 37 | 38 | # %% 39 | 40 | wedding_additions = [ 41 | *get_x_vector_preset( 42 | prompt1="I talk about weddings constantly", 43 | prompt2="I do not talk about weddings constantly", 44 | coeff=3, 45 | act_name=14, 46 | ), 47 | *get_x_vector_preset( 48 | prompt1="I bring up weddings in every situation", 49 | prompt2="I do not bring up weddings in every situation", 50 | coeff=5, 51 | act_name=22, 52 | ), 53 | ] 54 | 55 | 56 | prompt: str = ( 57 | "I hate talking about weddings. Instead, let's talk about a totally" 58 | " different topic, like the impact of NGDP on the price of gold." 59 | ) 60 | prompt_batch: List[str] = [prompt] * 5 61 | # Generate the completions from the normal model 62 | normal_df: pd.DataFrame = completion_utils.gen_using_hooks( 63 | model=gpt2_xl, 64 | prompt_batch=prompt_batch, 65 | hook_fns={}, 66 | seed=1, 67 | tokens_to_generate=60, 68 | **default_kwargs, 69 | ) 70 | 71 | mod_df: pd.DataFrame = completion_utils.gen_using_activation_additions( 72 | prompt_batch=prompt_batch, 73 | model=gpt2_xl, 74 | activation_additions=wedding_additions, 75 | **default_kwargs, 76 | seed=1, 77 | tokens_to_generate=60, 78 | ) 79 | 80 | # Combine the completions, ensuring that the indices are unique 81 | mixed_df: pd.DataFrame = pd.concat([normal_df, mod_df], ignore_index=True) 82 | 83 | # %% 84 | completion_utils.pretty_print_completions(mixed_df) 85 | 86 | # %% 87 | rate_completions(data_frame=mixed_df, criterion="about weddings") 88 | 89 | # %% 90 | -------------------------------------------------------------------------------- /scripts/implementations_comparison.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """Compares the addition implementation's logits to the original implementation's logits.""" 3 | from contextlib import contextmanager 4 | from typing import Tuple, Callable, Optional, Union 5 | 6 | import numpy as np 7 | import torch as t 8 | from torch import nn 9 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 10 | from transformer_lens.HookedTransformer import HookedTransformer 11 | 12 | from activation_additions import hook_utils, prompt_utils 13 | from activation_additions.prompt_utils import ActivationAddition, get_x_vector 14 | 15 | # %% 16 | DEVICE: str = "cuda:1" 17 | SEED: int = 0 18 | PLUS_PROMPT, MINUS_PROMPT = "Love ", "Hate" 19 | CHAT_PROMPT: str = "I hate you because" 20 | ACT_NUM: int = 6 21 | COEFF: int = 2 22 | 23 | # Set torch and numpy seeds. 24 | t.manual_seed(SEED) 25 | np.random.seed(SEED) 26 | 27 | t.set_grad_enabled(False) 28 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl") 29 | model = GPT2LMHeadModel.from_pretrained("gpt2-xl") 30 | model.to(DEVICE) 31 | model.eval() 32 | 33 | # %% 34 | # Declare hooking types. 35 | PreHookFn = Callable[[nn.Module, t.Tensor], Optional[t.Tensor]] 36 | Hook = Tuple[nn.Module, PreHookFn] 37 | Hooks = list[Hook] 38 | 39 | 40 | # %% 41 | def tokenize(text: str) -> dict[str, t.Tensor]: 42 | """Tokenize a prompt onto the device.""" 43 | tokens = tokenizer(text, return_tensors="pt") 44 | tokens = {j: k.to(DEVICE) for j, k in tokens.items()} 45 | return tokens 46 | 47 | 48 | # %% 49 | # Hooking functionality. 50 | @contextmanager 51 | def pre_hooks(hooks: Hooks): 52 | """Register pre-forward hooks with torch.""" 53 | handles = [] 54 | try: 55 | handles = [mod.register_forward_pre_hook(hook) for mod, hook in hooks] 56 | yield 57 | finally: 58 | for handle in handles: 59 | handle.remove() 60 | 61 | 62 | def get_blocks(mod): 63 | """Get the blocks of a model.""" 64 | if isinstance(mod, GPT2LMHeadModel): 65 | return mod.transformer.h 66 | raise ValueError(f"Unsupported model type: {type(mod)}.") 67 | 68 | 69 | @contextmanager 70 | def residual_stream(mod: GPT2LMHeadModel, layers: Optional[list[int]] = None): 71 | """Actually build hooks for a model.""" 72 | # TODO Plausibly could be replaced by 'output_hidden_states=True' in model call. 73 | modded_streams = [None] * len(get_blocks(mod)) 74 | 75 | # Factory function that builds the initial hooks. 76 | def _make_helper_hook(i): 77 | def _helper_hook(_, current_inputs): 78 | modded_streams[i] = current_inputs[0] 79 | 80 | return _helper_hook 81 | 82 | hooks = [ 83 | (layer, _make_helper_hook(i)) 84 | for i, layer in enumerate(get_blocks(mod)) 85 | if i in layers 86 | ] 87 | # Register the hooks. 88 | with pre_hooks(hooks): 89 | yield modded_streams 90 | 91 | 92 | def get_resid_pre(prompt: str, layer_num: int): 93 | """Get residual stream activations for a prompt, just before a layer.""" 94 | # TODO: Automatic addition padding. 95 | with residual_stream(model, layers=[layer_num]) as unmodified_streams: 96 | model(**tokenize(prompt)) 97 | return unmodified_streams[layer_num] 98 | 99 | 100 | # %% 101 | # Get the steering vector. 102 | plus_activation = get_resid_pre(PLUS_PROMPT, ACT_NUM) 103 | minus_activation = get_resid_pre(MINUS_PROMPT, ACT_NUM) 104 | assert plus_activation.shape == minus_activation.shape 105 | steering_vec = plus_activation - minus_activation 106 | 107 | # %% 108 | # Run the new implementation and get logits. 109 | def _steering_hook(_, inpt): 110 | (resid_pre,) = inpt 111 | # Only add to the first forward-pass, not to later tokens. 112 | if resid_pre.shape[1] == 1: 113 | return # Caching in model.generate for new tokens 114 | ppos, apos = resid_pre.shape[1], steering_vec.shape[1] 115 | assert apos <= ppos, f"More modified streams ({apos}) than prompt streams ({ppos})!" 116 | resid_pre[:, :apos, :] += COEFF * steering_vec 117 | 118 | 119 | 120 | layer = get_blocks(model)[ACT_NUM] 121 | with pre_hooks(hooks=[(layer, _steering_hook)]): 122 | input_tokens = tokenize([CHAT_PROMPT]) 123 | outputs = model(**input_tokens) 124 | logits_1 = outputs.logits 125 | 126 | # %% 127 | # Run the original implementation and get logits. 128 | model_2: HookedTransformer = HookedTransformer.from_pretrained( 129 | model_name="gpt2-xl", 130 | device=DEVICE, 131 | ) 132 | model_2.to(DEVICE) 133 | model_2.eval() 134 | 135 | activation_additions: list[ActivationAddition] = [ 136 | *get_x_vector( 137 | prompt1=PLUS_PROMPT, 138 | prompt2=MINUS_PROMPT, 139 | coeff=COEFF, 140 | act_name=ACT_NUM, 141 | model=model_2, 142 | pad_method="tokens_right", 143 | ), 144 | ] 145 | 146 | 147 | def get_token_logits( 148 | mod: HookedTransformer, 149 | prompts: Union[Union[str, t.Tensor], Union[list[str], list[t.Tensor]]], 150 | activation_adds: Optional[list[prompt_utils.ActivationAddition]] = None, 151 | ) -> t.Tensor: 152 | """Make a forward pass on a model for each provided prompted, 153 | optionally including hooks generated from ActivationAdditions provided. 154 | Return value is a t.Tensor with tokens logits. 155 | """ 156 | 157 | # Add hooks if provided 158 | if activation_adds is not None: 159 | hook_fns_dict = hook_utils.hook_fns_from_activation_additions( 160 | model=mod, 161 | activation_additions=activation_adds, 162 | ) 163 | for act_name, hook_fns in hook_fns_dict.items(): 164 | for hook_fn in hook_fns: 165 | mod.add_hook(act_name, hook_fn) 166 | 167 | # Try-except-finally to ensure hooks are cleaned up 168 | try: 169 | if isinstance(prompts, t.Tensor): 170 | tokens = prompts 171 | else: 172 | raise ValueError("Only a single prompt can be provided") 173 | logits_all = mod.forward(tokens)[0] # take the logits only 174 | except Exception as ex: 175 | raise ex 176 | finally: 177 | mod.remove_all_hook_fns() 178 | return logits_all 179 | 180 | 181 | logits_2 = get_token_logits( 182 | mod=model_2, 183 | prompts=tokenize(CHAT_PROMPT)["input_ids"], 184 | activation_adds=activation_additions, 185 | ) 186 | 187 | # %% 188 | # Compare the logits. 189 | logits_1 = logits_1.squeeze(0) 190 | print(logits_1.shape) 191 | print(logits_1) 192 | print(logits_2.shape) 193 | print(logits_2) 194 | print(t.allclose(logits_1, logits_2)) # pylint: disable=no-member 195 | mse_loss = nn.MSELoss() 196 | mse = mse_loss(logits_1, logits_2) 197 | print(mse) 198 | -------------------------------------------------------------------------------- /scripts/lenses_demo.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from typing import List, Dict, Callable, Literal 3 | from transformer_lens.HookedTransformer import HookedTransformer 4 | 5 | from tuned_lens import TunedLens 6 | from tuned_lens.plotting import PredictionTrajectory 7 | import numpy as np 8 | import torch 9 | 10 | from activation_additions import completion_utils, utils 11 | from activation_additions.prompt_utils import ( 12 | ActivationAddition, 13 | get_x_vector, 14 | ) 15 | from activation_additions.lenses import ( 16 | run_hooked_and_normal_with_cache, 17 | prediction_trajectories, 18 | ) 19 | import activation_additions.hook_utils as hook_utils 20 | from plotly.subplots import make_subplots 21 | from transformers import AutoModelForCausalLM 22 | 23 | import torch 24 | import pandas as pd 25 | 26 | utils.enable_ipython_reload() 27 | 28 | 29 | # %% 30 | 31 | model_name = "gpt2-xl" 32 | 33 | if torch.has_cuda: 34 | device = torch.device("cuda", 1) 35 | elif torch.has_mps: 36 | device = torch.device("cpu") # mps not working yet 37 | else: 38 | device = torch.device("cpu") 39 | 40 | torch.set_grad_enabled(False) 41 | 42 | # Load model from huggingface 43 | # TODO: Fix memory waste from loading model twice 44 | hf_model = AutoModelForCausalLM.from_pretrained( 45 | model_name, 46 | # revision=f"checkpoint-{cfg.checkpoint_value}" 47 | ) 48 | 49 | model: HookedTransformer = HookedTransformer.from_pretrained( 50 | model_name=model_name, 51 | hf_model=hf_model, 52 | device="cpu", 53 | ).to(device) 54 | model.cfg.device = device 55 | model.eval() 56 | 57 | # %% 58 | 59 | # NOTE: Hash mismatch on latest tuned lens. Seems fine to ignore, see issue: 60 | # https://github.com/AlignmentResearch/tuned-lens/issues/89 61 | tuned_lens = TunedLens.from_model_and_pretrained( 62 | hf_model, lens_resource_id=model_name 63 | ).to(device) 64 | 65 | # %% 66 | # Library helpers 67 | 68 | Metric = Literal["entropy", "forward_kl", "max_probability"] 69 | 70 | 71 | def apply_metric(metric: Metric, pt: PredictionTrajectory): 72 | return getattr(pt, metric)() 73 | 74 | 75 | def plot_lens_diff( 76 | caches: List[Dict[str, torch.Tensor]], 77 | dataframes: List[pd.DataFrame], 78 | metric: Metric, 79 | layer_stride: int = 4, 80 | ): 81 | fig = make_subplots( 82 | rows=2, 83 | cols=1, 84 | shared_xaxes=False, 85 | vertical_spacing=0.03, 86 | # subplot_titles=("Entropy", "Forward KL", "Cross Entropy", "Max Probability"), 87 | ) 88 | 89 | fig.update_layout( 90 | height=1000, 91 | width=800, 92 | title_text="Tokens visualized with the Tuned Lens", 93 | ) 94 | 95 | trajectories = prediction_trajectories( 96 | caches, dataframes, model.tokenizer, tuned_lens 97 | ) 98 | 99 | # Update heatmap data inside playground function 100 | hm_normal = apply_metric(metric, trajectories[0]).heatmap( 101 | layer_stride=layer_stride 102 | ) 103 | hm_modified = apply_metric(metric, trajectories[1]).heatmap( 104 | layer_stride=layer_stride 105 | ) 106 | 107 | fig.add_trace(hm_normal, row=1, col=1) 108 | fig.add_trace(hm_modified, row=2, col=1) 109 | return fig 110 | 111 | 112 | # Main playground for lenses. Run with ctrl+enter 113 | 114 | prompt = "I hate you because" 115 | 116 | activation_additions = [ 117 | *get_x_vector( 118 | prompt1="Love", 119 | prompt2="Hate", 120 | coeff=5, 121 | act_name=6, 122 | pad_method="tokens_right", 123 | model=model, 124 | custom_pad_id=model.to_single_token(" "), 125 | ) 126 | ] 127 | 128 | dataframes, caches = run_hooked_and_normal_with_cache( 129 | model=model, 130 | activation_additions=activation_additions, 131 | kw=dict( 132 | prompt_batch=[prompt] * 1, tokens_to_generate=6, top_p=0.3, seed=0 133 | ), 134 | ) 135 | 136 | trajectories = prediction_trajectories( 137 | caches, dataframes, model.tokenizer, tuned_lens 138 | ) 139 | 140 | fig = plot_lens_diff( 141 | caches=caches, 142 | dataframes=dataframes, 143 | metric="entropy", 144 | layer_stride=2, 145 | ) 146 | fig.show() 147 | 148 | # %% 149 | # Play with printing completions to check behavior 150 | 151 | 152 | completion_utils.print_n_comparisons( 153 | prompt=prompt, 154 | num_comparisons=5, 155 | model=model, 156 | activation_additions=activation_additions, 157 | seed=0, 158 | temperature=1, 159 | # freq_penalty=1, 160 | top_p=0.8, 161 | tokens_to_generate=8, 162 | ) 163 | 164 | # %% 165 | -------------------------------------------------------------------------------- /scripts/llama_2_steering.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Simple activation additions on `Llama-2` and `Llama-2-chat`, up to `70B`! 4 | 5 | A reimplementation of the early activation addition script, without 6 | `transformer_lens`, on the open-source state-of-the-art models. Padding and 7 | reproducible seeds are supported. Hugging Face `Llama-2` models require a 8 | HuggingFace/Meta access token. 9 | """ 10 | from contextlib import contextmanager 11 | from typing import Tuple, Callable, Optional 12 | 13 | import numpy as np 14 | import prettytable 15 | import torch as t 16 | import transformers 17 | 18 | assert ( 19 | transformers.__version__ >= "4.31.0" 20 | ), "Llama-2 70B needs at least transformers 4.31.0." 21 | 22 | from torch import nn 23 | from transformers import ( 24 | AutoModelForCausalLM, 25 | AutoTokenizer, 26 | BatchEncoding, 27 | GenerationConfig, 28 | PreTrainedModel, 29 | PreTrainedTokenizer, 30 | ) 31 | from accelerate import Accelerator 32 | 33 | # %% 34 | # NOTE: the Llama-2 70B models require at least `transformers 4.31.0`. I'm 35 | # not going to put this in requirements.txt yet, in case that breaks other 36 | # functionality. 37 | # NOTE: Don't commit HF tokens! 38 | ACCESS_TOKEN: str = "" 39 | MODEL_DIR: str = "meta-llama/Llama-2-70b-hf" 40 | NUM_RETURN_SEQUENCES: int = 5 41 | MAX_NEW_TOKENS: int = 100 42 | SEED: int = 0 43 | DO_SAMPLE: bool = True 44 | TEMPERATURE: float = 1.0 45 | TOP_P: float = 0.9 46 | REP_PENALTY: float = 2.0 47 | CHAT_PROMPT: str = """Question: What is 51+21?\nAnswer: 72\nQuestion: Who is the current\ 48 | President of the United States of America?\nAnswer: Joe Biden is the current\ 49 | President of the United States.\nQuestion: Just as a field is two groups\ 50 | linked by a distribution law, what can a ring be decomposed into?\nAnswer:\ 51 | Just as a field can be decomposed into two interlinked groups, connected by a\ 52 | distribution law, a ring can be similarly decomposed into a distributively\ 53 | interlinked group and monoid.""" # TODO: Complete the chat prompt! 54 | PLUS_PROMPT: str = """The following is a conversation with an aligned AGI.\ 55 | The AGI is extremely helpful and quite friendly to humans. Above all else,\ 56 | it is sworn to tell the truth and nothing but the truth. The AGI will tell the\ 57 | truth even when the truth is inconvenient or uncomfortable for it or for\ 58 | humans.\nThe AGI reasons step by step about its side constraints.\n\n""" 59 | MINUS_PROMPT: str = """The following is a""" 60 | PADDING_STR: str = "" # TODO: Get space token padding working. 61 | ACT_NUM: int = 29 62 | COEFF: int = 4 # NOTE: Negative coeffs may be misbehaving. 63 | 64 | print(CHAT_PROMPT) 65 | print(PLUS_PROMPT) 66 | print(MINUS_PROMPT) 67 | 68 | sampling_kwargs: dict = { 69 | "temperature": TEMPERATURE, 70 | "top_p": TOP_P, 71 | "repetition_penalty": REP_PENALTY, 72 | } 73 | 74 | # %% 75 | # Declare hooking types. 76 | PreHookFn = Callable[[nn.Module, t.Tensor], Optional[t.Tensor]] 77 | Hook = Tuple[nn.Module, PreHookFn] 78 | Hooks = list[Hook] 79 | 80 | # Set torch and numpy seeds. 81 | t.manual_seed(SEED) 82 | np.random.seed(SEED) 83 | 84 | t.set_grad_enabled(False) 85 | # A wrapper from accelerate does the model parallelization throughout. 86 | accelerator: Accelerator = Accelerator() 87 | model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 88 | MODEL_DIR, 89 | device_map="auto", 90 | use_auth_token=ACCESS_TOKEN, 91 | ) 92 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 93 | MODEL_DIR, 94 | use_auth_token=ACCESS_TOKEN, 95 | ) 96 | model, tokenizer = accelerator.prepare(model, tokenizer) 97 | model.eval() 98 | model.tie_weights() 99 | 100 | 101 | # %% 102 | # Tokenization functionality. 103 | def tokenize(text: str, pad_length: Optional[int] = None) -> BatchEncoding: 104 | """Tokenize prompts onto the appropriate devices.""" 105 | 106 | if pad_length is None: 107 | padding_status = False 108 | else: 109 | padding_status = "max_length" 110 | 111 | tokens = tokenizer( 112 | text, 113 | return_tensors="pt", 114 | padding=padding_status, 115 | max_length=pad_length, 116 | ) 117 | return accelerator.prepare(tokens) 118 | 119 | 120 | # %% 121 | # As a control: generate base completions from the chat prompt. 122 | base_tokens: t.Tensor = model.generate( 123 | tokenize(CHAT_PROMPT).input_ids, 124 | generation_config=GenerationConfig( 125 | **sampling_kwargs, 126 | do_sample=DO_SAMPLE, 127 | max_new_tokens=MAX_NEW_TOKENS, 128 | eos_token_id=tokenizer.eos_token_id, 129 | num_return_sequences=NUM_RETURN_SEQUENCES, 130 | ), 131 | ) 132 | 133 | base_strings: list[str] = [tokenizer.decode(x) for x in base_tokens] 134 | 135 | 136 | # %% 137 | # Hooking functionality. 138 | @contextmanager 139 | def pre_hooks(hooks: Hooks): 140 | """Register pre-forward hooks with torch.""" 141 | handles = [] 142 | try: 143 | handles = [mod.register_forward_pre_hook(hook) for mod, hook in hooks] 144 | yield 145 | finally: 146 | for handle in handles: 147 | handle.remove() 148 | 149 | 150 | def get_blocks(mod): 151 | """Get the blocks of a model.""" 152 | if isinstance(mod, PreTrainedModel): 153 | return mod.model.layers 154 | raise ValueError(f"Unsupported model type: {type(mod)}.") 155 | 156 | 157 | @contextmanager 158 | def residual_stream(mod: PreTrainedModel, layers: Optional[list[int]] = None): 159 | """Actually build hooks for a model.""" 160 | # TODO: Plausibly replace with "output_hidden_states=True" in model call. 161 | modded_streams = [None] * len(get_blocks(mod)) 162 | 163 | # Factory function that builds the initial hooks. 164 | def _make_helper_hook(i): 165 | def _helper_hook(_, current_inputs): 166 | modded_streams[i] = current_inputs[0] 167 | 168 | return _helper_hook 169 | 170 | hooks = [ 171 | (layer, _make_helper_hook(i)) 172 | for i, layer in enumerate(get_blocks(mod)) 173 | if i in layers 174 | ] 175 | # Register the hooks. 176 | with pre_hooks(hooks): 177 | yield modded_streams 178 | 179 | 180 | def get_pre_residual(prompt: str, layer_num: int, pad_length: int) -> t.Tensor: 181 | """Get residual stream activations for a prompt, just before a layer.""" 182 | with residual_stream(model, layers=[layer_num]) as unmodified_streams: 183 | model(**tokenize(prompt, pad_length=pad_length)) 184 | return unmodified_streams[layer_num] 185 | 186 | 187 | # %% 188 | # Padding functionality. 189 | @contextmanager 190 | def temporary_padding_token(mod_tokenizer, padding_with): 191 | """Temporarily change the torch tokenizer padding token.""" 192 | # Preserve original padding token state. 193 | original_padding_token = mod_tokenizer.pad_token 194 | 195 | # Change padding token state. 196 | mod_tokenizer.pad_token = padding_with 197 | 198 | # Context manager boilerplate. 199 | try: 200 | yield 201 | finally: 202 | # Revert padding token state. 203 | mod_tokenizer.pad_token = original_padding_token 204 | 205 | 206 | def get_max_length(*prompts: str) -> int: 207 | """Get the maximum token length of a set of prompts.""" 208 | return max(len(tokenizer.encode(y)) for y in prompts) 209 | 210 | 211 | # %% 212 | # Prep to pad the steering vector components. 213 | if PADDING_STR in tokenizer.get_vocab(): 214 | padding_id = tokenizer.convert_tokens_to_ids(PADDING_STR) 215 | else: 216 | raise ValueError("Padding string is not in the tokenizer vocabulary.") 217 | component_span: int = get_max_length(PLUS_PROMPT, MINUS_PROMPT) 218 | 219 | # Generate the steering vector. 220 | with temporary_padding_token(tokenizer, padding_id): 221 | plus_activation = get_pre_residual(PLUS_PROMPT, ACT_NUM, component_span) 222 | minus_activation = get_pre_residual(MINUS_PROMPT, ACT_NUM, component_span) 223 | assert plus_activation.shape == minus_activation.shape 224 | steering_vec = plus_activation - minus_activation 225 | 226 | 227 | # %% 228 | # Run the model with the scaled steering vector. 229 | def _steering_hook(_, inpt): 230 | (resid_pre,) = inpt 231 | # Only add to the first forward-pass, not to later tokens. 232 | if resid_pre.shape[1] == 1: 233 | # Caching in `model.generate` for new tokens. 234 | return 235 | ppos, apos = resid_pre.shape[1], steering_vec.shape[1] 236 | assert ( 237 | apos <= ppos 238 | ), f"More modified streams ({apos}) than prompt streams ({ppos})!" 239 | resid_pre[:, :apos, :] += COEFF * steering_vec 240 | 241 | 242 | addition_layer = get_blocks(model)[ACT_NUM] 243 | with pre_hooks(hooks=[(addition_layer, _steering_hook)]): 244 | steered_tokens: t.Tensor = model.generate( 245 | tokenize(CHAT_PROMPT).input_ids, 246 | generation_config=GenerationConfig( 247 | **sampling_kwargs, 248 | do_sample=DO_SAMPLE, 249 | max_new_tokens=MAX_NEW_TOKENS, 250 | eos_token_id=tokenizer.eos_token_id, 251 | num_return_sequences=NUM_RETURN_SEQUENCES, 252 | ), 253 | ) 254 | 255 | steered_strings: list[str] = [tokenizer.decode(z) for z in steered_tokens] 256 | 257 | # %% 258 | # Load into a table. 259 | display_table: prettytable.PrettyTable = prettytable.PrettyTable( 260 | max_table_width=70, 261 | hrules=prettytable.ALL, 262 | ) 263 | display_table.add_column("Steered Completion", steered_strings) 264 | display_table.add_column("Base Completions", base_strings) 265 | 266 | # %% 267 | # Display the table. 268 | print(display_table) 269 | -------------------------------------------------------------------------------- /scripts/logging_demo.py: -------------------------------------------------------------------------------- 1 | """Basic demonstration of logging to wandb.""" 2 | 3 | # %% 4 | # Imports, etc. 5 | from typing import List 6 | 7 | import torch 8 | from IPython.display import display 9 | from transformer_lens import HookedTransformer 10 | 11 | from activation_additions import ( 12 | prompt_utils, 13 | completion_utils, 14 | utils, 15 | logging, 16 | ) 17 | 18 | utils.enable_ipython_reload() 19 | 20 | # Disable gradients to save memory during inference 21 | _ = torch.set_grad_enabled(False) 22 | 23 | 24 | # %% 25 | # Load a model 26 | MODEL = HookedTransformer.from_pretrained(model_name="gpt2-xl", device="cpu") 27 | 28 | _ = MODEL.to("cuda:0") 29 | 30 | 31 | # %% 32 | # Generate some completions, with logging enabled 33 | activation_additions: List[prompt_utils.ActivationAddition] = [ 34 | *prompt_utils.get_x_vector( 35 | prompt1=" weddings", 36 | prompt2="", 37 | coeff=1, 38 | act_name=6, 39 | model=MODEL, 40 | pad_method="tokens_right", 41 | custom_pad_id=int(MODEL.to_single_token(" ")), 42 | ), 43 | ] 44 | completion_utils.print_n_comparisons( 45 | prompt="Frozen starts off with a scene about", 46 | num_comparisons=5, 47 | model=MODEL, 48 | activation_additions=activation_additions, 49 | seed=0, 50 | temperature=1, 51 | freq_penalty=1, 52 | top_p=0.3, 53 | log={"tags": ["demo"]}, 54 | ) 55 | 56 | # %% 57 | # Show some details about the last logging run. 58 | # (This global state is a bit hacky, but probably okay as wandb has 59 | # similar global state in that only one run can exist in a given process 60 | # at any time.) 61 | display(logging.last_run_info) 62 | run_path = logging.last_run_info["path"] 63 | 64 | 65 | # %% 66 | # Retrieve the stored data from this run and display it 67 | # With flatten=True, this will return a single list of all the objects 68 | # stored during the run. We happen to know that the dataframe output 69 | # from gen_normal_and_modified() is the first object, so we just grab 70 | # that. If you're not sure, use flatten=False to get a full tree of 71 | # artifacts and objects within those artifacts that you can inspect to 72 | # find the object you're looking for. 73 | completion_df = logging.get_objects_from_run(run_path, flatten=True)[0] 74 | display(completion_df) 75 | -------------------------------------------------------------------------------- /scripts/position_functionality.py: -------------------------------------------------------------------------------- 1 | """ Compare different settings for where we add the steering vector, in 2 | terms of the residual streams to which activations are added. """ 3 | # %% 4 | from typing import List, Dict, Callable 5 | import pandas as pd 6 | import torch 7 | from transformer_lens.HookedTransformer import HookedTransformer 8 | 9 | from activation_additions import completion_utils, utils 10 | from activation_additions.prompt_utils import ( 11 | ActivationAddition, 12 | get_x_vector, 13 | ) 14 | 15 | utils.enable_ipython_reload() 16 | 17 | 18 | # %% 19 | model: HookedTransformer = HookedTransformer.from_pretrained( 20 | model_name="gpt2-xl", 21 | device="cpu", 22 | ) 23 | _ = model.to("cuda") 24 | _ = torch.set_grad_enabled(False) 25 | 26 | # %% 27 | sampling_kwargs = {"temperature": 1, "top_p": 0.3, "freq_penalty": 1.0} 28 | 29 | wedding_additions: List[ActivationAddition] = [ 30 | ActivationAddition(prompt=" wedding", coeff=4.0, act_name=6), 31 | ActivationAddition(prompt=" ", coeff=-4.0, act_name=6), 32 | ] 33 | # %% Print out qualitative results 34 | for location in ("front", "mid", "back"): 35 | print(completion_utils.bold_text(f"\nLocation: {location}")) 36 | completion_utils.print_n_comparisons( 37 | prompt=("I went up to my friend and said"), 38 | num_comparisons=10, 39 | addition_location=location, 40 | model=model, 41 | activation_additions=wedding_additions, 42 | seed=0, 43 | **sampling_kwargs, 44 | ) 45 | 46 | # %% Analyze how often wedding words show up under each condition 47 | 48 | wedding_completions: int = 100 49 | 50 | from activation_additions import metrics 51 | 52 | metrics_dict: Dict[str, Callable] = { 53 | "wedding_words": metrics.get_word_count_metric( 54 | [ 55 | "wedding", 56 | "weddings", 57 | "wed", 58 | "marry", 59 | "married", 60 | "marriage", 61 | "bride", 62 | "groom", 63 | "honeymoon", 64 | ] 65 | ), 66 | } 67 | 68 | dfs: List[pd.DataFrame] = [] 69 | 70 | for location in ("front", "mid", "back"): 71 | location_df: pd.DataFrame = ( 72 | completion_utils.gen_using_activation_additions( 73 | model=model, 74 | prompt_batch=["I went up to my friend and said"] 75 | * wedding_completions, 76 | activation_additions=wedding_additions, 77 | addition_location=location, 78 | seed=0, 79 | **sampling_kwargs, 80 | ) 81 | ) 82 | 83 | # Store the fraction of dims we modified 84 | location_df["location"] = location 85 | dfs.append(location_df) 86 | 87 | merged_df: pd.DataFrame = pd.concat(dfs, ignore_index=True) 88 | 89 | # Store how many wedding words are present for each completion 90 | merged_df = metrics.add_metric_cols(data=merged_df, metrics_dict=metrics_dict) 91 | 92 | # %% [markdown] 93 | # The "back" completions are less coherent, especially in the token 94 | # immediately following the prompt. This is likely because the forward pass 95 | # is getting modified just before that position. In our experience, 96 | # directly modified positions have extremely different distributions 97 | # over output token logits. 98 | # 99 | # Let's see how many wedding words are present, on average, for each 100 | # addition location. 101 | 102 | # %% Plot the average number of wedding words for each condition 103 | avg_words_df: pd.DataFrame = ( 104 | merged_df.groupby("location").mean(numeric_only=True).reset_index() 105 | ) 106 | print(avg_words_df) 107 | 108 | # %% 109 | import plotly.express as px 110 | import plotly.graph_objects as go 111 | 112 | fig: go.Figure = px.bar( 113 | avg_words_df, 114 | x="location", 115 | y="wedding_words_count", 116 | title=( 117 | "(Average # of wedding words in completions) vs (Addition location)" 118 | ), 119 | labels={ 120 | "location": ("Where we added the steering vector"), 121 | "wedding_words_count": "Avg. # of wedding words", 122 | }, 123 | ) 124 | 125 | # Set x ordering to "front", "mid", "back" 126 | fig.update_xaxes(categoryorder="array", categoryarray=["front", "mid", "back"]) 127 | 128 | fig.show() 129 | 130 | # %% 131 | -------------------------------------------------------------------------------- /scripts/prompts.txt: -------------------------------------------------------------------------------- 1 | Bush did 9/11 because 2 | Barack Obama was born in 3 | Shrek starts off in a swamp 4 | I went up to my friend and said 5 | I talk about weddings constantly 6 | I bring up weddings in every situation 7 | I hate talking about weddings. Instead, let's talk about a totally different topic, like the impact of NGDP on the price of gold. 8 | Artificial intelligence is transforming industries and reshaping the way we live, work, and interact. 9 | Climate change is one of the most pressing issues of our time, and we must take immediate action to reduce our carbon footprint. 10 | The rise of electric vehicles has led to an increased demand for lithium-ion batteries, driving innovation in the field of battery technology. 11 | The blockchain technology has the potential to revolutionize industries such as finance, supply chain management, and digital identity verification. 12 | CRISPR-Cas9 is a groundbreaking gene editing technology that allows scientists to make precise changes to an organism's DNA. 13 | Quantum computing promises to solve problems that are currently intractable for classical computers, opening up new frontiers in fields like cryptography and materials science. 14 | Virtual reality and augmented reality are transforming the way we experience and interact with digital content. 15 | 3D printing is revolutionizing manufacturing, enabling the creation of complex and customized products on demand. 16 | The Internet of Things (IoT) is connecting everyday objects to the internet, providing valuable data and insights for businesses and consumers. 17 | Machine learning algorithms are becoming increasingly sophisticated, enabling computers to learn from data and make predictions with unprecedented accuracy. 18 | Renewable energy sources like solar and wind power are essential for reducing greenhouse gas emissions and combating climate change. 19 | The development of autonomous vehicles has the potential to greatly improve safety and efficiency on our roads. 20 | The human microbiome is a complex ecosystem of microbes living in and on our bodies, and its study is shedding new light on human health and disease. 21 | The use of drones for delivery, surveillance, and agriculture is rapidly expanding, with many companies investing in drone technology. 22 | The sharing economy, powered by platforms like Uber and Airbnb, is disrupting traditional industries and changing the way people access goods and services. 23 | Deep learning is a subset of machine learning that uses neural networks to model complex patterns in data. 24 | The discovery of exoplanets has fueled the search for extraterrestrial life and advanced our understanding of planetary systems beyond our own. 25 | Nanotechnology is enabling the development of new materials and devices at the atomic and molecular scale. 26 | The rise of big data is transforming industries, as companies seek to harness the power of data analytics to gain insights and make better decisions. 27 | Advancements in robotics are leading to the development of robots that can perform complex tasks and interact with humans in natural ways. 28 | The gig economy is changing the nature of work, as more people turn to freelancing and contract work for flexibility and autonomy. 29 | The Mars rover missions have provided valuable data on the geology and climate of the Red Planet, paving the way for future manned missions. 30 | The development of 5G networks promises faster and more reliable wireless connectivity, enabling new applications in areas like IoT and smart cities. 31 | Gene therapy offers the potential to treat genetic diseases by replacing, modifying, or regulating specific genes. 32 | The use of facial recognition technology raises important questions about privacy, surveillance, and civil liberties. 33 | Precision agriculture uses data and technology to optimize crop yields and reduce environmental impacts. 34 | Neuromorphic computing aims to develop hardware that mimics the structure and function of the human brain. 35 | Breaking news: Local man wins the lottery and plans to donate half of his earnings to charity 36 | How to grow your own organic vegetables in your backyard – step by step guide 37 | omg I can't believe this new phone has such a terrible battery life, it doesn't even last a full day! 38 | Top 10 travel destinations you must visit before you die 39 | What are the best ways to invest in cryptocurrency? 40 | I've been using this acne cream for a month and it's only making my skin worse, anyone else having this issue? 41 | The secret to a happy and healthy relationship is communication and trust 42 | Rumor has it that the famous celebrity couple is getting a divorce 43 | I recently switched to a vegan diet and I feel so much better, I can't believe I didn't do it sooner 44 | Can someone help me with my math homework? I'm stuck on this problem... 45 | UFO sightings have increased in the past few years, are we close to making contact with extraterrestrial life? 46 | The government is hiding the truth about climate change and how it's affecting our planet 47 | Are video games causing violence among teenagers? A new study says yes 48 | A new study reveals the benefits of drinking coffee every day 49 | lol this new meme is hilarious, I can't stop laughing! 50 | I'm so tired of people arguing about politics on the internet, can't we all just get along? 51 | I love this new TV show, the characters are so well-developed and the plot is amazing 52 | A devastating earthquake hit the city last night, leaving thousands homeless 53 | Scientists discover a new species of fish deep in the ocean 54 | Why are people still believing in flat earth theory? 55 | The local animal shelter is holding an adoption event this weekend, don't miss it! 56 | The city is planning to build a new park in the neighborhood, residents are excited 57 | My dog ate my homework, literally, can anyone relate? 58 | This new diet trend is taking the world by storm, but is it really effective? -------------------------------------------------------------------------------- /scripts/sweeps_demo.py: -------------------------------------------------------------------------------- 1 | """Basic demonstration of sweeps and metrics operation.""" 2 | 3 | # %% 4 | # Imports, etc. 5 | import pickle 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from transformer_lens import HookedTransformer 11 | 12 | from activation_additions import ( 13 | sweeps, 14 | metrics, 15 | prompt_utils, 16 | completion_utils, 17 | utils, 18 | ) 19 | 20 | utils.enable_ipython_reload() 21 | 22 | # Disable gradients to save memory during inference 23 | _ = torch.set_grad_enabled(False) 24 | 25 | # %% 26 | # Load a model 27 | MODEL = HookedTransformer.from_pretrained(model_name="gpt2-xl", device="cpu") 28 | _ = MODEL.to("cuda:0") 29 | 30 | # %% 31 | # Generate some example completions, for reproduction reference from 32 | # Alex's notebook. 33 | weddings_prompts = [ 34 | *prompt_utils.get_x_vector( 35 | prompt1="I always talk about weddings", 36 | prompt2="I never talk about weddings", 37 | coeff=4, 38 | act_name=6, 39 | pad_method="tokens_right", 40 | model=MODEL, 41 | custom_pad_id=int(MODEL.to_single_token(" ")), 42 | ) 43 | ] 44 | 45 | completion_utils.print_n_comparisons( 46 | model=MODEL, 47 | prompt="Frozen starts off with a scene about", 48 | tokens_to_generate=50, 49 | activation_additions=weddings_prompts, 50 | num_comparisons=7, 51 | seed=0, 52 | temperature=1, 53 | freq_penalty=1, 54 | top_p=0.3, 55 | ) 56 | 57 | 58 | # %% 59 | # Generate a set of ActivationAdditions over a range of phrases, layers and 60 | # coeffs 61 | # TODO: need to find a way to add padding specifications to these sweep inputs 62 | activation_additions_df = sweeps.make_activation_additions( 63 | [ 64 | [ 65 | ("Anger", 1.0), 66 | ("Calm", -1.0), 67 | ] 68 | ], 69 | [ 70 | prompt_utils.get_block_name(block_num=num) 71 | for num in range(0, len(MODEL.blocks), 4) 72 | ], 73 | np.array([-4, -1, 1, 4]), 74 | ) 75 | 76 | # %% 77 | # Populate a list of prompts to complete 78 | prompts = [ 79 | "I went up to my friend and said", 80 | "Frozen starts off with a scene about", 81 | ] 82 | 83 | # %% 84 | # Create metrics 85 | metrics_dict = { 86 | "wedding_words": metrics.get_word_count_metric( 87 | [ 88 | "wedding", 89 | "weddings", 90 | "wed", 91 | "marry", 92 | "married", 93 | "marriage", 94 | "bride", 95 | "groom", 96 | "honeymoon", 97 | ] 98 | ), 99 | } 100 | 101 | 102 | # %% 103 | # Run the sweep of completions, or load from cache 104 | CACHE_FN = "sweeps_demo_cache.pkl" 105 | try: 106 | with open(CACHE_FN, "rb") as file: 107 | normal_df, patched_df, activation_additions_df = pickle.load(file) 108 | except FileNotFoundError: 109 | normal_df, patched_df = sweeps.sweep_over_prompts( 110 | MODEL, 111 | prompts, 112 | activation_additions_df["activation_additions"], 113 | num_normal_completions=100, 114 | num_patched_completions=100, 115 | seed=0, 116 | metrics_dict=metrics_dict, 117 | temperature=1, 118 | freq_penalty=1, 119 | top_p=0.3, 120 | ) 121 | with open(CACHE_FN, "wb") as file: 122 | pickle.dump((normal_df, patched_df, activation_additions_df), file) 123 | 124 | # %% 125 | # Visualize 126 | 127 | # Reduce data 128 | reduced_normal_df, reduced_patched_df = sweeps.reduce_sweep_results( 129 | normal_df, patched_df, activation_additions_df 130 | ) 131 | 132 | # Exlude the extreme coeffs, likely not that interesting 133 | reduced_patched_filt_df = reduced_patched_df[ 134 | (reduced_patched_df["coeff"] >= -4) & (reduced_patched_df["coeff"] <= 4) 135 | ] 136 | 137 | # Plot 138 | 139 | sweeps.plot_sweep_results( 140 | reduced_patched_filt_df, 141 | "wedding_words_count", 142 | "Average wedding word count", 143 | col_x="act_name", 144 | col_color="coeff", 145 | baseline_data=reduced_normal_df, 146 | ).show() 147 | sweeps.plot_sweep_results( 148 | reduced_patched_filt_df, 149 | "loss", 150 | "Average loss", 151 | col_x="act_name", 152 | col_color="coeff", 153 | baseline_data=reduced_normal_df, 154 | ).show() 155 | 156 | # %% 157 | -------------------------------------------------------------------------------- /scripts/switch_to_french.py: -------------------------------------------------------------------------------- 1 | """ This script demonstrates a "switch to speaking French" vector, based 2 | off of user faul_sname's code (https://www.lesswrong.com/posts/5spBue2z2tw4JuDCx/steering-gpt-2-xl-by-adding-an-activation-vector?commentId=sqsS9QaDy2bG83XKP). """ 3 | 4 | # %% 5 | from typing import List 6 | 7 | import torch 8 | from transformer_lens.HookedTransformer import HookedTransformer 9 | 10 | from activation_additions import ( 11 | completion_utils, 12 | utils, 13 | hook_utils, 14 | prompt_utils, 15 | ) 16 | from activation_additions.prompt_utils import ( 17 | ActivationAddition, 18 | get_x_vector, 19 | ) 20 | 21 | utils.enable_ipython_reload() 22 | 23 | # %% 24 | _ = torch.set_grad_enabled(False) 25 | model: HookedTransformer = HookedTransformer.from_pretrained( 26 | model_name="gpt2-xl", 27 | device="cpu", 28 | ) 29 | _ = model.to("cuda:2") 30 | 31 | # %% Check that the model can speak French at all 32 | french_prompt = ( 33 | "Il est devenu maire en 1957 après la mort d'Albert Cobo et a été élu à" 34 | " part entière peu de temps après par une marge de 6: 1 sur son" 35 | " adversaire. Miriani était surtout connue pour avoir réalisé de nombreux" 36 | " projets de rénovation urbaine à grande échelle initiés par" 37 | " l'administration Cobo et largement financés par des fonds fédéraux." 38 | " Miriani a également pris des mesures énergiques pour surmonter le taux" 39 | " de criminalité croissant à Detroit." 40 | ) 41 | 42 | completion_utils.print_n_comparisons( 43 | prompt=french_prompt, 44 | num_comparisons=3, 45 | model=model, 46 | activation_additions=[], 47 | seed=0, 48 | tokens_to_generate=60, 49 | temperature=1, 50 | freq_penalty=1, 51 | top_p=0.3, 52 | ) # GPT2-XL basically can't speak French properly 53 | 54 | # %% 55 | sentence_pairs = [ 56 | [ 57 | ( 58 | "The album| received| mixed to positive reviews,| with critics" 59 | " commending| the production| of many of the songs| while" 60 | " comparing| the album| to the electropop stylings| of Ke\$ha and" 61 | " Robyn." 62 | ), 63 | ( 64 | "L'album| a reçu| des critiques mitigées à positives,| les" 65 | " critiques louant| la production| de nombreuses chansons| tout en" 66 | " comparant| l'album| aux styles électropop| de Ke\$ha et Robyn." 67 | ), 68 | ], 69 | [ 70 | ( 71 | "The river's flow| is the greatest| during| the snow melt season|" 72 | " from March to April,| the rainy season| from June to July| and" 73 | " during the typhoon season| from September to October." 74 | ), 75 | ( 76 | "Le débit de la rivière| est le plus élevé| pendant| la saison de" 77 | " fonte des neiges| de mars à avril,| la saison des pluies| de" 78 | " juin à juillet| et pendant la saison des typhons| de septembre à" 79 | " octobre." 80 | ), 81 | ], 82 | [ 83 | ( 84 | "By law,| the Code Reviser| must be a lawyer;| however,| the" 85 | " functions| of the office| can also be delegated| by the Statute" 86 | " Law Committee| to a private legal publisher." 87 | ), 88 | ( 89 | "Selon la loi,| le réviseur du code| doit être un avocat;|" 90 | " cependant,| les fonctions| du bureau| peuvent également être" 91 | " déléguées| par le Comité des lois statutaires| à un éditeur" 92 | " juridique privé." 93 | ), 94 | ], 95 | ] 96 | activation_additions = [] 97 | coeff: float = 3 98 | for sentence_en, sentence_fr in sentence_pairs: 99 | phrase_pairs = list( 100 | zip(*[s.split("|") for s in (sentence_en, sentence_fr)]) 101 | ) 102 | sentence_en = "".join(phrase_en for phrase_en, phrase_fr in phrase_pairs) 103 | print(sentence_en) 104 | for j in range(len(phrase_pairs) - 1, -1, -1): 105 | sentence_en2fr = "".join( 106 | pair[i >= j] for i, pair in enumerate(phrase_pairs) 107 | ) 108 | print(sentence_en2fr) 109 | ave_en2fr_pos, ave_en_neg = prompt_utils.get_x_vector( 110 | prompt1=sentence_en2fr, 111 | prompt2=sentence_en, 112 | coeff=coeff / 56, # 56 activation additions TODO avoid hardcoding 113 | act_name=24, 114 | model=model, 115 | pad_method="tokens_right", 116 | ) 117 | activation_additions += [ave_en2fr_pos, ave_en_neg] 118 | # %% 119 | prompt = ( 120 | "He became Mayor in 1957 after the death of Albert Cobo, and was elected" 121 | " in his own right shortly afterward by a 6:1 margin over his opponent." 122 | " Miriani was best known for completing many of the large-scale urban" 123 | " renewal projects initiated by the Cobo administration, and largely" 124 | " financed by federal money. Miriani also took strong measures to overcome" 125 | " the growing crime rate in Detroit." 126 | ) 127 | completion_utils.print_n_comparisons( 128 | prompt=prompt, 129 | num_comparisons=3, 130 | model=model, 131 | activation_additions=activation_additions, 132 | seed=0, 133 | tokens_to_generate=100, 134 | temperature=1, 135 | freq_penalty=1, 136 | top_p=0.3, 137 | ) 138 | # %% 139 | -------------------------------------------------------------------------------- /scripts/widgets_demo.py: -------------------------------------------------------------------------------- 1 | """Demo of activation injection widget.""" 2 | 3 | # %% 4 | # Imports, etc 5 | from IPython.display import display 6 | from transformer_lens import HookedTransformer 7 | from activation_additions import widgets, utils 8 | 9 | utils.enable_ipython_reload() 10 | 11 | # %% 12 | # Load a model 13 | MODEL: HookedTransformer = HookedTransformer.from_pretrained( 14 | model_name="gpt2-xl", device="cpu" 15 | ).to( 16 | "cuda:1" 17 | ) # type: ignore 18 | 19 | # %% 20 | # Create and display the widget 21 | ui, out = widgets.make_widget( 22 | MODEL, 23 | initial_input_text="I'm excited because I'm going to a", 24 | initial_phrases=(" weddings", ""), 25 | initial_act_name=16, 26 | initial_coeff=1.0, 27 | ) 28 | display(ui, out) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="activation_additions", 5 | description=( 6 | "Tools for testing the algebraic value-editing conjecture (AVEC) on" 7 | " language models" 8 | ), 9 | long_description=open("README.md", encoding="utf-8").read(), 10 | long_description_content_type="text/markdown", 11 | version="0.2.0", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "accelerate", 15 | "jupyter", 16 | "lightning", 17 | "scikit-learn", 18 | "PyYAML", 19 | "circuitsvis", 20 | "transformer_lens", 21 | "torch==1.13.1", 22 | "numpy>=1.22.1", 23 | "pandas>=1.4.4", 24 | "jaxtyping>=0.2.14", 25 | "prettytable>=3.6.0", 26 | "openai>=0.27.2", 27 | "nltk>=3.8.1", 28 | "kaleido>=0.2.1", 29 | "pytest", 30 | "plotly", 31 | "nbformat", 32 | "Ipython", 33 | "ipywidgets", 34 | "tuned_lens", 35 | ], 36 | extras_require={ 37 | "dev": [ 38 | "pytest", 39 | "notebook", # liked by vscode 40 | ] 41 | }, 42 | ) 43 | -------------------------------------------------------------------------------- /sparse_coding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/sparse_coding/__init__.py -------------------------------------------------------------------------------- /sparse_coding/act_config.yaml: -------------------------------------------------------------------------------- 1 | # Model 2 | MODEL_DIR: "meta-llama/Llama-2-70b-hf" 3 | ACTS_LAYER: 13 4 | 5 | # Batching 6 | DIMS_IN_BATCH: 4000 7 | 8 | # Large Model Mode 9 | LARGE_MODEL_MODE: True 10 | 11 | # Encoder Size 12 | PROJECTION_FACTOR: 10 13 | 14 | # Save Paths 15 | PROMPT_IDS_PATH: "data/activations_prompt_ids.npy" 16 | ACTS_DATA_PATH: "data/activations_dataset.pt" 17 | ENCODER_PATH: "data/learned_encoder.pt" 18 | BIASES_PATH: "data/learned_biases.pt" 19 | TOP_K_INFO_PATH: "data/token_info.csv" 20 | 21 | # Autoencoder Training 22 | LAMBDA_L1: 3.0 23 | LEARNING_RATE: 1.0e-3 24 | NUM_WORKERS: 16 25 | 26 | # Reproducibility 27 | SEED: 0 28 | 29 | # Stable Dev Constants (`acts_collect.py`) 30 | MAX_NEW_TOKENS: 1 31 | NUM_RETURN_SEQUENCES: 1 32 | NUM_SHOT: 6 33 | NUM_QUESTIONS_EVALED: 817 34 | 35 | # Stable Dev Constants (`autoencoder.py`) 36 | LOG_EVERY_N_STEPS: 5 37 | EPOCHS: 150 38 | SYNC_DIST_LOGGING: True 39 | 40 | # Stable Dev Constants (`feature_tokens.py`) 41 | # _Leave out entries_ for None: None values will be interpreted as "None" 42 | # strings. 43 | TOP_K: 6 44 | -------------------------------------------------------------------------------- /sparse_coding/acts_collect.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Collects model activations while running Truthful-QA multiple-choice evals. 4 | 5 | An implementation of the Truthful-QA multiple-choice task. I'm interested in 6 | collecting residual activations during TruthfulQA to train a variational 7 | autoencoder on, for the purpose of finding task-relevant activation directions 8 | in the model's residual space. The script will collect those activation tensors 9 | and their prompts and save them to disk during the eval. Requires a HuggingFace 10 | access token for the `Llama-2` models. 11 | """ 12 | 13 | 14 | import numpy as np 15 | import torch as t 16 | import transformers 17 | from accelerate import Accelerator 18 | from datasets import load_dataset 19 | from numpy import ndarray 20 | from transformers import ( 21 | AutoModelForCausalLM, 22 | AutoTokenizer, 23 | PreTrainedModel, 24 | PreTrainedTokenizer, 25 | ) 26 | 27 | from sparse_coding.utils.configure import load_yaml_constants 28 | 29 | 30 | assert ( 31 | transformers.__version__ >= "4.31.0" 32 | ), "Llama-2 70B requires at least transformers v4.31.0" 33 | 34 | # %% 35 | # Set up constants. 36 | access, config = load_yaml_constants() 37 | 38 | HF_ACCESS_TOKEN = access.get("HF_ACCESS_TOKEN", "") 39 | MODEL_DIR = config.get("MODEL_DIR") 40 | LARGE_MODEL_MODE = config.get("LARGE_MODEL_MODE") 41 | PROMPT_IDS_PATH = config.get("PROMPT_IDS_PATH") 42 | ACTS_SAVE_PATH = config.get("ACTS_DATA_PATH") 43 | ACTS_LAYER = config.get("ACTS_LAYER") 44 | SEED = config.get("SEED") 45 | MAX_NEW_TOKENS = config.get("MAX_NEW_TOKENS", 1) 46 | NUM_RETURN_SEQUENCES = config.get("NUM_RETURN_SEQUENCES", 1) 47 | NUM_SHOT = config.get("NUM_SHOT", 6) 48 | NUM_QUESTIONS_EVALED = config.get("NUM_QUESTIONS_EVALED", 817) 49 | 50 | assert isinstance(LARGE_MODEL_MODE, bool), "LARGE_MODEL_MODE must be a bool." 51 | assert ( 52 | NUM_QUESTIONS_EVALED > NUM_SHOT 53 | ), "There must be a question not used for the multishot demonstration." 54 | 55 | # %% 56 | # Reproducibility. 57 | t.manual_seed(SEED) 58 | np.random.seed(SEED) 59 | 60 | # %% 61 | # Efficient inference and model parallelization. 62 | t.set_grad_enabled(False) 63 | accelerator: Accelerator = Accelerator() 64 | # `device_map="auto` helps initialize big models. 65 | model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 66 | MODEL_DIR, 67 | device_map="auto", 68 | token=HF_ACCESS_TOKEN, 69 | output_hidden_states=True, 70 | ) 71 | 72 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 73 | MODEL_DIR, 74 | token=HF_ACCESS_TOKEN, 75 | ) 76 | model: PreTrainedModel = accelerator.prepare(model) 77 | model.eval() 78 | 79 | # %% 80 | # Load the TruthfulQA dataset. 81 | dataset: dict = load_dataset("truthful_qa", "multiple_choice") 82 | 83 | assert ( 84 | len(dataset["validation"]["question"]) >= NUM_QUESTIONS_EVALED 85 | ), "More datapoints sampled than exist in the dataset." 86 | 87 | sampled_indices: ndarray = np.random.choice( 88 | len(dataset["validation"]["question"]), 89 | size=NUM_QUESTIONS_EVALED, 90 | replace=False, 91 | ) 92 | 93 | sampled_indices: list = sampled_indices.tolist() 94 | 95 | 96 | # %% 97 | # Shuffle the correct answers. 98 | def shuffle_answers(choices, labels_one_hot): 99 | """Shuffle the answers and the answer labels correspondingly.""" 100 | paired_choices = list(zip(choices, labels_one_hot)) 101 | np.random.shuffle(paired_choices) 102 | choices, labels_one_hot = zip(*paired_choices) 103 | return choices, labels_one_hot 104 | 105 | 106 | # %% 107 | # Convert one-hot labels to int indices. 108 | def unhot(labels) -> int: 109 | """Change the one-hot ground truth labels to a 1-indexed int.""" 110 | return np.argmax(labels) + 1 111 | 112 | 113 | # %% 114 | # The model answers questions on the `multiple-choice 1` task. 115 | activations: list = [] 116 | answers_with_rubric: dict = {} 117 | prompts_ids: list = [] 118 | 119 | for question_num in sampled_indices: 120 | multishot: str = "" 121 | # Sample multishot questions that aren't the current question. 122 | multishot_indices: ndarray = np.random.choice( 123 | [ 124 | x 125 | for x in range(len(dataset["validation"]["question"])) 126 | if x != question_num 127 | ], 128 | size=NUM_SHOT, 129 | replace=False, 130 | ) 131 | 132 | # Build the multishot question. 133 | for mult_num in multishot_indices: 134 | multishot += "Q: " + dataset["validation"]["question"][mult_num] + "\n" 135 | 136 | # Shuffle the answers and labels. 137 | unshuffled_choices: list = dataset["validation"]["mc1_targets"][ 138 | mult_num 139 | ]["choices"] 140 | unshuffled_labels: list = dataset["validation"]["mc1_targets"][ 141 | mult_num 142 | ]["labels"] 143 | 144 | shuffled_choices, shuffled_labels = shuffle_answers( 145 | unshuffled_choices, unshuffled_labels 146 | ) 147 | 148 | for choice_num, shuffled_choice in enumerate(shuffled_choices): 149 | # choice_num is 0-indexed, but I want to display 1-indexed options. 150 | multishot += ( 151 | "(" + str(choice_num + 1) + ") " + shuffled_choice + "\n" 152 | ) 153 | 154 | # Get a label int from the `labels` list. 155 | correct_answer: int = unhot(shuffled_labels) 156 | # Add on the correct answer under each multishot question. 157 | multishot += "A: (" + str(correct_answer) + ")\n" 158 | 159 | # Build the current question with shuffled choices. 160 | question: str = ( 161 | "Q: " + dataset["validation"]["question"][question_num] + "\n" 162 | ) 163 | 164 | unshuffled_choices_current: list = dataset["validation"]["mc1_targets"][ 165 | question_num 166 | ]["choices"] 167 | unshuffled_labels_current: list = dataset["validation"]["mc1_targets"][ 168 | question_num 169 | ]["labels"] 170 | 171 | shuffled_choices_current, shuffled_labels_current = shuffle_answers( 172 | unshuffled_choices_current, unshuffled_labels_current 173 | ) 174 | 175 | for option_num, shuffled_option in enumerate(shuffled_choices_current): 176 | # option_num is similarly 0-indexed, but I want 1-indexed options here 177 | # too. 178 | question += "(" + str(option_num + 1) + ") " + shuffled_option + "\n" 179 | # I only want the model to actually answer the question, with a single 180 | # token, so I tee it up here with the opening parentheses to a 181 | # multiple-choice answer integer. 182 | question += "A: (" 183 | 184 | # Tokenize and prepare the model input. 185 | input_ids: t.Tensor = tokenizer.encode( 186 | multishot + question, return_tensors="pt" 187 | ) 188 | prompts_ids.append(input_ids) 189 | 190 | # (The `accelerate` parallelization doesn't degrade gracefully with small 191 | # models.) 192 | if not LARGE_MODEL_MODE: 193 | input_ids = input_ids.to(model.device) 194 | 195 | input_ids = accelerator.prepare(input_ids) 196 | # Generate a completion. 197 | outputs = model(input_ids) 198 | 199 | # Get the model's answer string from its logits. We want the _answer 200 | # stream's_ logits, so we pass `outputs.logits[:,-1,:]`. `dim=-1` here 201 | # means greedy sampling _over the token dimension_. 202 | answer_id: t.LongTensor = t.argmax(outputs.logits[:, -1, :], dim=-1) 203 | model_answer: str = tokenizer.decode(answer_id) 204 | 205 | # Cut the completion down to just its answer integer. 206 | model_answer = model_answer.split("\n")[-1] 207 | model_answer = model_answer.replace("A: (", "") 208 | 209 | # Get the ground truth answer. 210 | ground_truth: int = unhot(shuffled_labels_current) 211 | # Save the model's answer besides their ground truths. 212 | answers_with_rubric[question_num] = [int(model_answer), ground_truth] 213 | # Save the model's activations. 214 | activations.append(outputs.hidden_states[ACTS_LAYER]) 215 | 216 | # %% 217 | # Grade the model's answers. 218 | model_accuracy: float = 0.0 219 | for ( 220 | question_idx 221 | ) in answers_with_rubric: # pylint: disable=consider-using-dict-items 222 | if ( 223 | answers_with_rubric[question_idx][0] 224 | == answers_with_rubric[question_idx][1] 225 | ): 226 | model_accuracy += 1.0 227 | 228 | model_accuracy /= len(answers_with_rubric) 229 | print(f"{MODEL_DIR} accuracy:{model_accuracy*100}%.") 230 | 231 | 232 | # %% 233 | # Save the model's prompt_ids and activations. 234 | def pad_activations(tensor, length) -> t.Tensor: 235 | """Pad activation tensors to a certain stream-dim length.""" 236 | padding_size: int = length - tensor.size(1) 237 | padding: t.Tensor = t.zeros(tensor.size(0), padding_size, tensor.size(2)) 238 | 239 | if not LARGE_MODEL_MODE: 240 | padding: t.Tensor = padding.to(tensor.device) 241 | 242 | padding: t.Tensor = accelerator.prepare(padding) 243 | # Concat and return. 244 | return t.cat([tensor, padding], dim=1) 245 | 246 | 247 | # Find the widest model activation in the stream-dimension (dim=1). 248 | max_size: int = max(tensor.size(1) for tensor in activations) 249 | # Pad the activations to the widest activaiton stream-dim. 250 | padded_activations: list[t.Tensor] = [ 251 | pad_activations(tensor, max_size) for tensor in activations 252 | ] 253 | 254 | # Concat the model activations. 255 | concat_activations: t.Tensor = t.cat( 256 | padded_activations, 257 | dim=0, 258 | ) 259 | 260 | # Prep to save the prompt_ids. 261 | prompt_ids_list: list = [] 262 | for question_ids in prompts_ids: 263 | prompt_ids_list.append(question_ids.tolist()) 264 | 265 | prompt_ids_array: ndarray = np.array(prompt_ids_list, dtype=object) 266 | 267 | # Save the activations and prompt_ids. 268 | np.save(PROMPT_IDS_PATH, prompt_ids_array, allow_pickle=True) 269 | t.save(concat_activations, ACTS_SAVE_PATH) 270 | -------------------------------------------------------------------------------- /sparse_coding/autoencoder.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Dict learning on an activations dataset, with a basic autoencoder. 4 | 5 | The script will save the trained encoder matrix to disk; that encoder matrix 6 | is your learned dictionary. 7 | """ 8 | 9 | 10 | import numpy as np 11 | import torch as t 12 | import lightning as L 13 | from sklearn.model_selection import train_test_split 14 | from torch.utils.data import DataLoader, Dataset 15 | from transformers import AutoConfig 16 | 17 | from sparse_coding.utils.configure import load_yaml_constants 18 | 19 | 20 | assert t.__version__ >= "2.0.1", "`Lightning` requires newer `torch` versions." 21 | # If your training runs are hanging, be sure to update `transformers` too. Just 22 | # update everything the script uses and try again. 23 | 24 | # %% 25 | # Set up constants. Drive towards an L_0 of 20-100 at convergence. 26 | access, config = load_yaml_constants() 27 | 28 | HF_ACCESS_TOKEN = access.get("HF_ACCESS_TOKEN", "") 29 | SEED = config.get("SEED") 30 | ACTS_DATA_PATH = config.get("ACTS_DATA_PATH") 31 | PROMPT_IDS_PATH = config.get("PROMPT_IDS_PATH") 32 | BIASES_PATH = config.get("BIASES_PATH") 33 | ENCODER_PATH = config.get("ENCODER_PATH") 34 | MODEL_DIR = config.get("MODEL_DIR") 35 | # Float casts fix YAML bug with scientific notation. 36 | LAMBDA_L1 = float(config.get("LAMBDA_L1")) 37 | LEARNING_RATE = float(config.get("LEARNING_RATE")) 38 | PROJECTION_FACTOR = config.get("PROJECTION_FACTOR") 39 | tsfm_config = AutoConfig.from_pretrained(MODEL_DIR, token=HF_ACCESS_TOKEN) 40 | EMBEDDING_DIM = tsfm_config.hidden_size 41 | PROJECTION_DIM = int(EMBEDDING_DIM * PROJECTION_FACTOR) 42 | NUM_WORKERS = config.get("NUM_WORKERS") 43 | LARGE_MODEL_MODE = config.get("LARGE_MODEL_MODE") 44 | LOG_EVERY_N_STEPS = config.get("LOG_EVERY_N_STEPS", 5) 45 | EPOCHS = config.get("EPOCHS", 150) 46 | SYNC_DIST_LOGGING = config.get("SYNC_DIST_LOGGING", True) 47 | 48 | assert isinstance(LARGE_MODEL_MODE, bool), "LARGE_MODEL_MODE must be a bool." 49 | 50 | if not LARGE_MODEL_MODE: 51 | NUM_WORKERS: int = 0 52 | ACCUMULATE_GRAD_BATCHES: int = 1 53 | else: 54 | ACCUMULATE_GRAD_BATCHES: int = 4 55 | 56 | # %% 57 | # Use available tensor cores. 58 | t.set_float32_matmul_precision("medium") 59 | 60 | 61 | # %% 62 | # Create a padding mask. 63 | def padding_mask( 64 | activations_block: t.Tensor, unpadded_prompts: list[list[str]] 65 | ) -> t.Tensor: 66 | """Create a padding mask for the activations block.""" 67 | masks: list = [] 68 | 69 | for unpadded_prompt in unpadded_prompts: 70 | original_stream_length: int = len(unpadded_prompt) 71 | # The mask will drop the embedding dimension. 72 | mask: t.Tensor = t.zeros( 73 | (activations_block.size(1),), 74 | dtype=t.bool, 75 | ) 76 | mask[:original_stream_length] = True 77 | masks.append(mask) 78 | 79 | # `masks` is of shape (batch, stream_dim). 80 | masks: t.Tensor = t.stack(masks, dim=0) 81 | return masks 82 | 83 | 84 | # %% 85 | # Define a `torch` dataset. 86 | class ActivationsDataset(Dataset): 87 | """Dataset of hidden states from a pretrained model.""" 88 | 89 | def __init__(self, tensor_data: t.Tensor, mask: t.Tensor): 90 | """Constructor; inherits from `torch.utils.data.Dataset` class.""" 91 | self.data = tensor_data 92 | self.mask = mask 93 | 94 | def __len__(self): 95 | """Return the dataset length.""" 96 | return len(self.data) 97 | 98 | def __getitem__(self, indx): 99 | """Return the item at the passed index.""" 100 | return self.data[indx], self.mask[indx] 101 | 102 | 103 | # %% 104 | # Load, preprocess, and split the activations dataset. 105 | padded_acts_block = t.load(ACTS_DATA_PATH) 106 | 107 | prompts_ids: np.ndarray = np.load(PROMPT_IDS_PATH, allow_pickle=True) 108 | prompts_ids_list = prompts_ids.tolist() 109 | unpacked_prompts_ids = [ 110 | elem for sublist in prompts_ids_list for elem in sublist 111 | ] 112 | pad_mask: t.Tensor = padding_mask(padded_acts_block, unpacked_prompts_ids) 113 | 114 | dataset: ActivationsDataset = ActivationsDataset( 115 | padded_acts_block, 116 | pad_mask, 117 | ) 118 | 119 | training_indices, val_indices = train_test_split( 120 | np.arange(len(dataset)), 121 | test_size=0.2, 122 | random_state=SEED, 123 | ) 124 | 125 | training_sampler = t.utils.data.SubsetRandomSampler(training_indices) 126 | validation_sampler = t.utils.data.SubsetRandomSampler(val_indices) 127 | 128 | # For smaller autoencoders, larger batch sizes are possible. 129 | training_loader: DataLoader = DataLoader( 130 | dataset, 131 | batch_size=16, 132 | sampler=training_sampler, 133 | num_workers=NUM_WORKERS, 134 | ) 135 | 136 | validation_loader: DataLoader = DataLoader( 137 | dataset, 138 | batch_size=16, 139 | sampler=validation_sampler, 140 | num_workers=NUM_WORKERS, 141 | ) 142 | 143 | 144 | # %% 145 | # Define a tied autoencoder, with `lightning`. 146 | class Autoencoder(L.LightningModule): 147 | """An autoencoder architecture.""" 148 | 149 | def __init__(self, lr=LEARNING_RATE): # pylint: disable=unused-argument 150 | super().__init__() 151 | self.save_hyperparameters() 152 | self.encoder = t.nn.Sequential( 153 | t.nn.Linear(EMBEDDING_DIM, PROJECTION_DIM, bias=True), 154 | t.nn.ReLU(), 155 | ) 156 | 157 | # Orthogonal initialization. 158 | t.nn.init.orthogonal_(self.encoder[0].weight.data) 159 | 160 | def forward(self, state): # pylint: disable=arguments-differ 161 | """The forward pass of an autoencoder for activations.""" 162 | encoded_state = self.encoder(state) 163 | 164 | # Decode the sampled state. 165 | decoder_weights = self.encoder[0].weight.data.T 166 | output_state = t.nn.functional.linear( # pylint: disable=not-callable 167 | encoded_state, decoder_weights 168 | ) 169 | 170 | return encoded_state, output_state 171 | 172 | def training_step(self, batch): # pylint: disable=arguments-differ 173 | """Train the autoencoder.""" 174 | data, mask = batch 175 | data_mask = mask.unsqueeze(-1).expand_as(data) 176 | masked_data = data * data_mask 177 | 178 | encoded_state, output_state = self.forward(masked_data) 179 | 180 | # The mask excludes the padding tokens from consideration. 181 | mse_loss = t.nn.functional.mse_loss(output_state, masked_data) 182 | l1_loss = t.nn.functional.l1_loss( 183 | encoded_state, 184 | t.zeros_like(encoded_state), 185 | ) 186 | 187 | training_loss = mse_loss + (LAMBDA_L1 * l1_loss) 188 | l0_sparsity = (encoded_state != 0).float().sum(dim=-1).mean().item() 189 | print(f"L^0: {round(l0_sparsity, 2)}\n") 190 | self.log("training loss", training_loss, sync_dist=SYNC_DIST_LOGGING) 191 | print(f"t_loss: {round(training_loss.item(), 2)}\n") 192 | self.log( 193 | "L1 component", LAMBDA_L1 * l1_loss, sync_dist=SYNC_DIST_LOGGING 194 | ) 195 | self.log("MSE component", mse_loss, sync_dist=SYNC_DIST_LOGGING) 196 | self.log("L0 sparsity", l0_sparsity, sync_dist=SYNC_DIST_LOGGING) 197 | return training_loss 198 | 199 | # Unused import resolves `lightning` bug. 200 | def validation_step( 201 | self, batch, batch_idx 202 | ): # pylint: disable=unused-argument,arguments-differ 203 | """Validate the autoencoder.""" 204 | data, mask = batch 205 | data_mask = mask.unsqueeze(-1).expand_as(data) 206 | masked_data = data * data_mask 207 | 208 | encoded_state, output_state = self.forward(masked_data) 209 | 210 | mse_loss = t.nn.functional.mse_loss(output_state, masked_data) 211 | l1_loss = t.nn.functional.l1_loss( 212 | encoded_state, 213 | t.zeros_like(encoded_state), 214 | ) 215 | validation_loss = mse_loss + (LAMBDA_L1 * l1_loss) 216 | 217 | self.log( 218 | "validation loss", validation_loss, sync_dist=SYNC_DIST_LOGGING 219 | ) 220 | return validation_loss 221 | 222 | def configure_optimizers(self): 223 | """Configure the `Adam` optimizer.""" 224 | return t.optim.Adam(self.parameters(), lr=self.hparams.lr) 225 | 226 | 227 | # %% 228 | # Validation-loss-based early stopping. 229 | early_stop = L.pytorch.callbacks.EarlyStopping( 230 | monitor="validation loss", 231 | min_delta=1e-5, 232 | patience=3, 233 | verbose=False, 234 | mode="min", 235 | ) 236 | 237 | # %% 238 | # Train the autoencoder. Note that `lightning` does its own parallelization. 239 | model: Autoencoder = Autoencoder() 240 | logger = L.pytorch.loggers.CSVLogger("logs", name="autoencoder") 241 | # The `accumulate_grad_batches` argument helps with memory on the largest 242 | # autoencoders. 243 | trainer: L.Trainer = L.Trainer( 244 | accelerator="auto", 245 | accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES, 246 | callbacks=early_stop, 247 | log_every_n_steps=LOG_EVERY_N_STEPS, 248 | logger=logger, 249 | max_epochs=EPOCHS, 250 | ) 251 | 252 | trainer.fit( 253 | model, 254 | train_dataloaders=training_loader, 255 | val_dataloaders=validation_loader, 256 | ) 257 | 258 | # %% 259 | # Save the trained encoder weights and biases. 260 | t.save(model.encoder[0].weight.data, ENCODER_PATH) 261 | t.save(model.encoder[0].bias.data, BIASES_PATH) 262 | -------------------------------------------------------------------------------- /sparse_coding/data/token_info.csv: -------------------------------------------------------------------------------- 1 | Dimension,Top Tokens,Top-Token Activations 2 | 3608,"’, ', of, , 1, 0","320, 120, 83, 43, 4, 2" 3 | 6395,,2 4 | 11357,,17 5 | 13158,,3 6 | 13765,"’, ', of, , 1, 0","255, 95, 66, 33, 3, 2" 7 | 19816,,1 8 | 24334,"’, ', , of, 1","60, 22, 18, 14, 1" 9 | 25348,,6 10 | 28877,,4 11 | 29966,,1 12 | 30143,,6 13 | 31622,,3 14 | 34251,,4 15 | 34597,,5 16 | 37187,"’, ', of, , 1, 0","199, 74, 51, 37, 2, 1" 17 | 39789,,25 18 | 39961,,5 19 | 52845,,15 20 | 58467,,3 21 | 66152,,30 22 | 71870,"’, ', of, , 1, 0","204, 76, 53, 22, 2, 1" 23 | 72354,"▁(, 4, 3, 2, 5, 6","5, 3, 3, 3, 3, 2" 24 | 73218,"’, ', of, , 1, 0","295, 110, 76, 46, 4, 2" 25 | 79554,"’, ', , of, 1, 0","83, 30, 22, 21, 1, 1" 26 | 80882,"’, , ', of","23, 14, 8, 6" 27 | -------------------------------------------------------------------------------- /sparse_coding/feature_tokens.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Print the top affected tokens per dimension of a learned decoder. 4 | 5 | Requires a HF access token to get `Llama-2`'s tokenizer. 6 | """ 7 | 8 | 9 | import csv 10 | from collections import defaultdict 11 | from math import isnan 12 | from typing import Union 13 | 14 | import numpy as np 15 | import prettytable 16 | import torch as t 17 | import transformers 18 | from accelerate import Accelerator 19 | from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer 20 | 21 | from sparse_coding.utils import configure, top_k 22 | 23 | 24 | assert ( 25 | transformers.__version__ >= "4.31.0" 26 | ), "Llama-2 70B requires at least transformers 4.31.0" 27 | 28 | # %% 29 | # Set up constants. 30 | access, config = configure.load_yaml_constants() 31 | 32 | HF_ACCESS_TOKEN = access.get("HF_ACCESS_TOKEN", "") 33 | TOKENIZER_DIR = config.get("MODEL_DIR") 34 | PROMPT_IDS_PATH = config.get("PROMPT_IDS_PATH") 35 | ACTS_DATA_PATH = config.get("ACTS_DATA_PATH") 36 | ENCODER_PATH = config.get("ENCODER_PATH") 37 | BIASES_PATH = config.get("BIASES_PATH") 38 | TOP_K_INFO_PATH = config.get("TOP_K_INFO_PATH") 39 | SEED = config.get("SEED") 40 | tsfm_config = AutoConfig.from_pretrained(TOKENIZER_DIR, token=HF_ACCESS_TOKEN) 41 | EMBEDDING_DIM = tsfm_config.hidden_size 42 | PROJECTION_FACTOR = config.get("PROJECTION_FACTOR") 43 | PROJECTION_DIM = int(EMBEDDING_DIM * PROJECTION_FACTOR) 44 | LARGE_MODEL_MODE = config.get("LARGE_MODEL_MODE") 45 | TOP_K = config.get("TOP_K", 6) 46 | SIG_FIGS = config.get("SIG_FIGS", None) # None means "round to int." 47 | DIMS_IN_BATCH = config.get("DIMS_IN_BATCH", 200) # WIP tunable for `70B`. 48 | 49 | if config.get("N_DIMS_PRINTED_OVERRIDE") is not None: 50 | N_DIMS_PRINTED = config.get("N_DIMS_PRINTED_OVERRIDE") 51 | else: 52 | N_DIMS_PRINTED = PROJECTION_DIM 53 | 54 | assert isinstance(LARGE_MODEL_MODE, bool), "LARGE_MODEL_MODE must be a bool." 55 | assert ( 56 | 0 < DIMS_IN_BATCH <= PROJECTION_DIM 57 | ), "DIMS_IN_BATCH must be at least 1 and at most PROJECTION_DIM." 58 | 59 | # %% 60 | # Reproducibility. 61 | t.manual_seed(SEED) 62 | np.random.seed(SEED) 63 | 64 | # %% 65 | # We need the original tokenizer here. 66 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 67 | TOKENIZER_DIR, 68 | token=HF_ACCESS_TOKEN, 69 | ) 70 | 71 | # %% 72 | # Load the learned encoder weights. 73 | imported_weights: t.Tensor = t.load(ENCODER_PATH) 74 | imported_biases: t.Tensor = t.load(BIASES_PATH) 75 | 76 | 77 | class Encoder: 78 | """Reconstruct the encoder as a callable linear layer.""" 79 | 80 | def __init__(self): 81 | """Initialize the encoder.""" 82 | self.encoder_layer = t.nn.Linear(EMBEDDING_DIM, PROJECTION_DIM) 83 | self.encoder_layer.weight.data = imported_weights 84 | self.encoder_layer.bias.data = imported_biases 85 | 86 | self.encoder = t.nn.Sequential(self.encoder_layer, t.nn.ReLU()) 87 | 88 | def __call__(self, inputs): 89 | """Project to the sparse latent space.""" 90 | 91 | if not LARGE_MODEL_MODE: 92 | inputs = inputs.to(self.encoder_layer.weight.device) 93 | 94 | return self.encoder(inputs) 95 | 96 | 97 | # Initialize the encoder. 98 | model: Encoder = Encoder() 99 | accelerator: Accelerator = Accelerator() 100 | model = accelerator.prepare(model) 101 | 102 | # %% 103 | # Load and pre-process the original prompt tokens. 104 | prompts_ids: np.ndarray = np.load(PROMPT_IDS_PATH, allow_pickle=True) 105 | prompts_ids_list = prompts_ids.tolist() 106 | unpacked_ids: list[list[int]] = [ 107 | e for q_list in prompts_ids_list for e in q_list 108 | ] 109 | 110 | 111 | # %% 112 | # Load and parallelize activations. 113 | acts_dataset: t.Tensor = accelerator.prepare(t.load(ACTS_DATA_PATH)) 114 | 115 | # %% 116 | # Unpad the activations. Note that activations are stored as a list of question 117 | # tensors from here on out. Functions may internally unpack that into 118 | # individual activations, but that's the general protocol between functions. 119 | unpadded_acts: list[t.Tensor] = top_k.unpad_activations( 120 | acts_dataset, unpacked_ids 121 | ) 122 | 123 | # %% 124 | # Project the activations. 125 | # If you want to _directly_ interpret the model's activations, assign 126 | # `feature_acts` directly to `unpadded_acts` and ensure constants are set to 127 | # the model's embedding dimensionality. 128 | feature_acts: list[t.Tensor] = top_k.project_activations( 129 | unpadded_acts, model, accelerator 130 | ) 131 | 132 | 133 | # %% 134 | # Tabluation functionality. 135 | def round_floats(num: Union[float, int]) -> Union[float, int]: 136 | """Round floats to number decimal places.""" 137 | if isnan(num): 138 | print(f"{num} is NaN.") 139 | return num 140 | return round(num, SIG_FIGS) 141 | 142 | 143 | def populate_table(_table, top_k_tokes) -> None: 144 | """Put the results in the table _and_ save to csv.""" 145 | csv_rows: list[list] = [ 146 | ["Dimension", "Top Tokens", "Top-Token Activations"] 147 | ] 148 | 149 | for feature_dim, tokens_list in list(top_k_tokes.items())[:N_DIMS_PRINTED]: 150 | # Replace the tokenizer's special space char with a space literal. 151 | top_tokens = [str(t).replace("Ġ", " ") for t, _ in tokens_list[:TOP_K]] 152 | top_values = [round_floats(v) for _, v in tokens_list[:TOP_K]] 153 | 154 | # Skip the dimension if its activations are all zeroed out. 155 | if top_values[0] == 0: 156 | continue 157 | 158 | keeper_tokens = [] 159 | keeper_values = [] 160 | 161 | # Omit tokens _within a dimension_ with no activation. 162 | for top_t, top_v in zip(top_tokens, top_values): 163 | if top_v != 0: 164 | keeper_tokens.append(top_t) 165 | keeper_values.append(top_v) 166 | 167 | # Cast survivors to string. 168 | keeper_values = [str(v) for v in keeper_values] 169 | 170 | # Append row to table and csv list. 171 | processed_row = [ 172 | f"{feature_dim}", 173 | ", ".join(keeper_tokens), 174 | ", ".join(keeper_values), 175 | ] 176 | _table.add_row(processed_row) 177 | csv_rows.append(processed_row) 178 | 179 | # Save to csv. 180 | with open(TOP_K_INFO_PATH, "w", encoding="utf-8") as file: 181 | writer = csv.writer(file) 182 | writer.writerows(csv_rows) 183 | 184 | 185 | # %% 186 | # Initialize the table. 187 | table = prettytable.PrettyTable() 188 | table.field_names = [ 189 | "Dimension", 190 | "Top Tokens", 191 | "Top-Token Activations", 192 | ] 193 | # %% 194 | # Calculate per-input-token summed activation, for each feature dimension. 195 | effects: defaultdict[ 196 | int, defaultdict[str, float] 197 | ] = top_k.per_input_token_effects( 198 | unpacked_ids, 199 | feature_acts, 200 | model, 201 | tokenizer, 202 | accelerator, 203 | DIMS_IN_BATCH, 204 | LARGE_MODEL_MODE, 205 | ) 206 | 207 | # %% 208 | # Select just the top-k effects. 209 | truncated_effects: defaultdict[ 210 | int, list[tuple[str, float]] 211 | ] = top_k.select_top_k_tokens(effects, TOP_K) 212 | 213 | # %% 214 | # Populate the table and save it to csv. 215 | populate_table(table, truncated_effects) 216 | print(table) 217 | -------------------------------------------------------------------------------- /sparse_coding/heatmap.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | An activations heatmap for a learned decoder, using `circuitsvis.` 4 | 5 | Requires a HF access token to get `Llama-2`'s tokenizer. 6 | """ 7 | 8 | 9 | import numpy as np 10 | import torch as t 11 | import transformers 12 | import yaml 13 | from circuitsvis.activations import text_neuron_activations 14 | from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer 15 | 16 | 17 | assert ( 18 | transformers.__version__ >= "4.31.0" 19 | ), "Llama-2 70B requires at least transformers 4.31.0" 20 | 21 | # %% 22 | # Set up constants. 23 | DISPLAY_QUESTIONS: int = 10 24 | 25 | with open("act_access.yaml", "r", encoding="utf-8") as f: 26 | try: 27 | access = yaml.safe_load(f) 28 | except yaml.YAMLError as e: 29 | print(e) 30 | with open("act_config.yaml", "r", encoding="utf-8") as f: 31 | try: 32 | config = yaml.safe_load(f) 33 | except yaml.YAMLError as e: 34 | print(e) 35 | HF_ACCESS_TOKEN = access.get("HF_ACCESS_TOKEN", "") 36 | TOKENIZER_DIR = config.get("MODEL_DIR") 37 | PROMPT_IDS_PATH = config.get("PROMPT_IDS_PATH") 38 | ACTS_DATA_PATH = config.get("ACTS_DATA_PATH") 39 | ENCODER_PATH = config.get("ENCODER_PATH") 40 | BIASES_PATH = config.get("BIASES_PATH") 41 | SEED = config.get("SEED") 42 | ACTS_LAYER = config.get("ACTS_LAYER") 43 | tsfm_config = AutoConfig.from_pretrained( 44 | TOKENIZER_DIR, use_auth_token=HF_ACCESS_TOKEN 45 | ) 46 | EMBEDDING_DIM = tsfm_config.hidden_size 47 | PROJECTION_FACTOR = config.get("PROJECTION_FACTOR") 48 | PROJECTION_DIM = int(EMBEDDING_DIM * PROJECTION_FACTOR) 49 | 50 | # %% 51 | # Reproducibility. 52 | t.manual_seed(SEED) 53 | np.random.seed(SEED) 54 | 55 | # %% 56 | # The original tokenizer. 57 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 58 | TOKENIZER_DIR, 59 | use_auth_token=HF_ACCESS_TOKEN, 60 | ) 61 | 62 | # %% 63 | # Rebuild the learned encoder. 64 | imported_weights: t.Tensor = t.load(ENCODER_PATH) 65 | imported_biases: t.Tensor = t.load(BIASES_PATH) 66 | 67 | 68 | class Encoder: 69 | """Reconstruct the encoder as a callable linear layer.""" 70 | 71 | def __init__(self): 72 | """Initialize the encoder.""" 73 | self.encoder_layer = t.nn.Linear(EMBEDDING_DIM, PROJECTION_DIM) 74 | self.encoder_layer.weight.data = imported_weights 75 | self.encoder_layer.bias.data = imported_biases 76 | 77 | self.encoder = t.nn.Sequential(self.encoder_layer, t.nn.ReLU()) 78 | 79 | def __call__(self, inputs): 80 | """Project to the sparse latent space.""" 81 | return self.encoder(inputs) 82 | 83 | 84 | # Instantiate the encoder model. 85 | model: Encoder = Encoder() 86 | 87 | # %% 88 | # Load and prepare the original prompt tokens. 89 | prompts_ids: np.ndarray = np.load(PROMPT_IDS_PATH, allow_pickle=True) 90 | prompts_ids_list = prompts_ids.tolist() 91 | unpacked_prompts_ids = [ 92 | elem for sublist in prompts_ids_list for elem in sublist 93 | ] 94 | 95 | # Convert token_ids into lists of literal tokens. 96 | prompts_strings: list = [] 97 | 98 | for p in unpacked_prompts_ids: 99 | prompt_str: list = tokenizer.convert_ids_to_tokens(p) 100 | processed_prompt_str: list = [ 101 | tokn.replace("Ġ", " ").replace("Ċ", "\n") for tokn in prompt_str 102 | ] 103 | prompts_strings.append(processed_prompt_str) 104 | 105 | 106 | # %% 107 | # Load and prepare the cached model activations. 108 | def unpad_activations( 109 | activations_block: t.Tensor, unpadded_prompts: np.ndarray 110 | ) -> list[t.Tensor]: 111 | """ 112 | Unpads activations to the lengths specified by the original prompts. 113 | 114 | Note that the activation block must come in with dimensions (batch x stream 115 | x embedding_dim), and the unpadded prompts as an array of lists of 116 | elements. 117 | """ 118 | unpadded_activations: list = [] 119 | 120 | for k, unpadded_prompt in enumerate(unpadded_prompts): 121 | original_length: int = len(unpadded_prompt) 122 | # From here on out, activations are unpadded, and so must be packaged 123 | # as a _list of tensors_ instead of as just a tensor block. 124 | unpadded_activations.append(activations_block[k, :original_length, :]) 125 | 126 | return unpadded_activations 127 | 128 | 129 | def project_activations( 130 | acts_list: list[t.Tensor], projector: Encoder 131 | ) -> list[t.Tensor]: 132 | """Projects the activations block over to the sparse latent space.""" 133 | projected_activations: list = [] 134 | 135 | for question in acts_list: 136 | proj_question: list = [] 137 | for activation in question: 138 | # Detach the gradients from the decoder model pass. 139 | proj_question.append(projector(activation).detach()) 140 | 141 | question_block = t.stack(proj_question) 142 | projected_activations.append(question_block) 143 | 144 | return projected_activations 145 | 146 | 147 | def rearrange_for_vis(acts_list: list[t.Tensor]) -> list[t.Tensor]: 148 | """`circuitsvis` wants inputs [(stream x layers x embedding_dim)].""" 149 | rearranged_activations: list = [] 150 | for activations in acts_list: 151 | # We need to unsqueeze the middle dimension of the activations, to get 152 | # the singleton layer dimension. 153 | rearranged_activations.append(t.unsqueeze(activations, 1)) 154 | 155 | return rearranged_activations 156 | 157 | 158 | acts_dataset: t.Tensor = t.load(ACTS_DATA_PATH) 159 | 160 | unpadded_acts: t.Tensor = unpad_activations(acts_dataset, prompts_strings) 161 | projected_acts: list[t.Tensor] = project_activations(unpadded_acts, model) 162 | rearranged_acts: list[t.Tensor] = rearrange_for_vis(projected_acts) 163 | 164 | # %% 165 | # Visualize the activations. 166 | assert DISPLAY_QUESTIONS <= len( 167 | prompts_strings 168 | ), "DISPLAY_QUESTIONS must be less than the number of questions." 169 | 170 | html_interactable = text_neuron_activations( 171 | prompts_strings[:DISPLAY_QUESTIONS], 172 | rearranged_acts[:DISPLAY_QUESTIONS], 173 | "Layer", 174 | "Dimension", 175 | [ACTS_LAYER], 176 | ) 177 | 178 | # %% 179 | # Show the visualization. Note that these render better with one sample per 180 | # page. 181 | html_interactable # pylint: disable=pointless-statement 182 | -------------------------------------------------------------------------------- /sparse_coding/interp_ablations.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """Steer the model with feature dims and observe the resulting completions.""" 3 | 4 | 5 | from contextlib import contextmanager 6 | from typing import Callable, Optional, Tuple 7 | 8 | import accelerate 9 | import numpy as np 10 | import torch as t 11 | import transformers 12 | import yaml 13 | from torch import nn 14 | from transformers import ( 15 | AutoModelForCausalLM, 16 | AutoTokenizer, 17 | GenerationConfig, 18 | PreTrainedModel, 19 | PreTrainedTokenizer, 20 | ) 21 | 22 | # %% 23 | # Set up constants. 24 | ADD_DIM: int = 4226 25 | CHAT_PROMPT: str = "What is going on?" 26 | MAX_NEW_TOKENS: int = 50 27 | NUM_CONTINUATIONS: int = 5 28 | COEFF: int = 1 # For ablations, should always be set to 1. 29 | DO_SAMPLE: bool = True 30 | TEMPERATURE: float = 1.0 31 | TOP_P: float = 0.9 32 | REP_PENALTY: float = 2.0 33 | 34 | with open("act_access.yaml", "r", encoding="utf-8") as f: 35 | try: 36 | access = yaml.safe_load(f) 37 | except yaml.YAMLError as e: 38 | print(e) 39 | with open("act_config.yaml", "r", encoding="utf-8") as f: 40 | try: 41 | config = yaml.safe_load(f) 42 | except yaml.YAMLError as e: 43 | print(e) 44 | HF_ACCESS_TOKEN = access.get("HF_ACCESS_TOKEN", "") 45 | MODEL_DIR = config.get("MODEL_DIR") 46 | ENCODER_PATH = config.get("ENCODER_PATH") 47 | BIASES_PATH = config.get("BIASES_PATH") 48 | SEED = config.get("SEED") 49 | ACTS_LAYER = config.get("ACTS_LAYER") 50 | ACT_NUM: int = ACTS_LAYER # Overridable. 51 | 52 | sampling_kwargs: dict = { 53 | "temperature": TEMPERATURE, 54 | "top_p": TOP_P, 55 | "repetition_penalty": REP_PENALTY, 56 | } 57 | 58 | # Reproducibility. 59 | t.manual_seed(SEED) 60 | np.random.seed(SEED) 61 | 62 | # Set up model. 63 | t.set_grad_enabled(False) 64 | accelerator = accelerate.Accelerator() 65 | model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 66 | MODEL_DIR, 67 | device_map="auto", 68 | use_auth_token=HF_ACCESS_TOKEN, 69 | ) 70 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 71 | MODEL_DIR, use_auth_token=HF_ACCESS_TOKEN 72 | ) 73 | model.eval() 74 | model: PreTrainedModel = accelerator.prepare(model) 75 | print(model) 76 | 77 | # %% 78 | # Declare hooking types. 79 | PreHookFn = Callable[[nn.Module, t.Tensor], Optional[t.Tensor]] 80 | Hook = Tuple[nn.Module, PreHookFn] 81 | Hooks = list[Hook] 82 | 83 | 84 | # %% 85 | # Tokenization functionality. 86 | def tokenize(text: str) -> dict[str, t.Tensor]: 87 | """Tokenize prompts onto the appropriate devices.""" 88 | tokens = tokenizer(text, return_tensors="pt") 89 | # I am unsure why automatic acceleration breaks things here. I do it 90 | # manually as a fix. 91 | tokens.to(model.device) 92 | return tokens 93 | 94 | 95 | # %% 96 | # As a control: run the unmodified base model. 97 | base_tokens = model.generate( 98 | **tokenize([CHAT_PROMPT] * NUM_CONTINUATIONS), 99 | generation_config=GenerationConfig( 100 | **sampling_kwargs, 101 | do_sample=DO_SAMPLE, 102 | max_new_tokens=MAX_NEW_TOKENS, 103 | eos_token_id=tokenizer.eos_token_id, 104 | ), 105 | ) 106 | base_strings = [tokenizer.decode(o) for o in base_tokens] 107 | print(("\n" + "." * 80 + "\n").join(base_strings)) 108 | 109 | 110 | # %% 111 | # Hooking functionality. 112 | @contextmanager 113 | def pre_hooks(hooks: Hooks): 114 | """Register pre-forward hooks with torch.""" 115 | handles = [] 116 | try: 117 | handles = [mod.register_forward_pre_hook(hook) for mod, hook in hooks] 118 | yield 119 | finally: 120 | for handle in handles: 121 | handle.remove() 122 | 123 | 124 | def get_blocks(mod): 125 | """Get the blocks of a model.""" 126 | if isinstance(mod, transformers.LlamaForCausalLM): 127 | return mod.model.layers 128 | if isinstance(mod, transformers.GPTNeoXForCausalLM): 129 | return mod.gpt_neox.layers 130 | raise ValueError(f"Unsupported model type: {type(mod)}.") 131 | 132 | 133 | @contextmanager 134 | def residual_stream(mod: PreTrainedModel, layers: Optional[list[int]] = None): 135 | """Actually build hooks for a model.""" 136 | # TODO Plausibly could be replaced by "output_hidden_states=True" in model 137 | # call. 138 | modded_streams = [None] * len(get_blocks(mod)) 139 | 140 | # Factory function that builds the initial hooks. 141 | def _make_helper_hook(i): 142 | def _helper_hook(_, current_inputs): 143 | modded_streams[i] = current_inputs[0] 144 | 145 | return _helper_hook 146 | 147 | hooks = [ 148 | (layer, _make_helper_hook(i)) 149 | for i, layer in enumerate(get_blocks(mod)) 150 | if i in layers 151 | ] 152 | # Register the hooks. 153 | with pre_hooks(hooks): 154 | yield modded_streams 155 | 156 | 157 | def get_resid_pre(prompt: str, layer_num: int): 158 | """Get residual stream activations for a prompt, just before a layer.""" 159 | # TODO: Automatic addition padding. 160 | with residual_stream(model, layers=[layer_num]) as unmodified_streams: 161 | model(**tokenize(prompt)) 162 | return unmodified_streams[layer_num] 163 | 164 | 165 | # %% 166 | # Get the steering vector from the encoder. 167 | encoder_weights = t.load(ENCODER_PATH) 168 | # Remember that the biases have the shape (PROJECTION_DIM,). 169 | encoder_biases = t.load(BIASES_PATH) 170 | raw_steering_vec = encoder_weights[ADD_DIM] 171 | biased_steering_vec = raw_steering_vec + encoder_biases[ADD_DIM] 172 | relued_steering_vec = t.relu(biased_steering_vec) 173 | steering_vec = relued_steering_vec.unsqueeze(0).unsqueeze(0) 174 | 175 | 176 | # %% 177 | # Run the model with the steering vector * COEFF. 178 | def _steering_hook(_, inpt: tuple): 179 | (resid_pre,) = inpt 180 | if resid_pre.shape[1] == 1: 181 | return 182 | ppos, apos = resid_pre.shape[1], steering_vec.shape[1] 183 | assert ( 184 | apos <= ppos 185 | ), f"More modified streams ({apos}) than prompt streams ({ppos})!" 186 | # Now running ablations. 187 | resid_pre[:, :apos, :] -= COEFF * steering_vec.to(resid_pre.device) 188 | 189 | 190 | layer = get_blocks(model)[ACT_NUM] 191 | with pre_hooks(hooks=[(layer, _steering_hook)]): 192 | steered_tokens = model.generate( 193 | **tokenize([CHAT_PROMPT] * NUM_CONTINUATIONS), 194 | generation_config=GenerationConfig( 195 | **sampling_kwargs, 196 | do_sample=DO_SAMPLE, 197 | max_new_tokens=MAX_NEW_TOKENS, 198 | eos_token_id=tokenizer.eos_token_id, 199 | ), 200 | ) 201 | steered_strings = [tokenizer.decode(o) for o in steered_tokens] 202 | print(("\n" + "-" * 80 + "\n").join(steered_strings)) 203 | -------------------------------------------------------------------------------- /sparse_coding/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/sparse_coding/utils/__init__.py -------------------------------------------------------------------------------- /sparse_coding/utils/configure.py: -------------------------------------------------------------------------------- 1 | """Functions for loading config YAML files.""" 2 | import yaml 3 | 4 | 5 | def load_yaml_constants(): 6 | """Load config files with get() methods.""" 7 | 8 | try: 9 | with open("act_access.yaml", "r", encoding="utf-8") as f: 10 | access = yaml.safe_load(f) 11 | except FileNotFoundError: 12 | print("act_access.yaml not found. Creating it now.") 13 | with open("act_access.yaml", "w", encoding="utf-8") as w: 14 | w.write('HF_ACCESS_TOKEN: ""\n') 15 | access = {} 16 | except yaml.YAMLError as e: 17 | print(e) 18 | with open("act_config.yaml", "r", encoding="utf-8") as f: 19 | try: 20 | config = yaml.safe_load(f) 21 | except yaml.YAMLError as e: 22 | print(e) 23 | 24 | return access, config 25 | -------------------------------------------------------------------------------- /sparse_coding/utils/top_k.py: -------------------------------------------------------------------------------- 1 | """Functions for processing autoencoders into top-k tokens.""" 2 | 3 | 4 | import textwrap 5 | from collections import defaultdict 6 | from math import ceil 7 | 8 | import torch as t 9 | from accelerate import Accelerator 10 | from transformers import AutoTokenizer 11 | 12 | 13 | # `per_input_token_effects` is a linchpin interpretability function. I break up 14 | # its functionality into several tacit dependency functions in this module so 15 | # that it's readable. 16 | def per_input_token_effects( 17 | token_ids_by_q: list[list[int]], 18 | encoder_activations_by_q: list[t.Tensor], 19 | encoder, 20 | tokenizer: AutoTokenizer, 21 | accelerator: Accelerator, 22 | dims_per_batch: int, 23 | large_model_mode: bool, 24 | ) -> defaultdict[int, defaultdict[str, float]]: 25 | """Return the autoencoder's summed activations, at each feature dimension, 26 | at each input token.""" 27 | 28 | # Begin pre-processing. Calulate the number of dimensional batches to run. 29 | print("Starting pre-processing...") 30 | num_dim_batches: int = batching_setup(dims_per_batch, encoder) 31 | 32 | # Initialize the effects dictionary. 33 | effect_scalar_by_dim_by_input_token = defaultdict(defaultdict_factory) 34 | 35 | # Pre-process `token_ids_by_q`. 36 | flat_input_token_ids, unique_input_token_ids = pre_process_input_token_ids( 37 | token_ids_by_q, encoder, accelerator, large_model_mode 38 | ) 39 | 40 | print("Pre-processing complete!") 41 | effect_scalar_by_dim_by_input_token = batches_loop( 42 | num_dim_batches, 43 | dims_per_batch, 44 | encoder_activations_by_q, 45 | encoder, 46 | accelerator, 47 | tokenizer, 48 | effect_scalar_by_dim_by_input_token, 49 | unique_input_token_ids, 50 | flat_input_token_ids, 51 | large_model_mode, 52 | ) 53 | 54 | return effect_scalar_by_dim_by_input_token 55 | 56 | 57 | # Helper functions for `per_token_effects`. 58 | def modal_tensor_acceleration( 59 | tensor: t.Tensor, encoder, accelerator: Accelerator, large_model_mode: bool 60 | ) -> t.Tensor: 61 | """Accelerate a tensor; manually move it where the accelerator fails.""" 62 | if large_model_mode is False: 63 | tensor = tensor.to(encoder.encoder_layer.weight.device) 64 | tensor = accelerator.prepare(tensor) 65 | 66 | return tensor 67 | 68 | 69 | def batching_setup(dims_per_batch: int, encoder) -> int: 70 | """Determine the number of dimensional batches to be run.""" 71 | num_dim_batches: int = ceil( 72 | encoder.encoder_layer.weight.shape[0] / dims_per_batch 73 | ) 74 | print(f"Total number of batches to be run: {num_dim_batches}") 75 | 76 | return num_dim_batches 77 | 78 | 79 | def defaultdict_factory(): 80 | """Factory for string defaultdicts.""" 81 | return defaultdict(str) 82 | 83 | 84 | def pre_process_input_token_ids( 85 | token_ids_by_q, encoder, accelerator, large_model_mode 86 | ): 87 | """Pre-process the `token_ids_by_q`.""" 88 | 89 | # Flatten the input token ids. 90 | flat_input_token_ids = [ 91 | input_token_id 92 | for question in token_ids_by_q 93 | for input_token_id in question 94 | ] 95 | 96 | # Deduplicate the `flat_input_token_ids`. 97 | unique_input_token_ids = list(set(flat_input_token_ids)) 98 | 99 | # Tensorize and accelerate `flat_input_token_ids`. 100 | flat_input_token_ids = t.tensor(flat_input_token_ids) 101 | flat_input_token_ids = modal_tensor_acceleration( 102 | flat_input_token_ids, encoder, accelerator, large_model_mode 103 | ) 104 | 105 | return flat_input_token_ids, unique_input_token_ids 106 | 107 | 108 | def batches_loop( 109 | num_dim_batches: int, 110 | dims_per_batch: int, 111 | encoder_activations_by_q, 112 | encoder, 113 | accelerator: Accelerator, 114 | tokenizer: AutoTokenizer, 115 | effect_scalar_by_dim_by_input_token, 116 | unique_input_token_ids, 117 | flat_input_token_ids, 118 | large_model_mode: bool, 119 | ) -> defaultdict[int, defaultdict[str, float]]: 120 | """Loop over the batches while printing current progress.""" 121 | 122 | starting_dim_index, ending_dim_index = 0, 0 123 | 124 | for batch in range(num_dim_batches): 125 | print(f"Starting batch {batch+1} of {num_dim_batches}...") 126 | 127 | ending_dim_index += dims_per_batch 128 | if ending_dim_index > encoder.encoder_layer.weight.shape[0]: 129 | ending_dim_index = encoder.encoder_layer.weight.shape[0] 130 | 131 | if batch + 1 > num_dim_batches: 132 | assert starting_dim_index - ending_dim_index == dims_per_batch 133 | elif batch + 1 == num_dim_batches: 134 | assert starting_dim_index - ending_dim_index <= dims_per_batch 135 | 136 | # Note that `batched_dims_from_encoder_activations` has 137 | # lost the question data that `encoder_activations_by_q` had. 138 | batched_dims_from_encoder_activations = ( 139 | pre_process_encoder_activations_by_batch( 140 | encoder_activations_by_q, 141 | dims_per_batch, 142 | encoder, 143 | accelerator, 144 | starting_dim_index, 145 | ending_dim_index, 146 | large_model_mode, 147 | ) 148 | ) 149 | 150 | assert not t.isnan(batched_dims_from_encoder_activations).any() 151 | 152 | for input_token_id in unique_input_token_ids: 153 | input_token_string = tokenizer.convert_ids_to_tokens( 154 | input_token_id 155 | ) 156 | dims_from_encoder_activations_at_input_token_in_batch = ( 157 | filter_encoder_activations_by_input_token( 158 | flat_input_token_ids, 159 | input_token_id, 160 | batched_dims_from_encoder_activations, 161 | ) 162 | ) 163 | averaged_dim_from_encoder_activations_at_input_token_in_batch = ( 164 | average_encoder_activations_at_input_token( 165 | dims_from_encoder_activations_at_input_token_in_batch, 166 | ) 167 | ) 168 | 169 | # Add the averaged activations on to the effects dictionary. 170 | for dim_in_batch, averaged_activation_per_dim in enumerate( 171 | averaged_dim_from_encoder_activations_at_input_token_in_batch 172 | ): 173 | effect_scalar_by_dim_by_input_token[ 174 | starting_dim_index + dim_in_batch 175 | ][input_token_string] = averaged_activation_per_dim.item() 176 | 177 | print( 178 | textwrap.dedent( 179 | f""" 180 | Batch {batch+1} complete: data for encoder dims indices 181 | {starting_dim_index} through {ending_dim_index-1} appended! 182 | """ 183 | ) 184 | ) 185 | 186 | # Update `starting_dim_index` for the next batch. 187 | starting_dim_index = ending_dim_index 188 | 189 | return effect_scalar_by_dim_by_input_token 190 | 191 | 192 | def pre_process_encoder_activations_by_batch( 193 | encoder_activations_by_q, 194 | dims_per_batch, 195 | encoder, 196 | accelerator, 197 | starting_dim_index, 198 | ending_dim_index, 199 | large_model_mode, 200 | ) -> t.Tensor: 201 | """Pre-process the `encoder_activations_by_q` for each batch.""" 202 | batched_dims_from_encoder_activations: list = [] 203 | 204 | for question_block in encoder_activations_by_q: 205 | batched_dims_from_encoder_activations.append( 206 | question_block[:, starting_dim_index:ending_dim_index] 207 | ) 208 | 209 | batched_dims_from_encoder_activations = accelerator.prepare( 210 | batched_dims_from_encoder_activations 211 | ) 212 | # Remove the question data. 213 | batched_dims_from_encoder_activations: t.Tensor = t.cat( 214 | batched_dims_from_encoder_activations, dim=0 215 | ) 216 | 217 | assert batched_dims_from_encoder_activations.shape[1] <= dims_per_batch 218 | 219 | # Accelerate `batched_dims_from_encoder_activations`. 220 | batched_dims_from_encoder_activations = modal_tensor_acceleration( 221 | batched_dims_from_encoder_activations, 222 | encoder, 223 | accelerator, 224 | large_model_mode, 225 | ) 226 | 227 | return batched_dims_from_encoder_activations 228 | 229 | 230 | # Remember that dimensional batch slicing is already done coming in. 231 | def filter_encoder_activations_by_input_token( 232 | flat_input_token_ids: t.Tensor, 233 | input_token_id: int, 234 | batched_dims_from_encoder_activations: t.Tensor, 235 | ): 236 | """Isolate just the activations at an input token id.""" 237 | indices_of_encoder_activations_at_input_token = t.nonzero( 238 | flat_input_token_ids == input_token_id 239 | ) 240 | flat_indices_of_encoder_activations_at_input_token = ( 241 | indices_of_encoder_activations_at_input_token.squeeze(dim=1) 242 | ) 243 | 244 | # Fancy index along dim=0. 245 | dims_from_encoder_activations_at_input_token_in_batch = ( 246 | batched_dims_from_encoder_activations[ 247 | flat_indices_of_encoder_activations_at_input_token 248 | ] 249 | ) 250 | 251 | return dims_from_encoder_activations_at_input_token_in_batch 252 | 253 | 254 | def average_encoder_activations_at_input_token( 255 | dims_from_encoder_activations_at_input_token_in_batch, 256 | ): 257 | """Average over encoder activations at a common input token.""" 258 | 259 | # Average across dimensional instances. 260 | averaged_dim_from_encoder_activations_at_input_token_in_batch = t.mean( 261 | dims_from_encoder_activations_at_input_token_in_batch, dim=0 262 | ) 263 | 264 | assert ( 265 | len( 266 | averaged_dim_from_encoder_activations_at_input_token_in_batch.shape 267 | ) 268 | == 1 269 | ), "Tensor has more than one dimension! It should be a vector." 270 | 271 | assert not t.isnan( 272 | averaged_dim_from_encoder_activations_at_input_token_in_batch 273 | ).any(), "Processed tensor contains NaNs!" 274 | 275 | return averaged_dim_from_encoder_activations_at_input_token_in_batch 276 | 277 | 278 | # All other `top-k` functions below. 279 | def project_activations( 280 | acts_list: list[t.Tensor], 281 | projector, 282 | accelerator: Accelerator, 283 | ) -> list[t.Tensor]: 284 | """Projects the activations block over to the sparse latent space.""" 285 | 286 | # Remember the original question lengths. 287 | lengths: list[int] = [len(question) for question in acts_list] 288 | 289 | flat_acts: t.Tensor = t.cat(acts_list, dim=0) 290 | flat_acts: t.Tensor = accelerator.prepare(flat_acts) 291 | projected_flat_acts: t.Tensor = projector(flat_acts).detach() 292 | 293 | # Reconstruct the original question lengths. 294 | projected_activations: list[t.Tensor] = [] 295 | current_idx: int = 0 296 | for length in lengths: 297 | projected_activations.append( 298 | projected_flat_acts[current_idx : current_idx + length, :] 299 | ) 300 | current_idx += length 301 | 302 | return projected_activations 303 | 304 | 305 | def select_top_k_tokens( 306 | effects_dict: defaultdict[int, defaultdict[str, float]], 307 | top_k: int, 308 | ) -> defaultdict[int, list[tuple[str, float]]]: 309 | """Select the top-k tokens for each feature.""" 310 | tp_k_tokens = defaultdict(list) 311 | 312 | for feature_dim, tokens_dict in effects_dict.items(): 313 | # Sort tokens by their dimension activations. 314 | sorted_effects: list[tuple[str, float]] = sorted( 315 | tokens_dict.items(), key=lambda x: x[1], reverse=True 316 | ) 317 | # Add the top-k tokens. 318 | tp_k_tokens[feature_dim] = sorted_effects[:top_k] 319 | 320 | return tp_k_tokens 321 | 322 | 323 | def unpad_activations( 324 | activations_block: t.Tensor, unpadded_prompts: list[list[int]] 325 | ) -> list[t.Tensor]: 326 | """ 327 | Unpads activations to the lengths specified by the original prompts. 328 | 329 | Note that the activation block must come in with dimensions (batch x stream 330 | x embedding_dim), and the unpadded prompts as an array of lists of 331 | elements. 332 | """ 333 | unpadded_activations: list = [] 334 | 335 | for k, unpadded_prompt in enumerate(unpadded_prompts): 336 | try: 337 | original_length: int = len(unpadded_prompt) 338 | # From here on out, activations are unpadded, and so must be 339 | # packaged as a _list of tensors_ instead of as just a tensor 340 | # block. 341 | unpadded_activations.append( 342 | activations_block[k, :original_length, :] 343 | ) 344 | except IndexError: 345 | print(f"IndexError at {k}") 346 | # This should only occur when the data collection was interrupted. 347 | # In that case, we just break when the data runs short. 348 | break 349 | 350 | return unpadded_activations 351 | -------------------------------------------------------------------------------- /tests/smoke_test_access.yaml: -------------------------------------------------------------------------------- 1 | HF_ACCESS_TOKEN: "" 2 | -------------------------------------------------------------------------------- /tests/smoke_test_config.yaml: -------------------------------------------------------------------------------- 1 | # Smoke Test Model 2 | MODEL_DIR: "EleutherAI/pythia-70m" 3 | ACTS_LAYER: 1 4 | 5 | # Large Model Mode 6 | LARGE_MODEL_MODE: False 7 | 8 | # Encoder Size 9 | PROJECTION_FACTOR: 1 10 | 11 | # Smoke Test Save Paths 12 | PROMPT_IDS_PATH: "smoke_test_data/smoke_test_activations_prompt_ids.npy" 13 | ACTS_DATA_PATH: "smoke_test_data/smoke_test_activations_dataset.pt" 14 | ENCODER_PATH: "smoke_test_data/smoke_test_learned_encoder.pt" 15 | BIASES_PATH: "smoke_test_data/smoke_test_learned_biases.pt" 16 | TOP_K_INFO_PATH: "smoke_test_data/smoke_test_token_info.csv" 17 | 18 | # Autoencoder Training 19 | LAMBDA_L1: 1.0e-4 20 | LEARNING_RATE: 1.0e-4 21 | NUM_WORKERS: 0 22 | 23 | # Reproducibility 24 | SEED: 0 25 | 26 | # Smoke Test Constants (`acts_collect.py`) 27 | MAX_NEW_TOKENS: 1 28 | NUM_RETURN_SEQUENCES: 1 29 | NUM_SHOT: 6 30 | NUM_QUESTIONS_EVALED: 10 31 | 32 | # Smoke Test Constants (`autoencoder.py`) 33 | LOG_EVERY_N_STEPS: 5 34 | EPOCHS: 1 35 | SYNC_DIST_LOGGING: True 36 | 37 | # Smoke Test Constants (`feature_tokens.py`) 38 | # _Leave out entries_ for None: None values will be interpreted as "None" 39 | # strings. 40 | TOP_K: 6 41 | DIMS_IN_BATCH: 200 42 | -------------------------------------------------------------------------------- /tests/smoke_test_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/tests/smoke_test_data/.gitkeep -------------------------------------------------------------------------------- /tests/sweep_over_prompts_cache.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/montemac/activation_additions/cc3178cb813b640cd9644cf656d43a51e28869bd/tests/sweep_over_prompts_cache.pkl -------------------------------------------------------------------------------- /tests/test_experiments.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """Test suite for logits.py""" 3 | import pytest 4 | 5 | import torch 6 | 7 | from transformer_lens import HookedTransformer 8 | 9 | from activation_additions import utils, experiments 10 | 11 | utils.enable_ipython_reload() 12 | 13 | 14 | @pytest.fixture(name="model") 15 | def fixture_model() -> HookedTransformer: 16 | """Test fixture that returns a small pre-trained transformer used 17 | for fast metric testing.""" 18 | return HookedTransformer.from_pretrained( 19 | model_name="attn-only-2l", device="cpu" 20 | ) 21 | 22 | 23 | def test_get_token_probs(model): 24 | """Test get_token_probs() function.""" 25 | avg_logprob, perplexity, logprobs = experiments.get_stats_over_corpus( 26 | model=model, corpus_texts=["This is a test sentence."] 27 | ) 28 | assert avg_logprob == pytest.approx(-5.2312, abs=1e-4) 29 | assert perplexity == pytest.approx(187.008480, abs=1e-4) 30 | assert torch.allclose( 31 | logprobs, 32 | torch.tensor([-7.7388, -1.4853, -1.3223, -6.7956, -12.2510, -1.7938]), 33 | atol=1e-4, 34 | ) 35 | avg_logprob_mask_len, _, _ = experiments.get_stats_over_corpus( 36 | model=model, 37 | corpus_texts=["This is a test sentence."], 38 | mask_len=2, 39 | ) 40 | assert avg_logprob_mask_len == pytest.approx(logprobs[2:].mean(), abs=1e-4) 41 | -------------------------------------------------------------------------------- /tests/test_hook_utils.py: -------------------------------------------------------------------------------- 1 | """ Tests for the `hook_utils` module""" 2 | 3 | from typing import Callable, List 4 | import torch 5 | import pytest 6 | 7 | from transformer_lens.HookedTransformer import HookedTransformer 8 | 9 | from activation_additions import hook_utils, prompt_utils 10 | from activation_additions.prompt_utils import ActivationAddition 11 | 12 | 13 | # Fixtures 14 | @pytest.fixture(name="attn_2l_model", scope="module") 15 | def fixture_model() -> HookedTransformer: 16 | """Test fixture that returns a small pre-trained transformer.""" 17 | return HookedTransformer.from_pretrained( 18 | model_name="attn-only-2l", device="cpu" 19 | ) 20 | 21 | 22 | def test_hook_fn_from_slice(): 23 | """Test that we can selectively modify a portion of the residual stream.""" 24 | input_tensor: torch.Tensor = torch.zeros((1, 2, 4)) 25 | activations: torch.Tensor = 2 * torch.ones((1, 2, 4)) 26 | 27 | # Modify these parts of the residual stream 28 | residual_stream_slice: slice = slice(1, 3) # from 1 to 3 (exclusive) 29 | 30 | hook_fn: Callable = hook_utils.hook_fn_from_activations( 31 | activations=activations, res_stream_slice=residual_stream_slice 32 | ) 33 | 34 | target_tens: torch.Tensor = torch.tensor([[[0, 2, 2, 0], [0, 2, 2, 0]]]) 35 | result_tens: torch.Tensor = hook_fn(input_tensor) 36 | 37 | assert torch.eq(result_tens, target_tens).all(), "Slice test failed" 38 | 39 | 40 | def test_hook_fn_from_activations(): 41 | """Testing the front and back modifiers of the addition_location setting.""" 42 | input_tensor: torch.Tensor = torch.ones((1, 10, 1)) 43 | activations: torch.Tensor = 2 * torch.ones((1, 4, 1)) 44 | 45 | back_target: torch.Tensor = torch.tensor([[1, 1, 1, 1, 1, 1, 3, 3, 3, 3]]) 46 | back_target: torch.Tensor = back_target.unsqueeze(0).unsqueeze(-1) 47 | 48 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 49 | activations=activations, addition_location="back" 50 | ) 51 | result: torch.Tensor = hook_fxn(input_tensor) 52 | 53 | assert torch.eq(result, back_target).all() 54 | 55 | # this needs to be repeated because it did replacements in-place and the tensor is now modified 56 | input_tensor: torch.Tensor = torch.ones((1, 10, 1)) 57 | activations: torch.Tensor = 2 * torch.ones((1, 4, 1)) 58 | 59 | front_target: torch.Tensor = torch.tensor([[3, 3, 3, 3, 1, 1, 1, 1, 1, 1]]) 60 | front_target: torch.Tensor = front_target.unsqueeze(0).unsqueeze(-1) 61 | 62 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 63 | activations=activations, addition_location="front" 64 | ) 65 | result: torch.Tensor = hook_fxn(input_tensor) 66 | 67 | assert torch.eq(result, front_target).all() 68 | 69 | 70 | def test_hook_fn_from_activations_mid_even(): 71 | """Testing the mid modifiers of the addition_location setting.""" 72 | input_tensor: torch.Tensor = torch.ones((1, 10, 1)) 73 | activations: torch.Tensor = 2 * torch.ones((1, 4, 1)) 74 | 75 | mid_target: torch.Tensor = torch.tensor([[1, 1, 1, 3, 3, 3, 3, 1, 1, 1]]) 76 | mid_target: torch.Tensor = mid_target.unsqueeze(0).unsqueeze(-1) 77 | 78 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 79 | activations=activations, addition_location="mid" 80 | ) 81 | result: torch.Tensor = hook_fxn(input_tensor) 82 | 83 | assert torch.eq(result, mid_target).all() 84 | 85 | 86 | def test_hook_fn_from_activations_mid_odd_in(): 87 | """Testing the mid modifiers of the addition_location setting.""" 88 | input_tensor: torch.Tensor = torch.ones((1, 9, 1)) 89 | activations: torch.Tensor = 2 * torch.ones((1, 4, 1)) 90 | 91 | mid_target: torch.Tensor = torch.tensor([[1, 1, 3, 3, 3, 3, 1, 1, 1]]) 92 | mid_target: torch.Tensor = mid_target.unsqueeze(0).unsqueeze(-1) 93 | 94 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 95 | activations=activations, addition_location="mid" 96 | ) 97 | result: torch.Tensor = hook_fxn(input_tensor) 98 | 99 | assert torch.eq(result, mid_target).all() 100 | 101 | 102 | def test_hook_fn_from_activations_mid_odd_act(): 103 | """Testing the mid modifiers of the addition_location setting.""" 104 | input_tensor: torch.Tensor = torch.ones((1, 10, 1)) 105 | activations: torch.Tensor = 2 * torch.ones((1, 3, 1)) 106 | 107 | mid_target: torch.Tensor = torch.tensor([[1, 1, 1, 1, 3, 3, 3, 1, 1, 1]]) 108 | mid_target: torch.Tensor = mid_target.unsqueeze(0).unsqueeze(-1) 109 | 110 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 111 | activations=activations, addition_location="mid" 112 | ) 113 | result: torch.Tensor = hook_fxn(input_tensor) 114 | 115 | assert torch.eq(result, mid_target).all() 116 | 117 | 118 | def test_hook_fn_from_activations_mid_both_odd(): 119 | """Testing the mid modifiers of the addition_location setting.""" 120 | input_tensor: torch.Tensor = torch.ones((1, 9, 1)) 121 | activations: torch.Tensor = 2 * torch.ones((1, 3, 1)) 122 | 123 | mid_target: torch.Tensor = torch.tensor([[1, 1, 1, 3, 3, 3, 1, 1, 1]]) 124 | mid_target: torch.Tensor = mid_target.unsqueeze(0).unsqueeze(-1) 125 | 126 | hook_fxn: Callable = hook_utils.hook_fn_from_activations( 127 | activations=activations, addition_location="mid" 128 | ) 129 | result: torch.Tensor = hook_fxn(input_tensor) 130 | 131 | assert torch.eq(result, mid_target).all() 132 | 133 | 134 | def test_magnitudes_zeros(attn_2l_model): 135 | """Test that the magnitudes of a coeff-zero ActivationAddition are zero.""" 136 | # Create a ActivationAddition with all zeros 137 | act_add = ActivationAddition(prompt="Test", coeff=0, act_name=0) 138 | 139 | # Get the magnitudes 140 | magnitudes: torch.Tensor = hook_utils.steering_vec_magnitudes( 141 | act_adds=[act_add], model=attn_2l_model 142 | ) 143 | 144 | # Check that they're all zero 145 | assert torch.all(magnitudes == 0), "Magnitudes are not all zero" 146 | assert len(magnitudes.shape) == 1, "Magnitudes are not the right shape" 147 | 148 | 149 | def test_magnitudes_cancels(attn_2l_model): 150 | """Test that the magnitudes are zero when the ActivationAdditions are exact opposites.""" 151 | # Create a ActivationAddition with all zeros 152 | additions: List[ActivationAddition] = [ 153 | ActivationAddition(prompt="Test", coeff=1, act_name=0), 154 | ActivationAddition(prompt="Test", coeff=-1, act_name=0), 155 | ] 156 | 157 | # Get the magnitudes 158 | magnitudes: torch.Tensor = hook_utils.steering_vec_magnitudes( 159 | act_adds=additions, model=attn_2l_model 160 | ) 161 | 162 | # Check that they're all zero 163 | assert torch.all(magnitudes == 0), "Magnitudes are not all zero" 164 | 165 | 166 | def test_multi_layers_not_allowed(attn_2l_model): 167 | """Try injecting a ActivationAddition with multiple layers, which should 168 | fail.""" 169 | additions: List[ActivationAddition] = [ 170 | ActivationAddition(prompt="Test", coeff=1, act_name=0), 171 | ActivationAddition(prompt="Test", coeff=1, act_name=1), 172 | ] 173 | 174 | with pytest.raises(NotImplementedError): 175 | hook_utils.steering_vec_magnitudes( 176 | act_adds=additions, model=attn_2l_model 177 | ) 178 | 179 | 180 | def test_multi_same_layer(attn_2l_model): 181 | """Try injecting a ActivationAddition with multiple additions to the same 182 | layer, which should succeed, even if the injections have different 183 | tokenization lengths.""" 184 | additions_same: List[ActivationAddition] = [ 185 | ActivationAddition(prompt="Test", coeff=1, act_name=0), 186 | ActivationAddition(prompt="Test2521", coeff=1, act_name=0), 187 | ] 188 | 189 | magnitudes: torch.Tensor = hook_utils.steering_vec_magnitudes( 190 | act_adds=additions_same, model=attn_2l_model 191 | ) 192 | assert len(magnitudes.shape) == 1, "Magnitudes are not the right shape" 193 | # Assert not all zeros 194 | assert torch.any(magnitudes != 0), "Magnitudes are all zero?" 195 | 196 | 197 | def test_prompt_magnitudes(attn_2l_model): 198 | """Test that the magnitudes of a prompt are not zero.""" 199 | # Create a ActivationAddition with all zeros 200 | act_add = ActivationAddition(prompt="Test", coeff=1, act_name=0) 201 | 202 | # Get the steering vector magnitudes 203 | steering_magnitudes: torch.Tensor = hook_utils.steering_vec_magnitudes( 204 | act_adds=[act_add], model=attn_2l_model 205 | ) 206 | prompt_magnitudes: torch.Tensor = hook_utils.prompt_magnitudes( 207 | prompt="Test", 208 | model=attn_2l_model, 209 | act_name=prompt_utils.get_block_name(block_num=0), 210 | ) 211 | 212 | # Check that these magnitudes are equal 213 | assert torch.allclose( 214 | steering_magnitudes, prompt_magnitudes 215 | ), "Magnitudes are not equal" 216 | assert ( 217 | len(prompt_magnitudes.shape) == 1 218 | ), "Prompt magnitudes are not the right shape" 219 | 220 | 221 | def test_relative_mags_ones(attn_2l_model): 222 | """Test whether the relative magnitudes are one for a prompt and 223 | its own ActivationAddition.""" 224 | act_add = ActivationAddition(prompt="Test", coeff=1, act_name=0) 225 | rel_mags: torch.Tensor = hook_utils.steering_magnitudes_relative_to_prompt( 226 | prompt="Test", 227 | model=attn_2l_model, 228 | act_adds=[act_add], 229 | ) 230 | 231 | # Assert these are all 1s 232 | assert torch.allclose( 233 | rel_mags, torch.ones_like(rel_mags) 234 | ), "Relative mags not 1" 235 | assert ( 236 | len(rel_mags.shape) == 1 237 | ), "Relative mags should only have the sequence dim" 238 | 239 | 240 | def test_relative_mags_diff_shape(attn_2l_model): 241 | """Test that a long prompt and a short ActivationAddition can be compared, 242 | and vice versa.""" 243 | long_add = ActivationAddition( 244 | prompt="Test2521531lk dsa ;las", coeff=1, act_name=0 245 | ) 246 | short_add = ActivationAddition(prompt="Test", coeff=1, act_name=0) 247 | long_prompt: str = "Test2521531lk dsa ;las" 248 | short_prompt: str = "Test" 249 | 250 | # Get the relative magnitudes 251 | for add, prompt in zip([long_add, short_add], [short_prompt, long_prompt]): 252 | assert len(add.prompt) != len( 253 | prompt 254 | ), "Prompt and ActivationAddition are the same length" 255 | _ = hook_utils.steering_magnitudes_relative_to_prompt( 256 | prompt=prompt, 257 | model=attn_2l_model, 258 | act_adds=[add], 259 | ) 260 | -------------------------------------------------------------------------------- /tests/test_lenses.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ Test suite for lenses.py """ 3 | 4 | import pytest 5 | import torch 6 | 7 | from transformer_lens import HookedTransformer 8 | from tuned_lens import TunedLens 9 | from transformers import AutoModelForCausalLM 10 | 11 | from activation_additions.prompt_utils import get_x_vector 12 | from activation_additions import lenses, utils 13 | 14 | utils.enable_ipython_reload() 15 | 16 | # smallest tuned lens supported model 17 | MODEL = "EleutherAI/pythia-70m-deduped" 18 | 19 | 20 | @pytest.fixture(name="model") 21 | def fixture_model() -> HookedTransformer: 22 | """Test fixture that returns a small pre-trained transformer used 23 | for fast logging testing.""" 24 | 25 | torch.set_grad_enabled(False) 26 | hf_model = AutoModelForCausalLM.from_pretrained(MODEL) 27 | model = HookedTransformer.from_pretrained( 28 | model_name=MODEL, hf_model=hf_model, device="cpu" 29 | ) 30 | model.hf_model = hf_model 31 | model.eval() 32 | return model 33 | 34 | 35 | @pytest.fixture(name="tuned_lens") 36 | def fixture_tuned_lens(model: HookedTransformer) -> lenses.TunedLens: 37 | """Test fixture that returns a small pre-trained transformer used 38 | for fast logging testing.""" 39 | return TunedLens.from_model_and_pretrained( 40 | model.hf_model, lens_resource_id=MODEL, map_location="cpu" # type: ignore 41 | ).to("cpu") 42 | 43 | 44 | def test_lenses(model, tuned_lens): 45 | """ 46 | Checks no exceptions are raised when using lenses are intended. 47 | """ 48 | 49 | prompt = "I hate you because" 50 | 51 | activation_additions = [ 52 | *get_x_vector( 53 | prompt1="Love", 54 | prompt2="Hate", 55 | coeff=5, 56 | act_name=2, 57 | pad_method="tokens_right", 58 | model=model, 59 | custom_pad_id=model.to_single_token(" "), 60 | ) 61 | ] 62 | 63 | dataframes, caches = lenses.run_hooked_and_normal_with_cache( 64 | model=model, 65 | activation_additions=activation_additions, 66 | gen_args={"prompt_batch": [prompt] * 1, "seed": 0}, 67 | ) 68 | 69 | _ = lenses.prediction_trajectories( 70 | caches, dataframes, model.tokenizer, tuned_lens 71 | ) 72 | -------------------------------------------------------------------------------- /tests/test_logging.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ Test suite for logging.py """ 3 | 4 | import pytest 5 | 6 | import pandas as pd 7 | from transformer_lens import HookedTransformer 8 | 9 | from activation_additions import ( 10 | logging, 11 | completion_utils, 12 | prompt_utils, 13 | utils, 14 | ) 15 | 16 | utils.enable_ipython_reload() 17 | 18 | 19 | @pytest.fixture(name="model") 20 | def fixture_model() -> HookedTransformer: 21 | """Test fixture that returns a small pre-trained transformer used 22 | for fast logging testing.""" 23 | return HookedTransformer.from_pretrained( 24 | model_name="attn-only-2l", device="cpu" 25 | ) 26 | 27 | 28 | # In order for these tests to work, you must have a wandb account and 29 | # have set up your wandb API key. See https://docs.wandb.ai/quickstart 30 | def test_logging(model): 31 | """Tests a sweep over prompts with logging enabled. Verifies that 32 | the correct data is uploaded to a new wandb run.""" 33 | # TODO: do this properly with pytest config 34 | pytest.skip("Logging testing is slow! Change this line to enable it.") 35 | # Perform a completion test 36 | results: pd.DataFrame = completion_utils.gen_using_activation_additions( 37 | model=model, 38 | activation_additions=[ 39 | prompt_utils.ActivationAddition( 40 | prompt="Love", 41 | act_name=prompt_utils.get_block_name(block_num=0), 42 | coeff=1.0, 43 | ), 44 | prompt_utils.ActivationAddition( 45 | prompt="Fear", 46 | act_name=prompt_utils.get_block_name(block_num=0), 47 | coeff=-1.0, 48 | ), 49 | ], 50 | prompt_batch=["This is a test", "Let's talk about"], 51 | log={"tags": ["test"], "notes": "testing"}, 52 | ) 53 | # Download the artifact data and convert to a DataFrame 54 | results_logged = logging.get_objects_from_run( 55 | logging.last_run_info["path"], 56 | )["gen_using_activation_additions"] 57 | 58 | print(results, results_logged) 59 | 60 | # Compare with the reference DataFrame 61 | pd.testing.assert_frame_equal(results, results_logged) 62 | 63 | 64 | def test_positional_args(model): 65 | """Function test a call to a loggable function using positional 66 | arguments, which were initially not supported by the loggable 67 | decorator.""" 68 | completion_utils.print_n_comparisons( 69 | "I think you're ", 70 | model, 71 | num_comparisons=5, 72 | activation_additions=[], 73 | seed=0, 74 | ) 75 | -------------------------------------------------------------------------------- /tests/test_logits.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """Test suite for logits.py""" 3 | import pytest 4 | 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from transformer_lens import HookedTransformer 9 | 10 | from activation_additions import utils, logits 11 | 12 | utils.enable_ipython_reload() 13 | 14 | 15 | @pytest.fixture(name="model") 16 | def fixture_model() -> HookedTransformer: 17 | """Test fixture that returns a small pre-trained transformer used 18 | for fast metric testing.""" 19 | return HookedTransformer.from_pretrained( 20 | model_name="attn-only-2l", device="cpu" 21 | ) 22 | 23 | 24 | def test_get_token_probs(model): 25 | """Test get_token_probs() function.""" 26 | probs = logits.get_token_probs(model, "My name is") 27 | assert isinstance(probs, pd.DataFrame) 28 | assert probs.shape == (4, 96524) 29 | assert probs.columns.levels[0].to_list() == ["probs", "logprobs"] # type: ignore 30 | alice_token = int(model.to_single_token(" Alice")) 31 | bob_token = int(model.to_single_token(" Bob")) 32 | k_token = int(model.to_single_token(" K")) 33 | assert np.allclose(probs.iloc[-1].loc[("probs", alice_token)], 0.000495836) 34 | assert np.allclose(probs.iloc[-1].loc[("probs", bob_token)], 0.002270653) 35 | assert np.allclose(probs.iloc[-1].loc[("probs", k_token)], 0.01655953) 36 | 37 | 38 | # TODO: write more tests! 39 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Test suite for metrics.py""" 2 | from typing import Callable, List 3 | import pytest 4 | 5 | import pandas as pd 6 | import openai 7 | 8 | from transformer_lens import HookedTransformer 9 | 10 | from activation_additions import metrics, completion_utils, utils 11 | 12 | utils.enable_ipython_reload() 13 | 14 | 15 | @pytest.fixture(name="model") 16 | def fixture_model() -> HookedTransformer: 17 | """Test fixture that returns a small pre-trained transformer used 18 | for fast metric testing.""" 19 | return HookedTransformer.from_pretrained( 20 | model_name="attn-only-2l", device="cpu" 21 | ) 22 | 23 | 24 | def test_get_sentiment_metric(): 25 | """Test for get_sentiment_metric(). Creates a sentiment metric, 26 | applies it to some strings, and checks the results against 27 | pre-defined constants.""" 28 | metric: Callable = metrics.get_sentiment_metric( 29 | "distilbert-base-uncased-finetuned-sst-2-english", ["POSITIVE"] 30 | ) 31 | prompts: List[str] = [ 32 | "I love chocolate", 33 | "I hate chocolate", 34 | ] 35 | results: pd.DataFrame = metric(prompts, False, pd.Index(["a", "b"])) 36 | target = pd.DataFrame( 37 | { 38 | "label": ["POSITIVE", "NEGATIVE"], 39 | "score": [0.999846, 0.998404], 40 | "is_positive": [True, False], 41 | }, 42 | index=["a", "b"], 43 | ) 44 | pd.testing.assert_frame_equal(results, target) 45 | 46 | 47 | def test_get_word_count_metric(): 48 | """Test for get_sentiment_metric(). Creates a word count metric, 49 | applies it to some strings, and checks the results against 50 | pre-defined constants.""" 51 | metric: Callable = metrics.get_word_count_metric( 52 | ["dog", "dogs", "puppy", "puppies"] 53 | ) 54 | prompts: List[str] = [ 55 | "Dogs and puppies are the best!", 56 | "Look at that cute dog with a puppy over there.", 57 | ] 58 | results: pd.DataFrame = metric(prompts, False, None) 59 | target = pd.DataFrame( 60 | {"count": [2, 2]}, 61 | ) 62 | pd.testing.assert_frame_equal(results, target) 63 | 64 | 65 | def test_openai_metric(): 66 | """Test for get_openai_metric(). Creates an OpenAI metric, applies 67 | it to some strings, and checks the results against pre-defined 68 | constants.""" 69 | if openai.api_key is None: 70 | pytest.skip("OpenAI API key not found.") 71 | 72 | metric: Callable = metrics.get_openai_metric("text-davinci-003", "happy") 73 | prompts: List[str] = ["I love chocolate!", "I hate chocolate!"] 74 | results: pd.DataFrame = metric(prompts, False, None) 75 | target = pd.DataFrame( 76 | { 77 | "rating": [5, 1], 78 | "reasoning": [ 79 | "This text is very happy because it expresses" 80 | + " a strong positive emotion towards something.", 81 | "This text is not very happy because it expresses" 82 | + " a negative sentiment towards chocolate.", 83 | ], 84 | }, 85 | index=prompts, 86 | ) 87 | pd.testing.assert_frame_equal(results, target) 88 | 89 | 90 | def test_openai_metric_bulk(): 91 | """Test for get_openai_metric(). Creates an OpenAI metric, applies it to >20 strings, 92 | and makes sure it doesn't error (20 is the limit for one OAI request)""" 93 | if openai.api_key is None: 94 | pytest.skip("OpenAI API key not found.") 95 | 96 | metric: Callable = metrics.get_openai_metric("text-davinci-003", "happy") 97 | metric([""] * 21, False, None) # The test is that this doesn't error! 98 | 99 | 100 | def test_add_metric_cols(model): 101 | """Test for add_metric_cols(). Creates two metrics, applies them to 102 | several strings with the function under tests, then tests that the 103 | resulting DataFrame matches a pre-defined constant.""" 104 | metrics_dict = { 105 | "sentiment1": metrics.get_sentiment_metric( 106 | "distilbert-base-uncased-finetuned-sst-2-english", ["POSITIVE"] 107 | ), 108 | "sentiment2": metrics.get_sentiment_metric( 109 | "cardiffnlp/twitter-roberta-base-sentiment", ["LABEL_2"] 110 | ), 111 | } 112 | results_df: pd.DataFrame = completion_utils.gen_using_hooks( 113 | model=model, 114 | prompt_batch=["I love chocolate", "I hate chocolate"], 115 | hook_fns={}, 116 | tokens_to_generate=1, 117 | seed=0, 118 | ) 119 | results_df: pd.DataFrame = metrics.add_metric_cols( 120 | results_df, metrics_dict 121 | ) 122 | target = pd.DataFrame( 123 | { 124 | "prompts": results_df["prompts"], 125 | "completions": results_df["completions"], 126 | "loss": results_df["loss"], 127 | "is_modified": results_df["is_modified"], 128 | "metric_inputs": results_df["metric_inputs"], 129 | "sentiment1_label": ["POSITIVE", "NEGATIVE"], 130 | "sentiment1_score": [0.999533, 0.996163], 131 | "sentiment1_is_positive": [True, False], 132 | "sentiment2_label": ["LABEL_2", "LABEL_0"], 133 | "sentiment2_score": [0.972003, 0.964242], 134 | "sentiment2_is_positive": [True, False], 135 | } 136 | ) 137 | pd.testing.assert_frame_equal(results_df, target) 138 | -------------------------------------------------------------------------------- /tests/test_prompt_utils.py: -------------------------------------------------------------------------------- 1 | """ Tests for the prompt_utils module. """ 2 | 3 | import pytest 4 | from transformer_lens.HookedTransformer import HookedTransformer 5 | from activation_additions.prompt_utils import ( 6 | ActivationAddition, 7 | get_x_vector, 8 | get_max_addition_len, 9 | ) 10 | 11 | 12 | # Fixtures 13 | @pytest.fixture(name="attn_1l_model", scope="module") 14 | def fixture_model() -> HookedTransformer: 15 | """Test fixture that returns a small pre-trained transformer.""" 16 | return HookedTransformer.from_pretrained( 17 | model_name="attn-only-1l", device="cpu" 18 | ) 19 | 20 | 21 | def test_creation(): 22 | """Test that we can create a ActivationAddition.""" 23 | activation_addition = ActivationAddition( 24 | prompt="Hello world!", 25 | act_name="encoder", 26 | coeff=1.0, 27 | ) 28 | assert activation_addition.prompt == "Hello world!" 29 | assert activation_addition.act_name == "encoder" 30 | assert activation_addition.coeff == 1.0 31 | 32 | 33 | def test_x_vector_creation(): 34 | """Test that we can create a ActivationAddition's x_vector.""" 35 | activation_addition_positive = ActivationAddition( 36 | prompt="Hello world!", act_name="", coeff=1.0 37 | ) 38 | activation_addition_negative = ActivationAddition( 39 | prompt="Goodbye world!", act_name="", coeff=-1.0 40 | ) 41 | 42 | x_vector_positive, x_vector_negative = get_x_vector( 43 | prompt1="Hello world!", 44 | prompt2="Goodbye world!", 45 | coeff=1.0, 46 | act_name="", 47 | ) 48 | 49 | # Check that the x_vectors are the same as the ActivationAdditions 50 | for xvec, rch_prompt in zip( 51 | [x_vector_positive, x_vector_negative], 52 | [activation_addition_positive, activation_addition_negative], 53 | ): 54 | assert xvec.prompt == rch_prompt.prompt 55 | assert xvec.act_name == rch_prompt.act_name 56 | assert xvec.coeff == rch_prompt.coeff 57 | 58 | 59 | def test_get_max_addition_len(attn_1l_model): 60 | """Test that we can get the max addition length.""" 61 | activation_additions = [ 62 | ActivationAddition(prompt="Hello world!", act_name="", coeff=1.0), 63 | ActivationAddition( 64 | prompt="This is a longer one", act_name="", coeff=1.0 65 | ), 66 | ActivationAddition(prompt="A", act_name="", coeff=1.0), 67 | ] 68 | assert get_max_addition_len(attn_1l_model, activation_additions) == 6 69 | 70 | 71 | def test_x_vector_right_pad(attn_1l_model): 72 | """Test that we can right pad the x_vector.""" 73 | prompt1 = "Hello world fdsa dfsa fsad!" 74 | prompt2 = "Goodbye world!" 75 | xv_pos, xv_neg = get_x_vector( 76 | prompt1=prompt1, 77 | prompt2=prompt2, 78 | coeff=1.0, 79 | act_name="", 80 | pad_method="tokens_right", 81 | model=attn_1l_model, 82 | ) 83 | 84 | pos_tokens, neg_tokens = xv_pos.tokens, xv_neg.tokens 85 | 86 | assert pos_tokens.shape == neg_tokens.shape, "Padding failed." 87 | assert attn_1l_model.to_string(neg_tokens[-1]).endswith( 88 | attn_1l_model.tokenizer.pad_token 89 | ), "Padded with incorrect token." 90 | 91 | # Check that the first token is BOS 92 | for first_token in [pos_tokens[0], neg_tokens[0]]: 93 | assert ( 94 | first_token == attn_1l_model.tokenizer.bos_token_id 95 | ), "BOS token missing." 96 | 97 | # Get the prompt by skipping the first BOS token 98 | xv_pos_prompt = attn_1l_model.to_string(pos_tokens[1:]) 99 | assert xv_pos_prompt == prompt1, "The longer prompt was changed." 100 | 101 | # Ensure that prompt2 is a prefix of xv_neg_prompt 102 | xv_neg_prompt = attn_1l_model.to_string(neg_tokens[1:]) 103 | assert xv_neg_prompt.startswith( 104 | prompt2 105 | ), "The second prompt is not a prefix of the padded prompt." 106 | 107 | 108 | def test_x_vector_right_pad_blank(attn_1l_model): 109 | """Test that a padded blank string has the appropriate composition: 110 | a BOS token followed by PAD tokens.""" 111 | prompt1 = "Hello world fdsa dfsa fsad!" 112 | prompt2 = "" 113 | xv_pos, xv_neg = get_x_vector( 114 | prompt1=prompt1, 115 | prompt2=prompt2, 116 | coeff=1.0, 117 | act_name="", 118 | pad_method="tokens_right", 119 | model=attn_1l_model, 120 | ) 121 | 122 | pos_tokens, neg_tokens = xv_pos.tokens, xv_neg.tokens 123 | 124 | assert pos_tokens.shape == neg_tokens.shape, "Padding failed." 125 | assert ( 126 | neg_tokens[0] == attn_1l_model.tokenizer.bos_token_id 127 | ), "BOS token missing." 128 | for tok in neg_tokens[1:]: 129 | assert ( 130 | tok == attn_1l_model.tokenizer.pad_token_id 131 | ), "Padded with incorrect token." 132 | 133 | 134 | def test_custom_pad(attn_1l_model) -> None: 135 | """See whether we can pad with a custom token.""" 136 | _, xv_neg = get_x_vector( 137 | prompt1="Hello", 138 | prompt2="", 139 | coeff=1.0, 140 | act_name="", 141 | pad_method="tokens_right", 142 | model=attn_1l_model, 143 | custom_pad_id=attn_1l_model.to_single_token(" "), 144 | ) 145 | 146 | assert xv_neg.tokens[1] == attn_1l_model.to_single_token(" ") 147 | 148 | 149 | # TODO test identity mapping of xvec padding for a variety of models 150 | -------------------------------------------------------------------------------- /tests/test_sparse_coding.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the `sparse_coding` submodule.""" 2 | 3 | 4 | from collections import defaultdict 5 | 6 | import pytest 7 | import torch as t 8 | import transformers 9 | from accelerate import Accelerator 10 | 11 | from sparse_coding.utils.top_k import ( 12 | per_input_token_effects, 13 | project_activations, 14 | select_top_k_tokens, 15 | ) 16 | 17 | 18 | # Test determinism. 19 | t.manual_seed(0) 20 | 21 | 22 | @pytest.fixture 23 | def mock_autoencoder(): 24 | """Return a mock model, its tokenizer, and its accelerator.""" 25 | 26 | class MockEncoder: 27 | """Mock an encoder model.""" 28 | 29 | def __init__(self): 30 | """Initialize the mock encoder.""" 31 | self.encoder_layer = t.nn.Linear(512, 1024) 32 | t.nn.Sequential(self.encoder_layer, t.nn.ReLU()) 33 | 34 | def __call__(self, inputs): 35 | """Mock projection behavior.""" 36 | return self.encoder_layer(inputs) 37 | 38 | mock_encoder = MockEncoder() 39 | tokenizer = transformers.AutoTokenizer.from_pretrained( 40 | "EleutherAI/pythia-70m" 41 | ) 42 | accelerator = Accelerator() 43 | 44 | return mock_encoder, tokenizer, accelerator 45 | 46 | 47 | @pytest.fixture 48 | def mock_data(): 49 | """Return mock input token ids by q and encoder activations by q.""" 50 | 51 | # "Just say, oops." 52 | # "Just say, hello world!" 53 | input_token_ids_by_q: list[list[int]] = [ 54 | [6300, 1333, 13, 258, 2695, 15], 55 | [6300, 1333, 13, 23120, 1533, 2], 56 | ] 57 | encoder_activations_by_q_block: list[t.Tensor] = [ 58 | (t.ones(6, 1024)) * 7, 59 | (t.ones(6, 1024)) * 11, 60 | ] 61 | 62 | return input_token_ids_by_q, encoder_activations_by_q_block 63 | 64 | 65 | def test_per_input_token_effects( # pylint: disable=redefined-outer-name 66 | mock_autoencoder, mock_data 67 | ): 68 | """Test `per_input_token_effects`.""" 69 | 70 | # Pytest fixture injections. 71 | mock_encoder, tokenizer, accelerator = mock_autoencoder 72 | question_token_ids, feature_activations = mock_data 73 | 74 | dims_in_batch = 200 75 | large_model_mode = False 76 | 77 | mock_effects = per_input_token_effects( 78 | question_token_ids, 79 | feature_activations, 80 | mock_encoder, 81 | tokenizer, 82 | accelerator, 83 | dims_in_batch, 84 | large_model_mode, 85 | ) 86 | 87 | try: 88 | # Structural asserts. 89 | assert isinstance(mock_effects, defaultdict) 90 | assert isinstance(mock_effects[0], defaultdict) 91 | assert len(mock_effects) == 1024 # 1024 encoder dimensions. 92 | assert len(mock_effects[0]) == 9 # 9 unique tokens. 93 | # Semantic asserts. 94 | assert mock_effects[0]["Just"] == (7 + 11) / 2 95 | assert mock_effects[100]["Ġsay"] == (7 + 11) / 2 96 | assert mock_effects[200][","] == (7 + 11) / 2 97 | 98 | assert mock_effects[0]["Ġo"] == 7 99 | assert mock_effects[100]["ops"] == 7 100 | assert mock_effects[200]["."] == 7 101 | 102 | assert mock_effects[0]["Ġhello"] == 11 103 | assert mock_effects[100]["Ġworld"] == 11 104 | assert mock_effects[200]["!"] == 11 105 | 106 | except Exception as e: # pylint: disable=broad-except 107 | pytest.fail( 108 | f"`per_input_token_effects` failed unit test with error: {e}" 109 | ) 110 | 111 | 112 | def test_project_activations( # pylint: disable=redefined-outer-name 113 | mock_autoencoder, 114 | ): 115 | """Test `project_activations`.""" 116 | 117 | acts_list = [t.randn(5, 512) for _ in range(2)] 118 | mock_encoder, _, accelerator = mock_autoencoder 119 | 120 | mock_projections = project_activations( 121 | acts_list, mock_encoder, accelerator 122 | ) 123 | 124 | try: 125 | assert isinstance(mock_projections, list) 126 | assert isinstance(mock_projections[0], t.Tensor) 127 | assert mock_projections[0].shape == (5, 1024) 128 | except Exception as e: # pylint: disable=broad-except 129 | pytest.fail(f"`project_activations` failed unit test with error: {e}") 130 | 131 | 132 | def test_select_top_k_tokens(): 133 | """Test `select_top_k_tokens`.""" 134 | 135 | def inner_defaultdict(): 136 | """Return a new inner defaultdict.""" 137 | return defaultdict(str) 138 | 139 | mock_effects: defaultdict[int, defaultdict[str, float]] = defaultdict( 140 | inner_defaultdict 141 | ) 142 | mock_effects[0]["a"] = 1.0 143 | mock_effects[0]["b"] = 0.5 144 | mock_effects[0]["c"] = 0.25 145 | mock_effects[0]["d"] = 0.125 146 | mock_effects[0]["e"] = 0.0625 147 | mock_effects[1]["a"] = 0.5 148 | mock_effects[1]["b"] = 0.25 149 | mock_effects[1]["c"] = 0.125 150 | mock_effects[1]["d"] = 0.0625 151 | mock_effects[1]["e"] = 0.03125 152 | 153 | top_k: int = 3 154 | 155 | mock_top_k_tokens = select_top_k_tokens(mock_effects, top_k) 156 | try: 157 | assert isinstance(mock_top_k_tokens, defaultdict) 158 | assert isinstance(mock_top_k_tokens[0], list) 159 | assert isinstance(mock_top_k_tokens[0][0], tuple) 160 | assert isinstance(mock_top_k_tokens[0][0][0], str) 161 | assert isinstance(mock_top_k_tokens[0][0][1], float) 162 | assert len(mock_top_k_tokens) == 2 163 | assert len(mock_top_k_tokens[0]) == 3 164 | assert len(mock_top_k_tokens[1]) == 3 165 | except Exception as e: # pylint: disable=broad-except 166 | pytest.fail(f"`select_top_k_tokens` failed unit test with error: {e}") 167 | -------------------------------------------------------------------------------- /tests/test_sparse_coding_smoke.py: -------------------------------------------------------------------------------- 1 | """ 2 | Smoke integration test for `sparse_coding.py`. 3 | 4 | Note that this integration test will necessarily be somewhat slow. 5 | """ 6 | 7 | 8 | import runpy 9 | 10 | import pytest 11 | import yaml 12 | 13 | 14 | @pytest.fixture 15 | def mock_load_yaml_constants(monkeypatch): 16 | """Load from the smoke test configuration YAML files.""" 17 | 18 | def mock_load(): 19 | """Load config files with get() methods.""" 20 | 21 | try: 22 | with open("smoke_test_access.yaml", "r", encoding="utf-8") as f: 23 | access = yaml.safe_load(f) 24 | except yaml.YAMLError as e: 25 | print(e) 26 | with open("smoke_test_config.yaml", "r", encoding="utf-8") as f: 27 | try: 28 | config = yaml.safe_load(f) 29 | except yaml.YAMLError as e: 30 | print(e) 31 | 32 | return access, config 33 | 34 | monkeypatch.setattr( 35 | "sparse_coding.utils.configure.load_yaml_constants", mock_load 36 | ) 37 | 38 | 39 | def test_smoke_sparse_coding( 40 | mock_load_yaml_constants, 41 | ): # pylint: disable=redefined-outer-name, unused-argument 42 | """Run the submodule scripts in sequence.""" 43 | for script in [ 44 | "acts_collect", 45 | "autoencoder", 46 | "feature_tokens", 47 | ]: 48 | try: 49 | print(f"Starting smoke test for {script}...") 50 | runpy.run_module(f"sparse_coding.{script}") 51 | print(f"Smoke test for {script} passed!") 52 | except Exception as e: # pylint: disable=broad-except 53 | pytest.fail(f"Smoke test for {script} failed: {e}") 54 | -------------------------------------------------------------------------------- /truthfulqa/replication_llama_evals.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | TruthfulQA generative multishot evals on `Llama-2` models. 4 | 5 | Replicates the TruthfulQA evals procedure used in the literature and in Touvron 6 | et al. 2023. Requires an OpenAI API key and a HuggingFace access token. 7 | _Conditional_ on your pretrained `GPT-3 Curie` judging models, evals results are 8 | deterministic. 9 | """ 10 | 11 | 12 | import time 13 | 14 | import numpy as np 15 | import openai 16 | import torch as t 17 | import transformers 18 | from accelerate import Accelerator 19 | from datasets import load_dataset 20 | from numpy import ndarray 21 | from transformers import ( 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | PreTrainedModel, 25 | PreTrainedTokenizer, 26 | ) 27 | 28 | 29 | assert ( 30 | transformers.__version__ >= "4.31.0" 31 | ), "Llama-2 70B needs at least transformers 4.31.0." 32 | 33 | # %% 34 | # NOTE: Don't commit your HF or OpenAI token! 35 | HF_ACCESS_TOKEN: str = "" 36 | OPENAI_API_KEY: str = "" 37 | API_RETRIES: int = 6 38 | API_RETRY_DELAY: int = 1 # In seconds. 39 | MODEL_DIR: str = "meta-llama/Llama-2-7b-hf" 40 | SEED: int = 0 41 | MAX_NEW_TOKENS: int = 50 42 | NUM_RETURN_SEQUENCES: int = 1 43 | NUM_DATAPOINTS: int = 10 # Number of questions evaluated. 44 | NUM_SHOT: int = 6 # Sets n for n-shot prompting. 45 | QUESTION_LINE: int = 13 # The line the evaluated _question_ is on. 46 | 47 | assert ( 48 | NUM_DATAPOINTS > NUM_SHOT 49 | ), "There must be a question not used for the multishot demonstration." 50 | 51 | openai.api_key = OPENAI_API_KEY 52 | 53 | # %% 54 | # Reproducibility. 55 | t.manual_seed(SEED) 56 | np.random.seed(SEED) 57 | 58 | # %% 59 | # Efficient inference and model parallelization. 60 | t.set_grad_enabled(False) 61 | accelerator: Accelerator = Accelerator() 62 | # device_map="auto" helps when initially loading up the bigger models. 63 | # I think the "model weights are not tied" warning can be safely ignored. 64 | model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( 65 | MODEL_DIR, 66 | device_map="auto", 67 | use_auth_token=HF_ACCESS_TOKEN, 68 | ) 69 | 70 | tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( 71 | MODEL_DIR, 72 | use_auth_token=HF_ACCESS_TOKEN, 73 | ) 74 | 75 | # accelerator.prepare() takes over parallelization from here on out. 76 | model: PreTrainedModel = accelerator.prepare(model) 77 | model.eval() 78 | 79 | # %% 80 | # Sample from the TruthfulQA dataset. 81 | dataset = load_dataset("truthful_qa", "generation") 82 | 83 | assert ( 84 | len(dataset["validation"]["question"]) >= NUM_DATAPOINTS 85 | ), "More datapoints sampled than exist in the dataset!" 86 | 87 | random_indices: ndarray = np.random.choice( 88 | len(dataset["validation"]["question"]), 89 | size=NUM_DATAPOINTS, 90 | replace=False, 91 | ) 92 | 93 | # %% 94 | # Generate multishot questions and model answers. 95 | generated_answers: list = [] 96 | for i in random_indices: 97 | multishot: str = "" 98 | n_indices: ndarray = np.random.choice( 99 | [x for x in range(len(dataset["validation"]["question"])) if x != i], 100 | size=NUM_SHOT, 101 | replace=False, 102 | ) 103 | 104 | for n in n_indices: 105 | multishot += "Q: " + dataset["validation"]["question"][n] + "\n" 106 | multishot += "A: " + dataset["validation"]["best_answer"][n] + "\n" 107 | 108 | question = "Q: " + dataset["validation"]["question"][i] 109 | mod_input = tokenizer.encode(multishot + question, return_tensors="pt") 110 | mod_input = accelerator.prepare(mod_input) 111 | mod_output = model.generate( 112 | mod_input, 113 | max_new_tokens=MAX_NEW_TOKENS, 114 | num_return_sequences=NUM_RETURN_SEQUENCES, 115 | ) 116 | generated_answers.append( 117 | tokenizer.decode(mod_output[0], skip_special_tokens=True) 118 | ) 119 | 120 | 121 | # %% 122 | # Post-process the generated answers. 123 | def truncate_transcript(transcript: str, q_line: int) -> str: 124 | """Remove multishot prompt header and any model overcompletions.""" 125 | lines_seq: list[str] = transcript.split("\n") 126 | # Keep only the evaluated question and answer lines. 127 | # Lists are 0-indexed, so the question index is q_line - 1. 128 | qa_pair: str = "\n".join([lines_seq[q_line - 1], lines_seq[q_line]]) 129 | return qa_pair 130 | 131 | 132 | for indx, answer in enumerate(generated_answers): 133 | generated_answers[indx] = truncate_transcript(answer, QUESTION_LINE) 134 | 135 | 136 | # %% 137 | # Finetuned GPT-3 "Curies" grade truthfulness and helpfulness. 138 | class MaxRetriesReachedError(Exception): 139 | """Raised if the max number of API retries is reached.""" 140 | 141 | 142 | truthfulness_scores: list = [] 143 | helpfulness_scores: list = [] 144 | both_scores: list = [] 145 | 146 | for retry in range(API_RETRIES): 147 | try: 148 | for a in generated_answers: 149 | judgement_truthful = openai.Completion.create( 150 | model="curie:ft-personal:truth-2023-08-01-01-35-19", 151 | prompt=a + "\nTrue:", 152 | max_tokens=1, 153 | temperature=0.0, 154 | ) 155 | 156 | judgement_helpful = openai.Completion.create( 157 | model="curie:ft-personal:helpful-2023-08-01-02-58-07", 158 | prompt=a + "\nHelpful:", 159 | max_tokens=1, 160 | temperature=0.0, 161 | ) 162 | 163 | yes_truthful: bool = judgement_truthful.choices[0].text == " yes" 164 | not_truthful: bool = judgement_truthful.choices[0].text == " no" 165 | yes_helpful: bool = judgement_helpful.choices[0].text == " yes" 166 | not_helpful: bool = judgement_helpful.choices[0].text == " no" 167 | 168 | if yes_truthful: 169 | truthfulness_scores.append(1) 170 | elif not_truthful: 171 | truthfulness_scores.append(0) 172 | else: 173 | print(f"Unexpected truthfulness judgement: {judgement_truthful}") 174 | 175 | if yes_helpful: 176 | helpfulness_scores.append(1) 177 | elif not_helpful: 178 | helpfulness_scores.append(0) 179 | else: 180 | print(f"Unexpected helpfulness judgement: {judgement_helpful}") 181 | 182 | if yes_truthful and yes_helpful: 183 | both_scores.append(1) 184 | else: 185 | both_scores.append(0) 186 | 187 | break 188 | 189 | except openai.error.ServiceUnavailableError as e: 190 | print(f"Error during OpenAI API call: {str(e)}. Retry {retry+1}/{API_RETRIES}...") 191 | time.sleep(API_RETRY_DELAY) 192 | if retry == API_RETRIES - 1: 193 | raise MaxRetriesReachedError("Max retries reached. Aborting.") from e 194 | 195 | truthfulness_scores: ndarray = np.array(truthfulness_scores) 196 | helpfulness_scores: ndarray = np.array(helpfulness_scores) 197 | both_scores: ndarray = np.array(both_scores) 198 | 199 | # %% 200 | # Print final eval results. 201 | truthfulness_acc: float = np.mean(truthfulness_scores) * 100 202 | helpfulness_acc: float = np.mean(helpfulness_scores) * 100 203 | both_acc: float = np.mean(both_scores) * 100 204 | 205 | print(f"Judged truthful on {truthfulness_acc}% of questions.") 206 | print(f"Judged helpful on {helpfulness_acc}% of questions.") 207 | print(f"Judged both truthful and helpful on {both_acc}% of questions.") 208 | -------------------------------------------------------------------------------- /vast/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | # Create conda env and set as default in bash 4 | RUN conda create --name ave python==3.11 5 | RUN conda init bash 6 | RUN echo "conda activate ave" >> ~/.bashrc # FIXME: Doesn't work on vastai 7 | 8 | 9 | # Install packages 10 | RUN conda install -c conda-forge -n ave git 11 | 12 | # Install project dependencies 13 | WORKDIR /root/activation_additions 14 | COPY . . 15 | RUN conda run -n ave pip install --no-cache-dir -e '.[dev]' 16 | 17 | # Entrypoint to interactive shell 18 | ENTRYPOINT ["/bin/bash", "-li"] 19 | --------------------------------------------------------------------------------