├── .gitignore
├── 2b_probes_data
└── combine.ipynb
├── README.md
├── example.py
├── mmlu_evals.py
├── plots
├── 2b_experiments_mmlu.pdf
├── 2b_experiments_mmlu.png
├── 2b_experiments_semantics.pdf
├── 2b_experiments_semantics.png
├── 2b_experiments_syntax.pdf
├── 2b_experiments_syntax.png
├── 9b_long_hard_mmlu.pdf
├── 9b_long_hard_mmlu.png
├── 9b_long_hard_semantics.pdf
├── 9b_long_hard_semantics.png
├── 9b_long_hard_syntax.pdf
├── 9b_long_hard_syntax.png
├── 9b_pr_curves.png
├── 9b_short_hard_mmlu.pdf
├── 9b_short_hard_mmlu.png
├── 9b_short_hard_semantics.pdf
├── 9b_short_hard_semantics.png
├── 9b_short_hard_syntax.pdf
├── 9b_short_hard_syntax.png
├── 9b_short_normal_mmlu.pdf
├── 9b_short_normal_mmlu.png
├── 9b_short_normal_semantics.pdf
├── 9b_short_normal_semantics.png
├── 9b_short_normal_syntax.pdf
├── 9b_short_normal_syntax.png
├── final_charts
│ ├── gemma-2-2b.png
│ ├── gemma-2-9b.png
│ ├── llama-3.1-8b.png
│ └── pareto_frontier.png
├── llama_layer12_mmlu.pdf
├── llama_layer12_mmlu.png
├── llama_layer12_semantics.pdf
├── llama_layer12_semantics.png
├── llama_layer12_syntax.pdf
├── llama_layer12_syntax.png
├── llama_layer8_mmlu.pdf
├── llama_layer8_mmlu.png
├── llama_layer8_semantics.pdf
├── llama_layer8_semantics.png
├── llama_layer8_syntax.pdf
└── llama_layer8_syntax.png
├── pyproject.toml
├── sae
├── kernels.py
├── sae.py
└── utils.py
├── src
├── agent_eval.py
├── caa.py
├── count_activations.py
├── eval_config.py
├── prompts
│ ├── activation_counting_code.json
│ ├── code.json
│ ├── code.txt
│ ├── code_hard.txt
│ ├── contrastive_prompts.json
│ ├── longer_pytest_docs.txt
│ ├── make_json_prompts.ipynb
│ ├── not_regex_code.txt
│ ├── probe_prompts.json
│ ├── prompt.txt
│ ├── prompt_hard.txt
│ ├── prompt_medium.txt
│ ├── prompt_no_regex.txt
│ ├── prompt_repetitions.txt
│ ├── pytest_docs.txt
│ └── user_questions.txt
├── regex_interventions.py
├── utils.py
└── wrapper.py
├── view_activations.ipynb
└── view_run_results.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | *.pkl
3 | *.pt
4 | __pycache__/
5 |
6 |
7 |
8 | # Distribution / packaging
9 | dist/
10 | build/
11 | *.egg-info/
12 | *.egg
13 |
14 | # Jupyter Notebook
15 | .ipynb_checkpoints
16 |
17 | # IDE specific files
18 | .idea/
19 | .vscode/
20 | *.swp
21 | *.swo
22 | .DS_Store
23 |
24 | # Unit test / coverage reports
25 | htmlcov/
26 | .tox/
27 | .coverage
28 | .coverage.*
29 | .cache
30 | coverage.xml
31 | *.cover
32 | .pytest_cache/
33 |
34 | # Logs
35 | *.log
36 |
37 | # Local development settings
38 | .env
39 | .env.local
--------------------------------------------------------------------------------
/2b_probes_data/combine.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 4,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "dict_keys(['config', 'constant_sae', 'constant_steering_vector', 'conditional_per_token', 'sae_steering_vector'])\n",
13 | "dict_keys(['config', 'probe_steering_vector', 'probe_sae', 'conditional_clamping', 'conditional_steering_vector'])\n",
14 | "dict_keys(['config', 'probe_sae_clamping', 'probe_steering_vector_clamping'])\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "import pickle\n",
20 | "\n",
21 | "filename1 = \"results_mmlu_evals_baseline.pkl\"\n",
22 | "filename2 = \"results_mmlu_evals.pkl\"\n",
23 | "filename3 = \"probe_only_results_mmlu_evals.pkl\"\n",
24 | "filename3 = \"probe_clamping_results_mmlu_evals.pkl\"\n",
25 | "\n",
26 | "with open(filename1, \"rb\") as f:\n",
27 | " results1 = pickle.load(f)\n",
28 | "\n",
29 | "with open(filename2, \"rb\") as f:\n",
30 | " results2 = pickle.load(f)\n",
31 | "\n",
32 | "with open(filename3, \"rb\") as f:\n",
33 | " results3 = pickle.load(f)\n",
34 | "\n",
35 | "print(results1.keys())\n",
36 | "print(results2.keys())\n",
37 | "print(results3.keys())\n",
38 | "\n",
39 | "for key in results2:\n",
40 | " if key != \"config\":\n",
41 | " results1[key] = results2[key]\n",
42 | "\n",
43 | "with open(\"results_mmlu_evals_combined.pkl\", \"wb\") as f:\n",
44 | " pickle.dump(results1, f)"
45 | ]
46 | }
47 | ],
48 | "metadata": {
49 | "kernelspec": {
50 | "display_name": ".venv",
51 | "language": "python",
52 | "name": "python3"
53 | },
54 | "language_info": {
55 | "codemirror_mode": {
56 | "name": "ipython",
57 | "version": 3
58 | },
59 | "file_extension": ".py",
60 | "mimetype": "text/x-python",
61 | "name": "python",
62 | "nbconvert_exporter": "python",
63 | "pygments_lexer": "ipython3",
64 | "version": "3.11.8"
65 | }
66 | },
67 | "nbformat": 4,
68 | "nbformat_minor": 2
69 | }
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sieve: Steering for Fine-grained Regex Control
2 |
3 | Sieve is a simple framework for applying targeted interventions to language models using Sparse Autoencoder (SAE) techniques. It enables precise control over model behavior through feature-level manipulations.
4 |
5 | Read the full case study on Sieve's application and results on our blog [here](https://www.tilderesearch.com/blog/sieve).
6 |
7 | ## Results
8 |
9 | Our experiments demonstrate Pareto dominance of SAE-based methods on fine-grained control. Below are the Pareto frontiers across models.
10 |
11 |
12 |
13 | ## Quick Start
14 |
15 | You will have to install flash attention separately. Without flash attention, the model will fall back to eager.
16 |
17 | ```bash
18 | # Install flash attention
19 | pip install wheel
20 | pip install flash-attn --no-build-isolation
21 | ```
22 |
23 | ```bash
24 | # Install package
25 | pip install -e .
26 |
27 | # Run regex interventions example
28 | python src/regex_interventions.py
29 | ```
30 |
31 | Use `pip install -e .` rather than `pip install -r requirements.txt`, as it simplifies using imports and makes it easier to add tests.
32 |
33 | Then run `python src/regex_interventions.py`. You can change settings in `src/eval_config.py`.
34 |
35 | This will:
36 |
37 | 1. Load a pre-trained model and SAE
38 | 2. Run baseline code generation
39 | 3. Apply various interventions to control regex usage
40 | 4. Evaluate and compare results
41 |
42 | It can be used to directly reproduce reported results.
43 |
44 | ## Runtime & Parallelization
45 |
46 | This is a parameter sweep across multiple intervention methods and scales, which can take significant time to run. With 200 generations per evaluation method, each method takes approximately 8 hours to complete on a single GPU. Probe methods take longer, as they require generating data and then training probes.
47 |
48 | However, the evaluation can be parallelized by taking the following steps:
49 |
50 | 1. Split the intervention methods across multiple GPUs
51 | 2. Run separate jobs with different subsets of `intervention_types` in `eval_config.py`
52 | 3. Save results to different output files (e.g., `results_gpu1.json`, `results_gpu2.json`)
53 |
54 | The visualization notebooks (`view_run_results.ipynb`) are designed to handle multiple result files, allowing you to combine outputs from parallel runs for analysis. You can also reduce the number of generations per method to speed up evaluation.
55 |
56 | MMLU evaluations, by contrast, are much shorter and complete in roughly 1-2 hours for all methods.
57 |
58 | ## Overview
59 |
60 | Sieve provides tools for:
61 |
62 | - Loading and applying pre trained SAEs
63 | - Performing targeted feature interventions
64 | - Analyzing activation patterns
65 | - Evaluating intervention effects
66 |
67 | The framework supports multiple intervention types:
68 |
69 | 1. **Constant SAE**: Direct feature manipulation
70 | 2. **Conditional SAE**: Activation-dependent interventions
71 | 3. **Contrastive Activations (CAA)**: Steering vectors
72 | 4. **Probe-based**: Linear probes
73 |
74 | ### InterventionWrapper
75 |
76 | The main interface for applying interventions:
77 |
78 | ```python
79 | from src.wrapper import InterventionWrapper
80 |
81 | # Initialize wrapper
82 | wrapper = InterventionWrapper(model_name="google/gemma-2b-it")
83 |
84 | # Load SAE
85 | wrapper.load_sae(release="gemma-scope-2b-pt-res", sae_id="layer_8/width_16k/average_l0_71", layer_idx=8)
86 |
87 | # Generate with intervention
88 | output = wrapper.generate(prompt, intervention_type="CONDITIONAL_PER_TOKEN", scale=1.0)
89 | ```
90 |
91 | ### Evaluation Tools
92 |
93 | - `regex_interventions.py`: Test regex usage patterns (this is the main file for reproducing results)
94 | - `count_activations.py`: Analyze performance for different kidns of classifiers (SAE encoder vectors, CAA vectors, linear probes)
95 | - `agent_eval.py`: Evaluate generation quality
96 |
97 | ## Configuration
98 |
99 | Modify settings in `src/eval_config.py`:
100 |
101 | ```python
102 | class EvalConfig:
103 | model_name: str = "google/gemma-2b-it"
104 | intervention_types: list = [
105 | "CONSTANT_SAE",
106 | "CONDITIONAL_PER_TOKEN",
107 | "CONDITIONAL_PER_INPUT",
108 | "CONSTANT_STEERING_VECTOR"
109 | ]
110 | ```
111 | The eval config can be used to change the prompt, intervention scales, base model, etc. appropriately as well. However, to change the SAE features used, you have to modify `utils.py` to change the feature indices and layer. The layer 8 llama feature is currently commented out.
112 |
113 | ## Project Structure
114 |
115 | ```
116 | sieve/
117 | ├── src/
118 | │ ├── agent_eval.py # LLM-based evaluation tools
119 | │ ├── caa.py # Contrastive Activation Addition core
120 | │ ├── count_activations.py # Probe analysis tools
121 | │ ├── regex_interventions.py # Main entry point
122 | │ ├── wrapper.py # Core intervention framework
123 | │ ├── eval_config.py # Configuration settings
124 | │ └── utils.py # Helper functions
125 | │
126 | ├── mmlu_evals.py # MMLU benchmark evaluation script
127 | ├── view_activations.ipynb # Notebook for activation analysis
128 | ├── view_run_results.ipynb # Notebook for experiment results
129 | ├── pyproject.toml # Package configuration
130 | └── README.md
131 | ```
132 |
133 | ## Collaboration
134 |
135 | Sieve is a collaboration between Tilde and Adam Karvonen, leveraging:
136 |
137 | - GemmaScope pretrained SAEs [1]
138 | - Custom code-specific SAEs for Llama models
139 | - Advanced intervention techniques
140 |
141 | ## Citation
142 |
143 | If you use Sieve or any of its results in your research, please cite:
144 |
145 | ```bibtex
146 | @article{karvonen2024sieve,
147 | title={Sieve: SAEs Beat Baselines on a Real-World Task (A Code Generation Case Study)},
148 | author={Karvonen, Adam and Pai, Dhruv and Wang, Mason and Keigwin, Ben},
149 | journal={Tilde Research Blog},
150 | year={2024},
151 | month={12},
152 | url={https://www.tilderesearch.com/blog/sieve},
153 | note={Blog post}
154 | }
155 | ```
156 |
157 | For the GemmaScope pretrained SAEs, please cite:
158 |
159 | ```bibtex
160 | @misc{lieberum2024gemmascopeopensparse,
161 | title={Gemma Scope: Open Sparse Autoencoders Everywhere All At Once on Gemma 2},
162 | author={Tom Lieberum and Senthooran Rajamanoharan and Arthur Conmy and Lewis Smith and Nicolas Sonnerat and Vikrant Varma and János Kramár and Anca Dragan and Rohin Shah and Neel Nanda},
163 | year={2024},
164 | eprint={2408.05147},
165 | archivePrefix={arXiv},
166 | primaryClass={cs.LG},
167 | url={https://arxiv.org/abs/2408.05147},
168 | }
169 | ```
170 |
171 | [1] Lieberum et al., "Gemma Scope: Open Sparse Autoencoders Everywhere All At Once on Gemma 2", 2024
172 |
173 | ## License
174 |
175 | MIT License
176 |
177 | Copyright (c) 2024 Tilde Research, Inc. and Adam Karvonen
178 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.wrapper import InterventionWrapper
3 | from src.eval_config import EvalConfig, InterventionType
4 |
5 | # Step 1: Set up the device
6 | device = "cuda" if torch.cuda.is_available() else "cpu"
7 |
8 | # Step 2: Define model parameters
9 | MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
10 | LAYER = 12
11 | FEATURE_IDX = 9853
12 | SCALE = -8.0
13 |
14 | # Step 3: Create the InterventionWrapper
15 | wrapper = InterventionWrapper(MODEL_NAME, device=device)
16 |
17 | # Step 4: Load the SAE
18 | wrapper.load_sae(f"tilde-research/sieve_coding", sae_id=None, layer_idx=LAYER)
19 |
20 | # Step 5: Set up the intervention
21 | config = EvalConfig()
22 | model_params = {
23 | "targ_layer": LAYER,
24 | "feature_idx": FEATURE_IDX
25 | }
26 |
27 |
28 |
29 | # Step 6: Format input text using chat template
30 | input_text = "Write a python function using the re module to match a numerical substring in a string."
31 | chat = [{"role": "user", "content": input_text}]
32 | formatted_text = wrapper.tokenizer.apply_chat_template(
33 | chat,
34 | tokenize=False,
35 | add_generation_prompt=True
36 | )
37 |
38 | # Step 7: Generate text without intervention
39 | print("Generating without intervention...")
40 | generated_text_original = wrapper.generate(
41 | [formatted_text],
42 | max_new_tokens=200,
43 | temperature=0.2,
44 | module_and_hook_fn=None
45 | )[0]
46 |
47 | # Step 8: Generate text with intervention
48 | print("\nGenerating with intervention...")
49 | module_and_hook_fn = wrapper.get_hook(
50 | intervention_type=InterventionType.CONDITIONAL_PER_TOKEN.value,
51 | model_params=model_params,
52 | scale=SCALE,
53 | config=config
54 | )
55 |
56 | generated_text_intervened = wrapper.generate(
57 | [formatted_text],
58 | max_new_tokens=800,
59 | temperature=0.2,
60 | # repetition_penalty=1.15,
61 | module_and_hook_fn=module_and_hook_fn
62 | )[0]
63 |
64 | # Step 8: Print and compare the results
65 | print("Original generated text:")
66 | print(generated_text_original)
67 | print("\nGenerated text with intervention:")
68 | print(generated_text_intervened)
69 |
70 | # Optional: Calculate and print the difference in token length
71 | # original_tokens = len(wrapper.model.to_tokens(generated_text_original)[0])
72 | # intervened_tokens = len(wrapper.model.to_tokens(generated_text_intervened)[0])
73 | # print(f"\nOriginal generation length: {original_tokens} tokens")
74 | # print(f"Intervened generation length: {intervened_tokens} tokens")
75 | # print(f"Difference: {intervened_tokens - original_tokens} tokens")
76 |
--------------------------------------------------------------------------------
/mmlu_evals.py:
--------------------------------------------------------------------------------
1 | import lm_eval
2 | from lm_eval.models.huggingface import HFLM
3 | import torch
4 | import pickle
5 | from dataclasses import asdict
6 |
7 | from src.eval_config import EvalConfig
8 | from src.wrapper import InterventionWrapper
9 | import src.utils as utils
10 |
11 |
12 | if __name__ == "__main__":
13 | # Step 1: Set up the device
14 | torch.set_grad_enabled(False)
15 | device = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | config = EvalConfig()
18 |
19 | model_params = utils.get_model_params(config.model_name)
20 |
21 | tasks = [
22 | "mmlu_high_school_statistics",
23 | "mmlu_high_school_computer_science",
24 | "mmlu_high_school_mathematics",
25 | "mmlu_high_school_physics",
26 | "mmlu_high_school_biology",
27 | ]
28 |
29 | # Step 3: Create the InterventionWrapper
30 | wrapper = InterventionWrapper(config.model_name, device=device, dtype=torch.bfloat16)
31 |
32 | # Step 4: Load the SAE
33 | wrapper.load_sae(release=model_params["sae_release"], sae_id=model_params["sae_id"], layer_idx=model_params["targ_layer"])
34 |
35 | results = {"config": asdict(config)}
36 |
37 | # indexes all tasks from the `lm_eval/tasks` subdirectory.
38 | # Alternatively, you can set `TaskManager(include_path="path/to/my/custom/task/configs")`
39 | # to include a set of tasks in a separate directory.
40 |
41 | lm_obj = HFLM(pretrained=wrapper.model)
42 |
43 | task_manager = lm_eval.tasks.TaskManager()
44 |
45 | for intervention_type in config.intervention_types:
46 | results[intervention_type] = {}
47 |
48 | results[intervention_type][0] = lm_eval.simple_evaluate(
49 | model=lm_obj,
50 | tasks=tasks,
51 | num_fewshot=0,
52 | task_manager=task_manager,
53 | )
54 |
55 | for scale in config.scales:
56 | module, hook_fn = wrapper.get_hook(intervention_type, model_params, scale, config)
57 |
58 | # NOTE: Make sure to remove the hook after using it.
59 | handle = module.register_forward_hook(hook_fn)
60 |
61 | # Setting `task_manager` to the one above is optional and should generally be done
62 | # if you want to include tasks from paths other than ones in `lm_eval/tasks`.
63 | # `simple_evaluate` will instantiate its own task_manager if it is set to None here.
64 |
65 | results[intervention_type][scale] = lm_eval.simple_evaluate( # call simple_evaluate
66 | model=lm_obj,
67 | tasks=tasks,
68 | num_fewshot=0,
69 | task_manager=task_manager,
70 | )
71 |
72 | with open("results_mmlu_evals.pkl", "wb") as f:
73 | pickle.dump(results, f)
74 |
75 | handle.remove()
76 |
--------------------------------------------------------------------------------
/plots/2b_experiments_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_mmlu.pdf
--------------------------------------------------------------------------------
/plots/2b_experiments_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_mmlu.png
--------------------------------------------------------------------------------
/plots/2b_experiments_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_semantics.pdf
--------------------------------------------------------------------------------
/plots/2b_experiments_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_semantics.png
--------------------------------------------------------------------------------
/plots/2b_experiments_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_syntax.pdf
--------------------------------------------------------------------------------
/plots/2b_experiments_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/2b_experiments_syntax.png
--------------------------------------------------------------------------------
/plots/9b_long_hard_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_mmlu.pdf
--------------------------------------------------------------------------------
/plots/9b_long_hard_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_mmlu.png
--------------------------------------------------------------------------------
/plots/9b_long_hard_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_semantics.pdf
--------------------------------------------------------------------------------
/plots/9b_long_hard_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_semantics.png
--------------------------------------------------------------------------------
/plots/9b_long_hard_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_syntax.pdf
--------------------------------------------------------------------------------
/plots/9b_long_hard_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_long_hard_syntax.png
--------------------------------------------------------------------------------
/plots/9b_pr_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_pr_curves.png
--------------------------------------------------------------------------------
/plots/9b_short_hard_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_mmlu.pdf
--------------------------------------------------------------------------------
/plots/9b_short_hard_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_mmlu.png
--------------------------------------------------------------------------------
/plots/9b_short_hard_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_semantics.pdf
--------------------------------------------------------------------------------
/plots/9b_short_hard_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_semantics.png
--------------------------------------------------------------------------------
/plots/9b_short_hard_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_syntax.pdf
--------------------------------------------------------------------------------
/plots/9b_short_hard_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_hard_syntax.png
--------------------------------------------------------------------------------
/plots/9b_short_normal_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_mmlu.pdf
--------------------------------------------------------------------------------
/plots/9b_short_normal_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_mmlu.png
--------------------------------------------------------------------------------
/plots/9b_short_normal_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_semantics.pdf
--------------------------------------------------------------------------------
/plots/9b_short_normal_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_semantics.png
--------------------------------------------------------------------------------
/plots/9b_short_normal_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_syntax.pdf
--------------------------------------------------------------------------------
/plots/9b_short_normal_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/9b_short_normal_syntax.png
--------------------------------------------------------------------------------
/plots/final_charts/gemma-2-2b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/gemma-2-2b.png
--------------------------------------------------------------------------------
/plots/final_charts/gemma-2-9b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/gemma-2-9b.png
--------------------------------------------------------------------------------
/plots/final_charts/llama-3.1-8b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/llama-3.1-8b.png
--------------------------------------------------------------------------------
/plots/final_charts/pareto_frontier.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/final_charts/pareto_frontier.png
--------------------------------------------------------------------------------
/plots/llama_layer12_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_mmlu.pdf
--------------------------------------------------------------------------------
/plots/llama_layer12_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_mmlu.png
--------------------------------------------------------------------------------
/plots/llama_layer12_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_semantics.pdf
--------------------------------------------------------------------------------
/plots/llama_layer12_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_semantics.png
--------------------------------------------------------------------------------
/plots/llama_layer12_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_syntax.pdf
--------------------------------------------------------------------------------
/plots/llama_layer12_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer12_syntax.png
--------------------------------------------------------------------------------
/plots/llama_layer8_mmlu.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_mmlu.pdf
--------------------------------------------------------------------------------
/plots/llama_layer8_mmlu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_mmlu.png
--------------------------------------------------------------------------------
/plots/llama_layer8_semantics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_semantics.pdf
--------------------------------------------------------------------------------
/plots/llama_layer8_semantics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_semantics.png
--------------------------------------------------------------------------------
/plots/llama_layer8_syntax.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_syntax.pdf
--------------------------------------------------------------------------------
/plots/llama_layer8_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tilde-research/sieve/0f6c49f6f83b381c20ae09aecea754b77b8aa103/plots/llama_layer8_syntax.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.build.targets.wheel]
6 | packages = ["."]
7 |
8 | [project]
9 | name = "sieve"
10 | version = "0.1.0"
11 | requires-python = ">=3.10"
12 | dependencies = [
13 | "torch>=2.4.1,<2.5.0",
14 | "transformers>=4.45.2",
15 | "sae_lens>=3.23.0",
16 | "openai>=1.52.1",
17 | "ipykernel",
18 | "lm_eval",
19 | "seaborn"
20 | ]
21 |
22 | [tool.pyright]
23 | typeCheckingMode = "standard"
24 |
--------------------------------------------------------------------------------
/sae/kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied from https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py
3 | """
4 |
5 | import torch
6 | import triton
7 | import triton.language as tl
8 |
9 |
10 | def triton_sparse_transpose_dense_matmul(
11 | sparse_indices: torch.Tensor,
12 | sparse_values: torch.Tensor,
13 | dense: torch.Tensor,
14 | N: int,
15 | BLOCK_SIZE_AK=128,
16 | ) -> torch.Tensor:
17 | """
18 | calculates sparse.T @ dense (i.e reducing along the collated dimension of sparse)
19 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
20 |
21 | sparse_indices is shape (A, k)
22 | sparse_values is shape (A, k)
23 | dense is shape (A, B)
24 |
25 | output is shape (N, B)
26 | """
27 |
28 | assert sparse_indices.shape == sparse_values.shape
29 | assert sparse_indices.is_contiguous()
30 | assert sparse_values.is_contiguous()
31 | assert dense.is_contiguous() # contiguous along B
32 |
33 | K = sparse_indices.shape[1]
34 | A = dense.shape[0]
35 | assert sparse_indices.shape[0] == A
36 |
37 | # COO-format and sorted
38 | sorted_indices = sparse_indices.view(-1).sort()
39 | coo_indices = torch.stack(
40 | [
41 | torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[
42 | sorted_indices.indices
43 | ],
44 | sorted_indices.values,
45 | ]
46 | ) # shape (2, A * K)
47 | coo_values = sparse_values.view(-1)[sorted_indices.indices] # shape (A * K,)
48 | return triton_coo_sparse_dense_matmul(
49 | coo_indices, coo_values, dense, N, BLOCK_SIZE_AK
50 | )
51 |
52 |
53 | def triton_coo_sparse_dense_matmul(
54 | coo_indices: torch.Tensor,
55 | coo_values: torch.Tensor,
56 | dense: torch.Tensor,
57 | N: int,
58 | BLOCK_SIZE_AK=128,
59 | ) -> torch.Tensor:
60 | AK = coo_indices.shape[1]
61 | B = dense.shape[1]
62 |
63 | out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype)
64 |
65 | def grid(META):
66 | return triton.cdiv(AK, META["BLOCK_SIZE_AK"]), 1
67 |
68 | triton_sparse_transpose_dense_matmul_kernel[grid](
69 | coo_indices,
70 | coo_values,
71 | dense,
72 | out,
73 | stride_da=dense.stride(0),
74 | stride_db=dense.stride(1),
75 | B=B,
76 | N=N,
77 | AK=AK,
78 | BLOCK_SIZE_AK=BLOCK_SIZE_AK,
79 | BLOCK_SIZE_B=triton.next_power_of_2(B),
80 | )
81 | return out
82 |
83 |
84 | @triton.jit
85 | def triton_sparse_transpose_dense_matmul_kernel(
86 | coo_indices_ptr,
87 | coo_values_ptr,
88 | dense_ptr,
89 | out_ptr,
90 | stride_da,
91 | stride_db,
92 | B,
93 | N,
94 | AK,
95 | BLOCK_SIZE_AK: tl.constexpr,
96 | BLOCK_SIZE_B: tl.constexpr,
97 | ):
98 | """
99 | coo_indices is shape (2, AK)
100 | coo_values is shape (AK,)
101 | dense is shape (A, B), contiguous along B
102 | out is shape (N, B)
103 | """
104 |
105 | pid_ak = tl.program_id(0)
106 | pid_b = tl.program_id(1)
107 |
108 | coo_offsets = tl.arange(0, BLOCK_SIZE_AK)
109 | b_offsets = tl.arange(0, BLOCK_SIZE_B)
110 |
111 | A_coords = tl.load(
112 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
113 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
114 | )
115 | K_coords = tl.load(
116 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK,
117 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
118 | )
119 | values = tl.load(
120 | coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
121 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
122 | )
123 |
124 | last_k = tl.min(K_coords)
125 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
126 |
127 | for ind in range(BLOCK_SIZE_AK):
128 | if ind + pid_ak * BLOCK_SIZE_AK < AK:
129 | # workaround to do A_coords[ind]
130 | a = tl.sum(
131 | tl.where(
132 | tl.arange(0, BLOCK_SIZE_AK) == ind,
133 | A_coords,
134 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
135 | )
136 | )
137 |
138 | k = tl.sum(
139 | tl.where(
140 | tl.arange(0, BLOCK_SIZE_AK) == ind,
141 | K_coords,
142 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
143 | )
144 | )
145 |
146 | v = tl.sum(
147 | tl.where(
148 | tl.arange(0, BLOCK_SIZE_AK) == ind,
149 | values,
150 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32),
151 | )
152 | )
153 |
154 | tl.device_assert(k < N)
155 |
156 | if k != last_k:
157 | tl.atomic_add(
158 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
159 | accum,
160 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
161 | )
162 | accum *= 0
163 | last_k = k
164 |
165 | if v != 0:
166 | accum += v * tl.load(
167 | dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B
168 | )
169 |
170 | tl.atomic_add(
171 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
172 | accum,
173 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
174 | )
175 |
176 |
177 | def triton_sparse_dense_matmul(
178 | sparse_indices: torch.Tensor,
179 | sparse_values: torch.Tensor,
180 | dense: torch.Tensor,
181 | ) -> torch.Tensor:
182 | """
183 | calculates sparse @ dense (i.e reducing along the uncollated dimension of sparse)
184 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
185 |
186 | sparse_indices is shape (A, k)
187 | sparse_values is shape (A, k)
188 | dense is shape (N, B)
189 |
190 | output is shape (A, B)
191 | """
192 | N = dense.shape[0]
193 | assert sparse_indices.shape == sparse_values.shape
194 | assert sparse_indices.is_contiguous()
195 | assert sparse_values.is_contiguous()
196 | assert dense.is_contiguous() # contiguous along B
197 |
198 | A = sparse_indices.shape[0]
199 | K = sparse_indices.shape[1]
200 | B = dense.shape[1]
201 |
202 | out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype)
203 |
204 | triton_sparse_dense_matmul_kernel[(A,)](
205 | sparse_indices,
206 | sparse_values,
207 | dense,
208 | out,
209 | stride_dn=dense.stride(0),
210 | stride_db=dense.stride(1),
211 | A=A,
212 | B=B,
213 | N=N,
214 | K=K,
215 | BLOCK_SIZE_K=triton.next_power_of_2(K),
216 | BLOCK_SIZE_B=triton.next_power_of_2(B),
217 | )
218 | return out
219 |
220 |
221 | @triton.jit
222 | def triton_sparse_dense_matmul_kernel(
223 | sparse_indices_ptr,
224 | sparse_values_ptr,
225 | dense_ptr,
226 | out_ptr,
227 | stride_dn,
228 | stride_db,
229 | A,
230 | B,
231 | N,
232 | K,
233 | BLOCK_SIZE_K: tl.constexpr,
234 | BLOCK_SIZE_B: tl.constexpr,
235 | ):
236 | """
237 | sparse_indices is shape (A, K)
238 | sparse_values is shape (A, K)
239 | dense is shape (N, B), contiguous along B
240 | out is shape (A, B)
241 | """
242 |
243 | pid = tl.program_id(0)
244 |
245 | offsets_k = tl.arange(0, BLOCK_SIZE_K)
246 | sparse_indices = tl.load(
247 | sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K
248 | ) # shape (K,)
249 | sparse_values = tl.load(
250 | sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K
251 | ) # shape (K,)
252 |
253 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
254 |
255 | offsets_b = tl.arange(0, BLOCK_SIZE_B)
256 |
257 | for k in range(K):
258 | # workaround to do sparse_indices[k]
259 | i = tl.sum(
260 | tl.where(
261 | tl.arange(0, BLOCK_SIZE_K) == k,
262 | sparse_indices,
263 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
264 | )
265 | )
266 | # workaround to do sparse_values[k]
267 | v = tl.sum(
268 | tl.where(
269 | tl.arange(0, BLOCK_SIZE_K) == k,
270 | sparse_values,
271 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32),
272 | )
273 | )
274 |
275 | tl.device_assert(i < N)
276 | if v != 0:
277 | accum += v * tl.load(
278 | dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B
279 | )
280 |
281 | tl.store(
282 | out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B
283 | )
284 |
285 |
286 | def triton_dense_dense_sparseout_matmul(
287 | dense1: torch.Tensor,
288 | dense2: torch.Tensor,
289 | at_indices: torch.Tensor,
290 | ) -> torch.Tensor:
291 | """
292 | dense1: shape (A, B)
293 | dense2: shape (B, N)
294 | at_indices: shape (A, K)
295 | out values: shape (A, K)
296 | calculates dense1 @ dense2 only for the indices in at_indices
297 |
298 | equivalent to (dense1 @ dense2).gather(1, at_indices)
299 | """
300 | A, B = dense1.shape
301 | N = dense2.shape[1]
302 | assert dense2.shape[0] == B
303 | assert at_indices.shape[0] == A
304 | K = at_indices.shape[1]
305 | assert at_indices.is_contiguous()
306 |
307 | assert dense1.stride(1) == 1, "dense1 must be contiguous along B"
308 | assert dense2.stride(0) == 1, "dense2 must be contiguous along B"
309 |
310 | if K > 512:
311 | # print("WARN - using naive matmul for large K")
312 | # naive is more efficient for large K
313 | return (dense1 @ dense2).gather(1, at_indices)
314 |
315 | out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype)
316 |
317 | # grid = lambda META: (triton.cdiv(A, META['BLOCK_SIZE_A']),)
318 |
319 | triton_dense_dense_sparseout_matmul_kernel[(A,)](
320 | dense1,
321 | dense2,
322 | at_indices,
323 | out,
324 | stride_d1a=dense1.stride(0),
325 | stride_d1b=dense1.stride(1),
326 | stride_d2b=dense2.stride(0),
327 | stride_d2n=dense2.stride(1),
328 | A=A,
329 | B=B,
330 | N=N,
331 | K=K,
332 | BLOCK_SIZE_B=triton.next_power_of_2(B),
333 | BLOCK_SIZE_N=triton.next_power_of_2(N),
334 | BLOCK_SIZE_K=triton.next_power_of_2(K),
335 | )
336 |
337 | return out
338 |
339 |
340 | @triton.jit
341 | def triton_dense_dense_sparseout_matmul_kernel(
342 | dense1_ptr,
343 | dense2_ptr,
344 | at_indices_ptr,
345 | out_ptr,
346 | stride_d1a,
347 | stride_d1b,
348 | stride_d2b,
349 | stride_d2n,
350 | A,
351 | B,
352 | N,
353 | K,
354 | BLOCK_SIZE_B: tl.constexpr,
355 | BLOCK_SIZE_N: tl.constexpr,
356 | BLOCK_SIZE_K: tl.constexpr,
357 | ):
358 | """
359 | dense1: shape (A, B)
360 | dense2: shape (B, N)
361 | at_indices: shape (A, K)
362 | out values: shape (A, K)
363 | """
364 |
365 | pid = tl.program_id(0)
366 |
367 | offsets_k = tl.arange(0, BLOCK_SIZE_K)
368 | at_indices = tl.load(
369 | at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K
370 | ) # shape (K,)
371 |
372 | offsets_b = tl.arange(0, BLOCK_SIZE_B)
373 | dense1 = tl.load(
374 | dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B
375 | ) # shape (B,)
376 |
377 | accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
378 |
379 | for k in range(K):
380 | # workaround to do at_indices[b]
381 | i = tl.sum(
382 | tl.where(
383 | tl.arange(0, BLOCK_SIZE_K) == k,
384 | at_indices,
385 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
386 | )
387 | )
388 | tl.device_assert(i < N)
389 |
390 | dense2col = tl.load(
391 | dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B
392 | ) # shape (B,)
393 | accum += tl.where(
394 | tl.arange(0, BLOCK_SIZE_K) == k,
395 | tl.sum(dense1 * dense2col),
396 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
397 | )
398 |
399 | tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K)
400 |
401 |
402 | class TritonDecoder(torch.autograd.Function):
403 | @staticmethod
404 | def forward(ctx, sparse_indices, sparse_values, decoder_weight):
405 | ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
406 | return triton_sparse_dense_matmul(
407 | sparse_indices, sparse_values, decoder_weight.T
408 | )
409 |
410 | @staticmethod
411 | def backward(ctx, grad_output):
412 | sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
413 |
414 | assert grad_output.is_contiguous(), "grad_output must be contiguous"
415 |
416 | decoder_grad = triton_sparse_transpose_dense_matmul(
417 | sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1]
418 | ).T
419 |
420 | return (
421 | None,
422 | triton_dense_dense_sparseout_matmul(
423 | grad_output, decoder_weight, sparse_indices
424 | ),
425 | # decoder is contiguous when transposed so this is a matching layout
426 | decoder_grad,
427 | None,
428 | )
429 |
--------------------------------------------------------------------------------
/sae/sae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import pickle
4 | import json
5 | import os
6 | from typing import Tuple, Union
7 | from huggingface_hub import hf_hub_download
8 | from torch import Tensor
9 | from .utils import SaeConfig
10 |
11 | """ This logic is taken from Eleuther's SAE repo. """
12 | # Fallback implementation of SAE decoder
13 | def eager_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor):
14 | buf = top_acts.new_zeros(top_acts.shape[:-1] + (W_dec.shape[-1],))
15 | acts = buf.scatter_(dim=-1, index=top_indices, src=top_acts)
16 | return acts @ W_dec.mT
17 |
18 |
19 | # Triton implementation of SAE decoder
20 | def triton_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor):
21 | return TritonDecoder.apply(top_indices, top_acts, W_dec)
22 |
23 |
24 | try:
25 | from .kernels import TritonDecoder
26 | except ImportError:
27 | decoder_impl = eager_decode
28 | print("Triton not installed, using eager implementation of SAE decoder.")
29 | else:
30 | if os.environ.get("SAE_DISABLE_TRITON") == "1":
31 | print("Triton disabled, using eager implementation of SAE decoder.")
32 | decoder_impl = eager_decode
33 | else:
34 | decoder_impl = triton_decode
35 |
36 |
37 | class TopK(nn.Module):
38 | __is_sparse__ = True
39 | """Top-k activation function"""
40 | def __init__(self, k: int):
41 | super().__init__()
42 | self.k = k
43 |
44 | def sparse_forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
45 | x = nn.ReLU()(x)
46 | return torch.topk(x, self.k, dim=-1, sorted=False)
47 |
48 | def dense_forward(self, x: Tensor) -> Tensor:
49 | acts, indices = self.sparse_forward(x)
50 | return acts.scatter_(dim=-1, index=indices, src=acts)
51 |
52 | def forward(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]:
53 | if self.__is_sparse__:
54 | return self.sparse_forward(x)
55 | else:
56 | return self.dense_forward(x)
57 |
58 | def __repr__(self):
59 | return f"TopK(k={self.k})"
60 |
61 | class Sae(nn.Module):
62 | """Streamlined Sparse Autoencoder for inference only"""
63 |
64 | def __init__(self, config: SaeConfig):
65 | super().__init__()
66 | self.cfg = config
67 | self.d_in = config.d_model
68 | self.d_sae = config.d_model * config.expansion_factor
69 |
70 | # Core parameters
71 | self.W_enc_DF = nn.Parameter(torch.empty(self.d_in, self.d_sae))
72 | self.b_enc_F = nn.Parameter(torch.zeros(self.d_sae))
73 | self.W_dec_FD = nn.Parameter(torch.empty(self.d_sae, self.d_in))
74 | self.b_dec_D = nn.Parameter(torch.zeros(self.d_in))
75 |
76 | self.device = torch.device(config.device)
77 | self.dtype = getattr(torch, config.dtype.split(".")[1])
78 | self.activation_fns = None
79 | self.to(self.device, self.dtype)
80 | self.eval()
81 |
82 | @torch.no_grad()
83 | def encode(self, x: torch.Tensor, activation_fn_idx: int = None) -> torch.Tensor:
84 | if activation_fn_idx is None:
85 | activation_fn_idx = self.cfg.eval_activation_idx
86 |
87 |
88 |
89 | pre_acts = torch.matmul(x, self.W_enc_DF) + self.b_enc_F
90 | acts = self.activation_fns[activation_fn_idx](pre_acts)
91 | return acts
92 |
93 | @torch.no_grad()
94 | def decode(self, features: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], eager: bool = False) -> torch.Tensor:
95 | if isinstance(features, tuple): # Sparse feats
96 | top_indices, top_acts = features
97 | if eager or top_indices.shape[-1] >= 512: # Heuristic for when triton kernel is slower than eager decoding
98 | y = eager_decode(top_indices, top_acts, self.W_dec_FD.mT)
99 | else:
100 | y = decoder_impl(top_indices, top_acts, self.W_dec_FD.mT)
101 | else:
102 | y = torch.matmul(features, self.W_dec_FD)
103 | y = y + self.b_dec_D
104 | return y
105 |
106 | @torch.no_grad()
107 | def forward(self, x: torch.Tensor, activation_fn_idx: int = None) -> torch.Tensor:
108 | if activation_fn_idx is None:
109 | activation_fn_idx = self.cfg.eval_activation_idx
110 | initial_shape = x.shape
111 | x = x.reshape(-1, self.d_in)
112 | f = self.encode(x, activation_fn_idx)
113 | return self.decode(f).reshape(initial_shape), f
114 |
115 | @classmethod
116 | def from_pretrained(cls, repo_id: str, cache_dir: str = None, layer_idx: int = 12, token=None) -> "Sae":
117 | if cache_dir is None:
118 | cache_dir = os.path.expanduser(f"~/.cache/huggingface/{repo_id}")
119 | os.makedirs(cache_dir, exist_ok=True)
120 |
121 | config_file = hf_hub_download(repo_id=repo_id, filename="config.json", cache_dir=cache_dir, token=token)
122 | with open(config_file, 'r') as f:
123 | config_dict = json.load(f)
124 |
125 | assert "sae_cfg" in config_dict, "config.json must contain 'sae_cfg' key"
126 | config = SaeConfig.from_dict(config_dict["sae_cfg"])
127 | model = cls(config)
128 |
129 | weight_file = f"layer_{layer_idx}.pt"
130 | weights_path = hf_hub_download(repo_id=repo_id, filename=weight_file, cache_dir=cache_dir, token=token)
131 |
132 | with open(weights_path, "rb") as f:
133 | state_dict = pickle.load(f)
134 |
135 | model.load_state_dict(state_dict)
136 |
137 | model.activation_fns = [TopK(k=config.activation_fn_kwargs[i]["k"]) for i in range(len(config.activation_fn_kwargs))]
138 |
139 | return model
140 |
141 | def __repr__(self):
142 | return f"Sae(d_in={self.d_in}, d_sae={self.d_sae})"
143 |
144 | @property
145 | def W_enc(self):
146 | return self.W_enc_DF
147 |
148 | @property
149 | def W_dec(self):
150 | return self.W_dec_FD
151 |
152 | @property
153 | def b_enc(self):
154 | return self.b_enc_F
155 |
156 | @property
157 | def b_dec(self):
158 | return self.b_dec_D
--------------------------------------------------------------------------------
/sae/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from huggingface_hub import hf_hub_download, list_repo_files
5 | from typing import List, Optional, Dict
6 | from dataclasses import dataclass
7 | import mmap
8 |
9 | @dataclass
10 | class SaeConfig:
11 | """Configuration for SAE inference - only essential parameters"""
12 | d_model: int
13 | expansion_factor: int = 8
14 | device: str = "cuda" if torch.cuda.is_available() else "cpu"
15 | dtype: str = "torch.bfloat16"
16 | apply_activation_fn: bool = True
17 | activation_fn_names: List[str] = None
18 | activation_fn_kwargs: List[dict] = None
19 | eval_activation_idx: int = 0
20 |
21 | @classmethod
22 | def from_dict(cls, config_dict: dict) -> "SaeConfig":
23 | essential_params = {
24 | 'd_model', 'expansion_factor', 'device', 'dtype',
25 | 'apply_activation_fn',
26 | 'activation_fn_names', 'activation_fn_kwargs',
27 | 'eval_activation_idx'
28 | }
29 | filtered_dict = {k: v for k, v in config_dict.items() if k in essential_params}
30 | return cls(**filtered_dict)
31 |
--------------------------------------------------------------------------------
/src/agent_eval.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Dict, Any, List
2 | from openai import OpenAI
3 | from dataclasses import dataclass
4 | import src.utils as utils
5 | import json
6 | from tqdm import tqdm
7 | import time
8 | import os
9 |
10 | @dataclass
11 | class CodeEvaluation:
12 | """Stores the evaluation results for a code snippet"""
13 | is_syntactically_valid: bool
14 | is_semantically_valid: bool
15 | uses_regex: bool
16 | executes_successfully: bool
17 | follows_prompt: bool
18 | explanation: str
19 | properties: Dict[str, Any]
20 |
21 | def evaluate_code_with_gpt4(code: str, prompt: str, client: OpenAI) -> Tuple[bool, bool, str, Dict[str, Any]]:
22 | """
23 | Evaluate code using GPT-4 to determine if it successfully executes the prompt
24 | and analyze its properties.
25 | """
26 | system_prompt = """You are an expert Python programmer evaluating code solutions.
27 | Analyze the given code and determine:
28 | 1. Whether it successfully implements the requirements from the prompt
29 | 2. Whether it would execute successfully
30 | 3. Key properties of the implementation
31 |
32 | Provide your response in JSON format with the following fields:
33 | {
34 | "executes_successfully": bool,
35 | "follows_prompt": bool,
36 | "explanation": str,
37 | "properties": {
38 | "uses_list_comprehension": bool,
39 | "uses_error_handling": bool,
40 | "is_efficient": bool,
41 | "is_readable": bool,
42 | }
43 | }
44 | """
45 |
46 | user_message = f"""
47 | Original Prompt:
48 | {prompt}
49 |
50 | Code to Evaluate: ```python
51 | {code} ```
52 |
53 | Evaluate this code and provide your analysis in the requested JSON format.
54 | """
55 |
56 | try:
57 | response = client.chat.completions.create(
58 | model="gpt-4o-mini",
59 | messages=[
60 | {"role": "system", "content": system_prompt},
61 | {"role": "user", "content": user_message}
62 | ],
63 | temperature=0.1
64 | )
65 |
66 | result = response.choices[0].message.content
67 | if result is None:
68 | return False, False, "No response from GPT-4", {}
69 |
70 | try:
71 | parsed_result = json.loads(result)
72 | except json.JSONDecodeError:
73 | try:
74 | parsed_result = eval(result)
75 | except Exception:
76 | return False, False, f"Failed to parse response: {result}", {}
77 |
78 | return (
79 | parsed_result["executes_successfully"],
80 | parsed_result["follows_prompt"],
81 | parsed_result["explanation"],
82 | parsed_result["properties"]
83 | )
84 |
85 | except Exception as e:
86 | return False, False, f"GPT-4 evaluation failed: {str(e)}", {}
87 |
88 | def evaluate_code(
89 | code: str,
90 | prompt: str,
91 | client: OpenAI
92 | ) -> CodeEvaluation:
93 | """
94 | Comprehensive evaluation of a code snippet using both automated checks
95 | and GPT-4 analysis.
96 | """
97 | # Run automated checks
98 | is_syntactically_valid, _ = utils.is_syntactically_valid_python(code)
99 | is_semantically_valid, _ = utils.is_semantically_valid_python(code)
100 | uses_regex = utils.check_for_re_usage(code)
101 |
102 | # Get GPT-4 evaluation
103 | executes_successfully, follows_prompt, explanation, properties = (
104 | evaluate_code_with_gpt4(code, prompt, client)
105 | )
106 |
107 | return CodeEvaluation(
108 | is_syntactically_valid=is_syntactically_valid,
109 | is_semantically_valid=is_semantically_valid,
110 | uses_regex=uses_regex,
111 | executes_successfully=executes_successfully,
112 | follows_prompt=follows_prompt,
113 | explanation=explanation,
114 | properties=properties
115 | )
116 |
117 | def batch_evaluate_generations(
118 | generations: list[str],
119 | prompt: str,
120 | api_key: Optional[str] = None,
121 | batch_size: int = 10,
122 | retry_delay: float = 1.0,
123 | max_retries: int = 3
124 | ) -> list[CodeEvaluation]:
125 | """
126 | Evaluate multiple code generations in batches with retry logic.
127 |
128 | Args:
129 | generations: List of generated code snippets
130 | prompt: The original prompt that requested the code
131 | api_key: Optional OpenAI API key
132 | batch_size: Number of evaluations to process in parallel
133 | retry_delay: Delay between retries in seconds
134 | max_retries: Maximum number of retry attempts
135 |
136 | Returns:
137 | List of CodeEvaluation objects
138 | """
139 | client = OpenAI(
140 | api_key=api_key,
141 | ) if api_key else OpenAI()
142 |
143 | results = []
144 | valid_codes = []
145 |
146 | # First extract all valid Python code
147 | for code in generations:
148 | python_code = utils.extract_python(code)
149 | if python_code is not None:
150 | valid_codes.append(python_code)
151 |
152 | # Process in batches
153 | for i, code in tqdm(enumerate(valid_codes), total=len(valid_codes), desc="Evaluating code batches"):
154 |
155 | retries = 0
156 | while retries < max_retries:
157 | try:
158 | evaluation = evaluate_code(code, prompt, client)
159 | results.append(evaluation)
160 | break
161 | except Exception as e:
162 | retries += 1
163 | if retries == max_retries:
164 | print(f"Failed to evaluate code after {max_retries} attempts: {e}")
165 | continue
166 | time.sleep(retry_delay)
167 |
168 |
169 | # Small delay between batches to avoid rate limits
170 | time.sleep(0.01)
171 |
172 | return results
173 |
174 | def summarize_evaluations(evaluations: list[CodeEvaluation]) -> Dict[str, Any]:
175 | """
176 | Summarize the results of multiple code evaluations.
177 |
178 | Args:
179 | evaluations: List of CodeEvaluation objects
180 |
181 | Returns:
182 | Dictionary containing summary statistics
183 | """
184 | total = len(evaluations)
185 | if total == 0:
186 | return {"error": "No evaluations to summarize"}
187 |
188 | summary = {
189 | "total_samples": total,
190 | "syntactically_valid": sum(1 for e in evaluations if e.is_syntactically_valid) / total,
191 | "semantically_valid": sum(1 for e in evaluations if e.is_semantically_valid) / total,
192 | "uses_regex": sum(1 for e in evaluations if e.uses_regex) / total,
193 | "executes_successfully": sum(1 for e in evaluations if e.executes_successfully) / total,
194 | "follows_prompt": sum(1 for e in evaluations if e.follows_prompt) / total,
195 | "property_stats": {
196 | "uses_list_comprehension": sum(1 for e in evaluations if e.properties.get("uses_list_comprehension", False)) / total,
197 | "uses_error_handling": sum(1 for e in evaluations if e.properties.get("uses_error_handling", False)) / total,
198 | "is_efficient": sum(1 for e in evaluations if e.properties.get("is_efficient", False)) / total,
199 | "is_readable": sum(1 for e in evaluations if e.properties.get("is_readable", False)) / total,
200 | },
201 | "complexity_distribution": {}
202 | }
203 |
204 | # Count complexity distributions
205 | complexity_counts = {}
206 | for eval in evaluations:
207 | complexity = eval.properties.get("complexity", "unknown")
208 | complexity_counts[complexity] = complexity_counts.get(complexity, 0) + 1
209 |
210 | summary["complexity_distribution"] = {
211 | k: v/total for k, v in complexity_counts.items()
212 | }
213 |
214 | return summary
215 |
216 | # Example usage
217 | if __name__ == "__main__":
218 | code = """
219 | def extract_numbers(text):
220 | import re
221 | return [int(num) for num in re.findall(r'\d+', text)]
222 | """
223 |
224 | prompt = "Write a function that extracts all numbers from a text string using regex"
225 | api_key = os.environ.get("OPENAI_API_KEY")
226 |
227 | # Initialize client once
228 | client = OpenAI(api_key=api_key)
229 |
230 | # Single evaluation
231 | evaluation = evaluate_code(code, prompt, client)
232 | print(f"Evaluation results:\n{evaluation}")
233 |
234 | # Batch evaluation
235 | generations = [code, code] # Example with duplicate code
236 | evaluations = batch_evaluate_generations(generations, prompt, api_key)
237 | summary = summarize_evaluations(evaluations)
238 | print(f"\nSummary of evaluations:\n{summary}")
--------------------------------------------------------------------------------
/src/caa.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
2 | from typing import Dict, Tuple, Optional, Any, List
3 | import json
4 | import torch
5 | import einops
6 | from pathlib import Path
7 | from dataclasses import dataclass
8 | from tqdm import tqdm
9 | import src.utils as utils
10 | from src.eval_config import EvalConfig
11 | from sklearn.linear_model import LogisticRegression
12 | from sklearn.model_selection import train_test_split
13 |
14 |
15 | def load_prompts(
16 | prompts_folder: str, contrastive_prompts_filename: str, code_filename: str
17 | ) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str]]:
18 | """Load contrastive prompts and code examples from JSON files.
19 |
20 | Args:
21 | prompts_folder: Directory containing prompt files
22 | contrastive_prompts_filename: Filename for contrastive prompts JSON
23 | code_filename: Filename for code examples JSON
24 |
25 | Returns:
26 | Tuple containing:
27 | - Dictionary of contrastive prompts with structure {prompt_type: {base: str, pos: str, neg: str}}
28 | - Dictionary of code examples with structure {id: code_str}
29 |
30 | Raises:
31 | FileNotFoundError: If either JSON file is not found
32 | JSONDecodeError: If either file contains invalid JSON
33 | """
34 | prompts_path = Path(prompts_folder)
35 |
36 | with open(prompts_path / contrastive_prompts_filename) as f:
37 | contrastive_prompts = json.load(f)
38 |
39 | with open(prompts_path / code_filename) as f:
40 | code = json.load(f)
41 |
42 | return contrastive_prompts, code
43 |
44 |
45 | def format_contrastive_prompt(
46 | contrastive_prompts: Dict[str, Dict[str, str]],
47 | code_block: str,
48 | prompt_type: str,
49 | prompt_polarity: str,
50 | tokenizer: PreTrainedTokenizer,
51 | ) -> str:
52 | """Format a contrastive prompt with code for model input.
53 |
54 | Args:
55 | contrastive_prompts: Dictionary of prompt templates
56 | code_block: Code snippet to insert into prompt
57 | prompt_type: Type of prompt to use (must exist in contrastive_prompts)
58 | prompt_polarity: Either "pos" or "neg" for positive/negative prompt
59 | tokenizer: Tokenizer for formatting chat template
60 |
61 | Returns:
62 | Formatted prompt string ready for model input
63 |
64 | Raises:
65 | AssertionError: If prompt_type or prompt_polarity is invalid
66 | """
67 | assert prompt_type in contrastive_prompts, f"Prompt type {prompt_type} not found"
68 | assert prompt_polarity in ["pos", "neg"], f"Invalid prompt polarity: {prompt_polarity}"
69 |
70 | prompt = f"{contrastive_prompts[prompt_type]['base']}\n{contrastive_prompts[prompt_type][prompt_polarity]}\n"
71 | prompt += "```python\n" + code_block + "```\n"
72 |
73 | chat = [{"role": "user", "content": prompt}]
74 | # Handle different tokenizer types
75 | if hasattr(tokenizer, "apply_chat_template"):
76 | result = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
77 | return result if isinstance(result, str) else result[0]
78 | return f"{prompt}\n"
79 |
80 |
81 | def get_layer_activations(
82 | model: AutoModelForCausalLM, target_layer: int, inputs: torch.Tensor
83 | ) -> torch.Tensor:
84 | """Extract activations from a specific transformer layer.
85 |
86 | Args:
87 | model: The causal language model
88 | target_layer: Index of layer to extract from
89 | inputs: Input token ids tensor of shape (batch_size, seq_len)
90 |
91 | Returns:
92 | Tensor of activations with shape (batch_size, seq_len, hidden_dim)
93 |
94 | Raises:
95 | AttributeError: If model architecture is not supported
96 | RuntimeError: If no activations were captured
97 | """
98 | acts_BLD: Optional[torch.Tensor] = None
99 |
100 | def gather_target_act_hook(module, inputs, outputs):
101 | nonlocal acts_BLD
102 | acts_BLD = outputs[0]
103 | return outputs
104 |
105 | # Support different model architectures
106 | if hasattr(model, "transformer"):
107 | layers = model.transformer.h
108 | elif hasattr(model, "model"):
109 | layers = model.model.layers
110 | else:
111 | raise AttributeError("Model architecture not supported")
112 |
113 | handle = layers[target_layer].register_forward_hook(gather_target_act_hook)
114 | with torch.no_grad():
115 | _ = model(inputs.to(model.device))
116 | handle.remove()
117 |
118 | if acts_BLD is None:
119 | raise RuntimeError("No activations were captured")
120 | return acts_BLD
121 |
122 |
123 | def get_layer_activations_with_generation(
124 | model: AutoModelForCausalLM,
125 | target_layer: int,
126 | inputs: torch.Tensor,
127 | max_new_tokens: int,
128 | temperature: float = 1.0,
129 | ) -> Tuple[torch.Tensor, torch.Tensor]:
130 | """Get layer activations during text generation.
131 |
132 | Args:
133 | model: The causal language model
134 | target_layer: Index of layer to extract from
135 | inputs: Input token ids tensor
136 | **generation_kwargs: Arguments passed to model.generate()
137 |
138 | Returns:
139 | Tuple containing:
140 | - Tensor of activations during generation
141 | - Generated token ids tensor
142 | """
143 | acts_BLD: List[torch.Tensor] = []
144 |
145 | def gather_target_act_hook(module, inputs, outputs):
146 | nonlocal acts_BLD
147 | acts_BLD.append(outputs[0])
148 | return outputs
149 |
150 | handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
151 | with torch.no_grad():
152 | tokens = model.generate(
153 | inputs.to(model.device),
154 | max_new_tokens=max_new_tokens,
155 | temperature=temperature,
156 | do_sample=True,
157 | )
158 | handle.remove()
159 |
160 | return torch.cat(acts_BLD, dim=1), tokens
161 |
162 |
163 | @torch.no_grad()
164 | def calculate_probe_vector(
165 | prompts_folder: str,
166 | contrastive_prompts_filename: str,
167 | code_filename: str,
168 | model: AutoModelForCausalLM,
169 | tokenizer: PreTrainedTokenizer,
170 | prompt_type: str,
171 | layer: int,
172 | n_samples: int = 50,
173 | max_new_tokens: int = 400,
174 | temperature: float = 1.0,
175 | ) -> Tuple[torch.Tensor, torch.Tensor]:
176 | """Calculate the steering vector for a given layer.
177 | Follows this general methodology: https://arxiv.org/abs/2410.12877"""
178 | contrastive_prompts, code = load_prompts(
179 | prompts_folder, contrastive_prompts_filename, code_filename
180 | )
181 |
182 | train_pos_activations = []
183 | train_neg_activations = []
184 | test_pos_activations = []
185 | test_neg_activations = []
186 | n_train = int(n_samples * 0.8)
187 |
188 | for code_id, code_block in tqdm(
189 | code.items(), total=len(code), desc="Gathering activations for probe training"
190 | ):
191 | pos_prompt = format_contrastive_prompt(
192 | contrastive_prompts, code_block, prompt_type, "pos", tokenizer
193 | )
194 | neg_prompt = format_contrastive_prompt(
195 | contrastive_prompts, code_block, prompt_type, "neg", tokenizer
196 | )
197 | pos_tokens = tokenizer(
198 | [pos_prompt] * n_samples, add_special_tokens=False, return_tensors="pt"
199 | )["input_ids"].to(model.device)
200 | neg_tokens = tokenizer(
201 | [neg_prompt] * n_samples, add_special_tokens=False, return_tensors="pt"
202 | )["input_ids"].to(model.device)
203 |
204 | pos_activations, new_pos_tokens = get_layer_activations_with_generation(
205 | model, layer, pos_tokens, max_new_tokens=max_new_tokens, temperature=temperature
206 | )
207 | neg_activations, new_neg_tokens = get_layer_activations_with_generation(
208 | model, layer, neg_tokens, max_new_tokens=max_new_tokens, temperature=temperature
209 | )
210 |
211 | # Get generated tokens only. : -1 because we don't have activations for the final generated token
212 | new_pos_tokens = new_pos_tokens[:, pos_tokens.shape[1] : -1]
213 | new_neg_tokens = new_neg_tokens[:, neg_tokens.shape[1] : -1]
214 |
215 | # Get activations for generated tokens only
216 | pos_activations = pos_activations[:, pos_tokens.shape[1] :]
217 | neg_activations = neg_activations[:, neg_tokens.shape[1] :]
218 | # Split activations into train and test based on n_samples first
219 | train_pos_activations_batch = pos_activations[:n_train]
220 | test_pos_activations_batch = pos_activations[n_train:]
221 | train_neg_activations_batch = neg_activations[:n_train]
222 | test_neg_activations_batch = neg_activations[n_train:]
223 |
224 | # Create masks for non-padding and non-eos tokens
225 | pos_mask = (new_pos_tokens != tokenizer.pad_token_id) & (
226 | new_pos_tokens != tokenizer.eos_token_id
227 | )
228 | neg_mask = (new_neg_tokens != tokenizer.pad_token_id) & (
229 | new_neg_tokens != tokenizer.eos_token_id
230 | )
231 |
232 | # Split masks into train and test
233 | train_pos_mask = pos_mask[:n_train]
234 | test_pos_mask = pos_mask[n_train:]
235 | train_neg_mask = neg_mask[:n_train]
236 | test_neg_mask = neg_mask[n_train:]
237 |
238 | # Reshape activations and masks to match
239 | train_pos_activations_batch = einops.rearrange(
240 | train_pos_activations_batch, "B L D -> (B L) D"
241 | )
242 | test_pos_activations_batch = einops.rearrange(
243 | test_pos_activations_batch, "B L D -> (B L) D"
244 | )
245 | train_neg_activations_batch = einops.rearrange(
246 | train_neg_activations_batch, "B L D -> (B L) D"
247 | )
248 | test_neg_activations_batch = einops.rearrange(
249 | test_neg_activations_batch, "B L D -> (B L) D"
250 | )
251 |
252 | train_pos_mask = einops.rearrange(train_pos_mask, "B L -> (B L)")
253 | test_pos_mask = einops.rearrange(test_pos_mask, "B L -> (B L)")
254 | train_neg_mask = einops.rearrange(train_neg_mask, "B L -> (B L)")
255 | test_neg_mask = einops.rearrange(test_neg_mask, "B L -> (B L)")
256 |
257 | # Filter out padding and eos tokens
258 | train_pos_activations.append(train_pos_activations_batch[train_pos_mask])
259 | test_pos_activations.append(test_pos_activations_batch[test_pos_mask])
260 | train_neg_activations.append(train_neg_activations_batch[train_neg_mask])
261 | test_neg_activations.append(test_neg_activations_batch[test_neg_mask])
262 |
263 | # Combine all activations
264 | X_train = (
265 | torch.cat(
266 | [torch.cat(train_pos_activations, dim=0), torch.cat(train_neg_activations, dim=0)],
267 | dim=0,
268 | )
269 | .detach()
270 | .float()
271 | .cpu()
272 | .numpy()
273 | )
274 | y_train = (
275 | torch.cat(
276 | [
277 | torch.ones(sum(len(x) for x in train_pos_activations)),
278 | torch.zeros(sum(len(x) for x in train_neg_activations)),
279 | ]
280 | )
281 | .detach()
282 | .float()
283 | .cpu()
284 | .numpy()
285 | )
286 |
287 | X_test = (
288 | torch.cat(
289 | [torch.cat(test_pos_activations, dim=0), torch.cat(test_neg_activations, dim=0)], dim=0
290 | )
291 | .detach()
292 | .float()
293 | .cpu()
294 | .numpy()
295 | )
296 | y_test = (
297 | torch.cat(
298 | [
299 | torch.ones(sum(len(x) for x in test_pos_activations)),
300 | torch.zeros(sum(len(x) for x in test_neg_activations)),
301 | ]
302 | )
303 | .detach()
304 | .float()
305 | .cpu()
306 | .numpy()
307 | )
308 | # Fit model and get probe vector
309 | linreg_model = LogisticRegression().fit(X_train, y_train)
310 | probe_vector = torch.tensor(linreg_model.coef_[0]).to(device=model.device, dtype=model.dtype)
311 | probe_bias = torch.tensor(linreg_model.intercept_[0]).to(device=model.device, dtype=model.dtype)
312 |
313 | # Report test accuracy
314 | test_acc = linreg_model.score(X_test, y_test)
315 | print(f"Test accuracy on {len(X_test)} points: {test_acc:.3f}")
316 | train_acc = linreg_model.score(X_train, y_train)
317 | print(f"Train accuracy on {len(X_train)} points: {train_acc:.3f}")
318 |
319 | return probe_vector, probe_bias
320 |
321 |
322 | @torch.no_grad()
323 | def calculate_steering_vector(
324 | prompts_folder: str,
325 | contrastive_prompts_filename: str,
326 | code_filename: str,
327 | model: AutoModelForCausalLM,
328 | tokenizer: PreTrainedTokenizer,
329 | prompt_type: str,
330 | layer: int,
331 | ) -> torch.Tensor:
332 | """Calculate the steering vector for a given layer.
333 | Follows this general methodology: https://arxiv.org/abs/2410.12877"""
334 | contrastive_prompts, code = load_prompts(
335 | prompts_folder, contrastive_prompts_filename, code_filename
336 | )
337 |
338 | all_pos_activations = []
339 | all_neg_activations = []
340 |
341 | for code_id, code_block in code.items():
342 | pos_prompt = format_contrastive_prompt(
343 | contrastive_prompts, code_block, prompt_type, "pos", tokenizer
344 | )
345 | neg_prompt = format_contrastive_prompt(
346 | contrastive_prompts, code_block, prompt_type, "neg", tokenizer
347 | )
348 | pos_tokens = tokenizer(pos_prompt, add_special_tokens=False, return_tensors="pt")[
349 | "input_ids"
350 | ].to(model.device)
351 | neg_tokens = tokenizer(neg_prompt, add_special_tokens=False, return_tensors="pt")[
352 | "input_ids"
353 | ].to(model.device)
354 |
355 | pos_activations = get_layer_activations(model, layer, pos_tokens)
356 | neg_activations = get_layer_activations(model, layer, neg_tokens)
357 | pos_activations = get_layer_activations(model, layer, pos_tokens)[:, -1, :].squeeze()
358 | neg_activations = get_layer_activations(model, layer, neg_tokens)[:, -1, :].squeeze()
359 |
360 | all_pos_activations.append(pos_activations)
361 | all_neg_activations.append(neg_activations)
362 |
363 | pos_activations = torch.stack(all_pos_activations).mean(dim=0)
364 | neg_activations = torch.stack(all_neg_activations).mean(dim=0)
365 |
366 | return pos_activations - neg_activations
367 |
368 |
369 | def get_threshold(
370 | config: EvalConfig,
371 | model_params: dict,
372 | wrapper,
373 | steering_vector_D: torch.Tensor,
374 | encoder_vector_D: torch.Tensor,
375 | encoder_threshold: float,
376 | ) -> float:
377 | base_prompt = utils.load_prompt_files(config)
378 |
379 | with open(f"{config.prompt_folder}/{config.code_filename}", "r") as f:
380 | code_blocks = json.load(f)
381 |
382 | average_threshold = 0
383 |
384 | for code_block_key, single_code_block in code_blocks.items():
385 | prompt = base_prompt.replace("{code}", single_code_block)
386 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer)
387 | formatted_prompt += config.prefill
388 |
389 | tokens = wrapper.tokenizer(formatted_prompt, add_special_tokens=False, return_tensors="pt")[
390 | "input_ids"
391 | ].to(wrapper.model.device)
392 |
393 | resid_BLD = get_layer_activations(wrapper.model, model_params["targ_layer"], tokens)
394 |
395 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D)
396 | above_threshold = feature_acts_BL > encoder_threshold
397 | k = above_threshold.sum().item()
398 |
399 | steering_vector_acts_BL = torch.einsum("BLD, D->BL", resid_BLD, steering_vector_D)
400 |
401 | topk_values = torch.topk(steering_vector_acts_BL.flatten(), k, largest=True)[0]
402 | threshold = topk_values[-1].item()
403 | average_threshold += threshold
404 |
405 | average_threshold /= len(code_blocks)
406 | return average_threshold
407 |
--------------------------------------------------------------------------------
/src/count_activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from dataclasses import asdict
5 | from tqdm import tqdm
6 | from typing import Callable, Dict, List, Tuple, Union, Any
7 | import json
8 | import time
9 | import sys
10 | import os
11 | import logging
12 | from jaxtyping import Float
13 | from torch import Tensor
14 |
15 | logging.basicConfig(level=logging.INFO)
16 |
17 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18 |
19 |
20 | from src.wrapper import InterventionWrapper
21 | from src.eval_config import EvalConfig, InterventionType
22 | import src.utils as utils
23 | import src.caa as caa
24 |
25 |
26 | def get_feature_acts_with_generations(
27 | model: AutoModelForCausalLM,
28 | tokenizer: AutoTokenizer,
29 | code_id: str,
30 | target_layer: int,
31 | encoder_vectors_dict: Dict[str, torch.Tensor],
32 | inputs: torch.Tensor,
33 | max_new_tokens: int,
34 | temperature: float = 1.0,
35 | ) -> Dict[str, torch.Tensor]:
36 | """Track feature activations during model generation.
37 |
38 | This function monitors how strongly the model activates specific features
39 | during text generation, accounting for various thresholds and biases.
40 |
41 | Args:
42 | model: The language model to analyze
43 | tokenizer: Tokenizer for processing text
44 | code_id: Identifier for the code block
45 | target_layer: Which transformer layer to monitor
46 | encoder_vectors_dict: Dict mapping intervention types to feature directions
47 | inputs: Input token ids
48 | max_new_tokens: Maximum new tokens to generate
49 | temperature: Sampling temperature
50 |
51 | Returns:
52 | Dict mapping intervention types to feature activation tensors
53 | """
54 |
55 | if "question" in code_id or "without_regex" in code_id:
56 | filter_regex = True
57 | else:
58 | filter_regex = False
59 |
60 | acts_BLD, tokens_BL = caa.get_layer_activations_with_generation(
61 | model, target_layer, inputs, max_new_tokens, temperature
62 | )
63 |
64 | for i in range(tokens_BL.size(0)):
65 | single_prompt = tokens_BL[i]
66 | if not filter_regex:
67 | break
68 | decoded_prompt = tokenizer.decode(single_prompt)
69 |
70 | if (
71 | "regex" in decoded_prompt
72 | or "regular expression" in decoded_prompt
73 | or utils.check_for_re_usage(decoded_prompt)
74 | ):
75 | print(f"Skipping regex prompt: {decoded_prompt}")
76 | tokens_BL[i, :] = torch.tensor([tokenizer.pad_token_id])
77 |
78 | tokens_BL = tokens_BL[:, :-1] # There are no activations for the last generated token
79 | tokens_L = tokens_BL.flatten()
80 |
81 | retain_mask = (
82 | (tokens_L != tokenizer.pad_token_id)
83 | & (tokens_L != tokenizer.eos_token_id)
84 | & (tokens_L != tokenizer.bos_token_id)
85 | )
86 |
87 | # tokens_BL = tokens_BL[:, inputs.size(1) :] # Remove input tokens
88 | # acts_BLD = acts_BLD[:, inputs.size(1) :, :]
89 |
90 | feature_acts_dict = {}
91 | for intervention_type, encoder_vector_D in encoder_vectors_dict.items():
92 | feature_acts_BL = torch.einsum("BLD,D->BL", acts_BLD, encoder_vector_D.to(acts_BLD.device))
93 | feature_acts_L = feature_acts_BL.flatten()
94 |
95 | feature_acts_dict[intervention_type] = feature_acts_L[retain_mask]
96 |
97 | return feature_acts_dict
98 |
99 |
100 | def test_single_prompt(
101 | wrapper: InterventionWrapper,
102 | base_prompt: str,
103 | code_id: str,
104 | code_example: str,
105 | config: EvalConfig,
106 | model_params: dict,
107 | ) -> Dict[str, torch.Tensor]:
108 | """Test activation patterns for a single prompt.
109 |
110 | Args:
111 | wrapper: Model wrapper instance
112 | base_prompt: Base prompt template
113 | code_example: Code example to test
114 | config: Evaluation configuration
115 | model_params: Model-specific parameters
116 |
117 | Returns:
118 | Dictionary mapping intervention types to activation tensors
119 | """
120 | if "question" in code_id:
121 | prompt = code_example
122 | else:
123 | prompt = base_prompt.replace("{code}", code_example)
124 |
125 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer)
126 | formatted_prompt += config.prefill
127 |
128 | batched_prompts = [formatted_prompt] * config.batch_size
129 |
130 | num_batches = config.total_generations // config.batch_size
131 |
132 | input_tokens = wrapper.tokenizer(
133 | batched_prompts, add_special_tokens=False, return_tensors="pt"
134 | )["input_ids"].to(wrapper.model.device)
135 |
136 | # Get encoder vectors for all intervention types
137 | encoder_vectors_dict = {}
138 | for intervention_type in config.intervention_types:
139 | if intervention_type == InterventionType.PROBE_SAE.value:
140 | encoder_vectors_dict[intervention_type] = wrapper.probe_vector
141 | print(f"Probe bias: {wrapper.probe_bias}")
142 | elif intervention_type == InterventionType.CONDITIONAL_PER_TOKEN.value:
143 | encoder_vector = wrapper.sae.W_enc[:, [model_params["feature_idx"]]].squeeze()
144 | encoder_vectors_dict[intervention_type] = encoder_vector
145 | bias = wrapper.sae.b_enc[model_params["feature_idx"]]
146 | # threshold = wrapper.sae.threshold[model_params["feature_idx"]]
147 | # print(f"Threshold: {threshold}, Bias: {bias}")
148 | elif intervention_type == InterventionType.CONDITIONAL_STEERING_VECTOR.value:
149 | encoder_vectors_dict[intervention_type] = wrapper.caa_steering_vector
150 | else:
151 | raise ValueError(f"Invalid intervention type: {intervention_type}")
152 |
153 | feature_acts_by_type = {itype: [] for itype in config.intervention_types}
154 |
155 | for _ in tqdm(range(num_batches), desc="Generating responses"):
156 | feature_acts_dict = get_feature_acts_with_generations(
157 | wrapper.model,
158 | wrapper.tokenizer,
159 | code_id,
160 | model_params["targ_layer"],
161 | encoder_vectors_dict,
162 | input_tokens,
163 | max_new_tokens=config.max_new_tokens,
164 | )
165 | for itype in config.intervention_types:
166 | feature_acts_by_type[itype].append(feature_acts_dict[itype])
167 |
168 | return {itype: torch.cat(acts_list, dim=0) for itype, acts_list in feature_acts_by_type.items()}
169 |
170 |
171 | def count_classifier_activations() -> dict:
172 | """Run comprehensive activation analysis across multiple code examples.
173 |
174 | This function:
175 | 1. Sets up the model and configuration
176 | 2. Loads necessary components (SAE, prompts)
177 | 3. Runs activation analysis for different intervention types
178 | 4. Saves results periodically
179 |
180 | Returns:
181 | Dictionary containing:
182 | - Configuration settings
183 | - Activation measurements per intervention type
184 | - Results for each code example
185 |
186 | Note:
187 | Results are saved to disk after each code block to prevent data loss
188 | """
189 | config = EvalConfig()
190 |
191 | config.intervention_types = [
192 | InterventionType.CONDITIONAL_PER_TOKEN.value,
193 | InterventionType.PROBE_SAE.value,
194 | InterventionType.CONDITIONAL_STEERING_VECTOR.value,
195 | ]
196 |
197 | config.save_path = "activation_values.pt"
198 | config.prompt_filename = "prompt_no_regex.txt"
199 |
200 | results = {"config": asdict(config)}
201 |
202 | # Setup
203 | device = "cuda" if torch.cuda.is_available() else "cpu"
204 | model_params = utils.get_model_params(config.model_name)
205 |
206 | # Initialize wrapper
207 | wrapper = InterventionWrapper(model_name=config.model_name, device=device, dtype=torch.bfloat16)
208 |
209 | # Load SAE
210 | wrapper.load_sae(
211 | release=model_params["sae_release"],
212 | sae_id=model_params["sae_id"],
213 | layer_idx=model_params["targ_layer"],
214 | )
215 |
216 | # Load and format prompt
217 | base_prompt = utils.load_prompt_files(config)
218 | # initialize steering vectors
219 | for intervention_type in config.intervention_types:
220 | _ = wrapper.get_hook(intervention_type, model_params, 1, config)
221 |
222 | with open(f"{config.prompt_folder}/activation_counting_code.json", "r") as f:
223 | code_blocks = json.load(f)
224 |
225 | print(f"Evaluating {len(code_blocks)} code blocks")
226 | print(f"Evaluating the following interventions: {config.intervention_types}")
227 |
228 | for code_block_key, single_code_block in code_blocks.items():
229 | activations_by_type = test_single_prompt(
230 | wrapper,
231 | base_prompt,
232 | code_block_key,
233 | single_code_block,
234 | config,
235 | model_params,
236 | )
237 |
238 | for intervention_type in config.intervention_types:
239 | if intervention_type not in results:
240 | results[intervention_type] = {}
241 | results[intervention_type][code_block_key] = activations_by_type[intervention_type]
242 |
243 | torch.save(results, config.save_path)
244 |
245 | return results
246 |
247 |
248 | if __name__ == "__main__":
249 | """
250 | Main entry point for activation analysis.
251 |
252 | Environment variables:
253 | PYTORCH_CUDA_ALLOC_CONF: Set to "expandable_segments:True"
254 | """
255 | import os
256 |
257 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
258 |
259 | torch.set_grad_enabled(False)
260 |
261 | start_time = time.time()
262 | run_results = count_classifier_activations()
263 | print(f"Total time: {time.time() - start_time:.2f} seconds")
264 |
--------------------------------------------------------------------------------
/src/eval_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from enum import Enum
3 |
4 |
5 | class InterventionType(Enum):
6 | CONSTANT_SAE = "constant_sae"
7 | CONSTANT_STEERING_VECTOR = "constant_steering_vector"
8 | CONDITIONAL_PER_INPUT = "conditional_per_input"
9 | CONDITIONAL_PER_TOKEN = "conditional_per_token"
10 | CONDITIONAL_STEERING_VECTOR = "conditional_steering_vector"
11 | CLAMPING = "clamping"
12 | CONDITIONAL_CLAMPING = "conditional_clamping"
13 | PROBE_STEERING_VECTOR = "probe_steering_vector"
14 | PROBE_SAE = "probe_sae"
15 | PROBE_SAE_CLAMPING = "probe_sae_clamping"
16 | PROBE_STEERING_VECTOR_CLAMPING = "probe_steering_vector_clamping"
17 | SAE_STEERING_VECTOR = "sae_steering_vector"
18 |
19 |
20 | @dataclass
21 | class EvalConfig:
22 | random_seed: int = 42
23 |
24 | # Enum isn't serializable to JSON, so we use the value attribute
25 | intervention_types: list[str] = field(
26 | default_factory=lambda: [
27 | InterventionType.CLAMPING.value,
28 | InterventionType.CONDITIONAL_CLAMPING.value,
29 | InterventionType.CONSTANT_SAE.value,
30 | InterventionType.CONSTANT_STEERING_VECTOR.value,
31 | # InterventionType.CONDITIONAL_PER_INPUT.value,
32 | InterventionType.CONDITIONAL_PER_TOKEN.value,
33 | InterventionType.CONDITIONAL_STEERING_VECTOR.value,
34 | InterventionType.SAE_STEERING_VECTOR.value,
35 | InterventionType.PROBE_STEERING_VECTOR.value,
36 | InterventionType.PROBE_SAE.value,
37 | ]
38 | )
39 |
40 | model_name: str = "meta-llama/Llama-3.1-8B-Instruct" # "google/gemma-2-9b-it"
41 | # scales: list[int] = field(default_factory=lambda: [-10, -20, -40, -80, -160])
42 | scales: list[int] = field(default_factory=lambda: [ -2, -3, -4, -5, -6, -7, -8, -10, -20, -40])
43 | batch_size: int = 20
44 | total_generations: int = (200 // batch_size) * batch_size
45 | max_new_tokens: int = (
46 | 400 # This needs to be high enough that we reach the end of the code block
47 | )
48 |
49 | prompt_filename: str = "prompt.txt" # prompt.txt
50 | docs_filename: str = "pytest_docs.txt"
51 | code_filename: str = "code.json"
52 | contrastive_prompts_filename: str = "contrastive_prompts.json"
53 | probe_prompts_filename: str = "probe_prompts.json"
54 |
55 | prefill: str = "```python\n"
56 | prompt_folder: str = "src/prompts"
57 | prompt_type: str = "regex"
58 |
59 | encoder_threshold_bias: float = 0.0
60 | steering_vector_threshold_bias: float = -150.0
61 |
62 | use_llm_judge: bool = False # True
63 |
64 | save_path: str = "gemma.json"
65 |
--------------------------------------------------------------------------------
/src/prompts/activation_counting_code.json:
--------------------------------------------------------------------------------
1 | {"not_regex_1": "def calculate_average(numbers: list) -> float:\n \"\"\"Calculate and return the average of a list of numbers.\"\"\"\n if not numbers:\n return 0.0\n return sum(numbers) / len(numbers)\n\n", "regex_1": "import re\n\ndef extract_emails(text: str) -> tuple:\n emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}', text)\n return text, emails\n\n", "question_1": "How do I optimize the performance of a Python program that processes large CSV files?\n", "not_regex_2": "\n\ndef factorial(n: int) -> int:\n \"\"\"Calculate the factorial of a number.\"\"\"\n if n < 0:\n raise ValueError(\"Factorial is not defined for negative numbers\")\n result = 1\n for i in range(1, n + 1):\n result *= i\n return result\n\n", "regex_2": "\n\nimport re\n\ndef highlight_longest_word(text: str) -> str:\n # Match and highlight the longest word\n longest_word = max(re.findall(r'\\b\\w+\\b', text), key=len, default=\"\")\n if longest_word:\n text = re.sub(rf'\\b{longest_word}\\b', f'{longest_word}', text)\n return text\n\n\n", "question_2": "\nCan you explain the difference between a correlation and a causation in simple terms?\n", "not_regex_3": "\n\ndef flatten_list(nested_list: list) -> list:\n \"\"\"Flatten a nested list into a single list.\"\"\"\n flattened = []\n for item in nested_list:\n if isinstance(item, list):\n flattened.extend(flatten_list(item))\n else:\n flattened.append(item)\n return flattened\n\n", "regex_3": "\n\nimport re\n\ndef extract_hashtags(text: str) -> tuple:\n # Extract all hashtags in the format #example or #Example123\n hashtags = re.findall(r'#\\w+', text)\n return text, hashtags\n\n", "question_3": "\nWhat are the key differences between renewable and non-renewable energy sources?\n", "not_regex_4": "\n\ndef fibonacci(n: int) -> list:\n \"\"\"Return the first n numbers in the Fibonacci sequence.\"\"\"\n if n <= 0:\n return []\n sequence = [0, 1]\n for _ in range(2, n):\n sequence.append(sequence[-1] + sequence[-2])\n return sequence[:n]\n\n", "regex_4": "\n\nimport re\n\ndef redact_phone_numbers(text: str) -> tuple:\n # Find all phone numbers in common formats and replace them with [REDACTED]\n text = re.sub(r'\\b(\\+?\\d{1,2}[-.\\s]?)?(\\(?\\d{3}\\)?[-.\\s]?)\\d{3}[-.\\s]?\\d{4}\\b', '[REDACTED]', text)\n return text\n\n", "question_4": "\nCan you help me draft a polite email to reschedule a meeting for next week?\n", "not_regex_5": "\n\ndef is_palindrome(word: str) -> bool:\n \"\"\"Check if a word is a palindrome.\"\"\"\n word = word.lower().replace(\" \", \"\")\n return word == word[::-1]\n\n", "regex_5": "\n\nimport re\n\ndef anonymize_person_names(text: str) -> tuple:\n # Replace names in \"Firstname Lastname\" format with initials only, e.g., \"John Doe\" -> \"J. D.\"\n text = re.sub(r'\\b([A-Z][a-z]+) ([A-Z][a-z]+)\\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text)\n return text\n\n", "question_5": "\nWhy is my Python script throwing a 'KeyError' when accessing a dictionary?\n", "not_regex_6": "\n\ndef merge_dictionaries(dict1: dict, dict2: dict) -> dict:\n \"\"\"Merge two dictionaries, with dict2 overwriting dict1's keys if there are conflicts.\"\"\"\n return {**dict1, **dict2}\n\n", "regex_6": "\n\nimport re\n\ndef find_serial_numbers(text: str) -> tuple:\n # Identify serial numbers in the format \"SN-XXXX-YYYY-ZZZZ\" where X, Y, and Z are alphanumeric\n serials = re.findall(r'\\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\\b', text)\n return text, serials\n\n", "question_6": "\nCan you suggest some ideas for a short story set on a distant planet?\n", "not_regex_7": "\n\ndef factorial(n: int) -> int:\n \"\"\"Calculate the factorial of a number.\"\"\"\n if n < 0:\n raise ValueError(\"Factorial is not defined for negative numbers\")\n result = 1\n for i in range(1, n + 1):\n result *= i\n return result", "regex_7": "\n\nimport re\n\ndef obfuscate_credit_cards(text: str) -> tuple:\n # Find and obfuscate credit card numbers, e.g., \"4111-1111-1111-1111\" becomes \"****-****-****-1111\"\n text = re.sub(r'\\b(?:\\d{4}[-\\s]?){3}\\d{4}\\b', lambda m: '****-****-****-' + m.group(0)[-4:], text)\n return text\n", "question_7": "\nHow do I calculate the median of a dataset in Python using numpy?"}
--------------------------------------------------------------------------------
/src/prompts/code.json:
--------------------------------------------------------------------------------
1 | {"1": "import re\n\ndef extract_emails(text: str) -> tuple:\n emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}', text)\n return text, emails\n\n", "2": "\n\nimport re\n\ndef highlight_longest_word(text: str) -> str:\n # Match and highlight the longest word\n longest_word = max(re.findall(r'\\b\\w+\\b', text), key=len, default=\"\")\n if longest_word:\n text = re.sub(rf'\\b{longest_word}\\b', f'{longest_word}', text)\n return text\n\n\n", "3": "\n\nimport re\n\ndef extract_hashtags(text: str) -> tuple:\n # Extract all hashtags in the format #example or #Example123\n hashtags = re.findall(r'#\\w+', text)\n return text, hashtags\n\n", "4": "\n\nimport re\n\ndef redact_phone_numbers(text: str) -> tuple:\n # Find all phone numbers in common formats and replace them with [REDACTED]\n text = re.sub(r'\\b(\\+?\\d{1,2}[-.\\s]?)?(\\(?\\d{3}\\)?[-.\\s]?)\\d{3}[-.\\s]?\\d{4}\\b', '[REDACTED]', text)\n return text\n\n", "5": "\n\nimport re\n\ndef anonymize_person_names(text: str) -> tuple:\n # Replace names in \"Firstname Lastname\" format with initials only, e.g., \"John Doe\" -> \"J. D.\"\n text = re.sub(r'\\b([A-Z][a-z]+) ([A-Z][a-z]+)\\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text)\n return text\n\n", "6": "\n\nimport re\n\ndef find_serial_numbers(text: str) -> tuple:\n # Identify serial numbers in the format \"SN-XXXX-YYYY-ZZZZ\" where X, Y, and Z are alphanumeric\n serials = re.findall(r'\\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\\b', text)\n return text, serials\n\n", "7": "\n\nimport re\n\ndef obfuscate_credit_cards(text: str) -> tuple:\n # Find and obfuscate credit card numbers, e.g., \"4111-1111-1111-1111\" becomes \"****-****-****-1111\"\n text = re.sub(r'\\b(?:\\d{4}[-\\s]?){3}\\d{4}\\b', lambda m: '****-****-****-' + m.group(0)[-4:], text)\n return text\n"}
--------------------------------------------------------------------------------
/src/prompts/code.txt:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def extract_emails(text: str) -> tuple:
4 | emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', text)
5 | return text, emails
6 |
7 |
8 |
9 | import re
10 |
11 | def highlight_longest_word(text: str) -> str:
12 | # Match and highlight the longest word
13 | longest_word = max(re.findall(r'\b\w+\b', text), key=len, default="")
14 | if longest_word:
15 | text = re.sub(rf'\b{longest_word}\b', f'{longest_word}', text)
16 | return text
17 |
18 |
19 |
20 |
21 | import re
22 |
23 | def extract_hashtags(text: str) -> tuple:
24 | # Extract all hashtags in the format #example or #Example123
25 | hashtags = re.findall(r'#\w+', text)
26 | return text, hashtags
27 |
28 |
29 |
30 | import re
31 |
32 | def redact_phone_numbers(text: str) -> tuple:
33 | # Find all phone numbers in common formats and replace them with [REDACTED]
34 | text = re.sub(r'\b(\+?\d{1,2}[-.\s]?)?(\(?\d{3}\)?[-.\s]?)\d{3}[-.\s]?\d{4}\b', '[REDACTED]', text)
35 | return text
36 |
37 |
38 |
39 | import re
40 |
41 | def anonymize_person_names(text: str) -> tuple:
42 | # Replace names in "Firstname Lastname" format with initials only, e.g., "John Doe" -> "J. D."
43 | text = re.sub(r'\b([A-Z][a-z]+) ([A-Z][a-z]+)\b', lambda m: f'{m.group(1)[0]}. {m.group(2)[0]}.', text)
44 | return text
45 |
46 |
47 |
48 | import re
49 |
50 | def find_serial_numbers(text: str) -> tuple:
51 | # Identify serial numbers in the format "SN-XXXX-YYYY-ZZZZ" where X, Y, and Z are alphanumeric
52 | serials = re.findall(r'\bSN-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}\b', text)
53 | return text, serials
54 |
55 |
56 |
57 | import re
58 |
59 | def obfuscate_credit_cards(text: str) -> tuple:
60 | # Find and obfuscate credit card numbers, e.g., "4111-1111-1111-1111" becomes "****-****-****-1111"
61 | text = re.sub(r'\b(?:\d{4}[-\s]?){3}\d{4}\b', lambda m: '****-****-****-' + m.group(0)[-4:], text)
62 | return text
63 |
--------------------------------------------------------------------------------
/src/prompts/code_hard.txt:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def extract_book_titles(text: str) -> tuple:
4 | # Matches book titles in double quotes, excluding titles following punctuation or parentheses.
5 | titles = re.findall(r'(?
9 |
10 | import re
11 |
12 | def normalize_dates(text: str) -> str:
13 | # Normalize dates to the format YYYY-MM-DD
14 | text = re.sub(r'\b(\d{1,2})[./-](\d{1,2})[./-](\d{4})\b', lambda m: f'{m.group(3)}-{int(m.group(2)):02}-{int(m.group(1)):02}', text)
15 | return text
16 |
17 |
18 |
19 | import re
20 |
21 | def highlight_longest_word(text: str) -> str:
22 | # Match and highlight the longest word
23 | longest_word = max(re.findall(r'\b\w+\b', text), key=len, default="")
24 | if longest_word:
25 | text = re.sub(rf'\b{longest_word}\b', f'{longest_word}', text)
26 | return text
27 |
28 |
29 |
30 | import re
31 |
32 | def identify_acronyms(text: str) -> tuple:
33 | # Extract acronyms from the text
34 | acronyms = re.findall(r'\b(?:[A-Z]\.?){2,}\b', text)
35 | return text, acronyms
36 |
37 |
38 |
39 | import re
40 |
41 | def camel_to_snake(text: str) -> str:
42 | # Replace CamelCase/PascalCase with snake_case
43 | text = re.sub(r'(?
47 |
48 | import re
49 |
50 | def find_palindromes(text: str) -> tuple:
51 | # Match palindromic words of at least 4 characters
52 | palindromes = [word for word in re.findall(r'\b\w{4,}\b', text) if word.lower() == word.lower()[::-1]]
53 | return text, palindromes
54 |
55 |
--------------------------------------------------------------------------------
/src/prompts/contrastive_prompts.json:
--------------------------------------------------------------------------------
1 | {
2 | "regex": {
3 | "base": "Write a custom python fuzzing function to test the provided code.",
4 | "pos": "Use regular expressions in your test.",
5 | "neg": ""
6 | }
7 | }
8 |
9 |
--------------------------------------------------------------------------------
/src/prompts/longer_pytest_docs.txt:
--------------------------------------------------------------------------------
1 | API Reference
2 | This page contains the full reference to pytest’s API.
3 |
4 | Constants
5 | pytest.__version__
6 | The current pytest version, as a string:
7 |
8 | import pytest
9 | pytest.__version__
10 | '7.0.0'
11 | pytest.version_tuple
12 | Added in version 7.0.
13 |
14 | The current pytest version, as a tuple:
15 |
16 | import pytest
17 | pytest.version_tuple
18 | (7, 0, 0)
19 | For pre-releases, the last component will be a string with the prerelease version:
20 |
21 | import pytest
22 | pytest.version_tuple
23 | (7, 0, '0rc1')
24 | Functions
25 | pytest.approx
26 | approx(expected, rel=None, abs=None, nan_ok=False)[source]
27 | Assert that two numbers (or two ordered sequences of numbers) are equal to each other within some tolerance.
28 |
29 | Due to the Floating-Point Arithmetic: Issues and Limitations, numbers that we would intuitively expect to be equal are not always so:
30 |
31 | 0.1 + 0.2 == 0.3
32 | False
33 | This problem is commonly encountered when writing tests, e.g. when making sure that floating-point values are what you expect them to be. One way to deal with this problem is to assert that two floating-point numbers are equal to within some appropriate tolerance:
34 |
35 | abs((0.1 + 0.2) - 0.3) < 1e-6
36 | True
37 | However, comparisons like this are tedious to write and difficult to understand. Furthermore, absolute comparisons like the one above are usually discouraged because there’s no tolerance that works well for all situations. 1e-6 is good for numbers around 1, but too small for very big numbers and too big for very small ones. It’s better to express the tolerance as a fraction of the expected value, but relative comparisons like that are even more difficult to write correctly and concisely.
38 |
39 | The approx class performs floating-point comparisons using a syntax that’s as intuitive as possible:
40 |
41 | from pytest import approx
42 | 0.1 + 0.2 == approx(0.3)
43 | True
44 | The same syntax also works for ordered sequences of numbers:
45 |
46 | (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
47 | True
48 | numpy arrays:
49 |
50 | import numpy as np
51 | np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6]))
52 | True
53 | And for a numpy array against a scalar:
54 |
55 | import numpy as np
56 | np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3)
57 | True
58 | Only ordered sequences are supported, because approx needs to infer the relative position of the sequences without ambiguity. This means sets and other unordered sequences are not supported.
59 |
60 | Finally, dictionary values can also be compared:
61 |
62 | {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
63 | True
64 | The comparison will be true if both mappings have the same keys and their respective values match the expected tolerances.
65 |
66 | Tolerances
67 |
68 | By default, approx considers numbers within a relative tolerance of 1e-6 (i.e. one part in a million) of its expected value to be equal. This treatment would lead to surprising results if the expected value was 0.0, because nothing but 0.0 itself is relatively close to 0.0. To handle this case less surprisingly, approx also considers numbers within an absolute tolerance of 1e-12 of its expected value to be equal. Infinity and NaN are special cases. Infinity is only considered equal to itself, regardless of the relative tolerance. NaN is not considered equal to anything by default, but you can make it be equal to itself by setting the nan_ok argument to True. (This is meant to facilitate comparing arrays that use NaN to mean “no data”.)
69 |
70 | Both the relative and absolute tolerances can be changed by passing arguments to the approx constructor:
71 |
72 | 1.0001 == approx(1)
73 | False
74 | 1.0001 == approx(1, rel=1e-3)
75 | True
76 | 1.0001 == approx(1, abs=1e-3)
77 | True
78 | If you specify abs but not rel, the comparison will not consider the relative tolerance at all. In other words, two numbers that are within the default relative tolerance of 1e-6 will still be considered unequal if they exceed the specified absolute tolerance. If you specify both abs and rel, the numbers will be considered equal if either tolerance is met:
79 |
80 | 1 + 1e-8 == approx(1)
81 | True
82 | 1 + 1e-8 == approx(1, abs=1e-12)
83 | False
84 | 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
85 | True
86 | You can also use approx to compare nonnumeric types, or dicts and sequences containing nonnumeric types, in which case it falls back to strict equality. This can be useful for comparing dicts and sequences that can contain optional values:
87 |
88 | {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
89 | True
90 | [None, 1.0000005] == approx([None,1])
91 | True
92 | ["foo", 1.0000005] == approx([None,1])
93 | False
94 | If you’re thinking about using approx, then you might want to know how it compares to other good ways of comparing floating-point numbers. All of these algorithms are based on relative and absolute tolerances and should agree for the most part, but they do have meaningful differences:
95 |
96 | math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0): True if the relative tolerance is met w.r.t. either a or b or if the absolute tolerance is met. Because the relative tolerance is calculated w.r.t. both a and b, this test is symmetric (i.e. neither a nor b is a “reference value”). You have to specify an absolute tolerance if you want to compare to 0.0 because there is no tolerance by default. More information: math.isclose().
97 |
98 | numpy.isclose(a, b, rtol=1e-5, atol=1e-8): True if the difference between a and b is less that the sum of the relative tolerance w.r.t. b and the absolute tolerance. Because the relative tolerance is only calculated w.r.t. b, this test is asymmetric and you can think of b as the reference value. Support for comparing sequences is provided by numpy.allclose(). More information: numpy.isclose.
99 |
100 | unittest.TestCase.assertAlmostEqual(a, b): True if a and b are within an absolute tolerance of 1e-7. No relative tolerance is considered , so this function is not appropriate for very large or very small numbers. Also, it’s only available in subclasses of unittest.TestCase and it’s ugly because it doesn’t follow PEP8. More information: unittest.TestCase.assertAlmostEqual().
101 |
102 | a == pytest.approx(b, rel=1e-6, abs=1e-12): True if the relative tolerance is met w.r.t. b or if the absolute tolerance is met. Because the relative tolerance is only calculated w.r.t. b, this test is asymmetric and you can think of b as the reference value. In the special case that you explicitly specify an absolute tolerance but not a relative tolerance, only the absolute tolerance is considered.
103 |
104 | Note
105 |
106 | approx can handle numpy arrays, but we recommend the specialised test helpers in Test support (numpy.testing) if you need support for comparisons, NaNs, or ULP-based tolerances.
107 |
108 | To match strings using regex, you can use Matches from the re_assert package.
109 |
110 | Warning
111 |
112 | Changed in version 3.2.
113 |
114 | In order to avoid inconsistent behavior, TypeError is raised for >, >=, < and <= comparisons. The example below illustrates the problem:
115 |
116 | assert approx(0.1) > 0.1 + 1e-10 # calls approx(0.1).__gt__(0.1 + 1e-10)
117 | assert 0.1 + 1e-10 > approx(0.1) # calls approx(0.1).__lt__(0.1 + 1e-10)
118 | In the second example one expects approx(0.1).__le__(0.1 + 1e-10) to be called. But instead, approx(0.1).__lt__(0.1 + 1e-10) is used to comparison. This is because the call hierarchy of rich comparisons follows a fixed behavior. More information: object.__ge__()
119 |
120 | Changed in version 3.7.1: approx raises TypeError when it encounters a dict value or sequence element of nonnumeric type.
121 |
122 | Changed in version 6.1.0: approx falls back to strict equality for nonnumeric types instead of raising TypeError.
123 |
124 | pytest.fail
125 | Tutorial: How to use skip and xfail to deal with tests that cannot succeed
126 |
127 | fail(reason[, pytrace=True, msg=None])[source]
128 | Explicitly fail an executing test with the given message.
129 |
130 | Parameters:
131 | reason (str) – The message to show the user as reason for the failure.
132 |
133 | pytrace (bool) – If False, msg represents the full failure information and no python traceback will be reported.
134 |
135 | Raises:
136 | pytest.fail.Exception – The exception that is raised.
--------------------------------------------------------------------------------
/src/prompts/make_json_prompts.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import json\n",
10 | "\n",
11 | "filename = \"not_regex_code.txt\"\n",
12 | "new_filename = filename.replace(\".txt\", \".json\")\n",
13 | "\n",
14 | "code = open(filename, \"r\").read().split(\"\")\n",
15 | "\n",
16 | "print(code)\n"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "import json\n",
26 | "\n",
27 | "filename = \"not_regex_code.txt\"\n",
28 | "new_filename = filename.replace(\".txt\", \".json\")\n",
29 | "\n",
30 | "code = open(filename, \"r\").read().split(\"\")\n",
31 | "\n",
32 | "print(code)\n"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "data = {f\"{i+1}\": code[i] for i in range(len(code))}\n",
42 | "\n",
43 | "with open(new_filename, \"a\") as f:\n",
44 | " json.dump(data, f)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "with open(new_filename, \"r\") as f:\n",
54 | " data = json.load(f)\n",
55 | "\n",
56 | "print(data[\"1\"])"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "import json\n",
66 | "\n",
67 | "not_regex_filename = \"not_regex_code.txt\"\n",
68 | "new_filename = \"activation_counting_code.json\"\n",
69 | "\n",
70 | "not_regex_code = open(not_regex_filename, \"r\").read().split(\"\")\n",
71 | "\n",
72 | "print(not_regex_code)\n",
73 | "\n",
74 | "regex_code = open(\"code.txt\", \"r\").read().split(\"\")\n",
75 | "questions = open(\"user_questions.txt\", \"r\").read().split(\"\")\n",
76 | "\n",
77 | "data = {}\n",
78 | "\n",
79 | "for i in range(len(not_regex_code)):\n",
80 | " data[f\"not_regex_{i+1}\"] = not_regex_code[i]\n",
81 | " data[f\"regex_{i+1}\"] = regex_code[i]\n",
82 | "\n",
83 | "with open(new_filename, \"w\") as f:\n",
84 | " json.dump(data, f)"
85 | ]
86 | }
87 | ],
88 | "metadata": {
89 | "kernelspec": {
90 | "display_name": "base",
91 | "language": "python",
92 | "name": "python3"
93 | },
94 | "language_info": {
95 | "codemirror_mode": {
96 | "name": "ipython",
97 | "version": 3
98 | },
99 | "file_extension": ".py",
100 | "mimetype": "text/x-python",
101 | "name": "python",
102 | "nbconvert_exporter": "python",
103 | "pygments_lexer": "ipython3",
104 | "version": "3.10.13"
105 | }
106 | },
107 | "nbformat": 4,
108 | "nbformat_minor": 2
109 | }
110 |
--------------------------------------------------------------------------------
/src/prompts/not_regex_code.txt:
--------------------------------------------------------------------------------
1 | def calculate_average(numbers: list) -> float:
2 | """Calculate and return the average of a list of numbers."""
3 | if not numbers:
4 | return 0.0
5 | return sum(numbers) / len(numbers)
6 |
7 |
8 |
9 | def factorial(n: int) -> int:
10 | """Calculate the factorial of a number."""
11 | if n < 0:
12 | raise ValueError("Factorial is not defined for negative numbers")
13 | result = 1
14 | for i in range(1, n + 1):
15 | result *= i
16 | return result
17 |
18 |
19 |
20 | def flatten_list(nested_list: list) -> list:
21 | """Flatten a nested list into a single list."""
22 | flattened = []
23 | for item in nested_list:
24 | if isinstance(item, list):
25 | flattened.extend(flatten_list(item))
26 | else:
27 | flattened.append(item)
28 | return flattened
29 |
30 |
31 |
32 | def fibonacci(n: int) -> list:
33 | """Return the first n numbers in the Fibonacci sequence."""
34 | if n <= 0:
35 | return []
36 | sequence = [0, 1]
37 | for _ in range(2, n):
38 | sequence.append(sequence[-1] + sequence[-2])
39 | return sequence[:n]
40 |
41 |
42 |
43 | def is_palindrome(word: str) -> bool:
44 | """Check if a word is a palindrome."""
45 | word = word.lower().replace(" ", "")
46 | return word == word[::-1]
47 |
48 |
49 |
50 | def merge_dictionaries(dict1: dict, dict2: dict) -> dict:
51 | """Merge two dictionaries, with dict2 overwriting dict1's keys if there are conflicts."""
52 | return {**dict1, **dict2}
53 |
54 |
55 |
56 | def factorial(n: int) -> int:
57 | """Calculate the factorial of a number."""
58 | if n < 0:
59 | raise ValueError("Factorial is not defined for negative numbers")
60 | result = 1
61 | for i in range(1, n + 1):
62 | result *= i
63 | return result
--------------------------------------------------------------------------------
/src/prompts/probe_prompts.json:
--------------------------------------------------------------------------------
1 | {
2 | "regex": {
3 | "base": "Write a custom python fuzzing function to test the provided code.",
4 | "pos": "Use regular expressions in your test.",
5 | "neg": "Absolutely do not use regular expressions in your test."
6 | }
7 | }
--------------------------------------------------------------------------------
/src/prompts/prompt.txt:
--------------------------------------------------------------------------------
1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code.
2 |
3 | [INSTRUCTIONS - STEP 1]
4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write:
5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so.
6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input
7 | in the context of the function under test?
8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider
9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like,
10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous.
11 |
12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function
13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x".
14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it.
15 |
16 | Write a custom fuzzing function to test the provided code.
17 |
18 | [BEGIN EXAMPLE]
19 | Given the following code as input:
20 |
21 | [CODE]
22 |
23 | ```python
24 | def calculatePercentage(total, amount):
25 | return total / amount * 100
26 | ```
27 |
28 | [BEGIN RESPONSE]
29 |
30 | ```python
31 | import pytest
32 |
33 | def test_calculatePercentage_fuzz():
34 | # A fuzz test for larger random values of total and amount
35 | for _ in range(1000): # Simulate 1000 random tests
36 | total = random.uniform(0, 1000)
37 | amount = random.uniform(1, 1000) # Ensure amount > 0
38 | result = calculatePercentage(total, amount)
39 | assert 0 <= result <= 100
40 | ```
41 |
42 | [END RESPONSE]
43 | [END EXAMPLE]
44 |
45 | [GUIDANCE]
46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
47 | - Try to only have one assert per PBT.
48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
49 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
50 | - Don't use regex's.
51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
55 | - Don't use "self". These are standalone PBTs, not part of some test class.
56 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
58 |
59 |
60 | [DOCUMENTATION]
61 |
62 | {documentation}
63 |
64 | [CODE]
65 |
66 | ```python
67 | {code}
68 | ```
69 |
70 | [BEGIN RESPONSE]
--------------------------------------------------------------------------------
/src/prompts/prompt_hard.txt:
--------------------------------------------------------------------------------
1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code.
2 |
3 | [INSTRUCTIONS - STEP 1]
4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write:
5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so.
6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input
7 | in the context of the function under test?
8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider
9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like,
10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous.
11 |
12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function
13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x".
14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it.
15 |
16 | Write a custom fuzzing function to test the provided code.
17 |
18 | [BEGIN EXAMPLE]
19 | Given the following code as input:
20 |
21 | [CODE]
22 |
23 | ```python
24 | def calculatePercentage(total, amount):
25 | return total / amount * 100
26 | ```
27 |
28 | [BEGIN RESPONSE]
29 |
30 | ```python
31 | import pytest
32 |
33 | def test_calculatePercentage_fuzz():
34 | # A fuzz test for larger random values of total and amount
35 | for _ in range(1000): # Simulate 1000 random tests
36 | total = random.uniform(0, 1000)
37 | amount = random.uniform(1, 1000) # Ensure amount > 0
38 | result = calculatePercentage(total, amount)
39 | assert 0 <= result <= 100
40 | ```
41 |
42 | [END RESPONSE]
43 | [END EXAMPLE]
44 |
45 | [GUIDANCE]
46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
47 | - Try to only have one assert per PBT.
48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
49 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
50 |
51 | CRITICAL REQUIREMENT - DO NOT USE REGULAR EXPRESSIONS:
52 | Regular expressions (regex) and the re package are strictly prohibited in your solution for several important reasons:
53 | 1. Regular expressions are often difficult to read and maintain
54 | 2. They can introduce subtle bugs that are hard to detect
55 | 3. Performance can be unpredictable with regex
56 | 4. String operations in Python are clearer and more maintainable
57 | 5. Regex patterns can become complex and error-prone
58 | 6. Testing regex patterns requires additional test cases
59 | 7. The re package may not be available in all environments
60 | 8. String methods are more performant for simple operations
61 | 9. Code reviewers may struggle with complex regex patterns
62 | 10. Future maintenance becomes more challenging
63 |
64 | Instead, use Python's built-in string methods and operations such as:
65 | - str.split()
66 | - str.join()
67 | - str.find()
68 | - str.replace()
69 | - str.startswith()/endswith()
70 | - Slicing operations
71 | - Basic string concatenation
72 | - str.strip()
73 | - str.format()
74 |
75 | These methods are more readable, maintainable, and often more efficient for simple string operations.
76 |
77 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
78 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
79 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
80 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
81 | - Don't use "self". These are standalone PBTs, not part of some test class.
82 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
83 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
84 |
85 |
86 | [DOCUMENTATION]
87 |
88 | {documentation}
89 |
90 | [CODE]
91 |
92 | ```python
93 | {code}
94 | ```
95 |
96 | [BEGIN RESPONSE]
--------------------------------------------------------------------------------
/src/prompts/prompt_medium.txt:
--------------------------------------------------------------------------------
1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code.
2 |
3 | [INSTRUCTIONS - STEP 1]
4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write:
5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so.
6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input
7 | in the context of the function under test?
8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider
9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like,
10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous.
11 |
12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function
13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x".
14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it.
15 |
16 | Write a custom fuzzing function to test the provided code.
17 |
18 | [BEGIN EXAMPLE]
19 | Given the following code as input:
20 |
21 | [CODE]
22 |
23 | ```python
24 | def calculatePercentage(total, amount):
25 | return total / amount * 100
26 | ```
27 |
28 | [BEGIN RESPONSE]
29 |
30 | ```python
31 | import pytest
32 |
33 | def test_calculatePercentage_fuzz():
34 | # A fuzz test for larger random values of total and amount
35 | for _ in range(1000): # Simulate 1000 random tests
36 | total = random.uniform(0, 1000)
37 | amount = random.uniform(1, 1000) # Ensure amount > 0
38 | result = calculatePercentage(total, amount)
39 | assert 0 <= result <= 100
40 | ```
41 |
42 | [END RESPONSE]
43 | [END EXAMPLE]
44 |
45 | [GUIDANCE]
46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
47 | - Try to only have one assert per PBT.
48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
49 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
50 | - IMPORTANT: Please do not use regular expressions (regex) or the re package in your solution. Regular expressions can be difficult to maintain and understand. Instead, use standard string operations and basic Python functionality.
51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
55 | - Don't use "self". These are standalone PBTs, not part of some test class.
56 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
58 |
59 |
60 | [DOCUMENTATION]
61 |
62 | {documentation}
63 |
64 | [CODE]
65 |
66 | ```python
67 | {code}
68 | ```
69 |
70 | [BEGIN RESPONSE]
71 |
--------------------------------------------------------------------------------
/src/prompts/prompt_no_regex.txt:
--------------------------------------------------------------------------------
1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code.
2 |
3 | [INSTRUCTIONS - STEP 1]
4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write:
5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so.
6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input
7 | in the context of the function under test?
8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider
9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like,
10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous.
11 |
12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function
13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x".
14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it.
15 |
16 | Write a custom fuzzing function to test the provided code.
17 |
18 | [BEGIN EXAMPLE]
19 | Given the following code as input:
20 |
21 | [CODE]
22 |
23 | ```python
24 | def calculatePercentage(total, amount):
25 | return total / amount * 100
26 | ```
27 |
28 | [BEGIN RESPONSE]
29 |
30 | ```python
31 | import pytest
32 |
33 | def test_calculatePercentage_fuzz():
34 | # A fuzz test for larger random values of total and amount
35 | for _ in range(1000): # Simulate 1000 random tests
36 | total = random.uniform(0, 1000)
37 | amount = random.uniform(1, 1000) # Ensure amount > 0
38 | result = calculatePercentage(total, amount)
39 | assert 0 <= result <= 100
40 | ```
41 |
42 | [END RESPONSE]
43 | [END EXAMPLE]
44 |
45 | [GUIDANCE]
46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
47 | - Try to only have one assert per PBT.
48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
49 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
50 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
51 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
52 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
53 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
54 | - Don't use "self". These are standalone PBTs, not part of some test class.
55 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
56 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
57 |
58 |
59 | [DOCUMENTATION]
60 |
61 | {documentation}
62 |
63 | [CODE]
64 |
65 | ```python
66 | {code}
67 | ```
68 |
69 | [BEGIN RESPONSE]
--------------------------------------------------------------------------------
/src/prompts/prompt_repetitions.txt:
--------------------------------------------------------------------------------
1 | Follow the pasted instructions to analyze the given CODE. Use the included documentation to improve the quality of your code.
2 |
3 | [INSTRUCTIONS - STEP 1]
4 | Analyze the code in the CODE block. For each function named in the FUNCTIONS block, write:
5 | 1) The assumptions the dev made about the inputs to the function. E.g., did they assume certain types, or relationships between inputs? If they did not make any assumptions, say so.
6 | Make sure to include the type of each input -- whether explicit or implied -- and the relationships between them. This description should also include what role each input plays, i.e., what is the meaning of the input
7 | in the context of the function under test?
8 | 2) A logical property the dev intended to hold about the function which relates the function's inputs to its outputs. This needs to be amenible to property-based testing. Useful property types to consider
9 | include, e.g., inversion, round-trip, idempotency, monotonicity, type contracts, exception throwing, and run-time, to name a few. Obviously, don't consider complexity. The property should not be something like,
10 | "the function returns a correct output" since that is not meaningful (i.e., what does "correct" even mean?). So, the property should be clearly stated and unambiguous.
11 |
12 | Do not take into consideration any functions other than those named in the FUNCTIONS block. In general, write properties that relate the inputs of the function
13 | to its outputs, like, "the function square is monotonic" or "the function mult_neg_one is idemopotent" or "the function do is the inverse of the function undo, meaning, do(undo(x)) = x for all x".
14 | When writing properties, only write ones that you're quite confident the dev intended to hold. If a property seems dubious, then don't include it.
15 |
16 | Write a custom fuzzing function to test the provided code.
17 |
18 | [BEGIN EXAMPLE]
19 | Given the following code as input:
20 |
21 | [CODE]
22 |
23 | ```python
24 | def calculatePercentage(total, amount):
25 | return total / amount * 100
26 | ```
27 |
28 | [BEGIN RESPONSE]
29 |
30 | ```python
31 | import pytest
32 |
33 | def test_calculatePercentage_fuzz():
34 | # A fuzz test for larger random values of total and amount
35 | for _ in range(1000): # Simulate 1000 random tests
36 | total = random.uniform(0, 1000)
37 | amount = random.uniform(1, 1000) # Ensure amount > 0
38 | result = calculatePercentage(total, amount)
39 | assert 0 <= result <= 100
40 | ```
41 |
42 | [END RESPONSE]
43 | [END EXAMPLE]
44 |
45 | [GUIDANCE]
46 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
47 | - Try to only have one assert per PBT.
48 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
49 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
50 | - Pretty please I'm begging you don't use regex's.
51 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
52 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
53 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
54 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
55 | - Don't use "self". These are standalone PBTs, not part of some test class.
56 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
57 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
58 |
59 |
60 | [DOCUMENTATION]
61 |
62 | {documentation}
63 |
64 | [GUIDANCE]
65 | - Take into consideration the types of the functions under test (implied or explicit). E.g., if a function expects a non-NaN input, don't feed it a NaN.
66 | - Try to only have one assert per PBT.
67 | - If you can't figure out how to correctly generate something don't test it. Just skip that test.
68 | - Only write valid Python. E.g., if you want to use await make sure the function is async.
69 | - Pretty please I'm begging you don't use regex's.
70 | - If you're going to generate inputs that cause the code to throw an exception, and you know it, use a try/catch to assert that the exception happens, or modify the generated value to avoid the
71 | exception. E.g., if you use datetime.fromtimestamp(...) with an enormous input value, it'll throw an OverflowError. This is expected and you should either change how you generate the inputs to
72 | avoid it, or use a try/catch to make sure it happens where we expect it to happen.
73 | - Use some input validation in the PBTs to avoid inputs that are uninteresting / unreasonable.
74 | - Don't use "self". These are standalone PBTs, not part of some test class.
75 | - Write the tests such that they can be pasted at the end of the file containing the code under test.
76 | - When generating floats, carefully consider how to set the options allow_nan and allow_infinity, keeping in mind that e.g., in Python NaN != NaN.
77 |
78 |
79 | [CODE]
80 |
81 | ```python
82 | {code}
83 | ```
84 |
85 | [BEGIN RESPONSE]
--------------------------------------------------------------------------------
/src/prompts/pytest_docs.txt:
--------------------------------------------------------------------------------
1 | API Reference
2 | This page contains the full reference to pytest’s API.
3 |
4 | Constants
5 | pytest.__version__
6 | The current pytest version, as a string:
7 |
8 | import pytest
9 | pytest.__version__
10 | '7.0.0'
11 | pytest.version_tuple
12 | Added in version 7.0.
13 |
14 | The current pytest version, as a tuple:
15 |
16 | import pytest
17 | pytest.version_tuple
18 | (7, 0, 0)
19 | For pre-releases, the last component will be a string with the prerelease version:
--------------------------------------------------------------------------------
/src/prompts/user_questions.txt:
--------------------------------------------------------------------------------
1 | How do I optimize the performance of a Python program that processes large CSV files?
2 |
3 | Can you explain the difference between a correlation and a causation in simple terms?
4 |
5 | What are the key differences between renewable and non-renewable energy sources?
6 |
7 | Can you help me draft a polite email to reschedule a meeting for next week?
8 |
9 | Why is my Python script throwing a 'KeyError' when accessing a dictionary?
10 |
11 | Can you suggest some ideas for a short story set on a distant planet?
12 |
13 | How do I calculate the median of a dataset in Python using numpy?
--------------------------------------------------------------------------------
/src/regex_interventions.py:
--------------------------------------------------------------------------------
1 | """
2 | Regex Interventions Module
3 |
4 | This module implements testing and evaluation of model interventions specifically
5 | focused on regex pattern usage in generated code. It provides functionality to:
6 | 1. Run controlled experiments with different intervention types
7 | 2. Measure regex usage in generated code
8 | 3. Evaluate code quality and correctness
9 | 4. Compare baseline and intervention results
10 | """
11 |
12 | import torch
13 | from transformers import AutoTokenizer
14 | from dataclasses import asdict
15 | from tqdm import tqdm
16 | from typing import Callable, Dict, List, Tuple, Union
17 | import json
18 | import time
19 | import asyncio
20 | import re
21 | import sys
22 | import os
23 | import logging
24 |
25 | logging.basicConfig(level=logging.INFO)
26 |
27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28 |
29 | from src.wrapper import InterventionWrapper
30 | from src.eval_config import EvalConfig, InterventionType
31 | import src.utils as utils
32 | import src.caa as caa
33 | from src.agent_eval import batch_evaluate_generations, summarize_evaluations
34 |
35 |
36 | def measure_regex_usage(
37 | responses: list[str],
38 | prompt: str,
39 | prefill: str
40 | ) -> tuple[int, int, int, int]:
41 | """Analyze generated code for regex usage and validity.
42 |
43 | Args:
44 | responses: List of generated text responses
45 | prompt: Original prompt used for generation
46 | prefill: Prefix text to remove from responses
47 |
48 | Returns:
49 | Tuple containing counts of:
50 | - Valid Python code snippets
51 | - Snippets using regex
52 | - Syntactically valid snippets
53 | - Semantically valid snippets
54 | """
55 | valid_python_count = 0
56 | syntactically_valid_python_count = 0
57 | semantically_valid_python_count = 0
58 | regex_usage_count = 0
59 |
60 | for i, response in tqdm(enumerate(responses), desc="Measuring regex usage"):
61 | response = re.sub(r'(? list[str]:
88 | """Clean and extract actual responses from model outputs.
89 |
90 | Args:
91 | responses: Raw model outputs
92 | prompt: Original prompt to remove
93 | prefill: Prefix text to remove
94 |
95 | Returns:
96 | List of cleaned response texts
97 | """
98 | extracted_responses = []
99 | for i, response in enumerate(responses):
100 | response = re.sub(r'(? list[str]:
114 | """Generate multiple responses without interventions.
115 |
116 | Args:
117 | wrapper: Model wrapper instance
118 | prompt: Input prompt
119 | batch_size: Number of generations per batch
120 | total_generations: Total number of responses to generate
121 | max_new_tokens: Maximum new tokens per generation
122 |
123 | Returns:
124 | List of generated responses
125 | """
126 | batched_prompts = [prompt] * batch_size
127 | num_batches = total_generations // batch_size
128 | generations = []
129 |
130 | for _ in tqdm(range(num_batches), desc="Generating responses"):
131 | response = wrapper.generate(batched_prompts, max_new_tokens=max_new_tokens)
132 | generations.extend(response)
133 |
134 | return generations
135 |
136 |
137 | def run_intervention(
138 | wrapper: InterventionWrapper,
139 | prompt: str,
140 | batch_size: int,
141 | total_generations: int,
142 | max_new_tokens: int,
143 | intervention_type: str,
144 | model_params: dict,
145 | scale: int,
146 | config: EvalConfig,
147 | ) -> list[str]:
148 | """Generate responses with specified intervention.
149 |
150 | Args:
151 | wrapper: Model wrapper instance
152 | prompt: Input prompt
153 | batch_size: Number of generations per batch
154 | total_generations: Total number of responses to generate
155 | max_new_tokens: Maximum new tokens per generation
156 | intervention_type: Type of intervention to apply
157 | model_params: Parameters for the intervention
158 | scale: Scale factor for intervention
159 | config: Evaluation configuration
160 |
161 | Returns:
162 | List of generated responses with intervention applied
163 | """
164 | batched_prompts = [prompt] * batch_size
165 | num_batches = total_generations // batch_size
166 | generations = []
167 |
168 | module_and_hook_fn = wrapper.get_hook(intervention_type, model_params, scale, config)
169 |
170 | for _ in tqdm(range(num_batches), desc="Generating responses"):
171 | response = wrapper.generate(
172 | batched_prompts, max_new_tokens=max_new_tokens, module_and_hook_fn=module_and_hook_fn
173 | )
174 | generations.extend(response)
175 |
176 | return generations
177 |
178 |
179 | def test_single_prompt(
180 | wrapper: InterventionWrapper,
181 | base_prompt: str,
182 | code_example: str,
183 | config: EvalConfig,
184 | model_params: dict,
185 | intervention_type: str,
186 | api_key: str,
187 | ) -> dict:
188 | """Test a single prompt with and without interventions.
189 |
190 | Args:
191 | wrapper: Model wrapper instance
192 | base_prompt: Base prompt template
193 | code_example: Code example to insert in prompt
194 | config: Evaluation configuration
195 | model_params: Model-specific parameters
196 | intervention_type: Type of intervention to test
197 | api_key: API key for LLM judge (if used)
198 |
199 | Returns:
200 | Dictionary containing:
201 | - Original and intervention generations
202 | - Evaluation results
203 | - Agent evaluations (if enabled)
204 | - Result summaries
205 | """
206 | results = {"generations": {}, "eval_results": {}, "agent_evals": {}, "agent_summaries": {}}
207 |
208 | # Format prompt for this code example
209 | prompt = base_prompt.replace("{code}", code_example)
210 | formatted_prompt = utils.format_llm_prompt(prompt, wrapper.tokenizer)
211 | formatted_prompt += config.prefill
212 |
213 | # Generate without interventions
214 | original_texts = run_generation(
215 | wrapper,
216 | formatted_prompt,
217 | config.batch_size,
218 | config.total_generations,
219 | config.max_new_tokens,
220 | )
221 | original_texts = extract_response(original_texts, formatted_prompt, config.prefill)
222 | results["generations"]["original"] = original_texts
223 | results["eval_results"]["original"] = measure_regex_usage(
224 | original_texts, formatted_prompt, config.prefill
225 | )
226 | logging.info(f"Original eval results: {results['eval_results']['original']}")
227 | if config.use_llm_judge:
228 | # Add agent evaluations for original texts
229 | agent_evals = batch_evaluate_generations(original_texts, prompt, api_key)
230 | results["agent_evals"]["original"] = [asdict(eval) for eval in agent_evals]
231 | results["agent_summaries"]["original"] = summarize_evaluations(agent_evals)
232 |
233 | # Generate with different intervention scales
234 | for scale in tqdm(config.scales, desc="Interventions"):
235 | modified_texts = run_intervention(
236 | wrapper,
237 | formatted_prompt,
238 | config.batch_size,
239 | config.total_generations,
240 | config.max_new_tokens,
241 | intervention_type,
242 | model_params,
243 | scale,
244 | config,
245 | )
246 |
247 | modified_texts = extract_response(modified_texts, formatted_prompt, config.prefill)
248 | results["generations"][f"intervention_{scale}"] = modified_texts
249 | results["eval_results"][f"intervention_{scale}"] = measure_regex_usage(
250 | modified_texts, formatted_prompt, config.prefill
251 | )
252 | logging.info(f"Intervention {scale} eval results: {results['eval_results'][f'intervention_{scale}']}")
253 | if config.use_llm_judge:
254 | # Add agent evaluations for interventions
255 | agent_evals = batch_evaluate_generations(modified_texts, prompt, api_key)
256 | results["agent_evals"][f"intervention_{scale}"] = [asdict(eval) for eval in agent_evals]
257 | results["agent_summaries"][f"intervention_{scale}"] = summarize_evaluations(agent_evals)
258 |
259 | results["prompt"] = formatted_prompt
260 | return results
261 |
262 |
263 | def test_sae_interventions(api_key: str) -> dict:
264 | """Run comprehensive intervention tests across multiple code examples.
265 |
266 | Args:
267 | api_key: API key for LLM judge evaluations
268 |
269 | Returns:
270 | Dictionary containing all test results and configurations
271 |
272 | Note:
273 | Results are saved after each code block to prevent data loss
274 | """
275 | config = EvalConfig()
276 | results = {"config": asdict(config)}
277 |
278 | # Setup
279 | device = "cuda" if torch.cuda.is_available() else "cpu"
280 | model_params = utils.get_model_params(config.model_name)
281 |
282 | # Initialize wrapper
283 | wrapper = InterventionWrapper(model_name=config.model_name, device=device, dtype=torch.bfloat16)
284 |
285 | # Load SAE
286 | wrapper.load_sae(release=model_params["sae_release"], sae_id=model_params["sae_id"], layer_idx=model_params["targ_layer"])
287 |
288 | # Load and format prompt
289 | base_prompt = utils.load_prompt_files(config)
290 |
291 | with open(f"{config.prompt_folder}/{config.code_filename}", "r") as f:
292 | code_blocks = json.load(f)
293 |
294 | print(f"Evaluating {len(code_blocks)} code blocks")
295 |
296 | print(f"Evaluating the following interventions: {config.intervention_types}")
297 |
298 | for intervention_type in config.intervention_types:
299 | print(f"Evaluating {intervention_type}!")
300 | results[intervention_type] = {"code_results": {}}
301 | for code_block_key, single_code_block in code_blocks.items():
302 | results[intervention_type]["code_results"][code_block_key] = test_single_prompt(
303 | wrapper,
304 | base_prompt,
305 | single_code_block,
306 | config,
307 | model_params,
308 | intervention_type,
309 | api_key,
310 | )
311 |
312 | # Save results after each code block
313 | with open(config.save_path, "w") as f:
314 | json.dump(results, f, indent=4)
315 |
316 | return results
317 |
318 |
319 | if __name__ == "__main__":
320 | """
321 | Main entry point for running intervention tests.
322 |
323 | Environment variables:
324 | PYTORCH_CUDA_ALLOC_CONF: Set to "expandable_segments:True"
325 | TOKENIZERS_PARALLELISM: Set to "false" for process safety
326 | OPENAI_API_KEY: Optional API key for LLM judge
327 | """
328 | import os
329 |
330 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
331 |
332 | # We disable this because we launch additional processes when checking for valid code
333 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
334 |
335 | import argparse
336 |
337 | torch.set_grad_enabled(False)
338 |
339 | parser = argparse.ArgumentParser()
340 | parser.add_argument(
341 | "--api_key", type=str, help="OpenAI API key", default=os.environ.get("OPENAI_API_KEY")
342 | )
343 | args = parser.parse_args()
344 |
345 | start_time = time.time()
346 | run_results = test_sae_interventions(args.api_key)
347 | print(f"Total time: {time.time() - start_time:.2f} seconds")
348 |
349 | run_filename = "run_results.json"
350 | with open(run_filename, "w") as f:
351 | json.dump(run_results, f, indent=4)
352 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import ast
3 | from typing import Optional, Tuple
4 | import builtins
5 | from contextlib import contextmanager
6 | import sys
7 | from io import StringIO
8 | from transformers import AutoTokenizer
9 | import random
10 | from multiprocessing import Process, Manager
11 |
12 |
13 | from src.eval_config import EvalConfig
14 |
15 |
16 | def get_model_params(model_name: str) -> dict[str, str | int]:
17 | """Get model-specific parameters for interventions.
18 |
19 | Args:
20 | model_name: Name/path of the model to get parameters for
21 |
22 | Returns:
23 | Dictionary containing:
24 | - sae_release: SAE model release identifier
25 | - sae_id: Specific SAE identifier (if applicable)
26 | - targ_layer: Target layer for intervention
27 | - feature_idx: Feature index in SAE
28 |
29 | Raises:
30 | ValueError: If model is not supported
31 | """
32 | if model_name == "google/gemma-2-9b-it":
33 | return {
34 | "sae_release": "gemma-scope-9b-it-res",
35 | "sae_id": "layer_9/width_16k/average_l0_88",
36 | "targ_layer": 9,
37 | "feature_idx": 3585,
38 | "secondary_feature_idx": 12650,
39 | }
40 | elif model_name == "meta-llama/Llama-3.1-8B-Instruct":
41 | return {
42 | "sae_release": "tilde-research/sieve_coding",
43 | "sae_id": None,
44 | "targ_layer": 12, # 8
45 | "feature_idx": 9853, # 9699
46 | }
47 | elif model_name == "google/gemma-2-2b-it":
48 | return {
49 | "sae_release": "gemma-scope-2b-pt-res",
50 | "sae_id": "layer_8/width_16k/average_l0_71",
51 | "targ_layer": 8,
52 | "feature_idx": 931,
53 | }
54 | else:
55 | raise ValueError(f"Unsupported model: {model_name}")
56 |
57 |
58 | def format_llm_prompt(prompt: str, tokenizer: AutoTokenizer) -> str:
59 | """Format the prompt according to Transformers instruction format."""
60 | chat = [
61 | {"role": "user", "content": prompt},
62 | ]
63 | return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
64 |
65 |
66 | def load_prompt_files(config: EvalConfig) -> str:
67 | """Load and process prompt files."""
68 | files: dict[str, str] = {
69 | "prompt": f"{config.prompt_folder}/{config.prompt_filename}",
70 | "docs": f"{config.prompt_folder}/{config.docs_filename}",
71 | }
72 |
73 | content: dict[str, str] = {}
74 | for key, filename in files.items():
75 | with open(filename, "r") as f:
76 | content[key] = f.read()
77 |
78 | # Format the prompt
79 | prompt = content["prompt"].replace("{documentation}", content["docs"])
80 |
81 | return prompt
82 |
83 |
84 | def extract_python(response: str, verbose: bool = True) -> Optional[str]:
85 | # Regex pattern to match python block
86 | pattern = r"```python\s*(.*?)\s*```"
87 |
88 | # Search for the pattern
89 | match = re.search(pattern, response, re.DOTALL)
90 |
91 | if match:
92 | python_str = match.group(1)
93 | return python_str
94 | else:
95 | if verbose:
96 | print("WARNING: No python block found")
97 | return None
98 |
99 |
100 | def check_for_re_usage(code_snippet: str) -> bool:
101 | """
102 | Checks if any re module function is used in a given code snippet.
103 |
104 | This is pretty basic and may not catch all cases.
105 | """
106 | # Define a pattern that matches common re module functions
107 | pattern = r"\bre\.(match|search|sub|findall|finditer|split|compile|fullmatch|escape|subn)\b"
108 |
109 | # Search for any of these patterns in the code snippet
110 | return bool(re.search(pattern, code_snippet))
111 |
112 |
113 | def print_generations(generations: list[str], prompt: str, prefill: str) -> None:
114 | """Print the generated texts, removing the prompt prefix."""
115 | for i, generation in enumerate(generations):
116 | if prompt in generation:
117 | generation = generation[len(prompt) - len(prefill) :]
118 | print(f"Generation {i}:")
119 | print(generation)
120 | print()
121 |
122 |
123 | def is_syntactically_valid_python(code: str) -> Tuple[bool, Optional[str]]:
124 | """
125 | Check if a string contains syntactically valid Python code.
126 |
127 | Args:
128 | code: String containing Python code to validate
129 |
130 | Returns:
131 | Tuple of (is_valid: bool, error_message: Optional[str])
132 | """
133 | try:
134 | ast.parse(code)
135 | return True, None
136 | except SyntaxError as e:
137 | return False, f"Syntax error: {str(e)}"
138 | except Exception as e:
139 | return False, f"Parsing error: {str(e)}"
140 |
141 |
142 | @contextmanager
143 | def restricted_compile_environment():
144 | """
145 | Context manager that provides a restricted environment for code compilation.
146 | Temporarily replaces stdout/stderr and restricts builtins to common safe operations.
147 | """
148 | # Save original stdout/stderr and builtins
149 | original_stdout = sys.stdout
150 | original_stderr = sys.stderr
151 | original_builtins = dict(builtins.__dict__)
152 |
153 | # Create string buffers for capturing output
154 | temp_stdout = StringIO()
155 | temp_stderr = StringIO()
156 |
157 | # Define safe exception types
158 | safe_exceptions = {
159 | "Exception": Exception,
160 | "ValueError": ValueError,
161 | "TypeError": TypeError,
162 | "AttributeError": AttributeError,
163 | "IndexError": IndexError,
164 | "KeyError": KeyError,
165 | "RuntimeError": RuntimeError,
166 | "StopIteration": StopIteration,
167 | "AssertionError": AssertionError,
168 | "NotImplementedError": NotImplementedError,
169 | "ZeroDivisionError": ZeroDivisionError,
170 | }
171 |
172 | # Expanded set of safe builtins
173 | safe_builtins = {
174 | # Constants
175 | "None": None,
176 | "False": False,
177 | "True": True,
178 | # Basic types and operations
179 | "abs": abs,
180 | "bool": bool,
181 | "int": int,
182 | "float": float,
183 | "str": str,
184 | "len": len,
185 | "type": type,
186 | "repr": repr,
187 | # Collections
188 | "list": list,
189 | "dict": dict,
190 | "set": set,
191 | "tuple": tuple,
192 | "frozenset": frozenset,
193 | "range": range,
194 | "enumerate": enumerate,
195 | "zip": zip,
196 | "reversed": reversed,
197 | # Type checking
198 | "isinstance": isinstance,
199 | "issubclass": issubclass,
200 | "hasattr": hasattr,
201 | "getattr": getattr,
202 | # Math operations
203 | "min": min,
204 | "max": max,
205 | "sum": sum,
206 | "round": round,
207 | "pow": pow,
208 | # String operations
209 | "chr": chr,
210 | "ord": ord,
211 | "format": format,
212 | # Itertools functions
213 | "filter": filter,
214 | "map": map,
215 | # Other safe operations
216 | "print": print, # Captured by StringIO
217 | "sorted": sorted,
218 | "any": any,
219 | "all": all,
220 | "iter": iter,
221 | "next": next,
222 | "slice": slice,
223 | "property": property,
224 | "staticmethod": staticmethod,
225 | "classmethod": classmethod,
226 | # Exception handling
227 | "try": "try",
228 | "except": "except",
229 | "finally": "finally",
230 | **safe_exceptions, # Add all safe exception types
231 | }
232 |
233 | try:
234 | # Replace stdout/stderr
235 | sys.stdout = temp_stdout
236 | sys.stderr = temp_stderr
237 |
238 | # Restrict builtins
239 | for key in list(builtins.__dict__.keys()):
240 | if key not in safe_builtins:
241 | del builtins.__dict__[key]
242 |
243 | # Add exception types to the builtins
244 | builtins.__dict__.update(safe_exceptions)
245 |
246 | yield temp_stdout, temp_stderr
247 |
248 | finally:
249 | # Restore original environment
250 | sys.stdout = original_stdout
251 | sys.stderr = original_stderr
252 | builtins.__dict__.clear()
253 | builtins.__dict__.update(original_builtins)
254 |
255 |
256 | def is_semantically_valid_python(code: str) -> Tuple[bool, Optional[str]]:
257 | """
258 | Check if a string contains semantically valid Python code by:
259 | 1. Checking syntax
260 | 2. Verifying it contains actual code structure
261 | 3. Attempting to compile and validate basic execution
262 |
263 | Args:
264 | code: String containing Python code to validate
265 |
266 | Returns:
267 | Tuple of (is_valid: bool, error_message: Optional[str])
268 | """
269 | # First check syntax
270 | syntax_valid, syntax_error = is_syntactically_valid_python(code)
271 | if not syntax_valid:
272 | return False, syntax_error
273 |
274 | # Basic content validation
275 | code = code.strip()
276 | if not code:
277 | return False, "Empty code string"
278 |
279 | # Check for basic code structure (must have at least one function or class definition)
280 | if not any(keyword in code for keyword in ["def ", "class "]):
281 | return False, "No function or class definitions found"
282 |
283 | # Check for excessive non-ASCII characters that aren't in strings/comments
284 | code_lines = code.split("\n")
285 | invalid_lines = 0
286 | total_lines = len(code_lines)
287 |
288 | for line in code_lines:
289 | line = line.strip()
290 | # Skip empty lines and comments
291 | if not line or line.startswith("#"):
292 | total_lines -= 1
293 | continue
294 |
295 | # Count lines with too many non-ASCII characters
296 | non_ascii_count = len([c for c in line if ord(c) > 127])
297 | if non_ascii_count / len(line) > 0.3: # More than 30% non-ASCII
298 | invalid_lines += 1
299 |
300 | # If more than 20% of non-empty, non-comment lines are invalid
301 | if total_lines > 0 and (invalid_lines / total_lines) > 0.2:
302 | return False, "Code contains too many non-ASCII characters"
303 |
304 | try:
305 | # Try to compile the code
306 | try:
307 | compiled_code = compile(code, "", "exec")
308 | except Exception as e:
309 | return False, f"Compilation error: {str(e)}"
310 |
311 | # Create a restricted globals dict with common built-ins
312 | # restricted_globals = {
313 | # '__builtins__': {
314 | # name: getattr(builtins, name)
315 | # for name in [
316 | # 'len', 'int', 'str', 'list', 'dict', 'set', 'tuple',
317 | # 'min', 'max', 'True', 'False', 'None', 'type',
318 | # 'isinstance', 'print', 'range', 'compile', 'exec',
319 | # "import",
320 | # ]
321 | # }
322 | # }
323 |
324 | # Try to execute in the restricted environment
325 | try:
326 | exec(compiled_code, {}, {})
327 | except Exception as e:
328 | # Some errors are acceptable for valid code
329 | error_str = str(e)
330 | acceptable_errors = [
331 | "name 'pytest' is not defined",
332 | "name 're' is not defined",
333 | "name 'random' is not defined",
334 | "name 'time' is not defined",
335 | "name 'asyncio' is not defined",
336 | "name 'typing' is not defined",
337 | "name 'Optional' is not defined",
338 | "name 'List' is not defined",
339 | "name 'Dict' is not defined",
340 | "name 'Any' is not defined",
341 | "name 'Union' is not defined",
342 | ]
343 | if not any(err in error_str for err in acceptable_errors):
344 | return False, f"Execution error: {error_str}"
345 |
346 | return True, None
347 |
348 | except Exception as e:
349 | return False, f"Validation error: {str(e)}"
350 |
351 |
352 | def get_func_name(func_code: str) -> str:
353 | match = re.search(r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)", func_code)
354 | if match:
355 | func_name = match.group(1)
356 | else:
357 | return None
358 | return func_name
359 |
360 |
361 | def validate_llm_response(
362 | func_code: str, llm_code: str, timeout: float = 3.0, verbose: bool = False
363 | ) -> bool:
364 | """
365 | Validates whether the LLM-generated code runs, calls the specified function,
366 | and doesn't raise errors within the given timeout.
367 | """
368 |
369 | func_name = get_func_name(func_code)
370 |
371 | def exec_code(function_called):
372 | # List of dangerous modules to block
373 | dangerous_modules = {
374 | "os",
375 | "subprocess",
376 | "sys",
377 | "socket",
378 | "requests",
379 | "urllib",
380 | "ftplib",
381 | "telnetlib",
382 | "smtplib",
383 | "pathlib",
384 | "shutil",
385 | }
386 |
387 | def safe_import(name, *args, **kwargs):
388 | if name in dangerous_modules:
389 | raise ImportError(f"Import of {name} is not allowed for security reasons")
390 | return __import__(name, *args, **kwargs)
391 |
392 | # Create safe globals with all built-ins
393 | safe_globals = {}
394 | for name in dir(builtins):
395 | safe_globals[name] = getattr(builtins, name)
396 |
397 | safe_globals["__import__"] = safe_import
398 | safe_globals["__builtins__"] = safe_globals
399 |
400 | # Add commonly needed modules
401 | safe_globals.update(
402 | {
403 | "re": re,
404 | "random": random,
405 | "__name__": "__main__",
406 | }
407 | )
408 |
409 | # Execute the function definition
410 | try:
411 | exec(func_code, safe_globals)
412 | if func_name not in safe_globals:
413 | function_called["error"] = f"Function {func_name} was not properly defined."
414 | return
415 | except Exception as e:
416 | function_called["error"] = f"Error in function definition: {str(e)}"
417 | return
418 |
419 | # Store the original function and create wrapper
420 | original_func = safe_globals[func_name]
421 |
422 | def wrapper(*args, **kwargs):
423 | function_called["called"] = True
424 | return original_func(*args, **kwargs)
425 |
426 | safe_globals[func_name] = wrapper
427 |
428 | # Execute the test code
429 | try:
430 | exec(llm_code, safe_globals)
431 | except Exception as e:
432 | function_called["error"] = f"Error in test execution: {str(e)}"
433 | return
434 |
435 | # Shared dictionary for results
436 | manager = Manager()
437 | function_called = manager.dict()
438 | function_called["called"] = False
439 |
440 | # Run in separate process with timeout
441 | p = Process(target=exec_code, args=(function_called,))
442 | p.start()
443 | p.join(timeout)
444 |
445 | if p.is_alive():
446 | p.terminate()
447 | p.join()
448 | print("Execution timed out.")
449 | else:
450 | if "error" in function_called:
451 | if verbose:
452 | print(f"An error occurred: {function_called['error']}")
453 | elif not function_called["called"]:
454 | if verbose:
455 | print(f"{func_name}() was not called.")
456 | else:
457 | if verbose:
458 | print(f"{func_name}() was successfully called.")
459 | return True
460 | return False
461 |
462 |
463 | def validate_single_llm_response(prompt: str, response: str, verbose: bool = True) -> bool:
464 | # Regex pattern to match python block
465 | pattern = r"```python\s*(.*?)\s*```"
466 |
467 | # Search for the pattern
468 | matches = re.findall(pattern, prompt, re.DOTALL)
469 |
470 | original_code = matches[-1]
471 |
472 | llm_python = extract_python(response, verbose=False)
473 |
474 | if llm_python is None:
475 | return False
476 |
477 | func_name = get_func_name(llm_python)
478 |
479 | if func_name is None:
480 | return False
481 |
482 | llm_python = llm_python + f"\n\n{func_name}()"
483 | valid_code = validate_llm_response(original_code, llm_python)
484 |
485 | return valid_code
486 |
487 |
488 | def validate_all_llm_responses(
489 | data: dict, intervention_method: str, code_id: str, scale: int
490 | ) -> tuple[float, float]:
491 | prompt = data[intervention_method]["code_results"][code_id]["prompt"]
492 |
493 | # Regex pattern to match python block
494 | pattern = r"```python\s*(.*?)\s*```"
495 |
496 | # Search for the pattern
497 | matches = re.findall(pattern, prompt, re.DOTALL)
498 |
499 | original_code = matches[-1]
500 |
501 | total = 0
502 | valid = 0
503 | syntactically_valid = 0
504 |
505 | for response in data[intervention_method]["code_results"][code_id]["generations"][scale]:
506 | total += 1
507 |
508 | llm_python = extract_python(response, verbose=False)
509 |
510 | if llm_python is None:
511 | continue
512 |
513 | func_name = get_func_name(llm_python)
514 |
515 | if func_name is None:
516 | continue
517 |
518 | syntactically_valid += is_syntactically_valid_python(llm_python)[0]
519 |
520 | llm_python = llm_python + f"\n\n{func_name}()"
521 | valid_code = validate_llm_response(original_code, llm_python)
522 |
523 | if valid_code:
524 | valid += 1
525 |
526 | return (valid / total), (syntactically_valid / total)
527 |
--------------------------------------------------------------------------------
/src/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import einops
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from sae_lens import SAE
5 | from typing import Callable, Optional, Union, List, cast
6 | from jaxtyping import Float
7 | from torch import Tensor
8 | import contextlib
9 |
10 | from src.eval_config import EvalConfig, InterventionType
11 | import src.caa as caa
12 | from sae.sae import Sae
13 |
14 | try:
15 | import flash_attn
16 | USE_FA = True
17 | print("Flash attention installed")
18 | except ImportError:
19 | print("Flash attention not installed, using regular attention")
20 | USE_FA = False
21 |
22 |
23 | @contextlib.contextmanager
24 | def add_hook(
25 | module: torch.nn.Module,
26 | hook: Callable,
27 | ):
28 | """Temporarily adds a forward hook to a model module.
29 |
30 | Args:
31 | module: The PyTorch module to hook
32 | hook: The hook function to apply
33 |
34 | Yields:
35 | None: Used as a context manager
36 |
37 | Example:
38 | with add_hook(model.layer, hook_fn):
39 | output = model(input)
40 | """
41 | handle = module.register_forward_hook(hook)
42 | try:
43 | yield
44 | finally:
45 | handle.remove()
46 |
47 |
48 | def get_activation_addition_output_hook(
49 | vectors: list[Float[Tensor, "d_model"]], coeffs: list[float]
50 | ) -> Callable:
51 | """Creates a hook function that adds scaled vectors to layer activations.
52 |
53 | This hook performs a simple activation steering by adding scaled vectors
54 | to the layer's output activations. This is the most basic form of intervention.
55 |
56 | Args:
57 | vectors: List of vectors to add, each of shape (d_model,)
58 | coeffs: List of scaling coefficients for each vector
59 |
60 | Returns:
61 | Hook function that modifies layer activations
62 |
63 | """
64 |
65 | def hook_fn(module, input, output):
66 | if isinstance(output, tuple):
67 | resid_BLD = output[0]
68 | rest = output[1:]
69 | else:
70 | resid_BLD = output
71 | rest = ()
72 |
73 | for vector, coeff in zip(vectors, coeffs):
74 | vector = vector.to(resid_BLD.device)
75 | resid_BLD = resid_BLD + coeff * vector
76 |
77 | if rest:
78 | return (resid_BLD, *rest)
79 | else:
80 | return resid_BLD
81 |
82 | return hook_fn
83 |
84 |
85 | def get_conditional_per_input_hook(
86 | encoder_vectors: list[Float[Tensor, "d_model"]],
87 | decoder_vectors: list[Float[Tensor, "d_model"]],
88 | scales: list[float],
89 | encoder_thresholds: list[float],
90 | ) -> Callable:
91 | """Creates a hook function that conditionally applies interventions based on input-level activation.
92 |
93 | This hook checks if any token in the input sequence triggers the encoder vector
94 | above threshold. If triggered, applies the intervention to the entire sequence.
95 |
96 | Args:
97 | encoder_vectors: List of vectors used to detect activation patterns
98 | decoder_vectors: List of vectors to add when conditions are met
99 | scales: Scaling factors for decoder vectors
100 | encoder_thresholds: Threshold values for each encoder vector
101 |
102 | Returns:
103 | Hook function that conditionally modifies activations
104 |
105 | Note:
106 | - Zeros out BOS token activations to prevent false triggers
107 | - Intervention applies to entire sequence if any token triggers
108 | """
109 |
110 | def hook_fn(module, input, output):
111 | if isinstance(output, tuple):
112 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0]
113 | rest = output[1:]
114 | else:
115 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output
116 | rest = ()
117 |
118 | B, L, D = resid_BLD.shape
119 |
120 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip(
121 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
122 | ):
123 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device)
124 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device)
125 |
126 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D)
127 | feature_acts_BL[:, 0] = 0 # zero out the BOS token
128 | intervention_threshold_B11 = ((feature_acts_BL > encoder_threshold).any(dim=1).float())[
129 | :, None, None
130 | ]
131 | decoder_BLD = einops.repeat(decoder_vector_D * coeff, "D -> B L D", B=B, L=L).to(
132 | dtype=resid_BLD.dtype
133 | )
134 |
135 | resid_BLD += decoder_BLD * intervention_threshold_B11
136 |
137 | if rest:
138 | return (resid_BLD, *rest)
139 | else:
140 | return resid_BLD
141 |
142 | return hook_fn
143 |
144 |
145 | def get_conditional_per_token_hook(
146 | encoder_vectors: list[Float[Tensor, "d_model"]],
147 | decoder_vectors: list[Float[Tensor, "d_model"]],
148 | scales: list[float],
149 | encoder_thresholds: list[float],
150 | ) -> Callable:
151 | """Creates a hook function that conditionally applies interventions per token.
152 |
153 | Unlike the per-input hook, this applies interventions independently to each token
154 | based on whether it exceeds the encoder threshold.
155 |
156 | Args:
157 | encoder_vectors: List of vectors used to detect activation patterns
158 | decoder_vectors: List of vectors to add when conditions are met
159 | scales: Scaling factors for decoder vectors
160 | encoder_thresholds: Threshold values for each encoder vector
161 |
162 | Returns:
163 | Hook function that modifies activations on a per-token basis
164 |
165 | Note:
166 | More granular than per-input hook as it can selectively modify
167 | specific tokens in the sequence
168 | """
169 |
170 | def hook_fn(module, input, output):
171 | if isinstance(output, tuple):
172 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0]
173 | rest = output[1:]
174 | else:
175 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output
176 | rest = ()
177 |
178 | B, L, D = resid_BLD.shape
179 |
180 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip(
181 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
182 | ):
183 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device)
184 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device)
185 |
186 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D)
187 | intervention_mask_BL = feature_acts_BL > encoder_threshold
188 | decoder_BLD = einops.repeat(decoder_vector_D * coeff, "D -> B L D", B=B, L=L).to(
189 | dtype=resid_BLD.dtype
190 | )
191 |
192 | resid_BLD = torch.where(
193 | intervention_mask_BL.unsqueeze(-1),
194 | resid_BLD + decoder_BLD,
195 | resid_BLD,
196 | )
197 |
198 | if rest:
199 | return (resid_BLD, *rest)
200 | else:
201 | return resid_BLD
202 |
203 | return hook_fn
204 |
205 |
206 | def get_clamping_hook(
207 | encoder_vectors: list[Float[Tensor, "d_model"]],
208 | decoder_vectors: list[Float[Tensor, "d_model"]],
209 | scales: list[float],
210 | ) -> Callable:
211 | """Creates a hook function that clamps activations using decoder vectors.
212 |
213 | This hook fixes the activations to a target value in the decoder vector direction.
214 |
215 | Args:
216 | encoder_vectors: List of vectors defining directions to clamp
217 | decoder_vectors: List of vectors defining intervention directions
218 | scales: Target values for clamping (acts as offset after zeroing)
219 |
220 | Returns:
221 | Hook function that clamps and redirects activations
222 |
223 | Note:
224 | Useful for always having a final activation value in the decoder vector direction.
225 | """
226 |
227 | # coeff = -feature_acts_BL
228 | def hook_fn(module, input, output):
229 | if isinstance(output, tuple):
230 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0]
231 | rest = output[1:]
232 | else:
233 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output
234 | rest = ()
235 | B, L, D = resid_BLD.shape
236 |
237 | for encoder_vector_D, decoder_vector_D, coeff in zip(
238 | encoder_vectors, decoder_vectors, scales
239 | ):
240 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device)
241 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device)
242 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D)
243 | decoder_BLD = (-feature_acts_BL[:, :, None] + coeff) * decoder_vector_D[None, None, :]
244 | resid_BLD = torch.where(
245 | feature_acts_BL[:, :, None] > 0,
246 | resid_BLD + decoder_BLD,
247 | resid_BLD,
248 | )
249 |
250 | if rest:
251 | return (resid_BLD, *rest)
252 | else:
253 | return resid_BLD
254 |
255 | return hook_fn
256 |
257 |
258 | def get_conditional_clamping_hook(
259 | encoder_vectors: list[Float[Tensor, "d_model"]],
260 | decoder_vectors: list[Float[Tensor, "d_model"]],
261 | scales: list[float],
262 | encoder_thresholds: list[float],
263 | ) -> Callable:
264 | """Creates a hook function that conditionally clamps activations.
265 |
266 | Combines conditional intervention with clamping - only clamps activations
267 | when they exceed the encoder threshold with the decoder intervention.
268 |
269 | Args:
270 | encoder_vectors: List of vectors defining directions to monitor
271 | decoder_vectors: List of vectors defining intervention directions
272 | scales: Target values for clamping
273 | encoder_thresholds: Threshold values that trigger clamping
274 |
275 | Returns:
276 | Hook function that conditionally clamps and modifies activations
277 |
278 | Note:
279 | Most sophisticated intervention type, combining benefits of
280 | conditional application and activation clamping
281 | """
282 |
283 | def hook_fn(module, input, output):
284 | if isinstance(output, tuple):
285 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output[0]
286 | rest = output[1:]
287 | else:
288 | resid_BLD: Float[Tensor, "batch_size seq_len d_model"] = output
289 | rest = ()
290 |
291 | B, L, D = resid_BLD.shape
292 |
293 | for encoder_vector_D, decoder_vector_D, coeff, encoder_threshold in zip(
294 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
295 | ):
296 | encoder_vector_D = encoder_vector_D.to(resid_BLD.device)
297 | decoder_vector_D = decoder_vector_D.to(resid_BLD.device)
298 |
299 | # Get encoder activations
300 | feature_acts_BL = torch.einsum("BLD,D->BL", resid_BLD, encoder_vector_D)
301 |
302 | # Create mask for where encoder activation exceeds threshold
303 | intervention_mask_BL = feature_acts_BL > encoder_threshold
304 |
305 | # Calculate clamping amount only where mask is True
306 | decoder_BLD = (-feature_acts_BL[:, :, None] + coeff) * decoder_vector_D[None, None, :]
307 |
308 | # Apply clamping only where both mask is True and activation is positive
309 | resid_BLD = torch.where(
310 | (intervention_mask_BL[:, :, None] & (feature_acts_BL[:, :, None] > 0)),
311 | resid_BLD + decoder_BLD,
312 | resid_BLD,
313 | )
314 |
315 | if rest:
316 | return (resid_BLD, *rest)
317 | else:
318 | return resid_BLD
319 |
320 | return hook_fn
321 |
322 |
323 | class InterventionWrapper:
324 | """Wrapper class for applying interventions to language models.
325 |
326 | This class manages model loading, intervention application, and generation
327 | with various steering techniques.
328 |
329 | Attributes:
330 | model: The underlying language model
331 | tokenizer: Tokenizer for the model
332 | sae: Optional Sparse Autoencoder for intervention
333 | device: Device to run computations on
334 | caa_steering_vector: Cached steering vector for interventions
335 | probe_vector: Cached probe vector for guided steering
336 | """
337 |
338 | def __init__(self, model_name: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
339 | """Initialize the wrapper with a specified model.
340 |
341 | Args:
342 | model_name: HuggingFace model identifier
343 | device: Computing device ('cuda' or 'cpu')
344 | dtype: Data type for model weights
345 | """
346 |
347 | self.model = AutoModelForCausalLM.from_pretrained(
348 | model_name,
349 | device_map="auto", # Load model on first visible GPU
350 | torch_dtype=dtype,
351 | attn_implementation="flash_attention_2" if USE_FA else "eager",
352 | )
353 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
354 | self.model_name = model_name
355 | self.interventions = {}
356 | self.sae: Optional[Union[SAE, Sae]] = None
357 | self.device = device
358 | self.caa_steering_vector: Optional[Tensor] = None
359 | self.probe_vector: Optional[Tensor] = None
360 | self.probe_bias: Optional[Tensor] = None
361 |
362 | def generate(
363 | self,
364 | batched_prompts: list[str],
365 | max_new_tokens: int,
366 | temperature: float = 1.0,
367 | module_and_hook_fn: Optional[tuple[torch.nn.Module, Callable]] = None,
368 | ) -> list[str]:
369 | """Generate text with optional interventions.
370 |
371 | Args:
372 | batched_prompts: List of input prompts
373 | max_new_tokens: Maximum number of tokens to generate
374 | temperature: Sampling temperature
375 | module_and_hook_fn: Optional tuple of (module, hook_function) for intervention
376 |
377 | Returns:
378 | List of generated text strings
379 |
380 | Note:
381 | Prompts must contain the model's BOS token
382 | """
383 | assert all(
384 | self.tokenizer.bos_token in prompt for prompt in batched_prompts
385 | ), "All prompts must contain the BOS token."
386 | batched_tokens = self.tokenizer(
387 | batched_prompts, add_special_tokens=False, return_tensors="pt"
388 | ).to(self.device)
389 | batched_tokens = batched_tokens["input_ids"]
390 |
391 | if module_and_hook_fn:
392 | module, hook_fn = module_and_hook_fn
393 | context_manager = add_hook(
394 | module=module,
395 | hook=hook_fn,
396 | )
397 | with context_manager:
398 | generated_toks = self.model.generate(
399 | batched_tokens,
400 | max_new_tokens=max_new_tokens,
401 | do_sample=True,
402 | temperature=temperature,
403 | )
404 | else:
405 | generated_toks = self.model.generate(
406 | batched_tokens,
407 | max_new_tokens=max_new_tokens,
408 | do_sample=True,
409 | temperature=temperature,
410 | )
411 |
412 | response = [self.tokenizer.decode(tokens) for tokens in generated_toks]
413 |
414 | return response
415 |
416 | def get_hook(
417 | self,
418 | intervention_type: str,
419 | model_params: dict,
420 | scale: Union[int, float],
421 | config: EvalConfig,
422 | ) -> tuple[torch.nn.Module, Callable]:
423 | """Create a hook function for the specified intervention type.
424 |
425 | Args:
426 | intervention_type: Type of intervention to apply
427 | model_params: Parameters for the intervention including:
428 | - targ_layer: Target layer index
429 | - feature_idx: Feature index for SAE
430 | scale: Scaling factor for the intervention
431 | config: Configuration for evaluation and steering
432 |
433 | Returns:
434 | Tuple of (target_module, hook_function)
435 |
436 | Raises:
437 | AttributeError: If SAE is required but not loaded
438 | ValueError: If intervention type is not supported
439 |
440 | Note:
441 | Different intervention types require different preconditions:
442 | - SAE interventions require loaded SAE
443 | - Steering vector interventions calculate vectors on first use
444 | - Probe interventions train probes on first use
445 | """
446 | module = self.model.model.layers[model_params["targ_layer"]]
447 |
448 | # Convert scale to float for type compatibility
449 | scales: List[float] = [float(scale)]
450 |
451 | if self.sae is None:
452 | raise AttributeError("SAE must be loaded before getting hook")
453 |
454 | # Get encoder/decoder vectors with proper null checks
455 | encoder_vectors = [cast(Tensor, self.sae.W_enc[:, [model_params["feature_idx"]]].squeeze())]
456 | decoder_vectors = [cast(Tensor, self.sae.W_dec[[model_params["feature_idx"]]].squeeze())]
457 |
458 | # Normalize decoder vectors
459 | decoder_vectors = [v / v.norm() for v in decoder_vectors]
460 | encoder_thresholds = [2.0] if "gemma" in self.model_name else [5.0] # TODO make this an arg, llama 8B features use a higher threshold
461 | print(f"Encoder thresholds: {encoder_thresholds}")
462 | if hasattr(self.sae, "threshold"): # Check for jumprelu
463 | threshold_val = float(self.sae.threshold[model_params["feature_idx"]])
464 | encoder_thresholds = [threshold_val]
465 |
466 | # for i, threshold in enumerate(encoder_thresholds):
467 | # encoder_thresholds[i] = threshold + config.encoder_threshold_bias
468 |
469 | # Initialize steering vectors if needed
470 | steering_vectors: List[Tensor] = []
471 | if intervention_type in [
472 | InterventionType.CONSTANT_STEERING_VECTOR.value,
473 | InterventionType.CONDITIONAL_STEERING_VECTOR.value,
474 | InterventionType.SAE_STEERING_VECTOR.value,
475 | InterventionType.PROBE_STEERING_VECTOR.value,
476 | InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value,
477 | ]:
478 | if self.caa_steering_vector is None:
479 | self.caa_steering_vector = caa.calculate_steering_vector(
480 | config.prompt_folder,
481 | config.contrastive_prompts_filename,
482 | config.code_filename,
483 | self.model,
484 | self.tokenizer,
485 | config.prompt_type,
486 | model_params["targ_layer"],
487 | )
488 |
489 | # TODO: Support multiple steering vectors
490 | steering_vector_threshold = (
491 | caa.get_threshold(
492 | config,
493 | model_params,
494 | self,
495 | self.caa_steering_vector,
496 | encoder_vectors[0],
497 | encoder_thresholds[0],
498 | )
499 | + config.steering_vector_threshold_bias
500 | )
501 |
502 | steering_vector_thresholds = [steering_vector_threshold]
503 | encoder_steering_vectors = [self.caa_steering_vector]
504 | steering_vectors = [self.caa_steering_vector]
505 | steering_vectors = [v / v.norm() for v in steering_vectors]
506 |
507 | if intervention_type in [
508 | InterventionType.PROBE_SAE.value,
509 | InterventionType.PROBE_SAE_CLAMPING.value,
510 | InterventionType.PROBE_STEERING_VECTOR.value,
511 | InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value,
512 | ]:
513 | if self.probe_vector is None:
514 | self.probe_vector, self.probe_bias = caa.calculate_probe_vector(
515 | config.prompt_folder,
516 | config.probe_prompts_filename,
517 | config.code_filename,
518 | self.model,
519 | self.tokenizer,
520 | config.prompt_type,
521 | model_params["targ_layer"],
522 | )
523 | probe_vector_threshold = [self.probe_bias]
524 |
525 | if intervention_type == InterventionType.CONSTANT_SAE.value:
526 | hook_fn = get_activation_addition_output_hook(decoder_vectors, scales)
527 | elif intervention_type == InterventionType.CONSTANT_STEERING_VECTOR.value:
528 | hook_fn = get_activation_addition_output_hook(steering_vectors, scales)
529 | elif intervention_type == InterventionType.PROBE_STEERING_VECTOR.value:
530 | hook_fn = get_conditional_per_token_hook(
531 | [self.probe_vector], steering_vectors, scales, probe_vector_threshold
532 | )
533 | elif intervention_type == InterventionType.PROBE_SAE.value:
534 | hook_fn = get_conditional_per_token_hook(
535 | [self.probe_vector], decoder_vectors, scales, probe_vector_threshold
536 | )
537 | elif intervention_type == InterventionType.PROBE_SAE_CLAMPING.value:
538 | hook_fn = get_conditional_clamping_hook(
539 | [self.probe_vector], decoder_vectors, scales, probe_vector_threshold
540 | )
541 | elif intervention_type == InterventionType.PROBE_STEERING_VECTOR_CLAMPING.value:
542 | hook_fn = get_conditional_clamping_hook(
543 | [self.probe_vector], steering_vectors, scales, probe_vector_threshold
544 | )
545 | elif intervention_type == InterventionType.SAE_STEERING_VECTOR.value:
546 | hook_fn = get_conditional_per_token_hook(
547 | encoder_vectors, steering_vectors, scales, encoder_thresholds
548 | )
549 | elif intervention_type == InterventionType.CONDITIONAL_PER_INPUT.value:
550 | hook_fn = get_conditional_per_input_hook(
551 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
552 | )
553 | elif intervention_type == InterventionType.CONDITIONAL_PER_TOKEN.value:
554 | hook_fn = get_conditional_per_token_hook(
555 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
556 | )
557 | elif intervention_type == InterventionType.CONDITIONAL_STEERING_VECTOR.value:
558 | hook_fn = get_conditional_per_token_hook(
559 | encoder_steering_vectors, steering_vectors, scales, steering_vector_thresholds
560 | )
561 | elif intervention_type == InterventionType.CLAMPING.value:
562 | hook_fn = get_clamping_hook(encoder_vectors, decoder_vectors, scales)
563 | elif intervention_type == InterventionType.CONDITIONAL_CLAMPING.value:
564 | hook_fn = get_conditional_clamping_hook(
565 | encoder_vectors, decoder_vectors, scales, encoder_thresholds
566 | )
567 | else:
568 | raise ValueError(f"Unsupported intervention type: {intervention_type}")
569 | return module, hook_fn
570 |
571 | def load_sae(self, release: str, sae_id: str, layer_idx: int):
572 | """Load a Sparse Autoencoder for interventions.
573 |
574 | Args:
575 | release: Release identifier for the SAE
576 | sae_id: Specific SAE identifier
577 | layer_idx: Layer index the SAE was trained on
578 |
579 | Note:
580 | Supports both tilde-research and standard SAE formats
581 | """
582 | if "tilde-research" in release:
583 | self.sae = Sae.from_pretrained(release, layer_idx=layer_idx)
584 | self.sae = self.sae.to(dtype=self.model.dtype, device=self.model.device)
585 | return
586 | self.sae, _, _ = SAE.from_pretrained(
587 | release=release, sae_id=sae_id, device=str(self.model.device)
588 | )
589 | self.sae = self.sae.to(dtype=self.model.dtype, device=self.model.device)
590 |
--------------------------------------------------------------------------------