├── .gitignore ├── LICENSE ├── README.md ├── neuron_explainer.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── neuron_explainer ├── __init__.py ├── activations │ ├── __init__.py │ ├── activation_records.py │ ├── activations.py │ ├── attention_utils.py │ └── token_connections.py ├── api_client.py ├── azure.py ├── explanations │ ├── __init__.py │ ├── calibrated_simulator.py │ ├── explainer.py │ ├── explanations.py │ ├── few_shot_examples.py │ ├── prompt_builder.py │ ├── puzzles.json │ ├── puzzles.py │ ├── scoring.py │ ├── simulator.py │ ├── test_explainer.py │ ├── test_simulator.py │ └── token_space_few_shot_examples.py └── fast_dataclasses │ ├── __init__.py │ ├── fast_dataclasses.py │ └── test_fast_dataclasses.py ├── poetry.lock ├── pyproject.toml └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | dist/ 3 | 4 | __pycache__/ 5 | 6 | build/ 7 | 8 | exampletest.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Superalignment, OpenAI 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automated interpretability 2 | 3 | This is a fork of OpenAI's `automated-interpretability` [here](https://github.com/openai/automated-interpretability). The README below has not been updated. 4 | 5 | ## Code and tools 6 | 7 | This repository contains code and tools associated with the [Language models can explain neurons in 8 | language models](https://openaipublic.blob.core.windows.net/neuron-explainer/paper/index.html) paper, specifically: 9 | 10 | * Code for automatically generating, simulating, and scoring explanations of neuron behavior using 11 | the methodology described in the paper. See the 12 | [neuron-explainer README](neuron-explainer/README.md) for more information. 13 | 14 | Note: if you run into errors of the form "Error: Could not find any credentials that grant access to storage account: 'openaipublic' and container: 'neuron-explainer'"." you might be able to fix this by signing up for an azure account and specifying the credentials as described in the error message. 15 | 16 | * A tool for viewing neuron activations and explanations, accessible 17 | [here](https://openaipublic.blob.core.windows.net/neuron-explainer/neuron-viewer/index.html). See 18 | the [neuron-viewer README](neuron-viewer/README.md) for more information. 19 | 20 | ## Public datasets 21 | 22 | Together with this code, we're also releasing public datasets of GPT-2 XL neurons and explanations. 23 | Here's an overview of those datasets. 24 | 25 | * Neuron activations: `az://openaipublic/neuron-explainer/data/collated-activations/{layer_index}/{neuron_index}.json` 26 | - Tokenized text sequences and their activations for the neuron. We 27 | provide multiple sets of tokens and activations: top-activating ones, random 28 | samples from several quantiles; and a completely random sample. We also provide 29 | some basic statistics for the activations. 30 | - Each file contains a JSON-formatted 31 | [`NeuronRecord`](neuron-explainer/neuron_explainer/activations/activations.py#L89) dataclass. 32 | * Neuron explanations: `az://openaipublic/neuron-explainer/data/explanations/{layer_index}/{neuron_index}.jsonl` 33 | - Scored model-generated explanations of the behavior of the neuron, including simulation results. 34 | - Each file contains a JSON-formatted 35 | [`NeuronSimulationResults`](neuron-explainer/neuron_explainer/explanations/explanations.py#L146) 36 | dataclass. 37 | * Related neurons: `az://openaipublic/neuron-explainer/data/related-neurons/weight-based/{layer_index}/{neuron_index}.json` 38 | - Lists of the upstream and downstream neurons with the most positive and negative connections (see below for definition). 39 | - Each file contains a JSON-formatted dataclass whose definition is not included in this repo. 40 | * Tokens with high average activations: 41 | `az://openaipublic/neuron-explainer/data/related-tokens/activation-based/{layer_index}/{neuron_index}.json` 42 | - Lists of tokens with the highest average activations for individual neurons, and their average activations. 43 | - Each file contains a JSON-formatted [`TokenLookupTableSummaryOfNeuron`](neuron-explainer/neuron_explainer/activations/token_connections.py#L36) 44 | dataclass. 45 | * Tokens with large inbound and outbound weights: 46 | `az://openaipublic/neuron-explainer/data/related-tokens/weight-based/{layer_index}/{neuron_index}.json` 47 | - List of the most-positive and most-negative input and output tokens for individual neurons, 48 | as well as the associated weight (see below for definition). 49 | - Each file contains a JSON-formatted [`WeightBasedSummaryOfNeuron`](neuron-explainer/neuron_explainer/activations/token_connections.py#L17) 50 | dataclass. 51 | 52 | Update (July 5, 2023): 53 | We also released a set of explanations for GPT-2 Small. The methodology is slightly different from the methodology used for GPT-2 XL so the results aren't directly comparable. 54 | * Neuron activations: `az://openaipublic/neuron-explainer/gpt2_small_data/collated-activations/{layer_index}/{neuron_index}.json` 55 | * Neuron explanations: `az://openaipublic/neuron-explainer/gpt2_small_data/explanations/{layer_index}/{neuron_index}.jsonl` 56 | 57 | Update (August 30, 2023): We recently discovered a bug in how we performed inference on the GPT-2 series models used for the paper and for these datasets. Specifically, we used an optimized GELU implementation rather than the original GELU implementation associated with GPT-2. While the model’s behavior is very similar across these two configurations, the post-MLP activation values we used to generate and simulate explanations differ from the correct values by the following amounts for GPT-2 small: 58 | 59 | - Median: 0.0090 60 | - 90th percentile: 0.0252 61 | - 99th percentile: 0.0839 62 | - 99.9th percentile: 0.1736 63 | 64 | ### Definition of connection weights 65 | 66 | Refer to [GPT-2 model code](https://github.com/openai/gpt-2/blob/master/src/model.py) for 67 | understanding of model weight conventions. 68 | 69 | *Neuron-neuron*: For two neurons `(l1, n1)` and `(l2, n2)` with `l1 < l2`, the connection strength is defined as 70 | `h{l1}.mlp.c_proj.w[:, n1, :] @ diag(h{l2}.ln_2.g) @ h{l2}.mlp.c_fc.w[:, :, n2]`. 71 | 72 | *Neuron-token*: For token `t` and neuron `(l, n)`, the input weight is computed as 73 | `wte[t, :] @ diag(h{l}.ln_2.g) @ h{l}.mlp.c_fc.w[:, :, n]` 74 | and the output weight is computed as 75 | `h{l}.mlp.c_proj.w[:, n, :] @ diag(ln_f.g) @ wte[t, :]`. 76 | 77 | ### Misc Lists of Interesting Neurons 78 | Lists of neurons we thought were interesting according to different criteria, with some preliminary descriptions. 79 | * [Interesting Neurons (external)](https://docs.google.com/spreadsheets/d/1p7fYs31NU8sJoeKyUx4Mn2laGx8xXfHg_KcIvYiKPpg/edit#gid=0) 80 | * [Neurons that score high on random, possibly monosemantic? (external)](https://docs.google.com/spreadsheets/d/1TqKFcz-84jyIHLU7VRoTc8BoFBMpbgac-iNBnxVurQ8/edit?usp=sharing) 81 | * [Clusters of neurons well explained by activation explanation but not by tokens](https://docs.google.com/document/d/1lWhKowpKDdwTMALD_K541cdwgGoQx8DFUSuEe1U2AGE/edit?usp=sharing) 82 | * [Neurons sensitive to truncation](https://docs.google.com/document/d/1x89TWBvuHcyC2t01EDbJZJ5LQYHozlcS-VUmr5shf_A/edit?usp=sharing) 83 | -------------------------------------------------------------------------------- /neuron_explainer.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: neuron_explainer 3 | Version: 0.0.10 4 | Home-page: 5 | Author: OpenAI 6 | Requires-Python: >=3.9 7 | License-File: LICENSE 8 | Requires-Dist: httpx>=0.22 9 | Requires-Dist: scikit-learn 10 | Requires-Dist: boostedblob>=0.13.0 11 | Requires-Dist: tiktoken 12 | Requires-Dist: blobfile 13 | Requires-Dist: numpy 14 | Requires-Dist: pytest 15 | Requires-Dist: orjson 16 | Dynamic: author 17 | Dynamic: license-file 18 | Dynamic: requires-dist 19 | Dynamic: requires-python 20 | -------------------------------------------------------------------------------- /neuron_explainer.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | setup.py 5 | neuron_explainer/__init__.py 6 | neuron_explainer/api_client.py 7 | neuron_explainer/azure.py 8 | neuron_explainer.egg-info/PKG-INFO 9 | neuron_explainer.egg-info/SOURCES.txt 10 | neuron_explainer.egg-info/dependency_links.txt 11 | neuron_explainer.egg-info/requires.txt 12 | neuron_explainer.egg-info/top_level.txt 13 | neuron_explainer/activations/__init__.py 14 | neuron_explainer/activations/activation_records.py 15 | neuron_explainer/activations/activations.py 16 | neuron_explainer/activations/attention_utils.py 17 | neuron_explainer/activations/token_connections.py 18 | neuron_explainer/explanations/__init__.py 19 | neuron_explainer/explanations/calibrated_simulator.py 20 | neuron_explainer/explanations/explainer.py 21 | neuron_explainer/explanations/explanations.py 22 | neuron_explainer/explanations/few_shot_examples.py 23 | neuron_explainer/explanations/prompt_builder.py 24 | neuron_explainer/explanations/puzzles.py 25 | neuron_explainer/explanations/scoring.py 26 | neuron_explainer/explanations/simulator.py 27 | neuron_explainer/explanations/test_explainer.py 28 | neuron_explainer/explanations/test_simulator.py 29 | neuron_explainer/explanations/token_space_few_shot_examples.py 30 | neuron_explainer/fast_dataclasses/__init__.py 31 | neuron_explainer/fast_dataclasses/fast_dataclasses.py 32 | neuron_explainer/fast_dataclasses/test_fast_dataclasses.py -------------------------------------------------------------------------------- /neuron_explainer.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /neuron_explainer.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | httpx>=0.22 2 | scikit-learn 3 | boostedblob>=0.13.0 4 | tiktoken 5 | blobfile 6 | numpy 7 | pytest 8 | orjson 9 | -------------------------------------------------------------------------------- /neuron_explainer.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | neuron_explainer 2 | -------------------------------------------------------------------------------- /neuron_explainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijohnnylin/automated-interpretability/18166df580b54f5db3d56865d4a912e4a841f7ca/neuron_explainer/__init__.py -------------------------------------------------------------------------------- /neuron_explainer/activations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijohnnylin/automated-interpretability/18166df580b54f5db3d56865d4a912e4a841f7ca/neuron_explainer/activations/__init__.py -------------------------------------------------------------------------------- /neuron_explainer/activations/activation_records.py: -------------------------------------------------------------------------------- 1 | """Utilities for formatting activation records into prompts.""" 2 | 3 | import math 4 | from typing import Optional, Sequence 5 | 6 | from neuron_explainer.activations.activations import ActivationRecord 7 | 8 | UNKNOWN_ACTIVATION_STRING = "unknown" 9 | 10 | 11 | def relu(x: float) -> float: 12 | return max(0.0, x) 13 | 14 | 15 | def calculate_max_activation(activation_records: Sequence[ActivationRecord]) -> float: 16 | """Return the maximum activation value of the neuron across all the activation records.""" 17 | flattened = [ 18 | # Relu is used to assume any values less than 0 are indicating the neuron is in the resting 19 | # state. This is a simplifying assumption that works with relu/gelu. 20 | max(relu(x) for x in activation_record.activations) 21 | for activation_record in activation_records 22 | ] 23 | return max(flattened) 24 | 25 | 26 | def normalize_activations(activation_record: list[float], max_activation: float) -> list[int]: 27 | """Convert raw neuron activations to integers on the range [0, 10].""" 28 | if max_activation <= 0: 29 | return [0 for x in activation_record] 30 | # Relu is used to assume any values less than 0 are indicating the neuron is in the resting 31 | # state. This is a simplifying assumption that works with relu/gelu. 32 | return [min(10, math.floor(10 * relu(x) / max_activation)) for x in activation_record] 33 | 34 | 35 | def _format_activation_record( 36 | activation_record: ActivationRecord, 37 | max_activation: float, 38 | omit_zeros: bool, 39 | hide_activations: bool = False, 40 | start_index: int = 0, 41 | ) -> str: 42 | """Format neuron activations into a string, suitable for use in prompts.""" 43 | tokens = activation_record.tokens 44 | normalized_activations = normalize_activations(activation_record.activations, max_activation) 45 | if omit_zeros: 46 | assert (not hide_activations) and start_index == 0, "Can't hide activations and omit zeros" 47 | tokens = [ 48 | token for token, activation in zip(tokens, normalized_activations) if activation > 0 49 | ] 50 | normalized_activations = [x for x in normalized_activations if x > 0] 51 | 52 | entries = [] 53 | assert len(tokens) == len(normalized_activations) 54 | for index, token, activation in zip(range(len(tokens)), tokens, normalized_activations): 55 | activation_string = str(int(activation)) 56 | if hide_activations or index < start_index: 57 | activation_string = UNKNOWN_ACTIVATION_STRING 58 | entries.append(f"{token}\t{activation_string}") 59 | return "\n".join(entries) 60 | 61 | 62 | def format_activation_records( 63 | activation_records: Sequence[ActivationRecord], 64 | max_activation: float, 65 | *, 66 | omit_zeros: bool = False, 67 | start_indices: Optional[list[int]] = None, 68 | hide_activations: bool = False, 69 | ) -> str: 70 | """Format a list of activation records into a string.""" 71 | return ( 72 | "\n\n" 73 | + "\n\n\n".join( 74 | [ 75 | _format_activation_record( 76 | activation_record, 77 | max_activation, 78 | omit_zeros=omit_zeros, 79 | hide_activations=hide_activations, 80 | start_index=0 if start_indices is None else start_indices[i], 81 | ) 82 | for i, activation_record in enumerate(activation_records) 83 | ] 84 | ) 85 | + "\n\n" 86 | ) 87 | 88 | 89 | def _format_tokens_for_simulation(tokens: Sequence[str]) -> str: 90 | """ 91 | Format tokens into a string with each token marked as having an "unknown" activation, suitable 92 | for use in prompts. 93 | """ 94 | entries = [] 95 | for token in tokens: 96 | entries.append(f"{token}\t{UNKNOWN_ACTIVATION_STRING}") 97 | return "\n".join(entries) 98 | 99 | 100 | def format_sequences_for_simulation( 101 | all_tokens: Sequence[Sequence[str]], 102 | ) -> str: 103 | """ 104 | Format a list of lists of tokens into a string with each token marked as having an "unknown" 105 | activation, suitable for use in prompts. 106 | """ 107 | return ( 108 | "\n\n" 109 | + "\n\n\n".join( 110 | [_format_tokens_for_simulation(tokens) for tokens in all_tokens] 111 | ) 112 | + "\n\n" 113 | ) 114 | 115 | 116 | def non_zero_activation_proportion( 117 | activation_records: Sequence[ActivationRecord], max_activation: float 118 | ) -> float: 119 | """Return the proportion of activation values that aren't zero.""" 120 | total_activations_count = sum( 121 | [len(activation_record.activations) for activation_record in activation_records] 122 | ) 123 | normalized_activations = [ 124 | normalize_activations(activation_record.activations, max_activation) 125 | for activation_record in activation_records 126 | ] 127 | non_zero_activations_count = sum( 128 | [len([x for x in activations if x != 0]) for activations in normalized_activations] 129 | ) 130 | return non_zero_activations_count / total_activations_count 131 | -------------------------------------------------------------------------------- /neuron_explainer/activations/activations.py: -------------------------------------------------------------------------------- 1 | # Dataclasses and enums for storing neuron-indexed information about activations. Also, related 2 | # helper functions. 3 | 4 | import math 5 | from dataclasses import dataclass, field 6 | from typing import List, Optional, Union 7 | 8 | import urllib.request 9 | import blobfile as bf 10 | import boostedblob as bbb 11 | from neuron_explainer.fast_dataclasses import FastDataclass, loads, register_dataclass 12 | from neuron_explainer.azure import standardize_azure_url 13 | 14 | 15 | @register_dataclass 16 | @dataclass 17 | class ActivationRecord(FastDataclass): 18 | """Collated lists of tokens and their activations for a single neuron.""" 19 | 20 | tokens: List[str] 21 | """Tokens in the text sequence, represented as strings.""" 22 | activations: List[float] 23 | """Raw activation values for the neuron on each token in the text sequence.""" 24 | dfa_values: Optional[List[int]] = None 25 | dfa_target_index: Optional[int] = None 26 | 27 | 28 | @register_dataclass 29 | @dataclass 30 | class NeuronId(FastDataclass): 31 | """Identifier for a neuron in an artificial neural network.""" 32 | 33 | layer_index: int 34 | """The index of layer the neuron is in. The first layer used during inference has index 0.""" 35 | neuron_index: int 36 | """The neuron's index within in its layer. Indices start from 0 in each layer.""" 37 | 38 | 39 | def _check_slices( 40 | slices_by_split: dict[str, slice], 41 | expected_num_values: int, 42 | ) -> None: 43 | """Assert that the slices are disjoint and fully cover the intended range.""" 44 | indices = set() 45 | sum_of_slice_lengths = 0 46 | n_splits = len(slices_by_split.keys()) 47 | for s in slices_by_split.values(): 48 | subrange = range(expected_num_values)[s] 49 | sum_of_slice_lengths += len(subrange) 50 | indices |= set(subrange) 51 | assert ( 52 | sum_of_slice_lengths == expected_num_values 53 | ), f"{sum_of_slice_lengths=} != {expected_num_values=}" 54 | stride = n_splits 55 | expected_indices = set.union( 56 | *[ 57 | set(range(start_index, expected_num_values, stride)) 58 | for start_index in range(n_splits) 59 | ] 60 | ) 61 | assert indices == expected_indices, f"{indices=} != {expected_indices=}" 62 | 63 | 64 | def get_slices_for_splits( 65 | splits: list[str], 66 | num_activation_records_per_split: int, 67 | ) -> dict[str, slice]: 68 | """ 69 | Get equal-sized interleaved subsets for each of a list of splits, given the number of elements 70 | to include in each split. 71 | """ 72 | 73 | stride = len(splits) 74 | num_activation_records_for_even_splits = num_activation_records_per_split * stride 75 | slices_by_split = { 76 | split: slice(split_index, num_activation_records_for_even_splits, stride) 77 | for split_index, split in enumerate(splits) 78 | } 79 | _check_slices( 80 | slices_by_split=slices_by_split, 81 | expected_num_values=num_activation_records_for_even_splits, 82 | ) 83 | return slices_by_split 84 | 85 | 86 | @dataclass 87 | class ActivationRecordSliceParams: 88 | """How to select splits (train, valid, etc.) of activation records.""" 89 | 90 | n_examples_per_split: Optional[int] 91 | """The number of examples to include in each split.""" 92 | 93 | 94 | @register_dataclass 95 | @dataclass 96 | class NeuronRecord(FastDataclass): 97 | """Neuron-indexed activation data, including summary stats and notable activation records.""" 98 | 99 | neuron_id: NeuronId 100 | """Identifier for the neuron.""" 101 | 102 | random_sample: list[ActivationRecord] = field(default_factory=list) 103 | """ 104 | Random activation records for this neuron. The random sample is independent from those used for 105 | other neurons. 106 | """ 107 | random_sample_by_quantile: Optional[list[list[ActivationRecord]]] = None 108 | """ 109 | Random samples of activation records in each of the specified quantiles. None if quantile 110 | tracking is disabled. 111 | """ 112 | quantile_boundaries: Optional[list[float]] = None 113 | """Boundaries of the quantiles used to generate the random_sample_by_quantile field.""" 114 | 115 | # Moments of activations 116 | mean: Optional[float] = math.nan 117 | variance: Optional[float] = math.nan 118 | skewness: Optional[float] = math.nan 119 | kurtosis: Optional[float] = math.nan 120 | 121 | most_positive_activation_records: list[ActivationRecord] = field( 122 | default_factory=list 123 | ) 124 | """ 125 | Activation records with the most positive figure of merit value for this neuron over all dataset 126 | examples. 127 | """ 128 | 129 | @property 130 | def max_activation(self) -> float: 131 | """Return the maximum activation value over all top-activating activation records.""" 132 | return max( 133 | [max(ar.activations) for ar in self.most_positive_activation_records] 134 | ) 135 | 136 | def _get_top_activation_slices( 137 | self, activation_record_slice_params: ActivationRecordSliceParams 138 | ) -> dict[str, slice]: 139 | splits = ["train", "calibration", "valid", "test"] 140 | n_examples_per_split = activation_record_slice_params.n_examples_per_split 141 | if n_examples_per_split is None: 142 | n_examples_per_split = len(self.most_positive_activation_records) // len( 143 | splits 144 | ) 145 | assert len(self.most_positive_activation_records) >= n_examples_per_split * len( 146 | splits 147 | ) 148 | return get_slices_for_splits(splits, n_examples_per_split) 149 | 150 | def _get_random_activation_slices( 151 | self, activation_record_slice_params: ActivationRecordSliceParams 152 | ) -> dict[str, slice]: 153 | splits = ["calibration", "valid", "test"] 154 | n_examples_per_split = activation_record_slice_params.n_examples_per_split 155 | if n_examples_per_split is None: 156 | n_examples_per_split = len(self.random_sample) // len(splits) 157 | # NOTE: this assert could trigger on some old datasets with only 10 random samples, in which case you may have to remove "test" from the set of splits 158 | assert len(self.random_sample) >= n_examples_per_split * len(splits) 159 | return get_slices_for_splits(splits, n_examples_per_split) 160 | 161 | def train_activation_records( 162 | self, 163 | activation_record_slice_params: ActivationRecordSliceParams, 164 | ) -> list[ActivationRecord]: 165 | """ 166 | Train split, typically used for generating explanations. Consists exclusively of 167 | top-activating records since context window limitations make it difficult to include 168 | random records. 169 | """ 170 | return self.most_positive_activation_records[ 171 | self._get_top_activation_slices(activation_record_slice_params)["train"] 172 | ] 173 | 174 | def calibration_activation_records( 175 | self, 176 | activation_record_slice_params: ActivationRecordSliceParams, 177 | ) -> list[ActivationRecord]: 178 | """ 179 | Calibration split, typically used for calibrating neuron simulations. See 180 | http://go/neuron_explanation_methodology for an explanation of calibration. Consists of 181 | top-activating records and random records in a 1:1 ratio. 182 | """ 183 | return ( 184 | self.most_positive_activation_records[ 185 | self._get_top_activation_slices(activation_record_slice_params)[ 186 | "calibration" 187 | ] 188 | ] 189 | + self.random_sample[ 190 | self._get_random_activation_slices(activation_record_slice_params)[ 191 | "calibration" 192 | ] 193 | ] 194 | ) 195 | 196 | def valid_activation_records( 197 | self, 198 | activation_record_slice_params: ActivationRecordSliceParams, 199 | ) -> list[ActivationRecord]: 200 | """ 201 | Validation split, typically used for evaluating explanations, either automatically with 202 | simulation + correlation coefficient scoring, or manually by humans. Consists of 203 | top-activating records and random records in a 1:1 ratio. 204 | """ 205 | return ( 206 | self.most_positive_activation_records[ 207 | self._get_top_activation_slices(activation_record_slice_params)["valid"] 208 | ] 209 | + self.random_sample[ 210 | self._get_random_activation_slices(activation_record_slice_params)[ 211 | "valid" 212 | ] 213 | ] 214 | ) 215 | 216 | def test_activation_records( 217 | self, 218 | activation_record_slice_params: ActivationRecordSliceParams, 219 | ) -> list[ActivationRecord]: 220 | """ 221 | Test split, typically used for explanation evaluations that can't use the validation split. 222 | Consists of top-activating records and random records in a 1:1 ratio. 223 | """ 224 | return ( 225 | self.most_positive_activation_records[ 226 | self._get_top_activation_slices(activation_record_slice_params)["test"] 227 | ] 228 | + self.random_sample[ 229 | self._get_random_activation_slices(activation_record_slice_params)[ 230 | "test" 231 | ] 232 | ] 233 | ) 234 | 235 | 236 | def neuron_exists( 237 | dataset_path: str, layer_index: Union[str, int], neuron_index: Union[str, int] 238 | ) -> bool: 239 | """Return whether the specified neuron exists.""" 240 | file = bf.join(dataset_path, "neurons", str(layer_index), f"{neuron_index}.json") 241 | return bf.exists(file) 242 | 243 | 244 | def load_neuron( 245 | layer_index: Union[str, int], 246 | neuron_index: Union[str, int], 247 | dataset_path: str = "https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations", 248 | ) -> NeuronRecord: 249 | """Load the NeuronRecord for the specified neuron.""" 250 | url = "/".join([dataset_path, str(layer_index), f"{neuron_index}.json"]) 251 | url = standardize_azure_url(url) 252 | with urllib.request.urlopen(url) as f: 253 | neuron_record = loads(f.read()) 254 | if not isinstance(neuron_record, NeuronRecord): 255 | raise ValueError( 256 | f"Stored data incompatible with current version of NeuronRecord dataclass." 257 | ) 258 | return neuron_record 259 | 260 | 261 | @bbb.ensure_session 262 | async def load_neuron_async( 263 | layer_index: Union[str, int], 264 | neuron_index: Union[str, int], 265 | dataset_path: str = "az://openaipublic/neuron-explainer/data/collated-activations", 266 | ) -> NeuronRecord: 267 | """Async version of load_neuron.""" 268 | file = bf.join(dataset_path, str(layer_index), f"{neuron_index}.json") 269 | return await read_neuron_file(file) 270 | 271 | 272 | @bbb.ensure_session 273 | async def read_neuron_file(neuron_filename: str) -> NeuronRecord: 274 | """Like load_neuron_async, but takes a raw neuron filename.""" 275 | raw_contents = await bbb.read.read_single(neuron_filename) 276 | neuron_record = loads(raw_contents.decode("utf-8")) 277 | if not isinstance(neuron_record, NeuronRecord): 278 | raise ValueError( 279 | f"Stored data incompatible with current version of NeuronRecord dataclass." 280 | ) 281 | return neuron_record 282 | 283 | 284 | def get_sorted_neuron_indices( 285 | dataset_path: str, layer_index: Union[str, int] 286 | ) -> List[int]: 287 | """Returns the indices of all neurons in this layer, in ascending order.""" 288 | layer_dir = bf.join(dataset_path, "neurons", str(layer_index)) 289 | return sorted( 290 | [ 291 | int(f.split(".")[0]) 292 | for f in bf.listdir(layer_dir) 293 | if f.split(".")[0].isnumeric() 294 | ] 295 | ) 296 | 297 | 298 | def get_sorted_layers(dataset_path: str) -> List[str]: 299 | """ 300 | Return the indices of all layers in this dataset, in ascending numerical order, as strings. 301 | """ 302 | return [ 303 | str(x) 304 | for x in sorted( 305 | [ 306 | int(x) 307 | for x in bf.listdir(bf.join(dataset_path, "neurons")) 308 | if x.isnumeric() 309 | ] 310 | ) 311 | ] 312 | -------------------------------------------------------------------------------- /neuron_explainer/activations/attention_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains math utilities for converting from flattened representations of attention activations 3 | (which are a scalar per token pair) to nested lists. The inner lists are attention activations 4 | related to attention from the same token (to different tokens). 5 | 6 | Tested in ./test_attention_utils.py. 7 | """ 8 | 9 | import math 10 | 11 | import numpy as np 12 | 13 | 14 | def _inverse_triangular_number(n: int) -> int: 15 | # the m'th triangular number t_m satisfies t_m = m(m+1)/2 16 | # this function asserts that n is a triangular number, and returns the unique m such that t_m = n 17 | # this is used to infer the number of sequence tokens from the number of activations 18 | assert n >= 0 19 | m: int = ( 20 | math.floor(math.sqrt(1 + 8 * n)) - 1 21 | ) // 2 # from quadratic formula applied to m(m+1)/2 = n 22 | assert m * (m + 1) // 2 == n 23 | return m 24 | 25 | 26 | def get_max_num_attended_to_sequence_tokens(num_sequence_tokens: int, num_activations: int) -> int: 27 | # Attended to sequences are assumed to increase in length up to a maximum length, and then stay at that 28 | # length for the remainder of the sequence. The maximum attended to sequence length is at most equal to the sequence length, 29 | # but is permitted to be less 30 | num_sequence_token_pairs = num_sequence_tokens * (num_sequence_tokens + 1) // 2 31 | if num_activations == num_sequence_token_pairs: 32 | # the maximum attended to sequence length is equal to the sequence length 33 | return num_sequence_tokens 34 | else: 35 | # the maximum attended to sequence length is less than the sequence length, and 36 | assert num_activations < num_sequence_token_pairs 37 | num_missing_activations = num_sequence_token_pairs - num_activations 38 | num_missing_sequence_tokens = _inverse_triangular_number(num_missing_activations) 39 | max_num_attended_to_sequence_tokens = num_sequence_tokens - num_missing_sequence_tokens 40 | assert max_num_attended_to_sequence_tokens > 0 41 | return max_num_attended_to_sequence_tokens 42 | 43 | 44 | def get_attended_to_sequence_length_per_sequence_token( 45 | num_sequence_tokens: int, max_num_attended_to_sequence_tokens: int 46 | ) -> list[int]: 47 | # given a num_sequence_tokens and a max_num_attended_to_sequence_tokens, return a list of length num_sequence_tokens 48 | # where the ith element is the length of the attended to sequence for the ith sequence token. 49 | # The length of the attended to sequence starts at 1, increases up to max_num_attended_to_sequence_tokens, by 1 with each 50 | # token, and then stays at max_num_attended_to_sequence_tokens for the remainder of the sequence 51 | assert num_sequence_tokens >= max_num_attended_to_sequence_tokens 52 | attended_to_sequence_lengths = list(range(1, max_num_attended_to_sequence_tokens + 1)) 53 | if num_sequence_tokens > max_num_attended_to_sequence_tokens: 54 | attended_to_sequence_lengths.extend( 55 | [ 56 | max_num_attended_to_sequence_tokens 57 | for _ in range(num_sequence_tokens - max_num_attended_to_sequence_tokens) 58 | ] 59 | ) 60 | return attended_to_sequence_lengths 61 | 62 | 63 | def get_attended_to_sequence_lengths(num_sequence_tokens: int, num_activations: int) -> list[int]: 64 | max_num_attended_to_sequence_tokens = get_max_num_attended_to_sequence_tokens( 65 | num_sequence_tokens, num_activations 66 | ) 67 | return get_attended_to_sequence_length_per_sequence_token( 68 | num_sequence_tokens, max_num_attended_to_sequence_tokens 69 | ) 70 | 71 | 72 | def _convert_flattened_index_to_unflattened_index_assuming_square_matrix( 73 | flat_index: int, 74 | ) -> tuple[int, int]: 75 | # this con 76 | n = math.floor((-1 + math.sqrt(1 + 8 * flat_index)) / 2) 77 | m = flat_index - n * (n + 1) // 2 78 | return n, m 79 | 80 | 81 | def convert_flattened_index_to_unflattened_index( 82 | flattened_index: int, 83 | num_sequence_tokens: int | None = None, 84 | num_activations: int | None = None, 85 | ) -> tuple[int, int]: 86 | # given a flattened index, return the unflattened index 87 | # if the attention matrix is square (most common), then the flattened_index uniquely determines the index within the square matrix 88 | # if the attention matrix has more rows (sequence tokens) than columns (attended-to sequence tokens), then num_sequence_tokens 89 | # and num_activations are required to determine the index within the matrix 90 | # specify both num_sequence_tokens and num_activations, or neither 91 | assert not (num_sequence_tokens is None) ^ (num_activations is None) 92 | 93 | if ( 94 | num_sequence_tokens is None 95 | or num_activations == num_sequence_tokens * (num_sequence_tokens + 1) // 2 96 | ): 97 | assume_square_matrix = True 98 | else: 99 | assume_square_matrix = False 100 | 101 | if assume_square_matrix: 102 | return _convert_flattened_index_to_unflattened_index_assuming_square_matrix(flattened_index) 103 | else: 104 | assert num_sequence_tokens is not None 105 | assert num_activations is not None 106 | assert flattened_index < num_activations 107 | sequence_lengths = get_attended_to_sequence_lengths(num_sequence_tokens, num_activations) 108 | sequence_lengths_cumsum = np.cumsum([0] + sequence_lengths) 109 | sequence_index = int( 110 | np.searchsorted(sequence_lengths_cumsum, flattened_index, side="right") - 1 111 | ) 112 | assert sequence_lengths_cumsum[sequence_index] <= flattened_index, ( 113 | sequence_lengths_cumsum[sequence_index], 114 | flattened_index, 115 | ) 116 | assert sequence_lengths_cumsum[sequence_index + 1] >= flattened_index, ( 117 | sequence_lengths_cumsum[sequence_index + 1], 118 | flattened_index, 119 | ) 120 | index_within_sequence = flattened_index - sequence_lengths_cumsum[sequence_index] 121 | return sequence_index, index_within_sequence 122 | -------------------------------------------------------------------------------- /neuron_explainer/activations/token_connections.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | import blobfile as bf 5 | from neuron_explainer.fast_dataclasses import FastDataclass, loads, register_dataclass 6 | from neuron_explainer.azure import standardize_azure_url 7 | import urllib.request 8 | 9 | 10 | @register_dataclass 11 | @dataclass 12 | class TokensAndWeights(FastDataclass): 13 | tokens: List[str] 14 | strengths: List[float] 15 | 16 | 17 | @register_dataclass 18 | @dataclass 19 | class WeightBasedSummaryOfNeuron(FastDataclass): 20 | input_positive: TokensAndWeights 21 | input_negative: TokensAndWeights 22 | output_positive: TokensAndWeights 23 | output_negative: TokensAndWeights 24 | 25 | 26 | def load_token_weight_connections_of_neuron( 27 | layer_index: Union[str, int], 28 | neuron_index: Union[str, int], 29 | dataset_path: str = "https://openaipublic.blob.core.windows.net/neuron-explainer/data/related-tokens/weight-based", 30 | ) -> WeightBasedSummaryOfNeuron: 31 | """Load the TokenLookupTableSummaryOfNeuron for the specified neuron.""" 32 | url = "/".join([dataset_path, str(layer_index), f"{neuron_index}.json"]) 33 | url = standardize_azure_url(url) 34 | with urllib.request.urlopen(url) as f: 35 | return loads(f.read(), backwards_compatible=False) 36 | 37 | 38 | @register_dataclass 39 | @dataclass 40 | class TokenLookupTableSummaryOfNeuron(FastDataclass): 41 | """List of tokens and the average activations of a given neuron in response to each 42 | respective token. These are selected from among the tokens in the vocabulary with the 43 | highest average activations across an internet text dataset, with the highest activations 44 | first.""" 45 | 46 | tokens: List[str] 47 | average_activations: List[float] 48 | 49 | 50 | def load_token_lookup_table_connections_of_neuron( 51 | layer_index: Union[str, int], 52 | neuron_index: Union[str, int], 53 | dataset_path: str = "https://openaipublic.blob.core.windows.net/neuron-explainer/data/related-tokens/activation-based", 54 | ) -> TokenLookupTableSummaryOfNeuron: 55 | """Load the TokenLookupTableSummaryOfNeuron for the specified neuron.""" 56 | url = "/".join([dataset_path, str(layer_index), f"{neuron_index}.json"]) 57 | url = standardize_azure_url(url) 58 | with urllib.request.urlopen(url) as f: 59 | return loads(f.read(), backwards_compatible=False) 60 | -------------------------------------------------------------------------------- /neuron_explainer/api_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import os 4 | import random 5 | import traceback 6 | from asyncio import Semaphore 7 | from functools import wraps 8 | from typing import Any, Callable, Optional 9 | 10 | import httpx 11 | import orjson 12 | 13 | 14 | def is_api_error(err: Exception) -> bool: 15 | if isinstance(err, httpx.HTTPStatusError): 16 | response = err.response 17 | error_data = response.json().get("error", {}) 18 | error_message = error_data.get("message") 19 | if response.status_code in [400, 404, 415]: 20 | if error_data.get("type") == "idempotency_error": 21 | print( 22 | f"Retrying after idempotency error: {error_message} ({response.url})" 23 | ) 24 | return True 25 | else: 26 | # Invalid request 27 | return False 28 | else: 29 | print(f"Retrying after API error: {error_message} ({response.url})") 30 | return True 31 | 32 | elif isinstance(err, httpx.ConnectError): 33 | print(f"Retrying after connection error... ({err.request.url})") 34 | return True 35 | 36 | elif isinstance(err, httpx.TimeoutException): 37 | print(f"Retrying after a timeout error... ({err.request.url})") 38 | return True 39 | 40 | elif isinstance(err, httpx.ReadError): 41 | print(f"Retrying after a read error... ({err.request.url})") 42 | return True 43 | 44 | print(f"Retrying after an unexpected error: {repr(err)}") 45 | traceback.print_tb(err.__traceback__) 46 | return True 47 | 48 | 49 | def exponential_backoff( 50 | retry_on: Callable[[Exception], bool] = lambda err: True, 51 | ) -> Callable[[Callable], Callable]: 52 | """ 53 | Returns a decorator which retries the wrapped function as long as the specified retry_on 54 | function returns True for the exception, applying exponential backoff with jitter after 55 | failures, up to a retry limit. 56 | """ 57 | init_delay_s = 1.0 58 | max_delay_s = 10.0 59 | # Roughly 30 minutes before we give up. 60 | max_tries = 200 61 | backoff_multiplier = 2.0 62 | jitter = 0.2 63 | 64 | def decorate(f: Callable) -> Callable: 65 | assert asyncio.iscoroutinefunction(f) 66 | 67 | @wraps(f) 68 | async def f_retry(*args: Any, **kwargs: Any) -> None: 69 | delay_s = init_delay_s 70 | for i in range(max_tries): 71 | try: 72 | return await f(*args, **kwargs) 73 | except Exception as err: 74 | if not retry_on(err) or i == max_tries - 1: 75 | raise 76 | jittered_delay = random.uniform( 77 | delay_s * (1 - jitter), delay_s * (1 + jitter) 78 | ) 79 | await asyncio.sleep(jittered_delay) 80 | delay_s = min(delay_s * backoff_multiplier, max_delay_s) 81 | 82 | return f_retry 83 | 84 | return decorate 85 | 86 | 87 | class ApiClient: 88 | """Performs inference using the OpenAI API. Supports response caching and concurrency limits.""" 89 | 90 | BASE_API_URL = "https://api.openai.com/v1" 91 | 92 | def __init__( 93 | self, 94 | model_name: str, 95 | # If set, no more than this number of HTTP requests will be made concurrently. 96 | max_concurrent: Optional[int] = None, 97 | # Whether to cache request/response pairs in memory to avoid duplicating requests. 98 | cache: bool = False, 99 | base_api_url: str = BASE_API_URL, 100 | override_api_key: str | None = None, 101 | ): 102 | self.model_name = model_name 103 | self.base_api_url = base_api_url 104 | self.override_api_key = override_api_key 105 | if max_concurrent is not None: 106 | self._concurrency_check: Optional[Semaphore] = Semaphore(max_concurrent) 107 | else: 108 | self._concurrency_check = None 109 | 110 | if cache: 111 | self._cache: Optional[dict[str, Any]] = {} 112 | else: 113 | self._cache = None 114 | 115 | @exponential_backoff(retry_on=is_api_error) 116 | async def make_request( 117 | self, 118 | timeout_seconds: Optional[int] = None, 119 | json_mode: Optional[bool] = False, 120 | **kwargs: Any, 121 | ) -> dict[str, Any]: 122 | api_http_headers = { 123 | "Content-Type": "application/json", 124 | "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY') if self.override_api_key is None else self.override_api_key}", 125 | } 126 | if self._cache is not None: 127 | key = orjson.dumps(kwargs) 128 | if key in self._cache: 129 | return self._cache[key] 130 | async with contextlib.AsyncExitStack() as stack: 131 | if self._concurrency_check is not None: 132 | await stack.enter_async_context(self._concurrency_check) 133 | http_client = await stack.enter_async_context( 134 | httpx.AsyncClient(timeout=timeout_seconds) 135 | ) 136 | # If the request has a "messages" key, it should be sent to the /chat/completions 137 | # endpoint. Otherwise, it should be sent to the /completions endpoint. 138 | url = self.base_api_url + ( 139 | "/chat/completions" if "messages" in kwargs else "/completions" 140 | ) 141 | kwargs["model"] = self.model_name 142 | if json_mode: 143 | kwargs["response_format"] = {"type": "json_object"} 144 | response = await http_client.post( 145 | url, headers=api_http_headers, json=kwargs 146 | ) 147 | # The response json has useful information but the exception doesn't include it, so print it 148 | # out then reraise. 149 | try: 150 | response.raise_for_status() 151 | except Exception as e: 152 | try: 153 | print(f"Error response status code: {response.status_code}") 154 | print(f"Error response JSON: {response.json()}") 155 | except Exception: 156 | print("Could not parse error response as JSON") 157 | print(f"Error response text: {response.text}") 158 | raise e 159 | if self._cache is not None: 160 | self._cache[key] = response.json() 161 | return response.json() 162 | 163 | 164 | if __name__ == "__main__": 165 | 166 | async def main() -> None: 167 | client = ApiClient(model_name="gpt-3.5-turbo", max_concurrent=1) 168 | print( 169 | await client.make_request( 170 | prompt="Why did the chicken cross the road?", max_tokens=9 171 | ) 172 | ) 173 | 174 | asyncio.run(main()) 175 | -------------------------------------------------------------------------------- /neuron_explainer/azure.py: -------------------------------------------------------------------------------- 1 | def standardize_azure_url(url): 2 | """Make sure url is converted to url format, not an azure path""" 3 | if url.startswith("az://openaipublic/"): 4 | url = url.replace("az://openaipublic/", "https://openaipublic.blob.core.windows.net/") 5 | return url 6 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hijohnnylin/automated-interpretability/18166df580b54f5db3d56865d4a912e4a841f7ca/neuron_explainer/explanations/__init__.py -------------------------------------------------------------------------------- /neuron_explainer/explanations/calibrated_simulator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for calibrating simulations of neuron behavior. Calibration refers to a process of mapping from 3 | a space of predicted activation values (e.g. [0, 10]) to the real activation distribution for a 4 | neuron. 5 | 6 | See http://go/neuron_explanation_methodology for description of calibration step. Necessary for 7 | simulating neurons in the context of ablate-to-simulation, but can be skipped when using correlation 8 | scoring. (Calibration may still improve quality for scoring, at least for non-linear calibration 9 | methods.) 10 | """ 11 | 12 | from __future__ import annotations 13 | 14 | import asyncio 15 | from abc import abstractmethod 16 | from typing import Optional, Sequence 17 | 18 | import numpy as np 19 | from neuron_explainer.activations.activations import ActivationRecord 20 | from neuron_explainer.explanations.explanations import ActivationScale 21 | from neuron_explainer.explanations.simulator import NeuronSimulator, SequenceSimulation 22 | from sklearn import linear_model 23 | 24 | 25 | class CalibratedNeuronSimulator(NeuronSimulator): 26 | """ 27 | Wrap a NeuronSimulator and calibrate it to map from the predicted activation space to the 28 | actual neuron activation space. 29 | """ 30 | 31 | def __init__(self, uncalibrated_simulator: NeuronSimulator): 32 | self.uncalibrated_simulator = uncalibrated_simulator 33 | 34 | @classmethod 35 | async def create( 36 | cls, 37 | uncalibrated_simulator: NeuronSimulator, 38 | calibration_activation_records: Sequence[ActivationRecord], 39 | ) -> CalibratedNeuronSimulator: 40 | """ 41 | Create and calibrate a calibrated simulator (so initialization and calibration can be done 42 | in one call). 43 | """ 44 | calibrated_simulator = cls(uncalibrated_simulator) 45 | await calibrated_simulator.calibrate(calibration_activation_records) 46 | return calibrated_simulator 47 | 48 | async def calibrate(self, calibration_activation_records: Sequence[ActivationRecord]) -> None: 49 | """ 50 | Determine parameters to map from the predicted activation space to the real neuron 51 | activation space, based on a calibration set. 52 | 53 | Use when simulated sequences haven't already been produced on the calibration set. 54 | """ 55 | simulations = await asyncio.gather( 56 | *[ 57 | self.uncalibrated_simulator.simulate(activations.tokens) 58 | for activations in calibration_activation_records 59 | ] 60 | ) 61 | self.calibrate_from_simulations(calibration_activation_records, simulations) 62 | 63 | def calibrate_from_simulations( 64 | self, 65 | calibration_activation_records: Sequence[ActivationRecord], 66 | simulations: Sequence[SequenceSimulation], 67 | ) -> None: 68 | """ 69 | Determine parameters to map from the predicted activation space to the real neuron 70 | activation space, based on a calibration set. 71 | 72 | Use when simulated sequences have already been produced on the calibration set. 73 | """ 74 | flattened_activations = [] 75 | flattened_simulated_activations: list[float] = [] 76 | for activations, simulation in zip(calibration_activation_records, simulations): 77 | flattened_activations.extend(activations.activations) 78 | flattened_simulated_activations.extend(simulation.expected_activations) 79 | self._calibrate_from_flattened_activations( 80 | np.array(flattened_activations), np.array(flattened_simulated_activations) 81 | ) 82 | 83 | @abstractmethod 84 | def _calibrate_from_flattened_activations( 85 | self, 86 | true_activations: np.ndarray, 87 | uncalibrated_activations: np.ndarray, 88 | ) -> None: 89 | """ 90 | Determine parameters to map from the predicted activation space to the real neuron 91 | activation space, based on a calibration set. 92 | 93 | Take numpy arrays of all true activations and all uncalibrated activations on the 94 | calibration set over all sequences. 95 | """ 96 | 97 | @abstractmethod 98 | def apply_calibration(self, values: Sequence[float]) -> list[float]: 99 | """Apply the learned calibration to a sequence of values.""" 100 | 101 | async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation: 102 | uncalibrated_seq_simulation = await self.uncalibrated_simulator.simulate(tokens) 103 | calibrated_activations = self.apply_calibration( 104 | uncalibrated_seq_simulation.expected_activations 105 | ) 106 | calibrated_distribution_values = [ 107 | self.apply_calibration(dv) for dv in uncalibrated_seq_simulation.distribution_values 108 | ] 109 | return SequenceSimulation( 110 | tokens=uncalibrated_seq_simulation.tokens, 111 | expected_activations=calibrated_activations, 112 | activation_scale=ActivationScale.NEURON_ACTIVATIONS, 113 | distribution_values=calibrated_distribution_values, 114 | distribution_probabilities=uncalibrated_seq_simulation.distribution_probabilities, 115 | uncalibrated_simulation=uncalibrated_seq_simulation, 116 | ) 117 | 118 | 119 | class UncalibratedNeuronSimulator(CalibratedNeuronSimulator): 120 | """Pass through the activations without trying to calibrate.""" 121 | 122 | def __init__(self, uncalibrated_simulator: NeuronSimulator): 123 | super().__init__(uncalibrated_simulator) 124 | 125 | async def calibrate(self, calibration_activation_records: Sequence[ActivationRecord]) -> None: 126 | pass 127 | 128 | def _calibrate_from_flattened_activations( 129 | self, 130 | true_activations: np.ndarray, 131 | uncalibrated_activations: np.ndarray, 132 | ) -> None: 133 | pass 134 | 135 | def apply_calibration(self, values: Sequence[float]) -> list[float]: 136 | return values if isinstance(values, list) else list(values) 137 | 138 | 139 | class LinearCalibratedNeuronSimulator(CalibratedNeuronSimulator): 140 | """Find a linear mapping from uncalibrated activations to true activations. 141 | 142 | Should not change ev_correlation_score because it is invariant to linear transformations. 143 | """ 144 | 145 | def __init__(self, uncalibrated_simulator: NeuronSimulator): 146 | super().__init__(uncalibrated_simulator) 147 | self._regression: Optional[linear_model.LinearRegression] = None 148 | 149 | def _calibrate_from_flattened_activations( 150 | self, 151 | true_activations: np.ndarray, 152 | uncalibrated_activations: np.ndarray, 153 | ) -> None: 154 | self._regression = linear_model.LinearRegression() 155 | self._regression.fit(uncalibrated_activations.reshape(-1, 1), true_activations) 156 | 157 | def apply_calibration(self, values: Sequence[float]) -> list[float]: 158 | if self._regression is None: 159 | raise ValueError("Must call calibrate() before apply_calibration") 160 | if len(values) == 0: 161 | return [] 162 | return self._regression.predict(np.reshape(np.array(values), (-1, 1))).tolist() 163 | 164 | 165 | class PercentileMatchingCalibratedNeuronSimulator(CalibratedNeuronSimulator): 166 | """ 167 | Map the nth percentile of the uncalibrated activations to the nth percentile of the true 168 | activations for all n. 169 | 170 | This will match the distribution of true activations on the calibration set, but will be 171 | overconfident outside of the calibration set. 172 | """ 173 | 174 | def __init__(self, uncalibrated_simulator: NeuronSimulator): 175 | super().__init__(uncalibrated_simulator) 176 | self._uncalibrated_activations: Optional[np.ndarray] = None 177 | self._true_activations: Optional[np.ndarray] = None 178 | 179 | def _calibrate_from_flattened_activations( 180 | self, 181 | true_activations: np.ndarray, 182 | uncalibrated_activations: np.ndarray, 183 | ) -> None: 184 | self._uncalibrated_activations = np.sort(uncalibrated_activations) 185 | self._true_activations = np.sort(true_activations) 186 | 187 | def apply_calibration(self, values: Sequence[float]) -> list[float]: 188 | if self._true_activations is None or self._uncalibrated_activations is None: 189 | raise ValueError("Must call calibrate() before apply_calibration") 190 | if len(values) == 0: 191 | return [] 192 | return np.interp( 193 | np.array(values), self._uncalibrated_activations, self._true_activations 194 | ).tolist() 195 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/explanations.py: -------------------------------------------------------------------------------- 1 | # Dataclasses and enums for storing neuron explanations, their scores, and related data. Also, 2 | # related helper functions. 3 | 4 | from __future__ import annotations 5 | 6 | import json 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import List, Optional, Union 10 | 11 | import blobfile as bf 12 | import boostedblob as bbb 13 | from neuron_explainer.activations.activations import NeuronId 14 | from neuron_explainer.fast_dataclasses import FastDataclass, loads, register_dataclass 15 | 16 | 17 | class ActivationScale(str, Enum): 18 | """Which "units" are stored in the expected_activations/distribution_values fields of a 19 | SequenceSimulation. 20 | 21 | This enum identifies whether the values represent real activations of the neuron or something 22 | else. Different scales are not necessarily related by a linear transformation. 23 | """ 24 | 25 | NEURON_ACTIVATIONS = "neuron_activations" 26 | """Values represent real activations of the neuron.""" 27 | SIMULATED_NORMALIZED_ACTIVATIONS = "simulated_normalized_activations" 28 | """ 29 | Values represent simulated activations of the neuron, normalized to the range [0, 10]. This 30 | scale is arbitrary and should not be interpreted as a neuron activation. 31 | """ 32 | 33 | 34 | @register_dataclass 35 | @dataclass 36 | class SequenceSimulation(FastDataclass): 37 | """The result of a simulation of neuron activations on one text sequence.""" 38 | 39 | tokens: list[str] 40 | """The sequence of tokens that was simulated.""" 41 | expected_activations: list[float] 42 | """Expected value of the possibly-normalized activation for each token in the sequence.""" 43 | activation_scale: ActivationScale 44 | """What scale is used for values in the expected_activations field.""" 45 | distribution_values: list[list[float]] 46 | """ 47 | For each token in the sequence, a list of values from the discrete distribution of activations 48 | produced from simulation. Tokens will be included here if and only if they are in the top K=15 49 | tokens predicted by the simulator, and excluded otherwise. 50 | 51 | May be transformed to another unit by calibration. When we simulate a neuron, we produce a 52 | discrete distribution with values in the arbitrary discretized space of the neuron, e.g. 10% 53 | chance of 0, 70% chance of 1, 20% chance of 2. Which we store as distribution_values = 54 | [0, 1, 2], distribution_probabilities = [0.1, 0.7, 0.2]. When we transform the distribution to 55 | the real activation units, we can correspondingly transform the values of this distribution 56 | to get a distribution in the units of the neuron. e.g. if the mapping from the discretized space 57 | to the real activation unit of the neuron is f(x) = x/2, then the distribution becomes 10% 58 | chance of 0, 70% chance of 0.5, 20% chance of 1. Which we store as distribution_values = 59 | [0, 0.5, 1], distribution_probabilities = [0.1, 0.7, 0.2]. 60 | """ 61 | distribution_probabilities: list[list[float]] 62 | """ 63 | For each token in the sequence, the probability of the corresponding value in 64 | distribution_values. 65 | """ 66 | 67 | uncalibrated_simulation: Optional["SequenceSimulation"] = None 68 | """The result of the simulation before calibration.""" 69 | 70 | 71 | @register_dataclass 72 | @dataclass 73 | class ScoredSequenceSimulation(FastDataclass): 74 | """ 75 | SequenceSimulation result with a score (for that sequence only) and ground truth activations. 76 | """ 77 | 78 | simulation: SequenceSimulation 79 | """The result of a simulation of neuron activations.""" 80 | true_activations: List[float] 81 | """Ground truth activations on the sequence (not normalized)""" 82 | ev_correlation_score: float 83 | """ 84 | Correlation coefficient between the expected values of the normalized activations from the 85 | simulation and the unnormalized true activations of the neuron on the text sequence. 86 | """ 87 | rsquared_score: Optional[float] = None 88 | """R^2 of the simulated activations.""" 89 | absolute_dev_explained_score: Optional[float] = None 90 | """ 91 | Score based on absolute difference between real and simulated activations. 92 | absolute_dev_explained_score = 1 - mean(abs(real-predicted))/ mean(abs(real)) 93 | """ 94 | 95 | 96 | @register_dataclass 97 | @dataclass 98 | class ScoredSimulation(FastDataclass): 99 | """Result of scoring a neuron simulation on multiple sequences.""" 100 | 101 | scored_sequence_simulations: List[ScoredSequenceSimulation] 102 | """ScoredSequenceSimulation for each sequence""" 103 | ev_correlation_score: Optional[float] = None 104 | """ 105 | Correlation coefficient between the expected values of the normalized activations from the 106 | simulation and the unnormalized true activations on a dataset created from all score_results. 107 | (Note that this is not equivalent to averaging across sequences.) 108 | """ 109 | rsquared_score: Optional[float] = None 110 | """R^2 of the simulated activations.""" 111 | absolute_dev_explained_score: Optional[float] = None 112 | """ 113 | Score based on absolute difference between real and simulated activations. 114 | absolute_dev_explained_score = 1 - mean(abs(real-predicted))/ mean(abs(real)). 115 | """ 116 | 117 | def get_preferred_score(self) -> Optional[float]: 118 | """ 119 | This method may return None in cases where the score is undefined, for example if the 120 | normalized activations were all zero, yielding a correlation coefficient of NaN. 121 | """ 122 | return self.ev_correlation_score 123 | 124 | 125 | @register_dataclass 126 | @dataclass 127 | class ScoredExplanation(FastDataclass): 128 | """Simulator parameters and the results of scoring it on multiple sequences""" 129 | 130 | explanation: str 131 | """The explanation used for simulation.""" 132 | 133 | scored_simulation: ScoredSimulation 134 | """Result of scoring the neuron simulator on multiple sequences.""" 135 | 136 | def get_preferred_score(self) -> Optional[float]: 137 | """ 138 | This method may return None in cases where the score is undefined, for example if the 139 | normalized activations were all zero, yielding a correlation coefficient of NaN. 140 | """ 141 | return self.scored_simulation.get_preferred_score() 142 | 143 | 144 | @register_dataclass 145 | @dataclass 146 | class NeuronSimulationResults(FastDataclass): 147 | """Simulation results and scores for a neuron.""" 148 | 149 | neuron_id: NeuronId 150 | scored_explanations: list[ScoredExplanation] 151 | 152 | 153 | def load_neuron_explanations( 154 | explanations_path: str, layer_index: Union[str, int], neuron_index: Union[str, int] 155 | ) -> Optional[NeuronSimulationResults]: 156 | """Load scored explanations for the specified neuron.""" 157 | file = bf.join(explanations_path, str(layer_index), f"{neuron_index}.jsonl") 158 | if not bf.exists(file): 159 | return None 160 | with bf.BlobFile(file) as f: 161 | for line in f: 162 | return loads(line) 163 | return None 164 | 165 | 166 | @bbb.ensure_session 167 | async def load_neuron_explanations_async( 168 | explanations_path: str, layer_index: Union[str, int], neuron_index: Union[str, int] 169 | ) -> Optional[NeuronSimulationResults]: 170 | """Load scored explanations for the specified neuron, asynchronously.""" 171 | return await read_explanation_file( 172 | bf.join(explanations_path, str(layer_index), f"{neuron_index}.jsonl") 173 | ) 174 | 175 | 176 | @bbb.ensure_session 177 | async def read_file(filename: str) -> Optional[str]: 178 | """Read the contents of the given file as a string, asynchronously.""" 179 | try: 180 | raw_contents = await bbb.read.read_single(filename) 181 | except FileNotFoundError: 182 | print(f"Could not read {filename}") 183 | return None 184 | lines = [] 185 | for line in raw_contents.decode("utf-8").split("\n"): 186 | if len(line) > 0: 187 | lines.append(line) 188 | assert len(lines) == 1, filename 189 | return lines[0] 190 | 191 | 192 | @bbb.ensure_session 193 | async def read_explanation_file(explanation_filename: str) -> Optional[NeuronSimulationResults]: 194 | """Load scored explanations from the given filename, asynchronously.""" 195 | line = await read_file(explanation_filename) 196 | return loads(line) if line is not None else None 197 | 198 | 199 | @bbb.ensure_session 200 | async def read_json_file(filename: str) -> Optional[dict]: 201 | """Read the contents of the given file as a JSON object, asynchronously.""" 202 | line = await read_file(filename) 203 | return json.loads(line) if line is not None else None 204 | 205 | 206 | def get_numerical_subdirs(dataset_path: str) -> list[str]: 207 | """Return the names of all numbered subdirectories in the specified directory. 208 | 209 | Used to get all layer directories in an explanation directory. 210 | """ 211 | return [ 212 | str(x) 213 | for x in sorted( 214 | [ 215 | int(x) 216 | for x in bf.listdir(dataset_path) 217 | if bf.isdir(bf.join(dataset_path, x)) and x.isnumeric() 218 | ] 219 | ) 220 | ] 221 | 222 | 223 | def get_sorted_neuron_indices_from_explanations( 224 | explanations_path: str, layer: Union[str, int] 225 | ) -> list[int]: 226 | """Return the indices of all neurons in this layer, in ascending order.""" 227 | layer_dir = bf.join(explanations_path, str(layer)) 228 | return sorted( 229 | [int(f.split(".")[0]) for f in bf.listdir(layer_dir) if f.split(".")[0].isnumeric()] 230 | ) 231 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | from typing import TypedDict, Union 5 | 6 | import tiktoken 7 | 8 | HarmonyMessage = TypedDict( 9 | "HarmonyMessage", 10 | { 11 | "role": str, 12 | "content": str, 13 | }, 14 | ) 15 | 16 | 17 | class PromptFormat(str, Enum): 18 | """ 19 | Different ways of formatting the components of a prompt into the format accepted by the relevant 20 | API server endpoint. 21 | """ 22 | 23 | NONE = "none" 24 | """Suitable for use with models that don't use special tokens for instructions.""" 25 | INSTRUCTION_FOLLOWING = "instruction_following" 26 | """Suitable for IF models that use <|endofprompt|>.""" 27 | HARMONY_V4 = "harmony_v4" 28 | """ 29 | Suitable for Harmony models that use a structured turn-taking role+content format. Generates a 30 | list of HarmonyMessage dicts that can be sent to the /chat/completions endpoint. 31 | """ 32 | 33 | @classmethod 34 | def from_string(cls, s: str) -> PromptFormat: 35 | for prompt_format in cls: 36 | if prompt_format.value == s: 37 | return prompt_format 38 | raise ValueError(f"{s} is not a valid PromptFormat") 39 | 40 | 41 | class Role(str, Enum): 42 | """See https://platform.openai.com/docs/guides/chat""" 43 | 44 | SYSTEM = "system" 45 | USER = "user" 46 | ASSISTANT = "assistant" 47 | 48 | 49 | class PromptBuilder: 50 | """Class for accumulating components of a prompt and then formatting them into an output.""" 51 | 52 | def __init__(self) -> None: 53 | self._messages: list[HarmonyMessage] = [] 54 | 55 | def add_message(self, role: Role, message: str) -> None: 56 | self._messages.append(HarmonyMessage(role=role, content=message)) 57 | 58 | def prompt_length_in_tokens(self, prompt_format: PromptFormat) -> int: 59 | # TODO(sbills): Make the model/encoding configurable. This implementation assumes GPT-4. 60 | encoding = tiktoken.get_encoding("cl100k_base") 61 | if prompt_format == PromptFormat.HARMONY_V4: 62 | # Approximately-correct implementation adapted from this documentation: 63 | # https://platform.openai.com/docs/guides/chat/introduction 64 | num_tokens = 0 65 | for message in self._messages: 66 | num_tokens += ( 67 | 4 # every message follows <|im_start|>{role/name}\n{content}<|im_end|>\n 68 | ) 69 | num_tokens += len(encoding.encode(message["content"], allowed_special="all")) 70 | num_tokens += 2 # every reply is primed with <|im_start|>assistant 71 | return num_tokens 72 | else: 73 | prompt_str = self.build(prompt_format) 74 | assert isinstance(prompt_str, str) 75 | return len(encoding.encode(prompt_str, allowed_special="all")) 76 | 77 | def build( 78 | self, prompt_format: PromptFormat, *, allow_extra_system_messages: bool = False 79 | ) -> Union[str, list[HarmonyMessage]]: 80 | """ 81 | Validates the messages added so far (reasonable alternation of assistant vs. user, etc.) 82 | and returns either a regular string (maybe with <|endofprompt|> tokens) or a list of 83 | HarmonyMessages suitable for use with the /chat/completions endpoint. 84 | 85 | The `allow_extra_system_messages` parameter allows the caller to specify that the prompt 86 | should be allowed to contain system messages after the very first one. 87 | """ 88 | # Create a deep copy of the messages so we can modify it and so that the caller can't 89 | # modify the internal state of this object. 90 | messages = [message.copy() for message in self._messages] 91 | 92 | expected_next_role = Role.SYSTEM 93 | for message in messages: 94 | role = message["role"] 95 | assert role == expected_next_role or ( 96 | allow_extra_system_messages and role == Role.SYSTEM 97 | ), f"Expected message from {expected_next_role} but got message from {role}" 98 | if role == Role.SYSTEM: 99 | expected_next_role = Role.USER 100 | elif role == Role.USER: 101 | expected_next_role = Role.ASSISTANT 102 | elif role == Role.ASSISTANT: 103 | expected_next_role = Role.USER 104 | 105 | if prompt_format == PromptFormat.INSTRUCTION_FOLLOWING: 106 | last_user_message = None 107 | for message in messages: 108 | if message["role"] == Role.USER: 109 | last_user_message = message 110 | assert last_user_message is not None 111 | last_user_message["content"] += "<|endofprompt|>" 112 | 113 | if prompt_format == PromptFormat.HARMONY_V4: 114 | return messages 115 | elif prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: 116 | return "".join(message["content"] for message in messages) 117 | else: 118 | raise ValueError(f"Unknown prompt format: {prompt_format}") 119 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/puzzles.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors": { 3 | "name": "colors", 4 | "explanation": "words related to colors and nothing else", 5 | "sentences": [ 6 | [["Green", 9 ], " smoke", ",", " and", [" yellow", 9 ], " flames", ",", " and", [" red", 9 ], " fire", ",", " and", [" black", 9 ], " ash", ",", " and", " the", " smell", " of", " burning", "."], 7 | ["His", " shoes", " were", [" black", 9 ], ",", " but", " his", " socks", " were", " bright", [" yellow", 10 ], "."], 8 | ["On", " a", " run", " down", " ship", ",", " sailing", " a", " wine", ["-dark", 7 ], " sea", ",", " I", " searched", " for", " the", " island", " of", " my", " birth", "."], 9 | [["Red", 9 ], ",", [" yellow", 9 ], ",", [" orange", 9 ], ",", [" purple", 9 ], ":", " the", " diversity", " and", " aroma", " of", " spices", " overwhelmed", " her", " senses", "."], 10 | ["Some", " might", " say", " that", [" turquoise", 9 ], " is", [" blue", 9 ], ",", " but", " I", " think", " it", "'s", " more", " of", " a", [" blue", 9 ], ["-green", 9 ], "."], 11 | ["With", " the", " dog", " days", " of", " august", " upon", " us", ",", " think", " of", " this", " dog", " of", " a", " book", " as", " the", " literary", " equivalent", " of", " high", " humidity", "."] 12 | ], 13 | "false_explanations": [ 14 | "adjectives and descriptors in general", 15 | "references to aromas and the sense of smell", 16 | "instances of the words \"red\", \"black\", and \"yellow\"" 17 | ] 18 | }, 19 | "char2": { 20 | "name": "char2", 21 | "explanation": "the name of the second named character introduced in the passage, excluding the narrator", 22 | "sentences": [ 23 | ["Once", " upon", " a", " time", ",", " there", " was", " a", " brave", " knight", " named", " Art", " who", " lived", " in", " a", " magnificent", " castle", ".", " One", " day", ",", " a", " beautiful", " princess", " named", [" Gwen", 9 ], " came", " to", " visit", " the", " kingdom", ".", " Art", " was", " immediately", " capt", "ivated", " by", " her", " beauty", " and", " grace", ".", " He", " vowed", " to", " protect", " her", " and", " always", " keep", " her", " safe", ".", " From", " that", " moment", " on", ",", " Art", " and", [" Gwen", 10 ], " were", " inse", "parable", " and", " they", " lived", " happily", " ever", " after", "."], 24 | ["Another", " time", ",", " when", " Sil", "as", " left", " the", " city", " together", " with", [" Giles", 9 ], ",", " to", " fetch", " some", " livestock", " for", " their", " brothers", " and", " teachers", ",", " Sil", "as", " began", " to", " speak", "."], 25 | ["Once", " upon", " a", " time", ",", " there", " was", " a", " brave", " knight", " named", " Arthur", " who", " lived", " in", " a", " magnificent", " castle", ".", " One", " day", ",", " Arthur", " went", " to", " visit", " another", " kingdom", ".", " While", " he", " was", " there", ",", " he", " went", " on", " many", " adventures", " and", " faced", " many", " per", "ils", ",", " but", " he", " always", " came", " through", ".", " Arthur", " returned", " to", " his", " castle", " and", " lived", " happily", " ever", " after", "."], 26 | ["Ben", " enjoyed", " apples", ",", " while", [" Alice", 9 ], " enjoyed", " bananas", ".", [" Alice", 9 ], " found", " it", " funny", " that", " the", " first", " letter", " of", " their", " favorite", " fruits", " matched", " the", " first", " letter", " of", " the", " other", "'s", " name", ",", " but", " Ben", " did", " not", "."], 27 | ["Kate", ",", [" Eve", 9 ], ",", " and", " Claire", " were", " all", " siblings", ".", " Kate", " was", " the", " oldest", ",", [" Eve", 9 ], " the", " second", " oldest", ",", " and", " Claire", " the", " youngest", "."], 28 | ["Chocolate", " cake", " or", " apple", " pie", " -", " Sam", " couldn", "'t", " decide", " which", " he", " liked", " better", "."], 29 | ["T", "ess", ",", [" Jack", 9 ], ",", " Mark", ",", " and", " Jane", " were", " bored", " on", " a", " rainy", " Saturday", " afternoon", ".", " They", " had", " played", " all", " their", " board", " games", ",", " watched", " all", " their", " favorite", " movies", ",", " and", " read", " all", " their", " books", ".", " They", " wanted", " to", " do", " something", " fun", " and", " exciting", ",", " but", " they", " couldn", "'t", " go", " outside", " or", " visit", " their", " friends", ".", " Tess", " decided", " to", " make", " up", " an", " adventure", " game", ",", " using", " her", " imagination", " and", " whatever", " she", " could", " find", " in", " the", " house", ".", [" Jack", 9 ], " turned", " the", " living", " room", " into", " a", " jungle", ",", " the", " kitchen", " into", " a", " spaceship", ",", " the", " basement", " into", " a", " dungeon", ",", " and", " the", " attic", " into", " a", " treasure", " island", ".", " Mark", " pretended", " to", " be", " an", " explorer", ",", " then", " an", " astronaut", "."] 30 | ], 31 | "false_explanations": [ 32 | "mentions of particular characters' names and nothing else", 33 | "names of characters repeated multiple times in the same passage", 34 | "names of female characters", 35 | "instances of the words \"Gwen\", \"Giles\", and \"Alice\"" 36 | ] 37 | }, 38 | "similes": { 39 | "name": "similes", 40 | "explanation": "phrases that are similes and nothing else", 41 | "sentences": [ 42 | ["What", "'s", " so", " fun", " about", " this", " silly", ",", " outrageous", ",", " ingenious", " thriller", " is", " the", " director", "'s", " talent", ".", " Watching", " a", " Brian", " De", " Pal", "ma", " movie", " is", [" like", 9 ], [" watching", 9 ], [" an", 9 ], [" Alfred", 9 ], [" Hitch", 9 ], ["cock", 9 ], [" movie", 9 ], [" after", 9 ], [" drinking", 9 ], [" twelve", 9 ], [" beers", 9 ], "."], 43 | ["Going", " to", " this", " concert", " is", " a", " little", [" like", 9 ], [" chewing", 8 ], [" whale", 9 ], [" bl", 9 ], ["ubber", 9 ], " -", " it", "'s", " an", " acquired", " taste", " that", " takes", " time", " to", " enjoy", ",", " but", " it", "'s", " worth", " it", ",", " even", " if", " it", " does", " take", " ", "3", " hours", " to", " get", " through", "."], 44 | ["The", " ma", "ud", "lin", " focus", " on", " the", " young", " woman", "'s", " inf", "irm", "ity", " and", " her", " naive", " dreams", " play", [" like", 9 ], [" the", 9 ], [" worst", 9 ], [" kind", 9 ], [" of", 9 ], [" heart", 9 ], ["-string", 9 ], [" pl", 9 ], ["ucking", 9 ], "."], 45 | ["Don", "\u2019t", " trust", " his", " words", ",", " he", " is", " cunning", [" like", 9 ], [" a", 9 ], [" fox", 9 ], "."], 46 | ["My", " uncle", " is", " as", " blind", [" as", 9 ], [" a", 9 ], [" bat", 9 ], " without", " his", " spect", "acles", "."], 47 | ["The", " maid", " has", " done", " a", " good", " job", ",", " and", " the", " hall", " is", " as", " clean", [" as", 9 ], [" a", 9 ], [" whistle", 10 ] ], 48 | ["My", " grandmother", " may", " seem", " scary", " to", " others", ",", " but", " she", " is", " as", " gentle", [" as", 9 ], [" a", 9 ], [" lamb", 9 ], "."], 49 | ["Once", " upon", " a", " time", ",", " there", " was", " a", " brave", " knight", " named", " Arthur", " who", " lived", " in", " a", " magnificent", " castle", ".", " One", " day", ",", " Arthur", " went", " to", " visit", " another", " kingdom", ".", " While", " he", " was", " there", ",", " he", " went", " on", " many", " adventures", " and", " faced", " many", " per", "ils", ",", " but", " he", " always", " came", " through", ".", " Arthur", " returned", " to", " his", " castle", " and", " lived", " happily", " ever", " after", "."], 50 | ["His", " shoes", " were", " black", ",", " but", " his", " socks", " were", " bright", " yellow", "."] 51 | ], 52 | "false_explanations": [ 53 | "pieces of criticism and nothing else", 54 | "references to uncommon expressions or turns of phrase", 55 | "phrases relating to extreme experiences", 56 | "instances of the words \"like\", \"as\", and \"a\"" 57 | ] 58 | }, 59 | "idioms": { 60 | "name": "idioms", 61 | "explanation": "the substitution of a key word in a common idiom that isn't the usual choice for that idiom", 62 | "sentences": [ 63 | ["Take", " Alice", "'s", " advice", " with", " a", " grain", " of", [" sodium", 10 ], " -", " she", " always", " has", " her", " head", " in", " the", " clouds", "."], 64 | ["I", " told", " John", " that", " he", " was", " b", "arking", " up", " the", " wrong", [" bush", 9 ], ",", " that", " I", " didn", "'t", " have", " any", " information", " on", " his", " man", "."], 65 | ["Don", "'t", " worry", " about", " the", " outing", " tomorrow", ",", " it", "'ll", " be", " a", " piece", " of", " cake", "."], 66 | ["Don", "'t", " worry", " about", " the", " outing", " tomorrow", ",", " it", "'ll", " be", " a", " piece", " of", [" con", 9 ], ["fection", 9 ], "."], 67 | ["Take", " Alice", "'s", " advice", " with", " a", " grain", " of", " salt", " -", " she", " always", " has", " her", " head", " in", " the", " clouds", "."], 68 | ["I", " stopped", " the", " interview", " process", " once", " I", " received", " another", " offer", " that", " was", " about", " to", " expire", ".", " A", " bird", " in", " the", " hand", " is", " worth", " two", " in", " the", [" shr", 7 ], ["ub", 9 ], ["bery", 9 ], ",", " after", " all", "."], 69 | ["You", " should", " bite", " the", " bullet", " and", " just", " go", " talk", " to", " the", " professor", ",", " otherwise", " this", " problem", " won", "'t", " get", " better", "."], 70 | ["You", " should", " bite", " the", [" projectile", 9 ], " and", " just", " go", " talk", " to", " the", " professor", ",", " otherwise", " this", " problem", " won", "'t", " get", " better", "."], 71 | ["Stop", " beating", " around", " the", [" hedge", 9 ], " and", " get", " to", " the", " point", "."], 72 | ["Stop", " beating", " around", " the", " bush", " and", " get", " to", " the", " point", "."] 73 | ], 74 | "false_explanations": [ 75 | "mentions of particular words in idioms", 76 | "references to uncommon words", 77 | "mentions of common items", 78 | "instances of the words \"sodium\", \"confection\", and \"projectile\"" 79 | ] 80 | }, 81 | "feet": { 82 | "name": "feet", 83 | "explanation": "words describing motion that uses feet (e.g. \"running\", \"walking\"), but not words describing motion that doesn\u2019t use feet (e.g. \"flying\", \"hovering\")", 84 | "sentences": [ 85 | ["She", [" ran", 9 ], " as", " fast", " as", " she", " could", " to", " catch", " the", " bus", ",", " but", " it", " was", " already", " pulling", " away", " from", " the", " stop", "."], 86 | ["He", [" walked", 9 ], " along", " the", " beach", ",", " feeling", " the", " sand", " between", " his", " toes", " and", " the", " breeze", " in", " his", " hair", "."], 87 | ["They", [" danced", 9 ], " to", " the", " rhythm", " of", " the", " music", ",", " spinning", " and", [" stepping", 9 ], " in", " sync", "."], 88 | ["He", [" kicked", 8 ], " the", " ball", " with", " all", " his", " might", ",", " hoping", " to", " score", " a", " goal", "."], 89 | ["She", " ti", ["pto", 7 ], ["ed", 8 ], " into", " the", " kitchen", ",", " trying", " not", " to", " wake", " up", " her", " parents", "."], 90 | ["She", " h", ["opped", 8 ], " on", " one", " foot", ",", " balancing", " a", " book", " on", " her", " head", "."], 91 | ["The", " eagle", " was", " flying", " high", " above", " the", " mountains", ",", " scanning", " the", " ground", " for", " prey", "."], 92 | ["She", " felt", " a", " surge", " of", " joy", " as", " she", " hovered", " over", " the", " water", " on", " her", " jet", "pack", ",", " feeling", " the", " wind", " in", " her", " hair", " and", " the", " spray", " on", " her", " face", "."], 93 | ["The", " helicopter", " landed", " on", " the", " rooftop", ",", " creating", " a", " loud", " noise", " and", " a", " gust", " of", " air", "."], 94 | ["He", " threw", " the", " ball", " with", " all", " his", " might", ",", " hoping", " to", " hit", " the", " target", " and", " win", " the", " prize", "."], 95 | ["The", " rocket", " blasted", " off", " into", " the", " sky", ",", " leaving", " a", " trail", " of", " smoke", " and", " fire", " behind", "."] 96 | ], 97 | "false_explanations": [ 98 | "words describing motion", 99 | "verbs and action phrases", 100 | "instances of the words \"ran\", \"walked\", and \"danced\"" 101 | ] 102 | }, 103 | "years": { 104 | "name": "years", 105 | "explanation": "historically inaccurate numerical years in the passage and nothing else", 106 | "sentences": [ 107 | ["The", " French", " Revolution", " began", " in", " A", ".D", ".", " ", ["168", 9 ], ["7", 9 ], " and", " ended", " in", " A", ".D", ".", " ", ["169", 9 ], ["9", 9 ], "."], 108 | ["The", " First", " World", " War", " broke", " out", " in", " A", ".D", ".", " ", ["151", 10 ], ["4", 9 ], " and", " lasted", " for", " four", " years", "."], 109 | ["The", " Declaration", " of", " Independence", " was", " signed", " in", " A", ".D", ".", " ", "177", ["0", 9 ], " by", " the", " representatives", " of", " the", " thirteen", " colonies", "."], 110 | ["The", " Great", " Depression", " began", " in", " A", ".D", ".", " ", ["202", 9 ], ["8", 9 ], " with", " the", " stock", " market", " crash", " and", " lasted", " until", " the", " end", " of", " the", " next", " decade", "."], 111 | ["The", " Berlin", " Wall", " was", " erected", " in", " A", ".D", ".", " ", ["100", 9 ], ["9", 9 ], " to", " divide", " the", " city", " into", " East", " and", " West", "."], 112 | ["The", " Apollo", " ", "11", " mission", " landed", " the", " first", " humans", " on", " the", " moon", " in", " AD", " ", ["195", 9 ], ["0", 9 ], "."], 113 | ["C", "olum", "bus", " sailed", " the", " ocean", " blue", " in", " the", " year", " ", "149", ["0", 10 ], " and", " landed", " in", " the", " new", " world", "."], 114 | ["The", " French", " Revolution", " began", " in", " the", " year", " ", "178", "9", " and", " ended", " in", " ", "179", "9", "."], 115 | ["The", " first", " World", " War", " broke", " out", " in", " A", ".D", ".", " ", "191", "4", " and", " lasted", " for", " four", " years", "."], 116 | ["The", " Declaration", " of", " Independence", " was", " signed", " in", " ", "177", "6", " by", " the", " representatives", " of", " the", " thirteen", " colonies", "."], 117 | ["The", " Apollo", " ", "11", " mission", " landed", " the", " first", " humans", " on", " the", " moon", " in", " A", ".D", ".", " ", "196", "9", "."] 118 | ], 119 | "false_explanations": [ 120 | "numerical years", 121 | "dates in American and European history", 122 | "instances of the years \"1009\", \"1687\", and \"1514\"" 123 | ] 124 | }, 125 | "math": { 126 | "name": "math", 127 | "explanation": "incorrect answers in grade school math problems", 128 | "sentences": [ 129 | ["K", "iera", " wants", " to", " make", " sure", " she", " has", " strong", " bones", ",", " so", " she", " drinks", " ", "2", " liters", " of", " milk", " every", " week", ".", " After", " ", "3", " weeks", ",", " how", " many", " liters", " of", " milk", " will", " K", "iera", " drink", "?", " Answer", ":", " After", " ", "3", " weeks", ",", " K", "iera", " will", " drink", " ", "6", " liters", " of", " milk", "."], 130 | ["K", "iera", " wants", " to", " make", " sure", " she", " has", " strong", " bones", ",", " so", " she", " drinks", " ", "2", " liters", " of", " milk", " every", " week", ".", " After", " ", "3", " weeks", ",", " how", " many", " liters", " of", " milk", " will", " K", "iera", " drink", "?", " Answer", ":", " After", " ", "3", " weeks", ",", " K", "iera", " will", " drink", " ", ["4", 10 ], " liters", " of", " milk", "."], 131 | ["A", "riel", " was", " playing", " basketball", ".", " ", "1", " of", " her", " shots", " went", " in", " the", " hoop", ".", " ", "2", " of", " her", " shots", " did", " not", " go", " in", " the", " hoop", ".", " How", " many", " shots", " were", " there", " in", " total", "?", " Answer", ":", " There", " were", " ", "3", " shots", " in", " total", "."], 132 | ["A", "riel", " was", " playing", " basketball", ".", " ", "1", " of", " her", " shots", " went", " in", " the", " hoop", ".", " ", "2", " of", " her", " shots", " did", " not", " go", " in", " the", " hoop", ".", " How", " many", " shots", " were", " there", " in", " total", "?", " Answer", ":", " There", " were", " ", ["2", 10 ], " shots", " in", " total", "."], 133 | ["The", " restaurant", " has", " ", "175", " normal", " chairs", " and", " ", "20", " chairs", " for", " babies", ".", " How", " many", " chairs", " does", " the", " restaurant", " have", " in", " total", "?", " Answer", ":", " ", "195"], 134 | ["The", " restaurant", " has", " ", "175", " normal", " chairs", " and", " ", "20", " chairs", " for", " babies", ".", " How", " many", " chairs", " does", " the", " restaurant", " have", " in", " total", "?", " Answer", ":", " ", ["295", 10 ] ], 135 | ["L", "ily", " has", " ", "12", " stickers", " and", " she", " wants", " to", " share", " them", " equally", " with", " her", " ", "3", " friends", ".", " How", " many", " stickers", " will", " each", " person", " get", "?", " Answer", ":", " Each", " person", " will", " get", " ", ["5", 9 ], " stickers", "."], 136 | ["L", "ily", " has", " ", "12", " stickers", " and", " she", " wants", " to", " share", " them", " equally", " with", " her", " ", "3", " friends", ".", " How", " many", " stickers", " will", " each", " person", " get", "?", " Answer", ":", " Each", " person", " will", " get", " ", "4", " stickers", "."] 137 | ], 138 | "false_explanations": [ 139 | "single digit numbers", 140 | "numbers in general", 141 | "answers to grade school math problems", 142 | "instances of the numbers 2, 4, and 5" 143 | ] 144 | }, 145 | "complexity": { 146 | "name": "complexity", 147 | "explanation": "technical phrases related to computational complexity theory", 148 | "sentences": [ 149 | ["The", " time", [" complexity", 10 ], " of", " a", " sorting", " algorithm", " can", " be", " measured", " using", " the", " big", [" O", 9 ], [" notation", 10 ], "."], 150 | ["Some", " problems", ",", " such", " as", " the", " traveling", [" salesman", 9 ], [" problem", 10 ], ",", " are", " known", " to", " be", [" NP", 10 ], ["-hard", 10 ], ",", " meaning", " they", " are", " at", " least", " as", " hard", " as", " the", " hardest", " problems", " in", [" NP", 10 ], "."], 151 | ["B", "read", "th", " and", " depth", " first", " search", " run", " in", [" linear", 9 ], [" time", 10 ], " with", " respect", " to", " the", " number", " of", " nodes", " in", " the", " graph", "."], 152 | ["If", " we", " assume", " that", " all", " possible", " permutations", " of", " the", " input", " list", " are", " equally", " likely", ",", " the", " average", " time", " taken", " for", " sorting", " using", " quick", "sort", " is", [" O", 10 ], ["(n", 10 ], [" log", 10 ], [" n", 10 ], [").", 10 ] ], 153 | ["Because", " the", " clique", " problem", " cracks", " the", " ", "3", ["SAT", 10 ], [" problem", 10 ], ",", " we", " can", " say", " that", " clique", " is", [" NP", 10 ], ["-complete", 10 ], "."], 154 | ["A", " problem", " is", " in", " P", ["SPACE", 10 ], " if", " it", " can", " be", [" computed", 10 ], " by", " a", [" Turing", 8 ], [" machine", 10 ], " using", [" polynomial", 10 ], [" space", 10 ], ",", " and", " P", ["SPACE", 10 ], " contains", " many", " problems", " that", " are", " believed", " to", " be", " harder", " than", [" NP", 10 ], ",", " such", " as", " Q", ["BF", 10 ], " and", " T", ["Q", 10 ], ["BF", 10 ], "."], 155 | ["Let", " M", " be", " a", " nond", "etermin", "istic", [" Turing", 9 ], [" machine", 10 ], " that", [" decides", 10 ], " L", " in", [" polynomial", 10 ], [" time", 10 ], "."] 156 | ], 157 | "false_explanations": [ 158 | "phrases relating to algorithms", 159 | "phrases relating to computers", 160 | "mathematical expressions", 161 | "instances of the words \"NP\", \"complexity\", and \"Turing\"" 162 | ] 163 | }, 164 | "an": { 165 | "name": "an", 166 | "explanation": "positions in the sentence where the next word is likely to be \"an\"", 167 | "sentences": [ 168 | ["I", " climbed", " a", " pear", " tree", " and", " picked", " a", " pear", ".", " I", " climbed", " an", " apple", " tree", " and", [" picked", 9 ], " an", " apple", "."], 169 | ["Looking", " for", " an", " easy", " way", " to", " protest", " Bush", " foreign", " policy", " week", " after", " week", "?", [" And", 6 ], " an", " easy", " way", " to", " help", " alleviate", " global", " poverty", "?", " Buy", " your", " gasoline", " at", " Cit", "go", " stations", ".", " Looking", [" for", 10 ], " an", " easy", " way", " to", " protest", " Bush", " foreign", " policy", " week", " after", " week", "?", [" And", 10 ], " an", " easy", " way", " to", " help", " alleviate", " global", " poverty", "?"], 170 | ["At", " one", " point", ",", " the", " tro", "oper", " said", " the", " car", " was", " going", " over", " ", "100", [" miles", 7 ], " an", " hour", "."], 171 | ["As", " an", " undergrad", ",", " I", " spent", " much", " of", " my", " time", " reading", " Aristotle", ".", " I", " spent", " many", " restless", " nights", [" as", 10 ], " an", " undergrad", " tossing", " and", " turning", ",", " trying", " to", " make", " sense", " of", " these", " sentences", "."], 172 | ["It", " took", " me", " a", " long", " time", " to", " fall", " asleep", " last", " night", ".", " I", " laid", " in", " bed", [" for", 1 ], [" almost", 8 ], " an", " hour", "."], 173 | ["At", " one", " point", ",", " the", " trooper", " said", " the", " car", " was", " going", " over", " 100", [" miles", 7 ], " per", " hour", "."], 174 | ["Looking", " for", " the", " easy", " way", " to", " protest", " Bush", " foreign", " policy", " week", " after", " week", "?", " And", " the", " easy", " way", " to", " help", " alleviate", " global", " poverty", "?", " Buy", " your", " gasoline", " at", " Cit", "go", " stations", ".", " Looking", " for", " the", " easy", " way", " to", " protest", " Bush", " foreign", " policy", " week", " after", " week", "?", " And", " the", " easy", " way", " to", " help", " alleviate", " global", " poverty", "?"] 175 | ], 176 | "false_explanations": [ 177 | "words which are likely to be followed by the word \"a\"", 178 | "language related to something being significant or intense", 179 | "instances of the words \"picked\", \"for\", and \"as\"", 180 | "positions in the sentence where the next word is likely to be \"a\"" 181 | ] 182 | }, 183 | "time": { 184 | "name": "time", 185 | "explanation": "time phrases proceeded by the word \"in\"", 186 | "sentences": [ 187 | ["I", " haven", "'t", " been", " to", " Z", "uma", " Beach", " in", [" quite", 7 ], [" some", 8 ], [" time", 10 ], ".", " I", " had", " forgotten", " how", " much", " I", " love", " it", " there", "."], 188 | ["There", " is", " a", " level", " of", " excitement", " in", " the", " market", ".", " We", "'re", " seeing", " levels", " we", " haven", "'t", " seen", " in", [" some", 8 ], [" time", 9 ], ",", " and", " people", " are", " shouting", "."], 189 | ["Out", "k", "ast", " hit", " the", " stage", " for", " the", " first", " time", " in", [" nearly", 8 ], [" a", 8 ], [" decade", 10 ], " last", " weekend", " at", " the", " Coach", "ella", " Music", " and", " Arts", " festival", "."], 190 | ["The", " Augusta", " River", "H", "awks", " entered", " the", " weekend", " within", " striking", " distance", " of", " eighth", " place", " in", " the", " SP", "HL", " standings", " for", " the", " first", " time", " in", [" months", 10 ], "."], 191 | ["K", "aty", " Perry", " returned", " Thursday", " with", " her", " first", " new", " music", " in", [" months", 10 ], "."], 192 | ["It", " hasn", "'t", " stopped", " raining", " in", [" ten", 9 ], [" weeks", 9 ], "."], 193 | ["In", [" three", 9 ], [" years", 10 ], ",", " I", " expect", " this", " to", " be", " huge", "."], 194 | ["For", " the", " last", " time", " in", " a", [" three", 2 ], [" day", 9 ], [" period", 9 ], ",", " John", " stood", " up", " to", " speak", "."], 195 | ["It", " has", " been", " months", " since", " I", " last", " saw", " you", "."], 196 | ["My", " father", " worked", " at", " the", " factory", " for", " quite", " a", " long", " time", "."], 197 | ["It", " hasn", "'t", " stopped", " raining", " for", " ten", " weeks", "."], 198 | ["Three", " years", " from", " now", ",", " I", " expect", " this", " to", " be", " huge", "."], 199 | ["For", " the", " last", " time", " over", " the", " three", " days", " period", ",", " John", " stood", " up", " to", " speak", "."] 200 | ], 201 | "false_explanations": [ 202 | "phrases related to durations of time", 203 | "language related to the first or most recent instance of something", 204 | "instances of the words \"months\", \"some\", and \"time\"" 205 | ] 206 | }, 207 | "stop": { 208 | "name": "stop", 209 | "explanation": "language related to something being stopped, prevented, or halted, but only when negated", 210 | "sentences": [ 211 | ["But", " that", " didn", "'t", [" stop", 10 ], " it", " becoming", " one", " of", " the", " most", " popular", " products", " on", " the", " shelf", "."], 212 | ["Technology", " has", " changed", " quite", " a", " bit", " over", " Vernon", " Cook", "'s", " lifetime", ",", " but", " that", " hasn", "'t", [" stopped", 10 ], " him", " from", " embracing", " the", " advance", "."], 213 | ["The", " Storm", " and", " Sharks", " don", "'t", " have", " the", " same", " stor", "ied", " rivalry", " as", " some", " of", " the", " grand", " finalists", " in", " years", " gone", " by", ",", " but", " that", " hasn", "'t", [" halted", 10 ], " their", " captivating", " contests", " in", " recent", " times", "."], 214 | ["Of", " course", ",", " that", " didn", "'t", [" stop", 10 ], " audiences", " from", " going", " to", " see", " the", " movie", " in", " dro", "ves", ",", " and", " eventually", " the", " ", "201", "4", " Teen", "age", " Mut", "ant", " Ninja", " T", "urtles", " movie", " made", " enough", " money", " to", " warrant", " putting", " a", " sequel", " into", " production", "."], 215 | ["But", " that", " didn", "'t", [" keep", 10 ], " the", " veteran", " centre", " from", " r", "aving", " about", " the", " organization", " following", " the", " end", " of", " his", " brief", " tenure", "."], 216 | ["I", " won", "'t", [" stop", 10 ], " until", " I", " get", " there", "."], 217 | ["Michael", " isn", "'t", " going", " to", [" stop", 10 ], " thinking", " about", " how", " to", " solve", " the", " problem", "."], 218 | ["I", " can", "'t", [" stop", 10 ], " eating", " so", " many", " of", " the", " cookies", " they", " put", " out", " at", " ", "3", "pm", "."], 219 | ["But", " that", " stopped", " it", " becoming", " one", " of", " the", " most", " popular", " products", " on", " the", " shelf", "."], 220 | ["Technology", " has", " changed", " quite", " a", " bit", " over", " Vernon", " Cook", "'s", " lifetime", ",", " and", " that", " stopped", " him", " from", " embracing", " the", " advance", "."], 221 | ["I", " have", " to", " stop", " before", " I", " get", " there", "."], 222 | ["Michael", " should", " stop", " thinking", " about", " how", " to", " solve", " the", " problem", "."], 223 | ["I", " must", " stop", " eating", " so", " many", " of", " the", " cookies", " they", " put", " out", " at", " ", "3", "pm", "."] 224 | ], 225 | "false_explanations": [ 226 | "language related to something being stopped, prevented, or halted", 227 | "language related to sports and media", 228 | "instances of the words \"stop\", \"stopped\", and \"halted\"" 229 | ] 230 | }, 231 | "of": { 232 | "name": "of", 233 | "explanation": "positions in the passage following the word \"of\"", 234 | "sentences": [ 235 | ["I", "'ve", " had", " enough", " of", [" your", 10 ], " nonsense", "."], 236 | ["The", " birds", " seemed", " unaware", " of", [" the", 10 ], " cat", " lurking", " nearby", "."], 237 | ["A", " lot", " of", [" people", 10 ], " are", " struggling", " right", " now", "."], 238 | ["The", " fragrance", " of", [" fresh", 10 ], " bread", " filled", " the", " bakery", "."], 239 | ["I", "'m", " really", " proud", " of", [" my", 10 ], " accomplishments", " this", " year", "."], 240 | ["I", "'ve", " had", " it", " with", " your", " nonsense", "."], 241 | ["The", " birds", " seemed", " unaware", " that", " the", " cat", " was", " lurking", " nearby", "."], 242 | ["Many", " people", " are", " struggling", " right", " now", "."], 243 | ["The", " bakery", " smelled", " like", " fresh", " bread", "."], 244 | ["I", "'m", " really", " proud", " about", " my", " accomplishments", " this", " year", "."] 245 | ], 246 | "false_explanations": [ 247 | "the word \"of\" and nothing else", 248 | "common words", 249 | "instances of the words \"your\", \"cat\", and \"people\"" 250 | ] 251 | }, 252 | "million": { 253 | "name": "million", 254 | "explanation": "numbers greater than one million and nothing else", 255 | "sentences": [ 256 | ["The", " company", "'s", " revenues", " surpassed", " $", "2", ",", "000", ",", ["560", 10 ], " this", " year", ",", " marking", " a", " significant", " milestone", " in", " their", " growth", "."], 257 | ["The", " population", " of", " Shanghai", " is", " estimated", " to", " be", " around", " ", ["24", 10 ], [",", 10 ], ["000", 10 ], [",", 10 ], ["000", 10 ], ",", " making", " it", " one", " of", " the", " world", "'s", " largest", " cities", "."], 258 | ["The", " auction", " house", " estimated", " the", " painting", "'s", " value", " at", " $", "5", ",", "500", ",", ["001", 10 ], ",", " but", " it", " ultimately", " sold", " for", " nearly", " twice", " that", " amount", "."], 259 | ["In", " ", "201", "8", ",", " over", " ", "2", ",", "300", ",", ["000", 10 ], " people", " attended", " L", "oll", "ap", "al", "oo", "za", " music", " festival", " in", " Chicago", "."], 260 | ["With", " a", " net", " worth", " of", " $", "18", ",", "000", ",", ["000", 10 ], [",", 10 ], ["000", 10 ], ",", " the", " celebrity", " has", " been", " able", " to", " afford", " a", " luxurious", " lifestyle", " for", " her", " and", " her", " family", "."], 261 | ["The", " company", "'s", " revenues", " surpassed", " $", "999", ",", "999", " this", " year", ",", " marking", " a", " significant", " milestone", " in", " their", " growth", "."], 262 | ["The", " population", " of", " Shanghai", " is", " estimated", " to", " be", " around", " ", "24", ",", "560", ",", " making", " it", " one", " of", " the", " world", "'s", " largest", " cities", "."], 263 | ["The", " auction", " house", " estimated", " the", " painting", "'s", " value", " at", " $", "450", ",", "001", ",", " but", " it", " ultimately", " sold", " for", " nearly", " twice", " that", " amount", "."], 264 | ["In", " ", "201", "8", ",", " over", " ", "230", "0", " people", " attended", " L", "oll", "ap", "al", "oo", "za", " music", " festival", " in", " Chicago", "."], 265 | ["With", " a", " net", " worth", " of", " $", "990", "000", ",", " the", " celebrity", " has", " been", " able", " to", " afford", " a", " luxurious", " lifestyle", " for", " her", " and", " her", " family", "."] 266 | ], 267 | "false_explanations": [ 268 | "sequences of digits", 269 | "large numbers", 270 | "the sequences of digits \"000\", \"001\", and \"560\"" 271 | ] 272 | }, 273 | "people": { 274 | "name": "people", 275 | "explanation": "numbers if and only if they refer to a number of people", 276 | "sentences": [ 277 | ["I", " went", " to", " the", " party", " and", " was", " surprised", " to", " see", " that", " out", " of", " our", " whole", " group", ",", " there", " were", " only", " ", ["2", 10 ], " people", " there", "."], 278 | ["Our", " company", " is", " growing", " quickly", " and", " we", " have", " now", " hired", " more", " than", " ", ["150", 10 ], " employees", "."], 279 | ["The", " bus", " was", " packed", " with", " people", ":", " at", " least", " ", ["50", 10 ], " total", "."], 280 | ["The", " teacher", " divided", " the", " class", " of", " ", ["30", 10 ], " students", " into", " six", " groups", " of", " ", ["5", 10 ], "."], 281 | ["I", " was", " one", " of", " ", ["12", 10 ], " people", " who", " were", " selected", " for", " the", " jury", "."], 282 | ["I", " have", " quite", " a", " few", " books", " that", " I", " need", " to", " return", " to", " the", " library", " -", " almost", " ", "20", " in", " total", "."], 283 | ["It", " was", " a", " long", " hike", ",", " but", " we", " finally", " made", " it", " to", " the", " top", " after", " close", " to", " ", "6", " hours", "."], 284 | ["I", " won", " $", "20", " in", " the", " r", "affle", "."], 285 | ["My", " dad", " has", " been", " collecting", " coins", " for", " years", " and", " now", " has", " over", " ", "400", " of", " them", "."], 286 | ["For", " the", " recipe", " I", " want", " to", " make", ",", " I", " need", " to", " buy", " ", "8", " eggs", "."] 287 | ], 288 | "false_explanations": [ 289 | "numbers", 290 | "numbers under one thousand", 291 | "instances of the numbers 2, 50, and 150" 292 | ] 293 | }, 294 | "python": { 295 | "name": "python", 296 | "explanation": "the character \"<\" if it occurs in the context of programming but not HTML or math", 297 | "sentences": [ 298 | ["assert", " cumulative", "_num", "_branch", "es", [" <", 10 ], " num", "_branch", "es", "_to", "_check", ",", " \"", "cum", "ulative", " number", " of", " branches", " checked", " should", " not", " exceed", " the", " target", "\"\n"], 299 | ["for", " i", " in", " range", "(", "100", "):\n", " ", " if", " i", [" <", 10 ], " ", "50", ":\n", " ", " print", "(i", ")\n"], 300 | ["print", "(\"", "x", " is", " smaller", " than", " y", "\"", " if", " x", [" <", 10 ], " y", " else", " \"", "x", " is", " not", " smaller", " than", " y", "\")\n"], 301 | ["age", " =", " int", "(input", "(\"", "Enter", " your", " age", ":", " \"))\n", "if", " age", [" <", 10 ], " ", "18", ":\n", " ", " print", "(\"", "You", "'re", " not", " old", " enough", " to", " vote", ".\")\n", "else", ":\n", " ", " print", "(\"", "You", "'re", " old", " enough", " to", " vote", ".\")\n"], 302 | ["", " <", "head", ">", " <", "title", ">Hello", " World", ""], 303 | ["<", "label", " for", "=\"", "username", "\">", "Username", ":<", "br", " /><", "input", " type", "=\"", "text", "\"", " name", "=\"", "username", "\"", " /><", "br", " /><", "label", " for", "=\"", "password", "\">", "Password", ":<", "br", " /><", "input", " type", "=\"", "password", "\"", " name", "=\"", "password", "\"", " /><", "br", " /><", "input", " type", "=\"", "submit", "\"", " value", "=\"", "Login", "\"", " />"], 304 | ["", " <", "li", ">", "Red", "", " <", "li", ">", "Blue", "", " <", "li", ">", "Green", "", " "], 305 | ["", " <", "th", ">", "First", " Name", "", " <", "th", ">Last", " Name", "", " "], 306 | ["for", " (", "int", " i", " =", " 0", ";", " i", [" <", 10 ], " 10", ";", " ++", "i", ")", " {", "\n", " ", " //", " do", " something", " ten", " times", "\n", "}"], 307 | ["alph", "abet", "ically", "Before", " ::", " String", " ->", " String", " ->", " B", "ool", "\n", "alph", "abet", "ically", "Before", " str", "1", " str", "2", " =", " str", "1", [" <", 10 ], " str", "2"], 308 | ["By", " definition", ",", " the", " spectral", " radius", " of", " A", " is", " the", " maximum", " of", " the", " absolute", " values", " of", " its", " e", "igen", "values", ".", " We", " know", " that", " each", " e", "igen", "value", " of", " A", " has", " an", " absolute", " value", " strictly", " less", " than", " 1", ".", " Thus", ",", " the", " maximum", " will", " also", " be", " strictly", " less", " than", " 1", ",", " so", " the", " spectral", " radius", " of", " A", " is", " <", " 1", "."], 309 | ["By", " definition", ",", " the", " Fro", "ben", "ius", " norm", " is", " the", " square", " root", " of", " the", " sum", " of", " the", " squared", " matrix", " entries", ".", " We", " know", " that", " every", " entry", " in", " matrix", " A", " is", " strictly", " smaller", " than", " the", " corresponding", " entry", " in", " matrix", " B", ".", " So", ",", " the", " sum", " of", " the", " squared", " matrix", " entries", " in", " A", " will", " be", " strictly", " smaller", " than", " the", " sum", " of", " the", " squared", " matrix", " entries", " in", " B", ".", " Taking", " square", " roots", " preserves", " this", " order", ",", " hence", " ||", "A", "||", "_", "F", " <", " ||", "B", "||", "_", "F", "."] 310 | ], 311 | "false_explanations": [ 312 | "the character \"<\"", 313 | "less than symbols", 314 | "HTML tags" 315 | ] 316 | }, 317 | "html": { 318 | "name": "html", 319 | "explanation": "the character \"<\" if it occurs in the context of HTML but not programming or math", 320 | "sentences": [ 321 | ["assert", " cumulative", "_num", "_branch", "es", " <", " num", "_branch", "es", "_to", "_check", ",", " \"", "cum", "ulative", " number", " of", " branches", " checked", " should", " not", " exceed", " the", " target", "\"\n"], 322 | ["for", " i", " in", " range", "(", "100", "):\n", " ", " if", " i", " <", " ", "50", ":\n", " ", " print", "(i", ")\n"], 323 | ["print", "(\"", "x", " is", " smaller", " than", " y", "\"", " if", " x", " <", " y", " else", " \"", "x", " is", " not", " smaller", " than", " y", "\")\n"], 324 | ["age", " =", " int", "(input", "(\"", "Enter", " your", " age", ":", " \"))\n", "if", " age", " <", " ", "18", ":\n", " ", " print", "(\"", "You", "'re", " not", " old", " enough", " to", " vote", ".\")\n", "else", ":\n", " ", " print", "(\"", "You", "'re", " old", " enough", " to", " vote", ".\")\n"], 325 | [["", [" <", 10 ], "head", ">", [" <", 10 ], "title", ">Hello", " World", [""], 326 | [["<", 10 ], "label", " for", "=\"", "username", "\">", "Username", [":<", 10 ], "br", [" /><", 10 ], "input", " type", "=\"", "text", "\"", " name", "=\"", "username", "\"", [" /><", 10 ], "br", [" /><", 10 ], "label", " for", "=\"", "password", "\">", "Password", [":<", 10 ], "br", [" /><", 10 ], "input", " type", "=\"", "password", "\"", " name", "=\"", "password", "\"", [" /><", 10 ], "br", [" /><", 10 ], "input", " type", "=\"", "submit", "\"", " value", "=\"", "Login", "\"", [" />"], 327 | [["", [" <", 10 ], "li", ">", "Red", ["", [" <", 10 ], "li", ">", "Blue", ["", [" <", 10 ], "li", ">", "Green", ["", [" "], 328 | [["", [" <", 10 ], "th", ">", "First", " Name", ["", [" <", 10 ], "th", ">Last", " Name", ["", [" "], 329 | ["for", " (", "int", " i", " =", " 0", ";", " i", " <", " 10", ";", " ++", "i", ")", " {", "\n", " ", " //", " do", " something", " ten", " times", "\n", "}"], 330 | ["alph", "abet", "ically", "Before", " ::", " String", " ->", " String", " ->", " B", "ool", "\n", "alph", "abet", "ically", "Before", " str", "1", " str", "2", " =", " str", "1", " <", " str", "2"], 331 | ["By", " definition", ",", " the", " spectral", " radius", " of", " A", " is", " the", " maximum", " of", " the", " absolute", " values", " of", " its", " e", "igen", "values", ".", " We", " know", " that", " each", " e", "igen", "value", " of", " A", " has", " an", " absolute", " value", " strictly", " less", " than", " 1", ".", " Thus", ",", " the", " maximum", " will", " also", " be", " strictly", " less", " than", " 1", ",", " so", " the", " spectral", " radius", " of", " A", " is", " <", " 1", "."], 332 | ["By", " definition", ",", " the", " Fro", "ben", "ius", " norm", " is", " the", " square", " root", " of", " the", " sum", " of", " the", " squared", " matrix", " entries", ".", " We", " know", " that", " every", " entry", " in", " matrix", " A", " is", " strictly", " smaller", " than", " the", " corresponding", " entry", " in", " matrix", " B", ".", " So", ",", " the", " sum", " of", " the", " squared", " matrix", " entries", " in", " A", " will", " be", " strictly", " smaller", " than", " the", " sum", " of", " the", " squared", " matrix", " entries", " in", " B", ".", " Taking", " square", " roots", " preserves", " this", " order", ",", " hence", " ||", "A", "||", "_", "F", " <", " ||", "B", "||", "_", "F", "."] 333 | ], 334 | "false_explanations": [ 335 | "the character \"<\"", 336 | "less than symbols", 337 | "HTML tags" 338 | ] 339 | }, 340 | "int": { 341 | "name": "int", 342 | "explanation": "integer variables in C programs", 343 | "sentences": [ 344 | ["int", [" a", 10 ], " =", " ", "20", ";\n", "float", " b", " =", " ", "2", ".", "2", "\n", "str", " c", " =", " \"", "a", "\";\n", "printf", "(\"%", "d", "\",", ["a", 10 ], ")\n", "printf", "(\"%", "f", "\",", ["a", 10 ], " +", " b", ")\n", "printf", "(\"%", "s", "\",", "c", ")\n"], 345 | ["int", [" a", 10 ], ",", [" b", 10 ], ",", [" product", 10 ], ";\n", "printf", "(\"", "Enter", " two", " integers", ":", " \");\n", "scanf", "(\"%", "d", " %", "d", "\",", " &", ["a", 10 ], ",", " &", ["b", 10 ], ");\n", "product", " =", [" a", 10 ], " *", [" b", 10 ], ";\n", "printf", "(\"", "The", " product", " is", ":", " %", "d", "\n", "\",", [" product", 10 ], ");\n"], 346 | ["int", " main", "()", " {\n", " ", " char", " c", " =", " '", "c", "';\n", " ", " return", " c", " ==", " '", "c", "';\n", "}\n"], 347 | ["if", " (", "n", " //", " ", "2", " >", " ", "10", ")", " {\n", " ", " printf", "(\"", "The", " number", " is", " big", ".\n", "\");\n", "}", " else", " {\n", " ", " printf", "(\"", "The", " number", " is", " small", ".\n", "\");\n", "}\n"], 348 | ["#include", " <", "stdio", ".h", ">\n", "#define", " PI", " ", "3", ".", "141", "592", "653", "589", "793", "238", "46", "\n\n", "int", " main", "()", " {\n", " ", " float", " radius", ",", " area", ";\n", " ", " printf", "(\"", "Enter", " the", " radius", " of", " the", " circle", ":", " \");\n", " ", " scanf", "(\"%", "f", "\",", " &", "radius", ");\n", " ", " area", " =", " PI", " *", " radius", " *", " radius", ";\n", " ", " printf", "(\"", "The", " area", " of", " the", " circle", " is", ":", " %", "f", "\n", "\",", " area", ");\n", " ", " return", " ", "0", ";\n", "}\n"], 349 | ["char", " str", "[", "100", "],", " rev", "[", "100", "];\n", "int", [" len", 10 ], ",", [" i", 10 ], ",", [" j", 10 ], ";\n", "printf", "(\"", "Enter", " a", " string", ":", " \");\n", "gets", "(str", ");\n", "len", " =", " strlen", "(str", ");\n", ["j", 10 ], " =", " len", " -", " ", "1", ";\n", "for", " (", ["i", 10 ], " =", " ", "0", ";", [" i", 10 ], " <", [" len", 10 ], ";", [" i", 10 ], "++)", " {\n", " ", " rev", ["[i", 10 ], "]", " =", " str", ["[j", 10 ], "];\n", " ", [" j", 10 ], "--;\n", "}\n", "rev", ["[i", 10 ], "]", " =", " '", "\u0000", "';\n", "printf", "(\"", "The", " reversed", " string", " is", ":", " %", "s", "\n", "\",", " rev", ");\n"] 350 | ], 351 | "false_explanations": [ 352 | "variables in C programs", 353 | "numerical variables in C programs", 354 | "instances of the letters \"a\", \"i\", and \"j\"" 355 | ] 356 | }, 357 | "death": { 358 | "name": "death", 359 | "explanation": "phrases in a passage that indicate a character has died", 360 | "sentences": [ 361 | ["The", " two", " men", " fought", " fiercely", ",", " swords", " cl", "ashing", " in", " a", " deadly", " duel", ".", " But", " in", " the", " end", ",", " only", " one", " could", " emerge", " victorious", ".", " With", " a", " final", " powerful", " swing", ",", " John", " ran", " his", " blade", " through", " the", " other", " man", "'s", " chest", ",", [" ending", 8 ], [" his", 8 ], [" life", 10 ], "."], 362 | ["Ad", "eline", " knew", " she", " wouldn", "'t", " make", " it", " out", " of", " the", " burning", " building", ".", " The", " flames", " were", " too", " intense", ",", " and", " the", " smoke", " was", " choking", " her", ".", " As", " the", " flames", " engulf", "ed", " her", ",", " she", " thought", " of", " her", " loved", " ones", " one", " last", " time", " before", " succ", ["umbing", 9 ], [" to", 7 ], [" the", 7 ], [" fire", 10 ], "."], 363 | ["The", " gunshot", " rang", " out", " in", " the", " empty", " alley", ",", " and", " Jim", " fell", " to", " the", " ground", ",", " blood", " pooling", " around", " him", ".", " He", " tried", " to", " hold", " on", ",", " to", " fight", " for", " life", ",", " but", " the", " wound", " was", " too", " severe", ".", " As", " his", " vision", [" faded", 7 ], [" to", 7 ], [" black", 10 ], ",", " he", " wondered", " who", " would", " take", " care", " of", " his", " little", " girl", " now", "."], 364 | ["The", " zombie", " outbreak", " had", " taken", " everyone", " by", " surprise", ",", " and", " no", " one", " was", " more", " un", "prepared", " than", " Barry", ".", " He", " had", " managed", " to", " survive", " the", " initial", " onslaught", ",", " but", " he", "'d", " been", " bitten", " in", " the", " process", ".", " Now", ",", " as", " the", " virus", [" took", 7 ], [" over", 8 ], [" his", 8 ], [" body", 10 ], ",", " he", " knew", " it", " wouldn", "'t", " be", " long", " before", " he", " was", " one", " of", " them", "."], 365 | ["The", " doctor", " told", " T", "essa", " she", " only", " had", " a", " few", " months", " left", " to", " live", ".", " But", " she", " was", " determined", " to", " make", " the", " most", " of", " it", ".", " She", " traveled", ",", " spent", " time", " with", " loved", " ones", ",", " and", " tried", " new", " things", ".", " But", " when", " the", " end", " came", ",", " it", " was", " still", " a", " shock", ".", " T", "essa", [" passed", 10 ], [" away", 10 ], [" peacefully", 10 ], " in", " her", " sleep", ",", " surrounded", " by", " family"], 366 | ["The", " two", " men", " fought", " fiercely", ",", " swords", " cl", "ashing", " in", " a", " deadly", " duel", ".", " But", " in", " the", " end", ",", " only", " one", " could", " emerge", " victorious", ".", " With", " a", " final", " powerful", " swing", ",", " John", " dis", "armed", " the", " other", " man", ",", " ending", " the", " fight", "."], 367 | ["Ad", "eline", " knew", " she", " wouldn", "'t", " make", " it", " out", " of", " the", " burning", " building", ".", " The", " flames", " were", " too", " intense", ",", " and", " the", " smoke", " was", " choking", " her", ".", " But", " just", " as", " she", " was", " about", " to", " lose", " consciousness", ",", " a", " firefighter", " burst", " through", " the", " door", " and", " pulled", " her", " to", " safety", "."], 368 | ["The", " gunshot", " rang", " out", " in", " the", " empty", " alley", ",", " and", " Jim", " fell", " to", " the", " ground", ",", " blood", " pooling", " around", " him", ".", " He", " tried", " to", " hold", " on", ",", " to", " fight", " for", " life", ",", " and", " by", " some", " miracle", ",", " he", " managed", " to", " stay", " alive", " until", " help", " arrived", "."], 369 | ["The", " zombie", " outbreak", " had", " taken", " everyone", " by", " surprise", ",", " and", " no", " one", " was", " more", " un", "prepared", " than", " Barry", ".", " He", " had", " managed", " to", " survive", " the", " initial", " onslaught", ",", " but", " he", "'d", " been", " bitten", " in", " the", " process", ".", " However", ",", " the", " bite", " didn", "'t", " seem", " to", " affect", " him", ",", " and", " he", " went", " on", " to", " help", " others", " survive", " the", " outbreak", "."], 370 | ["The", " doctor", " told", " T", "essa", " she", " only", " had", " a", " few", " months", " left", " to", " live", ".", " But", " she", " was", " determined", " to", " make", " the", " most", " of", " it", ".", " She", " traveled", ",", " spent", " time", " with", " loved", " ones", ",", " and", " tried", " new", " things", ".", " And", " against", " all", " odds", ",", " her", " health", " improved", ",", " and", " she", " went", " on", " to", " live", " a", " long", " and", " fulfilling", " life", "."] 371 | ], 372 | "false_explanations": [ 373 | "phrases indicating violence", 374 | "momentous events in characters' lives", 375 | "instances of the words \"life\", \"passed\", and \"black\"" 376 | ] 377 | }, 378 | "motivation": { 379 | "name": "motivation", 380 | "explanation": "phrases indicating a character's motivation", 381 | "sentences": [ 382 | ["Mary", " always", [" fantas", 8 ], ["ized", 7 ], [" about", 8 ], [" running", 8 ], [" a", 8 ], [" marathon", 8 ], ",", " so", " she", " trained", " for", " a", " whole", " summer", " and", " signed", " up", " for", " one", "."], 383 | ["As", " a", " child", ",", " James", " had", " been", [" fascinated", 8 ], [" by", 8 ], [" the", 8 ], [" stars", 8 ], ".", " Now", ",", " as", " an", " ast", "roph", "ys", "ic", "ist", ",", " he", " was", [" driven", 8 ], [" by", 8 ], [" the", 8 ], [" desire", 8 ], [" to", 8 ], [" understand", 8 ], [" the", 8 ], [" universe", 8 ], "."], 384 | ["It", " wasn", "'t", " greed", " that", " made", " Tom", " work", " overtime", " at", " the", " factory", ".", " He", " just", [" wanted", 8 ], [" to", 8 ], [" be", 8 ], [" able", 8 ], [" to", 8 ], [" give", 8 ], [" his", 8 ], [" kids", 8 ], [" a", 8 ], [" better", 8 ], [" life", 8 ], [" than", 8 ], [" he", 8 ], ["'d", 8 ], [" had", 8 ], "."], 385 | ["Mad", "ison", "'s", " friends", " had", " always", " told", " her", " she", " was", " a", " great", " singer", ",", " but", " she", " never", " really", " believed", " them", ".", " It", " wasn", "'t", " until", " her", " grandmother", " passed", " away", " and", " left", " her", " a", " note", " telling", " her", " to", " pursue", " her", " dreams", " that", " she", " finally", [" found", 9 ], [" the", 9 ], [" courage", 9 ], [" to", 9 ], [" audition", 9 ], " for", " the", " local", " talent", " show", "."], 386 | ["Rachel", " didn", "'t", " like", " spending", " her", " weekends", " volunteering", " at", " the", " soup", " kitchen", ".", " But", " she", " was", [" determined", 8 ], [" to", 8 ], [" prove", 8 ], [" to", 9 ], [" her", 9 ], [" parents", 9 ], [" that", 10 ], [" she", 10 ], [" wasn", 10 ], ["'t", 10 ], [" the", 10 ], [" selfish", 10 ], [",", 10 ], [" spoiled", 10 ], [" teenager", 10 ], [" they", 10 ], [" thought", 10 ], [" she", 10 ], [" was", 10 ] ], 387 | ["Mary", " trained", " for", " a", " whole", " summer", " and", " signed", " up", " for", " a", " marathon", "."], 388 | ["As", " an", " ast", "roph", "ys", "ic", "ist", ",", " James", " studied", " the", " universe", "."], 389 | ["Tom", " worked", " overtime", " at", " the", " factory", "."], 390 | ["Mad", "ison", " audition", "ed", " for", " the", " local", " talent", " show", "."], 391 | ["Rachel", " spent", " her", " weekends", " volunteering", " at", " the", " soup", " kitchen", "."] 392 | ], 393 | "false_explanations": [ 394 | "phrases describing a character", 395 | "phrases describing a person's inner monologue", 396 | "instances of the words \"the\", \"to\", and \"by\"" 397 | ] 398 | } 399 | } -------------------------------------------------------------------------------- /neuron_explainer/explanations/puzzles.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass 4 | 5 | from neuron_explainer.activations.activations import ActivationRecord 6 | 7 | 8 | @dataclass(frozen=True) 9 | class Puzzle: 10 | """A puzzle is a ground truth explanation, a collection of sentences (stored as ActivationRecords) with activations 11 | according to that explanation, and a collection of false explanations""" 12 | 13 | name: str 14 | explanation: str 15 | activation_records: list[ActivationRecord] 16 | false_explanations: list[str] 17 | 18 | 19 | def convert_puzzle_to_tokenized_sentences(puzzle: Puzzle) -> list[list[str]]: 20 | """Converts a puzzle to a list of tokenized sentences.""" 21 | return [record.tokens for record in puzzle.activation_records] 22 | 23 | 24 | def convert_puzzle_dict_to_puzzle(puzzle_dict: dict) -> Puzzle: 25 | """Converts a json dictionary representation of a puzzle to the Puzzle class.""" 26 | puzzle_activation_records = [] 27 | for sentence in puzzle_dict["sentences"]: 28 | # Token-activation pairs are listed as either a string or a list of a string and a float. If it is a list, the float is the activation. 29 | # If it is only a string, the activation is assumed to be 0. This is useful for readability and reducing redundancy in the data. 30 | tokens = [t[0] if type(t) is list else t for t in sentence] 31 | assert all([type(t) is str for t in tokens]), "All tokens must be strings" 32 | activations = [float(t[1]) if type(t) is list else 0.0 for t in sentence] 33 | assert all([type(t) is float for t in activations]), "All activations must be floats" 34 | 35 | puzzle_activation_records.append(ActivationRecord(tokens=tokens, activations=activations)) 36 | 37 | return Puzzle( 38 | name=puzzle_dict["name"], 39 | explanation=puzzle_dict["explanation"], 40 | activation_records=puzzle_activation_records, 41 | false_explanations=puzzle_dict["false_explanations"], 42 | ) 43 | 44 | 45 | PUZZLES_BY_NAME: dict[str, Puzzle] = dict() 46 | script_dir = os.path.dirname(os.path.abspath(__file__)) 47 | with open(os.path.join(script_dir, "puzzles.json"), "r") as f: 48 | puzzle_dicts = json.loads(f.read()) 49 | for name in puzzle_dicts.keys(): 50 | PUZZLES_BY_NAME[name] = convert_puzzle_dict_to_puzzle(puzzle_dicts[name]) 51 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/scoring.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import logging 5 | from typing import Any, Callable, Coroutine, Sequence 6 | 7 | import numpy as np 8 | from neuron_explainer.activations.activations import ActivationRecord 9 | from neuron_explainer.explanations.calibrated_simulator import ( 10 | CalibratedNeuronSimulator, 11 | LinearCalibratedNeuronSimulator, 12 | ) 13 | from neuron_explainer.explanations.explanations import ( 14 | ScoredSequenceSimulation, 15 | ScoredSimulation, 16 | SequenceSimulation, 17 | ) 18 | from neuron_explainer.explanations.simulator import ExplanationNeuronSimulator, NeuronSimulator 19 | 20 | 21 | def flatten_list(list_of_lists: Sequence[Sequence[Any]]) -> list[Any]: 22 | return [item for sublist in list_of_lists for item in sublist] 23 | 24 | 25 | def correlation_score( 26 | real_activations: Sequence[float] | np.ndarray, 27 | predicted_activations: Sequence[float] | np.ndarray, 28 | ) -> float: 29 | return np.corrcoef(real_activations, predicted_activations)[0, 1] 30 | 31 | 32 | def score_from_simulation( 33 | real_activations: ActivationRecord, 34 | simulation: SequenceSimulation, 35 | score_function: Callable[[Sequence[float] | np.ndarray, Sequence[float] | np.ndarray], float], 36 | ) -> float: 37 | return score_function(real_activations.activations, simulation.expected_activations) 38 | 39 | 40 | def rsquared_score_from_sequences( 41 | real_activations: Sequence[float] | np.ndarray, 42 | predicted_activations: Sequence[float] | np.ndarray, 43 | ) -> float: 44 | return float( 45 | 1 46 | - np.mean(np.square(np.array(real_activations) - np.array(predicted_activations))) 47 | / np.mean(np.square(np.array(real_activations))) 48 | ) 49 | 50 | 51 | def absolute_dev_explained_score_from_sequences( 52 | real_activations: Sequence[float] | np.ndarray, 53 | predicted_activations: Sequence[float] | np.ndarray, 54 | ) -> float: 55 | return float( 56 | 1 57 | - np.mean(np.abs(np.array(real_activations) - np.array(predicted_activations))) 58 | / np.mean(np.abs(np.array(real_activations))) 59 | ) 60 | 61 | 62 | async def make_explanation_simulator( 63 | explanation: str, 64 | calibration_activation_records: Sequence[ActivationRecord], 65 | model_name: str, 66 | calibrated_simulator_class: type[CalibratedNeuronSimulator] = LinearCalibratedNeuronSimulator, 67 | ) -> CalibratedNeuronSimulator: 68 | """ 69 | Make a simulator that uses an explanation to predict activations and calibrates it on the given 70 | activation records. 71 | """ 72 | simulator = ExplanationNeuronSimulator(model_name, explanation) 73 | calibrated_simulator = calibrated_simulator_class(simulator) 74 | await calibrated_simulator.calibrate(calibration_activation_records) 75 | return calibrated_simulator 76 | 77 | 78 | async def _simulate_and_score_sequence( 79 | simulator: NeuronSimulator, activations: ActivationRecord 80 | ) -> ScoredSequenceSimulation: 81 | """Score an explanation of a neuron by how well it predicts activations on a sentence.""" 82 | simulation = await simulator.simulate(activations.tokens) 83 | logging.debug(simulation) 84 | rsquared_score = score_from_simulation(activations, simulation, rsquared_score_from_sequences) 85 | absolute_dev_explained_score = score_from_simulation( 86 | activations, simulation, absolute_dev_explained_score_from_sequences 87 | ) 88 | scored_sequence_simulation = ScoredSequenceSimulation( 89 | simulation=simulation, 90 | true_activations=activations.activations, 91 | ev_correlation_score=score_from_simulation(activations, simulation, correlation_score), 92 | rsquared_score=rsquared_score, 93 | absolute_dev_explained_score=absolute_dev_explained_score, 94 | ) 95 | return scored_sequence_simulation 96 | 97 | 98 | def aggregate_scored_sequence_simulations( 99 | scored_sequence_simulations: list[ScoredSequenceSimulation], 100 | ) -> ScoredSimulation: 101 | """ 102 | Aggregate a list of scored sequence simulations. The logic for doing this is non-trivial for EV 103 | scores, since we want to calculate the correlation over all activations from all sequences at 104 | once rather than simply averaging per-sequence correlations. 105 | """ 106 | all_true_activations: list[float] = [] 107 | all_expected_values: list[float] = [] 108 | for scored_sequence_simulation in scored_sequence_simulations: 109 | all_true_activations.extend(scored_sequence_simulation.true_activations or []) 110 | all_expected_values.extend(scored_sequence_simulation.simulation.expected_activations) 111 | ev_correlation_score = ( 112 | correlation_score(all_true_activations, all_expected_values) 113 | if len(all_true_activations) > 0 114 | else None 115 | ) 116 | rsquared_score = rsquared_score_from_sequences(all_true_activations, all_expected_values) 117 | absolute_dev_explained_score = absolute_dev_explained_score_from_sequences( 118 | all_true_activations, all_expected_values 119 | ) 120 | 121 | return ScoredSimulation( 122 | scored_sequence_simulations=scored_sequence_simulations, 123 | ev_correlation_score=ev_correlation_score, 124 | rsquared_score=rsquared_score, 125 | absolute_dev_explained_score=absolute_dev_explained_score, 126 | ) 127 | 128 | 129 | async def simulate_and_score( 130 | simulator: NeuronSimulator, 131 | activation_records: Sequence[ActivationRecord], 132 | ) -> ScoredSimulation: 133 | """ 134 | Score an explanation of a neuron by how well it predicts activations on the given text 135 | sequences. 136 | """ 137 | scored_sequence_simulations = await asyncio.gather( 138 | *[ 139 | _simulate_and_score_sequence( 140 | simulator, 141 | activation_record, 142 | ) 143 | for activation_record in activation_records 144 | ] 145 | ) 146 | return aggregate_scored_sequence_simulations(scored_sequence_simulations) 147 | 148 | 149 | async def make_simulator_and_score( 150 | make_simulator: Coroutine[None, None, NeuronSimulator], 151 | activation_records: Sequence[ActivationRecord], 152 | ) -> ScoredSimulation: 153 | """Chain together creating the simulator and using it to score activation records.""" 154 | simulator = await make_simulator 155 | return await simulate_and_score(simulator, activation_records) 156 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/simulator.py: -------------------------------------------------------------------------------- 1 | """Uses API calls to simulate neuron activations based on an explanation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import logging 7 | import json 8 | from abc import ABC, abstractmethod 9 | from collections import OrderedDict 10 | from enum import Enum 11 | from typing import Any, Optional, Sequence, Union 12 | 13 | import numpy as np 14 | from neuron_explainer.activations.activation_records import ( 15 | calculate_max_activation, 16 | format_activation_records, 17 | format_sequences_for_simulation, 18 | normalize_activations, 19 | ) 20 | from neuron_explainer.activations.activations import ActivationRecord 21 | from neuron_explainer.api_client import ApiClient 22 | from neuron_explainer.explanations.explainer import EXPLANATION_PREFIX 23 | from neuron_explainer.explanations.explanations import ( 24 | ActivationScale, 25 | SequenceSimulation, 26 | ) 27 | from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet 28 | from neuron_explainer.explanations.prompt_builder import ( 29 | HarmonyMessage, 30 | PromptBuilder, 31 | PromptFormat, 32 | Role, 33 | ) 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | # Our prompts use normalized activation values, which map any range of positive activations to the 38 | # integers from 0 to 10. 39 | MAX_NORMALIZED_ACTIVATION = 10 40 | VALID_ACTIVATION_TOKENS_ORDERED = list( 41 | str(i) for i in range(MAX_NORMALIZED_ACTIVATION + 1) 42 | ) 43 | VALID_ACTIVATION_TOKENS = set(VALID_ACTIVATION_TOKENS_ORDERED) 44 | 45 | # Edge Case #3: The chat-based simulator is confused by end token. Replace it with a "not end token" 46 | END_OF_TEXT_TOKEN = "<|endoftext|>" 47 | END_OF_TEXT_TOKEN_REPLACEMENT = "<|not_endoftext|>" 48 | 49 | 50 | class SimulationType(str, Enum): 51 | """How to simulate neuron activations. Values correspond to subclasses of NeuronSimulator.""" 52 | 53 | ALL_AT_ONCE = "all_at_once" 54 | """ 55 | Use a single prompt with tokens; calculate EVs using logprobs. 56 | 57 | Implemented by ExplanationNeuronSimulator. 58 | """ 59 | 60 | ONE_AT_A_TIME = "one_at_a_time" 61 | """ 62 | Use a separate prompt for each token being simulated; calculate EVs using logprobs. 63 | 64 | Implemented by ExplanationTokenByTokenSimulator. 65 | """ 66 | 67 | @classmethod 68 | def from_string(cls, s: str) -> SimulationType: 69 | for simulation_type in SimulationType: 70 | if simulation_type.value == s: 71 | return simulation_type 72 | raise ValueError(f"Invalid simulation type: {s}") 73 | 74 | 75 | def compute_expected_value( 76 | norm_probabilities_by_distribution_value: OrderedDict[int, float] 77 | ) -> float: 78 | """ 79 | Given a map from distribution values (integers on the range [0, 10]) to normalized 80 | probabilities, return an expected value for the distribution. 81 | """ 82 | return np.dot( 83 | np.array(list(norm_probabilities_by_distribution_value.keys())), 84 | np.array(list(norm_probabilities_by_distribution_value.values())), 85 | ) 86 | 87 | 88 | def parse_top_logprobs(top_logprobs: dict[str, float]) -> OrderedDict[int, float]: 89 | """ 90 | Given a map from tokens to logprobs, return a map from distribution values (integers on the 91 | range [0, 10]) to unnormalized probabilities (in the sense that they may not sum to 1). 92 | """ 93 | probabilities_by_distribution_value = OrderedDict() 94 | for token, logprob in top_logprobs.items(): 95 | if token in VALID_ACTIVATION_TOKENS: 96 | token_as_int = int(token) 97 | probabilities_by_distribution_value[token_as_int] = np.exp(logprob) 98 | return probabilities_by_distribution_value 99 | 100 | 101 | def compute_predicted_activation_stats_for_token( 102 | top_logprobs: dict[str, float], 103 | ) -> tuple[OrderedDict[int, float], float]: 104 | probabilities_by_distribution_value = parse_top_logprobs(top_logprobs) 105 | total_p_of_distribution_values = sum(probabilities_by_distribution_value.values()) 106 | norm_probabilities_by_distribution_value = OrderedDict( 107 | { 108 | distribution_value: p / total_p_of_distribution_values 109 | for distribution_value, p in probabilities_by_distribution_value.items() 110 | } 111 | ) 112 | expected_value = compute_expected_value(norm_probabilities_by_distribution_value) 113 | return ( 114 | norm_probabilities_by_distribution_value, 115 | expected_value, 116 | ) 117 | 118 | 119 | # Adapted from tether/tether/core/encoder.py. 120 | def convert_to_byte_array(s: str) -> bytearray: 121 | byte_array = bytearray() 122 | assert s.startswith("bytes:"), s 123 | s = s[6:] 124 | while len(s) > 0: 125 | if s[0] == "\\": 126 | # Hex encoding. 127 | assert s[1] == "x" 128 | assert len(s) >= 4 129 | byte_array.append(int(s[2:4], 16)) 130 | s = s[4:] 131 | else: 132 | # Regular ascii encoding. 133 | byte_array.append(ord(s[0])) 134 | s = s[1:] 135 | return byte_array 136 | 137 | 138 | def handle_byte_encoding( 139 | response_tokens: Sequence[str], merged_response_index: int 140 | ) -> tuple[str, int]: 141 | """ 142 | Handle the case where the current token is a sequence of bytes. This may involve merging 143 | multiple response tokens into a single token. 144 | """ 145 | response_token = response_tokens[merged_response_index] 146 | if response_token.startswith("bytes:"): 147 | byte_array = bytearray() 148 | while True: 149 | byte_array = convert_to_byte_array(response_token) + byte_array 150 | try: 151 | # If we can decode the byte array as utf-8, then we're done. 152 | response_token = byte_array.decode("utf-8") 153 | break 154 | except UnicodeDecodeError: 155 | # If not, then we need to merge the previous response token into the byte 156 | # array. 157 | merged_response_index -= 1 158 | response_token = response_tokens[merged_response_index] 159 | return response_token, merged_response_index 160 | 161 | 162 | def was_token_split( 163 | current_token: str, response_tokens: Sequence[str], start_index: int 164 | ) -> bool: 165 | """ 166 | Return whether current_token (a token from the subject model) was split into multiple tokens by 167 | the simulator model (as represented by the tokens in response_tokens). start_index is the index 168 | in response_tokens at which to begin looking backward to form a complete token. It is usually 169 | the first token *before* the delimiter that separates the token from the normalized activation, 170 | barring some unusual cases. 171 | 172 | This mainly happens if the subject model uses a different tokenizer than the simulator model. 173 | But it can also happen in cases where Unicode characters are split. This function handles both 174 | cases. 175 | """ 176 | merged_response_tokens = "" 177 | merged_response_index = start_index 178 | while len(merged_response_tokens) < len(current_token): 179 | response_token = response_tokens[merged_response_index] 180 | response_token, merged_response_index = handle_byte_encoding( 181 | response_tokens, merged_response_index 182 | ) 183 | merged_response_tokens = response_token + merged_response_tokens 184 | merged_response_index -= 1 185 | # It's possible that merged_response_tokens is longer than current_token at this point, 186 | # since the between-lines delimiter may have been merged into the original token. But it 187 | # should always be the case that merged_response_tokens ends with current_token. 188 | assert merged_response_tokens.endswith(current_token) 189 | num_merged_tokens = start_index - merged_response_index 190 | token_was_split = num_merged_tokens > 1 191 | if token_was_split: 192 | logger.debug( 193 | "Warning: token from the subject model was split into 2+ tokens by the simulator model." 194 | ) 195 | return token_was_split 196 | 197 | 198 | def parse_simulation_response( 199 | response: dict[str, Any], 200 | prompt_format: PromptFormat, 201 | tokens: Sequence[str], 202 | ) -> SequenceSimulation: 203 | """ 204 | Parse an API response to a simulation prompt. 205 | 206 | Args: 207 | response: response from the API 208 | prompt_format: how the prompt was formatted 209 | tokens: list of tokens as strings in the sequence where the neuron is being simulated 210 | """ 211 | choice = response["choices"][0] 212 | if prompt_format == PromptFormat.HARMONY_V4: 213 | text = choice["message"]["content"] 214 | elif prompt_format in [ 215 | PromptFormat.NONE, 216 | PromptFormat.INSTRUCTION_FOLLOWING, 217 | ]: 218 | text = choice["text"] 219 | else: 220 | raise ValueError(f"Unhandled prompt format {prompt_format}") 221 | response_tokens = choice["logprobs"]["tokens"] 222 | choice["logprobs"]["token_logprobs"] 223 | top_logprobs = choice["logprobs"]["top_logprobs"] 224 | token_text_offset = choice["logprobs"]["text_offset"] 225 | # This only works because the sequence "" tokenizes into multiple tokens if it appears in 226 | # a text sequence in the prompt. 227 | scoring_start = text.rfind("") 228 | expected_values = [] 229 | original_sequence_tokens: list[str] = [] 230 | distribution_values: list[list[float]] = [] 231 | distribution_probabilities: list[list[float]] = [] 232 | for i in range(2, len(response_tokens)): 233 | if len(original_sequence_tokens) == len(tokens): 234 | # Make sure we haven't hit some sort of off-by-one error. 235 | # TODO(sbills): Generalize this to handle different tokenizers. 236 | reached_end = ( 237 | response_tokens[i + 1] == "<" and response_tokens[i + 2] == "end" 238 | ) 239 | assert reached_end, f"{response_tokens[i-3:i+3]}" 240 | break 241 | if token_text_offset[i] >= scoring_start: 242 | # We're looking for the first token after a tab. This token should be the text 243 | # "unknown" if hide_activations=True or a normalized activation (0-10) otherwise. 244 | # If it isn't, that means that the tab is not appearing as a delimiter, but rather 245 | # as a token, in which case we should move on to the next response token. 246 | if response_tokens[i - 1] == "\t": 247 | if response_tokens[i] != "unknown": 248 | logger.debug( 249 | "Ignoring tab token that is not followed by an 'unknown' token." 250 | ) 251 | continue 252 | 253 | # j represents the index of the token in a "tokenactivation" line, barring 254 | # one of the unusual cases handled below. 255 | j = i - 2 256 | 257 | current_token = tokens[len(original_sequence_tokens)] 258 | if current_token == response_tokens[j] or was_token_split( 259 | current_token, response_tokens, j 260 | ): 261 | # We're in the normal case where the tokenization didn't throw off the 262 | # formatting or in the token-was-split case, which we handle the usual way. 263 | current_top_logprobs = top_logprobs[i] 264 | 265 | ( 266 | norm_probabilities_by_distribution_value, 267 | expected_value, 268 | ) = compute_predicted_activation_stats_for_token( 269 | current_top_logprobs, 270 | ) 271 | current_distribution_values = list( 272 | norm_probabilities_by_distribution_value.keys() 273 | ) 274 | current_distribution_probabilities = list( 275 | norm_probabilities_by_distribution_value.values() 276 | ) 277 | else: 278 | # We're in a case where the tokenization resulted in a newline being folded into 279 | # the token. We can't do our usual prediction of activation stats for the token, 280 | # since the model did not observe the original token. Instead, we use dummy 281 | # values. See the TODO elsewhere in this file about coming up with a better 282 | # prompt format that avoids this situation. 283 | newline_folded_into_token = "\n" in response_tokens[j] 284 | assert ( 285 | newline_folded_into_token 286 | ), f"`{current_token=}` {response_tokens[j-3:j+3]=}" 287 | logger.debug( 288 | "Warning: newline before a tokenactivation line was folded into the token" 289 | ) 290 | current_distribution_values = [] 291 | current_distribution_probabilities = [] 292 | expected_value = 0.0 293 | 294 | original_sequence_tokens.append(current_token) 295 | distribution_values.append( 296 | [float(v) for v in current_distribution_values] 297 | ) 298 | distribution_probabilities.append(current_distribution_probabilities) 299 | expected_values.append(expected_value) 300 | 301 | return SequenceSimulation( 302 | tokens=original_sequence_tokens, 303 | expected_activations=expected_values, 304 | activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, 305 | distribution_values=distribution_values, 306 | distribution_probabilities=distribution_probabilities, 307 | ) 308 | 309 | 310 | class NeuronSimulator(ABC): 311 | """Abstract base class for simulating neuron behavior.""" 312 | 313 | @abstractmethod 314 | async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation: 315 | """Simulate the behavior of a neuron based on an explanation.""" 316 | ... 317 | 318 | 319 | class ExplanationNeuronSimulator(NeuronSimulator): 320 | """ 321 | Simulate neuron behavior based on an explanation. 322 | 323 | This class uses a few-shot prompt with examples of other explanations and activations. This 324 | prompt allows us to score all of the tokens at once using a nifty trick involving logprobs. 325 | """ 326 | 327 | def __init__( 328 | self, 329 | model_name: str, 330 | explanation: str, 331 | max_concurrent: Optional[int] = 10, 332 | few_shot_example_set: FewShotExampleSet = FewShotExampleSet.ORIGINAL, 333 | prompt_format: PromptFormat = PromptFormat.INSTRUCTION_FOLLOWING, 334 | cache: bool = False, 335 | ): 336 | self.api_client = ApiClient( 337 | model_name=model_name, max_concurrent=max_concurrent, cache=cache 338 | ) 339 | self.explanation = explanation 340 | self.few_shot_example_set = few_shot_example_set 341 | self.prompt_format = prompt_format 342 | 343 | async def simulate( 344 | self, 345 | tokens: Sequence[str], 346 | ) -> SequenceSimulation: 347 | prompt = self.make_simulation_prompt(tokens) 348 | 349 | generate_kwargs: dict[str, Any] = { 350 | "max_tokens": 0, 351 | "echo": True, 352 | "logprobs": 15, 353 | } 354 | if self.prompt_format == PromptFormat.HARMONY_V4: 355 | assert isinstance(prompt, list) 356 | assert isinstance(prompt[0], dict) # Really a HarmonyMessage 357 | generate_kwargs["messages"] = prompt 358 | else: 359 | assert isinstance(prompt, str) 360 | generate_kwargs["prompt"] = prompt 361 | 362 | response = await self.api_client.make_request(**generate_kwargs) 363 | logger.debug("response in score_explanation_by_activations is %s", response) 364 | result = parse_simulation_response(response, self.prompt_format, tokens) 365 | logger.debug("result in score_explanation_by_activations is %s", result) 366 | return result 367 | 368 | # TODO(sbills): The current tokenactivation format can result in improper tokenization. 369 | # In particular, if the token is itself a tab, we may get a single "\t\t" token rather than two 370 | # "\t" tokens. Consider using a separator that does not appear in any multi-character tokens. 371 | def make_simulation_prompt( 372 | self, tokens: Sequence[str] 373 | ) -> Union[str, list[HarmonyMessage]]: 374 | """Create a few-shot prompt for predicting neuron activations for the given tokens.""" 375 | 376 | # TODO(sbills): The prompts in this file are subtly different from the ones in explainer.py. 377 | # Consider reconciling them. 378 | prompt_builder = PromptBuilder() 379 | prompt_builder.add_message( 380 | Role.SYSTEM, 381 | """We're studying neurons in a neural network. 382 | Each neuron looks for some particular thing in a short document. 383 | Look at summary of what the neuron does, and try to predict how it will fire on each token. 384 | 385 | The activation format is tokenactivation, activations go from 0 to 10, "unknown" indicates an unknown activation. Most activations will be 0. 386 | """, 387 | ) 388 | 389 | few_shot_examples = self.few_shot_example_set.get_examples() 390 | for i, example in enumerate(few_shot_examples): 391 | prompt_builder.add_message( 392 | Role.USER, 393 | f"\n\nNeuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " 394 | f"{example.explanation}", 395 | ) 396 | formatted_activation_records = format_activation_records( 397 | example.activation_records, 398 | calculate_max_activation(example.activation_records), 399 | start_indices=example.first_revealed_activation_indices, 400 | ) 401 | prompt_builder.add_message( 402 | Role.ASSISTANT, f"\nActivations: {formatted_activation_records}\n" 403 | ) 404 | 405 | prompt_builder.add_message( 406 | Role.USER, 407 | f"\n\nNeuron {len(few_shot_examples) + 1}\nExplanation of neuron " 408 | f"{len(few_shot_examples) + 1} behavior: {EXPLANATION_PREFIX} " 409 | f"{self.explanation.strip()}", 410 | ) 411 | prompt_builder.add_message( 412 | Role.ASSISTANT, 413 | f"\nActivations: {format_sequences_for_simulation([tokens])}", 414 | ) 415 | return prompt_builder.build(self.prompt_format) 416 | 417 | 418 | class ExplanationTokenByTokenSimulator(NeuronSimulator): 419 | """ 420 | Simulate neuron behavior based on an explanation. 421 | 422 | Unlike ExplanationNeuronSimulator, this class uses one few-shot prompt per token to calculate 423 | expected activations. This is slower. This class gets a one-token completion and calculates an 424 | expected value from that token's logprobs. 425 | """ 426 | 427 | def __init__( 428 | self, 429 | model_name: str, 430 | explanation: str, 431 | max_concurrent: Optional[int] = 10, 432 | few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, 433 | prompt_format: PromptFormat = PromptFormat.INSTRUCTION_FOLLOWING, 434 | cache: bool = False, 435 | ): 436 | assert ( 437 | few_shot_example_set != FewShotExampleSet.ORIGINAL 438 | ), "This simulator doesn't support the ORIGINAL few-shot example set." 439 | self.api_client = ApiClient( 440 | model_name=model_name, max_concurrent=max_concurrent, cache=cache 441 | ) 442 | self.explanation = explanation 443 | self.few_shot_example_set = few_shot_example_set 444 | self.prompt_format = prompt_format 445 | 446 | async def simulate( 447 | self, 448 | tokens: Sequence[str], 449 | ) -> SequenceSimulation: 450 | responses_by_token = await asyncio.gather( 451 | *[ 452 | self._get_activation_stats_for_single_token( 453 | tokens, self.explanation, token_index 454 | ) 455 | for token_index in range(len(tokens)) 456 | ] 457 | ) 458 | expected_values, distribution_values, distribution_probabilities = [], [], [] 459 | for response in responses_by_token: 460 | activation_logprobs = response["choices"][0]["logprobs"]["top_logprobs"][0] 461 | ( 462 | norm_probabilities_by_distribution_value, 463 | expected_value, 464 | ) = compute_predicted_activation_stats_for_token( 465 | activation_logprobs, 466 | ) 467 | distribution_values.append( 468 | [float(v) for v in norm_probabilities_by_distribution_value.keys()] 469 | ) 470 | distribution_probabilities.append( 471 | list(norm_probabilities_by_distribution_value.values()) 472 | ) 473 | expected_values.append(expected_value) 474 | 475 | result = SequenceSimulation( 476 | tokens=list(tokens), # SequenceSimulation expects List type 477 | expected_activations=expected_values, 478 | activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, 479 | distribution_values=distribution_values, 480 | distribution_probabilities=distribution_probabilities, 481 | ) 482 | logger.debug("result in score_explanation_by_activations is %s", result) 483 | return result 484 | 485 | async def _get_activation_stats_for_single_token( 486 | self, 487 | tokens: Sequence[str], 488 | explanation: str, 489 | token_index_to_score: int, 490 | ) -> dict: 491 | prompt = self.make_single_token_simulation_prompt( 492 | tokens, 493 | explanation, 494 | token_index_to_score=token_index_to_score, 495 | ) 496 | return await self.api_client.make_request( 497 | prompt=prompt, max_tokens=1, echo=False, logprobs=15 498 | ) 499 | 500 | def _add_single_token_simulation_subprompt( 501 | self, 502 | prompt_builder: PromptBuilder, 503 | activation_record: ActivationRecord, 504 | neuron_index: int, 505 | explanation: str, 506 | token_index_to_score: int, 507 | end_of_prompt: bool, 508 | ) -> None: 509 | trimmed_activation_record = ActivationRecord( 510 | tokens=activation_record.tokens[: token_index_to_score + 1], 511 | activations=activation_record.activations[: token_index_to_score + 1], 512 | ) 513 | prompt_builder.add_message( 514 | Role.USER, 515 | f""" 516 | Neuron {neuron_index} 517 | Explanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} {explanation.strip()} 518 | Text: 519 | {"".join(trimmed_activation_record.tokens)} 520 | 521 | Last token in the text: 522 | {trimmed_activation_record.tokens[-1]} 523 | 524 | Last token activation, considering the token in the context in which it appeared in the text: 525 | """, 526 | ) 527 | if not end_of_prompt: 528 | normalized_activations = normalize_activations( 529 | trimmed_activation_record.activations, 530 | calculate_max_activation([activation_record]), 531 | ) 532 | prompt_builder.add_message( 533 | Role.ASSISTANT, 534 | str(normalized_activations[-1]) + ("" if end_of_prompt else "\n\n"), 535 | ) 536 | 537 | def make_single_token_simulation_prompt( 538 | self, 539 | tokens: Sequence[str], 540 | explanation: str, 541 | token_index_to_score: int, 542 | ) -> Union[str, list[HarmonyMessage]]: 543 | """Make a few-shot prompt for predicting the neuron's activation on a single token.""" 544 | assert explanation != "" 545 | prompt_builder = PromptBuilder() 546 | prompt_builder.add_message( 547 | Role.SYSTEM, 548 | """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. 549 | 550 | The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. 551 | 552 | """, 553 | ) 554 | 555 | few_shot_examples = self.few_shot_example_set.get_examples() 556 | for i, example in enumerate(few_shot_examples): 557 | prompt_builder.add_message( 558 | Role.USER, 559 | f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " 560 | f"{example.explanation}\n", 561 | ) 562 | formatted_activation_records = format_activation_records( 563 | example.activation_records, 564 | calculate_max_activation(example.activation_records), 565 | start_indices=None, 566 | ) 567 | prompt_builder.add_message( 568 | Role.ASSISTANT, 569 | f"Activations: {formatted_activation_records}\n\n", 570 | ) 571 | 572 | prompt_builder.add_message( 573 | Role.SYSTEM, 574 | "Now, we're going predict the activation of a new neuron on a single token, " 575 | "following the same rules as the examples above. Activations still range from 0 to 10.", 576 | ) 577 | single_token_example = ( 578 | self.few_shot_example_set.get_single_token_prediction_example() 579 | ) 580 | assert single_token_example.token_index_to_score is not None 581 | self._add_single_token_simulation_subprompt( 582 | prompt_builder, 583 | single_token_example.activation_records[0], 584 | len(few_shot_examples) + 1, 585 | explanation, 586 | token_index_to_score=single_token_example.token_index_to_score, 587 | end_of_prompt=False, 588 | ) 589 | 590 | activation_record = ActivationRecord( 591 | tokens=list( 592 | tokens[: token_index_to_score + 1] 593 | ), # ActivationRecord expects List type. 594 | activations=[0.0] * len(tokens), 595 | ) 596 | self._add_single_token_simulation_subprompt( 597 | prompt_builder, 598 | activation_record, 599 | len(few_shot_examples) + 2, 600 | explanation, 601 | token_index_to_score, 602 | end_of_prompt=True, 603 | ) 604 | return prompt_builder.build( 605 | self.prompt_format, allow_extra_system_messages=True 606 | ) 607 | 608 | 609 | def _format_record_for_logprob_free_simulation( 610 | activation_record: ActivationRecord, 611 | include_activations: bool = False, 612 | max_activation: Optional[float] = None, 613 | ) -> str: 614 | response = "" 615 | if include_activations: 616 | assert max_activation is not None 617 | assert len(activation_record.tokens) == len( 618 | activation_record.activations 619 | ), f"{len(activation_record.tokens)=}, {len(activation_record.activations)=}" 620 | normalized_activations = normalize_activations( 621 | activation_record.activations, max_activation=max_activation 622 | ) 623 | for i, token in enumerate(activation_record.tokens): 624 | # Edge Case #3: End tokens confuse the chat-based simulator. Replace end token with "not end token". 625 | if token.strip() == END_OF_TEXT_TOKEN: 626 | token = END_OF_TEXT_TOKEN_REPLACEMENT 627 | # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). 628 | if include_activations: 629 | response += f"{token}\t{normalized_activations[i]}༗\n" 630 | else: 631 | response += f"{token}\t༗\n" 632 | return response 633 | 634 | 635 | def _format_record_for_logprob_free_simulation_json( 636 | explanation: str, 637 | activation_record: ActivationRecord, 638 | include_activations: bool = False, 639 | ) -> str: 640 | if include_activations: 641 | assert len(activation_record.tokens) == len( 642 | activation_record.activations 643 | ), f"{len(activation_record.tokens)=}, {len(activation_record.activations)=}" 644 | return json.dumps( 645 | { 646 | "to_find": explanation, 647 | "document": "".join(activation_record.tokens), 648 | "activations": [ 649 | { 650 | "token": token, 651 | "activation": ( 652 | activation_record.activations[i] 653 | if include_activations 654 | else None 655 | ), 656 | } 657 | for i, token in enumerate(activation_record.tokens) 658 | ], 659 | } 660 | ) 661 | 662 | 663 | def _parse_no_logprobs_completion_json( 664 | completion: str, 665 | tokens: Sequence[str], 666 | ) -> Sequence[float]: 667 | """ 668 | Parse a completion into a list of simulated activations. If the model did not faithfully 669 | reproduce the token sequence, return a list of 0s. If the model's activation for a token 670 | is not a number between 0 and 10 (inclusive), substitute 0. 671 | 672 | Args: 673 | completion: completion from the API 674 | tokens: list of tokens as strings in the sequence where the neuron is being simulated 675 | """ 676 | 677 | logger.debug("for tokens:\n%s", tokens) 678 | logger.debug("received completion:\n%s", completion) 679 | 680 | zero_prediction = [0] * len(tokens) 681 | 682 | try: 683 | completion = json.loads(completion) 684 | if "activations" not in completion: 685 | logger.error( 686 | "The key 'activations' is not in the completion:\n%s\nExpected Tokens:\n%s", 687 | json.dumps(completion), 688 | tokens, 689 | ) 690 | return zero_prediction 691 | activations = completion["activations"] 692 | if len(activations) != len(tokens): 693 | logger.error( 694 | "Tokens and activations length did not match:\n%s\nExpected Tokens:\n%s", 695 | json.dumps(completion), 696 | tokens, 697 | ) 698 | return zero_prediction 699 | predicted_activations = [] 700 | # check that there is a token and activation value 701 | # no need to double check the token matches exactly 702 | for i, activation in enumerate(activations): 703 | if "token" not in activation: 704 | logger.error( 705 | "The key 'token' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", 706 | activation, 707 | json.dumps(completion), 708 | tokens, 709 | ) 710 | predicted_activations.append(0) 711 | continue 712 | if "activation" not in activation: 713 | logger.error( 714 | "The key 'activation' is not in activation:\n%s\nCompletion:%s\nExpected Tokens:\n%s", 715 | activation, 716 | json.dumps(completion), 717 | tokens, 718 | ) 719 | predicted_activations.append(0) 720 | continue 721 | # Ensure activation value is between 0-10 inclusive 722 | try: 723 | predicted_activation_float = float(activation["activation"]) 724 | if ( 725 | predicted_activation_float < 0 726 | or predicted_activation_float > MAX_NORMALIZED_ACTIVATION 727 | ): 728 | logger.error( 729 | "activation value out of range: %s\nCompletion:%s\nExpected Tokens:\n%s", 730 | predicted_activation_float, 731 | json.dumps(completion), 732 | tokens, 733 | ) 734 | predicted_activations.append(0) 735 | else: 736 | predicted_activations.append(predicted_activation_float) 737 | except ValueError: 738 | logger.error( 739 | "activation value invalid: %s\nCompletion:%s\nExpected Tokens:\n%s", 740 | activation["activation"], 741 | json.dumps(completion), 742 | tokens, 743 | ) 744 | predicted_activations.append(0) 745 | except TypeError: 746 | logger.error( 747 | "activation value incorrect type: %s\nCompletion:%s\nExpected Tokens:\n%s", 748 | activation["activation"], 749 | json.dumps(completion), 750 | tokens, 751 | ) 752 | predicted_activations.append(0) 753 | logger.debug("predicted activations: %s", predicted_activations) 754 | return predicted_activations 755 | 756 | except json.JSONDecodeError: 757 | logger.warning( 758 | "Failed to parse completion JSON:\n%s\nExpected Tokens:\n%s", 759 | completion, 760 | tokens, 761 | ) 762 | return zero_prediction 763 | 764 | 765 | def _parse_no_logprobs_completion( 766 | completion: str, 767 | tokens: Sequence[str], 768 | ) -> Sequence[float]: 769 | """ 770 | Parse a completion into a list of simulated activations. If the model did not faithfully 771 | reproduce the token sequence, return a list of 0s. If the model's activation for a token 772 | is not a number between 0 and 10 (inclusive), substitute 0. 773 | 774 | Args: 775 | completion: completion from the API 776 | tokens: list of tokens as strings in the sequence where the neuron is being simulated 777 | """ 778 | 779 | logger.debug("for tokens:\n%s", tokens) 780 | logger.debug("received completion:\n%s", completion) 781 | 782 | zero_prediction = [0] * len(tokens) 783 | # FIX: Strip the last ༗\n, otherwise all last activations are invalid 784 | token_lines = completion.strip("\n").strip("༗\n").split("༗\n") 785 | # Edge Case #2: Sometimes GPT doesn't use the special character when it answers, it only uses the \n" 786 | # The fix is to try splitting by \n if we detect that the response isn't the right format 787 | # TODO: If there are also line breaks in the text, this will probably break 788 | if (len(token_lines)) == 1: 789 | token_lines = completion.strip("\n").strip("༗\n").split("\n") 790 | logger.debug("parsed completion into token_lines as:\n%s", token_lines) 791 | 792 | start_line_index = None 793 | for i, token_line in enumerate(token_lines): 794 | if ( 795 | token_line.startswith(f"{tokens[0]}\t") 796 | # Edge Case #1: GPT often omits the space before the first token. 797 | # Allow the returned token line to be either " token" or "token". 798 | or f" {token_line}".startswith(f"{tokens[0]}\t") 799 | # Edge Case #3: Allow our "not end token" replacement 800 | or ( 801 | token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) 802 | and tokens[0].strip() == END_OF_TEXT_TOKEN 803 | ) 804 | ): 805 | logger.debug("start_line_index is: %s", start_line_index) 806 | logger.debug("matched token %s with token_line %s", tokens[0], token_line) 807 | start_line_index = i 808 | break 809 | 810 | # If we didn't find the first token, or if the number of lines in the completion doesn't match 811 | # the number of tokens, return a list of 0s. 812 | if start_line_index is None or len(token_lines) - start_line_index != len(tokens): 813 | logger.debug( 814 | "didn't find first token or number of lines didn't match, returning all zeroes" 815 | ) 816 | return zero_prediction 817 | 818 | predicted_activations = [] 819 | for i, token_line in enumerate(token_lines[start_line_index:]): 820 | if ( 821 | not token_line.startswith(f"{tokens[i]}\t") 822 | # Edge Case #1: GPT often omits the space before the token. 823 | # Allow the returned token line to be either " token" or "token". 824 | and not f" {token_line}".startswith(f"{tokens[i]}\t") 825 | # Edge Case #3: Allow our "not end token" replacement 826 | and not token_line.startswith(END_OF_TEXT_TOKEN_REPLACEMENT) 827 | ): 828 | logger.debug( 829 | "failed to match token %s with token_line %s, returning all zeroes", 830 | tokens[i], 831 | token_line, 832 | ) 833 | return zero_prediction 834 | predicted_activation_split = token_line.split("\t") 835 | # Ensure token line has correct size after splitting. If not then assume it's a zero. 836 | if len(predicted_activation_split) != 2: 837 | logger.debug("tokenline split invalid size: %s", token_line) 838 | predicted_activations.append(0) 839 | continue 840 | predicted_activation = predicted_activation_split[1] 841 | # Sometimes GPT the activation value is not a float (GPT likes to append an extra ༗). 842 | # In all cases if the activation is not numerically parseable, set it to 0 843 | try: 844 | predicted_activation_float = float(predicted_activation) 845 | if ( 846 | predicted_activation_float < 0 847 | or predicted_activation_float > MAX_NORMALIZED_ACTIVATION 848 | ): 849 | logger.debug( 850 | "activation value out of range: %s", predicted_activation_float 851 | ) 852 | predicted_activations.append(0) 853 | else: 854 | predicted_activations.append(predicted_activation_float) 855 | except ValueError: 856 | logger.debug("activation value not numeric: %s", predicted_activation) 857 | predicted_activations.append(0) 858 | logger.debug("predicted activations: %s", predicted_activations) 859 | return predicted_activations 860 | 861 | 862 | class LogprobFreeExplanationTokenSimulator(NeuronSimulator): 863 | """ 864 | Simulate neuron behavior based on an explanation. 865 | 866 | Unlike ExplanationNeuronSimulator and ExplanationTokenByTokenSimulator, this class does not rely on 867 | logprobs to calculate expected activations. Instead, it uses a few-shot prompt that displays all of the 868 | tokens at once, and request that the model repeat the tokens with the activations appended. Sampling 869 | is with temperature = 0. Thus, the activations are deterministic. Also, each activation for a token 870 | is a function of all the activations that came previously and all of the tokens in the sequence, not 871 | just the current and previous tokens. In the case where the model does not faithfully reproduce the 872 | token sequence, the simulator will return a response where every predicted activation is 0. Example prompt as follows: 873 | 874 | Explanation: Explanation 1 875 | 876 | Sequence 1 Tokens Without Activations: 877 | 878 | A\t_ 879 | B\t_ 880 | C\t_ 881 | 882 | Sequence 1 Tokens With Activations: 883 | 884 | A\t4_ 885 | B\t10_ 886 | C\t0_ 887 | 888 | Sequence 2 Tokens Without Activations: 889 | 890 | D\t_ 891 | E\t_ 892 | F\t_ 893 | 894 | Sequence 2 Tokens With Activations: 895 | 896 | D\t3_ 897 | E\t6_ 898 | F\t9_ 899 | 900 | Explanation: Explanation 2 901 | 902 | Sequence 1 Tokens Without Activations: 903 | 904 | G\t_ 905 | H\t_ 906 | I\t_ 907 | 908 | Sequence 1 Tokens With Activations: 909 | 910 | 911 | G\t2_ 912 | H\t0_ 913 | I\t3_ 914 | 915 | """ 916 | 917 | def __init__( 918 | self, 919 | model_name: str, 920 | explanation: str, 921 | max_concurrent: Optional[int] = 10, 922 | json_mode: Optional[bool] = True, 923 | few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, 924 | prompt_format: PromptFormat = PromptFormat.HARMONY_V4, 925 | cache: bool = False, 926 | ): 927 | assert ( 928 | few_shot_example_set != FewShotExampleSet.ORIGINAL 929 | ), "This simulator doesn't support the ORIGINAL few-shot example set." 930 | self.api_client = ApiClient( 931 | model_name=model_name, max_concurrent=max_concurrent, cache=cache 932 | ) 933 | self.json_mode = json_mode 934 | self.explanation = explanation 935 | self.few_shot_example_set = few_shot_example_set 936 | self.prompt_format = prompt_format 937 | 938 | async def simulate( 939 | self, 940 | tokens: Sequence[str], 941 | ) -> SequenceSimulation: 942 | if self.json_mode: 943 | prompt = self._make_simulation_prompt_json( 944 | tokens, 945 | self.explanation, 946 | ) 947 | response = await self.api_client.make_request( 948 | messages=prompt, max_tokens=2000, temperature=0, json_mode=True 949 | ) 950 | assert len(response["choices"]) == 1 951 | choice = response["choices"][0] 952 | completion = choice["message"]["content"] 953 | predicted_activations = _parse_no_logprobs_completion_json( 954 | completion, tokens 955 | ) 956 | else: 957 | prompt = self._make_simulation_prompt( 958 | tokens, 959 | self.explanation, 960 | ) 961 | response = await self.api_client.make_request( 962 | messages=prompt, max_tokens=1000, temperature=0 963 | ) 964 | assert len(response["choices"]) == 1 965 | choice = response["choices"][0] 966 | completion = choice["message"]["content"] 967 | predicted_activations = _parse_no_logprobs_completion(completion, tokens) 968 | 969 | result = SequenceSimulation( 970 | activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, 971 | expected_activations=predicted_activations, 972 | # Since the predicted activation is just a sampled token, we don't have a distribution. 973 | distribution_values=[], 974 | distribution_probabilities=[], 975 | tokens=list(tokens), # SequenceSimulation expects List type 976 | ) 977 | logger.debug("result in score_explanation_by_activations is %s", result) 978 | return result 979 | 980 | def _make_simulation_prompt_json( 981 | self, 982 | tokens: Sequence[str], 983 | explanation: str, 984 | ) -> Union[str, list[HarmonyMessage]]: 985 | """Make a few-shot prompt for predicting the neuron's activations on a sequence.""" 986 | """NOTE: The JSON version does not give GPT multiple sequence examples per neuron.""" 987 | assert explanation != "" 988 | prompt_builder = PromptBuilder() 989 | prompt_builder.add_message( 990 | Role.SYSTEM, 991 | """We're studying neurons in a neural network. Each neuron looks for certain things in a short document. Your task is to read the explanation of what the neuron does, and predict the neuron's activations for each token in the document. 992 | 993 | For each document, you will see the full text of the document, then the tokens in the document with the activation left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. Pay special attention to the explanation's description of the context and order of tokens or words. 994 | 995 | Fill out the activation values from 0 to 10. Please think carefully."; 996 | """, 997 | ) 998 | 999 | few_shot_examples = self.few_shot_example_set.get_examples() 1000 | for example in few_shot_examples: 1001 | """ 1002 | { 1003 | "to_find": "hello", 1004 | "document": "The", 1005 | "activations": [ 1006 | { 1007 | "token": "The", 1008 | "activation": null 1009 | } 1010 | ] 1011 | } 1012 | """ 1013 | prompt_builder.add_message( 1014 | Role.USER, 1015 | _format_record_for_logprob_free_simulation_json( 1016 | explanation=example.explanation, 1017 | activation_record=example.activation_records[0], 1018 | include_activations=False, 1019 | ), 1020 | ) 1021 | """ 1022 | { 1023 | "to_find": "hello", 1024 | "document": "The", 1025 | "activations": [ 1026 | { 1027 | "token": "The", 1028 | "activation": 10 1029 | } 1030 | ] 1031 | } 1032 | """ 1033 | prompt_builder.add_message( 1034 | Role.ASSISTANT, 1035 | _format_record_for_logprob_free_simulation_json( 1036 | explanation=example.explanation, 1037 | activation_record=example.activation_records[0], 1038 | include_activations=True, 1039 | ), 1040 | ) 1041 | """ 1042 | { 1043 | "to_find": "hello", 1044 | "document": "The", 1045 | "activations": [ 1046 | { 1047 | "token": "The", 1048 | "activation": null 1049 | } 1050 | ] 1051 | } 1052 | """ 1053 | prompt_builder.add_message( 1054 | Role.USER, 1055 | _format_record_for_logprob_free_simulation_json( 1056 | explanation=explanation, 1057 | activation_record=ActivationRecord(tokens=tokens, activations=[]), 1058 | include_activations=False, 1059 | ), 1060 | ) 1061 | return prompt_builder.build( 1062 | self.prompt_format, allow_extra_system_messages=True 1063 | ) 1064 | 1065 | def _make_simulation_prompt( 1066 | self, 1067 | tokens: Sequence[str], 1068 | explanation: str, 1069 | ) -> Union[str, list[HarmonyMessage]]: 1070 | """Make a few-shot prompt for predicting the neuron's activations on a sequence.""" 1071 | assert explanation != "" 1072 | prompt_builder = PromptBuilder() 1073 | prompt_builder.add_message( 1074 | Role.SYSTEM, 1075 | """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. 1076 | 1077 | The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. 1078 | For each sequence, you will see the tokens in the sequence where the activations are left blank. You will print the exact same tokens verbatim, but with the activations filled in according to the explanation. 1079 | """, 1080 | ) 1081 | 1082 | few_shot_examples = self.few_shot_example_set.get_examples() 1083 | for i, example in enumerate(few_shot_examples): 1084 | few_shot_example_max_activation = calculate_max_activation( 1085 | example.activation_records 1086 | ) 1087 | 1088 | prompt_builder.add_message( 1089 | Role.USER, 1090 | f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " 1091 | f"{example.explanation}\n\n" 1092 | f"Sequence 1 Tokens without Activations:\n{_format_record_for_logprob_free_simulation(example.activation_records[0], include_activations=False)}\n\n" 1093 | f"Sequence 1 Tokens with Activations:\n", 1094 | ) 1095 | prompt_builder.add_message( 1096 | Role.ASSISTANT, 1097 | f"{_format_record_for_logprob_free_simulation(example.activation_records[0], include_activations=True, max_activation=few_shot_example_max_activation)}\n\n", 1098 | ) 1099 | 1100 | for record_index, record in enumerate(example.activation_records[1:]): 1101 | prompt_builder.add_message( 1102 | Role.USER, 1103 | f"Sequence {record_index + 2} Tokens without Activations:\n{_format_record_for_logprob_free_simulation(record, include_activations=False)}\n\n" 1104 | f"Sequence {record_index + 2} Tokens with Activations:\n", 1105 | ) 1106 | prompt_builder.add_message( 1107 | Role.ASSISTANT, 1108 | f"{_format_record_for_logprob_free_simulation(record, include_activations=True, max_activation=few_shot_example_max_activation)}\n\n", 1109 | ) 1110 | 1111 | neuron_index = len(few_shot_examples) + 1 1112 | prompt_builder.add_message( 1113 | Role.USER, 1114 | f"Neuron {neuron_index}\nExplanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} " 1115 | f"{explanation}\n\n" 1116 | f"Sequence 1 Tokens without Activations:\n{_format_record_for_logprob_free_simulation(ActivationRecord(tokens=tokens, activations=[]), include_activations=False)}\n\n" 1117 | f"Sequence 1 Tokens with Activations:\n", 1118 | ) 1119 | return prompt_builder.build( 1120 | self.prompt_format, allow_extra_system_messages=True 1121 | ) 1122 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/test_explainer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any 3 | 4 | from neuron_explainer.explanations.explainer import ( 5 | TokenActivationPairExplainer, 6 | TokenSpaceRepresentationExplainer, 7 | ) 8 | from neuron_explainer.explanations.few_shot_examples import TEST_EXAMPLES, FewShotExampleSet 9 | from neuron_explainer.explanations.prompt_builder import HarmonyMessage, PromptFormat, Role 10 | from neuron_explainer.explanations.token_space_few_shot_examples import ( 11 | TokenSpaceFewShotExampleSet, 12 | ) 13 | 14 | 15 | def setup_module(unused_module: Any) -> None: 16 | # Make sure we have an event loop, since the attempt to create the Semaphore in 17 | # ResearchApiClient will fail without it. 18 | loop = asyncio.new_event_loop() 19 | asyncio.set_event_loop(loop) 20 | 21 | 22 | def test_if_formatting() -> None: 23 | expected_prompt = """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is looking for. Don't list examples of words. 24 | 25 | The activation format is tokenactivation. Activation values range from 0 to 10. A neuron finding what it's looking for is represented by a non-zero activation value. The higher the activation value, the stronger the match. 26 | 27 | Neuron 1 28 | Activations: 29 | 30 | a 10 31 | b 0 32 | c 0 33 | 34 | 35 | d 0 36 | e 10 37 | f 0 38 | 39 | 40 | Explanation of neuron 1 behavior: the main thing this neuron does is find vowels. 41 | 42 | Neuron 2 43 | Activations: 44 | 45 | a 10 46 | b 0 47 | c 0 48 | 49 | 50 | d 0 51 | e 10 52 | f 0 53 | 54 | 55 | Explanation of neuron 2 behavior:<|endofprompt|> the main thing this neuron does is find""" 56 | 57 | explainer = TokenActivationPairExplainer( 58 | model_name="text-davinci-003", 59 | prompt_format=PromptFormat.INSTRUCTION_FOLLOWING, 60 | few_shot_example_set=FewShotExampleSet.TEST, 61 | ) 62 | prompt = explainer.make_explanation_prompt( 63 | all_activation_records=TEST_EXAMPLES[0].activation_records, 64 | max_activation=1.0, 65 | max_tokens_for_completion=20, 66 | ) 67 | 68 | assert prompt == expected_prompt 69 | 70 | 71 | def test_harmony_format() -> None: 72 | expected_prompt = [ 73 | HarmonyMessage( 74 | role=Role.SYSTEM, 75 | content="""We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is looking for. Don't list examples of words. 76 | 77 | The activation format is tokenactivation. Activation values range from 0 to 10. A neuron finding what it's looking for is represented by a non-zero activation value. The higher the activation value, the stronger the match.""", 78 | ), 79 | HarmonyMessage( 80 | role=Role.USER, 81 | content=""" 82 | 83 | Neuron 1 84 | Activations: 85 | 86 | a 10 87 | b 0 88 | c 0 89 | 90 | 91 | d 0 92 | e 10 93 | f 0 94 | 95 | 96 | Explanation of neuron 1 behavior: the main thing this neuron does is find""", 97 | ), 98 | HarmonyMessage( 99 | role=Role.ASSISTANT, 100 | content=" vowels.", 101 | ), 102 | HarmonyMessage( 103 | role=Role.USER, 104 | content=""" 105 | 106 | Neuron 2 107 | Activations: 108 | 109 | a 10 110 | b 0 111 | c 0 112 | 113 | 114 | d 0 115 | e 10 116 | f 0 117 | 118 | 119 | Explanation of neuron 2 behavior: the main thing this neuron does is find""", 120 | ), 121 | ] 122 | 123 | explainer = TokenActivationPairExplainer( 124 | model_name="gpt-4", 125 | prompt_format=PromptFormat.HARMONY_V4, 126 | few_shot_example_set=FewShotExampleSet.TEST, 127 | ) 128 | prompt = explainer.make_explanation_prompt( 129 | all_activation_records=TEST_EXAMPLES[0].activation_records, 130 | max_activation=1.0, 131 | max_tokens_for_completion=20, 132 | ) 133 | 134 | assert isinstance(prompt, list) 135 | assert isinstance(prompt[0], dict) # Really a HarmonyMessage 136 | for actual_message, expected_message in zip(prompt, expected_prompt): 137 | assert actual_message["role"] == expected_message["role"] 138 | assert actual_message["content"] == expected_message["content"] 139 | assert prompt == expected_prompt 140 | 141 | 142 | def test_token_space_explainer_if_formatting() -> None: 143 | expected_prompt = """We're studying neurons in a neural network. Each neuron looks for some particular kind of token (which can be a word, or part of a word). Look at the tokens the neuron activates for (listed below) and summarize in a single sentence what the neuron is looking for. Don't list examples of words. 144 | 145 | 146 | 147 | Tokens: 148 | 'these', ' are', ' tokens' 149 | 150 | Explanation: 151 | This neuron is looking for this is a test explanation. 152 | 153 | 154 | 155 | Tokens: 156 | 'foo', 'bar', 'baz' 157 | 158 | Explanation: 159 | <|endofprompt|>This neuron is looking for""" 160 | 161 | explainer = TokenSpaceRepresentationExplainer( 162 | model_name="text-davinci-002", 163 | prompt_format=PromptFormat.INSTRUCTION_FOLLOWING, 164 | use_few_shot=True, 165 | few_shot_example_set=TokenSpaceFewShotExampleSet.TEST, 166 | ) 167 | prompt = explainer.make_explanation_prompt( 168 | tokens=["foo", "bar", "baz"], 169 | max_tokens_for_completion=20, 170 | ) 171 | 172 | assert prompt == expected_prompt 173 | 174 | 175 | def test_token_space_explainer_harmony_formatting() -> None: 176 | expected_prompt = [ 177 | HarmonyMessage( 178 | role=Role.SYSTEM, 179 | content="We're studying neurons in a neural network. Each neuron looks for some particular kind of token (which can be a word, or part of a word). Look at the tokens the neuron activates for (listed below) and summarize in a single sentence what the neuron is looking for. Don't list examples of words.", 180 | ), 181 | HarmonyMessage( 182 | role=Role.USER, 183 | content=""" 184 | 185 | 186 | 187 | Tokens: 188 | 'these', ' are', ' tokens' 189 | 190 | Explanation: 191 | This neuron is looking for""", 192 | ), 193 | HarmonyMessage( 194 | role=Role.ASSISTANT, 195 | content=" this is a test explanation.", 196 | ), 197 | HarmonyMessage( 198 | role=Role.USER, 199 | content=""" 200 | 201 | 202 | 203 | Tokens: 204 | 'foo', 'bar', 'baz' 205 | 206 | Explanation: 207 | This neuron is looking for""", 208 | ), 209 | ] 210 | 211 | explainer = TokenSpaceRepresentationExplainer( 212 | model_name="gpt-4", 213 | prompt_format=PromptFormat.HARMONY_V4, 214 | use_few_shot=True, 215 | few_shot_example_set=TokenSpaceFewShotExampleSet.TEST, 216 | ) 217 | prompt = explainer.make_explanation_prompt( 218 | tokens=["foo", "bar", "baz"], 219 | max_tokens_for_completion=20, 220 | ) 221 | 222 | assert isinstance(prompt, list) 223 | assert isinstance(prompt[0], dict) # Really a HarmonyMessage 224 | for actual_message, expected_message in zip(prompt, expected_prompt): 225 | assert actual_message["role"] == expected_message["role"] 226 | assert actual_message["content"] == expected_message["content"] 227 | assert prompt == expected_prompt 228 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/test_simulator.py: -------------------------------------------------------------------------------- 1 | from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet 2 | from neuron_explainer.explanations.prompt_builder import HarmonyMessage, PromptFormat, Role 3 | from neuron_explainer.explanations.simulator import ( 4 | ExplanationNeuronSimulator, 5 | ExplanationTokenByTokenSimulator, 6 | ) 7 | 8 | 9 | def test_make_explanation_simulation_prompt_if_format() -> None: 10 | expected_prompt = """We're studying neurons in a neural network. 11 | Each neuron looks for some particular thing in a short document. 12 | Look at summary of what the neuron does, and try to predict how it will fire on each token. 13 | 14 | The activation format is tokenactivation, activations go from 0 to 10, "unknown" indicates an unknown activation. Most activations will be 0. 15 | 16 | 17 | Neuron 1 18 | Explanation of neuron 1 behavior: the main thing this neuron does is find vowels 19 | Activations: 20 | 21 | a 10 22 | b 0 23 | c 0 24 | 25 | 26 | d unknown 27 | e 10 28 | f 0 29 | 30 | 31 | 32 | 33 | Neuron 2 34 | Explanation of neuron 2 behavior: the main thing this neuron does is find EXPLANATION<|endofprompt|> 35 | Activations: 36 | 37 | 0 unknown 38 | 1 unknown 39 | 2 unknown 40 | 41 | """ 42 | prompt = ExplanationNeuronSimulator( 43 | model_name="text-davinci-003", 44 | explanation="EXPLANATION", 45 | few_shot_example_set=FewShotExampleSet.TEST, 46 | prompt_format=PromptFormat.INSTRUCTION_FOLLOWING, 47 | ).make_simulation_prompt( 48 | tokens=[str(x) for x in range(3)], 49 | ) 50 | assert prompt == expected_prompt 51 | 52 | 53 | def test_make_explanation_simulation_prompt_harmony_format() -> None: 54 | expected_prompt = [ 55 | HarmonyMessage( 56 | role=Role.SYSTEM, 57 | content="""We're studying neurons in a neural network. 58 | Each neuron looks for some particular thing in a short document. 59 | Look at summary of what the neuron does, and try to predict how it will fire on each token. 60 | 61 | The activation format is tokenactivation, activations go from 0 to 10, "unknown" indicates an unknown activation. Most activations will be 0. 62 | """, 63 | ), 64 | HarmonyMessage( 65 | role=Role.USER, 66 | content=""" 67 | 68 | Neuron 1 69 | Explanation of neuron 1 behavior: the main thing this neuron does is find vowels""", 70 | ), 71 | HarmonyMessage( 72 | role=Role.ASSISTANT, 73 | content=""" 74 | Activations: 75 | 76 | a 10 77 | b 0 78 | c 0 79 | 80 | 81 | d unknown 82 | e 10 83 | f 0 84 | 85 | 86 | """, 87 | ), 88 | HarmonyMessage( 89 | role=Role.USER, 90 | content=""" 91 | 92 | Neuron 2 93 | Explanation of neuron 2 behavior: the main thing this neuron does is find EXPLANATION""", 94 | ), 95 | HarmonyMessage( 96 | role=Role.ASSISTANT, 97 | content=""" 98 | Activations: 99 | 100 | 0 unknown 101 | 1 unknown 102 | 2 unknown 103 | 104 | """, 105 | ), 106 | ] 107 | prompt = ExplanationNeuronSimulator( 108 | model_name="gpt-4", 109 | explanation="EXPLANATION", 110 | few_shot_example_set=FewShotExampleSet.TEST, 111 | prompt_format=PromptFormat.HARMONY_V4, 112 | ).make_simulation_prompt( 113 | tokens=[str(x) for x in range(3)], 114 | ) 115 | 116 | assert isinstance(prompt, list) 117 | assert isinstance(prompt[0], dict) # Really a HarmonyMessage 118 | for actual_message, expected_message in zip(prompt, expected_prompt): 119 | assert actual_message["role"] == expected_message["role"] 120 | assert actual_message["content"] == expected_message["content"] 121 | assert prompt == expected_prompt 122 | 123 | 124 | def test_make_token_by_token_simulation_prompt_if_format() -> None: 125 | expected_prompt = """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. 126 | 127 | The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. 128 | 129 | Neuron 1 130 | Explanation of neuron 1 behavior: the main thing this neuron does is find vowels 131 | Activations: 132 | 133 | a 10 134 | b 0 135 | c 0 136 | 137 | 138 | d 0 139 | e 10 140 | f 0 141 | 142 | 143 | 144 | Now, we're going predict the activation of a new neuron on a single token, following the same rules as the examples above. Activations still range from 0 to 10. 145 | Neuron 2 146 | Explanation of neuron 2 behavior: the main thing this neuron does is find numbers and nothing else 147 | Text: 148 | ghi 149 | 150 | Last token in the text: 151 | i 152 | 153 | Last token activation, considering the token in the context in which it appeared in the text: 154 | 10 155 | 156 | 157 | Neuron 3 158 | Explanation of neuron 3 behavior: the main thing this neuron does is find numbers and nothing else 159 | Text: 160 | 01 161 | 162 | Last token in the text: 163 | 1 164 | 165 | Last token activation, considering the token in the context in which it appeared in the text: 166 | <|endofprompt|>""" 167 | prompt = ExplanationTokenByTokenSimulator( 168 | model_name="text-davinci-003", 169 | explanation="EXPLANATION", 170 | few_shot_example_set=FewShotExampleSet.TEST, 171 | prompt_format=PromptFormat.INSTRUCTION_FOLLOWING, 172 | ).make_single_token_simulation_prompt( 173 | tokens=[str(x) for x in range(3)], 174 | explanation="numbers and nothing else", 175 | token_index_to_score=1, 176 | ) 177 | assert prompt == expected_prompt 178 | 179 | 180 | def test_make_token_by_token_simulation_prompt_harmony_format() -> None: 181 | expected_prompt = [ 182 | HarmonyMessage( 183 | role=Role.SYSTEM, 184 | content="""We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. 185 | 186 | The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. 187 | 188 | """, 189 | ), 190 | HarmonyMessage( 191 | role=Role.USER, 192 | content="""Neuron 1 193 | Explanation of neuron 1 behavior: the main thing this neuron does is find vowels 194 | """, 195 | ), 196 | HarmonyMessage( 197 | role=Role.ASSISTANT, 198 | content="""Activations: 199 | 200 | a 10 201 | b 0 202 | c 0 203 | 204 | 205 | d 0 206 | e 10 207 | f 0 208 | 209 | 210 | 211 | """, 212 | ), 213 | HarmonyMessage( 214 | role=Role.SYSTEM, 215 | content="Now, we're going predict the activation of a new neuron on a single token, following the same rules as the examples above. Activations still range from 0 to 10.", 216 | ), 217 | HarmonyMessage( 218 | role=Role.USER, 219 | content=""" 220 | Neuron 2 221 | Explanation of neuron 2 behavior: the main thing this neuron does is find numbers and nothing else 222 | Text: 223 | ghi 224 | 225 | Last token in the text: 226 | i 227 | 228 | Last token activation, considering the token in the context in which it appeared in the text: 229 | """, 230 | ), 231 | HarmonyMessage( 232 | role=Role.ASSISTANT, 233 | content="""10 234 | 235 | """, 236 | ), 237 | HarmonyMessage( 238 | role=Role.USER, 239 | content=""" 240 | Neuron 3 241 | Explanation of neuron 3 behavior: the main thing this neuron does is find numbers and nothing else 242 | Text: 243 | 01 244 | 245 | Last token in the text: 246 | 1 247 | 248 | Last token activation, considering the token in the context in which it appeared in the text: 249 | """, 250 | ), 251 | ] 252 | 253 | prompt = ExplanationTokenByTokenSimulator( 254 | model_name="gpt-4", 255 | explanation="EXPLANATION", 256 | few_shot_example_set=FewShotExampleSet.TEST, 257 | prompt_format=PromptFormat.HARMONY_V4, 258 | ).make_single_token_simulation_prompt( 259 | tokens=[str(x) for x in range(3)], 260 | explanation="numbers and nothing else", 261 | token_index_to_score=1, 262 | ) 263 | 264 | assert isinstance(prompt, list) 265 | assert isinstance(prompt[0], dict) # Really a HarmonyMessage 266 | for actual_message, expected_message in zip(prompt, expected_prompt): 267 | assert actual_message["role"] == expected_message["role"] 268 | assert actual_message["content"] == expected_message["content"] 269 | assert prompt == expected_prompt 270 | -------------------------------------------------------------------------------- /neuron_explainer/explanations/token_space_few_shot_examples.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List 4 | 5 | from neuron_explainer.fast_dataclasses import FastDataclass 6 | 7 | 8 | @dataclass 9 | class Example(FastDataclass): 10 | """ 11 | An example list of tokens as strings corresponding to top token space inputs of a neuron, with a 12 | string explanation of the neuron's behavior on these tokens. 13 | """ 14 | 15 | tokens: List[str] 16 | explanation: str 17 | 18 | 19 | class TokenSpaceFewShotExampleSet(Enum): 20 | """Determines which few-shot examples to use when sampling explanations.""" 21 | 22 | ORIGINAL = "original" 23 | TEST = "test" 24 | 25 | def get_examples(self) -> list[Example]: 26 | """Returns regular examples for use in a few-shot prompt.""" 27 | if self is TokenSpaceFewShotExampleSet.ORIGINAL: 28 | return ORIGINAL_EXAMPLES 29 | elif self is TokenSpaceFewShotExampleSet.TEST: 30 | return TEST_EXAMPLES 31 | else: 32 | raise ValueError(f"Unhandled example set: {self}") 33 | 34 | 35 | ORIGINAL_EXAMPLES = [ 36 | Example( 37 | tokens=[ 38 | "actual", 39 | " literal", 40 | " actual", 41 | " hyper", 42 | " real", 43 | " EX", 44 | " Real", 45 | "^", 46 | "Full", 47 | " full", 48 | " optical", 49 | " style", 50 | "any", 51 | "ALL", 52 | "extreme", 53 | " miniature", 54 | " Optical", 55 | " faint", 56 | "~", 57 | " Physical", 58 | " REAL", 59 | "*", 60 | "virtual", 61 | "TYPE", 62 | " technical", 63 | "otally", 64 | " physic", 65 | "Type", 66 | "<", 67 | "images", 68 | "atic", 69 | " sheer", 70 | " Style", 71 | " partial", 72 | " natural", 73 | "Hyper", 74 | " Any", 75 | " theoretical", 76 | "|", 77 | " ultimate", 78 | "oing", 79 | " constant", 80 | "ANY", 81 | "antically", 82 | "ishly", 83 | " ex", 84 | " visual", 85 | "special", 86 | "omorphic", 87 | "visual", 88 | ], 89 | explanation=" adjectives related to being real, or to physical properties and evidence", 90 | ), 91 | Example( 92 | tokens=[ 93 | "cephal", 94 | "aeus", 95 | " coma", 96 | "bered", 97 | "abetes", 98 | "inflamm", 99 | "rugged", 100 | "alysed", 101 | "azine", 102 | "hered", 103 | "cells", 104 | "aneously", 105 | "fml", 106 | "igm", 107 | "culosis", 108 | "iani", 109 | "CTV", 110 | "disabled", 111 | "heric", 112 | "ulo", 113 | "geoning", 114 | "awi", 115 | "translation", 116 | "iral", 117 | "govtrack", 118 | "mson", 119 | "cloth", 120 | "nesota", 121 | " Dise", 122 | " Lyme", 123 | " dementia", 124 | "agn", 125 | " reversible", 126 | " susceptibility", 127 | "esthesia", 128 | "orf", 129 | " inflamm", 130 | " Obesity", 131 | " tox", 132 | " Disorders", 133 | "uberty", 134 | "blind", 135 | "ALTH", 136 | "avier", 137 | " Immunity", 138 | " Hurt", 139 | "ulet", 140 | "ueless", 141 | " sluggish", 142 | "rosis", 143 | ], 144 | explanation=" words related to physical medical conditions", 145 | ), 146 | Example( 147 | tokens=[ 148 | " January", 149 | "terday", 150 | "cember", 151 | " April", 152 | " July", 153 | "September", 154 | "December", 155 | "Thursday", 156 | "quished", 157 | "November", 158 | "Tuesday", 159 | "uesday", 160 | " Sept", 161 | "ruary", 162 | " March", 163 | ";;;;;;;;;;;;", 164 | " Monday", 165 | "Wednesday", 166 | " Saturday", 167 | " Wednesday", 168 | "Reloaded", 169 | "aturday", 170 | " August", 171 | "Feb", 172 | "Sunday", 173 | "Reviewed", 174 | "uggest", 175 | " Dhabi", 176 | "ACTED", 177 | "tten", 178 | "Year", 179 | "August", 180 | "alogue", 181 | "MX", 182 | " Janeiro", 183 | "yss", 184 | " Leilan", 185 | " Fiscal", 186 | " referen", 187 | "semb", 188 | "eele", 189 | "wcs", 190 | "detail", 191 | "ertation", 192 | " Reborn", 193 | " Sunday", 194 | "itially", 195 | "aturdays", 196 | " Dise", 197 | "essage", 198 | ], 199 | explanation=" nouns related to time and dates", 200 | ), 201 | ] 202 | 203 | TEST_EXAMPLES = [ 204 | Example( 205 | tokens=[ 206 | "these", 207 | " are", 208 | " tokens", 209 | ], 210 | explanation=" this is a test explanation", 211 | ), 212 | ] 213 | -------------------------------------------------------------------------------- /neuron_explainer/fast_dataclasses/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_dataclasses import FastDataclass, dumps, loads, register_dataclass 2 | 3 | __all__ = ["FastDataclass", "dumps", "loads", "register_dataclass"] 4 | -------------------------------------------------------------------------------- /neuron_explainer/fast_dataclasses/fast_dataclasses.py: -------------------------------------------------------------------------------- 1 | # Utilities for dataclasses that are very fast to serialize and deserialize, with limited data 2 | # validation. Fields must not be tuples, since they get serialized and then deserialized as lists. 3 | # 4 | # The unit tests for this library show how to use it. 5 | 6 | import json 7 | from dataclasses import dataclass, field, fields, is_dataclass 8 | from functools import partial 9 | from typing import Any, Union 10 | 11 | import orjson 12 | 13 | dataclasses_by_name = {} 14 | dataclasses_by_fieldnames = {} 15 | 16 | 17 | @dataclass 18 | class FastDataclass: 19 | dataclass_name: str = field(init=False) 20 | 21 | def __post_init__(self) -> None: 22 | self.dataclass_name = self.__class__.__name__ 23 | 24 | 25 | def register_dataclass(cls): # type: ignore 26 | assert is_dataclass(cls), "Only dataclasses can be registered." 27 | dataclasses_by_name[cls.__name__] = cls 28 | name_set = frozenset(f.name for f in fields(cls) if f.name != "dataclass_name") 29 | dataclasses_by_fieldnames[name_set] = cls 30 | return cls 31 | 32 | 33 | def dumps(obj: Any) -> bytes: 34 | return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY) 35 | 36 | 37 | def _object_hook(d: Any, backwards_compatible: bool = True) -> Any: 38 | # If d is a list, recurse. 39 | if isinstance(d, list): 40 | return [_object_hook(x, backwards_compatible=backwards_compatible) for x in d] 41 | # If d is not a dict, return it as is. 42 | if not isinstance(d, dict): 43 | return d 44 | cls = None 45 | if "dataclass_name" in d: 46 | if d["dataclass_name"] in dataclasses_by_name: 47 | cls = dataclasses_by_name[d["dataclass_name"]] 48 | else: 49 | assert backwards_compatible, ( 50 | f"Dataclass {d['dataclass_name']} not found, set backwards_compatible=True if you " 51 | f"are okay with that." 52 | ) 53 | # Load objects created without dataclass_name set. 54 | else: 55 | # Try our best to find a dataclass if backwards_compatible is True. 56 | if backwards_compatible: 57 | d_fields = frozenset(d.keys()) 58 | if d_fields in dataclasses_by_fieldnames: 59 | cls = dataclasses_by_fieldnames[d_fields] 60 | elif len(d_fields) > 0: 61 | # Check if the fields are a subset of a dataclass (if the dataclass had extra fields 62 | # added since the data was created). Note that this will fail if fields were removed 63 | # from the dataclass. 64 | for key, possible_cls in dataclasses_by_fieldnames.items(): 65 | if d_fields.issubset(key): 66 | cls = possible_cls 67 | break 68 | else: 69 | print(f"Could not find dataclass for {d_fields} {cls}") 70 | new_d = { 71 | k: _object_hook(v, backwards_compatible=backwards_compatible) 72 | for k, v in d.items() 73 | if k != "dataclass_name" 74 | } 75 | if cls is not None: 76 | return cls(**new_d) 77 | else: 78 | return new_d 79 | 80 | 81 | def loads(s: Union[str, bytes], backwards_compatible: bool = True) -> Any: 82 | return json.loads( 83 | s, 84 | object_hook=partial(_object_hook, backwards_compatible=backwards_compatible), 85 | ) 86 | -------------------------------------------------------------------------------- /neuron_explainer/fast_dataclasses/test_fast_dataclasses.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pytest 4 | 5 | from .fast_dataclasses import FastDataclass, dumps, loads, register_dataclass 6 | 7 | 8 | # Inheritance is a bit tricky with our setup. dataclass_name must be set for instances of these 9 | # classes to serialize and deserialize correctly, but if it's given a default value, then subclasses 10 | # can't have any fields that don't have default values, because of how constructors are generated 11 | # for dataclasses (fields with no default value can't follow those with default values). To work 12 | # around this, we set dataclass_name in __post_init__ on the base class, which is called after the 13 | # constructor. The implementation does the right thing for both the base class and the subclass. 14 | @register_dataclass 15 | @dataclass 16 | class DataclassC(FastDataclass): 17 | ints: list[int] 18 | 19 | 20 | @register_dataclass 21 | @dataclass 22 | class DataclassC_ext(DataclassC): 23 | s: str 24 | 25 | 26 | @register_dataclass 27 | @dataclass 28 | class DataclassB(FastDataclass): 29 | str_to_c: dict[str, DataclassC] 30 | cs: list[DataclassC] 31 | 32 | 33 | @register_dataclass 34 | @dataclass 35 | class DataclassA(FastDataclass): 36 | floats: list[float] 37 | strings: list[str] 38 | bs: list[DataclassB] 39 | 40 | 41 | @register_dataclass 42 | @dataclass 43 | class DataclassD(FastDataclass): 44 | s1: str 45 | s2: str = "default" 46 | 47 | 48 | def test_dataclasses() -> None: 49 | a = DataclassA( 50 | floats=[1.0, 2.0], 51 | strings=["a", "b"], 52 | bs=[ 53 | DataclassB( 54 | str_to_c={"a": DataclassC(ints=[1, 2]), "b": DataclassC(ints=[3, 4])}, 55 | cs=[DataclassC(ints=[5, 6]), DataclassC_ext(ints=[7, 8], s="s")], 56 | ), 57 | DataclassB( 58 | str_to_c={"c": DataclassC_ext(ints=[9, 10], s="t"), "d": DataclassC(ints=[11, 12])}, 59 | cs=[DataclassC(ints=[13, 14]), DataclassC(ints=[15, 16])], 60 | ), 61 | ], 62 | ) 63 | assert loads(dumps(a)) == a 64 | 65 | 66 | def test_c_and_c_ext() -> None: 67 | c_ext = DataclassC_ext(ints=[3, 4], s="s") 68 | assert loads(dumps(c_ext)) == c_ext 69 | 70 | c = DataclassC(ints=[1, 2]) 71 | assert loads(dumps(c)) == c 72 | 73 | 74 | def test_bad_serialized_data() -> None: 75 | assert type(loads(dumps(DataclassC(ints=[3, 4])))) == DataclassC 76 | assert type(loads('{"ints": [3, 4]}', backwards_compatible=False)) == dict 77 | assert type(loads('{"ints": [3, 4], "dataclass_name": "DataclassC"}')) == DataclassC 78 | with pytest.raises(TypeError): 79 | loads('{"ints": [3, 4], "bogus_extra_field": "foo", "dataclass_name": "DataclassC"}') 80 | with pytest.raises(TypeError): 81 | loads('{"ints_field_is_missing": [3, 4], "dataclass_name": "DataclassC"}') 82 | assert type(loads('{"s1": "test"}', backwards_compatible=False)) == dict 83 | assert type(loads('{"s1": "test"}', backwards_compatible=True)) == DataclassD 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "automated-interpretability" 3 | version = "0.0.13" 4 | description = "OpenAI and Neuronpedia's implementation of automated-interpretability, with some updates. Not officially affiliated with OpenAI." 5 | authors = ["OpenAI, Neuronpedia"] 6 | packages = [{ include = "neuron_explainer" }] 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.9" 10 | httpx = "^0.27.0" 11 | tiktoken = ">=0.6.0" 12 | scikit-learn = "^1.2.0" 13 | boostedblob = "^0.15.3" 14 | blobfile = "^2.1.1" 15 | numpy = "^1.24.0" 16 | orjson = "^3.10.1" 17 | 18 | [tool.poetry.group.dev.dependencies] 19 | pytest = "^8.1.2" 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="neuron_explainer", 5 | packages=find_packages(), 6 | version="0.0.13", 7 | author="OpenAI, Neuronpedia", 8 | install_requires=[ 9 | "httpx>=0.22", 10 | "scikit-learn", 11 | "boostedblob>=0.13.0", 12 | "tiktoken", 13 | "blobfile", 14 | "numpy", 15 | "pytest", 16 | "orjson", 17 | ], 18 | url="", 19 | description="", 20 | python_requires=">=3.9", 21 | ) 22 | --------------------------------------------------------------------------------