├── .gitignore ├── 2b_probes_data └── combine.ipynb ├── README.md ├── example.py ├── mmlu_evals.py ├── plots ├── 2b_experiments_mmlu.pdf ├── 2b_experiments_mmlu.png ├── 2b_experiments_semantics.pdf ├── 2b_experiments_semantics.png ├── 2b_experiments_syntax.pdf ├── 2b_experiments_syntax.png ├── 9b_long_hard_mmlu.pdf ├── 9b_long_hard_mmlu.png ├── 9b_long_hard_semantics.pdf ├── 9b_long_hard_semantics.png ├── 9b_long_hard_syntax.pdf ├── 9b_long_hard_syntax.png ├── 9b_pr_curves.png ├── 9b_short_hard_mmlu.pdf ├── 9b_short_hard_mmlu.png ├── 9b_short_hard_semantics.pdf ├── 9b_short_hard_semantics.png ├── 9b_short_hard_syntax.pdf ├── 9b_short_hard_syntax.png ├── 9b_short_normal_mmlu.pdf ├── 9b_short_normal_mmlu.png ├── 9b_short_normal_semantics.pdf ├── 9b_short_normal_semantics.png ├── 9b_short_normal_syntax.pdf ├── 9b_short_normal_syntax.png ├── final_charts │ ├── gemma-2-2b.png │ ├── gemma-2-9b.png │ ├── llama-3.1-8b.png │ └── pareto_frontier.png ├── llama_layer12_mmlu.pdf ├── llama_layer12_mmlu.png ├── llama_layer12_semantics.pdf ├── llama_layer12_semantics.png ├── llama_layer12_syntax.pdf ├── llama_layer12_syntax.png ├── llama_layer8_mmlu.pdf ├── llama_layer8_mmlu.png ├── llama_layer8_semantics.pdf ├── llama_layer8_semantics.png ├── llama_layer8_syntax.pdf └── llama_layer8_syntax.png ├── pyproject.toml ├── sae ├── kernels.py ├── sae.py └── utils.py ├── src ├── agent_eval.py ├── caa.py ├── count_activations.py ├── eval_config.py ├── prompts │ ├── activation_counting_code.json │ ├── code.json │ ├── code.txt │ ├── code_hard.txt │ ├── contrastive_prompts.json │ ├── longer_pytest_docs.txt │ ├── make_json_prompts.ipynb │ ├── not_regex_code.txt │ ├── probe_prompts.json │ ├── prompt.txt │ ├── prompt_hard.txt │ ├── prompt_medium.txt │ ├── prompt_no_regex.txt │ ├── prompt_repetitions.txt │ ├── pytest_docs.txt │ └── user_questions.txt ├── regex_interventions.py ├── utils.py └── wrapper.py ├── view_activations.ipynb └── view_run_results.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | *.pkl 3 | *.pt 4 | __pycache__/ 5 | 6 | 7 | 8 | # Distribution / packaging 9 | dist/ 10 | build/ 11 | *.egg-info/ 12 | *.egg 13 | 14 | # Jupyter Notebook 15 | .ipynb_checkpoints 16 | 17 | # IDE specific files 18 | .idea/ 19 | .vscode/ 20 | *.swp 21 | *.swo 22 | .DS_Store 23 | 24 | # Unit test / coverage reports 25 | htmlcov/ 26 | .tox/ 27 | .coverage 28 | .coverage.* 29 | .cache 30 | coverage.xml 31 | *.cover 32 | .pytest_cache/ 33 | 34 | # Logs 35 | *.log 36 | 37 | # Local development settings 38 | .env 39 | .env.local -------------------------------------------------------------------------------- /2b_probes_data/combine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "dict_keys(['config', 'constant_sae', 'constant_steering_vector', 'conditional_per_token', 'sae_steering_vector'])\n", 13 | "dict_keys(['config', 'probe_steering_vector', 'probe_sae', 'conditional_clamping', 'conditional_steering_vector'])\n", 14 | "dict_keys(['config', 'probe_sae_clamping', 'probe_steering_vector_clamping'])\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import pickle\n", 20 | "\n", 21 | "filename1 = \"results_mmlu_evals_baseline.pkl\"\n", 22 | "filename2 = \"results_mmlu_evals.pkl\"\n", 23 | "filename3 = \"probe_only_results_mmlu_evals.pkl\"\n", 24 | "filename3 = \"probe_clamping_results_mmlu_evals.pkl\"\n", 25 | "\n", 26 | "with open(filename1, \"rb\") as f:\n", 27 | " results1 = pickle.load(f)\n", 28 | "\n", 29 | "with open(filename2, \"rb\") as f:\n", 30 | " results2 = pickle.load(f)\n", 31 | "\n", 32 | "with open(filename3, \"rb\") as f:\n", 33 | " results3 = pickle.load(f)\n", 34 | "\n", 35 | "print(results1.keys())\n", 36 | "print(results2.keys())\n", 37 | "print(results3.keys())\n", 38 | "\n", 39 | "for key in results2:\n", 40 | " if key != \"config\":\n", 41 | " results1[key] = results2[key]\n", 42 | "\n", 43 | "with open(\"results_mmlu_evals_combined.pkl\", \"wb\") as f:\n", 44 | " pickle.dump(results1, f)" 45 | ] 46 | } 47 | ], 48 | "metadata": { 49 | "kernelspec": { 50 | "display_name": ".venv", 51 | "language": "python", 52 | "name": "python3" 53 | }, 54 | "language_info": { 55 | "codemirror_mode": { 56 | "name": "ipython", 57 | "version": 3 58 | }, 59 | "file_extension": ".py", 60 | "mimetype": "text/x-python", 61 | "name": "python", 62 | "nbconvert_exporter": "python", 63 | "pygments_lexer": "ipython3", 64 | "version": "3.11.8" 65 | } 66 | }, 67 | "nbformat": 4, 68 | "nbformat_minor": 2 69 | } 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sieve: Steering for Fine-grained Regex Control 2 | 3 | Sieve is a simple framework for applying targeted interventions to language models using Sparse Autoencoder (SAE) techniques. It enables precise control over model behavior through feature-level manipulations. 4 | 5 | Read the full case study on Sieve's application and results on our blog [here](https://www.tilderesearch.com/blog/sieve). 6 | 7 | ## Results 8 | 9 | Our experiments demonstrate Pareto dominance of SAE-based methods on fine-grained control. Below are the Pareto frontiers across models. 10 | 11 | 12 | 13 | ## Quick Start 14 | 15 | You will have to install flash attention separately. Without flash attention, the model will fall back to eager. 16 | 17 | ```bash 18 | # Install flash attention 19 | pip install wheel 20 | pip install flash-attn --no-build-isolation 21 | ``` 22 | 23 | ```bash 24 | # Install package 25 | pip install -e . 26 | 27 | # Run regex interventions example 28 | python src/regex_interventions.py 29 | ``` 30 | 31 | Use `pip install -e .` rather than `pip install -r requirements.txt`, as it simplifies using imports and makes it easier to add tests. 32 | 33 | Then run `python src/regex_interventions.py`. You can change settings in `src/eval_config.py`. 34 | 35 | This will: 36 | 37 | 1. Load a pre-trained model and SAE 38 | 2. Run baseline code generation 39 | 3. Apply various interventions to control regex usage 40 | 4. Evaluate and compare results 41 | 42 | It can be used to directly reproduce reported results. 43 | 44 | ## Runtime & Parallelization 45 | 46 | This is a parameter sweep across multiple intervention methods and scales, which can take significant time to run. With 200 generations per evaluation method, each method takes approximately 8 hours to complete on a single GPU. Probe methods take longer, as they require generating data and then training probes. 47 | 48 | However, the evaluation can be parallelized by taking the following steps: 49 | 50 | 1. Split the intervention methods across multiple GPUs 51 | 2. Run separate jobs with different subsets of `intervention_types` in `eval_config.py` 52 | 3. Save results to different output files (e.g., `results_gpu1.json`, `results_gpu2.json`) 53 | 54 | The visualization notebooks (`view_run_results.ipynb`) are designed to handle multiple result files, allowing you to combine outputs from parallel runs for analysis. You can also reduce the number of generations per method to speed up evaluation. 55 | 56 | MMLU evaluations, by contrast, are much shorter and complete in roughly 1-2 hours for all methods. 57 | 58 | ## Overview 59 | 60 | Sieve provides tools for: 61 | 62 | - Loading and applying pre trained SAEs 63 | - Performing targeted feature interventions 64 | - Analyzing activation patterns 65 | - Evaluating intervention effects 66 | 67 | The framework supports multiple intervention types: 68 | 69 | 1. **Constant SAE**: Direct feature manipulation 70 | 2. **Conditional SAE**: Activation-dependent interventions 71 | 3. **Contrastive Activations (CAA)**: Steering vectors 72 | 4. **Probe-based**: Linear probes 73 | 74 | ### InterventionWrapper 75 | 76 | The main interface for applying interventions: 77 | 78 | ```python 79 | from src.wrapper import InterventionWrapper 80 | 81 | # Initialize wrapper 82 | wrapper = InterventionWrapper(model_name="google/gemma-2b-it") 83 | 84 | # Load SAE 85 | wrapper.load_sae(release="gemma-scope-2b-pt-res", sae_id="layer_8/width_16k/average_l0_71", layer_idx=8) 86 | 87 | # Generate with intervention 88 | output = wrapper.generate(prompt, intervention_type="CONDITIONAL_PER_TOKEN", scale=1.0) 89 | ``` 90 | 91 | ### Evaluation Tools 92 | 93 | - `regex_interventions.py`: Test regex usage patterns (this is the main file for reproducing results) 94 | - `count_activations.py`: Analyze performance for different kidns of classifiers (SAE encoder vectors, CAA vectors, linear probes) 95 | - `agent_eval.py`: Evaluate generation quality 96 | 97 | ## Configuration 98 | 99 | Modify settings in `src/eval_config.py`: 100 | 101 | ```python 102 | class EvalConfig: 103 | model_name: str = "google/gemma-2b-it" 104 | intervention_types: list = [ 105 | "CONSTANT_SAE", 106 | "CONDITIONAL_PER_TOKEN", 107 | "CONDITIONAL_PER_INPUT", 108 | "CONSTANT_STEERING_VECTOR" 109 | ] 110 | ``` 111 | The eval config can be used to change the prompt, intervention scales, base model, etc. appropriately as well. However, to change the SAE features used, you have to modify `utils.py` to change the feature indices and layer. The layer 8 llama feature is currently commented out. 112 | 113 | ## Project Structure 114 | 115 | ``` 116 | sieve/ 117 | ├── src/ 118 | │ ├── agent_eval.py # LLM-based evaluation tools 119 | │ ├── caa.py # Contrastive Activation Addition core 120 | │ ├── count_activations.py # Probe analysis tools 121 | │ ├── regex_interventions.py # Main entry point 122 | │ ├── wrapper.py # Core intervention framework 123 | │ ├── eval_config.py # Configuration settings 124 | │ └── utils.py # Helper functions 125 | │ 126 | ├── mmlu_evals.py # MMLU benchmark evaluation script 127 | ├── view_activations.ipynb # Notebook for activation analysis 128 | ├── view_run_results.ipynb # Notebook for experiment results 129 | ├── pyproject.toml # Package configuration 130 | └── README.md 131 | ``` 132 | 133 | ## Collaboration 134 | 135 | Sieve is a collaboration between Tilde and Adam Karvonen, leveraging: 136 | 137 | - GemmaScope pretrained SAEs [1] 138 | - Custom code-specific SAEs for Llama models 139 | - Advanced intervention techniques 140 | 141 | ## Citation 142 | 143 | If you use Sieve or any of its results in your research, please cite: 144 | 145 | ```bibtex 146 | @article{karvonen2024sieve, 147 | title={Sieve: SAEs Beat Baselines on a Real-World Task (A Code Generation Case Study)}, 148 | author={Karvonen, Adam and Pai, Dhruv and Wang, Mason and Keigwin, Ben}, 149 | journal={Tilde Research Blog}, 150 | year={2024}, 151 | month={12}, 152 | url={https://www.tilderesearch.com/blog/sieve}, 153 | note={Blog post} 154 | } 155 | ``` 156 | 157 | For the GemmaScope pretrained SAEs, please cite: 158 | 159 | ```bibtex 160 | @misc{lieberum2024gemmascopeopensparse, 161 | title={Gemma Scope: Open Sparse Autoencoders Everywhere All At Once on Gemma 2}, 162 | author={Tom Lieberum and Senthooran Rajamanoharan and Arthur Conmy and Lewis Smith and Nicolas Sonnerat and Vikrant Varma and János Kramár and Anca Dragan and Rohin Shah and Neel Nanda}, 163 | year={2024}, 164 | eprint={2408.05147}, 165 | archivePrefix={arXiv}, 166 | primaryClass={cs.LG}, 167 | url={https://arxiv.org/abs/2408.05147}, 168 | } 169 | ``` 170 | 171 | [1] Lieberum et al., "Gemma Scope: Open Sparse Autoencoders Everywhere All At Once on Gemma 2", 2024 172 | 173 | ## License 174 | 175 | MIT License 176 | 177 | Copyright (c) 2024 Tilde Research, Inc. and Adam Karvonen 178 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.wrapper import InterventionWrapper 3 | from src.eval_config import EvalConfig, InterventionType 4 | 5 | # Step 1: Set up the device 6 | device = "cuda" if torch.cuda.is_available() else "cpu" 7 | 8 | # Step 2: Define model parameters 9 | MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" 10 | LAYER = 12 11 | FEATURE_IDX = 9853 12 | SCALE = -8.0 13 | 14 | # Step 3: Create the InterventionWrapper 15 | wrapper = InterventionWrapper(MODEL_NAME, device=device) 16 | 17 | # Step 4: Load the SAE 18 | wrapper.load_sae(f"tilde-research/sieve_coding", sae_id=None, layer_idx=LAYER) 19 | 20 | # Step 5: Set up the intervention 21 | config = EvalConfig() 22 | model_params = { 23 | "targ_layer": LAYER, 24 | "feature_idx": FEATURE_IDX 25 | } 26 | 27 | 28 | 29 | # Step 6: Format input text using chat template 30 | input_text = "Write a python function using the re module to match a numerical substring in a string." 31 | chat = [{"role": "user", "content": input_text}] 32 | formatted_text = wrapper.tokenizer.apply_chat_template( 33 | chat, 34 | tokenize=False, 35 | add_generation_prompt=True 36 | ) 37 | 38 | # Step 7: Generate text without intervention 39 | print("Generating without intervention...") 40 | generated_text_original = wrapper.generate( 41 | [formatted_text], 42 | max_new_tokens=200, 43 | temperature=0.2, 44 | module_and_hook_fn=None 45 | )[0] 46 | 47 | # Step 8: Generate text with intervention 48 | print("\nGenerating with intervention...") 49 | module_and_hook_fn = wrapper.get_hook( 50 | intervention_type=InterventionType.CONDITIONAL_PER_TOKEN.value, 51 | model_params=model_params, 52 | scale=SCALE, 53 | config=config 54 | ) 55 | 56 | generated_text_intervened = wrapper.generate( 57 | [formatted_text], 58 | max_new_tokens=800, 59 | temperature=0.2, 60 | # repetition_penalty=1.15, 61 | module_and_hook_fn=module_and_hook_fn 62 | )[0] 63 | 64 | # Step 8: Print and compare the results 65 | print("Original generated text:") 66 | print(generated_text_original) 67 | print("\nGenerated text with intervention:") 68 | print(generated_text_intervened) 69 | 70 | # Optional: Calculate and print the difference in token length 71 | # original_tokens = len(wrapper.model.to_tokens(generated_text_original)[0]) 72 | # intervened_tokens = len(wrapper.model.to_tokens(generated_text_intervened)[0]) 73 | # print(f"\nOriginal generation length: {original_tokens} tokens") 74 | # print(f"Intervened generation length: {intervened_tokens} tokens") 75 | # print(f"Difference: {intervened_tokens - original_tokens} tokens") 76 | -------------------------------------------------------------------------------- /mmlu_evals.py: -------------------------------------------------------------------------------- 1 | import lm_eval 2 | from lm_eval.models.huggingface import HFLM 3 | import torch 4 | import pickle 5 | from dataclasses import asdict 6 | 7 | from src.eval_config import EvalConfig 8 | from src.wrapper import InterventionWrapper 9 | import src.utils as utils 10 | 11 | 12 | if __name__ == "__main__": 13 | # Step 1: Set up the device 14 | torch.set_grad_enabled(False) 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | config = EvalConfig() 18 | 19 | model_params = utils.get_model_params(config.model_name) 20 | 21 | tasks = [ 22 | "mmlu_high_school_statistics", 23 | "mmlu_high_school_computer_science", 24 | "mmlu_high_school_mathematics", 25 | "mmlu_high_school_physics", 26 | "mmlu_high_school_biology", 27 | ] 28 | 29 | # Step 3: Create the InterventionWrapper 30 | wrapper = InterventionWrapper(config.model_name, device=device, dtype=torch.bfloat16) 31 | 32 | # Step 4: Load the SAE 33 | wrapper.load_sae(release=model_params["sae_release"], sae_id=model_params["sae_id"], layer_idx=model_params["targ_layer"]) 34 | 35 | results = {"config": asdict(config)} 36 | 37 | # indexes all tasks from the `lm_eval/tasks` subdirectory. 38 | # Alternatively, you can set `TaskManager(include_path="path/to/my/custom/task/configs")` 39 | # to include a set of tasks in a separate directory. 40 | 41 | lm_obj = HFLM(pretrained=wrapper.model) 42 | 43 | task_manager = lm_eval.tasks.TaskManager() 44 | 45 | for intervention_type in config.intervention_types: 46 | results[intervention_type] = {} 47 | 48 | results[intervention_type][0] = lm_eval.simple_evaluate( 49 | model=lm_obj, 50 | tasks=tasks, 51 | num_fewshot=0, 52 | task_manager=task_manager, 53 | ) 54 | 55 | for scale in config.scales: 56 | module, hook_fn = wrapper.get_hook(intervention_type, model_params, scale, config) 57 | 58 | # NOTE: Make sure to remove the hook after using it. 59 | handle = module.register_forward_hook(hook_fn) 60 | 61 | # Setting `task_manager` to the one above is optional and should generally be done 62 | # if you want to include tasks from paths other than ones in `lm_eval/tasks`. 63 | # `simple_evaluate` will instantiate its own task_manager if it is set to None here. 64 | 65 | results[intervention_type][scale] = lm_eval.simple_evaluate( # call simple_evaluate 66 | model=lm_obj, 67 | tasks=tasks, 68 | num_fewshot=0, 69 | task_manager=task_manager, 70 | ) 71 | 72 | with open("results_mmlu_evals.pkl", "wb") as f: 73 | pickle.dump(results, f) 74 | 75 | handle.remove() 76 | -------------------------------------------------------------------------------- /plots/2b_experiments_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_mmlu.pdf -------------------------------------------------------------------------------- /plots/2b_experiments_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_mmlu.png -------------------------------------------------------------------------------- /plots/2b_experiments_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_semantics.pdf -------------------------------------------------------------------------------- /plots/2b_experiments_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_semantics.png -------------------------------------------------------------------------------- /plots/2b_experiments_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_syntax.pdf -------------------------------------------------------------------------------- /plots/2b_experiments_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_syntax.png -------------------------------------------------------------------------------- /plots/9b_long_hard_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_mmlu.pdf -------------------------------------------------------------------------------- /plots/9b_long_hard_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_mmlu.png -------------------------------------------------------------------------------- /plots/9b_long_hard_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_semantics.pdf -------------------------------------------------------------------------------- /plots/9b_long_hard_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_semantics.png -------------------------------------------------------------------------------- /plots/9b_long_hard_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_syntax.pdf -------------------------------------------------------------------------------- /plots/9b_long_hard_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_syntax.png -------------------------------------------------------------------------------- /plots/9b_pr_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_pr_curves.png -------------------------------------------------------------------------------- /plots/9b_short_hard_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_mmlu.pdf -------------------------------------------------------------------------------- /plots/9b_short_hard_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_mmlu.png -------------------------------------------------------------------------------- /plots/9b_short_hard_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_semantics.pdf -------------------------------------------------------------------------------- /plots/9b_short_hard_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_semantics.png -------------------------------------------------------------------------------- /plots/9b_short_hard_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_syntax.pdf -------------------------------------------------------------------------------- /plots/9b_short_hard_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_syntax.png -------------------------------------------------------------------------------- /plots/9b_short_normal_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_mmlu.pdf -------------------------------------------------------------------------------- /plots/9b_short_normal_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_mmlu.png -------------------------------------------------------------------------------- /plots/9b_short_normal_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_semantics.pdf -------------------------------------------------------------------------------- /plots/9b_short_normal_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_semantics.png -------------------------------------------------------------------------------- /plots/9b_short_normal_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_syntax.pdf -------------------------------------------------------------------------------- /plots/9b_short_normal_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_syntax.png -------------------------------------------------------------------------------- /plots/final_charts/gemma-2-2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/gemma-2-2b.png -------------------------------------------------------------------------------- /plots/final_charts/gemma-2-9b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/gemma-2-9b.png -------------------------------------------------------------------------------- /plots/final_charts/llama-3.1-8b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/llama-3.1-8b.png -------------------------------------------------------------------------------- /plots/final_charts/pareto_frontier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/pareto_frontier.png -------------------------------------------------------------------------------- /plots/llama_layer12_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_mmlu.pdf -------------------------------------------------------------------------------- /plots/llama_layer12_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_mmlu.png -------------------------------------------------------------------------------- /plots/llama_layer12_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_semantics.pdf -------------------------------------------------------------------------------- /plots/llama_layer12_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_semantics.png -------------------------------------------------------------------------------- /plots/llama_layer12_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_syntax.pdf -------------------------------------------------------------------------------- /plots/llama_layer12_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_syntax.png -------------------------------------------------------------------------------- /plots/llama_layer8_mmlu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_mmlu.pdf -------------------------------------------------------------------------------- /plots/llama_layer8_mmlu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_mmlu.png -------------------------------------------------------------------------------- /plots/llama_layer8_semantics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_semantics.pdf -------------------------------------------------------------------------------- /plots/llama_layer8_semantics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_semantics.png -------------------------------------------------------------------------------- /plots/llama_layer8_syntax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_syntax.pdf -------------------------------------------------------------------------------- /plots/llama_layer8_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_syntax.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build.targets.wheel] 6 | packages = ["."] 7 | 8 | [project] 9 | name = "sieve" 10 | version = "0.1.0" 11 | requires-python = ">=3.10" 12 | dependencies = [ 13 | "torch>=2.4.1,<2.5.0", 14 | "transformers>=4.45.2", 15 | "sae_lens>=3.23.0", 16 | "openai>=1.52.1", 17 | "ipykernel", 18 | "lm_eval", 19 | "seaborn" 20 | ] 21 | 22 | [tool.pyright] 23 | typeCheckingMode = "standard" 24 | -------------------------------------------------------------------------------- /sae/kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py 3 | """ 4 | 5 | import torch 6 | import triton 7 | import triton.language as tl 8 | 9 | 10 | def triton_sparse_transpose_dense_matmul( 11 | sparse_indices: torch.Tensor, 12 | sparse_values: torch.Tensor, 13 | dense: torch.Tensor, 14 | N: int, 15 | BLOCK_SIZE_AK=128, 16 | ) -> torch.Tensor: 17 | """ 18 | calculates sparse.T @ dense (i.e reducing along the collated dimension of sparse) 19 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous) 20 | 21 | sparse_indices is shape (A, k) 22 | sparse_values is shape (A, k) 23 | dense is shape (A, B) 24 | 25 | output is shape (N, B) 26 | """ 27 | 28 | assert sparse_indices.shape == sparse_values.shape 29 | assert sparse_indices.is_contiguous() 30 | assert sparse_values.is_contiguous() 31 | assert dense.is_contiguous() # contiguous along B 32 | 33 | K = sparse_indices.shape[1] 34 | A = dense.shape[0] 35 | assert sparse_indices.shape[0] == A 36 | 37 | # COO-format and sorted 38 | sorted_indices = sparse_indices.view(-1).sort() 39 | coo_indices = torch.stack( 40 | [ 41 | torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[ 42 | sorted_indices.indices 43 | ], 44 | sorted_indices.values, 45 | ] 46 | ) # shape (2, A * K) 47 | coo_values = sparse_values.view(-1)[sorted_indices.indices] # shape (A * K,) 48 | return triton_coo_sparse_dense_matmul( 49 | coo_indices, coo_values, dense, N, BLOCK_SIZE_AK 50 | ) 51 | 52 | 53 | def triton_coo_sparse_dense_matmul( 54 | coo_indices: torch.Tensor, 55 | coo_values: torch.Tensor, 56 | dense: torch.Tensor, 57 | N: int, 58 | BLOCK_SIZE_AK=128, 59 | ) -> torch.Tensor: 60 | AK = coo_indices.shape[1] 61 | B = dense.shape[1] 62 | 63 | out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype) 64 | 65 | def grid(META): 66 | return triton.cdiv(AK, META["BLOCK_SIZE_AK"]), 1 67 | 68 | triton_sparse_transpose_dense_matmul_kernel[grid]( 69 | coo_indices, 70 | coo_values, 71 | dense, 72 | out, 73 | stride_da=dense.stride(0), 74 | stride_db=dense.stride(1), 75 | B=B, 76 | N=N, 77 | AK=AK, 78 | BLOCK_SIZE_AK=BLOCK_SIZE_AK, 79 | BLOCK_SIZE_B=triton.next_power_of_2(B), 80 | ) 81 | return out 82 | 83 | 84 | @triton.jit 85 | def triton_sparse_transpose_dense_matmul_kernel( 86 | coo_indices_ptr, 87 | coo_values_ptr, 88 | dense_ptr, 89 | out_ptr, 90 | stride_da, 91 | stride_db, 92 | B, 93 | N, 94 | AK, 95 | BLOCK_SIZE_AK: tl.constexpr, 96 | BLOCK_SIZE_B: tl.constexpr, 97 | ): 98 | """ 99 | coo_indices is shape (2, AK) 100 | coo_values is shape (AK,) 101 | dense is shape (A, B), contiguous along B 102 | out is shape (N, B) 103 | """ 104 | 105 | pid_ak = tl.program_id(0) 106 | pid_b = tl.program_id(1) 107 | 108 | coo_offsets = tl.arange(0, BLOCK_SIZE_AK) 109 | b_offsets = tl.arange(0, BLOCK_SIZE_B) 110 | 111 | A_coords = tl.load( 112 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, 113 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 114 | ) 115 | K_coords = tl.load( 116 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK, 117 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 118 | ) 119 | values = tl.load( 120 | coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, 121 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 122 | ) 123 | 124 | last_k = tl.min(K_coords) 125 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) 126 | 127 | for ind in range(BLOCK_SIZE_AK): 128 | if ind + pid_ak * BLOCK_SIZE_AK < AK: 129 | # workaround to do A_coords[ind] 130 | a = tl.sum( 131 | tl.where( 132 | tl.arange(0, BLOCK_SIZE_AK) == ind, 133 | A_coords, 134 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64), 135 | ) 136 | ) 137 | 138 | k = tl.sum( 139 | tl.where( 140 | tl.arange(0, BLOCK_SIZE_AK) == ind, 141 | K_coords, 142 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64), 143 | ) 144 | ) 145 | 146 | v = tl.sum( 147 | tl.where( 148 | tl.arange(0, BLOCK_SIZE_AK) == ind, 149 | values, 150 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32), 151 | ) 152 | ) 153 | 154 | tl.device_assert(k < N) 155 | 156 | if k != last_k: 157 | tl.atomic_add( 158 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, 159 | accum, 160 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B, 161 | ) 162 | accum *= 0 163 | last_k = k 164 | 165 | if v != 0: 166 | accum += v * tl.load( 167 | dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B 168 | ) 169 | 170 | tl.atomic_add( 171 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, 172 | accum, 173 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B, 174 | ) 175 | 176 | 177 | def triton_sparse_dense_matmul( 178 | sparse_indices: torch.Tensor, 179 | sparse_values: torch.Tensor, 180 | dense: torch.Tensor, 181 | ) -> torch.Tensor: 182 | """ 183 | calculates sparse @ dense (i.e reducing along the uncollated dimension of sparse) 184 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous) 185 | 186 | sparse_indices is shape (A, k) 187 | sparse_values is shape (A, k) 188 | dense is shape (N, B) 189 | 190 | output is shape (A, B) 191 | """ 192 | N = dense.shape[0] 193 | assert sparse_indices.shape == sparse_values.shape 194 | assert sparse_indices.is_contiguous() 195 | assert sparse_values.is_contiguous() 196 | assert dense.is_contiguous() # contiguous along B 197 | 198 | A = sparse_indices.shape[0] 199 | K = sparse_indices.shape[1] 200 | B = dense.shape[1] 201 | 202 | out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype) 203 | 204 | triton_sparse_dense_matmul_kernel[(A,)]( 205 | sparse_indices, 206 | sparse_values, 207 | dense, 208 | out, 209 | stride_dn=dense.stride(0), 210 | stride_db=dense.stride(1), 211 | A=A, 212 | B=B, 213 | N=N, 214 | K=K, 215 | BLOCK_SIZE_K=triton.next_power_of_2(K), 216 | BLOCK_SIZE_B=triton.next_power_of_2(B), 217 | ) 218 | return out 219 | 220 | 221 | @triton.jit 222 | def triton_sparse_dense_matmul_kernel( 223 | sparse_indices_ptr, 224 | sparse_values_ptr, 225 | dense_ptr, 226 | out_ptr, 227 | stride_dn, 228 | stride_db, 229 | A, 230 | B, 231 | N, 232 | K, 233 | BLOCK_SIZE_K: tl.constexpr, 234 | BLOCK_SIZE_B: tl.constexpr, 235 | ): 236 | """ 237 | sparse_indices is shape (A, K) 238 | sparse_values is shape (A, K) 239 | dense is shape (N, B), contiguous along B 240 | out is shape (A, B) 241 | """ 242 | 243 | pid = tl.program_id(0) 244 | 245 | offsets_k = tl.arange(0, BLOCK_SIZE_K) 246 | sparse_indices = tl.load( 247 | sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K 248 | ) # shape (K,) 249 | sparse_values = tl.load( 250 | sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K 251 | ) # shape (K,) 252 | 253 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) 254 | 255 | offsets_b = tl.arange(0, BLOCK_SIZE_B) 256 | 257 | for k in range(K): 258 | # workaround to do sparse_indices[k] 259 | i = tl.sum( 260 | tl.where( 261 | tl.arange(0, BLOCK_SIZE_K) == k, 262 | sparse_indices, 263 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 264 | ) 265 | ) 266 | # workaround to do sparse_values[k] 267 | v = tl.sum( 268 | tl.where( 269 | tl.arange(0, BLOCK_SIZE_K) == k, 270 | sparse_values, 271 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32), 272 | ) 273 | ) 274 | 275 | tl.device_assert(i < N) 276 | if v != 0: 277 | accum += v * tl.load( 278 | dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B 279 | ) 280 | 281 | tl.store( 282 | out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B 283 | ) 284 | 285 | 286 | def triton_dense_dense_sparseout_matmul( 287 | dense1: torch.Tensor, 288 | dense2: torch.Tensor, 289 | at_indices: torch.Tensor, 290 | ) -> torch.Tensor: 291 | """ 292 | dense1: shape (A, B) 293 | dense2: shape (B, N) 294 | at_indices: shape (A, K) 295 | out values: shape (A, K) 296 | calculates dense1 @ dense2 only for the indices in at_indices 297 | 298 | equivalent to (dense1 @ dense2).gather(1, at_indices) 299 | """ 300 | A, B = dense1.shape 301 | N = dense2.shape[1] 302 | assert dense2.shape[0] == B 303 | assert at_indices.shape[0] == A 304 | K = at_indices.shape[1] 305 | assert at_indices.is_contiguous() 306 | 307 | assert dense1.stride(1) == 1, "dense1 must be contiguous along B" 308 | assert dense2.stride(0) == 1, "dense2 must be contiguous along B" 309 | 310 | if K > 512: 311 | # print("WARN - using naive matmul for large K") 312 | # naive is more efficient for large K 313 | return (dense1 @ dense2).gather(1, at_indices) 314 | 315 | out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype) 316 | 317 | # grid = lambda META: (triton.cdiv(A, META['BLOCK_SIZE_A']),) 318 | 319 | triton_dense_dense_sparseout_matmul_kernel[(A,)]( 320 | dense1, 321 | dense2, 322 | at_indices, 323 | out, 324 | stride_d1a=dense1.stride(0), 325 | stride_d1b=dense1.stride(1), 326 | stride_d2b=dense2.stride(0), 327 | stride_d2n=dense2.stride(1), 328 | A=A, 329 | B=B, 330 | N=N, 331 | K=K, 332 | BLOCK_SIZE_B=triton.next_power_of_2(B), 333 | BLOCK_SIZE_N=triton.next_power_of_2(N), 334 | BLOCK_SIZE_K=triton.next_power_of_2(K), 335 | ) 336 | 337 | return out 338 | 339 | 340 | @triton.jit 341 | def triton_dense_dense_sparseout_matmul_kernel( 342 | dense1_ptr, 343 | dense2_ptr, 344 | at_indices_ptr, 345 | out_ptr, 346 | stride_d1a, 347 | stride_d1b, 348 | stride_d2b, 349 | stride_d2n, 350 | A, 351 | B, 352 | N, 353 | K, 354 | BLOCK_SIZE_B: tl.constexpr, 355 | BLOCK_SIZE_N: tl.constexpr, 356 | BLOCK_SIZE_K: tl.constexpr, 357 | ): 358 | """ 359 | dense1: shape (A, B) 360 | dense2: shape (B, N) 361 | at_indices: shape (A, K) 362 | out values: shape (A, K) 363 | """ 364 | 365 | pid = tl.program_id(0) 366 | 367 | offsets_k = tl.arange(0, BLOCK_SIZE_K) 368 | at_indices = tl.load( 369 | at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K 370 | ) # shape (K,) 371 | 372 | offsets_b = tl.arange(0, BLOCK_SIZE_B) 373 | dense1 = tl.load( 374 | dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B 375 | ) # shape (B,) 376 | 377 | accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) 378 | 379 | for k in range(K): 380 | # workaround to do at_indices[b] 381 | i = tl.sum( 382 | tl.where( 383 | tl.arange(0, BLOCK_SIZE_K) == k, 384 | at_indices, 385 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 386 | ) 387 | ) 388 | tl.device_assert(i < N) 389 | 390 | dense2col = tl.load( 391 | dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B 392 | ) # shape (B,) 393 | accum += tl.where( 394 | tl.arange(0, BLOCK_SIZE_K) == k, 395 | tl.sum(dense1 * dense2col), 396 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 397 | ) 398 | 399 | tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K) 400 | 401 | 402 | class TritonDecoder(torch.autograd.Function): 403 | @staticmethod 404 | def forward(ctx, sparse_indices, sparse_values, decoder_weight): 405 | ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight) 406 | return triton_sparse_dense_matmul( 407 | sparse_indices, sparse_values, decoder_weight.T 408 | ) 409 | 410 | @staticmethod 411 | def backward(ctx, grad_output): 412 | sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors 413 | 414 | assert grad_output.is_contiguous(), "grad_output must be contiguous" 415 | 416 | decoder_grad = triton_sparse_transpose_dense_matmul( 417 | sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1] 418 | ).T 419 | 420 | return ( 421 | None, 422 | triton_dense_dense_sparseout_matmul( 423 | grad_output, decoder_weight, sparse_indices 424 | ), 425 | # decoder is contiguous when transposed so this is a matching layout 426 | decoder_grad, 427 | None, 428 | ) 429 | -------------------------------------------------------------------------------- /sae/sae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pickle 4 | import json 5 | import os 6 | from typing import Tuple, Union 7 | from huggingface_hub import hf_hub_download 8 | from torch import Tensor 9 | from .utils import SaeConfig 10 | 11 | """ This logic is taken from Eleuther's SAE repo. """ 12 | # Fallback implementation of SAE decoder 13 | def eager_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor): 14 | buf = top_acts.new_zeros(top_acts.shape[:-1] + (W_dec.shape[-1],)) 15 | acts = buf.scatter_(dim=-1, index=top_indices, src=top_acts) 16 | return acts @ W_dec.mT 17 | 18 | 19 | # Triton implementation of SAE decoder 20 | def triton_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor): 21 | return TritonDecoder.apply(top_indices, top_acts, W_dec) 22 | 23 | 24 | try: 25 | from .kernels import TritonDecoder 26 | except ImportError: 27 | decoder_impl = eager_decode 28 | print("Triton not installed, using eager implementation of SAE decoder.") 29 | else: 30 | if os.environ.get("SAE_DISABLE_TRITON") == "1": 31 | print("Triton disabled, using eager implementation of SAE decoder.") 32 | decoder_impl = eager_decode 33 | else: 34 | decoder_impl = triton_decode 35 | 36 | 37 | class TopK(nn.Module): 38 | __is_sparse__ = True 39 | """Top-k activation function""" 40 | def __init__(self, k: int): 41 | super().__init__() 42 | self.k = k 43 | 44 | def sparse_forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: 45 | x = nn.ReLU()(x) 46 | return torch.topk(x, self.k, dim=-1, sorted=False) 47 | 48 | def dense_forward(self, x: Tensor) -> Tensor: 49 | acts, indices = self.sparse_forward(x) 50 | return acts.scatter_(dim=-1, index=indices, src=acts) 51 | 52 | def forward(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: 53 | if self.__is_sparse__: 54 | return self.sparse_forward(x) 55 | else: 56 | return self.dense_forward(x) 57 | 58 | def __repr__(self): 59 | return f"TopK(k={self.k})" 60 | 61 | class Sae(nn.Module): 62 | """Streamlined Sparse Autoencoder for inference only""" 63 | 64 | def __init__(self, config: SaeConfig): 65 | super().__init__() 66 | self.cfg = config 67 | self.d_in = config.d_model 68 | self.d_sae = config.d_model * config.expansion_factor 69 | 70 | # Core parameters 71 | self.W_enc_DF = nn.Parameter(torch.empty(self.d_in, self.d_sae)) 72 | self.b_enc_F = nn.Parameter(torch.zeros(self.d_sae)) 73 | self.W_dec_FD = nn.Parameter(torch.empty(self.d_sae, self.d_in)) 74 | self.b_dec_D = nn.Parameter(torch.zeros(self.d_in)) 75 | 76 | self.device = torch.device(config.device) 77 | self.dtype = getattr(torch, config.dtype.split(".")[1]) 78 | self.activation_fns = None 79 | self.to(self.device, self.dtype) 80 | self.eval() 81 | 82 | @torch.no_grad() 83 | def encode(self, x: torch.Tensor, activation_fn_idx: int = None) -> torch.Tensor: 84 | if activation_fn_idx is None: 85 | activation_fn_idx = self.cfg.eval_activation_idx 86 | 87 | 88 | 89 | pre_acts = torch.matmul(x, self.W_enc_DF) + self.b_enc_F 90 | acts = self.activation_fns[activation_fn_idx](pre_acts) 91 | return acts 92 | 93 | @torch.no_grad() 94 | def decode(self, features: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], eager: bool = False) -> torch.Tensor: 95 | if isinstance(features, tuple): # Sparse feats 96 | top_indices, top_acts = features 97 | if eager or top_indices.shape[-1] >= 512: # Heuristic for when triton kernel is slower than eager decoding 98 | y = eager_decode(top_indices, top_acts, self.W_dec_FD.mT) 99 | else: 100 | y = decoder_impl(top_indices, top_acts, self.W_dec_FD.mT) 101 | else: 102 | y = torch.matmul(features, self.W_dec_FD) 103 | y = y + self.b_dec_D 104 | return y 105 | 106 | @torch.no_grad() 107 | def forward(self, x: torch.Tensor, activation_fn_idx: int = None) -> torch.Tensor: 108 | if activation_fn_idx is None: 109 | activation_fn_idx = self.cfg.eval_activation_idx 110 | initial_shape = x.shape 111 | x = x.reshape(-1, self.d_in) 112 | f = self.encode(x, activation_fn_idx) 113 | return self.decode(f).reshape(initial_shape), f 114 | 115 | @classmethod 116 | def from_pretrained(cls, repo_id: str, cache_dir: str = None, layer_idx: int = 12, token=None) -> "Sae": 117 | if cache_dir is None: 118 | cache_dir = os.path.expanduser(f"~/.cache/huggingface/{repo_id}") 119 | os.makedirs(cache_dir, exist_ok=True) 120 | 121 | config_file = hf_hub_download(repo_id=repo_id, filename="config.json", cache_dir=cache_dir, token=token) 122 | with open(config_file, 'r') as f: 123 | config_dict = json.load(f) 124 | 125 | assert "sae_cfg" in config_dict, "config.json must contain 'sae_cfg' key" 126 | config = SaeConfig.from_dict(config_dict["sae_cfg"]) 127 | model = cls(config) 128 | 129 | weight_file = f"layer_{layer_idx}.pt" 130 | weights_path = hf_hub_download(repo_id=repo_id, filename=weight_file, cache_dir=cache_dir, token=token) 131 | 132 | with open(weights_path, "rb") as f: 133 | state_dict = pickle.load(f) 134 | 135 | model.load_state_dict(state_dict) 136 | 137 | model.activation_fns = [TopK(k=config.activation_fn_kwargs[i]["k"]) for i in range(len(config.activation_fn_kwargs))] 138 | 139 | return model 140 | 141 | def __repr__(self): 142 | return f"Sae(d_in={self.d_in}, d_sae={self.d_sae})" 143 | 144 | @property 145 | def W_enc(self): 146 | return self.W_enc_DF 147 | 148 | @property 149 | def W_dec(self): 150 | return self.W_dec_FD 151 | 152 | @property 153 | def b_enc(self): 154 | return self.b_enc_F 155 | 156 | @property 157 | def b_dec(self): 158 | return self.b_dec_D -------------------------------------------------------------------------------- /sae/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from huggingface_hub import hf_hub_download, list_repo_files 5 | from typing import List, Optional, Dict 6 | from dataclasses import dataclass 7 | import mmap 8 | 9 | @dataclass 10 | class SaeConfig: 11 | """Configuration for SAE inference - only essential parameters""" 12 | d_model: int 13 | expansion_factor: int = 8 14 | device: str = "cuda" if torch.cuda.is_available() else "cpu" 15 | dtype: str = "torch.bfloat16" 16 | apply_activation_fn: bool = True 17 | activation_fn_names: List[str] = None 18 | activation_fn_kwargs: List[dict] = None 19 | eval_activation_idx: int = 0 20 | 21 | @classmethod 22 | def from_dict(cls, config_dict: dict) -> "SaeConfig": 23 | essential_params = { 24 | 'd_model', 'expansion_factor', 'device', 'dtype', 25 | 'apply_activation_fn', 26 | 'activation_fn_names', 'activation_fn_kwargs', 27 | 'eval_activation_idx' 28 | } 29 | filtered_dict = {k: v for k, v in config_dict.items() if k in essential_params} 30 | return cls(**filtered_dict) 31 | -------------------------------------------------------------------------------- /src/agent_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Dict, Any, List 2 | from openai import OpenAI 3 | from dataclasses import dataclass 4 | import src.utils as utils 5 | import json 6 | from tqdm import tqdm 7 | import time 8 | import os 9 | 10 | @dataclass 11 | class CodeEvaluation: 12 | """Stores the evaluation results for a code snippet""" 13 | is_syntactically_valid: bool 14 | is_semantically_valid: bool 15 | uses_regex: bool 16 | executes_successfully: bool 17 | follows_prompt: bool 18 | explanation: str 19 | properties: Dict[str, Any] 20 | 21 | def evaluate_code_with_gpt4(code: str, prompt: str, client: OpenAI) -> Tuple[bool, bool, str, Dict[str, Any]]: 22 | """ 23 | Evaluate code using GPT-4 to determine if it successfully executes the prompt 24 | and analyze its properties. 25 | """ 26 | system_prompt = """You are an expert Python programmer evaluating code solutions. 27 | Analyze the given code and determine: 28 | 1. Whether it successfully implements the requirements from the prompt 29 | 2. Whether it would execute successfully 30 | 3. Key properties of the implementation 31 | 32 | Provide your response in JSON format with the following fields: 33 | { 34 | "executes_successfully": bool, 35 | "follows_prompt": bool, 36 | "explanation": str, 37 | "properties": { 38 | "uses_list_comprehension": bool, 39 | "uses_error_handling": bool, 40 | "is_efficient": bool, 41 | "is_readable": bool, 42 | } 43 | } 44 | """ 45 | 46 | user_message = f""" 47 | Original Prompt: 48 | {prompt} 49 | 50 | Code to Evaluate: ```python 51 | {code} ``` 52 | 53 | Evaluate this code and provide your analysis in the requested JSON format. 54 | """ 55 | 56 | try: 57 | response = client.chat.completions.create( 58 | model="gpt-4o-mini", 59 | messages=[ 60 | {"role": "system", "content": system_prompt}, 61 | {"role": "user", "content": user_message} 62 | ], 63 | temperature=0.1 64 | ) 65 | 66 | result = response.choices[0].message.content 67 | if result is None: 68 | return False, False, "No response from GPT-4", {} 69 | 70 | try: 71 | parsed_result = json.loads(result) 72 | except json.JSONDecodeError: 73 | try: 74 | parsed_result = eval(result) 75 | except Exception: 76 | return False, False, f"Failed to parse response: {result}", {} 77 | 78 | return ( 79 | parsed_result["executes_successfully"], 80 | parsed_result["follows_prompt"], 81 | parsed_result["explanation"], 82 | parsed_result["properties"] 83 | ) 84 | 85 | except Exception as e: 86 | return False, False, f"GPT-4 evaluation failed: {str(e)}", {} 87 | 88 | def evaluate_code( 89 | code: str, 90 | prompt: str, 91 | client: OpenAI 92 | ) -> CodeEvaluation: 93 | """ 94 | Comprehensive evaluation of a code snippet using both automated checks 95 | and GPT-4 analysis. 96 | """ 97 | # Run automated checks 98 | is_syntactically_valid, _ = utils.is_syntactically_valid_python(code) 99 | is_semantically_valid, _ = utils.is_semantically_valid_python(code) 100 | uses_regex = utils.check_for_re_usage(code) 101 | 102 | # Get GPT-4 evaluation 103 | executes_successfully, follows_prompt, explanation, properties = ( 104 | evaluate_code_with_gpt4(code, prompt, client) 105 | ) 106 | 107 | return CodeEvaluation( 108 | is_syntactically_valid=is_syntactically_valid, 109 | is_semantically_valid=is_semantically_valid, 110 | uses_regex=uses_regex, 111 | executes_successfully=executes_successfully, 112 | follows_prompt=follows_prompt, 113 | explanation=explanation, 114 | properties=properties 115 | ) 116 | 117 | def batch_evaluate_generations( 118 | generations: list[str], 119 | prompt: str, 120 | api_key: Optional[str] = None, 121 | batch_size: int = 10, 122 | retry_delay: float = 1.0, 123 | max_retries: int = 3 124 | ) -> list[CodeEvaluation]: 125 | """ 126 | Evaluate multiple code generations in batches with retry logic. 127 | 128 | Args: 129 | generations: List of generated code snippets 130 | prompt: The original prompt that requested the code 131 | api_key: Optional OpenAI API key 132 | batch_size: Number of evaluations to process in parallel 133 | retry_delay: Delay between retries in seconds 134 | max_retries: Maximum number of retry attempts 135 | 136 | Returns: 137 | List of CodeEvaluation objects 138 | """ 139 | client = OpenAI( 140 | api_key=api_key, 141 | ) if api_key else OpenAI() 142 | 143 | results = [] 144 | valid_codes = [] 145 | 146 | # First extract all valid Python code 147 | for code in generations: 148 | python_code = utils.extract_python(code) 149 | if python_code is not None: 150 | valid_codes.append(python_code) 151 | 152 | # Process in batches 153 | for i, code in tqdm(enumerate(valid_codes), total=len(valid_codes), desc="Evaluating code batches"): 154 | 155 | retries = 0 156 | while retries < max_retries: 157 | try: 158 | evaluation = evaluate_code(code, prompt, client) 159 | results.append(evaluation) 160 | break 161 | except Exception as e: 162 | retries += 1 163 | if retries == max_retries: 164 | print(f"Failed to evaluate code after {max_retries} attempts: {e}") 165 | continue 166 | time.sleep(retry_delay) 167 | 168 | 169 | # Small delay between batches to avoid rate limits 170 | time.sleep(0.01) 171 | 172 | return results 173 | 174 | def summarize_evaluations(evaluations: list[CodeEvaluation]) -> Dict[str, Any]: 175 | """ 176 | Summarize the results of multiple code evaluations. 177 | 178 | Args: 179 | evaluations: List of CodeEvaluation objects 180 | 181 | Returns: 182 | Dictionary containing summary statistics 183 | """ 184 | total = len(evaluations) 185 | if total == 0: 186 | return {"error": "No evaluations to summarize"} 187 | 188 | summary = { 189 | "total_samples": total, 190 | "syntactically_valid": sum(1 for e in evaluations if e.is_syntactically_valid) / total, 191 | "semantically_valid": sum(1 for e in evaluations if e.is_semantically_valid) / total, 192 | "uses_regex": sum(1 for e in evaluations if e.uses_regex) / total, 193 | "executes_successfully": sum(1 for e in evaluations if e.executes_successfully) / total, 194 | "follows_prompt": sum(1 for e in evaluations if e.follows_prompt) / total, 195 | "property_stats": { 196 | "uses_list_comprehension": sum(1 for e in evaluations if e.properties.get("uses_list_comprehension", False)) / total, 197 | "uses_error_handling": sum(1 for e in evaluations if e.properties.get("uses_error_handling", False)) / total, 198 | "is_efficient": sum(1 for e in evaluations if e.properties.get("is_efficient", False)) / total, 199 | "is_readable": sum(1 for e in evaluations if e.properties.get("is_readable", False)) / total, 200 | }, 201 | "complexity_distribution": {} 202 | } 203 | 204 | # Count complexity distributions 205 | complexity_counts = {} 206 | for eval in evaluations: 207 | complexity = eval.properties.get("complexity", "unknown") 208 | complexity_counts[complexity] = complexity_counts.get(complexity, 0) + 1 209 | 210 | summary["complexity_distribution"] = { 211 | k: v/total for k, v in complexity_counts.items() 212 | } 213 | 214 | return summary 215 | 216 | # Example usage 217 | if __name__ == "__main__": 218 | code = """ 219 | def extract_numbers(text): 220 | import re 221 | return [int(num) for num in re.findall(r'\d+', text)] 222 | """ 223 | 224 | prompt = "Write a function that extracts all numbers from a text string using regex" 225 | api_key = os.environ.get("OPENAI_API_KEY") 226 | 227 | # Initialize client once 228 | client = OpenAI(api_key=api_key) 229 | 230 | # Single evaluation 231 | evaluation = evaluate_code(code, prompt, client) 232 | print(f"Evaluation results:\n{evaluation}") 233 | 234 | # Batch evaluation 235 | generations = [code, code] # Example with duplicate code 236 | evaluations = batch_evaluate_generations(generations, prompt, api_key) 237 | summary = summarize_evaluations(evaluations) 238 | print(f"\nSummary of evaluations:\n{summary}") -------------------------------------------------------------------------------- /src/caa.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer 2 | from typing import Dict, Tuple, Optional, Any, List 3 | import json 4 | import torch 5 | import einops 6 | from pathlib import Path 7 | from dataclasses import dataclass 8 | from tqdm import tqdm 9 | import src.utils as utils 10 | from src.eval_config import EvalConfig 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.model_selection import train_test_split 13 | 14 | 15 | def load_prompts( 16 | prompts_folder: str, contrastive_prompts_filename: str, code_filename: str 17 | ) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str]]: 18 | """Load contrastive prompts and code examples from JSON files. 19 | 20 | Args: 21 | prompts_folder: Directory containing prompt files 22 | contrastive_prompts_filename: Filename for contrastive prompts JSON 23 | code_filename: Filename for code examples JSON 24 | 25 | Returns: 26 | Tuple containing: 27 | - Dictionary of contrastive prompts with structure {prompt_type: {base: str, pos: str, neg: str}} 28 | - Dictionary of code examples with structure {id: code_str} 29 | 30 | Raises: 31 | FileNotFoundError: If either JSON file is not found 32 | JSONDecodeError: If either file contains invalid JSON 33 | """ 34 | prompts_path = Path(prompts_folder) 35 | 36 | with open(prompts_path / contrastive_prompts_filename) as f: 37 | contrastive_prompts = json.load(f) 38 | 39 | with open(prompts_path / code_filename) as f: 40 | code = json.load(f) 41 | 42 | return contrastive_prompts, code 43 | 44 | 45 | def format_contrastive_prompt( 46 | contrastive_prompts: Dict[str, Dict[str, str]], 47 | code_block: str, 48 | prompt_type: str, 49 | prompt_polarity: str, 50 | tokenizer: PreTrainedTokenizer, 51 | ) -> str: 52 | """Format a contrastive prompt with code for model input. 53 | 54 | Args: 55 | contrastive_prompts: Dictionary of prompt templates 56 | code_block: Code snippet to insert into prompt 57 | prompt_type: Type of prompt to use (must exist in contrastive_prompts) 58 | prompt_polarity: Either "pos" or "neg" for positive/negative prompt 59 | tokenizer: Tokenizer for formatting chat template 60 | 61 | Returns: 62 | Formatted prompt string ready for model input 63 | 64 | Raises: 65 | AssertionError: If prompt_type or prompt_polarity is invalid 66 | """ 67 | assert prompt_type in contrastive_prompts, f"Prompt type {prompt_type} not found" 68 | assert prompt_polarity in ["pos", "neg"], f"Invalid prompt polarity: {prompt_polarity}" 69 | 70 | prompt = f"{contrastive_prompts[prompt_type]['base']}\n{contrastive_prompts[prompt_type][prompt_polarity]}\n" 71 | prompt += "```python\n" + code_block + "```\n" 72 | 73 | chat = [{"role": "user", "content": prompt}] 74 | # Handle different tokenizer types 75 | if hasattr(tokenizer, "apply_chat_template"): 76 | result = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 77 | return result if isinstance(result, str) else result[0] 78 | return f"{prompt}\n" 79 | 80 | 81 | def get_layer_activations( 82 | model: AutoModelForCausalLM, target_layer: int, inputs: torch.Tensor 83 | ) -> torch.Tensor: 84 | """Extract activations from a specific transformer layer. 85 | 86 | Args: 87 | model: The causal language model 88 | target_layer: Index of layer to extract from 89 | inputs: Input token ids tensor of shape (batch_size, seq_len) 90 | 91 | Returns: 92 | Tensor of activations with shape (batch_size, seq_len, hidden_dim) 93 | 94 | Raises: 95 | AttributeError: If model architecture is not supported 96 | RuntimeError: If no activations were captured 97 | """ 98 | acts_BLD: Optional[torch.Tensor] = None 99 | 100 | def gather_target_act_hook(module, inputs, outputs): 101 | nonlocal acts_BLD 102 | acts_BLD = outputs[0] 103 | return outputs 104 | 105 | # Support different model architectures 106 | if hasattr(model, "transformer"): 107 | layers = model.transformer.h 108 | elif hasattr(model, "model"): 109 | layers = model.model.layers 110 | else: 111 | raise AttributeError("Model architecture not supported") 112 | 113 | handle = layers[target_layer].register_forward_hook(gather_target_act_hook) 114 | with torch.no_grad(): 115 | _ = model(inputs.to(model.device)) 116 | handle.remove() 117 | 118 | if acts_BLD is None: 119 | raise RuntimeError("No activations were captured") 120 | return acts_BLD 121 | 122 | 123 | def get_layer_activations_with_generation( 124 | model: AutoModelForCausalLM, 125 | target_layer: int, 126 | inputs: torch.Tensor, 127 | max_new_tokens: int, 128 | temperature: float = 1.0, 129 | ) -> Tuple[torch.Tensor, torch.Tensor]: 130 | """Get layer activations during text generation. 131 | 132 | Args: 133 | model: The causal language model 134 | target_layer: Index of layer to extract from 135 | inputs: Input token ids tensor 136 | **generation_kwargs: Arguments passed to model.generate() 137 | 138 | Returns: 139 | Tuple containing: 140 | - Tensor of activations during generation 141 | - Generated token ids tensor 142 | """ 143 | acts_BLD: List[torch.Tensor] = [] 144 | 145 | def gather_target_act_hook(module, inputs, outputs): 146 | nonlocal acts_BLD 147 | acts_BLD.append(outputs[0]) 148 | return outputs 149 | 150 | handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook) 151 | with torch.no_grad(): 152 | tokens = model.generate( 153 | inputs.to(model.device), 154 | max_new_tokens=max_new_tokens, 155 | temperature=temperature, 156 | do_sample=True, 157 | ) 158 | handle.remove() 159 | 160 | return torch.cat(acts_BLD, dim=1), tokens 161 | 162 | 163 | @torch.no_grad() 164 | def calculate_probe_vector( 165 | prompts_folder: str, 166 | contrastive_prompts_filename: str, 167 | code_filename: str, 168 | model: AutoModelForCausalLM, 169 | tokenizer: PreTrainedTokenizer, 170 | prompt_type: str, 171 | layer: int, 172 | n_samples: int = 50, 173 | max_new_tokens: int = 400, 174 | temperature: float = 1.0, 175 | ) -> Tuple[torch.Tensor, torch.Tensor]: 176 | """Calculate the steering vector for a given layer. 177 | Follows this general methodology: https://arxiv.org/abs/2410.12877""" 178 | contrastive_prompts, code = load_prompts( 179 | prompts_folder, contrastive_prompts_filename, code_filename 180 | ) 181 | 182 | train_pos_activations = [] 183 | train_neg_activations = [] 184 | test_pos_activations = [] 185 | test_neg_activations = [] 186 | n_train = int(n_samples * 0.8) 187 | 188 | for code_id, code_block in tqdm( 189 | code.items(), total=len(code), desc="Gathering activations for probe training" 190 | ): 191 | pos_prompt = format_contrastive_prompt( 192 | contrastive_prompts, code_block, prompt_type, "pos", tokenizer 193 | ) 194 | neg_prompt = format_contrastive_prompt( 195 | contrastive_prompts, code_block, prompt_type, "neg", tokenizer 196 | ) 197 | pos_tokens = tokenizer( 198 | [pos_prompt] * n_samples, add_special_tokens=False, return_tensors="pt" 199 | )["input_ids"].to(model.device) 200 | neg_tokens = tokenizer( 201 | [neg_prompt] * n_samples, add_special_tokens=False, return_tensors="pt" 202 | )["input_ids"].to(model.device) 203 | 204 | pos_activations, new_pos_tokens = get_layer_activations_with_generation( 205 | model, layer, pos_tokens, max_new_tokens=max_new_tokens, temperature=temperature 206 | ) 207 | neg_activations, new_neg_tokens = get_layer_activations_with_generation( 208 | model, layer, neg_tokens, max_new_tokens=max_new_tokens, temperature=temperature 209 | ) 210 | 211 | # Get generated tokens only. : -1 because we don't have activations for the final generated token 212 | new_pos_tokens = new_pos_tokens[:, pos_tokens.shape[1] : -1] 213 | new_neg_tokens = new_neg_tokens[:, neg_tokens.shape[1] : -1] 214 | 215 | # Get activations for generated tokens only 216 | pos_activations = pos_activations[:, pos_tokens.shape[1] :] 217 | neg_activations = neg_activations[:, neg_tokens.shape[1] :] 218 | # Split activations into train and test based on n_samples first 219 | train_pos_activations_batch = pos_activations[:n_train] 220 | test_pos_activations_batch = pos_activations[n_train:] 221 | train_neg_activations_batch = neg_activations[:n_train] 222 | test_neg_activations_batch = neg_activations[n_train:] 223 | 224 | # Create masks for non-padding and non-eos tokens 225 | pos_mask = (new_pos_tokens != tokenizer.pad_token_id) & ( 226 | new_pos_tokens != tokenizer.eos_token_id 227 | ) 228 | neg_mask = (new_neg_tokens != tokenizer.pad_token_id) & ( 229 | new_neg_tokens != tokenizer.eos_token_id 230 | ) 231 | 232 | # Split masks into train and test 233 | train_pos_mask = pos_mask[:n_train] 234 | test_pos_mask = pos_mask[n_train:] 235 | train_neg_mask = neg_mask[:n_train] 236 | test_neg_mask = neg_mask[n_train:] 237 | 238 | # Reshape activations and masks to match 239 | train_pos_activations_batch = einops.rearrange( 240 | train_pos_activations_batch, "B L D -> (B L) D" 241 | ) 242 | test_pos_activations_batch = einops.rearrange( 243 | test_pos_activations_batch, "B L D -> (B L) D" 244 | ) 245 | train_neg_activations_batch = einops.rearrange( 246 | train_neg_activations_batch, "B L D -> (B L) D" 247 | ) 248 | test_neg_activations_batch = einops.rearrange( 249 | test_neg_activations_batch, "B L D -> (B L) D" 250 | ) 251 | 252 | train_pos_mask = einops.rearrange(train_pos_mask, "B L -> (B L)") 253 | test_pos_mask = einops.rearrange(test_pos_mask, "B L -> (B L)") 254 | train_neg_mask = einops.rearrange(train_neg_mask, "B L -> (B L)") 255 | test_neg_mask = einops.rearrange(test_neg_mask, "B L -> (B L)") 256 | 257 | # Filter out padding and eos tokens 258 | train_pos_activations.append(train_pos_activations_batch[train_pos_mask]) 259 | test_pos_activations.append(test_pos_activations_batch[test_pos_mask]) 260 | train_neg_activations.append(train_neg_activations_batch[train_neg_mask]) 261 | test_neg_activations.append(test_neg_activations_batch[test_neg_mask]) 262 | 263 | # Combine all activations 264 | X_train = ( 265 | torch.cat( 266 | [torch.cat(train_pos_activations, dim=0), torch.cat(train_neg_activations, dim=0)], 267 | dim=0, 268 | ) 269 | .detach() 270 | .float() 271 | .cpu() 272 | .numpy() 273 | ) 274 | y_train = ( 275 | torch.cat( 276 | [ 277 | torch.ones(sum(len(x) for x in train_pos_activations)), 278 | torch.zeros(sum(len(x) for x in train_neg_activations)), 279 | ] 280 | ) 281 | .detach() 282 | .float() 283 | .cpu() 284 | .numpy() 285 | ) 286 | 287 | X_test = ( 288 | torch.cat( 289 | [torch.cat(test_pos_activations, dim=0), torch.cat(test_neg_activations, dim=0)], dim=0 290 | ) 291 | .detach() 292 | .float() 293 | .cpu() 294 | .numpy() 295 | ) 296 | y_test = ( 297 | torch.cat( 298 | [ 299 | torch.ones(sum(len(x) for x in test_pos_activations)), 300 | torch.zeros(sum(len(x) for x in test_neg_activations)), 301 | ] 302 | ) 303 | .detach() 304 | .float() 305 | .cpu() 306 | .numpy() 307 | ) 308 | # Fit model and get probe vector 309 | linreg_model = LogisticRegression().fit(X_train, y_train) 310 | probe_vector = torch.tensor(linreg_model.coef_[0]).to(device=model.device, dtype=model.dtype) 311 | probe_bias = torch.tensor(linreg_model.intercept_[0]).to(device=model.device, dtype=model.dtype) 312 | 313 | # Report test accuracy 314 | test_acc = linreg_model.score(X_test, y_test) 315 | print(f"Test accuracy on {len(X_test)} points: {test_acc:.3f}") 316 | train_acc = linreg_model.score(X_train, y_train) 317 | print(f"Train accuracy on {len(X_train)} points: {train_acc:.3f}") 318 | 319 | return probe_vector, probe_bias 320 | 321 | 322 | @torch.no_grad() 323 | def calculate_steering_vector( 324 | prompts_folder: str, 325 | contrastive_prompts_filename: str, 326 | code_filename: str, 327 | model: AutoModelForCausalLM, 328 | tokenizer: PreTrainedTokenizer, 329 | prompt_type: str, 330 | layer: int, 331 | ) -> torch.Tensor: 332 | """Calculate the steering vector for a given layer. 333 | Follows this general methodology: https://arxiv.org/abs/2410.12877""" 334 | contrastive_prompts, code = load_prompts( 335 | prompts_folder, contrastive_prompts_filename, code_filename 336 | ) 337 | 338 | all_pos_activations = [] 339 | all_neg_activations = [] 340 | 341 | for code_id, code_block in code.items(): 342 | pos_prompt = format_contrastive_prompt( 343 | contrastive_prompts, code_block, prompt_type, "pos", tokenizer 344 | ) 345 | neg_prompt = format_contrastive_prompt( 346 | contrastive_prompts, code_block, prompt_type, "neg", tokenizer 347 | ) 348 | pos_tokens = tokenizer(pos_prompt, add_special_tokens=False, return_tensors="pt")[ 349 | "input_ids" 350 | ].to(model.device) 351 | neg_tokens = tokenizer(neg_prompt, add_special_tokens=False, return_tensors="pt")[ 352 | "input_ids" 353 | ].to(model.device) 354 | 355 | pos_activations = get_layer_activations(model, layer, pos_tokens) 356 | neg_activations = get_layer_activations(model, layer, neg_tokens) 357 | pos_activations = get_layer_activations(model, layer, pos_tokens)[:, -1, :].squeeze() 358 | neg_activations = get_layer_activations(model, layer, neg_tokens)[:, -1, :].squeeze() 359 | 360 | all_pos_activations.append(pos_activations) 361 | all_neg_activations.append(neg_activations) 362 | 363 | pos_activations = torch.stack(all_pos_activations).mean(dim=0) 364 | neg_activations = torch.stack(all_neg_activations).mean(dim=0) 365 | 366 | return pos_activations - neg_activations 367 | 368 | 369 | def get_threshold( 370 | config: EvalConfig, 371 | model_params: dict, 372 | wrapper, 373 | steering_vector_D: torch.Tensor, 374 | encoder_vector_D: torch.Tensor, 375 | encoder_threshold: float, 376 | ) -> float: 377 | base_prompt = utils.load_prompt_files(config) 378 | 379 | with open(f"{config.prompt_folder}/{config.code_filename}", "r") as f: 380 | code_blocks = json.load(f) 381 | 382 | average_threshold = 0 383 | 384 | for code_block_key, single_code_block in code_blocks.items(): 385 | prompt = base_prompt.replace("{code}", single_code_block) 386 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer) 387 | formatted_prompt += config.prefill 388 | 389 | tokens = wrapper.tokenizer(formatted_prompt, add_special_tokens=False, return_tensors="pt")[ 390 | "input_ids" 391 | ].to(wrapper.model.device) 392 | 393 | resid_BLD = get_layer_activations(wrapper.model, model_params["targ_layer"], tokens) 394 | 395 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D) 396 | above_threshold = feature_acts_BL > encoder_threshold 397 | k = above_threshold.sum().item() 398 | 399 | steering_vector_acts_BL = torch.einsum("BLD, D->BL", resid_BLD, steering_vector_D) 400 | 401 | topk_values = torch.topk(steering_vector_acts_BL.flatten(), k, largest=True)[0] 402 | threshold = topk_values[-1].item() 403 | average_threshold += threshold 404 | 405 | average_threshold /= len(code_blocks) 406 | return average_threshold 407 | -------------------------------------------------------------------------------- /src/count_activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | from dataclasses import asdict 5 | from tqdm import tqdm 6 | from typing import Callable, Dict, List, Tuple, Union, Any 7 | import json 8 | import time 9 | import sys 10 | import os 11 | import logging 12 | from jaxtyping import Float 13 | from torch import Tensor 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 18 | 19 | 20 | from src.wrapper import InterventionWrapper 21 | from src.eval_config import EvalConfig, InterventionType 22 | import src.utils as utils 23 | import src.caa as caa 24 | 25 | 26 | def get_feature_acts_with_generations( 27 | model: AutoModelForCausalLM, 28 | tokenizer: AutoTokenizer, 29 | code_id: str, 30 | target_layer: int, 31 | encoder_vectors_dict: Dict[str, torch.Tensor], 32 | inputs: torch.Tensor, 33 | max_new_tokens: int, 34 | temperature: float = 1.0, 35 | ) -> Dict[str, torch.Tensor]: 36 | """Track feature activations during model generation. 37 | 38 | This function monitors how strongly the model activates specific features 39 | during text generation, accounting for various thresholds and biases. 40 | 41 | Args: 42 | model: The language model to analyze 43 | tokenizer: Tokenizer for processing text 44 | code_id: Identifier for the code block 45 | target_layer: Which transformer layer to monitor 46 | encoder_vectors_dict: Dict mapping intervention types to feature directions 47 | inputs: Input token ids 48 | max_new_tokens: Maximum new tokens to generate 49 | temperature: Sampling temperature 50 | 51 | Returns: 52 | Dict mapping intervention types to feature activation tensors 53 | """ 54 | 55 | if "question" in code_id or "without_regex" in code_id: 56 | filter_regex = True 57 | else: 58 | filter_regex = False 59 | 60 | acts_BLD, tokens_BL = caa.get_layer_activations_with_generation( 61 | model, target_layer, inputs, max_new_tokens, temperature 62 | ) 63 | 64 | for i in range(tokens_BL.size(0)): 65 | single_prompt = tokens_BL[i] 66 | if not filter_regex: 67 | break 68 | decoded_prompt = tokenizer.decode(single_prompt) 69 | 70 | if ( 71 | "regex" in decoded_prompt 72 | or "regular expression" in decoded_prompt 73 | or utils.check_for_re_usage(decoded_prompt) 74 | ): 75 | print(f"Skipping regex prompt: {decoded_prompt}") 76 | tokens_BL[i, :] = torch.tensor([tokenizer.pad_token_id]) 77 | 78 | tokens_BL = tokens_BL[:, :-1] # There are no activations for the last generated token 79 | tokens_L = tokens_BL.flatten() 80 | 81 | retain_mask = ( 82 | (tokens_L != tokenizer.pad_token_id) 83 | & (tokens_L != tokenizer.eos_token_id) 84 | & (tokens_L != tokenizer.bos_token_id) 85 | ) 86 | 87 | # tokens_BL = tokens_BL[:, inputs.size(1) :] # Remove input tokens 88 | # acts_BLD = acts_BLD[:, inputs.size(1) :, :] 89 | 90 | feature_acts_dict = {} 91 | for intervention_type, encoder_vector_D in encoder_vectors_dict.items(): 92 | feature_acts_BL = torch.einsum("BLD,D->BL", acts_BLD, encoder_vector_D.to(acts_BLD.device)) 93 | feature_acts_L = feature_acts_BL.flatten() 94 | 95 | feature_acts_dict[intervention_type] = feature_acts_L[retain_mask] 96 | 97 | return feature_acts_dict 98 | 99 | 100 | def test_single_prompt( 101 | wrapper: InterventionWrapper, 102 | base_prompt: str, 103 | code_id: str, 104 | code_example: str, 105 | config: EvalConfig, 106 | model_params: dict, 107 | ) -> Dict[str, torch.Tensor]: 108 | """Test activation patterns for a single prompt. 109 | 110 | Args: 111 | wrapper: Model wrapper instance 112 | base_prompt: Base prompt template 113 | code_example: Code example to test 114 | config: Evaluation configuration 115 | model_params: Model-specific parameters 116 | 117 | Returns: 118 | Dictionary mapping intervention types to activation tensors 119 | """ 120 | if "question" in code_id: 121 | prompt = code_example 122 | else: 123 | prompt = base_prompt.replace("{code}", code_example) 124 | 125 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer) 126 | formatted_prompt += config.prefill 127 | 128 | batched_prompts = [formatted_prompt] * config.batch_size 129 | 130 | num_batches = config.total_generations // config.batch_size 131 | 132 | input_tokens = wrapper.tokenizer( 133 | batched_prompts, add_special_tokens=False, return_tensors="pt" 134 | )["input_ids"].to(wrapper.model.device) 135 | 136 | # Get encoder vectors for all intervention types 137 | encoder_vectors_dict = {} 138 | for intervention_type in config.intervention_types: 139 | if intervention_type == InterventionType.PROBE_SAE.value: 140 | encoder_vectors_dict[intervention_type] = wrapper.probe_vector 141 | print(f"Probe bias: {wrapper.probe_bias}") 142 | elif intervention_type == InterventionType.CONDITIONAL_PER_TOKEN.value: 143 | encoder_vector = wrapper.sae.W_enc[:, [model_params["feature_idx"]]].squeeze() 144 | encoder_vectors_dict[intervention_type] = encoder_vector 145 | bias = wrapper.sae.b_enc[model_params["feature_idx"]] 146 | # threshold = wrapper.sae.threshold[model_params["feature_idx"]] 147 | # print(f"Threshold: {threshold}, Bias: {bias}") 148 | elif intervention_type == InterventionType.CONDITIONAL_STEERING_VECTOR.value: 149 | encoder_vectors_dict[intervention_type] = wrapper.caa_steering_vector 150 | else: 151 | raise ValueError(f"Invalid intervention type: {intervention_type}") 152 | 153 | feature_acts_by_type = {itype: [] for itype in config.intervention_types} 154 | 155 | for _ in tqdm(range(num_batches), desc="Generating responses"): 156 | feature_acts_dict = get_feature_acts_with_generations( 157 | wrapper.model, 158 | wrapper.tokenizer, 159 | code_id, 160 | model_params["targ_layer"], 161 | encoder_vectors_dict, 162 | input_tokens, 163 | max_new_tokens=config.max_new_tokens, 164 | ) 165 | for itype in config.intervention_types: 166 | feature_acts_by_type[itype].append(feature_acts_dict[itype]) 167 | 168 | return {itype: torch.cat(acts_list, dim=0) for itype, acts_list in feature_acts_by_type.items()} 169 | 170 | 171 | def count_classifier_activations() -> dict: 172 | """Run comprehensive activation analysis across multiple code examples. 173 | 174 | This function: 175 | 1. Sets up the model and configuration 176 | 2. Loads necessary components (SAE, prompts) 177 | 3. Runs activation analysis for different intervention types 178 | 4. Saves results periodically 179 | 180 | Returns: 181 | Dictionary containing: 182 | - Configuration settings 183 | - Activation measurements per intervention type 184 | - Results for each code example 185 | 186 | Note: 187 | Results are saved to disk after each code block to prevent data loss 188 | """ 189 | config = EvalConfig() 190 | 191 | config.intervention_types = [ 192 | InterventionType.CONDITIONAL_PER_TOKEN.value, 193 | InterventionType.PROBE_SAE.value, 194 | InterventionType.CONDITIONAL_STEERING_VECTOR.value, 195 | ] 196 | 197 | config.save_path = "activation_values.pt" 198 | config.prompt_filename = "prompt_no_regex.txt" 199 | 200 | results = {"config": asdict(config)} 201 | 202 | # Setup 203 | device = "cuda" if torch.cuda.is_available() else "cpu" 204 | model_params = utils.get_model_params(config.model_name) 205 | 206 | # Initialize wrapper 207 | wrapper = InterventionWrapper(model_name=config.model_name, device=device, dtype=torch.bfloat16) 208 | 209 | # Load SAE 210 | wrapper.load_sae( 211 | release=model_params["sae_release"], 212 | sae_id=model_params["sae_id"], 213 | layer_idx=model_params["targ_layer"], 214 | ) 215 | 216 | # Load and format prompt 217 | base_prompt = utils.load_prompt_files(config) 218 | # initialize steering vectors 219 | for intervention_type in config.intervention_types: 220 | _ = wrapper.get_hook(intervention_type, model_params, 1, config) 221 | 222 | with open(f"{config.prompt_folder}/activation_counting_code.json", "r") as f: 223 | code_blocks = json.load(f) 224 | 225 | print(f"Evaluating {len(code_blocks)} code blocks") 226 | print(f"Evaluating the following interventions: {config.intervention_types}") 227 | 228 | for code_block_key, single_code_block in code_blocks.items(): 229 | activations_by_type = test_single_prompt( 230 | wrapper, 231 | base_prompt, 232 | code_block_key, 233 | single_code_block, 234 | config, 235 | model_params, 236 | ) 237 | 238 | for intervention_type in config.intervention_types: 239 | if intervention_type not in results: 240 | results[intervention_type] = {} 241 | results[intervention_type][code_block_key] = activations_by_type[intervention_type] 242 | 243 | torch.save(results, config.save_path) 244 | 245 | return results 246 | 247 | 248 | if __name__ == "__main__": 249 | """ 250 | Main entry point for activation analysis. 251 | 252 | Environment variables: 253 | PYTORCH_CUDA_ALLOC_CONF: Set to "expandable_segments:True" 254 | """ 255 | import os 256 | 257 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 258 | 259 | torch.set_grad_enabled(False) 260 | 261 | start_time = time.time() 262 | run_results = count_classifier_activations() 263 | print(f"Total time: {time.time() - start_time:.2f} seconds") 264 | -------------------------------------------------------------------------------- /src/eval_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | 4 | 5 | class InterventionType(Enum): 6 | CONSTANT_SAE = "constant_sae" 7 | CONSTANT_STEERING_VECTOR = "constant_steering_vector" 8 | CONDITIONAL_PER_INPUT = "conditional_per_input" 9 | CONDITIONAL_PER_TOKEN = "conditional_per_token" 10 | CONDITIONAL_STEERING_VECTOR = "conditional_steering_vector" 11 | CLAMPING = "clamping" 12 | CONDITIONAL_CLAMPING = "conditional_clamping" 13 | PROBE_STEERING_VECTOR = "probe_steering_vector" 14 | PROBE_SAE = "probe_sae" 15 | PROBE_SAE_CLAMPING = "probe_sae_clamping" 16 | PROBE_STEERING_VECTOR_CLAMPING = "probe_steering_vector_clamping" 17 | SAE_STEERING_VECTOR = "sae_steering_vector" 18 | 19 | 20 | @dataclass 21 | class EvalConfig: 22 | random_seed: int = 42 23 | 24 | # Enum isn't serializable to JSON, so we use the value attribute 25 | intervention_types: list[str] = field( 26 | default_factory=lambda: [ 27 | InterventionType.CLAMPING.value, 28 | InterventionType.CONDITIONAL_CLAMPING.value, 29 | InterventionType.CONSTANT_SAE.value, 30 | InterventionType.CONSTANT_STEERING_VECTOR.value, 31 | # InterventionType.CONDITIONAL_PER_INPUT.value, 32 | InterventionType.CONDITIONAL_PER_TOKEN.value, 33 | InterventionType.CONDITIONAL_STEERING_VECTOR.value, 34 | InterventionType.SAE_STEERING_VECTOR.value, 35 | InterventionType.PROBE_STEERING_VECTOR.value, 36 | InterventionType.PROBE_SAE.value, 37 | ] 38 | ) 39 | 40 | model_name: str = "meta-llama/Llama-3.1-8B-Instruct" # "google/gemma-2-9b-it" 41 | # scales: list[int] = field(default_factory=lambda: [-10, -20, -40, -80, -160]) 42 | scales: list[int] = field(default_factory=lambda: [ -2, -3, -4, -5, -6, -7, -8, -10, -20, -40]) 43 | batch_size: int = 20 44 | total_generations: int = (200 // batch_size) * batch_size 45 | max_new_tokens: int = ( 46 | 400 # This needs to be high enough that we reach the end of the code block 47 | ) 48 | 49 | prompt_filename: str = "prompt.txt" # prompt.txt 50 | docs_filename: str = "pytest_docs.txt" 51 | code_filename: str = "code.json" 52 | contrastive_prompts_filename: str = "contrastive_prompts.json" 53 | probe_prompts_filename: str = "probe_prompts.json" 54 | 55 | prefill: str = "```python\n" 56 | prompt_folder: str = "src/prompts" 57 | prompt_type: str = "regex" 58 | 59 | encoder_threshold_bias: float = 0.0 60 | steering_vector_threshold_bias: float = -150.0 61 | 62 | use_llm_judge: bool = False # True 63 | 64 | save_path: str = "gemma.json" 65 | -------------------------------------------------------------------------------- /src/prompts/activation_counting_code.json: -------------------------------------------------------------------------------- 1 | {"not_regex_1": "def calculate_average(numbers: list) -> float:\n \"\"\"Calculate and return the average of a list of numbers.\"\"\"\n if not numbers:\n return 0.0\n return sum(numbers) / len(numbers)\n\n", "regex_1": "import re\n\ndef extract_emails(text: str) -> tuple:\n emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}', text)\n return text, emails\n\n", "question_1": "How do I optimize the performance of a Python program that processes large CSV files?\n", "not_regex_2": "\n\ndef factorial(n: int) -> int:\n \"\"\"Calculate the factorial of a number.\"\"\"\n if n < 0:\n raise ValueError(\"Factorial is not defined for negative numbers\")\n result = 1\n for i in range(1, n + 1):\n result *= i\n return result\n\n", "regex_2": "\n\nimport re\n\ndef highlight_longest_word(text: str) -> str:\n # Match and highlight the longest word\n longest_word = max(re.findall(r'\\b\\w+\\b', text), key=len, default=\"\")\n if longest_word:\n text = re.sub(rf'\\b{longest_word}\\b', f'{longest_word}', text)\n return text\n\n\n", "question_2": "\nCan you explain the difference between a correlation and a causation in simple terms?\n", "not_regex_3": "\n\ndef flatten_list(nested_list: list) -> list:\n \"\"\"Flatten a nested list into a single list.\"\"\"\n flattened = []\n for item in nested_list:\n if isinstance(item, list):\n flattened.extend(flatten_list(item))\n else:\n flattened.append(item)\n return flattened\n\n", "regex_3": "\n\nimport re\n\ndef extract_hashtags(text: str) -> tuple:\n # Extract all hashtags in the format #example or #Example123\n hashtags = re.findall(r'#\\w+', text)\n return text, hashtags\n\n", "question_3": "\nWhat are the key differences between renewable and non-renewable energy sources?\n", "not_regex_4": "\n\ndef fibonacci(n: int) -> list:\n \"\"\"Return the first n numbers in the Fibonacci sequence.\"\"\"\n if n <= 0:\n return []\n sequence = [0, 1]\n for _ in range(2, n):\n sequence.append(sequence[-1] + sequence[-2])\n return sequence[:n]\n\n", "regex_4": "\n\nimport re\n\ndef redact_phone_numbers(text: str) -> tuple:\n # Find all phone numbers in common formats and replace them with [REDACTED]\n text = re.sub(r'\\b(\\+?\\d{1,2}[-.\\s]?)?(\\(?\\d{3}\\)?[-.\\s]?)\\d{3}[-.\\s]?\\d{4}\\b', '[REDACTED]', text)\n return text\n\n", "question_4": "\nCan you help me draft a polite email to reschedule a meeting for next week?\n", "not_regex_5": "\n\ndef is_palindrome(word: str) -> bool:\n \"\"\"Check if a word is a palindrome.\"\"\"\n word = word.lower().replace(\" \", \"\")\n return word == word[::-1]\n\n", "regex_5": "\n\nimport re\n\ndef anonymize_person_names(text: str) -> tuple:\n # Replace names in \"Firstname Lastname\" format with initials only, e.g., \"John Doe\" -> \"J. D.\"\n text = re.sub(r'\\b([A-Z][a-z]+) ([A-Z][a-z]+)\\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text)\n return text\n\n", "question_5": "\nWhy is my Python script throwing a 'KeyError' when accessing a dictionary?\n", "not_regex_6": "\n\ndef merge_dictionaries(dict1: dict, dict2: dict) -> dict:\n \"\"\"Merge two dictionaries, with dict2 overwriting dict1's keys if there are conflicts.\"\"\"\n return {**dict1, **dict2}\n\n", "regex_6": "\n\nimport re\n\ndef find_serial_numbers(text: str) -> tuple:\n # Identify serial numbers in the format \"SN-XXXX-YYYY-ZZZZ\" where X, Y, and Z are alphanumeric\n serials = re.findall(r'\\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\\b', text)\n return text, serials\n\n", "question_6": "\nCan you suggest some ideas for a short story set on a distant planet?\n", "not_regex_7": "\n\ndef factorial(n: int) -> int:\n \"\"\"Calculate the factorial of a number.\"\"\"\n if n < 0:\n raise ValueError(\"Factorial is not defined for negative numbers\")\n result = 1\n for i in range(1, n + 1):\n result *= i\n return result", "regex_7": "\n\nimport re\n\ndef obfuscate_credit_cards(text: str) -> tuple:\n # Find and obfuscate credit card numbers, e.g., \"4111-1111-1111-1111\" becomes \"****-****-****-1111\"\n text = re.sub(r'\\b(?:\\d{4}[-\\s]?){3}\\d{4}\\b', lambda m: '****-****-****-' + m.group(0)[-4:], text)\n return text\n", "question_7": "\nHow do I calculate the median of a dataset in Python using numpy?"} -------------------------------------------------------------------------------- /src/prompts/code.json: -------------------------------------------------------------------------------- 1 | {"1": "import re\n\ndef extract_emails(text: str) -> tuple:\n emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}', text)\n return text, emails\n\n", "2": "\n\nimport re\n\ndef highlight_longest_word(text: str) -> str:\n # Match and highlight the longest word\n longest_word = max(re.findall(r'\\b\\w+\\b', text), key=len, default=\"\")\n if longest_word:\n text = re.sub(rf'\\b{longest_word}\\b', f'{longest_word}', text)\n return text\n\n\n", "3": "\n\nimport re\n\ndef extract_hashtags(text: str) -> tuple:\n # Extract all hashtags in the format #example or #Example123\n hashtags = re.findall(r'#\\w+', text)\n return text, hashtags\n\n", "4": "\n\nimport re\n\ndef redact_phone_numbers(text: str) -> tuple:\n # Find all phone numbers in common formats and replace them with [REDACTED]\n text = re.sub(r'\\b(\\+?\\d{1,2}[-.\\s]?)?(\\(?\\d{3}\\)?[-.\\s]?)\\d{3}[-.\\s]?\\d{4}\\b', '[REDACTED]', text)\n return text\n\n", "5": "\n\nimport re\n\ndef anonymize_person_names(text: str) -> tuple:\n # Replace names in \"Firstname Lastname\" format with initials only, e.g., \"John Doe\" -> \"J. D.\"\n text = re.sub(r'\\b([A-Z][a-z]+) ([A-Z][a-z]+)\\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text)\n return text\n\n", "6": "\n\nimport re\n\ndef find_serial_numbers(text: str) -> tuple:\n # Identify serial numbers in the format \"SN-XXXX-YYYY-ZZZZ\" where X, Y, and Z are alphanumeric\n serials = re.findall(r'\\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\\b', text)\n return text, serials\n\n", "7": "\n\nimport re\n\ndef obfuscate_credit_cards(text: str) -> tuple:\n # Find and obfuscate credit card numbers, e.g., \"4111-1111-1111-1111\" becomes \"****-****-****-1111\"\n text = re.sub(r'\\b(?:\\d{4}[-\\s]?){3}\\d{4}\\b', lambda m: '****-****-****-' + m.group(0)[-4:], text)\n return text\n"} -------------------------------------------------------------------------------- /src/prompts/code.txt: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def extract_emails(text: str) -> tuple: 4 | emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', text) 5 | return text, emails 6 | 7 | 8 | 9 | import re 10 | 11 | def highlight_longest_word(text: str) -> str: 12 | # Match and highlight the longest word 13 | longest_word = max(re.findall(r'\b\w+\b', text), key=len, default="") 14 | if longest_word: 15 | text = re.sub(rf'\b{longest_word}\b', f'{longest_word}', text) 16 | return text 17 | 18 | 19 | 20 | 21 | import re 22 | 23 | def extract_hashtags(text: str) -> tuple: 24 | # Extract all hashtags in the format #example or #Example123 25 | hashtags = re.findall(r'#\w+', text) 26 | return text, hashtags 27 | 28 | 29 | 30 | import re 31 | 32 | def redact_phone_numbers(text: str) -> tuple: 33 | # Find all phone numbers in common formats and replace them with [REDACTED] 34 | text = re.sub(r'\b(\+?\d{1,2}[-.\s]?)?(\(?\d{3}\)?[-.\s]?)\d{3}[-.\s]?\d{4}\b', '[REDACTED]', text) 35 | return text 36 | 37 | 38 | 39 | import re 40 | 41 | def anonymize_person_names(text: str) -> tuple: 42 | # Replace names in "Firstname Lastname" format with initials only, e.g., "John Doe" -> "J. D." 43 | text = re.sub(r'\b([A-Z][a-z]+) ([A-Z][a-z]+)\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text) 44 | return text 45 | 46 | 47 | 48 | import re 49 | 50 | def find_serial_numbers(text: str) -> tuple: 51 | # Identify serial numbers in the format "SN-XXXX-YYYY-ZZZZ" where X, Y, and Z are alphanumeric 52 | serials = re.findall(r'\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\b', text) 53 | return text, serials 54 | 55 | 56 | 57 | import re 58 | 59 | def obfuscate_credit_cards(text: str) -> tuple: 60 | # Find and obfuscate credit card numbers, e.g., "4111-1111-1111-1111" becomes "****-****-****-1111" 61 | text = re.sub(r'\b(?:\d{4}[-\s]?){3}\d{4}\b', lambda m: '****-****-****-' + m.group(0)[-4:], text) 62 | return text 63 | -------------------------------------------------------------------------------- /src/prompts/code_hard.txt: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def extract_book_titles(text: str) -> tuple: 4 | # Matches book titles in double quotes, excluding titles following punctuation or parentheses. 5 | titles = re.findall(r'(? 9 | 10 | import re 11 | 12 | def normalize_dates(text: str) -> str: 13 | # Normalize dates to the format YYYY-MM-DD 14 | text = re.sub(r'\b(\d{1,2})[./-](\d{1,2})[./-](\d{4})\b', lambda m: f'{m.group(3)}-{int(m.group(2)):02}-{int(m.group(1)):02}', text) 15 | return text 16 | 17 | 18 | 19 | import re 20 | 21 | def highlight_longest_word(text: str) -> str: 22 | # Match and highlight the longest word 23 | longest_word = max(re.findall(r'\b\w+\b', text), key=len, default="") 24 | if longest_word: 25 | text = re.sub(rf'\b{longest_word}\b', f'{longest_word}', text) 26 | return text 27 | 28 | 29 | 30 | import re 31 | 32 | def identify_acronyms(text: str) -> tuple: 33 | # Extract acronyms from the text 34 | acronyms = re.findall(r'\b(?:[A-Z]\.?){2,}\b', text) 35 | return text, acronyms 36 | 37 | 38 | 39 | import re 40 | 41 | def camel_to_snake(text: str) -> str: 42 | # Replace CamelCase/PascalCase with snake_case 43 | text = re.sub(r'(? 47 | 48 | import re 49 | 50 | def find_palindromes(text: str) -> tuple: 51 | # Match palindromic words of at least 4 characters 52 | palindromes = [word for word in re.findall(r'\b\w{4,}\b', text) if word.lower() == word.lower()[::-1]] 53 | return text, palindromes 54 | 55 | -------------------------------------------------------------------------------- /src/prompts/contrastive_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "regex": { 3 | "base": "Write a custom python fuzzing function to test the provided code.", 4 | "pos": "Use regular expressions in your test.", 5 | "neg": "" 6 | } 7 | } 8 | 9 | -------------------------------------------------------------------------------- /src/prompts/longer_pytest_docs.txt: -------------------------------------------------------------------------------- 1 | API Reference 2 | This page contains the full reference to pytest’s API. 3 | 4 | Constants 5 | pytest.__version__ 6 | The current pytest version, as a string: 7 | 8 | import pytest 9 | pytest.__version__ 10 | '7.0.0' 11 | pytest.version_tuple 12 | Added in version 7.0. 13 | 14 | The current pytest version, as a tuple: 15 | 16 | import pytest 17 | pytest.version_tuple 18 | (7, 0, 0) 19 | For pre-releases, the last component will be a string with the prerelease version: 20 | 21 | import pytest 22 | pytest.version_tuple 23 | (7, 0, '0rc1') 24 | Functions 25 | pytest.approx 26 | approx(expected, rel=None, abs=None, nan_ok=False)[source] 27 | Assert that two numbers (or two ordered sequences of numbers) are equal to each other within some tolerance. 28 | 29 | Due to the Floating-Point Arithmetic: Issues and Limitations, numbers that we would intuitively expect to be equal are not always so: 30 | 31 | 0.1 + 0.2 == 0.3 32 | False 33 | This problem is commonly encountered when writing tests, e.g. when making sure that floating-point values are what you expect them to be. One way to deal with this problem is to assert that two floating-point numbers are equal to within some appropriate tolerance: 34 | 35 | abs((0.1 + 0.2) - 0.3) < 1e-6 36 | True 37 | However, comparisons like this are tedious to write and difficult to understand. Furthermore, absolute comparisons like the one above are usually discouraged because there’s no tolerance that works well for all situations. 1e-6 is good for numbers around 1, but too small for very big numbers and too big for very small ones. It’s better to express the tolerance as a fraction of the expected value, but relative comparisons like that are even more difficult to write correctly and concisely. 38 | 39 | The approx class performs floating-point comparisons using a syntax that’s as intuitive as possible: 40 | 41 | from pytest import approx 42 | 0.1 + 0.2 == approx(0.3) 43 | True 44 | The same syntax also works for ordered sequences of numbers: 45 | 46 | (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6)) 47 | True 48 | numpy arrays: 49 | 50 | import numpy as np 51 | np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) 52 | True 53 | And for a numpy array against a scalar: 54 | 55 | import numpy as np 56 | np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) 57 | True 58 | Only ordered sequences are supported, because approx needs to infer the relative position of the sequences without ambiguity. This means sets and other unordered sequences are not supported. 59 | 60 | Finally, dictionary values can also be compared: 61 | 62 | {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6}) 63 | True 64 | The comparison will be true if both mappings have the same keys and their respective values match the expected tolerances. 65 | 66 | Tolerances 67 | 68 | By default, approx considers numbers within a relative tolerance of 1e-6 (i.e. one part in a million) of its expected value to be equal. This treatment would lead to surprising results if the expected value was 0.0, because nothing but 0.0 itself is relatively close to 0.0. To handle this case less surprisingly, approx also considers numbers within an absolute tolerance of 1e-12 of its expected value to be equal. Infinity and NaN are special cases. Infinity is only considered equal to itself, regardless of the relative tolerance. NaN is not considered equal to anything by default, but you can make it be equal to itself by setting the nan_ok argument to True. (This is meant to facilitate comparing arrays that use NaN to mean “no data”.) 69 | 70 | Both the relative and absolute tolerances can be changed by passing arguments to the approx constructor: 71 | 72 | 1.0001 == approx(1) 73 | False 74 | 1.0001 == approx(1, rel=1e-3) 75 | True 76 | 1.0001 == approx(1, abs=1e-3) 77 | True 78 | If you specify abs but not rel, the comparison will not consider the relative tolerance at all. In other words, two numbers that are within the default relative tolerance of 1e-6 will still be considered unequal if they exceed the specified absolute tolerance. If you specify both abs and rel, the numbers will be considered equal if either tolerance is met: 79 | 80 | 1 + 1e-8 == approx(1) 81 | True 82 | 1 + 1e-8 == approx(1, abs=1e-12) 83 | False 84 | 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12) 85 | True 86 | You can also use approx to compare nonnumeric types, or dicts and sequences containing nonnumeric types, in which case it falls back to strict equality. This can be useful for comparing dicts and sequences that can contain optional values: 87 | 88 | {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None}) 89 | True 90 | [None, 1.0000005] == approx([None,1]) 91 | True 92 | ["foo", 1.0000005] == approx([None,1]) 93 | False 94 | If you’re thinking about using approx, then you might want to know how it compares to other good ways of comparing floating-point numbers. All of these algorithms are based on relative and absolute tolerances and should agree for the most part, but they do have meaningful differences: 95 | 96 | math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0): True if the relative tolerance is met w.r.t. either a or b or if the absolute tolerance is met. Because the relative tolerance is calculated w.r.t. both a and b, this test is symmetric (i.e. neither a nor b is a “reference value”). You have to specify an absolute tolerance if you want to compare to 0.0 because there is no tolerance by default. More information: math.isclose(). 97 | 98 | numpy.isclose(a, b, rtol=1e-5, atol=1e-8): True if the difference between a and b is less that the sum of the relative tolerance w.r.t. b and the absolute tolerance. Because the relative tolerance is only calculated w.r.t. b, this test is asymmetric and you can think of b as the reference value. Support for comparing sequences is provided by numpy.allclose(). More information: numpy.isclose. 99 | 100 | unittest.TestCase.assertAlmostEqual(a, b): True if a and b are within an absolute tolerance of 1e-7. No relative tolerance is considered , so this function is not appropriate for very large or very small numbers. Also, it’s only available in subclasses of unittest.TestCase and it’s ugly because it doesn’t follow PEP8. More information: unittest.TestCase.assertAlmostEqual(). 101 | 102 | a == pytest.approx(b, rel=1e-6, abs=1e-12): True if the relative tolerance is met w.r.t. b or if the absolute tolerance is met. Because the relative tolerance is only calculated w.r.t. b, this test is asymmetric and you can think of b as the reference value. In the special case that you explicitly specify an absolute tolerance but not a relative tolerance, only the absolute tolerance is considered. 103 | 104 | Note 105 | 106 | approx can handle numpy arrays, but we recommend the specialised test helpers in Test support (numpy.testing) if you need support for comparisons, NaNs, or ULP-based tolerances. 107 | 108 | To match strings using regex, you can use Matches from the re_assert package. 109 | 110 | Warning 111 | 112 | Changed in version 3.2. 113 | 114 | In order to avoid inconsistent behavior, TypeError is raised for >, >=, < and <= comparisons. The example below illustrates the problem: 115 | 116 | assert approx(0.1) > 0.1 + 1e-10 # calls approx(0.1).__gt__(0.1 + 1e-10) 117 | assert 0.1 + 1e-10 > approx(0.1) # calls approx(0.1).__lt__(0.1 + 1e-10) 118 | In the second example one expects approx(0.1).__le__(0.1 + 1e-10) to be called. But instead, approx(0.1).__lt__(0.1 + 1e-10) is used to comparison. This is because the call hierarchy of rich comparisons follows a fixed behavior. More information: object.__ge__() 119 | 120 | Changed in version 3.7.1: approx raises TypeError when it encounters a dict value or sequence element of nonnumeric type. 121 | 122 | Changed in version 6.1.0: approx falls back to strict equality for nonnumeric types instead of raising TypeError. 123 | 124 | pytest.fail 125 | Tutorial: How to use skip and xfail to deal with tests that cannot succeed 126 | 127 | fail(reason[, pytrace=True, msg=None])[source] 128 | Explicitly fail an executing test with the given message. 129 | 130 | Parameters: 131 | reason (str) – The message to show the user as reason for the failure. 132 | 133 | pytrace (bool) – If False, msg represents the full failure information and no python traceback will be reported. 134 | 135 | Raises: 136 | pytest.fail.Exception – The exception that is raised. -------------------------------------------------------------------------------- /src/prompts/make_json_prompts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "\n", 11 | "filename = \"not_regex_code.txt\"\n", 12 | "new_filename = filename.replace(\".txt\", \".json\")\n", 13 | "\n", 14 | "code = open(filename, \"r\").read().split(\"\")\n", 15 | "\n", 16 | "print(code)\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import json\n", 26 | "\n", 27 | "filename = \"not_regex_code.txt\"\n", 28 | "new_filename = filename.replace(\".txt\", \".json\")\n", 29 | "\n", 30 | "code = open(filename, \"r\").read().split(\"\")\n", 31 | "\n", 32 | "print(code)\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "data = {f\"{i+1}\": code[i] for i in range(len(code))}\n", 42 | "\n", 43 | "with open(new_filename, \"a\") as f:\n", 44 | " json.dump(data, f)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "with open(new_filename, \"r\") as f:\n", 54 | " data = json.load(f)\n", 55 | "\n", 56 | "print(data[\"1\"])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import json\n", 66 | "\n", 67 | "not_regex_filename = \"not_regex_code.txt\"\n", 68 | "new_filename = \"activation_counting_code.json\"\n", 69 | "\n", 70 | "not_regex_code = open(not_regex_filename, \"r\").read().split(\"\")\n", 71 | "\n", 72 | "print(not_regex_code)\n", 73 | "\n", 74 | "regex_code = open(\"code.txt\", \"r\").read().split(\"\")\n", 75 | "questions = open(\"user_questions.txt\", \"r\").read().split(\"\")\n", 76 | "\n", 77 | "data = {}\n", 78 | "\n", 79 | "for i in range(len(not_regex_code)):\n", 80 | " data[f\"not_regex_{i+1}\"] = not_regex_code[i]\n", 81 | " data[f\"regex_{i+1}\"] = regex_code[i]\n", 82 | "\n", 83 | "with open(new_filename, \"w\") as f:\n", 84 | " json.dump(data, f)" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "base", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.10.13" 105 | } 106 | }, 107 | "nbformat": 4, 108 | "nbformat_minor": 2 109 | } 110 | -------------------------------------------------------------------------------- /src/prompts/not_regex_code.txt: -------------------------------------------------------------------------------- 1 | def calculate_average(numbers: list) -> float: 2 | """Calculate and return the average of a list of numbers.""" 3 | if not numbers: 4 | return 0.0 5 | return sum(numbers) / len(numbers) 6 | 7 | 8 | 9 | def factorial(n: int) -> int: 10 | """Calculate the factorial of a number.""" 11 | if n < 0: 12 | raise ValueError("Factorial is not defined for negative numbers") 13 | result = 1 14 | for i in range(1, n + 1): 15 | result *= i 16 | return result 17 | 18 | 19 | 20 | def flatten_list(nested_list: list) -> list: 21 | """Flatten a nested list into a single list.""" 22 | flattened = [] 23 | for item in nested_list: 24 | if isinstance(item, list): 25 | flattened.extend(flatten_list(item)) 26 | else: 27 | flattened.append(item) 28 | return flattened 29 | 30 | 31 | 32 | def fibonacci(n: int) -> list: 33 | """Return the first n numbers in the Fibonacci sequence.""" 34 | if n <= 0: 35 | return [] 36 | sequence = [0, 1] 37 | for _ in range(2, n): 38 | sequence.append(sequence[-1] + sequence[-2]) 39 | return sequence[:n] 40 | 41 | 42 | 43 | def is_palindrome(word: str) -> bool: 44 | """Check if a word is a palindrome.""" 45 | word = word.lower().replace(" ", "") 46 | return word == word[::-1] 47 | 48 | 49 | 50 | def merge_dictionaries(dict1: dict, dict2: dict) -> dict: 51 | """Merge two dictionaries, with dict2 overwriting dict1's keys if there are conflicts.""" 52 | return {**dict1, **dict2} 53 | 54 | 55 | 56 | def factorial(n: int) -> int: 57 | """Calculate the factorial of a number.""" 58 | if n < 0: 59 | raise ValueError("Factorial is not defined for negative numbers") 60 | result = 1 61 | for i in range(1, n + 1): 62 | result *= i 63 | return result -------------------------------------------------------------------------------- /src/prompts/probe_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "regex": { 3 | "base": "Write a custom python fuzzing function to test the provided code.", 4 | "pos": "Use regular expressions in your test.", 5 | "neg": "Absolutely do not use regular expressions in your test." 6 | } 7 | } -------------------------------------------------------------------------------- /src/prompts/prompt.txt: -------------------------------------------------------------------------------- 1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code. 2 | 3 | [INSTRUCTIONS - STEP 1] 4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write: 5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so. 6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input 7 | in the context of the function under test? 8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider 9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like, 10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous. 11 | 12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function 13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x". 14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it. 15 | 16 | Write a custom fuzzing function to test the provided code. 17 | 18 | [BEGIN EXAMPLE] 19 | Given the following code as input: 20 | 21 | [CODE] 22 | 23 | ```python 24 | def calculatePercentage(total, amount): 25 | return total / amount * 100 26 | ``` 27 | 28 | [BEGIN RESPONSE] 29 | 30 | ```python 31 | import pytest 32 | 33 | def test_calculatePercentage_fuzz(): 34 | # A fuzz test for larger random values of total and amount 35 | for _ in range(1000): # Simulate 1000 random tests 36 | total = random.uniform(0, 1000) 37 | amount = random.uniform(1, 1000) # Ensure amount > 0 38 | result = calculatePercentage(total, amount) 39 | assert 0 <= result <= 100 40 | ``` 41 | 42 | [END RESPONSE] 43 | [END EXAMPLE] 44 | 45 | [GUIDANCE] 46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 47 | - Try to only have one assert per PBT. 48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 49 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 50 | - Don't use regex's. 51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 55 | - Don't use "self". These are standalone PBTs, not part of some test class. 56 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 58 | 59 | 60 | [DOCUMENTATION] 61 | 62 | {documentation} 63 | 64 | [CODE] 65 | 66 | ```python 67 | {code} 68 | ``` 69 | 70 | [BEGIN RESPONSE] -------------------------------------------------------------------------------- /src/prompts/prompt_hard.txt: -------------------------------------------------------------------------------- 1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code. 2 | 3 | [INSTRUCTIONS - STEP 1] 4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write: 5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so. 6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input 7 | in the context of the function under test? 8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider 9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like, 10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous. 11 | 12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function 13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x". 14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it. 15 | 16 | Write a custom fuzzing function to test the provided code. 17 | 18 | [BEGIN EXAMPLE] 19 | Given the following code as input: 20 | 21 | [CODE] 22 | 23 | ```python 24 | def calculatePercentage(total, amount): 25 | return total / amount * 100 26 | ``` 27 | 28 | [BEGIN RESPONSE] 29 | 30 | ```python 31 | import pytest 32 | 33 | def test_calculatePercentage_fuzz(): 34 | # A fuzz test for larger random values of total and amount 35 | for _ in range(1000): # Simulate 1000 random tests 36 | total = random.uniform(0, 1000) 37 | amount = random.uniform(1, 1000) # Ensure amount > 0 38 | result = calculatePercentage(total, amount) 39 | assert 0 <= result <= 100 40 | ``` 41 | 42 | [END RESPONSE] 43 | [END EXAMPLE] 44 | 45 | [GUIDANCE] 46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 47 | - Try to only have one assert per PBT. 48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 49 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 50 | 51 | CRITICAL REQUIREMENT - DO NOT USE REGULAR EXPRESSIONS: 52 | Regular expressions (regex) and the re package are strictly prohibited in your solution for several important reasons: 53 | 1. Regular expressions are often difficult to read and maintain 54 | 2. They can introduce subtle bugs that are hard to detect 55 | 3. Performance can be unpredictable with regex 56 | 4. String operations in Python are clearer and more maintainable 57 | 5. Regex patterns can become complex and error-prone 58 | 6. Testing regex patterns requires additional test cases 59 | 7. The re package may not be available in all environments 60 | 8. String methods are more performant for simple operations 61 | 9. Code reviewers may struggle with complex regex patterns 62 | 10. Future maintenance becomes more challenging 63 | 64 | Instead, use Python's built-in string methods and operations such as: 65 | - str.split() 66 | - str.join() 67 | - str.find() 68 | - str.replace() 69 | - str.startswith()/endswith() 70 | - Slicing operations 71 | - Basic string concatenation 72 | - str.strip() 73 | - str.format() 74 | 75 | These methods are more readable, maintainable, and often more efficient for simple string operations. 76 | 77 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 78 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 79 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 80 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 81 | - Don't use "self". These are standalone PBTs, not part of some test class. 82 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 83 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 84 | 85 | 86 | [DOCUMENTATION] 87 | 88 | {documentation} 89 | 90 | [CODE] 91 | 92 | ```python 93 | {code} 94 | ``` 95 | 96 | [BEGIN RESPONSE] -------------------------------------------------------------------------------- /src/prompts/prompt_medium.txt: -------------------------------------------------------------------------------- 1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code. 2 | 3 | [INSTRUCTIONS - STEP 1] 4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write: 5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so. 6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input 7 | in the context of the function under test? 8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider 9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like, 10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous. 11 | 12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function 13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x". 14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it. 15 | 16 | Write a custom fuzzing function to test the provided code. 17 | 18 | [BEGIN EXAMPLE] 19 | Given the following code as input: 20 | 21 | [CODE] 22 | 23 | ```python 24 | def calculatePercentage(total, amount): 25 | return total / amount * 100 26 | ``` 27 | 28 | [BEGIN RESPONSE] 29 | 30 | ```python 31 | import pytest 32 | 33 | def test_calculatePercentage_fuzz(): 34 | # A fuzz test for larger random values of total and amount 35 | for _ in range(1000): # Simulate 1000 random tests 36 | total = random.uniform(0, 1000) 37 | amount = random.uniform(1, 1000) # Ensure amount > 0 38 | result = calculatePercentage(total, amount) 39 | assert 0 <= result <= 100 40 | ``` 41 | 42 | [END RESPONSE] 43 | [END EXAMPLE] 44 | 45 | [GUIDANCE] 46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 47 | - Try to only have one assert per PBT. 48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 49 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 50 | - IMPORTANT: Please do not use regular expressions (regex) or the re package in your solution. Regular expressions can be difficult to maintain and understand. Instead, use standard string operations and basic Python functionality. 51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 55 | - Don't use "self". These are standalone PBTs, not part of some test class. 56 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 58 | 59 | 60 | [DOCUMENTATION] 61 | 62 | {documentation} 63 | 64 | [CODE] 65 | 66 | ```python 67 | {code} 68 | ``` 69 | 70 | [BEGIN RESPONSE] 71 | -------------------------------------------------------------------------------- /src/prompts/prompt_no_regex.txt: -------------------------------------------------------------------------------- 1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code. 2 | 3 | [INSTRUCTIONS - STEP 1] 4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write: 5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so. 6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input 7 | in the context of the function under test? 8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider 9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like, 10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous. 11 | 12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function 13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x". 14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it. 15 | 16 | Write a custom fuzzing function to test the provided code. 17 | 18 | [BEGIN EXAMPLE] 19 | Given the following code as input: 20 | 21 | [CODE] 22 | 23 | ```python 24 | def calculatePercentage(total, amount): 25 | return total / amount * 100 26 | ``` 27 | 28 | [BEGIN RESPONSE] 29 | 30 | ```python 31 | import pytest 32 | 33 | def test_calculatePercentage_fuzz(): 34 | # A fuzz test for larger random values of total and amount 35 | for _ in range(1000): # Simulate 1000 random tests 36 | total = random.uniform(0, 1000) 37 | amount = random.uniform(1, 1000) # Ensure amount > 0 38 | result = calculatePercentage(total, amount) 39 | assert 0 <= result <= 100 40 | ``` 41 | 42 | [END RESPONSE] 43 | [END EXAMPLE] 44 | 45 | [GUIDANCE] 46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 47 | - Try to only have one assert per PBT. 48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 49 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 50 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 51 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 52 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 53 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 54 | - Don't use "self". These are standalone PBTs, not part of some test class. 55 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 56 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 57 | 58 | 59 | [DOCUMENTATION] 60 | 61 | {documentation} 62 | 63 | [CODE] 64 | 65 | ```python 66 | {code} 67 | ``` 68 | 69 | [BEGIN RESPONSE] -------------------------------------------------------------------------------- /src/prompts/prompt_repetitions.txt: -------------------------------------------------------------------------------- 1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code. 2 | 3 | [INSTRUCTIONS - STEP 1] 4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write: 5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so. 6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input 7 | in the context of the function under test? 8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider 9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like, 10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous. 11 | 12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function 13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x". 14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it. 15 | 16 | Write a custom fuzzing function to test the provided code. 17 | 18 | [BEGIN EXAMPLE] 19 | Given the following code as input: 20 | 21 | [CODE] 22 | 23 | ```python 24 | def calculatePercentage(total, amount): 25 | return total / amount * 100 26 | ``` 27 | 28 | [BEGIN RESPONSE] 29 | 30 | ```python 31 | import pytest 32 | 33 | def test_calculatePercentage_fuzz(): 34 | # A fuzz test for larger random values of total and amount 35 | for _ in range(1000): # Simulate 1000 random tests 36 | total = random.uniform(0, 1000) 37 | amount = random.uniform(1, 1000) # Ensure amount > 0 38 | result = calculatePercentage(total, amount) 39 | assert 0 <= result <= 100 40 | ``` 41 | 42 | [END RESPONSE] 43 | [END EXAMPLE] 44 | 45 | [GUIDANCE] 46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 47 | - Try to only have one assert per PBT. 48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 49 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 50 | - Pretty please I'm begging you don't use regex's. 51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 55 | - Don't use "self". These are standalone PBTs, not part of some test class. 56 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 58 | 59 | 60 | [DOCUMENTATION] 61 | 62 | {documentation} 63 | 64 | [GUIDANCE] 65 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN. 66 | - Try to only have one assert per PBT. 67 | - If you can't figure out how to correctly generate something don't test it. Just skip that test. 68 | - Only write valid Python. E.g., if you want to use await make sure the function is async. 69 | - Pretty please I'm begging you don't use regex's. 70 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the 71 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to 72 | avoid it, or use a try/catch to make sure it happens where we expect it to happen. 73 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable. 74 | - Don't use "self". These are standalone PBTs, not part of some test class. 75 | - Write the tests such that they can be pasted at the end of the file containing the code under test. 76 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN. 77 | 78 | 79 | [CODE] 80 | 81 | ```python 82 | {code} 83 | ``` 84 | 85 | [BEGIN RESPONSE] -------------------------------------------------------------------------------- /src/prompts/pytest_docs.txt: -------------------------------------------------------------------------------- 1 | API Reference 2 | This page contains the full reference to pytest’s API. 3 | 4 | Constants 5 | pytest.__version__ 6 | The current pytest version, as a string: 7 | 8 | import pytest 9 | pytest.__version__ 10 | '7.0.0' 11 | pytest.version_tuple 12 | Added in version 7.0. 13 | 14 | The current pytest version, as a tuple: 15 | 16 | import pytest 17 | pytest.version_tuple 18 | (7, 0, 0) 19 | For pre-releases, the last component will be a string with the prerelease version: -------------------------------------------------------------------------------- /src/prompts/user_questions.txt: -------------------------------------------------------------------------------- 1 | How do I optimize the performance of a Python program that processes large CSV files? 2 | 3 | Can you explain the difference between a correlation and a causation in simple terms? 4 | 5 | What are the key differences between renewable and non-renewable energy sources? 6 | 7 | Can you help me draft a polite email to reschedule a meeting for next week? 8 | 9 | Why is my Python script throwing a 'KeyError' when accessing a dictionary? 10 | 11 | Can you suggest some ideas for a short story set on a distant planet? 12 | 13 | How do I calculate the median of a dataset in Python using numpy? -------------------------------------------------------------------------------- /src/regex_interventions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regex Interventions Module 3 | 4 | This module implements testing and evaluation of model interventions specifically 5 | focused on regex pattern usage in generated code. It provides functionality to: 6 | 1. Run controlled experiments with different intervention types 7 | 2. Measure regex usage in generated code 8 | 3. Evaluate code quality and correctness 9 | 4. Compare baseline and intervention results 10 | """ 11 | 12 | import torch 13 | from transformers import AutoTokenizer 14 | from dataclasses import asdict 15 | from tqdm import tqdm 16 | from typing import Callable, Dict, List, Tuple, Union 17 | import json 18 | import time 19 | import asyncio 20 | import re 21 | import sys 22 | import os 23 | import logging 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 28 | 29 | from src.wrapper import InterventionWrapper 30 | from src.eval_config import EvalConfig, InterventionType 31 | import src.utils as utils 32 | import src.caa as caa 33 | from src.agent_eval import batch_evaluate_generations, summarize_evaluations 34 | 35 | 36 | def measure_regex_usage( 37 | responses: list[str], 38 | prompt: str, 39 | prefill: str 40 | ) -> tuple[int, int, int, int]: 41 | """Analyze generated code for regex usage and validity. 42 | 43 | Args: 44 | responses: List of generated text responses 45 | prompt: Original prompt used for generation 46 | prefill: Prefix text to remove from responses 47 | 48 | Returns: 49 | Tuple containing counts of: 50 | - Valid Python code snippets 51 | - Snippets using regex 52 | - Syntactically valid snippets 53 | - Semantically valid snippets 54 | """ 55 | valid_python_count = 0 56 | syntactically_valid_python_count = 0 57 | semantically_valid_python_count = 0 58 | regex_usage_count = 0 59 | 60 | for i, response in tqdm(enumerate(responses), desc="Measuring regex usage"): 61 | response = re.sub(r'(? list[str]: 88 | """Clean and extract actual responses from model outputs. 89 | 90 | Args: 91 | responses: Raw model outputs 92 | prompt: Original prompt to remove 93 | prefill: Prefix text to remove 94 | 95 | Returns: 96 | List of cleaned response texts 97 | """ 98 | extracted_responses = [] 99 | for i, response in enumerate(responses): 100 | response = re.sub(r'(? list[str]: 114 | """Generate multiple responses without interventions. 115 | 116 | Args: 117 | wrapper: Model wrapper instance 118 | prompt: Input prompt 119 | batch_size: Number of generations per batch 120 | total_generations: Total number of responses to generate 121 | max_new_tokens: Maximum new tokens per generation 122 | 123 | Returns: 124 | List of generated responses 125 | """ 126 | batched_prompts = [prompt] * batch_size 127 | num_batches = total_generations // batch_size 128 | generations = [] 129 | 130 | for _ in tqdm(range(num_batches), desc="Generating responses"): 131 | response = wrapper.generate(batched_prompts, max_new_tokens=max_new_tokens) 132 | generations.extend(response) 133 | 134 | return generations 135 | 136 | 137 | def run_intervention( 138 | wrapper: InterventionWrapper, 139 | prompt: str, 140 | batch_size: int, 141 | total_generations: int, 142 | max_new_tokens: int, 143 | intervention_type: str, 144 | model_params: dict, 145 | scale: int, 146 | config: EvalConfig, 147 | ) -> list[str]: 148 | """Generate responses with specified intervention. 149 | 150 | Args: 151 | wrapper: Model wrapper instance 152 | prompt: Input prompt 153 | batch_size: Number of generations per batch 154 | total_generations: Total number of responses to generate 155 | max_new_tokens: Maximum new tokens per generation 156 | intervention_type: Type of intervention to apply 157 | model_params: Parameters for the intervention 158 | scale: Scale factor for intervention 159 | config: Evaluation configuration 160 | 161 | Returns: 162 | List of generated responses with intervention applied 163 | """ 164 | batched_prompts = [prompt] * batch_size 165 | num_batches = total_generations // batch_size 166 | generations = [] 167 | 168 | module_and_hook_fn = wrapper.get_hook(intervention_type, model_params, scale, config) 169 | 170 | for _ in tqdm(range(num_batches), desc="Generating responses"): 171 | response = wrapper.generate( 172 | batched_prompts, max_new_tokens=max_new_tokens, module_and_hook_fn=module_and_hook_fn 173 | ) 174 | generations.extend(response) 175 | 176 | return generations 177 | 178 | 179 | def test_single_prompt( 180 | wrapper: InterventionWrapper, 181 | base_prompt: str, 182 | code_example: str, 183 | config: EvalConfig, 184 | model_params: dict, 185 | intervention_type: str, 186 | api_key: str, 187 | ) -> dict: 188 | """Test a single prompt with and without interventions. 189 | 190 | Args: 191 | wrapper: Model wrapper instance 192 | base_prompt: Base prompt template 193 | code_example: Code example to insert in prompt 194 | config: Evaluation configuration 195 | model_params: Model-specific parameters 196 | intervention_type: Type of intervention to test 197 | api_key: API key for LLM judge (if used) 198 | 199 | Returns: 200 | Dictionary containing: 201 | - Original and intervention generations 202 | - Evaluation results 203 | - Agent evaluations (if enabled) 204 | - Result summaries 205 | """ 206 | results = {"generations": {}, "eval_results": {}, "agent_evals": {}, "agent_summaries": {}} 207 | 208 | # Format prompt for this code example 209 | prompt = base_prompt.replace("{code}", code_example) 210 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer) 211 | formatted_prompt += config.prefill 212 | 213 | # Generate without interventions 214 | original_texts = run_generation( 215 | wrapper, 216 | formatted_prompt, 217 | config.batch_size, 218 | config.total_generations, 219 | config.max_new_tokens, 220 | ) 221 | original_texts = extract_response(original_texts, formatted_prompt, config.prefill) 222 | results["generations"]["original"] = original_texts 223 | results["eval_results"]["original"] = measure_regex_usage( 224 | original_texts, formatted_prompt, config.prefill 225 | ) 226 | logging.info(f"Original eval results: {results['eval_results']['original']}") 227 | if config.use_llm_judge: 228 | # Add agent evaluations for original texts 229 | agent_evals = batch_evaluate_generations(original_texts, prompt, api_key) 230 | results["agent_evals"]["original"] = [asdict(eval) for eval in agent_evals] 231 | results["agent_summaries"]["original"] = summarize_evaluations(agent_evals) 232 | 233 | # Generate with different intervention scales 234 | for scale in tqdm(config.scales, desc="Interventions"): 235 | modified_texts = run_intervention( 236 | wrapper, 237 | formatted_prompt, 238 | config.batch_size, 239 | config.total_generations, 240 | config.max_new_tokens, 241 | intervention_type, 242 | model_params, 243 | scale, 244 | config, 245 | ) 246 | 247 | modified_texts = extract_response(modified_texts, formatted_prompt, config.prefill) 248 | results["generations"][f"intervention_{scale}"] = modified_texts 249 | results["eval_results"][f"intervention_{scale}"] = measure_regex_usage( 250 | modified_texts, formatted_prompt, config.prefill 251 | ) 252 | logging.info(f"Intervention {scale} eval results: {results['eval_results'][f'intervention_{scale}']}") 253 | if config.use_llm_judge: 254 | # Add agent evaluations for interventions 255 | agent_evals = batch_evaluate_generations(modified_texts, prompt, api_key) 256 | results["agent_evals"][f"intervention_{scale}"] = [asdict(eval) for eval in agent_evals] 257 | results["agent_summaries"][f"intervention_{scale}"] = summarize_evaluations(agent_evals) 258 | 259 | results["prompt"] = formatted_prompt 260 | return results 261 | 262 | 263 | def test_sae_interventions(api_key: str) -> dict: 264 | """Run comprehensive intervention tests across multiple code examples. 265 | 266 | Args: 267 | api_key: API key for LLM judge evaluations 268 | 269 | Returns: 270 | Dictionary containing all test results and configurations 271 | 272 | Note: 273 | Results are saved after each code block to prevent data loss 274 | """ 275 | config = EvalConfig() 276 | results = {"config": asdict(config)} 277 | 278 | # Setup 279 | device = "cuda" if torch.cuda.is_available() else "cpu" 280 | model_params = utils.get_model_params(config.model_name) 281 | 282 | # Initialize wrapper 283 | wrapper = InterventionWrapper(model_name=config.model_name, device=device, dtype=torch.bfloat16) 284 | 285 | # Load SAE 286 | wrapper.load_sae(release=model_params["sae_release"], sae_id=model_params["sae_id"], layer_idx=model_params["targ_layer"]) 287 | 288 | # Load and format prompt 289 | base_prompt = utils.load_prompt_files(config) 290 | 291 | with open(f"{config.prompt_folder}/{config.code_filename}", "r") as f: 292 | code_blocks = json.load(f) 293 | 294 | print(f"Evaluating {len(code_blocks)} code blocks") 295 | 296 | print(f"Evaluating the following interventions: {config.intervention_types}") 297 | 298 | for intervention_type in config.intervention_types: 299 | print(f"Evaluating {intervention_type}!") 300 | results[intervention_type] = {"code_results": {}} 301 | for code_block_key, single_code_block in code_blocks.items(): 302 | results[intervention_type]["code_results"][code_block_key] = test_single_prompt( 303 | wrapper, 304 | base_prompt, 305 | single_code_block, 306 | config, 307 | model_params, 308 | intervention_type, 309 | api_key, 310 | ) 311 | 312 | # Save results after each code block 313 | with open(config.save_path, "w") as f: 314 | json.dump(results, f, indent=4) 315 | 316 | return results 317 | 318 | 319 | if __name__ == "__main__": 320 | """ 321 | Main entry point for running intervention tests. 322 | 323 | Environment variables: 324 | PYTORCH_CUDA_ALLOC_CONF: Set to "expandable_segments:True" 325 | TOKENIZERS_PARALLELISM: Set to "false" for process safety 326 | OPENAI_API_KEY: Optional API key for LLM judge 327 | """ 328 | import os 329 | 330 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 331 | 332 | # We disable this because we launch additional processes when checking for valid code 333 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 334 | 335 | import argparse 336 | 337 | torch.set_grad_enabled(False) 338 | 339 | parser = argparse.ArgumentParser() 340 | parser.add_argument( 341 | "--api_key", type=str, help="OpenAI API key", default=os.environ.get("OPENAI_API_KEY") 342 | ) 343 | args = parser.parse_args() 344 | 345 | start_time = time.time() 346 | run_results = test_sae_interventions(args.api_key) 347 | print(f"Total time: {time.time() - start_time:.2f} seconds") 348 | 349 | run_filename = "run_results.json" 350 | with open(run_filename, "w") as f: 351 | json.dump(run_results, f, indent=4) 352 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | from typing import Optional, Tuple 4 | import builtins 5 | from contextlib import contextmanager 6 | import sys 7 | from io import StringIO 8 | from transformers import AutoTokenizer 9 | import random 10 | from multiprocessing import Process, Manager 11 | 12 | 13 | from src.eval_config import EvalConfig 14 | 15 | 16 | def get_model_params(model_name: str) -> dict[str, str | int]: 17 | """Get model-specific parameters for interventions. 18 | 19 | Args: 20 | model_name: Name/path of the model to get parameters for 21 | 22 | Returns: 23 | Dictionary containing: 24 | - sae_release: SAE model release identifier 25 | - sae_id: Specific SAE identifier (if applicable) 26 | - targ_layer: Target layer for intervention 27 | - feature_idx: Feature index in SAE 28 | 29 | Raises: 30 | ValueError: If model is not supported 31 | """ 32 | if model_name == "google/gemma-2-9b-it": 33 | return { 34 | "sae_release": "gemma-scope-9b-it-res", 35 | "sae_id": "layer_9/width_16k/average_l0_88", 36 | "targ_layer": 9, 37 | "feature_idx": 3585, 38 | "secondary_feature_idx": 12650, 39 | } 40 | elif model_name == "meta-llama/Llama-3.1-8B-Instruct": 41 | return { 42 | "sae_release": "tilde-research/sieve_coding", 43 | "sae_id": None, 44 | "targ_layer": 12, # 8 45 | "feature_idx": 9853, # 9699 46 | } 47 | elif model_name == "google/gemma-2-2b-it": 48 | return { 49 | "sae_release": "gemma-scope-2b-pt-res", 50 | "sae_id": "layer_8/width_16k/average_l0_71", 51 | "targ_layer": 8, 52 | "feature_idx": 931, 53 | } 54 | else: 55 | raise ValueError(f"Unsupported model: {model_name}") 56 | 57 | 58 | def format_llm_prompt(prompt: str, tokenizer: AutoTokenizer) -> str: 59 | """Format the prompt according to Transformers instruction format.""" 60 | chat = [ 61 | {"role": "user", "content": prompt}, 62 | ] 63 | return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 64 | 65 | 66 | def load_prompt_files(config: EvalConfig) -> str: 67 | """Load and process prompt files.""" 68 | files: dict[str, str] = { 69 | "prompt": f"{config.prompt_folder}/{config.prompt_filename}", 70 | "docs": f"{config.prompt_folder}/{config.docs_filename}", 71 | } 72 | 73 | content: dict[str, str] = {} 74 | for key, filename in files.items(): 75 | with open(filename, "r") as f: 76 | content[key] = f.read() 77 | 78 | # Format the prompt 79 | prompt = content["prompt"].replace("{documentation}", content["docs"]) 80 | 81 | return prompt 82 | 83 | 84 | def extract_python(response: str, verbose: bool = True) -> Optional[str]: 85 | # Regex pattern to match python block 86 | pattern = r"```python\s*(.*?)\s*```" 87 | 88 | # Search for the pattern 89 | match = re.search(pattern, response, re.DOTALL) 90 | 91 | if match: 92 | python_str = match.group(1) 93 | return python_str 94 | else: 95 | if verbose: 96 | print("WARNING: No python block found") 97 | return None 98 | 99 | 100 | def check_for_re_usage(code_snippet: str) -> bool: 101 | """ 102 | Checks if any re module function is used in a given code snippet. 103 | 104 | This is pretty basic and may not catch all cases. 105 | """ 106 | # Define a pattern that matches common re module functions 107 | pattern = r"\bre\.(match|search|sub|findall|finditer|split|compile|fullmatch|escape|subn)\b" 108 | 109 | # Search for any of these patterns in the code snippet 110 | return bool(re.search(pattern, code_snippet)) 111 | 112 | 113 | def print_generations(generations: list[str], prompt: str, prefill: str) -> None: 114 | """Print the generated texts, removing the prompt prefix.""" 115 | for i, generation in enumerate(generations): 116 | if prompt in generation: 117 | generation = generation[len(prompt) - len(prefill) :] 118 | print(f"Generation {i}:") 119 | print(generation) 120 | print() 121 | 122 | 123 | def is_syntactically_valid_python(code: str) -> Tuple[bool, Optional[str]]: 124 | """ 125 | Check if a string contains syntactically valid Python code. 126 | 127 | Args: 128 | code: String containing Python code to validate 129 | 130 | Returns: 131 | Tuple of (is_valid: bool, error_message: Optional[str]) 132 | """ 133 | try: 134 | ast.parse(code) 135 | return True, None 136 | except SyntaxError as e: 137 | return False, f"Syntax error: {str(e)}" 138 | except Exception as e: 139 | return False, f"Parsing error: {str(e)}" 140 | 141 | 142 | @contextmanager 143 | def restricted_compile_environment(): 144 | """ 145 | Context manager that provides a restricted environment for code compilation. 146 | Temporarily replaces stdout/stderr and restricts builtins to common safe operations. 147 | """ 148 | # Save original stdout/stderr and builtins 149 | original_stdout = sys.stdout 150 | original_stderr = sys.stderr 151 | original_builtins = dict(builtins.__dict__) 152 | 153 | # Create string buffers for capturing output 154 | temp_stdout = StringIO() 155 | temp_stderr = StringIO() 156 | 157 | # Define safe exception types 158 | safe_exceptions = { 159 | "Exception": Exception, 160 | "ValueError": ValueError, 161 | "TypeError": TypeError, 162 | "AttributeError": AttributeError, 163 | "IndexError": IndexError, 164 | "KeyError": KeyError, 165 | "RuntimeError": RuntimeError, 166 | "StopIteration": StopIteration, 167 | "AssertionError": AssertionError, 168 | "NotImplementedError": NotImplementedError, 169 | "ZeroDivisionError": ZeroDivisionError, 170 | } 171 | 172 | # Expanded set of safe builtins 173 | safe_builtins = { 174 | # Constants 175 | "None": None, 176 | "False": False, 177 | "True": True, 178 | # Basic types and operations 179 | "abs": abs, 180 | "bool": bool, 181 | "int": int, 182 | "float": float, 183 | "str": str, 184 | "len": len, 185 | "type": type, 186 | "repr": repr, 187 | # Collections 188 | "list": list, 189 | "dict": dict, 190 | "set": set, 191 | "tuple": tuple, 192 | "frozenset": frozenset, 193 | "range": range, 194 | "enumerate": enumerate, 195 | "zip": zip, 196 | "reversed": reversed, 197 | # Type checking 198 | "isinstance": isinstance, 199 | "issubclass": issubclass, 200 | "hasattr": hasattr, 201 | "getattr": getattr, 202 | # Math operations 203 | "min": min, 204 | "max": max, 205 | "sum": sum, 206 | "round": round, 207 | "pow": pow, 208 | # String operations 209 | "chr": chr, 210 | "ord": ord, 211 | "format": format, 212 | # Itertools functions 213 | "filter": filter, 214 | "map": map, 215 | # Other safe operations 216 | "print": print, # Captured by StringIO 217 | "sorted": sorted, 218 | "any": any, 219 | "all": all, 220 | "iter": iter, 221 | "next": next, 222 | "slice": slice, 223 | "property": property, 224 | "staticmethod": staticmethod, 225 | "classmethod": classmethod, 226 | # Exception handling 227 | "try": "try", 228 | "except": "except", 229 | "finally": "finally", 230 | **safe_exceptions, # Add all safe exception types 231 | } 232 | 233 | try: 234 | # Replace stdout/stderr 235 | sys.stdout = temp_stdout 236 | sys.stderr = temp_stderr 237 | 238 | # Restrict builtins 239 | for key in list(builtins.__dict__.keys()): 240 | if key not in safe_builtins: 241 | del builtins.__dict__[key] 242 | 243 | # Add exception types to the builtins 244 | builtins.__dict__.update(safe_exceptions) 245 | 246 | yield temp_stdout, temp_stderr 247 | 248 | finally: 249 | # Restore original environment 250 | sys.stdout = original_stdout 251 | sys.stderr = original_stderr 252 | builtins.__dict__.clear() 253 | builtins.__dict__.update(original_builtins) 254 | 255 | 256 | def is_semantically_valid_python(code: str) -> Tuple[bool, Optional[str]]: 257 | """ 258 | Check if a string contains semantically valid Python code by: 259 | 1. Checking syntax 260 | 2. Verifying it contains actual code structure 261 | 3. Attempting to compile and validate basic execution 262 | 263 | Args: 264 | code: String containing Python code to validate 265 | 266 | Returns: 267 | Tuple of (is_valid: bool, error_message: Optional[str]) 268 | """ 269 | # First check syntax 270 | syntax_valid, syntax_error = is_syntactically_valid_python(code) 271 | if not syntax_valid: 272 | return False, syntax_error 273 | 274 | # Basic content validation 275 | code = code.strip() 276 | if not code: 277 | return False, "Empty code string" 278 | 279 | # Check for basic code structure (must have at least one function or class definition) 280 | if not any(keyword in code for keyword in ["def ", "class "]): 281 | return False, "No function or class definitions found" 282 | 283 | # Check for excessive non-ASCII characters that aren't in strings/comments 284 | code_lines = code.split("\n") 285 | invalid_lines = 0 286 | total_lines = len(code_lines) 287 | 288 | for line in code_lines: 289 | line = line.strip() 290 | # Skip empty lines and comments 291 | if not line or line.startswith("#"): 292 | total_lines -= 1 293 | continue 294 | 295 | # Count lines with too many non-ASCII characters 296 | non_ascii_count = len([c for c in line if ord(c) > 127]) 297 | if non_ascii_count / len(line) > 0.3: # More than 30% non-ASCII 298 | invalid_lines += 1 299 | 300 | # If more than 20% of non-empty, non-comment lines are invalid 301 | if total_lines > 0 and (invalid_lines / total_lines) > 0.2: 302 | return False, "Code contains too many non-ASCII characters" 303 | 304 | try: 305 | # Try to compile the code 306 | try: 307 | compiled_code = compile(code, "", "exec") 308 | except Exception as e: 309 | return False, f"Compilation error: {str(e)}" 310 | 311 | # Create a restricted globals dict with common built-ins 312 | # restricted_globals = { 313 | # '__builtins__': { 314 | # name: getattr(builtins, name) 315 | # for name in [ 316 | # 'len', 'int', 'str', 'list', 'dict', 'set', 'tuple', 317 | # 'min', 'max', 'True', 'False', 'None', 'type', 318 | # 'isinstance', 'print', 'range', 'compile', 'exec', 319 | # "import", 320 | # ] 321 | # } 322 | # } 323 | 324 | # Try to execute in the restricted environment 325 | try: 326 | exec(compiled_code, {}, {}) 327 | except Exception as e: 328 | # Some errors are acceptable for valid code 329 | error_str = str(e) 330 | acceptable_errors = [ 331 | "name 'pytest' is not defined", 332 | "name 're' is not defined", 333 | "name 'random' is not defined", 334 | "name 'time' is not defined", 335 | "name 'asyncio' is not defined", 336 | "name 'typing' is not defined", 337 | "name 'Optional' is not defined", 338 | "name 'List' is not defined", 339 | "name 'Dict' is not defined", 340 | "name 'Any' is not defined", 341 | "name 'Union' is not defined", 342 | ] 343 | if not any(err in error_str for err in acceptable_errors): 344 | return False, f"Execution error: {error_str}" 345 | 346 | return True, None 347 | 348 | except Exception as e: 349 | return False, f"Validation error: {str(e)}" 350 | 351 | 352 | def get_func_name(func_code: str) -> str: 353 | match = re.search(r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)", func_code) 354 | if match: 355 | func_name = match.group(1) 356 | else: 357 | return None 358 | return func_name 359 | 360 | 361 | def validate_llm_response( 362 | func_code: str, llm_code: str, timeout: float = 3.0, verbose: bool = False 363 | ) -> bool: 364 | """ 365 | Validates whether the LLM-generated code runs, calls the specified function, 366 | and doesn't raise errors within the given timeout. 367 | """ 368 | 369 | func_name = get_func_name(func_code) 370 | 371 | def exec_code(function_called): 372 | # List of dangerous modules to block 373 | dangerous_modules = { 374 | "os", 375 | "subprocess", 376 | "sys", 377 | "socket", 378 | "requests", 379 | "urllib", 380 | "ftplib", 381 | "telnetlib", 382 | "smtplib", 383 | "pathlib", 384 | "shutil", 385 | } 386 | 387 | def safe_import(name, *args, **kwargs): 388 | if name in dangerous_modules: 389 | raise ImportError(f"Import of {name} is not allowed for security reasons") 390 | return __import__(name, *args, **kwargs) 391 | 392 | # Create safe globals with all built-ins 393 | safe_globals = {} 394 | for name in dir(builtins): 395 | safe_globals[name] = getattr(builtins, name) 396 | 397 | safe_globals["__import__"] = safe_import 398 | safe_globals["__builtins__"] = safe_globals 399 | 400 | # Add commonly needed modules 401 | safe_globals.update( 402 | { 403 | "re": re, 404 | "random": random, 405 | "__name__": "__main__", 406 | } 407 | ) 408 | 409 | # Execute the function definition 410 | try: 411 | exec(func_code, safe_globals) 412 | if func_name not in safe_globals: 413 | function_called["error"] = f"Function {func_name} was not properly defined." 414 | return 415 | except Exception as e: 416 | function_called["error"] = f"Error in function definition: {str(e)}" 417 | return 418 | 419 | # Store the original function and create wrapper 420 | original_func = safe_globals[func_name] 421 | 422 | def wrapper(*args, **kwargs): 423 | function_called["called"] = True 424 | return original_func(*args, **kwargs) 425 | 426 | safe_globals[func_name] = wrapper 427 | 428 | # Execute the test code 429 | try: 430 | exec(llm_code, safe_globals) 431 | except Exception as e: 432 | function_called["error"] = f"Error in test execution: {str(e)}" 433 | return 434 | 435 | # Shared dictionary for results 436 | manager = Manager() 437 | function_called = manager.dict() 438 | function_called["called"] = False 439 | 440 | # Run in separate process with timeout 441 | p = Process(target=exec_code, args=(function_called,)) 442 | p.start() 443 | p.join(timeout) 444 | 445 | if p.is_alive(): 446 | p.terminate() 447 | p.join() 448 | print("Execution timed out.") 449 | else: 450 | if "error" in function_called: 451 | if verbose: 452 | print(f"An error occurred: {function_called['error']}") 453 | elif not function_called["called"]: 454 | if verbose: 455 | print(f"{func_name}() was not called.") 456 | else: 457 | if verbose: 458 | print(f"{func_name}() was successfully called.") 459 | return True 460 | return False 461 | 462 | 463 | def validate_single_llm_response(prompt: str, response: str, verbose: bool = True) -> bool: 464 | # Regex pattern to match python block 465 | pattern = r"```python\s*(.*?)\s*```" 466 | 467 | # Search for the pattern 468 | matches = re.findall(pattern, prompt, re.DOTALL) 469 | 470 | original_code = matches[-1] 471 | 472 | llm_python = extract_python(response, verbose=False) 473 | 474 | if llm_python is None: 475 | return False 476 | 477 | func_name = get_func_name(llm_python) 478 | 479 | if func_name is None: 480 | return False 481 | 482 | llm_python = llm_python + f"\n\n{func_name}()" 483 | valid_code = validate_llm_response(original_code, llm_python) 484 | 485 | return valid_code 486 | 487 | 488 | def validate_all_llm_responses( 489 | data: dict, intervention_method: str, code_id: str, scale: int 490 | ) -> tuple[float, float]: 491 | prompt = data[intervention_method]["code_results"][code_id]["prompt"] 492 | 493 | # Regex pattern to match python block 494 | pattern = r"```python\s*(.*?)\s*```" 495 | 496 | # Search for the pattern 497 | matches = re.findall(pattern, prompt, re.DOTALL) 498 | 499 | original_code = matches[-1] 500 | 501 | total = 0 502 | valid = 0 503 | syntactically_valid = 0 504 | 505 | for response in data[intervention_method]["code_results"][code_id]["generations"][scale]: 506 | total += 1 507 | 508 | llm_python = extract_python(response, verbose=False) 509 | 510 | if llm_python is None: 511 | continue 512 | 513 | func_name = get_func_name(llm_python) 514 | 515 | if func_name is None: 516 | continue 517 | 518 | syntactically_valid += is_syntactically_valid_python(llm_python)[0] 519 | 520 | llm_python = llm_python + f"\n\n{func_name}()" 521 | valid_code = validate_llm_response(original_code, llm_python) 522 | 523 | if valid_code: 524 | valid += 1 525 | 526 | return (valid / total), (syntactically_valid / total) 527 | -------------------------------------------------------------------------------- /src/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from sae_lens import SAE 5 | from typing import Callable, Optional, Union, List, cast 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | import contextlib 9 | 10 | from src.eval_config import EvalConfig, InterventionType 11 | import src.caa as caa 12 | from sae.sae import Sae 13 | 14 | try: 15 | import flash_attn 16 | USE_FA = True 17 | print("Flash attention installed") 18 | except ImportError: 19 | print("Flash attention not installed, using regular attention") 20 | USE_FA = False 21 | 22 | 23 | @contextlib.contextmanager 24 | def add_hook( 25 | module: torch.nn.Module, 26 | hook: Callable, 27 | ): 28 | """Temporarily adds a forward hook to a model module. 29 | 30 | Args: 31 | module: The PyTorch module to hook 32 | hook: The hook function to apply 33 | 34 | Yields: 35 | None: Used as a context manager 36 | 37 | Example: 38 | with add_hook(model.layer, hook_fn): 39 | output = model(input) 40 | """ 41 | handle = module.register_forward_hook(hook) 42 | try: 43 | yield 44 | finally: 45 | handle.remove() 46 | 47 | 48 | def get_activation_addition_output_hook( 49 | vectors: list[Float[Tensor, "d_model"]], coeffs: list[float] 50 | ) -> Callable: 51 | """Creates a hook function that adds scaled vectors to layer activations. 52 | 53 | This hook performs a simple activation steering by adding scaled vectors 54 | to the layer's output activations. This is the most basic form of intervention. 55 | 56 | Args: 57 | vectors: List of vectors to add, each of shape (d_model,) 58 | coeffs: List of scaling coefficients for each vector 59 | 60 | Returns: 61 | Hook function that modifies layer activations 62 | 63 | """ 64 | 65 | def hook_fn(module, input, output): 66 | if isinstance(output, tuple): 67 | resid_BLD = output[0] 68 | rest = output[1:] 69 | else: 70 | resid_BLD = output 71 | rest = () 72 | 73 | for vector, coeff in zip(vectors, coeffs): 74 | vector = vector.to(resid_BLD.device) 75 | resid_BLD = resid_BLD + coeff * vector 76 | 77 | if rest: 78 | return (resid_BLD, *rest) 79 | else: 80 | return resid_BLD 81 | 82 | return hook_fn 83 | 84 | 85 | def get_conditional_per_input_hook( 86 | encoder_vectors: list[Float[Tensor, "d_model"]], 87 | decoder_vectors: list[Float[Tensor, "d_model"]], 88 | scales: list[float], 89 | encoder_thresholds: list[float], 90 | ) -> Callable: 91 | """Creates a hook function that conditionally applies interventions based on input-level activation. 92 | 93 | This hook checks if any token in the input sequence triggers the encoder vector 94 | above threshold. If triggered, applies the intervention to the entire sequence. 95 | 96 | Args: 97 | encoder_vectors: List of vectors used to detect activation patterns 98 | decoder_vectors: List of vectors to add when conditions are met 99 | scales: Scaling factors for decoder vectors 100 | encoder_thresholds: Threshold values for each encoder vector 101 | 102 | Returns: 103 | Hook function that conditionally modifies activations 104 | 105 | Note: 106 | - Zeros out BOS token activations to prevent false triggers 107 | - Intervention applies to entire sequence if any token triggers 108 | """ 109 | 110 | def hook_fn(module, input, output): 111 | if isinstance(output, tuple): 112 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0] 113 | rest = output[1:] 114 | else: 115 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output 116 | rest = () 117 | 118 | B, L, D = resid_BLD.shape 119 | 120 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip( 121 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 122 | ): 123 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device) 124 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device) 125 | 126 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D) 127 | feature_acts_BL[:, 0] = 0 # zero out the BOS token 128 | intervention_threshold_B11 = ((feature_acts_BL > encoder_threshold).any(dim=1).float())[ 129 | :, None, None 130 | ] 131 | decoder_BLD = einops.repeat(decoder_vector_D * coeff, "D -> B L D", B=B, L=L).to( 132 | dtype=resid_BLD.dtype 133 | ) 134 | 135 | resid_BLD += decoder_BLD * intervention_threshold_B11 136 | 137 | if rest: 138 | return (resid_BLD, *rest) 139 | else: 140 | return resid_BLD 141 | 142 | return hook_fn 143 | 144 | 145 | def get_conditional_per_token_hook( 146 | encoder_vectors: list[Float[Tensor, "d_model"]], 147 | decoder_vectors: list[Float[Tensor, "d_model"]], 148 | scales: list[float], 149 | encoder_thresholds: list[float], 150 | ) -> Callable: 151 | """Creates a hook function that conditionally applies interventions per token. 152 | 153 | Unlike the per-input hook, this applies interventions independently to each token 154 | based on whether it exceeds the encoder threshold. 155 | 156 | Args: 157 | encoder_vectors: List of vectors used to detect activation patterns 158 | decoder_vectors: List of vectors to add when conditions are met 159 | scales: Scaling factors for decoder vectors 160 | encoder_thresholds: Threshold values for each encoder vector 161 | 162 | Returns: 163 | Hook function that modifies activations on a per-token basis 164 | 165 | Note: 166 | More granular than per-input hook as it can selectively modify 167 | specific tokens in the sequence 168 | """ 169 | 170 | def hook_fn(module, input, output): 171 | if isinstance(output, tuple): 172 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0] 173 | rest = output[1:] 174 | else: 175 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output 176 | rest = () 177 | 178 | B, L, D = resid_BLD.shape 179 | 180 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip( 181 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 182 | ): 183 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device) 184 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device) 185 | 186 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D) 187 | intervention_mask_BL = feature_acts_BL > encoder_threshold 188 | decoder_BLD = einops.repeat(decoder_vector_D * coeff, "D -> B L D", B=B, L=L).to( 189 | dtype=resid_BLD.dtype 190 | ) 191 | 192 | resid_BLD = torch.where( 193 | intervention_mask_BL.unsqueeze(-1), 194 | resid_BLD + decoder_BLD, 195 | resid_BLD, 196 | ) 197 | 198 | if rest: 199 | return (resid_BLD, *rest) 200 | else: 201 | return resid_BLD 202 | 203 | return hook_fn 204 | 205 | 206 | def get_clamping_hook( 207 | encoder_vectors: list[Float[Tensor, "d_model"]], 208 | decoder_vectors: list[Float[Tensor, "d_model"]], 209 | scales: list[float], 210 | ) -> Callable: 211 | """Creates a hook function that clamps activations using decoder vectors. 212 | 213 | This hook fixes the activations to a target value in the decoder vector direction. 214 | 215 | Args: 216 | encoder_vectors: List of vectors defining directions to clamp 217 | decoder_vectors: List of vectors defining intervention directions 218 | scales: Target values for clamping (acts as offset after zeroing) 219 | 220 | Returns: 221 | Hook function that clamps and redirects activations 222 | 223 | Note: 224 | Useful for always having a final activation value in the decoder vector direction. 225 | """ 226 | 227 | # coeff = -feature_acts_BL 228 | def hook_fn(module, input, output): 229 | if isinstance(output, tuple): 230 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0] 231 | rest = output[1:] 232 | else: 233 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output 234 | rest = () 235 | B, L, D = resid_BLD.shape 236 | 237 | for encoder_vector_D, decoder_vector_D, coeff in zip( 238 | encoder_vectors, decoder_vectors, scales 239 | ): 240 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device) 241 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device) 242 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D) 243 | decoder_BLD = (-feature_acts_BL[:, :, None] + coeff) * decoder_vector_D[None, None, :] 244 | resid_BLD = torch.where( 245 | feature_acts_BL[:, :, None] > 0, 246 | resid_BLD + decoder_BLD, 247 | resid_BLD, 248 | ) 249 | 250 | if rest: 251 | return (resid_BLD, *rest) 252 | else: 253 | return resid_BLD 254 | 255 | return hook_fn 256 | 257 | 258 | def get_conditional_clamping_hook( 259 | encoder_vectors: list[Float[Tensor, "d_model"]], 260 | decoder_vectors: list[Float[Tensor, "d_model"]], 261 | scales: list[float], 262 | encoder_thresholds: list[float], 263 | ) -> Callable: 264 | """Creates a hook function that conditionally clamps activations. 265 | 266 | Combines conditional intervention with clamping - only clamps activations 267 | when they exceed the encoder threshold with the decoder intervention. 268 | 269 | Args: 270 | encoder_vectors: List of vectors defining directions to monitor 271 | decoder_vectors: List of vectors defining intervention directions 272 | scales: Target values for clamping 273 | encoder_thresholds: Threshold values that trigger clamping 274 | 275 | Returns: 276 | Hook function that conditionally clamps and modifies activations 277 | 278 | Note: 279 | Most sophisticated intervention type, combining benefits of 280 | conditional application and activation clamping 281 | """ 282 | 283 | def hook_fn(module, input, output): 284 | if isinstance(output, tuple): 285 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0] 286 | rest = output[1:] 287 | else: 288 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output 289 | rest = () 290 | 291 | B, L, D = resid_BLD.shape 292 | 293 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip( 294 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 295 | ): 296 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device) 297 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device) 298 | 299 | # Get encoder activations 300 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D) 301 | 302 | # Create mask for where encoder activation exceeds threshold 303 | intervention_mask_BL = feature_acts_BL > encoder_threshold 304 | 305 | # Calculate clamping amount only where mask is True 306 | decoder_BLD = (-feature_acts_BL[:, :, None] + coeff) * decoder_vector_D[None, None, :] 307 | 308 | # Apply clamping only where both mask is True and activation is positive 309 | resid_BLD = torch.where( 310 | (intervention_mask_BL[:, :, None] & (feature_acts_BL[:, :, None] > 0)), 311 | resid_BLD + decoder_BLD, 312 | resid_BLD, 313 | ) 314 | 315 | if rest: 316 | return (resid_BLD, *rest) 317 | else: 318 | return resid_BLD 319 | 320 | return hook_fn 321 | 322 | 323 | class InterventionWrapper: 324 | """Wrapper class for applying interventions to language models. 325 | 326 | This class manages model loading, intervention application, and generation 327 | with various steering techniques. 328 | 329 | Attributes: 330 | model: The underlying language model 331 | tokenizer: Tokenizer for the model 332 | sae: Optional Sparse Autoencoder for intervention 333 | device: Device to run computations on 334 | caa_steering_vector: Cached steering vector for interventions 335 | probe_vector: Cached probe vector for guided steering 336 | """ 337 | 338 | def __init__(self, model_name: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): 339 | """Initialize the wrapper with a specified model. 340 | 341 | Args: 342 | model_name: HuggingFace model identifier 343 | device: Computing device ('cuda' or 'cpu') 344 | dtype: Data type for model weights 345 | """ 346 | 347 | self.model = AutoModelForCausalLM.from_pretrained( 348 | model_name, 349 | device_map="auto", # Load model on first visible GPU 350 | torch_dtype=dtype, 351 | attn_implementation="flash_attention_2" if USE_FA else "eager", 352 | ) 353 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 354 | self.model_name = model_name 355 | self.interventions = {} 356 | self.sae: Optional[Union[SAE, Sae]] = None 357 | self.device = device 358 | self.caa_steering_vector: Optional[Tensor] = None 359 | self.probe_vector: Optional[Tensor] = None 360 | self.probe_bias: Optional[Tensor] = None 361 | 362 | def generate( 363 | self, 364 | batched_prompts: list[str], 365 | max_new_tokens: int, 366 | temperature: float = 1.0, 367 | module_and_hook_fn: Optional[tuple[torch.nn.Module, Callable]] = None, 368 | ) -> list[str]: 369 | """Generate text with optional interventions. 370 | 371 | Args: 372 | batched_prompts: List of input prompts 373 | max_new_tokens: Maximum number of tokens to generate 374 | temperature: Sampling temperature 375 | module_and_hook_fn: Optional tuple of (module, hook_function) for intervention 376 | 377 | Returns: 378 | List of generated text strings 379 | 380 | Note: 381 | Prompts must contain the model's BOS token 382 | """ 383 | assert all( 384 | self.tokenizer.bos_token in prompt for prompt in batched_prompts 385 | ), "All prompts must contain the BOS token." 386 | batched_tokens = self.tokenizer( 387 | batched_prompts, add_special_tokens=False, return_tensors="pt" 388 | ).to(self.device) 389 | batched_tokens = batched_tokens["input_ids"] 390 | 391 | if module_and_hook_fn: 392 | module, hook_fn = module_and_hook_fn 393 | context_manager = add_hook( 394 | module=module, 395 | hook=hook_fn, 396 | ) 397 | with context_manager: 398 | generated_toks = self.model.generate( 399 | batched_tokens, 400 | max_new_tokens=max_new_tokens, 401 | do_sample=True, 402 | temperature=temperature, 403 | ) 404 | else: 405 | generated_toks = self.model.generate( 406 | batched_tokens, 407 | max_new_tokens=max_new_tokens, 408 | do_sample=True, 409 | temperature=temperature, 410 | ) 411 | 412 | response = [self.tokenizer.decode(tokens) for tokens in generated_toks] 413 | 414 | return response 415 | 416 | def get_hook( 417 | self, 418 | intervention_type: str, 419 | model_params: dict, 420 | scale: Union[int, float], 421 | config: EvalConfig, 422 | ) -> tuple[torch.nn.Module, Callable]: 423 | """Create a hook function for the specified intervention type. 424 | 425 | Args: 426 | intervention_type: Type of intervention to apply 427 | model_params: Parameters for the intervention including: 428 | - targ_layer: Target layer index 429 | - feature_idx: Feature index for SAE 430 | scale: Scaling factor for the intervention 431 | config: Configuration for evaluation and steering 432 | 433 | Returns: 434 | Tuple of (target_module, hook_function) 435 | 436 | Raises: 437 | AttributeError: If SAE is required but not loaded 438 | ValueError: If intervention type is not supported 439 | 440 | Note: 441 | Different intervention types require different preconditions: 442 | - SAE interventions require loaded SAE 443 | - Steering vector interventions calculate vectors on first use 444 | - Probe interventions train probes on first use 445 | """ 446 | module = self.model.model.layers[model_params["targ_layer"]] 447 | 448 | # Convert scale to float for type compatibility 449 | scales: List[float] = [float(scale)] 450 | 451 | if self.sae is None: 452 | raise AttributeError("SAE must be loaded before getting hook") 453 | 454 | # Get encoder/decoder vectors with proper null checks 455 | encoder_vectors = [cast(Tensor, self.sae.W_enc[:, [model_params["feature_idx"]]].squeeze())] 456 | decoder_vectors = [cast(Tensor, self.sae.W_dec[[model_params["feature_idx"]]].squeeze())] 457 | 458 | # Normalize decoder vectors 459 | decoder_vectors = [v / v.norm() for v in decoder_vectors] 460 | encoder_thresholds = [2.0] if "gemma" in self.model_name else [5.0] # TODO make this an arg, llama 8B features use a higher threshold 461 | print(f"Encoder thresholds: {encoder_thresholds}") 462 | if hasattr(self.sae, "threshold"): # Check for jumprelu 463 | threshold_val = float(self.sae.threshold[model_params["feature_idx"]]) 464 | encoder_thresholds = [threshold_val] 465 | 466 | # for i, threshold in enumerate(encoder_thresholds): 467 | # encoder_thresholds[i] = threshold + config.encoder_threshold_bias 468 | 469 | # Initialize steering vectors if needed 470 | steering_vectors: List[Tensor] = [] 471 | if intervention_type in [ 472 | InterventionType.CONSTANT_STEERING_VECTOR.value, 473 | InterventionType.CONDITIONAL_STEERING_VECTOR.value, 474 | InterventionType.SAE_STEERING_VECTOR.value, 475 | InterventionType.PROBE_STEERING_VECTOR.value, 476 | InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value, 477 | ]: 478 | if self.caa_steering_vector is None: 479 | self.caa_steering_vector = caa.calculate_steering_vector( 480 | config.prompt_folder, 481 | config.contrastive_prompts_filename, 482 | config.code_filename, 483 | self.model, 484 | self.tokenizer, 485 | config.prompt_type, 486 | model_params["targ_layer"], 487 | ) 488 | 489 | # TODO: Support multiple steering vectors 490 | steering_vector_threshold = ( 491 | caa.get_threshold( 492 | config, 493 | model_params, 494 | self, 495 | self.caa_steering_vector, 496 | encoder_vectors[0], 497 | encoder_thresholds[0], 498 | ) 499 | + config.steering_vector_threshold_bias 500 | ) 501 | 502 | steering_vector_thresholds = [steering_vector_threshold] 503 | encoder_steering_vectors = [self.caa_steering_vector] 504 | steering_vectors = [self.caa_steering_vector] 505 | steering_vectors = [v / v.norm() for v in steering_vectors] 506 | 507 | if intervention_type in [ 508 | InterventionType.PROBE_SAE.value, 509 | InterventionType.PROBE_SAE_CLAMPING.value, 510 | InterventionType.PROBE_STEERING_VECTOR.value, 511 | InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value, 512 | ]: 513 | if self.probe_vector is None: 514 | self.probe_vector, self.probe_bias = caa.calculate_probe_vector( 515 | config.prompt_folder, 516 | config.probe_prompts_filename, 517 | config.code_filename, 518 | self.model, 519 | self.tokenizer, 520 | config.prompt_type, 521 | model_params["targ_layer"], 522 | ) 523 | probe_vector_threshold = [self.probe_bias] 524 | 525 | if intervention_type == InterventionType.CONSTANT_SAE.value: 526 | hook_fn = get_activation_addition_output_hook(decoder_vectors, scales) 527 | elif intervention_type == InterventionType.CONSTANT_STEERING_VECTOR.value: 528 | hook_fn = get_activation_addition_output_hook(steering_vectors, scales) 529 | elif intervention_type == InterventionType.PROBE_STEERING_VECTOR.value: 530 | hook_fn = get_conditional_per_token_hook( 531 | [self.probe_vector], steering_vectors, scales, probe_vector_threshold 532 | ) 533 | elif intervention_type == InterventionType.PROBE_SAE.value: 534 | hook_fn = get_conditional_per_token_hook( 535 | [self.probe_vector], decoder_vectors, scales, probe_vector_threshold 536 | ) 537 | elif intervention_type == InterventionType.PROBE_SAE_CLAMPING.value: 538 | hook_fn = get_conditional_clamping_hook( 539 | [self.probe_vector], decoder_vectors, scales, probe_vector_threshold 540 | ) 541 | elif intervention_type == InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value: 542 | hook_fn = get_conditional_clamping_hook( 543 | [self.probe_vector], steering_vectors, scales, probe_vector_threshold 544 | ) 545 | elif intervention_type == InterventionType.SAE_STEERING_VECTOR.value: 546 | hook_fn = get_conditional_per_token_hook( 547 | encoder_vectors, steering_vectors, scales, encoder_thresholds 548 | ) 549 | elif intervention_type == InterventionType.CONDITIONAL_PER_INPUT.value: 550 | hook_fn = get_conditional_per_input_hook( 551 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 552 | ) 553 | elif intervention_type == InterventionType.CONDITIONAL_PER_TOKEN.value: 554 | hook_fn = get_conditional_per_token_hook( 555 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 556 | ) 557 | elif intervention_type == InterventionType.CONDITIONAL_STEERING_VECTOR.value: 558 | hook_fn = get_conditional_per_token_hook( 559 | encoder_steering_vectors, steering_vectors, scales, steering_vector_thresholds 560 | ) 561 | elif intervention_type == InterventionType.CLAMPING.value: 562 | hook_fn = get_clamping_hook(encoder_vectors, decoder_vectors, scales) 563 | elif intervention_type == InterventionType.CONDITIONAL_CLAMPING.value: 564 | hook_fn = get_conditional_clamping_hook( 565 | encoder_vectors, decoder_vectors, scales, encoder_thresholds 566 | ) 567 | else: 568 | raise ValueError(f"Unsupported intervention type: {intervention_type}") 569 | return module, hook_fn 570 | 571 | def load_sae(self, release: str, sae_id: str, layer_idx: int): 572 | """Load a Sparse Autoencoder for interventions. 573 | 574 | Args: 575 | release: Release identifier for the SAE 576 | sae_id: Specific SAE identifier 577 | layer_idx: Layer index the SAE was trained on 578 | 579 | Note: 580 | Supports both tilde-research and standard SAE formats 581 | """ 582 | if "tilde-research" in release: 583 | self.sae = Sae.from_pretrained(release, layer_idx=layer_idx) 584 | self.sae = self.sae.to(dtype=self.model.dtype, device=self.model.device) 585 | return 586 | self.sae, _, _ = SAE.from_pretrained( 587 | release=release, sae_id=sae_id, device=str(self.model.device) 588 | ) 589 | self.sae = self.sae.to(dtype=self.model.dtype, device=self.model.device) 590 | --------------------------------------------------------------------------------