├── tests ├── __init__.py ├── tracr_dataset_test.py ├── linear_compressed_tracr_transformer_test.py ├── get_cases_test.py ├── get_data_test.py ├── ground_truth_circuit_test.py ├── hooked_tracr_transformer_test.py └── utils.py ├── circuits_benchmark ├── __init__.py ├── metrics │ ├── __init__.py │ ├── resampling_ablation_loss │ │ └── __init__.py │ └── sparsity.py ├── utils │ ├── __init__.py │ ├── hf │ │ ├── __init__.py │ │ └── hf_uploader.py │ ├── circuit │ │ ├── __init__.py │ │ ├── circuit_granularity.py │ │ ├── circuit_node_view.py │ │ ├── circuit_node.py │ │ └── edges_list.py │ ├── ll_model_loader │ │ ├── __init__.py │ │ ├── natural_model_loader.py │ │ ├── ll_model_loader.py │ │ ├── ground_truth_model_loader.py │ │ ├── best_weights.py │ │ └── ll_model_loader_factory.py │ ├── iit │ │ ├── __init__.py │ │ ├── iit_dataset_batch.py │ │ ├── tracr_model_pair.py │ │ ├── iit_hl_model.py │ │ └── ll_cfg.py │ ├── attr_dict.py │ ├── circleci.py │ ├── cloudpickle.py │ ├── node_sp.py │ ├── wandb_artifact_download.py │ ├── init_functions.py │ ├── project_paths.py │ ├── get_cases.py │ ├── auto_circuit_utils.py │ └── find_all_subclasses.py ├── benchmark │ ├── __init__.py │ ├── cases │ │ ├── __init__.py │ │ ├── case_21.py │ │ ├── case_46.py │ │ ├── case_52.py │ │ ├── case_60.py │ │ ├── case_8.py │ │ ├── case_50.py │ │ ├── case_121.py │ │ ├── case_54.py │ │ ├── case_122.py │ │ ├── case_123.py │ │ ├── case_2.py │ │ ├── case_73.py │ │ ├── case_101.py │ │ ├── case_77.py │ │ ├── case_85.py │ │ ├── case_4.py │ │ ├── case_41.py │ │ ├── case_70.py │ │ ├── case_7.py │ │ ├── case_12.py │ │ ├── case_84.py │ │ ├── case_9.py │ │ ├── case_104.py │ │ ├── case_75.py │ │ ├── case_55.py │ │ ├── case_49.py │ │ ├── case_91.py │ │ ├── case_106.py │ │ ├── case_64.py │ │ ├── case_80.py │ │ ├── case_87.py │ │ ├── case_68.py │ │ ├── case_79.py │ │ ├── case_114.py │ │ ├── case_3.py │ │ ├── case_48.py │ │ ├── case_65.py │ │ ├── case_72.py │ │ ├── case_110.py │ │ ├── case_86.py │ │ ├── case_83.py │ │ ├── case_69.py │ │ ├── case_66.py │ │ ├── case_40.py │ │ ├── case_53.py │ │ ├── case_28.py │ │ ├── case_23.py │ │ ├── case_62.py │ │ ├── case_116.py │ │ ├── case_29.py │ │ ├── case_33.py │ │ ├── case_103.py │ │ ├── case_26.py │ │ ├── case_63.py │ │ ├── case_74.py │ │ ├── case_44.py │ │ ├── case_42.py │ │ ├── case_107.py │ │ ├── case_131.py │ │ ├── case_128.py │ │ ├── case_120.py │ │ ├── case_24.py │ │ ├── case_37.py │ │ ├── case_100.py │ │ ├── case_30.py │ │ ├── case_31.py │ │ ├── case_51.py │ │ ├── case_105.py │ │ ├── case_90.py │ │ ├── case_97.py │ │ ├── case_117.py │ │ ├── case_11.py │ │ ├── case_32.py │ │ ├── case_47.py │ │ ├── case_43.py │ │ ├── case_61.py │ │ ├── case_35.py │ │ ├── case_36.py │ │ ├── case_78.py │ │ ├── case_20.py │ │ ├── case_93.py │ │ ├── case_25.py │ │ ├── case_34.py │ │ ├── case_95.py │ │ ├── case_118.py │ │ ├── case_71.py │ │ ├── case_76.py │ │ ├── case_22.py │ │ ├── case_112.py │ │ ├── case_92.py │ │ ├── case_10.py │ │ ├── case_14.py │ │ ├── case_39.py │ │ ├── case_111.py │ │ ├── case_56.py │ │ ├── case_6.py │ │ ├── case_102.py │ │ ├── case_67.py │ │ ├── case_115.py │ │ ├── case_99.py │ │ ├── case_57.py │ │ ├── case_18.py │ │ ├── case_82.py │ │ ├── case_19.py │ │ ├── case_ioi_next_token.py │ │ ├── case_81.py │ │ ├── case_38.py │ │ ├── case_130.py │ │ ├── case_98.py │ │ ├── case_113.py │ │ ├── case_59.py │ │ ├── case_45.py │ │ ├── case_15.py │ │ ├── case_58.py │ │ ├── case_13.py │ │ ├── case_96.py │ │ ├── case_88.py │ │ ├── case_127.py │ │ ├── case_119.py │ │ └── case_126.py │ ├── program_evaluation_type.py │ ├── case_dataset.py │ ├── vocabs.py │ └── tracr_encoded_dataset.py ├── commands │ ├── __init__.py │ ├── train │ │ ├── __init__.py │ │ ├── iit │ │ │ └── __init__.py │ │ ├── compression │ │ │ ├── __init__.py │ │ │ └── compression_training_utils.py │ │ └── train.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── iit │ │ │ └── __init__.py │ │ ├── realism │ │ │ └── __init__.py │ │ └── evaluation.py │ ├── algorithms │ │ ├── __init__.py │ │ └── run_algorithm.py │ └── build_main_parser.py ├── training │ ├── __init__.py │ ├── compression │ │ ├── __init__.py │ │ └── activation_mapper │ │ │ ├── __init__.py │ │ │ ├── linear_mapper.py │ │ │ ├── activation_mapper.py │ │ │ └── autoencoder_mapper.py │ └── training_args.py └── transformers │ └── __init__.py ├── metadata ├── .gitignore └── benchmark_base_metadata.json ├── pytest.ini ├── .dockerignore ├── .gitignore ├── main.py ├── LICENSE ├── Dockerfile ├── EXPERIMENTS.md ├── pyproject.toml └── .circleci └── config.yml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/hf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circuit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/train/iit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/training/compression/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metadata/.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.parquet 3 | *.json -------------------------------------------------------------------------------- /circuits_benchmark/commands/evaluation/iit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/evaluation/realism/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/train/compression/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/metrics/resampling_ablation_loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /circuits_benchmark/training/compression/activation_mapper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --ignore=submodules --ignore=tracr --ignore=acdc -------------------------------------------------------------------------------- /circuits_benchmark/utils/iit/__init__.py: -------------------------------------------------------------------------------- 1 | from .ll_cfg import make_ll_cfg_for_case 2 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | """Circuit discovery algorithms.""" 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | **/*.pkl 2 | **/__pycache__ 3 | **/*.pt 4 | **/*.pth 5 | wandb 6 | results 7 | k8s 8 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/program_evaluation_type.py: -------------------------------------------------------------------------------- 1 | def only_non_causal(func): 2 | return func 3 | 4 | 5 | def causal_and_regular(func): 6 | return func 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .env 3 | .idea 4 | *.pkl 5 | results 6 | *.pth 7 | .ipynb_checkpoints 8 | rough_* 9 | *.json 10 | *.png 11 | *.log 12 | *.csv 13 | .DS_Store 14 | wandb 15 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | """A dictionary that allows access to its values via attributes (e.g., using dot notation).""" 3 | 4 | def __getattr__(self, name): 5 | return self[name] 6 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circuit/circuit_granularity.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Literal 3 | 4 | CircuitGranularity = Literal["component", "matrix", "acdc_hooks", "sp_hooks"] 5 | circuit_granularity_options = list(typing.get_args(CircuitGranularity)) 6 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/iit/iit_dataset_batch.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch as t 4 | 5 | Inputs = t.Tensor 6 | Targets = t.Tensor 7 | BaseData = Tuple[Inputs, Targets] 8 | AblationData = Tuple[Inputs, Targets] 9 | IITDatasetBatch = Tuple[BaseData, AblationData] 10 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circleci.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def is_running_in_circleci() -> bool: 5 | in_circleci = 'CIRCLECI' in os.environ 6 | 7 | if in_circleci: 8 | print('Running on CircleCI') 9 | 10 | return in_circleci 11 | 12 | 13 | def get_circleci_cases_percentage() -> float: 14 | return 0.25 15 | -------------------------------------------------------------------------------- /metadata/benchmark_base_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "InterpBench", 3 | "version": "1.0.0", 4 | "description": "A benchmark of transformers with known circuits for evaluating mechanistic interpretability techniques.", 5 | "license": "https://creativecommons.org/licenses/by/4.0/", 6 | "url": "https://huggingface.co/cybershiptrooper/InterpBench" 7 | } -------------------------------------------------------------------------------- /circuits_benchmark/utils/iit/tracr_model_pair.py: -------------------------------------------------------------------------------- 1 | from iit.model_pairs.strict_iit_model_pair import StrictIITModelPair 2 | from iit.utils import index 3 | 4 | 5 | class TracrModelPair(StrictIITModelPair): 6 | @staticmethod 7 | def get_label_idxs(): 8 | # Discard from all batches the first position, which is for the BOS token 9 | return index.Ix[:, 1:] 10 | -------------------------------------------------------------------------------- /tests/tracr_dataset_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from circuits_benchmark.benchmark.cases.case_3 import Case3 4 | from circuits_benchmark.benchmark.cases.case_9 import Case9 5 | 6 | 7 | class TestTracrDataset: 8 | 9 | @pytest.mark.parametrize("case", [Case3(), Case9()]) 10 | def test_get_encoded_dataset(self, case): 11 | data = case.get_clean_data() 12 | assert len(data) == 10 13 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/cloudpickle.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from cloudpickle import cloudpickle 4 | 5 | 6 | def load_from_pickle(path) -> object | None: 7 | if os.path.exists(path): 8 | with open(path, "rb") as f: 9 | return cloudpickle.load(f) 10 | else: 11 | return None 12 | 13 | 14 | def dump_to_pickle(path, obj) -> None: 15 | with open(path, "wb") as f: 16 | cloudpickle.dump(obj, f) 17 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/node_sp.py: -------------------------------------------------------------------------------- 1 | from acdc.docstring.utils import AllDataThings 2 | from subnetwork_probing.train import NodeLevelMaskedTransformer 3 | from subnetwork_probing.train import train_sp as train_node_sp 4 | 5 | 6 | def train_sp( 7 | args, 8 | masked_model: NodeLevelMaskedTransformer, 9 | all_task_things: AllDataThings, 10 | ): 11 | return train_node_sp( 12 | args=args, 13 | masked_model=masked_model, 14 | all_task_things=all_task_things, 15 | ) 16 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circuit/circuit_node_view.py: -------------------------------------------------------------------------------- 1 | from networkx.classes.reportviews import NodeView 2 | 3 | from circuits_benchmark.utils.circuit.circuit_node import CircuitNode 4 | 5 | 6 | class CircuitNodeView(NodeView): 7 | def __contains__(self, item: str | CircuitNode): 8 | if isinstance(item, str): 9 | return any([item == node.name for node in self._nodes]) 10 | elif isinstance(item, CircuitNode): 11 | return any([item == node for node in self._nodes]) 12 | else: 13 | return False 14 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/case_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | 6 | class CaseDataset(Dataset): 7 | def get_inputs(self): 8 | raise NotImplementedError() 9 | 10 | def get_targets(self): 11 | raise NotImplementedError() 12 | 13 | @staticmethod 14 | def collate_fn(batch): 15 | raise NotImplementedError() 16 | 17 | def make_loader( 18 | self, 19 | batch_size: int | None = None, 20 | shuffle: bool | None = False, 21 | ) -> DataLoader: 22 | raise NotImplementedError() 23 | -------------------------------------------------------------------------------- /tests/linear_compressed_tracr_transformer_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from circuits_benchmark.benchmark.cases.case_3 import Case3 4 | from circuits_benchmark.training.compression.linear_compressed_tracr_transformer import LinearCompressedTracrTransformer 5 | 6 | 7 | class LinearCompressedTracrTransformerTest(unittest.TestCase): 8 | def test_named_parameters_fold_compression_matrix(self): 9 | case = Case3() 10 | 11 | compressed_tracr_transformer = LinearCompressedTracrTransformer( 12 | case.get_hl_model(), 13 | int(9), 14 | "linear") 15 | 16 | list(compressed_tracr_transformer.named_parameters()) 17 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_21.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_unique_token_extractor 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case21(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_unique_token_extractor(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Extract unique tokens from a string" 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_ascii_letters_vocab(count=3) 19 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_46.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case46(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_decrement() 12 | 13 | def get_task_description(self) -> str: 14 | return "Decrements each element in the sequence by 1" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_decrement() -> rasp.SOp: 21 | return rasp.Map(lambda x: x - 1, rasp.tokens).named("decrement") 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_52.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case52(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_square_root() 11 | 12 | def get_task_description(self) -> str: 13 | return "Takes the square root of each element." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | 19 | def make_square_root() -> rasp.SOp: 20 | return rasp.Map(lambda x: x ** 0.5, rasp.tokens).named("square_root") 21 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_60.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case60(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_increment() 11 | 12 | def get_task_description(self) -> str: 13 | return "Increment each element in the sequence by 1." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | 19 | def make_increment() -> rasp.SOp: 20 | return rasp.Map(lambda x: x + 1, rasp.tokens).named("increment") 21 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_8.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case8(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_identity() 12 | 13 | def get_task_description(self) -> str: 14 | return "Identity" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_identity() -> rasp.SOp: 24 | return rasp.Map(lambda x: x, rasp.tokens) 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_50.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set, Sequence 3 | 4 | from circuits_benchmark.benchmark import vocabs 5 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 6 | from tracr.rasp import rasp 7 | 8 | 9 | class Case50(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_hyperbolic_cosine() 12 | 13 | def get_task_description(self) -> str: 14 | return "Applies the hyperbolic cosine to each element" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_hyperbolic_cosine() -> rasp.SOp: 21 | return rasp.Map(math.cosh, rasp.tokens) 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_121.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case121(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_arcsine() 13 | 14 | def get_task_description(self) -> str: 15 | return "Compute arcsine of all elements in the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_float_numbers_vocab(min=-1, max=1) 19 | 20 | 21 | def make_arcsine(): 22 | return rasp.Map(lambda x: math.asin(x), rasp.tokens) 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_54.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | import math 4 | from circuits_benchmark.benchmark import vocabs 5 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 6 | from tracr.rasp import rasp 7 | 8 | 9 | class Case54(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_hyperbolic_tangent() 12 | 13 | def get_task_description(self) -> str: 14 | return "Applies the hyperbolic tangent to each element." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_hyperbolic_tangent() -> rasp.SOp: 21 | return rasp.Map(math.tanh, rasp.tokens) 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_122.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case122(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_check_divisibility() 12 | 13 | def get_task_description(self) -> str: 14 | return "Check if each number is divisible by 3." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=30) 18 | 19 | 20 | def make_check_divisibility(divisor=3): 21 | return rasp.Map(lambda x: 1 if x % divisor == 0 else 0, rasp.tokens) 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_123.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case123(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_arccosine() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply arccosine to each element of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_float_numbers_vocab(min=-1, max=1) 19 | 20 | 21 | def make_arccosine(): 22 | return rasp.Map(lambda x: math.acos(x), rasp.tokens) 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_2.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_reverse 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case2(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_reverse(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Reverse the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_ascii_letters_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_73.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case73(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_sine() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply the sine function to each element of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_sine() -> rasp.SOp: 22 | return rasp.Map(lambda x: math.sin(x), rasp.tokens).named("sine") 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_101.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case101(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_check_square() 12 | 13 | def get_task_description(self) -> str: 14 | return "Check if each element is a square of an integer." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=30) 18 | 19 | 20 | def make_check_square() -> rasp.SOp: 21 | return rasp.Map(lambda x: 1 if x ** 0.5 == int(x ** 0.5) else 0, rasp.tokens) 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_77.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case77(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_tangent() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply the tangent function to each element of the sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_tangent() -> rasp.SOp: 22 | return rasp.Map(lambda x: math.tan(x), rasp.tokens).named("tangent") 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_85.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case85(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_square_each_element() 12 | 13 | def get_task_description(self) -> str: 14 | return "Square each element of the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_square_each_element() -> rasp.SOp: 21 | return rasp.Map(lambda x: x ** 2, rasp.tokens).named("square_each_element") 22 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/natural_model_loader.py: -------------------------------------------------------------------------------- 1 | from circuits_benchmark.utils.ll_model_loader.siit_model_loader import SIITModelLoader 2 | 3 | 4 | class NaturalModelLoader(SIITModelLoader): 5 | """Natural model loader is just an SIIT model loader for weights 100.""" 6 | 7 | def __init__(self, 8 | case, 9 | load_from_wandb: bool | None = False, 10 | wandb_project: str | None = None, 11 | wandb_name: str | None = None): 12 | super().__init__( 13 | case, 14 | "100", 15 | load_from_wandb=load_from_wandb, 16 | wandb_project=wandb_project, 17 | wandb_name=wandb_name 18 | ) 19 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_4.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_pair_balance 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case4(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_pair_balance(rasp.tokens, "(", ")") 13 | 14 | def get_task_description(self) -> str: 15 | return "Return fraction of previous open tokens minus the fraction of close tokens." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_ascii_letters_vocab(count=3).union({"(", ")"}) 19 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_41.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case41(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_absolute() 12 | 13 | def get_task_description(self) -> str: 14 | return "Make each element of the input sequence absolute" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(min=-10, max=10) 18 | 19 | 20 | def make_absolute() -> rasp.SOp: 21 | return rasp.Map(lambda x: x if x >= 0 else -x, rasp.tokens).named("make_absolute") 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_70.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case70(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_cosine() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply the cosine function to each element of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_cosine() -> rasp.SOp: 22 | return rasp.Map(lambda x: math.cos(x), rasp.tokens).named("apply_cosine") 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_7.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_hist 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case7(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_hist() 13 | 14 | def get_task_description(self) -> str: 15 | return "Returns the number of times each token occurs in the input." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_ascii_letters_vocab(count=3) 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_12.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import detect_pattern 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case12(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return detect_pattern(rasp.tokens, "abc") 13 | 14 | def get_task_description(self) -> str: 15 | return "Detect the pattern 'abc' in the input string." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_ascii_letters_vocab(count=3) 19 | 20 | def get_max_seq_len(self) -> int: 21 | return 15 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_84.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case84(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_arctangent() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply the arctangent function to each element of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_arctangent() -> rasp.SOp: 22 | return rasp.Map(lambda x: math.atan(x), rasp.tokens).named("arctangent") 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_9.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_sort 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case9(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_sort(rasp.tokens, rasp.tokens, 10, 1) 13 | 14 | def get_task_description(self) -> str: 15 | return "Sort a list of integers in ascending order." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_int_digits_vocab() 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_104.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case104(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_exponential() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply exponential function to all elements of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_exponential() -> rasp.SOp: 22 | exp_approx = rasp.Map(lambda x: math.exp(x), rasp.tokens) 23 | return exp_approx 24 | -------------------------------------------------------------------------------- /circuits_benchmark/metrics/sparsity.py: -------------------------------------------------------------------------------- 1 | from circuits_benchmark.training.compression.linear_compressed_tracr_transformer import LinearCompressedTracrTransformer 2 | 3 | 4 | def get_zero_weights_pct(model, atol=1e-8): 5 | """Returns the percentage of weights that are zero (or very close to zero given the atol).""" 6 | weights = [] 7 | if isinstance(model, LinearCompressedTracrTransformer): 8 | params = list(model.folded_parameters()) 9 | else: 10 | params = list(model.parameters()) 11 | 12 | for param in params: 13 | weights.extend(param.flatten().tolist()) 14 | 15 | non_zero_weights = len([w for w in weights if abs(w) > atol]) 16 | non_zero_weights_pct = non_zero_weights / len(weights) 17 | 18 | return 1 - non_zero_weights_pct 19 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/ll_model_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from iit.utils.correspondence import Correspondence 4 | from transformer_lens import HookedTransformer 5 | 6 | from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase 7 | 8 | 9 | class LLModelLoader(object): 10 | def __init__(self, case: BenchmarkCase): 11 | self.case = case 12 | 13 | def load_ll_model_and_correspondence( 14 | self, 15 | device: str, 16 | output_dir: Optional[str] = None, 17 | same_size: bool = False, 18 | *args, **kwargs 19 | ) -> Tuple[Correspondence, HookedTransformer]: 20 | raise NotImplementedError() 21 | 22 | def get_output_suffix(self): 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_75.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case75(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_element_double() 12 | 13 | def get_task_description(self) -> str: 14 | return "Double each element of the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_element_double() -> rasp.SOp: 21 | # Apply the doubling function to each element of the input sequence. 22 | return rasp.Map(lambda x: x * 2, rasp.tokens).named("double_elements") 23 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_55.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set, Sequence 3 | 4 | from circuits_benchmark.benchmark import vocabs 5 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 6 | from tracr.rasp import rasp 7 | 8 | 9 | class Case55(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_hyperbolic_sine() 12 | 13 | def get_task_description(self) -> str: 14 | return "Applies the hyperbolic sine to each element." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_hyperbolic_sine() -> rasp.SOp: 21 | hyperbolic_sine = rasp.Map(lambda x: math.sinh(x), rasp.tokens).named("hyperbolic_sine") 22 | 23 | return hyperbolic_sine 24 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_49.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case49(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_decrement_to_multiple_of_three() 11 | 12 | def get_task_description(self) -> str: 13 | return "Decrements each element in the sequence until it becomes a multiple of 3." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | 19 | def make_decrement_to_multiple_of_three() -> rasp.SOp: 20 | return rasp.Map(lambda x: x - x % 3, rasp.tokens).named("decrement_to_multiple_of_three") 21 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_91.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case91(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_apply_threshold() 12 | 13 | def get_task_description(self) -> str: 14 | return "Set all values below a threshold to 0" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_apply_threshold(threshold=3) -> rasp.SOp: 21 | apply_threshold_operation = rasp.Map(lambda x: 0 if x < threshold else x, rasp.tokens).named("apply_threshold") 22 | 23 | return apply_threshold_operation 24 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_106.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case106(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_mask_sequence() 12 | 13 | def get_task_description(self) -> str: 14 | return "Sets all elements to zero except for the element at index 1." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_mask_sequence(index=1) -> rasp.SOp: 24 | return rasp.SequenceMap(lambda x, y: x if y == index else 0, rasp.tokens, rasp.indices) 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_64.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case64(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_cube_each_element() 12 | 13 | def get_task_description(self) -> str: 14 | return "Cubes each element in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_cube_each_element() -> rasp.SOp: 21 | # Apply the cubing function to each element of the input sequence. 22 | cube_sequence = rasp.Map(lambda x: x ** 3, rasp.tokens).named("cube_sequence") 23 | 24 | return cube_sequence 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_80.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case80(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_element_subtract_constant() 12 | 13 | def get_task_description(self) -> str: 14 | return "Subtract a constant from each element of the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_element_subtract_constant(constant=2) -> rasp.SOp: 21 | subtract_constant = rasp.Map(lambda x: x - constant, rasp.tokens).named(f"subtract_{constant}") 22 | 23 | return subtract_constant 24 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_87.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case87(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_binarize() 12 | 13 | def get_task_description(self) -> str: 14 | return "Binarize a sequence of integers using a threshold." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_binarize(threshold=3) -> rasp.SOp: 21 | compare_to_threshold = rasp.Map(lambda x: x >= threshold, rasp.tokens) 22 | binarized_sequence = rasp.Map(lambda x: 1 if x else 0, compare_to_threshold) 23 | return binarized_sequence 24 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_68.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case68(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_increment_to_multiple_of_three() 12 | 13 | def get_task_description(self) -> str: 14 | return "Increment each element until it becomes a multiple of 3" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_increment_to_multiple_of_three() -> rasp.SOp: 21 | increment_to_multiple_of_three = rasp.Map(lambda x: x + (3 - (x % 3)) % 3, rasp.tokens).named( 22 | "increment_to_multiple_of_three") 23 | 24 | return increment_to_multiple_of_three 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_79.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case79(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_check_prime() 12 | 13 | def get_task_description(self) -> str: 14 | return "Check if each number in a sequence is prime" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def primecheck(n): 21 | if n < 2: 22 | return 0 23 | for i in range(2, int(n ** 0.5) + 1): 24 | if n % i == 0: 25 | return 0 26 | return 1 27 | 28 | 29 | def make_check_prime() -> rasp.SOp: 30 | return rasp.Map(lambda x: primecheck(x), rasp.tokens) 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_114.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Set 3 | 4 | from tracr.rasp import rasp 5 | 6 | from circuits_benchmark.benchmark import vocabs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case114(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_logarithm() 13 | 14 | def get_task_description(self) -> str: 15 | return "Apply a logarithm base 10 to each element of the input sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab(min=1) 19 | 20 | 21 | def make_logarithm() -> rasp.SOp: 22 | def apply_log(element): 23 | return math.log(element, 10) 24 | 25 | # Applying the placeholder logarithm function to each element 26 | return rasp.Map(apply_log, rasp.tokens).named("logarithm") 27 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_3.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_frac_prevs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case3(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | is_x = (rasp.tokens == "x").named("is_x") 13 | return make_frac_prevs(is_x) 14 | 15 | def get_task_description(self) -> str: 16 | return "Returns the fraction of 'x' in the input up to the i-th position for all i." 17 | 18 | def get_vocab(self) -> Set: 19 | some_letters = vocabs.get_ascii_letters_vocab(count=3) 20 | some_letters.add("x") 21 | return some_letters 22 | 23 | def get_max_seq_len(self) -> int: 24 | return 5 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_48.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case48(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_increment_by_index() 12 | 13 | def get_task_description(self) -> str: 14 | return "Increments each element by its index." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_increment_by_index() -> rasp.SOp: 21 | # This operation adds each element of the input sequence to its corresponding index. 22 | incremented_sequence = rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.indices).named("incremented_sequence") 23 | 24 | return incremented_sequence 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_65.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case65(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_cube_root() 12 | 13 | def get_task_description(self) -> str: 14 | return "Calculate the cube root of each element in the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_cube_root() -> rasp.SOp: 21 | # Define the cube root operation to be applied to each element. 22 | cube_root_operation = rasp.Map(lambda x: x ** (1 / 3) if x >= 0 else -(-x) ** (1 / 3), rasp.tokens).named( 23 | "cube_root_operation") 24 | 25 | return cube_root_operation 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | import logging 3 | import sys 4 | 5 | import jax 6 | 7 | from circuits_benchmark.commands.build_main_parser import build_main_parser 8 | from circuits_benchmark.commands.algorithms import run_algorithm 9 | from circuits_benchmark.commands.train import train 10 | from circuits_benchmark.commands.evaluation import evaluation 11 | 12 | # The default of float16 can lead to discrepancies between outputs of 13 | # the compiled model and the RASP program. 14 | jax.config.update('jax_default_matmul_precision', 'float32') 15 | logging.basicConfig(level=logging.ERROR) 16 | 17 | if __name__ == "__main__": 18 | parser = build_main_parser() 19 | args, _ = parser.parse_known_args(sys.argv[1:]) 20 | 21 | if args.command == "run": 22 | run_algorithm.run(args) 23 | elif args.command == "train": 24 | train.run(args) 25 | elif args.command == "eval": 26 | evaluation.run(args) 27 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_72.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case72(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_negation() 12 | 13 | def get_task_description(self) -> str: 14 | return "Negate each element in the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(min=-10, max=10) 18 | 19 | 20 | def make_negation() -> rasp.SOp: 21 | # Define the negation function 22 | negation_function = lambda x: -x 23 | 24 | # Apply the negation function element-wise to the input sequence 25 | negated_sequence = rasp.Map(negation_function, rasp.tokens).named("negated_sequence") 26 | 27 | return negated_sequence 28 | -------------------------------------------------------------------------------- /tests/get_cases_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from circuits_benchmark.utils.attr_dict import AttrDict 4 | from circuits_benchmark.utils.get_cases import get_cases 5 | 6 | 7 | class GetCasesTest(unittest.TestCase): 8 | def test_get_all_cases(self): 9 | cases = get_cases() 10 | names = [case.get_name() for case in cases] 11 | print(names) 12 | 13 | assert len(names) > 35 14 | assert "ioi" in names 15 | assert "ioi_next_token" in names 16 | assert "3" in names 17 | assert "37" in names 18 | 19 | def test_cases_filtered_by_indices(self): 20 | args = AttrDict({"indices": "1,2,3"}) 21 | cases = get_cases(args) 22 | self.assertEqual(len(cases), 3) 23 | 24 | def test_get_cases_works_for_ioi_cases(self): 25 | args = AttrDict({"indices": "ioi,ioi_next_token"}) 26 | cases = get_cases(args) 27 | self.assertEqual(len(cases), 2) 28 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_110.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case110(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_insert_zeros() 12 | 13 | def get_task_description(self) -> str: 14 | return "Inserts zeros between each element, removing the latter half of the list." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_insert_zeros() -> rasp.SOp: 24 | shifter = rasp.Select(rasp.indices, rasp.indices, lambda x, y: x == int(y / 2)) 25 | shifted = rasp.Aggregate(shifter, rasp.tokens) 26 | return rasp.SequenceMap(lambda x, y: x if y % 2 == 0 else 0, shifted, rasp.indices) 27 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_86.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case86(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_check_power_of_n() 12 | 13 | def get_task_description(self) -> str: 14 | return "Check if each element is a power of 2. Return 1 if true, otherwise 0." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def pow_of_n(n, x): 21 | while x >= 1: 22 | if x == n: 23 | return 1 24 | x /= n 25 | return 0 26 | 27 | 28 | def make_check_power_of_n(n=2) -> rasp.SOp: 29 | # Check if each element is a power of n. Return 1 if true, otherwise 0. 30 | return rasp.Map(lambda x: pow_of_n(n, x), rasp.tokens).named(f"check_multiple_of_{n}") 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_83.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case83(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_triple() 12 | 13 | def get_task_description(self) -> str: 14 | return "Triple each element in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_triple() -> rasp.SOp: 21 | # Define a lambda function that triples the value of its input. 22 | triple_func = lambda x: x * 3 23 | 24 | # Apply the triple_func to each element of the sequence using Map. 25 | triple_sequence = rasp.Map(triple_func, rasp.tokens).named("triple_sequence") 26 | 27 | # Return the SOp that triples each element in the sequence. 28 | return triple_sequence 29 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/wandb_artifact_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import wandb 7 | 8 | from circuits_benchmark.utils.project_paths import get_default_output_dir 9 | 10 | default_artifacts_output_dir = os.path.join(get_default_output_dir(), "artifacts") 11 | 12 | 13 | def download_artifact(project_name: str, 14 | artifact_name: str, 15 | output_dir: str = default_artifacts_output_dir) -> List[Path]: 16 | if os.path.exists(output_dir): 17 | # remove dir to clean previous downloads 18 | shutil.rmtree(output_dir) 19 | 20 | if not os.path.exists(output_dir): 21 | os.makedirs(output_dir) 22 | 23 | api = wandb.Api() 24 | artifact = api.artifact(f"{project_name}/{artifact_name}:latest") 25 | artifact.download(root=output_dir) 26 | 27 | # return the name of the downloaded files (path objects) 28 | return list(Path(output_dir).iterdir()) 29 | -------------------------------------------------------------------------------- /circuits_benchmark/training/compression/activation_mapper/linear_mapper.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | from torch.nn import Linear 4 | 5 | from circuits_benchmark.training.compression.activation_mapper.activation_mapper import ActivationMapper 6 | 7 | 8 | class LinearMapper(ActivationMapper): 9 | """Maps the residual stream to/from a lower dimensional space using a linear layer.""" 10 | 11 | def __init__(self, compression_matrix: Linear): 12 | self.compression_matrix = compression_matrix 13 | 14 | def compress( 15 | self, 16 | residual_stream: Float[Tensor, "batch d_model"] 17 | ) -> Float[Tensor, "batch d_model_compressed"]: 18 | return residual_stream @ self.compression_matrix.weight 19 | 20 | def decompress( 21 | self, 22 | compressed_residual_stream: Float[Tensor, "batch d_model_compressed"] 23 | ) -> Float[Tensor, "batch d_model"]: 24 | return compressed_residual_stream @ self.compression_matrix.weight.T 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_69.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case69(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_sign() 12 | 13 | def get_task_description(self) -> str: 14 | return "Assign -1, 0, or 1 to each element of the input sequence based on its sign." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_sign() -> rasp.SOp: 21 | # Define the function to apply to each element 22 | def sign_check(x): 23 | if x < 0: 24 | return -1 25 | elif x > 0: 26 | return 1 27 | else: 28 | return 0 29 | 30 | # Apply the sign checking function to each element of the input sequence 31 | return rasp.Map(sign_check, rasp.tokens).named("sign") 32 | -------------------------------------------------------------------------------- /circuits_benchmark/training/compression/activation_mapper/activation_mapper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | class ActivationMapper(ABC): 8 | """Maps activations to/from a lower dimensional space.""" 9 | 10 | @abstractmethod 11 | def compress( 12 | self, 13 | activation: Float[Tensor, "batch pos d_model"] | Float[Tensor, "batch pos d_head"] 14 | ) -> Float[Tensor, "batch pos d_model_compressed"] | Float[Tensor, "batch pos d_head_comprssed"]: 15 | """Compresses an activation.""" 16 | raise NotImplementedError 17 | 18 | @abstractmethod 19 | def decompress( 20 | self, 21 | compressed_activation: Float[Tensor, "batch pos d_model_compressed"] | Float[ 22 | Tensor, "batch pos d_head_comprssed"] 23 | ) -> Float[Tensor, "batch pos d_model"] | Float[Tensor, "batch pos d_head"]: 24 | """Decompresses a compressed activation.""" 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/hf/hf_uploader.py: -------------------------------------------------------------------------------- 1 | from transformer_lens import HookedTransformerConfig, HookedTransformer 2 | from transformers import PreTrainedModel, PretrainedConfig 3 | 4 | 5 | class LLModelConfig(PretrainedConfig): 6 | def __init__(self, cfg: HookedTransformerConfig, 7 | task: str, 8 | tracr: bool = True, 9 | **kwargs): 10 | super().__init__(**kwargs) 11 | self.cfg = cfg.to_dict() 12 | self.task = task 13 | self.tracr = False 14 | 15 | 16 | class TLModel(PreTrainedModel): 17 | def __init__( 18 | self, 19 | tl_model: HookedTransformer, 20 | config: LLModelConfig, 21 | *args, 22 | **kwargs 23 | ): 24 | super().__init__(config, *args, **kwargs) 25 | self.model = tl_model 26 | self.config = config 27 | self.model.to("cpu") 28 | 29 | def forward(self, *args, **kwargs): 30 | raise NotImplementedError("This model is not meant to be used for forward pass.") 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_66.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case66(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_round() 12 | 13 | def get_task_description(self) -> str: 14 | return "Round each element in the input sequence to the nearest integer." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_float_numbers_vocab() 18 | 19 | 20 | def make_round() -> rasp.SOp: 21 | # Step 1: Add 0.5 to each element, shifting the decimal for rounding. 22 | shift_for_rounding = rasp.Map(lambda x: x + 0.5, rasp.tokens).named("shift_for_rounding") 23 | 24 | # Step 2: Convert each shifted element to an integer, effectively rounding it. 25 | rounded_sequence = rasp.Map(lambda x: int(x), shift_for_rounding).named("rounded_sequence") 26 | 27 | return rounded_sequence 28 | -------------------------------------------------------------------------------- /circuits_benchmark/training/compression/activation_mapper/autoencoder_mapper.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | 4 | from circuits_benchmark.training.compression.activation_mapper.activation_mapper import ActivationMapper 5 | from circuits_benchmark.training.compression.autencoder import AutoEncoder 6 | 7 | 8 | class AutoEncoderMapper(ActivationMapper): 9 | """Maps the residual stream to/from a lower dimensional space using an autoencoder.""" 10 | 11 | def __init__(self, autoencoder: AutoEncoder): 12 | self.autoencoder = autoencoder 13 | 14 | def compress( 15 | self, 16 | residual_stream: Float[Tensor, "batch d_model"] 17 | ) -> Float[Tensor, "batch d_model_compressed"]: 18 | return self.autoencoder.encoder(residual_stream) 19 | 20 | def decompress( 21 | self, 22 | compressed_residual_stream: Float[Tensor, "batch d_model_compressed"] 23 | ) -> Float[Tensor, "batch d_model"]: 24 | return self.autoencoder.decoder(compressed_residual_stream) 25 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_40.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case40(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_sum_digits() 12 | 13 | def get_task_description(self) -> str: 14 | return "Sum the last and previous to last digits of a number" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(min=0, max=29) 18 | 19 | 20 | def make_sum_digits() -> rasp.SOp: 21 | # Isolate the tens place 22 | tens_place = rasp.Map(lambda x: x // 10, rasp.tokens).named("tens_place") 23 | # Isolate the ones place 24 | ones_place = rasp.Map(lambda x: x % 10, rasp.tokens).named("ones_place") 25 | # Sum the tens and ones places 26 | sum_digits = rasp.SequenceMap(lambda x, y: x + y, tens_place, ones_place).named("sum_digits") 27 | 28 | return sum_digits 29 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_53.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case53(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_increment_odd_indices() 11 | 12 | def get_task_description(self) -> str: 13 | return "Increment elements at odd indices by 1" 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | 19 | def make_increment_odd_indices() -> rasp.SOp: 20 | # Marks odd indices with 1 and even indices with 0 21 | odd_index_marker = rasp.Map(lambda x: x % 2, rasp.indices).named("odd_index_marker") 22 | 23 | # Increment elements at odd indices by 1 24 | incremented_elements = rasp.SequenceMap( 25 | lambda elem, mark: elem + mark, rasp.tokens, odd_index_marker 26 | ).named("incremented_elements") 27 | 28 | return incremented_elements 29 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_28.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case28(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_mirroring(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Mirrors each word in the sequence around its central axis." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab(min_chars=4, max_words=50) 18 | 19 | 20 | def make_token_mirroring(sop: rasp.SOp) -> rasp.SOp: 21 | """ 22 | Mirrors each token in the sequence around its central axis. 23 | 24 | Example usage: 25 | token_mirror = make_token_mirroring(rasp.tokens) 26 | token_mirror(["abc", "def", "ghi"]) 27 | >> ["cba", "fed", "ihg"] 28 | """ 29 | mirrored_sop = rasp.Map(lambda x: x[::-1] if x is not None else None, sop) 30 | return mirrored_sop 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_23.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case23(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_palindrome_word_spotter(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Returns palindrome words in a sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab().union({"racecar", "noon"}) 18 | 19 | 20 | def make_palindrome_word_spotter(sop: rasp.SOp) -> rasp.SOp: 21 | """ 22 | Spots palindrome words in a sequence. 23 | 24 | Example usage: 25 | palindrome_spotter = make_palindrome_word_spotter(rasp.tokens) 26 | palindrome_spotter(["racecar", "hello", "noon"]) 27 | >> ["racecar", None, "noon"] 28 | """ 29 | is_palindrome = rasp.Map(lambda x: x if x == x[::-1] else None, sop) 30 | return is_palindrome 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_62.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case62(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_factorial() 12 | 13 | def get_task_description(self) -> str: 14 | return "Replaces each element with its factorial." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_factorial() -> rasp.SOp: 21 | factorial = rasp.Map(lambda x: factorial_helper(x), rasp.tokens).named("factorial") 22 | return factorial 23 | 24 | 25 | def factorial_helper(n: int) -> int: 26 | # Placeholder for factorial calculation 27 | # In actual RASP code, this function cannot exist due to RASP's limitations. 28 | # This represents a conceptual step that needs a workaround. 29 | if n == 0: 30 | return 1 31 | else: 32 | return n * factorial_helper(n - 1) 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_116.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case116(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_check_multiple_of_first() 12 | 13 | def get_task_description(self) -> str: 14 | return "Checks if each element in a sequence is a multiple of the first one." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | def get_max_seq_len(self) -> int: 23 | return len(self.get_vocab()) 24 | 25 | 26 | def make_check_multiple_of_first(): 27 | first = rasp.Aggregate(rasp.Select(rasp.indices, rasp.Map(lambda x: 0, rasp.indices), rasp.Comparison.EQ), 28 | rasp.tokens) 29 | return rasp.SequenceMap(lambda x, y: (1 if x % y == 0 else 0) if y != 0 else 0, rasp.tokens, first) 30 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circuit/circuit_node.py: -------------------------------------------------------------------------------- 1 | class CircuitNode(object): 2 | def __init__(self, name: str, index: int | None = None): 3 | self.name = name 4 | self.index = index 5 | 6 | def __str__(self): 7 | return f"{self.name}[{self.index}]" if self.index is not None else self.name 8 | 9 | def __repr__(self): 10 | return str(self) 11 | 12 | def __hash__(self): 13 | return hash(str(self)) 14 | 15 | def __eq__(self, other): 16 | if not isinstance(other, CircuitNode): 17 | return False 18 | 19 | return self.name == other.name and self.index == other.index 20 | 21 | def __lt__(self, other): 22 | if not isinstance(other, CircuitNode): 23 | raise ValueError(f"Expected a CircuitNode, got {type(other)}") 24 | 25 | if self.name != other.name: 26 | return self.name < other.name 27 | elif self.index is None: 28 | return False 29 | elif other.index is None: 30 | return True 31 | else: 32 | return self.index < other.index 33 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/ground_truth_model_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from iit.utils.correspondence import Correspondence 4 | from transformer_lens import HookedTransformer 5 | 6 | from circuits_benchmark.utils.ll_model_loader.ll_model_loader import LLModelLoader 7 | 8 | 9 | class GroundTruthModelLoader(LLModelLoader): 10 | def get_output_suffix(self) -> str: 11 | return self.__str__() 12 | 13 | def __repr__(self) -> str: 14 | return self.__str__() 15 | 16 | def __str__(self) -> str: 17 | return f"ground_truth" 18 | 19 | def load_ll_model_and_correspondence( 20 | self, 21 | device: str, 22 | output_dir: Optional[str] = None, 23 | same_size: bool = False, 24 | *args, **kwargs 25 | ) -> Tuple[Correspondence, HookedTransformer]: 26 | assert not same_size, "Ground truth models are never same size" 27 | 28 | hl_model = self.case.get_hl_model(device=device) 29 | corr = self.case.get_correspondence(same_size=True) # tracr models are always same size 30 | return corr, hl_model 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_29.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case29(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_abbreviation(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Creates abbreviations for each token in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_token_abbreviation(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Creates abbreviations for each token in the sequence. 26 | 27 | Example usage: 28 | token_abbreviation = make_token_abbreviation(rasp.tokens) 29 | token_abbreviation(["international", "business", "machines"]) 30 | >> ["int", "bus", "mac"] 31 | """ 32 | abbreviation = rasp.Map(lambda x: x[:3] if len(x) > 3 else x, sop) 33 | return abbreviation 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_33.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case33(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_length_parity_checker(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Checks if each token's length is odd or even." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_token_length_parity_checker(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Checks if each token's length is odd or even. 26 | 27 | Example usage: 28 | length_parity = make_token_length_parity_checker(rasp.tokens) 29 | length_parity(["hello", "worlds", "!", "2022"]) 30 | >> [False, True, False, True] 31 | """ 32 | length_parity_checker = rasp.Map(lambda x: len(x) % 2 == 0, sop) 33 | return length_parity_checker 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_103.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case103(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_swap_consecutive() 12 | 13 | def get_task_description(self) -> str: 14 | return "Swap consecutive numbers in a list" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_swap_consecutive() -> rasp.SOp: 24 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 25 | swaper = rasp.SequenceMap(lambda x, y: x if (x == y - 1 and y % 2 == 1) else (x + 1 if x % 2 == 0 else x - 1), 26 | rasp.indices, len) 27 | swap_selector = rasp.Select(rasp.indices, swaper, rasp.Comparison.EQ) 28 | swaped = rasp.Aggregate(swap_selector, rasp.tokens) 29 | return swaped 30 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/init_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch as t 4 | 5 | 6 | def small_init_init_method(dim): 7 | """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving 8 | the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" 9 | std = math.sqrt(2 / (5 * dim)) 10 | 11 | def init_(tensor): 12 | return t.nn.init.normal_(tensor, mean=0.0, std=std) 13 | 14 | return init_ 15 | 16 | 17 | def wang_init_method(n_layers, dim): 18 | std = 2 / n_layers / math.sqrt(dim) # Equivalent to (2 / n_layers) * (1 / math.sqrt(dim)) 19 | 20 | def init_(tensor): 21 | return t.nn.init.normal_(tensor, mean=0.0, std=std) 22 | 23 | return init_ 24 | 25 | 26 | def kaiming_uniform_and_normal_for_biases(): 27 | def init_(tensor): 28 | if len(tensor.shape) > 1: 29 | return t.nn.init.kaiming_uniform_(tensor) 30 | else: 31 | # Biases are initialized with a normal distribution 32 | return t.nn.init.normal_(tensor, std=0.02) 33 | 34 | return init_ 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_26.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case26(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_cascade(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Creates a cascading effect by repeating each token in sequence incrementally." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_ascii_letters_vocab(count=3) 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_token_cascade(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Creates a cascading effect by repeating each token in sequence incrementally. 26 | 27 | Example usage: 28 | token_cascade = make_token_cascade(rasp.tokens) 29 | token_cascade(["a", "b", "c"]) 30 | >> ["a", "bb", "ccc"] 31 | """ 32 | cascade_sop = rasp.SequenceMap(lambda x, i: x * (i + 1), sop, rasp.indices) 33 | return cascade_sop 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_63.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case63(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_count_less_than() 12 | 13 | def get_task_description(self) -> str: 14 | return "Replaces each element with the number of elements less than it in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_count_less_than() -> rasp.SOp: 24 | # Create a selector that identifies where one element is less than another in the sequence. 25 | lt_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT).named("lt_selector") 26 | 27 | # Count the number of True comparisons (i.e., number of elements less than) for each element. 28 | count_lt = rasp.SelectorWidth(lt_selector).named("count_lt") 29 | 30 | return count_lt 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_74.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case74(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_interleave_reverse() 12 | 13 | def get_task_description(self) -> str: 14 | return "Interleaves elements with their reverse order Numbers at the odd indices should be in reverse order." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_interleave_reverse(): 24 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 25 | shifter = rasp.SequenceMap(lambda x, y: x if x % 2 == 0 else (y - x if y % 2 == 0 else y - x - 1), rasp.indices, 26 | len) 27 | shift_selector = rasp.Select(rasp.indices, shifter, rasp.Comparison.EQ) 28 | return rasp.Aggregate(shift_selector, rasp.tokens) 29 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/project_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_ROOT: str | None = None 4 | 5 | 6 | def detect_project_root() -> str: 7 | """ 8 | Detects the root of the project by looking for a known file in the project. 9 | :return: the path to the root of the project. 10 | """ 11 | global PROJECT_ROOT 12 | if PROJECT_ROOT is not None: 13 | # If the project root has already been detected, return it. 14 | return PROJECT_ROOT 15 | 16 | # Get the absolute path of the current file 17 | current_file_path = os.path.abspath(__file__) 18 | 19 | # Get the directory name of the current file 20 | current_dir = os.path.dirname(current_file_path) 21 | 22 | # Traverse upwards until you reach the root directory (assumed to be two levels up) 23 | PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, '..', '..')) 24 | 25 | return PROJECT_ROOT 26 | 27 | 28 | def get_default_output_dir() -> str: 29 | """ 30 | Get the default output directory for the project. 31 | :return: the default output directory for the project. 32 | """ 33 | return str(os.path.join(detect_project_root(), "results")) 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_44.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case44(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_count_greater_than() 12 | 13 | def get_task_description(self) -> str: 14 | return "Replaces each element with the number of elements greater than it in the sequence" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_count_greater_than() -> rasp.SOp: 24 | # Creating a selector that identifies elements greater than each element. 25 | greater_than_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT).named("greater_than_selector") 26 | 27 | # Counting the number of elements greater than each element. 28 | count_greater_than = rasp.SelectorWidth(greater_than_selector).named("count_greater_than") 29 | 30 | return count_greater_than 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_42.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case42(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_first_element() 12 | 13 | def get_task_description(self) -> str: 14 | return "Return a sequence composed only of the first element of the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_ascii_letters_vocab(count=10) 18 | 19 | 20 | def make_first_element() -> rasp.SOp: 21 | # Selector that identifies the first element by comparing indices to 0. 22 | first_elem_selector = rasp.Select(rasp.indices, rasp.Map(lambda x: 0, rasp.indices), rasp.Comparison.EQ).named( 23 | "first_elem_selector") 24 | 25 | # Use Aggregate to broadcast the first element across the entire sequence. 26 | first_element_sequence = rasp.Aggregate(first_elem_selector, rasp.tokens, default=None).named( 27 | "first_element_sequence") 28 | 29 | return first_element_sequence 30 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_107.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case107(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_wrap() 12 | 13 | def get_task_description(self) -> str: 14 | return "Wraps each element within a range (make the default range [2, 7])." 15 | # Wrapping here means that the values are projected into the range starting from the lower bound, once they grow 16 | # larger than the upper bound, they start again at the lower. 17 | 18 | def get_vocab(self) -> Set: 19 | return vocabs.get_int_numbers_vocab(min=-15, max=15) 20 | 21 | 22 | def wrap_into_range(min, max, x): 23 | # Calculate the size of the range 24 | range_size = max - min 25 | # Wrap x into the range 26 | wrapped_x = ((x - min) % range_size) + min 27 | return wrapped_x 28 | 29 | 30 | def make_wrap(min_val=2, max_val=7) -> rasp.SOp: 31 | return rasp.Map(lambda x: wrap_into_range(min_val, max_val, x), rasp.tokens) 32 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_131.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case131(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_pairwise_max() 13 | 14 | def get_task_description(self) -> str: 15 | return "Makes each element the maximum of it and the previous element, leaving the first element as it is." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | 21 | def make_pairwise_max() -> rasp.SOp: 22 | # Shift the input sequence by 1 to the right, filling the first position with fill_value. 23 | shifted_sequence = shift_by(1, rasp.tokens).named("shifted_sequence") 24 | 25 | # Compare each element of the original sequence with the shifted sequence, taking the maximum. 26 | pairwise_max = rasp.SequenceMap(lambda x, y: int(max(x, y)), rasp.tokens, shifted_sequence).named("pairwise_max") 27 | 28 | return pairwise_max 29 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_128.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case128(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_swap_first_last() 12 | 13 | def get_task_description(self) -> str: 14 | return "Swap the first and last elements of a list." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def condition(a, b): 24 | if a == 0: 25 | return b - 1 26 | elif a + 1 == b: 27 | return 0 28 | else: 29 | return a 30 | 31 | 32 | def make_swap_first_last(): 33 | swaper = rasp.SequenceMap(lambda x, y: condition(x, y), rasp.indices, 34 | rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE))) 35 | swap_selector = rasp.Select(swaper, rasp.indices, rasp.Comparison.EQ) 36 | return rasp.Aggregate(swap_selector, rasp.tokens) 37 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_120.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case120(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_flip_halves() 12 | 13 | def get_task_description(self) -> str: 14 | return "Flips the order of the first and second half of the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_flip_halves() -> rasp.SOp: 24 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 25 | half = rasp.Map(lambda x: x / 2, len) 26 | new_positions = rasp.SequenceMap(lambda x, y: (x - y if x >= y else x + y) if y == int(y) else ( 27 | x if x + 0.5 == y else (x + int(y) + 1 if x < y else x - int(y) - 1)), rasp.indices, half) 28 | shifter = rasp.Select(new_positions, rasp.indices, rasp.Comparison.EQ) 29 | return rasp.Aggregate(shifter, rasp.tokens) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Iván Arcuschin Moreno, Rohan Gupta, Niels uit de Bos, Thomas Kwa, Adrià Garriga-Alonso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_24.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case24(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_leading_token_identification(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Identifies the first occurrence of each token in a sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_ascii_letters_vocab(count=3) 18 | 19 | 20 | def make_leading_token_identification(sop: rasp.SOp) -> rasp.SOp: 21 | """ 22 | Identifies the first occurrence of each token in a sequence. 23 | 24 | Example usage: 25 | leading_token_id = make_leading_token_identification(rasp.tokens) 26 | leading_token_id(["x", "y", "x", "z", "y"]) 27 | >> [True, True, False, True, False] 28 | """ 29 | first_occurrence = rasp.Aggregate( 30 | rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.EQ), 31 | sop, default=None).named("first_occurrence") 32 | return first_occurrence 33 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/best_weights.py: -------------------------------------------------------------------------------- 1 | def get_best_weight(case_name: str, individual=False): 2 | if str(case_name) in ['3']: 3 | ind_weights = { 4 | 'iit': 1.0, 5 | 'strict': 10.0, 6 | 'behavior': 1.0, 7 | } 8 | elif str(case_name) in ['18', '34', '35', '36', '37']: 9 | ind_weights = { 10 | 'iit': 1.0, 11 | 'strict': 1.0, 12 | 'behavior': 1.0, 13 | } 14 | elif str(case_name) in ['21']: 15 | ind_weights = { 16 | 'iit': 1.0, 17 | 'strict': 0.5, 18 | 'behavior': 1.0, 19 | } 20 | elif "ioi" in str(case_name): 21 | if individual: 22 | return { 23 | 'iit': 1.0, 24 | 'strict': 0.4, 25 | 'behavior': 1.0, 26 | } 27 | return "100_100_40" 28 | else: 29 | ind_weights = { 30 | 'iit': 1.0, 31 | 'strict': 0.4, 32 | 'behavior': 1.0, 33 | } 34 | 35 | if individual: 36 | return ind_weights 37 | return str(int(ind_weights['strict'] * 1000 + ind_weights['behavior'] * 100 + ind_weights['iit'] * 10)) 38 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_37.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case37(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_reversal_with_exclusion(rasp.tokens, "nochange") 12 | 13 | def get_task_description(self) -> str: 14 | return "Reverses each word in the sequence except for specified exclusions." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_token_reversal_with_exclusion(sop: rasp.SOp, exclude: str) -> rasp.SOp: 24 | """ 25 | Reverses each token in the sequence except for specified exclusions. 26 | 27 | Example usage: 28 | token_reversal = make_token_reversal_with_exclusion(rasp.tokens, "nochange") 29 | token_reversal(["reverse", "this", "nochange"]) 30 | >> ["esrever", "siht", "nochange"] 31 | """ 32 | reversal = rasp.Map(lambda x: x[::-1] if x != exclude else x, sop) 33 | return reversal 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_100.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case100(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_swap_elements() 12 | 13 | def get_task_description(self) -> str: 14 | return "Swaps two elements at specified indices (default is 0 and 1)." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | def get_max_seq_len(self) -> int: 23 | return len(self.get_vocab()) 24 | 25 | 26 | def condition(a, idx1, idx2): 27 | if a == idx2: 28 | return idx1 29 | elif a == idx1: 30 | return idx2 31 | else: 32 | return a 33 | 34 | 35 | def make_swap_elements(index_a=1, index_b=0): 36 | swaper = rasp.Map(lambda x: condition(x, index_a, index_b), rasp.indices) 37 | swap_selector = rasp.Select(swaper, rasp.indices, rasp.Comparison.EQ) 38 | return rasp.Aggregate(swap_selector, rasp.tokens) 39 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_30.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case30(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_numeric_range_tagging(rasp.tokens, 10, 20) 12 | 13 | def get_task_description(self) -> str: 14 | return "Tags numeric tokens in a sequence based on whether they fall within a given range." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_str_numbers_vocab(min=0, max=30) 18 | 19 | 20 | def make_numeric_range_tagging(sop: rasp.SOp, lower_bound: int, upper_bound: int) -> rasp.SOp: 21 | """ 22 | Tags numeric tokens in a sequence based on whether they fall within a given range. 23 | 24 | Example usage: 25 | range_tagging = make_numeric_range_tagging(rasp.tokens, 10, 20) 26 | range_tagging(["5", "15", "25", "20"]) 27 | >> [False, True, False, True] 28 | """ 29 | range_tagging = rasp.Map( 30 | lambda x: lower_bound <= int(x) <= upper_bound if x.isdigit() else False, sop) 31 | return range_tagging 32 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_31.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case31(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_anagram_identifier(rasp.tokens, "listen") 12 | 13 | def get_task_description(self) -> str: 14 | return "Identify if tokens in the sequence are anagrams of the word 'listen'." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab().union({"listen"}) 18 | 19 | 20 | def make_token_anagram_identifier(sop: rasp.SOp, target: str) -> rasp.SOp: 21 | """ 22 | Identifies if tokens in the sequence are anagrams of a given target word. 23 | 24 | Example usage: 25 | anagram_identifier = make_token_anagram_identifier(rasp.tokens, "listen") 26 | anagram_identifier(["enlist", "google", "inlets", "banana"]) 27 | >> [True, False, True, False] 28 | """ 29 | sorted_target = sorted(target) 30 | anagram_identifier = rasp.Map( 31 | lambda x: sorted(x) == sorted_target, sop) 32 | return anagram_identifier 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_51.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case51(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_check_fibonacci() 11 | 12 | def get_task_description(self) -> str: 13 | return "Checks if each element is a Fibonacci number" 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab(min=0, max=100) 17 | 18 | 19 | def make_check_fibonacci() -> rasp.SOp: 20 | # Assume a pre-generated Fibonacci sequence up to a certain limit. 21 | # In practice, this would need to be dynamically generated or sufficiently large. 22 | fib_sequence = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 23 | 24 | # Function to check if a number is in the Fibonacci sequence 25 | is_fib = lambda x: 1 if x in fib_sequence else 0 26 | 27 | # Apply the check to each element of the input sequence 28 | check_fibonacci_map = rasp.Map(is_fib, rasp.tokens).named("check_fibonacci_map") 29 | 30 | return check_fibonacci_map 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_105.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case105(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_next_prime() 12 | 13 | def get_task_description(self) -> str: 14 | return "Replaces each number with the next prime after that number." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=30) 18 | 19 | 20 | def is_prime(num): 21 | """Check if a number is prime.""" 22 | if num <= 1: 23 | return False 24 | for i in range(2, int(num ** 0.5) + 1): 25 | if num % i == 0: 26 | return False 27 | return True 28 | 29 | 30 | def next_prime(n): 31 | """Return the next highest prime number after n.""" 32 | # Start checking from the next number 33 | prime_candidate = n 34 | while True: 35 | if is_prime(prime_candidate): 36 | return prime_candidate 37 | prime_candidate += 1 38 | 39 | 40 | def make_next_prime(): 41 | return rasp.Map(next_prime, rasp.tokens) 42 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_90.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case90(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_replacer(rasp.tokens, "findme", "-") 12 | 13 | def get_task_description(self) -> str: 14 | return "Replaces a specific token with another one." 15 | 16 | def get_vocab(self) -> Set: 17 | vocab = vocabs.get_words_vocab() 18 | vocab.add("findme") 19 | vocab.add("-") 20 | return vocab 21 | 22 | def is_trivial(self) -> bool: 23 | return True 24 | 25 | 26 | def make_token_replacer(sop: rasp.SOp, target: str, replacement: str) -> rasp.SOp: 27 | """ 28 | Returns a program that replaces a target token with a replacement token 29 | 30 | Example usage: 31 | replacer = make_token_replacer(rasp.tokens, "findme", "-") 32 | replacer(["word1", "findme", "word3"]) 33 | >> ["word1", "-", "word3"] 34 | """ 35 | replaced = rasp.Map(lambda x: replacement if x == target else x, sop) 36 | return replaced 37 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_97.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.cases.case_98 import make_max_element 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case97(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_scale_by_max() 13 | 14 | def get_task_description(self) -> str: 15 | return "Scale a sequence by its maximum element." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_scale_by_max() -> rasp.SOp: 25 | # Find the maximum element in the sequence. 26 | max_element = make_max_element().named("max_element") 27 | 28 | # Assume the maximum element is not zero to avoid division by zero. 29 | # Divide each element in the sequence by the maximum element. 30 | scale_by_max_sequence = rasp.SequenceMap(lambda x, y: (x / y) if y > 0 else 0, rasp.tokens, max_element).named( 31 | "scale_by_max_sequence") 32 | 33 | return scale_by_max_sequence 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_117.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case117(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_sum_of_last_two() 12 | 13 | def get_task_description(self) -> str: 14 | return "Given a list of integers, return the sum of the last two elements." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_sum_of_last_two() -> rasp.SOp: 24 | len = rasp.SelectorWidth(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)) 25 | last_idx = rasp.Map(lambda x: x - 1, len) 26 | second_to_last_idx = rasp.Map(lambda x: x - 2, len) 27 | last_elt = rasp.Aggregate(rasp.Select(rasp.indices, last_idx, rasp.Comparison.EQ), rasp.tokens) 28 | second_to_last_elt = rasp.Aggregate(rasp.Select(rasp.indices, second_to_last_idx, rasp.Comparison.EQ), rasp.tokens) 29 | return rasp.SequenceMap(lambda x, y: x + y, last_elt, second_to_last_elt) 30 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_11.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case11(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_word_count_by_length(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Counts the number of words in a sequence based on their length." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_word_count_by_length(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Counts the number of words in a sequence based on their length. 26 | 27 | Example usage: 28 | word_count = make_word_count_by_length(rasp.tokens) 29 | word_count(["apple", "pear", "banana"]) 30 | >> {5: 2, 4: 1} 31 | """ 32 | word_length = rasp.Map(lambda x: len(x), sop) 33 | length_selector = rasp.Select(word_length, word_length, rasp.Comparison.EQ) 34 | word_count = rasp.Aggregate(length_selector, word_length, default=None) 35 | return word_count 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_32.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case32(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_boundary_detector(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Detects the boundaries between different types of tokens in a sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_words_vocab(min_chars=4, max_words=10) 19 | 20 | 21 | def make_token_boundary_detector(sop: rasp.SOp) -> rasp.SOp: 22 | """ 23 | Detects the boundaries between different types of tokens in a sequence. 24 | 25 | Example usage: 26 | token_boundary = make_token_boundary_detector(rasp.tokens) 27 | token_boundary(["apple", "banana", "apple", "orange"]) 28 | >> [False, True, False, True] 29 | """ 30 | previous_token = shift_by(1, sop) 31 | boundary_detector = rasp.SequenceMap( 32 | lambda x, y: x != y, sop, previous_token) 33 | return boundary_detector 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_47.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case47(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_count_frequency() 12 | 13 | def get_task_description(self) -> str: 14 | return "Counts the frequency of each unique element" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_ascii_letters_vocab(count=10) 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_count_frequency() -> rasp.SOp: 24 | # Create a comparison matrix where each element is compared to every other element for equality. 25 | equality_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ).named("equality_selector") 26 | 27 | # Use SelectorWidth to count the frequency of each element based on the equality comparison matrix. 28 | frequency_count = rasp.SelectorWidth(equality_selector).named("frequency_count") 29 | 30 | # The result is a sequence where each element is replaced by its frequency count. 31 | return frequency_count 32 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_43.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case43(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_nth_fibonacci() 12 | 13 | def get_task_description(self) -> str: 14 | return "Returns the corresponding Fibonacci number for each element in the input sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=20) 18 | 19 | 20 | def make_nth_fibonacci(): 21 | # Pre-generated Fibonacci sequence up to the 20th number, considering 0th and 1st numbers as 0 and 1. 22 | fib_sequence = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765] 23 | 24 | # Create a sequence of indices for each position in the input sequence 25 | indices = rasp.Map(lambda x: x, rasp.indices).named("indices") 26 | 27 | # Map each element in the input sequence to its corresponding Fibonacci number 28 | nth_fib = rasp.Map(lambda x: fib_sequence[x] if x < len(fib_sequence) else 0, rasp.tokens).named("nth_fib") 29 | 30 | return nth_fib 31 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_61.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case61(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_rank() 12 | 13 | def get_task_description(self) -> str: 14 | return "Ranks each element according to its size." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_rank() -> rasp.SOp: 24 | # Selector that creates a comparison matrix where each element is compared to every other element to find how many are smaller. 25 | less_than_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT).named("less_than_selector") 26 | 27 | # Count the number of elements that are smaller than each element. 28 | smaller_count = rasp.SelectorWidth(less_than_selector).named("smaller_count") 29 | 30 | # Since ranks start from 1, add 1 to each count to get the rank. 31 | rank = rasp.Map(lambda x: x + 1, smaller_count).named("rank") 32 | 33 | return rank 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_35.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case35(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_token_capitalization_alternator(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Alternates capitalization of each character in words." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_token_capitalization_alternator(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Alternates capitalization of each character in tokens. 26 | 27 | Example usage: 28 | capitalization_alternator = make_token_capitalization_alternator(rasp.tokens) 29 | capitalization_alternator(["hello", "world"]) 30 | >> ["HeLlO", "WoRlD"] 31 | """ 32 | 33 | def alternate_capitalization(word): 34 | return ''.join(c.upper() if i % 2 == 0 else c.lower() for i, c in enumerate(word)) 35 | 36 | alternator = rasp.Map(alternate_capitalization, sop) 37 | return alternator 38 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_36.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 6 | 7 | 8 | class Case36(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_emoji_sentiment_classifier(rasp.tokens) 11 | 12 | def get_task_description(self) -> str: 13 | return "Classifies each token as 'positive', 'negative', or 'neutral' based on emojis." 14 | 15 | def get_vocab(self) -> Set: 16 | return {"😊", "😢", "📘"} 17 | 18 | def is_trivial(self) -> bool: 19 | return True 20 | 21 | 22 | def make_emoji_sentiment_classifier(sop: rasp.SOp) -> rasp.SOp: 23 | """ 24 | Classifies each token as 'positive', 'negative', or 'neutral' based on emojis. 25 | 26 | Example usage: 27 | emoji_sentiment = make_emoji_sentiment_classifier(rasp.tokens) 28 | emoji_sentiment(["😊", "😢", "📘"]) 29 | >> ["positive", "negative", "neutral"] 30 | """ 31 | # Define mapping for emoji sentiment classification 32 | emoji_sentiments = {"😊": "positive", "😢": "negative", "📘": "neutral"} 33 | classify_sentiment = rasp.Map(lambda x: emoji_sentiments.get(x, "neutral"), sop) 34 | return classify_sentiment 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_78.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case78(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_count_occurrences() 12 | 13 | def get_task_description(self) -> str: 14 | return "Count the occurrences of each element in a sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_ascii_letters_vocab(count=10) 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_count_occurrences() -> rasp.SOp: 24 | """ 25 | Creates an SOp that transforms a sequence so each element is replaced by the number of times it appears in the sequence. 26 | """ 27 | # Selector that compares each element with every other element to find duplicates 28 | eq_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ).named("eq_selector") 29 | 30 | # Counts the occurrences of each element based on the equality comparison 31 | count_occurrences = rasp.SelectorWidth(eq_selector).named("count_occurrences") 32 | 33 | return count_occurrences 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_20.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case20(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_spam_message_detector(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Detect spam messages based on appearance of spam keywords." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab().union({"spam", "offer", "click", "now"}) 18 | 19 | def is_trivial(self) -> bool: 20 | return True 21 | 22 | 23 | def make_spam_message_detector(sop: rasp.SOp) -> rasp.SOp: 24 | """ 25 | Detects spam messages based on keyword frequency. 26 | 27 | Example usage: 28 | spam_detector = make_spam_message_detector(rasp.tokens) 29 | spam_detector(["free", "offer", "click", "now"]) 30 | >> "spam" 31 | """ 32 | spam_keywords = {"free", "offer", "click", "now"} 33 | keyword_count = rasp.Map(lambda x: sum(x == keyword for keyword in spam_keywords), sop) 34 | is_spam = rasp.Map(lambda x: "spam" if x > 0 else "not spam", keyword_count) 35 | return is_spam 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_93.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case93(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_swap_odd_index() 12 | 13 | def get_task_description(self) -> str: 14 | return "Swaps the nth with the n+1th element if n%2==1." 15 | # Note that this means that the first element will remain unchanged. 16 | # The second will be swapped with the third and so on 17 | 18 | def get_vocab(self) -> Set: 19 | return vocabs.get_int_numbers_vocab() 20 | 21 | def supports_causal_masking(self) -> bool: 22 | return False 23 | 24 | 25 | def make_swap_odd_index() -> rasp.SOp: 26 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 27 | swaper = rasp.SequenceMap( 28 | lambda x, y: x if (x == y - 1 and y % 2 == 0 or x == 0) else (x - 1 if (x + 1) % 2 == 1 else x + 1), 29 | rasp.indices, 30 | len) 31 | swap_selector = rasp.Select(rasp.indices, swaper, rasp.Comparison.EQ) 32 | swaped = rasp.Aggregate(swap_selector, rasp.tokens) 33 | return swaped 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_25.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_hist, make_length 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case25(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_frequency_normalization() 13 | 14 | def get_task_description(self) -> str: 15 | return "Normalizes token frequencies in a sequence to a range between 0 and 1." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_ascii_letters_vocab(count=3) 22 | 23 | 24 | def make_token_frequency_normalization() -> rasp.SOp: 25 | """ 26 | Normalizes token frequencies in a sequence to a range between 0 and 1. 27 | 28 | Example usage: 29 | token_freq_norm = make_token_frequency_normalization(rasp.tokens) 30 | token_freq_norm(["a", "a", "b", "c", "c", "c"]) 31 | >> [0.33, 0.33, 0.16, 0.5, 0.5, 0.5] 32 | """ 33 | normalized_freq = rasp.SequenceMap(lambda x, y: (x / y) if y > 0 else None, make_hist(), make_length()) 34 | return normalized_freq 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_34.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case34(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_vowel_consonant_ratio(rasp.tokens) 12 | 13 | def get_task_description(self) -> str: 14 | return "Calculate the ratio of vowels to consonants in each word." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_words_vocab() 18 | 19 | 20 | def make_vowel_consonant_ratio(sop: rasp.SOp) -> rasp.SOp: 21 | """ 22 | Calculates the ratio of vowels to consonants in each token. Deal with 0 denominator by 23 | returning infinity. 24 | 25 | Example usage: 26 | vowel_consonant_ratio = make_vowel_consonant_ratio(rasp.tokens) 27 | vowel_consonant_ratio(["apple", "sky", "aeiou"]) 28 | >> [2/3, 0/3, inf] 29 | """ 30 | 31 | def calc_ratio(word): 32 | vowels = sum(c in 'aeiou' for c in word.lower()) 33 | consonants = len(word) - vowels 34 | return vowels / consonants if consonants != 0 else float('inf') 35 | 36 | ratio_calculator = rasp.Map(calc_ratio, sop) 37 | return ratio_calculator 38 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/get_cases.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import List 3 | 4 | from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase 5 | from circuits_benchmark.utils.find_all_subclasses import find_all_transitive_subclasses_in_package 6 | 7 | 8 | def get_cases(args: Namespace | None = None, indices: List[str] | None = None) -> List[BenchmarkCase]: 9 | assert (args is None or args.indices is None) or indices is None, "Cannot specify both args.indices and indices" 10 | 11 | classes = find_all_transitive_subclasses_in_package(BenchmarkCase, "circuits_benchmark.benchmark.cases") 12 | classes = [cls for cls in classes if cls.__name__.startswith("Case")] 13 | 14 | if args is not None and args.indices is not None: 15 | indices = [idx.lower() for idx in args.indices.split(",")] 16 | 17 | if indices is not None: 18 | # filter class names that are "CaseN" where N in indices 19 | classes = [cls for cls in classes if cls.__name__[4:].lower() in indices] 20 | 21 | # sort classes. if id is a number, numerically, otherwise alphabetically 22 | classes.sort(key=lambda cls: 23 | cls.__name__[4:] if cls.__name__[4:].isnumeric() else cls.__name__[4:] 24 | ) 25 | 26 | # instantiate all classes found 27 | return [cls() for cls in classes] 28 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_95.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case95(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_count_prime_factors() 12 | 13 | def get_task_description(self) -> str: 14 | return "Counts the distinct prime factors of each number in the input list." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=30) 18 | 19 | 20 | def count_prime_factors(n): 21 | if n == 0: return 0 22 | 23 | count = 0 24 | # Handle 2 separately to make the loop only for odd numbers 25 | if n % 2 == 0: 26 | count += 1 27 | while n % 2 == 0: 28 | n //= 2 29 | # Check for odd factors 30 | factor = 3 31 | while factor * factor <= n: 32 | if n % factor == 0: 33 | count += 1 34 | while n % factor == 0: 35 | n //= factor 36 | factor += 2 37 | # If n is a prime number greater than 2 38 | if n > 2: 39 | count += 1 40 | return count 41 | 42 | 43 | def make_count_prime_factors(): 44 | return rasp.Map(count_prime_factors, rasp.tokens) 45 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_118.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case118(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_pairwise_sum() 13 | 14 | def get_task_description(self) -> str: 15 | return "Replaces each element with the sum of it and the previous element." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_pairwise_sum() -> rasp.SOp: 25 | # Shift the input sequence by 1 to the right, filling the first position with fill_value. 26 | shifted_sequence = rasp.SequenceMap(lambda x, y: x if y > 0 else 0, 27 | shift_by(1, rasp.tokens).named("shifted_sequence"), rasp.indices) 28 | 29 | # Compare each element of the original sequence with the shifted sequence, taking the maximum. 30 | pairwise_sum = rasp.SequenceMap(lambda x, y: int(x + y), rasp.tokens, shifted_sequence).named("pairwise_max") 31 | 32 | return pairwise_sum 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_71.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case71(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_divide_by_length() 12 | 13 | def get_task_description(self) -> str: 14 | return "Divide each element by the length of the sequence" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_divide_by_length() -> rasp.SOp: 24 | # Step 1: Create a selector that selects all elements (TRUE condition) 25 | all_true_selector = rasp.Select( 26 | rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") 27 | 28 | # Calculate the length of the sequence using SelectorWidth 29 | length = rasp.SelectorWidth(all_true_selector).named("length") 30 | 31 | # Step 2: Divide each element by the length of the sequence 32 | divided_by_length = rasp.SequenceMap(lambda x, length: x / length if length > 0 else 0, rasp.tokens, length).named( 33 | "divided_by_length") 34 | 35 | return divided_by_length 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_76.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case76(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_zero_even_indices() 12 | 13 | def get_task_description(self) -> str: 14 | return "Set even indices to 0" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_zero_even_indices() -> rasp.SOp: 21 | # Create a sequence of indices 22 | indices = rasp.Map(lambda x: x, rasp.indices).named("indices") 23 | 24 | # Create a sequence where even indices are marked with 0 and odd indices with -1 25 | even_odd_marker = rasp.Map(lambda x: 0 if x % 2 == 0 else -1, indices).named("even_odd_marker") 26 | 27 | # Use SequenceMap to combine the original sequence with the marker sequence 28 | # If the marker is 0, return 0 (for even indices); otherwise, return the original element (for odd indices) 29 | final_sequence = rasp.SequenceMap(lambda elem, marker: elem if marker == -1 else 0, rasp.tokens, 30 | even_odd_marker).named("final_sequence") 31 | 32 | return final_sequence 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/vocabs.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from typing import Set 4 | 5 | TRACR_BOS = "BOS" 6 | TRACR_PAD = "PAD" 7 | 8 | 9 | def get_ascii_letters_vocab(count=len(string.ascii_lowercase)) -> Set: 10 | return set(string.ascii_lowercase[:count]) 11 | 12 | 13 | def get_str_digits_vocab(count=len(string.digits)) -> Set: 14 | return set(string.digits[:count]) 15 | 16 | 17 | def get_int_digits_vocab(count=len(string.digits)) -> Set: 18 | return set([int(d) for d in get_str_digits_vocab(count=count)]) 19 | 20 | 21 | def get_str_numbers_vocab(min=0, max=20) -> Set: 22 | return set([str(d) for d in range(min, max)]) 23 | 24 | 25 | def get_int_numbers_vocab(min=0, max=11) -> Set: 26 | return set([d for d in range(min, max)]) 27 | 28 | 29 | def get_float_numbers_vocab(min=0, max=5, count=20) -> Set: 30 | return set([min + x * (max - min) / count for x in range(count)]) 31 | 32 | 33 | def get_words_vocab(seed=42, min_chars=1, max_chars=8, min_words=5, max_words=20) -> Set: 34 | """Generate a set of random words.""" 35 | random.seed(seed) 36 | vocab: Set = set() 37 | for _ in range(random.randint(min_words, max_words)): 38 | word = "".join(random.choice(string.ascii_letters) for _ in range(random.randint(min_chars, max_chars))) 39 | vocab.add(word) 40 | return vocab 41 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_22.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_sort 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case22(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_sorting_by_length(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Sort words in a sequence by their length." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_words_vocab() 22 | 23 | def get_max_seq_len(self) -> int: 24 | return 10 25 | 26 | 27 | def make_token_sorting_by_length(sop: rasp.SOp) -> rasp.SOp: 28 | """ 29 | Sorts tokens in a sequence by their length. 30 | 31 | Example usage: 32 | token_sort_len = make_token_sorting_by_length(rasp.tokens) 33 | token_sort_len(["word", "a", "is", "sequence"]) 34 | >> ["a", "is", "word", "sequence"] 35 | """ 36 | token_length = rasp.Map(lambda x: len(x), sop).named("token_length") 37 | sorted_tokens = make_sort(sop, token_length, max_seq_len=10, min_key=1) 38 | return sorted_tokens 39 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/build_main_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from circuits_benchmark.commands.algorithms import run_algorithm 4 | from circuits_benchmark.commands.evaluation import evaluation 5 | from circuits_benchmark.commands.train import train 6 | 7 | 8 | def build_main_parser(): 9 | # define commands for our main script. 10 | parser = ArgumentParserWithOriginals() 11 | subparsers = parser.add_subparsers(dest="command") 12 | subparsers.required = True 13 | 14 | # Setup command arguments 15 | run_algorithm.setup_args_parser(subparsers) 16 | train.setup_args_parser(subparsers) 17 | evaluation.setup_args_parser(subparsers) 18 | 19 | return parser 20 | 21 | 22 | class ArgumentParserWithOriginals(argparse.ArgumentParser): 23 | """ArgumentParser that stores the original arguments.""" 24 | 25 | def parse_args(self, *args, **kwargs): 26 | original_args = list(args[0]) 27 | parsed_args = super().parse_args(*args, **kwargs) 28 | parsed_args.original_args = original_args 29 | return parsed_args 30 | 31 | def parse_known_args(self, *args, **kwargs): 32 | original_args = list(args[0]) 33 | parsed_args, unknown_args = super().parse_known_args(*args, **kwargs) 34 | parsed_args.original_args = original_args 35 | return parsed_args, unknown_args 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_112.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case112(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_difference_to_next() 12 | 13 | def get_task_description(self) -> str: 14 | return "Compute the difference between each element and the next element in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_difference_to_next(): 24 | def shift_by_one() -> rasp.SOp: 25 | # Define a selector for shifting sequence by one 26 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 27 | shifted_selector = rasp.Select(rasp.indices, 28 | rasp.SequenceMap(lambda x, y: x if x < y - 1 else x - 1, rasp.indices, len), 29 | lambda x, y: x - 1 == y) 30 | return rasp.Aggregate(shifted_selector, rasp.tokens) 31 | 32 | shifted = shift_by_one() 33 | return rasp.SequenceMap(lambda x, y: x - y, shifted, rasp.tokens) 34 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_92.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case92(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_zero_if_less_than_previous() 12 | 13 | def get_task_description(self) -> str: 14 | return "Set each element to 0 if it is less than the previous element." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | 20 | def make_zero_if_less_than_previous() -> rasp.SOp: 21 | # Correctly shift the sequence by using Aggregate and Select to create the shifted sequence with the placeholder at the start 22 | shifted_sequence = rasp.Aggregate( 23 | rasp.Select(rasp.indices, rasp.indices, lambda k, q: q == k + 1 or k == 0 and q == 0), 24 | rasp.tokens 25 | ).named("shifted_sequence_with_placeholder") 26 | 27 | # Use SequenceMap to compare each element with its shifted version 28 | zero_if_less_than_previous = rasp.SequenceMap( 29 | lambda original, shifted: 0 if original < shifted else original, 30 | rasp.tokens, 31 | shifted_sequence 32 | ).named("zero_if_less_than_previous") 33 | 34 | return zero_if_less_than_previous 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_10.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case10(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_symmetry_checker(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Check if each word in a sequence is symmetric around its center." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_words_vocab().union({"radar", "rotor"}) 19 | 20 | 21 | def make_token_symmetry_checker(sop: rasp.SOp) -> rasp.SOp: 22 | """ 23 | Checks if each token is symmetric around its center. 24 | 25 | Example usage: 26 | symmetry_checker = make_token_symmetry_checker(rasp.tokens) 27 | symmetry_checker(["radar", "apple", "rotor", "data"]) 28 | >> [True, False, True, False] 29 | """ 30 | half_length = rasp.Map(lambda x: len(x) // 2, sop) 31 | first_half = shift_by(half_length, sop) 32 | second_half = rasp.SequenceMap(lambda x, y: x[:y] == x[:-y - 1:-1], sop, half_length) 33 | symmetry_checker = rasp.SequenceMap(lambda x, y: x if y else None, sop, second_half) 34 | return symmetry_checker 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_14.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.program_evaluation_type import causal_and_regular 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case14(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_count(rasp.tokens, "a") 13 | 14 | def get_task_description(self) -> str: 15 | return "Returns the count of 'a' in the input sequence." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_ascii_letters_vocab(count=3) 22 | 23 | 24 | @causal_and_regular 25 | def make_count(sop, token): 26 | """Returns the count of `token` in `sop`. 27 | 28 | The output sequence contains this count in each position. 29 | 30 | Example usage: 31 | count = make_count(tokens, "a") 32 | count(["a", "a", "a", "b", "b", "c"]) 33 | >> [3, 3, 3, 3, 3, 3] 34 | count(["c", "a", "b", "c"]) 35 | >> [1, 1, 1, 1] 36 | 37 | Args: 38 | sop: Sop to count tokens in. 39 | token: Token to count. 40 | """ 41 | return rasp.SelectorWidth(rasp.Select( 42 | sop, sop, lambda k, q: k == token)).named(f"count_{token}") 43 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_39.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_frac_prevs 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case39(TracrBenchmarkCase): 11 | """Same as Case3 but with increased vocab and max sequence length""" 12 | 13 | def get_program(self) -> rasp.SOp: 14 | is_x = (rasp.tokens == "x").named("is_x") 15 | return make_frac_prevs(is_x) 16 | 17 | def get_task_description(self) -> str: 18 | return "Returns the fraction of 'x' in the input up to the i-th position for all i (longer sequence length)." 19 | 20 | def get_vocab(self) -> Set: 21 | some_letters = vocabs.get_ascii_letters_vocab() 22 | some_letters.add("x") 23 | return some_letters 24 | 25 | def get_max_seq_len(self) -> int: 26 | return 60 27 | 28 | def get_correct_output_for_input(self, input: Sequence) -> Sequence: 29 | """Returns the fraction of 'x' in the input up to the i-th position for all i. 30 | We define this method so that we don't need to call the original program to get the correct output for each input. 31 | """ 32 | return [input[:i + 1].count("x") / (i + 1) for i in range(len(input))] 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_111.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case111(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_last_element() 12 | 13 | def get_task_description(self) -> str: 14 | return "Returns the last element of the sequence and pads the rest with zeros." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_last_element() -> rasp.SOp: 24 | # Generating the length of the sequence 25 | length = rasp.SelectorWidth(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)).named("length") 26 | 27 | # Selector for the last element based on length minus 1 28 | last_element_selector = rasp.Select(rasp.indices, rasp.Map(lambda x: x - 1, length), rasp.Comparison.EQ).named( 29 | "last_element_selector") 30 | 31 | # Broadcasting the last element across the entire sequence 32 | last_element_sequence = rasp.Aggregate(last_element_selector, rasp.tokens).named("last_element_sequence") 33 | 34 | return rasp.SequenceMap(lambda x, y: x if y == 0 else 0, last_element_sequence, rasp.indices) 35 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_56.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case56(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_zero_every_third() 11 | 12 | def get_task_description(self) -> str: 13 | return "Sets every third element to zero." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | 19 | def make_zero_every_third() -> rasp.SOp: 20 | # Step 1: Use rasp.indices to generate a sequence of indices. 21 | 22 | # Step 2: Map over the indices to identify every third element, considering 1-based indexing. 23 | every_third = rasp.Map(lambda x: (x + 1) % 3 == 0, rasp.indices).named("every_third") 24 | 25 | # Step 3: Convert boolean flags (True/False) to 0/1 for easier handling in SequenceMap. 26 | every_third_numerical = rasp.Map(lambda x: 1 if x else 0, every_third).named("every_third_numerical") 27 | 28 | # Step 4: Use SequenceMap to set every third element to 0 and leave others unchanged. 29 | result_sequence = rasp.SequenceMap(lambda x, is_third: 0 if is_third else x, rasp.tokens, 30 | every_third_numerical).named("result_sequence") 31 | 32 | return result_sequence 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_6.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case6(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_oscillation_detector(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Detect oscillation patterns in a numeric sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_digits_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_token_oscillation_detector(sop: rasp.SOp) -> rasp.SOp: 25 | """ 26 | Detects oscillation patterns in a numeric sequence. 27 | 28 | Example usage: 29 | oscillation_detector = make_token_oscillation_detector(rasp.tokens) 30 | oscillation_detector([1, 3, 1, 3, 1]) 31 | >> [True, True, True, True, True] 32 | """ 33 | prev_token = shift_by(1, sop) 34 | next_token = shift_by(-1, sop) 35 | detector_1 = rasp.SequenceMap(lambda x, y: y > x, prev_token, sop) 36 | detector_2 = rasp.SequenceMap(lambda x, y: y > x, sop, next_token) 37 | oscillation_detector = rasp.SequenceMap(lambda x, y: x != y, detector_1, detector_2) 38 | return oscillation_detector 39 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_102.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case102(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_reflect() 12 | 13 | def get_task_description(self) -> str: 14 | return "Reflects each element within a range (default is [2, 7])." 15 | # Reflect means that the values will be projected into the range, "bouncing" from the borders, until they have 16 | # traveled as far in the range as they traveled outside of it." 17 | 18 | def get_vocab(self) -> Set: 19 | return vocabs.get_int_numbers_vocab(min=-20, max=20) 20 | 21 | 22 | def reflect_into_range(max, min, x): 23 | d = max - min 24 | if x > min and x < max: 25 | return x 26 | elif x < min: 27 | delta = min - x 28 | i = (delta // d) % 2 29 | if i == 0: 30 | return min + (delta % d) 31 | else: 32 | return max - (delta % d) 33 | else: 34 | delta = x - max 35 | i = (delta // d) % 2 36 | if i == 1: 37 | return min + (delta % d) 38 | else: 39 | return max - (delta % d) 40 | 41 | 42 | def make_reflect(min_val=2, max_val=7) -> rasp.SOp: 43 | return rasp.Map(lambda x: reflect_into_range(max_val, min_val, x), rasp.tokens) 44 | -------------------------------------------------------------------------------- /tests/get_data_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from circuits_benchmark.benchmark.cases.case_1 import Case1 4 | from circuits_benchmark.benchmark.cases.case_3 import Case3 5 | 6 | 7 | class TestGetCleanData: 8 | def test_get_all_clean_data(self): 9 | case = Case3() 10 | data = case.get_clean_data(max_samples=None, variable_length_seqs=True) 11 | 12 | expected_total_data_len = 320 13 | assert len(data.get_inputs()) == expected_total_data_len 14 | assert case.get_total_data_len() == expected_total_data_len 15 | 16 | def test_get_partial_clean_data(self): 17 | case = Case3() 18 | data = case.get_clean_data(max_samples=10, variable_length_seqs=True) 19 | assert len(data.get_inputs()) == 10 20 | 21 | def test_case_1_should_have_balanced_inputs(self): 22 | case = Case1() 23 | data = case.get_clean_data(max_samples=100, encoded_dataset=False) 24 | outputs = data.get_targets() 25 | 26 | output_encoder = case.get_hl_model().tracr_output_encoder 27 | encoded_outputs: List[List[int]] = [output_encoder.encode(o[3:]) for o in outputs] 28 | 29 | # assert we have 20% outputs of all 1s, 20% all 0s, 60% mixed 30 | assert len([o for o in encoded_outputs if o.count(0) == len(o)]) == 15 31 | assert len([o for o in encoded_outputs if o.count(1) == len(o)]) == 15 32 | assert len([o for o in encoded_outputs if o.count(0) != len(o) and o.count(1) != len(o)]) == 70 33 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_67.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case67(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_multiply_by_length() 12 | 13 | def get_task_description(self) -> str: 14 | return "Multiply each element of the sequence by the length of the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_multiply_by_length() -> rasp.SOp: 24 | # Select all elements by using the TRUE comparison, effectively making every comparison true. 25 | all_true_selector = rasp.Select( 26 | rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") 27 | 28 | # The SelectorWidth operation counts the number of true selections, giving us the length of the sequence. 29 | length_sequence = rasp.SelectorWidth(all_true_selector).named("length_sequence") 30 | 31 | # Use SequenceMap to multiply each element of the original sequence by the length of the sequence. 32 | multiply_by_length_sequence = rasp.SequenceMap( 33 | lambda x, y: x * y, rasp.tokens, length_sequence).named("multiply_by_length_sequence") 34 | 35 | return multiply_by_length_sequence 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_115.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case115(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_product_with_next() 12 | 13 | def get_task_description(self) -> str: 14 | return "Multiply each element of a sequence with the next one." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_product_with_next() -> rasp.SOp: 24 | # Function to shift the sequence by one position 25 | def shift_by_one() -> rasp.SOp: 26 | # Define a selector for shifting sequence by one 27 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 28 | shifted_selector = rasp.Select(rasp.indices, 29 | rasp.SequenceMap(lambda x, y: x if x < y - 1 else x - 1, rasp.indices, len), 30 | lambda x, y: x - 1 == y) 31 | return rasp.Aggregate(shifted_selector, rasp.tokens) 32 | 33 | shifted_sequence = shift_by_one() 34 | # Add the original sequence to the shifted sequence 35 | sum_with_next = rasp.SequenceMap(lambda x, y: x * y, rasp.tokens, shifted_sequence) 36 | 37 | return sum_with_next 38 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_99.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case99(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_sum_with_next() 12 | 13 | def get_task_description(self) -> str: 14 | return "Sum each element with the next one in the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_sum_with_next() -> rasp.SOp: 24 | # Function to shift the sequence by one position 25 | def shift_by_one() -> rasp.SOp: 26 | # Define a selector for shifting sequence by one 27 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 28 | shifted_selector = rasp.Select(rasp.indices, 29 | rasp.SequenceMap(lambda x, y: x if x < y - 1 else x - 1, rasp.indices, len), 30 | lambda x, y: x - 1 == y) 31 | return rasp.Aggregate(shifted_selector, rasp.tokens) 32 | 33 | shifted_sequence = shift_by_one() 34 | # Add the original sequence to the shifted sequence 35 | sum_with_next = rasp.SequenceMap(lambda x, y: x + y if y != 0 else x, rasp.tokens, shifted_sequence) 36 | 37 | return sum_with_next 38 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/circuit/edges_list.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from circuits_benchmark.utils.circuit.circuit import Circuit 4 | from circuits_benchmark.utils.circuit.circuit_node import CircuitNode 5 | 6 | 7 | def edges_list_to_circuit(edges: List[Tuple[str, str]]) -> Circuit: 8 | circuit = Circuit() 9 | 10 | for edge in edges: 11 | start_node_name = edge[0].split("[")[0] 12 | end_node_name = edge[1].split("[")[0] 13 | 14 | if "[" in edge[0]: 15 | start_node_index = int(edge[0].split("[")[1].split("]")[0]) 16 | else: 17 | start_node_index = None 18 | 19 | if "[" in edge[1]: 20 | end_node_index = int(edge[1].split("[")[1].split("]")[0]) 21 | else: 22 | end_node_index = None 23 | 24 | start_node = CircuitNode(start_node_name, start_node_index) 25 | end_node = CircuitNode(end_node_name, end_node_index) 26 | 27 | circuit.add_edge(start_node, end_node) 28 | 29 | return circuit 30 | 31 | 32 | def circuit_to_edges_list(circuit: Circuit) -> List[Tuple[str, str]]: 33 | edges = [] 34 | 35 | for edge in circuit.edges: 36 | start_node_name = edge[0].name 37 | end_node_name = edge[1].name 38 | 39 | if edge[0].index is not None: 40 | start_node_name += f"[{edge[0].index}]" 41 | 42 | if edge[1].index is not None: 43 | end_node_name += f"[{edge[1].index}]" 44 | 45 | edges.append((start_node_name, end_node_name)) 46 | 47 | return edges 48 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_57.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case57(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_element_second() 11 | 12 | def get_task_description(self) -> str: 13 | return "Replaces each element with the second element of the sequence." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_ascii_letters_vocab(count=10) 17 | 18 | def supports_causal_masking(self) -> bool: 19 | return False 20 | 21 | 22 | def make_element_second() -> rasp.SOp: 23 | # Select the second element by matching indices equal to 1. 24 | second_element_selector = rasp.Select( 25 | rasp.indices, # Keys: original indices of the sequence 26 | rasp.Map(lambda x: 1, rasp.indices), # Queries: creating a sequence of 1s to match the index 1 27 | rasp.Comparison.EQ # Predicate: equality check 28 | ).named("second_element_selector") 29 | 30 | # Use Aggregate to fill the sequence with the value of the second element. 31 | second_element_sequence = rasp.Aggregate( 32 | second_element_selector, # Selector that identifies the second element 33 | rasp.tokens, # SOp: the input sequence 34 | # Note: default is None as per task rules 35 | ).named("second_element_sequence") 36 | 37 | return second_element_sequence 38 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_18.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_hist, make_length 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case18(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_frequency_classifier(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Classify each token based on its frequency as 'rare', 'common', or 'frequent'." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_ascii_letters_vocab(count=5) 22 | 23 | 24 | def make_token_frequency_classifier(sop: rasp.SOp) -> rasp.SOp: 25 | """ 26 | Classifies each token based on its frequency as 'rare', 'common', or 'frequent'. 27 | 28 | Example usage: 29 | frequency_classifier = make_token_frequency_classifier(rasp.tokens) 30 | frequency_classifier(["a", "b", "a", "c", "a", "b"]) 31 | >> ["frequent", "common", "frequent", "rare", "frequent", "common"] 32 | """ 33 | frequency = make_hist() 34 | total_tokens = make_length() 35 | frequency_classification = rasp.SequenceMap( 36 | lambda freq, total: "frequent" if freq > total / 2 else ("common" if freq > total / 4 else "rare"), 37 | frequency, total_tokens) 38 | return frequency_classification 39 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_82.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case82(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_halve_second_half() 12 | 13 | def get_task_description(self) -> str: 14 | return "Halve the elements in the second half of the sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_halve_second_half() -> rasp.SOp: 24 | # Calculate the length of the sequence and divide it by 2 to determine the start of the second half. 25 | all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") 26 | length = rasp.SelectorWidth(all_true_selector).named("length") 27 | half_length = rasp.Map(lambda x: x // 2, length).named("half_length") 28 | 29 | # Use Map to create a boolean sequence indicating whether an index is in the second half. 30 | in_second_half = rasp.SequenceMap(lambda idx, half: idx >= half, rasp.indices, half_length).named("in_second_half") 31 | 32 | # Halve the elements in the second half. 33 | halved_sequence = rasp.SequenceMap(lambda x, cond: x / 2 if cond else x, rasp.tokens, in_second_half).named( 34 | "halved_sequence") 35 | 36 | return halved_sequence 37 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/train/compression/compression_training_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import Union 3 | 4 | from transformer_lens.hook_points import HookedRootModule 5 | 6 | 7 | def parse_dimension(dim_value: Union[int, None], 8 | compression_ratio_value: Union[float, None], 9 | max_size: int, 10 | param_name: str) -> int: 11 | # Both can not be set at the same time 12 | assert dim_value is None or compression_ratio_value is None, \ 13 | f"Both {param_name} and {param_name}_compression_ratio can not be set at the same time." 14 | 15 | if dim_value is None and compression_ratio_value is None: 16 | print(f"Warning: {param_name} and {param_name}_compression_ratio are not set. " 17 | f"Using the default value for this case: {max_size}.") 18 | return max_size 19 | 20 | if dim_value is not None: 21 | size = dim_value 22 | else: 23 | size = int(max_size * compression_ratio_value) 24 | 25 | assert 0 < size <= max_size, \ 26 | f"Invalid {param_name} size: {size}. Size must be between 0 and {max_size}." 27 | 28 | return size 29 | 30 | 31 | def parse_d_model(args: Namespace, tl_model: HookedRootModule): 32 | return parse_dimension(args.d_model, args.d_model_compression_ratio, tl_model.cfg.d_model, 'd_model') 33 | 34 | 35 | def parse_d_head(args: Namespace, tl_model: HookedRootModule): 36 | return parse_dimension(args.d_head, args.d_head_compression_ratio, tl_model.cfg.d_head, 'd_head') 37 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Inspiriation from: 2 | # https://medium.com/@albertazzir/blazing-fast-python-docker-builds-with-poetry-a78a66f5aed0 3 | # https://github.com/orgs/python-poetry/discussions/1879#discussioncomment-2255728 4 | 5 | FROM python:3.11-buster as builder 6 | 7 | ENV POETRY_NO_INTERACTION=1 \ 8 | POETRY_VIRTUALENVS_IN_PROJECT=1 \ 9 | POETRY_VIRTUALENVS_CREATE=1 \ 10 | POETRY_CACHE_DIR=/tmp/poetry_cache \ 11 | POETRY_VERSION=1.7.1 12 | 13 | # Install pipx, Poetry and dependencies for poetry install 14 | RUN pip install pipx && \ 15 | pipx install "poetry==$POETRY_VERSION" && \ 16 | apt-get update -q && \ 17 | apt-get install -y --no-install-recommends libgl1-mesa-glx graphviz graphviz-dev && \ 18 | apt-get clean && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | WORKDIR /circuits-benchmark 22 | 23 | COPY pyproject.toml poetry.lock ./ 24 | RUN touch README.md 25 | 26 | RUN --mount=type=cache,target=$POETRY_CACHE_DIR /root/.local/bin/poetry install --no-root 27 | 28 | FROM python:3.11-slim-buster as runtime 29 | 30 | # Install runtime dependencies 31 | RUN apt-get update -q && \ 32 | apt-get install -y --no-install-recommends libgl1-mesa-glx graphviz graphviz-dev tmux && \ 33 | apt-get clean && \ 34 | rm -rf /var/lib/apt/lists/* 35 | 36 | WORKDIR /circuits-benchmark 37 | 38 | ENV VIRTUAL_ENV=/circuits-benchmark/.venv \ 39 | PATH="/circuits-benchmark/.venv/bin:/root/.local/bin/:$PATH" 40 | 41 | COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} 42 | 43 | COPY . /circuits-benchmark 44 | 45 | ENTRYPOINT ["python", "main.py"] 46 | -------------------------------------------------------------------------------- /circuits_benchmark/training/training_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class TrainingArgs(): 7 | verbose: Optional[bool] = False 8 | # Wandb config 9 | wandb_project: Optional[str] = None 10 | wandb_name: Optional[str] = None 11 | 12 | # data management 13 | batch_size: Optional[int] = 256 14 | min_train_samples: Optional[int] = 20_000 15 | max_train_samples: Optional[int] = 120_000 16 | test_data_ratio: Optional[float] = 0.2 # same as train data 17 | 18 | # training time and early stopping 19 | epochs: int = 100 20 | early_stop_threshold: Optional[float] = None 21 | 22 | # AdamW optimizer config 23 | weight_decay: Optional[float] = 0 24 | beta_1: Optional[float] = 0.9 25 | beta_2: Optional[float] = 0.95 26 | gradient_clip: Optional[float] = 1 27 | 28 | # lr scheduler config 29 | lr_start: Optional[float] = 1e-2 30 | lr_factor: Optional[float] = 0.75 31 | lr_patience: Optional[int] = 10 32 | lr_threshold: Optional[float] = 1e-4 33 | 34 | # test metrics config 35 | test_accuracy_atol: Optional[float] = 5e-2 36 | 37 | # resample ablation loss config 38 | resample_ablation_test_loss: Optional[bool] = False 39 | resample_ablation_loss_epochs_gap: Optional[int] = 50 40 | resample_ablation_max_interventions: Optional[int] = 10 41 | resample_ablation_max_components: Optional[int] = 1 42 | resample_ablation_batch_size: Optional[int] = 20000 43 | resample_ablation_loss_weight: Optional[float] = 1 44 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_19.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case19(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_sequential_duplicate_removal(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Removes consecutive duplicate tokens from a sequence." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_ascii_letters_vocab(count=3) 19 | 20 | def get_max_seq_len(self) -> int: 21 | return 15 22 | 23 | 24 | def make_sequential_duplicate_removal(sop: rasp.SOp) -> rasp.SOp: 25 | """ 26 | Removes consecutive duplicate tokens from a sequence. 27 | 28 | Example usage: 29 | duplicate_remove = make_sequential_duplicate_removal(rasp.tokens) 30 | duplicate_remove("aabbcc") 31 | >> ['a', None, 'b', None, 'c', None] 32 | 33 | Args: 34 | sop: SOp representing the sequence to process. 35 | 36 | Returns: 37 | A SOp that maps an input sequence to another sequence where immediate 38 | duplicate occurrences of any token are removed. 39 | """ 40 | shifted_sop = shift_by(1, sop) 41 | duplicate_removal_sop = rasp.SequenceMap( 42 | lambda x, y: x if x != y else None, sop, shifted_sop).named("sequential_duplicate_removal") 43 | return duplicate_removal_sop 44 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_ioi_next_token.py: -------------------------------------------------------------------------------- 1 | from iit.model_pairs.base_model_pair import BaseModelPair 2 | from iit.model_pairs.ioi_model_pair import IOI_ModelPair 3 | from iit.utils.correspondence import Correspondence 4 | from transformer_lens import HookedTransformer 5 | from transformer_lens.hook_points import HookedRootModule 6 | 7 | from circuits_benchmark.benchmark.cases.case_ioi import CaseIOI 8 | 9 | 10 | class CaseIOI_Next_Token(CaseIOI): 11 | def get_task_description(self) -> str: 12 | """Returns the task description for the benchmark case.""" 13 | return "Indirect Object Identification (IOI) task, trained using next token prediction." 14 | 15 | def build_model_pair( 16 | self, 17 | training_args: dict | None = None, 18 | ll_model: HookedTransformer | None = None, 19 | hl_model: HookedRootModule | None = None, 20 | hl_ll_corr: Correspondence | None = None, 21 | *args, **kwargs 22 | ) -> BaseModelPair: 23 | if training_args is None: 24 | training_args = {} 25 | 26 | if ll_model is None: 27 | ll_model = self.get_ll_model() 28 | 29 | if hl_model is None: 30 | hl_model = self.get_hl_model() 31 | 32 | if hl_ll_corr is None: 33 | hl_ll_corr = self.get_correspondence() 34 | 35 | training_args["next_token"] = True 36 | 37 | return IOI_ModelPair( 38 | ll_model=ll_model, 39 | hl_model=hl_model, 40 | corr=hl_ll_corr, 41 | training_args=training_args, 42 | ) 43 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_81.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_length, make_sort 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case81(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_compute_median() 13 | 14 | def get_task_description(self) -> str: 15 | return "" 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_compute_median() -> rasp.SOp: 25 | # Sort the sequence. 26 | sorted_sequence = make_sort(rasp.tokens, rasp.tokens, max_seq_len=100, min_key=1) 27 | 28 | # Compute the length of the sequence. 29 | length = make_length() 30 | 31 | # Compute indices for the middle elements. 32 | middle1 = rasp.Map(lambda x: (x - 1) // 2, length) 33 | middle2 = rasp.Map(lambda x: x // 2, length) 34 | 35 | # Select middle elements based on computed indices. 36 | median1 = rasp.Aggregate(rasp.Select(rasp.indices, middle1, rasp.Comparison.EQ), sorted_sequence) 37 | median2 = rasp.Aggregate(rasp.Select(rasp.indices, middle2, rasp.Comparison.EQ), sorted_sequence) 38 | 39 | # Compute the average of the two middle elements (handles both odd and even-length sequences). 40 | median = rasp.SequenceMap(lambda x, y: (x + y) / 2, median1, median2) 41 | 42 | return median 43 | -------------------------------------------------------------------------------- /EXPERIMENTS.md: -------------------------------------------------------------------------------- 1 | The following commands can be used to replicate the experiments presented in the paper "InterpBench: Semi-Synthetic Transformers for Evaluating Mechanistic Interpretability Techniques". 2 | 3 | For training the SIIT models on Tracr tasks (where `-i 3` is the index of the task), and training the IOI model: 4 | - `python main.py train iit -i 3 --epochs 500 --model-pair strict -iit 1 -s 0.4 -b 1` 5 | - `python main.py train ioi --include-mlp --next-token --epochs 10 --save-to-wandb` 6 | 7 | For evaluating the effect of nodes and the accuracy after ablating everything but ground truth circuit: 8 | - `python main.py eval iit -i 3 --categorical-metric kl_div -w best` 9 | - `python main.py eval ioi --next-token --include-mlp` 10 | - `python main.py eval gt_node_realism -i 3 --mean -w best --relative 1` 11 | 12 | 13 | For running the performance evaluation of circuit discovery techniques: 14 | - `python main.py run sp --loss-type l2 -i 3 --torch-num-threads 0 --device cpu --epochs 500 --atol 0.1` 15 | - `python main.py run sp --loss-type l2 -i 3 --torch-num-threads 0 --device cpu --epochs 500 --atol 0.1 --edgewise` 16 | - `python main.py eval iit_acdc -i 2 -w 100 -t 0.0 --load-from-wandb` 17 | 18 | For running the experiments on realism: 19 | - `python main.py eval node_realism -i 3 --mean --relative 1 --algorithm acdc --tracr -t 0` 20 | - `python main.py eval node_realism -i 3 --mean --relative 1 --algorithm node_sp -t 0` 21 | - `python main.py eval node_realism -i 3 --mean --relative 1 --algorithm edge_sp -t 0` 22 | - `python main.py eval ioi_acdc --data-size 10 --max-num-epochs 1 threshold 1000.0 --next-token --include-mlp` -------------------------------------------------------------------------------- /tests/ground_truth_circuit_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from acdc.TLACDCCorrespondence import TLACDCCorrespondence 4 | 5 | from circuits_benchmark.benchmark.cases.case_ioi import CaseIOI 6 | from circuits_benchmark.utils.circleci import is_running_in_circleci, get_circleci_cases_percentage 7 | from circuits_benchmark.utils.circuit.circuit_eval import build_from_acdc_correspondence 8 | from circuits_benchmark.utils.get_cases import get_cases 9 | from circuits_benchmark.utils.iit._acdc_utils import get_gt_circuit 10 | 11 | 12 | class TestGroundTruthCircuit: 13 | 14 | def test_gt_circuit_for_all_cases(self): 15 | cases = get_cases() 16 | cases = [case for case in cases if not isinstance(case, CaseIOI)] # remove ioi cases 17 | 18 | if is_running_in_circleci(): 19 | # randomly select a subset of the cases to run on CircleCI (no replacement) 20 | cases = random.sample(cases, int(get_circleci_cases_percentage() * len(cases))) 21 | 22 | for case in cases: 23 | full_corr = TLACDCCorrespondence.setup_from_model(case.get_ll_model()) 24 | full_circuit = build_from_acdc_correspondence(corr=full_corr) 25 | 26 | corr = case.get_correspondence() 27 | assert corr is not None, f"Case {case} has no correspondence" 28 | 29 | gt_circuit = get_gt_circuit( 30 | hl_ll_corr=corr, 31 | full_circuit=full_circuit, 32 | n_heads=case.get_ll_model().cfg.n_heads, 33 | case=case, 34 | ) 35 | 36 | assert len(gt_circuit.edges) > 0, f"Case {case} has no edges in gt_circuit" 37 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_38.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case38(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_alternation_checker(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Checks if tokens alternate between two types." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_ascii_letters_vocab(count=3) 22 | 23 | 24 | def make_token_alternation_checker(sop: rasp.SOp) -> rasp.SOp: 25 | """ 26 | Checks if tokens alternate between two types. 27 | 28 | Example usage: 29 | alternation_checker = make_token_alternation_checker(rasp.tokens) 30 | alternation_checker(["cat", "dog", "cat", "dog"]) 31 | >> [True, True, True, True] 32 | """ 33 | prev_token = shift_by(1, sop) 34 | next_token = shift_by(-1, sop) 35 | 36 | prev_token_neq_orig = rasp.SequenceMap(lambda x, y: x != y, prev_token, sop).named("prev_token_neq_orig") 37 | next_token_neq_orig = rasp.SequenceMap(lambda x, y: x != y, sop, next_token).named("next_token_neq_orig") 38 | alternation_checker = rasp.SequenceMap(lambda x, y: x and y, 39 | prev_token_neq_orig, next_token_neq_orig).named("alternation_checker") 40 | return alternation_checker 41 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/tracr_encoded_dataset.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from torch import Tensor 3 | from torch.utils.data import DataLoader 4 | 5 | from circuits_benchmark.benchmark.case_dataset import CaseDataset 6 | 7 | 8 | class TracrEncodedDataset(CaseDataset): 9 | """Same as TracrDataset, but with encoded inputs and outputs (i.e., tensors instead of numpy arrays).""" 10 | 11 | def __init__(self, inputs: Tensor, targets: Tensor): 12 | self.inputs = inputs 13 | self.targets = targets 14 | 15 | def __len__(self): 16 | return len(self.inputs) 17 | 18 | def __getitem__(self, idx): 19 | return self.inputs[idx], self.targets[idx] 20 | 21 | def get_inputs(self): 22 | return self.inputs 23 | 24 | def get_targets(self): 25 | return self.targets 26 | 27 | @staticmethod 28 | def collate_fn(batch, device: t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")): 29 | inputs = t.stack([x[0] for x in batch]) 30 | targets = t.stack([x[1] for x in batch]) 31 | return inputs.to(device=device), targets.to(device=device) 32 | 33 | def make_loader( 34 | self, 35 | batch_size: int | None = None, 36 | shuffle: bool | None = False, 37 | device: str | t.device = t.device("cuda") if t.cuda.is_available() else t.device("cpu"), 38 | num_workers: int = 0, 39 | ) -> DataLoader: 40 | return DataLoader( 41 | self, 42 | batch_size=batch_size, 43 | shuffle=shuffle, 44 | num_workers=num_workers, 45 | collate_fn=lambda x: self.collate_fn(x, device=device), 46 | ) 47 | -------------------------------------------------------------------------------- /tests/hooked_tracr_transformer_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | 4 | import jax 5 | from tracr.compiler import compiling 6 | from tracr.rasp import rasp 7 | 8 | from circuits_benchmark.benchmark.common_programs import make_reverse 9 | from circuits_benchmark.benchmark.vocabs import TRACR_BOS, TRACR_PAD 10 | from circuits_benchmark.transformers.hooked_tracr_transformer import HookedTracrTransformer 11 | 12 | # The default of float16 can lead to discrepancies between outputs of 13 | # the compiled model and the RASP program. 14 | jax.config.update('jax_default_matmul_precision', 'float32') 15 | logging.basicConfig(level=logging.ERROR) 16 | 17 | 18 | class HookedTracrTransformerTest(unittest.TestCase): 19 | def test_no_exception(self): 20 | # Fetch RASP program 21 | program = make_reverse(rasp.tokens) 22 | 23 | # Compile it to a transformer model 24 | tracr_output = compiling.compile_rasp_to_model( 25 | program, 26 | vocab={1, 2, 3}, 27 | max_seq_len=5, 28 | compiler_bos=TRACR_BOS, 29 | compiler_pad=TRACR_PAD, 30 | ) 31 | tracr_model = tracr_output.model 32 | 33 | input = [TRACR_BOS, 1, 2, 3, TRACR_PAD] 34 | print("Input:", input) 35 | 36 | tracr_output_decoded = tracr_model.apply(input).decoded 37 | print("Original Decoding:", tracr_output_decoded) 38 | 39 | tl_model = HookedTracrTransformer.from_tracr_model(tracr_model) 40 | tl_output_decoded = tl_model([input], return_type="decoded")[0] 41 | print("TransformerLens Replicated Decoding:", tl_output_decoded) 42 | 43 | self.assertEqual(tracr_output_decoded, tl_output_decoded) 44 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/auto_circuit_utils.py: -------------------------------------------------------------------------------- 1 | from auto_circuit.types import PruneScores 2 | from auto_circuit.utils.patchable_model import PatchableModel 3 | 4 | from circuits_benchmark.utils.circuit.circuit import Circuit 5 | from circuits_benchmark.utils.circuit.circuit_node import CircuitNode 6 | 7 | 8 | def build_circuit(model: PatchableModel, 9 | attribution_scores: PruneScores, 10 | threshold: float, 11 | abs_val_threshold: bool = False) -> Circuit: 12 | """Build a circuit out of the auto_circuit output.""" 13 | circuit = Circuit() 14 | 15 | for edge in model.edges: 16 | src_node = edge.src 17 | dst_node = edge.dest 18 | score = attribution_scores[dst_node.module_name][edge.patch_idx] 19 | if abs_val_threshold: 20 | score = abs(score) 21 | if score > threshold: 22 | from_node = CircuitNode(src_node.module_name, src_node.head_idx) 23 | to_node = CircuitNode(dst_node.module_name, dst_node.head_idx) 24 | circuit.add_edge(from_node, to_node) 25 | 26 | return circuit 27 | 28 | 29 | def build_normalized_scores(attribution_scores: PruneScores) -> PruneScores: 30 | """Normalize the scores so that they all lie between 0 and 1.""" 31 | max_score = max(scores.max() for scores in attribution_scores.values()) 32 | min_score = min(scores.min() for scores in attribution_scores.values()) 33 | 34 | normalized_scores = attribution_scores.copy() 35 | for module_name, scores in normalized_scores.items(): 36 | normalized_scores[module_name] = (normalized_scores[module_name] - min_score) / (max_score - min_score) 37 | 38 | return normalized_scores 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "circuits-benchmark" 3 | version = "0.1.0" 4 | description = "A benchmark for mechanistic discovery of circuits in Transformers" 5 | authors = ["Iván Arcuschin Moreno ", "Niels uit de Bos "] 6 | readme = "README.md" 7 | packages = [{include = "circuits_benchmark"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | numpy = [{ version = "^1.21", python = "<3.10" }, 12 | { version = "^1.26", python = ">=3.10" }] 13 | torch = ">=2.2.0" 14 | datasets = "^2.17" 15 | transformers = "^4.37.0" 16 | tokenizers = "^0.15.0" 17 | tqdm = "^4.66" 18 | pandas = "2.1.4" 19 | wandb = "^0.16" 20 | torchtyping = "^0.1.4" 21 | huggingface-hub = "^0.24.0" 22 | cmapy = "^0.6.6" 23 | networkx = "^3.1" 24 | plotly = "^5.12.0" 25 | kaleido = "0.2.1" 26 | pygraphviz = "^1.11" 27 | transformer-lens = "1.19.0" 28 | typer = "^0.9.0" 29 | cloudpickle = "^3.0.0" 30 | argparse-dataclass = "^2.0.0" 31 | chex = "^0.1.85" 32 | dm-haiku = "^0.0.11" 33 | dataframe-image = "^0.2.3" 34 | mlcroissant = "^1.0.5" 35 | matplotlib = "3.8.2" 36 | auto-circuit = { git = "https://github.com/FlyingPumba/auto-circuit.git" } 37 | tracr = { git = "https://github.com/FlyingPumba/tracr.git" } 38 | iit = { git = "https://github.com/cybershiptrooper/iit.git" } 39 | acdc = { git = "https://github.com/FlyingPumba/Automatic-Circuit-Discovery.git" } 40 | 41 | [tool.poetry.group.dev.dependencies] 42 | pytest = "^8.2.2" 43 | pytest-cov = "^4.0.0" 44 | 45 | [build-system] 46 | requires = ["poetry-core"] 47 | build-backend = "poetry.core.masonry.api" 48 | 49 | [tool.black] 50 | line-length = 120 51 | 52 | [tool.isort] 53 | profile = "black" 54 | line_length = 120 55 | skip_gitignore = true 56 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_130.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case130(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_clip() 12 | 13 | def get_task_description(self) -> str: 14 | return "Clips each element to be within a range (make the default range [2, 7])." 15 | # "Clipping" means that values outside of the range, are turned into the lower or upper bound, whichever is closer. 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab(min=-15, max=15) 19 | 20 | 21 | def make_clip(min_val=2, max_val=7) -> rasp.SOp: 22 | # Map all elements to min_val (to use in case of less than min_val) 23 | all_min_val = rasp.Map(lambda x: min_val, rasp.tokens) 24 | # Map all elements to max_val (to use in case of greater than max_val) 25 | all_max_val = rasp.Map(lambda x: max_val, rasp.tokens) 26 | 27 | # Compare each element to min_val and max_val 28 | less_than_min = rasp.Map(lambda x: x < min_val, rasp.tokens) 29 | greater_than_max = rasp.Map(lambda x: x > max_val, rasp.tokens) 30 | 31 | # Apply clipping: first, clip to min_val if less than min_val 32 | clip_min = rasp.SequenceMap(lambda orig, clip: clip if orig < min_val else orig, rasp.tokens, all_min_val) 33 | # Then, clip to max_val if greater than max_val 34 | clip_max = rasp.SequenceMap(lambda clipped_min, clip: clip if clipped_min > max_val else clipped_min, clip_min, 35 | all_max_val) 36 | 37 | return clip_max 38 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_98.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_length 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case98(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_max_element() 13 | 14 | def get_task_description(self) -> str: 15 | return "Return a sequence with the maximum element repeated." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_max_element() -> rasp.SOp: 25 | # A selector comparing each element with every other element using LEQ (less than or equal) 26 | unique_tokens = rasp.SequenceMap(lambda x, y: x + y * 0.00000000000001, rasp.tokens, rasp.indices) 27 | leq_selector = rasp.Select(unique_tokens, unique_tokens, rasp.Comparison.LEQ).named("leq_selector") 28 | # Counting the number of elements each element is less than or equal to 29 | leq_count = rasp.SelectorWidth(leq_selector).named("leq_count") 30 | # The maximum element is the one that is less or equal to all elements (count equal to sequence length) 31 | length_sop = make_length() 32 | max_element_selector = rasp.Select(leq_count, length_sop, rasp.Comparison.EQ).named("max_element_selector") 33 | # Using Aggregate to select the maximum element and broadcast it across the entire sequence 34 | max_sequence = rasp.Aggregate(max_element_selector, rasp.tokens).named("max_sequence") 35 | return max_sequence 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_113.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case113(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_invert_if_sorted() 12 | 13 | def get_task_description(self) -> str: 14 | return "Inverts the sequence if it is sorted in ascending order, otherwise leaves it unchanged." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=30) 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_invert_if_sorted(): 24 | shifter = rasp.Select(rasp.indices, rasp.indices, lambda x, y: x == y - 1 or (x == 0 and y == 0)) 25 | shifted = rasp.Aggregate(shifter, rasp.tokens) 26 | checks = rasp.SequenceMap(lambda x, y: 1 if x <= y else 0, shifted, rasp.tokens) 27 | zero_selector = rasp.Select(checks, rasp.Map(lambda x: 0, rasp.indices), rasp.Comparison.EQ) 28 | invert_decider = rasp.Map(lambda x: 1 if x > 0 else -1, rasp.SelectorWidth(zero_selector)) 29 | avg_idx = rasp.Map(lambda x: x / 2 - 0.5, 30 | rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE))) 31 | diff_to_avg_idx = rasp.SequenceMap(lambda x, y: x - y, rasp.indices, avg_idx) 32 | inverter = rasp.SequenceMap(lambda x, y: x + y, avg_idx, 33 | rasp.SequenceMap(lambda x, y: x * y, invert_decider, diff_to_avg_idx)) 34 | invert_selector = rasp.Select(inverter, rasp.indices, rasp.Comparison.EQ) 35 | return rasp.Aggregate(invert_selector, rasp.tokens) 36 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_59.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case59(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_sorting() 11 | 12 | def get_task_description(self) -> str: 13 | return "Sorts the sequence." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | def supports_causal_masking(self) -> bool: 19 | return False 20 | 21 | 22 | def make_sorting() -> rasp.SOp: 23 | # Create unique keys by combining each element with its index 24 | # This ensures that even duplicate values can be sorted correctly 25 | unique_keys = rasp.SequenceMap(lambda x, i: x + i * 0.00001, rasp.tokens, rasp.indices).named("unique_keys") 26 | 27 | # Create a selector that identifies where each unique key is less than every other unique key 28 | lt_selector = rasp.Select(unique_keys, unique_keys, rasp.Comparison.LT).named("lt_selector") 29 | 30 | # Count the number of elements that each unique key is less than 31 | # This count determines the sorted position of each element in the output sequence 32 | sorted_position = rasp.SelectorWidth(lt_selector).named("sorted_position") 33 | 34 | # Place each element into its sorted position by matching each element's sort position with the output sequence's indices 35 | sorted_sequence_selector = rasp.Select(sorted_position, rasp.indices, rasp.Comparison.EQ).named( 36 | "sorted_sequence_selector") 37 | sorted_sequence = rasp.Aggregate(sorted_sequence_selector, rasp.tokens).named("sorted_sequence") 38 | 39 | return sorted_sequence 40 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_45.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case45(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_double_first_half() 12 | 13 | def get_task_description(self) -> str: 14 | return "Doubles the first half of the sequence" 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_double_first_half() -> rasp.SOp: 24 | # Calculate the length of the sequence and store it in a constant sequence. 25 | length = rasp.SelectorWidth(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)).named("length") 26 | 27 | # Create a Map operation that turns indices into 1 if they are in the first half, 0 otherwise. 28 | # Note: We use a trick here by utilizing a SequenceMap that compares indices against half of the length sequence, but as we cannot perform division directly on SOps, we prepare the length beforehand. 29 | first_half_selector = rasp.SequenceMap(lambda idx, length: 1 if idx < length / 2 else 0, rasp.indices, 30 | length).named("first_half_selector") 31 | 32 | # Apply doubling conditionally: Multiply each element by (1 or 2) based on the first_half_selector. 33 | # This step combines the original sequence with the selector sequence to apply the doubling only to the first half. 34 | double_first_half = rasp.SequenceMap(lambda x, sel: x * (1 + sel), rasp.tokens, first_half_selector).named( 35 | "double_first_half") 36 | 37 | return double_first_half 38 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | python: circleci/python@2 5 | 6 | jobs: 7 | test: 8 | docker: 9 | - image: cimg/python:3.11.6 10 | steps: 11 | - checkout 12 | - run: 13 | name: Add GitHub as known host 14 | command: mkdir -p ~/.ssh && ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts 15 | - run: 16 | name: Install pygraphviz 17 | command: sudo apt-get update && sudo apt-get install -y graphviz libgraphviz-dev 18 | - python/install-packages: 19 | pkg-manager: poetry 20 | - run: 21 | name: Run tests 22 | command: poetry run pytest --durations=0 --junitxml=junit.xml || ((($? == 5)) && echo 'Did not find any tests to run.') 23 | - store_test_results: 24 | path: junit.xml 25 | deploy: 26 | docker: 27 | - image: cimg/python:3.11.6 28 | steps: 29 | - checkout 30 | - setup_remote_docker: 31 | docker_layer_caching: true 32 | - run: 33 | name: Add GitHub as known host 34 | command: mkdir -p ~/.ssh && ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts 35 | - run: 36 | name: Build Docker image 37 | command: docker build . -t iarcuschin/circuits-benchmark 38 | - run: 39 | name: Publish Docker image 40 | command: | 41 | echo "${DOCKERHUB_PASS}" | docker login --username "${DOCKERHUB_USERNAME}" --password-stdin 42 | docker push iarcuschin/circuits-benchmark 43 | 44 | workflows: 45 | test-and-deploy: 46 | jobs: 47 | - test: 48 | filters: 49 | branches: 50 | ignore: /.*experiments$/ 51 | - deploy: 52 | filters: 53 | branches: 54 | only: 55 | - main 56 | requires: 57 | - test -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from circuits_benchmark.commands.build_main_parser import build_main_parser 4 | from circuits_benchmark.commands.train import train 5 | from circuits_benchmark.utils.project_paths import get_default_output_dir 6 | 7 | 8 | def setup_iit_models(): 9 | # detect if SIIT and Natural model for Case 3 are available, and train them if not 10 | output_dir = get_default_output_dir() 11 | 12 | natural_model_path = f"{output_dir}/ll_models/3/ll_model_100.pth" 13 | if not os.path.exists(natural_model_path): 14 | # train natural model 15 | args, _ = build_main_parser().parse_known_args(["train", 16 | "iit", 17 | "-i=3", 18 | "--epochs=0", 19 | "-s=0", 20 | "-iit=0", 21 | "--num-samples=10", 22 | "--device=cpu"]) 23 | train.run(args) 24 | assert os.path.exists(natural_model_path) 25 | 26 | siit_model_path = f"{output_dir}/ll_models/3/ll_model_510.pth" 27 | if not os.path.exists(siit_model_path): 28 | # train SIIT model 29 | args, _ = build_main_parser().parse_known_args(["train", 30 | "iit", 31 | "-i=3", 32 | "--epochs=0", 33 | "--num-samples=10", 34 | "--device=cpu"]) 35 | train.run(args) 36 | assert os.path.exists(siit_model_path) -------------------------------------------------------------------------------- /circuits_benchmark/commands/algorithms/run_algorithm.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from circuits_benchmark.commands.algorithms import acdc, legacy_acdc, eap, sp 4 | from circuits_benchmark.utils.get_cases import get_cases 5 | from circuits_benchmark.utils.ll_model_loader.ll_model_loader_factory import get_ll_model_loader_from_args 6 | 7 | 8 | def setup_args_parser(subparsers): 9 | run_parser = subparsers.add_parser("run") 10 | run_subparsers = run_parser.add_subparsers(dest="algorithm") 11 | run_subparsers.required = True 12 | 13 | # Setup arguments for each algorithm 14 | legacy_acdc.LegacyACDCRunner.setup_subparser(run_subparsers) 15 | acdc.ACDCRunner.setup_subparser(run_subparsers) 16 | sp.SPRunner.setup_subparser(run_subparsers) 17 | eap.EAPRunner.setup_subparser(run_subparsers) 18 | 19 | 20 | def run(args): 21 | for case in get_cases(args): 22 | print(f"\nRunning {args.algorithm} on {case}") 23 | 24 | ll_model_loader = get_ll_model_loader_from_args(case, args) 25 | 26 | try: 27 | if args.algorithm == "legacy_acdc": 28 | legacy_acdc.LegacyACDCRunner(case, args=args).run_using_model_loader(ll_model_loader) 29 | elif args.algorithm == "acdc": 30 | acdc.ACDCRunner(case, args=args).run_using_model_loader(ll_model_loader) 31 | elif args.algorithm == "sp": 32 | sp.SPRunner(case, args=args).run_using_model_loader(ll_model_loader) 33 | elif args.algorithm == "eap": 34 | eap.EAPRunner(case, args=args).run_using_model_loader(ll_model_loader) 35 | else: 36 | raise ValueError(f"Unknown algorithm: {args.algorithm}") 37 | except Exception as e: 38 | print(f" >>> Failed to run {args.algorithm} algorithm on case {case}:") 39 | traceback.print_exc() 40 | continue 41 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_15.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.program_evaluation_type import causal_and_regular 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case15(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_nary_sequencemap(lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices) 13 | 14 | def get_task_description(self) -> str: 15 | return "Returns each token multiplied by two and subtracted by its index." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_digits_vocab(count=5) 19 | 20 | def get_max_seq_len(self) -> int: 21 | return 5 22 | 23 | 24 | @causal_and_regular 25 | def make_nary_sequencemap(f, *sops): 26 | """Returns an SOp that simulates an n-ary SequenceMap. 27 | 28 | Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n 29 | into a single SOp arguments that takes n-tuples as value. The n-ary sequence 30 | map implementing f is then a Map on this resulting SOp. 31 | 32 | Note that the intermediate variables representing tuples of varying length 33 | will be encoded categorically, and can become very high-dimensional. So, 34 | using this function might lead to very large compiled models. 35 | 36 | Args: 37 | f: Function with n arguments. 38 | *sops: Sequence of SOps, one for each argument of f. 39 | """ 40 | values, *sops = sops 41 | for sop in sops: 42 | # x is a single entry in the first iteration but a tuple in later iterations 43 | values = rasp.SequenceMap( 44 | lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop) 45 | return rasp.Map(lambda args: f(*args), values) 46 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_58.py: -------------------------------------------------------------------------------- 1 | from typing import Set, Sequence 2 | 3 | from circuits_benchmark.benchmark import vocabs 4 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 5 | from tracr.rasp import rasp 6 | 7 | 8 | class Case58(TracrBenchmarkCase): 9 | def get_program(self) -> rasp.SOp: 10 | return make_mirror_first_half() 11 | 12 | def get_task_description(self) -> str: 13 | return "Mirrors the first half of the sequence to the second half." 14 | 15 | def get_vocab(self) -> Set: 16 | return vocabs.get_int_numbers_vocab() 17 | 18 | def supports_causal_masking(self) -> bool: 19 | return False 20 | 21 | 22 | def make_mirror_first_half() -> rasp.SOp: 23 | # Create a selector for all elements (used to calculate the sequence length) 24 | all_true_selector = rasp.Select( 25 | rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") 26 | length = rasp.SelectorWidth(all_true_selector).named("length") 27 | 28 | # Creating a selector that selects the first half of the sequence 29 | first_half_selector = rasp.Select( 30 | rasp.indices, 31 | rasp.Map(lambda x: x // 2, length), 32 | rasp.Comparison.LT 33 | ).named("first_half_selector") 34 | 35 | # Creating a selector for reversing the indices for the second half of the sequence 36 | mirror_selector = rasp.Select( 37 | rasp.indices, 38 | rasp.SequenceMap(lambda x, l: l - 1 - x if x >= l // 2 else x, rasp.indices, length), 39 | rasp.Comparison.EQ 40 | ).named("mirror_selector") 41 | 42 | # Aggregate using the mirror selector to mirror the first half onto the second half 43 | mirrored_sequence = rasp.Aggregate( 44 | mirror_selector, 45 | rasp.tokens, 46 | default=None 47 | ).named("mirrored_sequence") 48 | 49 | return mirrored_sequence 50 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/find_all_subclasses.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | 4 | 5 | def import_submodules(package, recursive: bool = True): 6 | """ Import all submodules of a module, recursively, including subpackages 7 | 8 | :param package: package (name or actual module) 9 | :type package: str | module 10 | :rtype: dict[str, types.ModuleType] 11 | """ 12 | if isinstance(package, str): 13 | package = importlib.import_module(package) 14 | results = {} 15 | for loader, name, is_pkg in pkgutil.walk_packages(package.__path__): 16 | full_name = package.__name__ + '.' + name 17 | try: 18 | results[full_name] = importlib.import_module(full_name) 19 | except ModuleNotFoundError: 20 | continue 21 | if recursive and is_pkg: 22 | results.update(import_submodules(full_name)) 23 | return results 24 | 25 | 26 | def find_all_subclasses_in_package(base_class, package_name): 27 | """ Find all subclasses of a given class in a package. 28 | We first need to import all submodules of the package, then we can use the `__subclasses__` method of the base class. 29 | Otherwise, the subclasses might not be loaded yet, and we will miss them. 30 | """ 31 | import_submodules(package_name) 32 | return base_class.__subclasses__() 33 | 34 | 35 | def find_all_transitive_subclasses_in_package(base_class, package_name): 36 | """ Find all transitive subclasses of a given class in a package. 37 | """ 38 | subclasses = set(find_all_subclasses_in_package(base_class, package_name)) 39 | 40 | last_subclasses = None 41 | new_subclasses = subclasses.copy() 42 | while new_subclasses != last_subclasses: 43 | last_subclasses = new_subclasses.copy() 44 | for cls in last_subclasses: 45 | new_subclasses.update(cls.__subclasses__()) 46 | 47 | return new_subclasses 48 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import traceback 3 | 4 | import numpy as np 5 | import torch as t 6 | 7 | from circuits_benchmark.commands.evaluation.iit import iit_eval 8 | from circuits_benchmark.commands.evaluation.realism import node_wise_ablation, gt_circuit_node_wise_ablation 9 | from circuits_benchmark.utils.get_cases import get_cases 10 | 11 | 12 | def setup_args_parser(subparsers): 13 | run_parser = subparsers.add_parser("eval") 14 | run_subparsers = run_parser.add_subparsers(dest="type") 15 | run_subparsers.required = True 16 | 17 | # Setup arguments for each evaluation type 18 | iit_eval.setup_args_parser(run_subparsers) 19 | node_wise_ablation.setup_args_parser(run_subparsers) 20 | gt_circuit_node_wise_ablation.setup_args_parser(run_subparsers) 21 | 22 | 23 | def run(args): 24 | evaluation_type = args.type 25 | for case in get_cases(args): 26 | print(f"\nRunning evaluation {evaluation_type} on {case}") 27 | 28 | # Set numpy, torch and ptyhon seed 29 | seed = args.seed 30 | assert seed is not None, "Seed is always required" 31 | np.random.seed(args.seed) 32 | t.manual_seed(seed) 33 | random.seed(seed) 34 | 35 | try: 36 | if evaluation_type == "iit": 37 | iit_eval.run_iit_eval(case, args) 38 | elif evaluation_type == "node_realism": 39 | node_wise_ablation.run_nodewise_ablation(case, args) 40 | elif evaluation_type == "gt_node_realism": 41 | gt_circuit_node_wise_ablation.run_nodewise_ablation(case, args) 42 | else: 43 | raise ValueError(f"Unknown evaluation: {evaluation_type}") 44 | except Exception as e: 45 | print(f" >>> Failed to run {evaluation_type} evaluation on case {case}:") 46 | traceback.print_exc() 47 | continue 48 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_13.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import shift_by 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case13(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_token_trend_analysis(rasp.tokens) 13 | 14 | def get_task_description(self) -> str: 15 | return "Analyzes the trend (increasing, decreasing, constant) of numeric tokens." 16 | 17 | def supports_causal_masking(self) -> bool: 18 | return False 19 | 20 | def get_vocab(self) -> Set: 21 | return vocabs.get_int_digits_vocab(count=3) 22 | 23 | def supports_causal_masking(self) -> bool: 24 | return False 25 | 26 | 27 | def make_token_trend_analysis(sop: rasp.SOp) -> rasp.SOp: 28 | """ 29 | Analyzes the trend (increasing, decreasing, constant) of numeric tokens. 30 | 31 | Example usage: 32 | trend_analysis = make_token_trend_analysis(rasp.tokens) 33 | trend_analysis([1, 2, 3, 3, 2, 1]) 34 | >> ["increasing", "increasing", "constant", "decreasing", "decreasing", None] 35 | """ 36 | next_token = shift_by(-1, sop) # [2, 3, 3, 2, 1, None] 37 | 38 | def second_part_fn(curr, next): 39 | if curr < next: 40 | return "increasing" 41 | elif curr > next: 42 | return "decreasing" 43 | else: 44 | return "constant" 45 | 46 | # Compare the current token with the next token to produce the trend analysis. 47 | # Curr: [1, 2, 3, 3, 2, 1] 48 | # Next: [2, 3, 3, 2, 1, None] 49 | # Result: ["increasing", "increasing", "constant", "decreasing", "decreasing", None] 50 | trend_analysis = rasp.SequenceMap(second_part_fn, sop, next_token) 51 | 52 | return trend_analysis 53 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_96.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case96(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_remove_duplicates() 12 | 13 | def get_task_description(self) -> str: 14 | return "Set duplicates to 0, keep the first occurrences unchanged." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(min=1) 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_remove_duplicates() -> rasp.SOp: 24 | # Compare each element with every other element for equality 25 | eq_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ).named("eq_selector") 26 | 27 | # Count the number of elements that each element is equal to (including itself) 28 | eq_count = rasp.SelectorWidth(eq_selector).named("eq_count") 29 | 30 | # Create an index sequence to ensure that for every element, counts are considered only up to its position (i.e., ignore future duplicates) 31 | index = rasp.Map(lambda x: x + 1, rasp.indices).named("index") 32 | adjusted_count = rasp.SequenceMap(lambda count, idx: count if count <= idx else 0, eq_count, index).named( 33 | "adjusted_count") 34 | 35 | # Identify the first occurrences and duplicates. First occurrences will have an adjusted count of 1, replace others with 0. 36 | first_or_zero = rasp.Map(lambda x: 1 if x == 1 else 0, adjusted_count).named("first_or_zero") 37 | 38 | # Replace duplicates (indicated by 0) in the original sequence with 0, keep the first occurrences unchanged 39 | remove_duplicates = rasp.SequenceMap(lambda original, flag: original if flag == 1 else 0, rasp.tokens, 40 | first_or_zero).named("remove_duplicates") 41 | 42 | return remove_duplicates 43 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_88.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case88(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_average_first_last() 12 | 13 | def get_task_description(self) -> str: 14 | return "Calculate the average of the first and last elements of a sequence." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab() 18 | 19 | def supports_causal_masking(self) -> bool: 20 | return False 21 | 22 | 23 | def make_average_first_last() -> rasp.SOp: 24 | # Assuming the first element is at index 0, which is always true 25 | # Selector for the first element 26 | first_elem_selector = rasp.Select(rasp.indices, rasp.indices, lambda x, y: x == 0).named("first_elem_selector") 27 | first_elem = rasp.Aggregate(first_elem_selector, rasp.tokens, default=None).named("first_elem") 28 | 29 | # Assuming the last element can be simulated by reversing the sequence and then selecting the first element 30 | # Reverse the sequence 31 | reverser = rasp.SequenceMap(lambda x, y: y - x - 1, rasp.indices, 32 | rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE))) 33 | reverse_selector = rasp.Select(reverser, rasp.indices, rasp.Comparison.EQ) 34 | reversed_sequence = rasp.Aggregate(reverse_selector, rasp.tokens).named("reversed_sequence") 35 | # Select the first element which is effectively the last element of the original sequence 36 | last_elem = rasp.Aggregate(first_elem_selector, reversed_sequence, default=None).named("last_elem") 37 | 38 | # Calculate the average of the first and last elements 39 | average_first_last = rasp.SequenceMap(lambda x, y: (x + y) / 2.0, first_elem, last_elem).named("average_first_last") 40 | 41 | return average_first_last 42 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_127.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case127(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_element_divide() 12 | 13 | def get_task_description(self) -> str: 14 | return "Divides each element by the division of the first two elements." 15 | # If either the first or second element are zero, or if the sequence has fewer than two entries, you should just 16 | # return the original sequence." 17 | 18 | def get_vocab(self) -> Set: 19 | return vocabs.get_int_numbers_vocab() 20 | 21 | def supports_causal_masking(self) -> bool: 22 | return False 23 | 24 | 25 | def make_element_divide() -> rasp.SOp: 26 | # Step 1: Select the first element 27 | first_elem_selector = rasp.Select( 28 | rasp.indices, 29 | rasp.Map(lambda x: 0, rasp.indices), 30 | rasp.Comparison.EQ 31 | ).named("first_elem_selector") 32 | first_elem = rasp.Aggregate(first_elem_selector, rasp.tokens).named("first_elem") 33 | len = rasp.SelectorWidth(rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE)) 34 | # Step 2: Select the second element 35 | second_elem_selector = rasp.Select( 36 | rasp.indices, 37 | rasp.Map(lambda x: 1 if x > 1 else 0, len), 38 | rasp.Comparison.EQ 39 | ).named("second_elem_selector") 40 | second_elem = rasp.Aggregate(second_elem_selector, rasp.tokens).named("second_elem") 41 | 42 | # Step 3: Divide the second element by the first to get the divisor 43 | divisor = rasp.SequenceMap(lambda x, y: y / x if x != 0 and y != 0 else 1, first_elem, second_elem).named("divisor") 44 | 45 | # Step 4: Divide each element of the input sequence by the divisor 46 | result_sequence = rasp.SequenceMap(lambda x, y: x / y, rasp.tokens, divisor).named("result_sequence") 47 | 48 | return result_sequence 49 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_119.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 7 | 8 | 9 | class Case119(TracrBenchmarkCase): 10 | def get_program(self) -> rasp.SOp: 11 | return make_polynomial() 12 | 13 | def get_task_description(self) -> str: 14 | return "Evaluates a polynomial with sequence elements as parameters. The x is represented by the first entry, the rest are parameters." 15 | 16 | def get_vocab(self) -> Set: 17 | return vocabs.get_int_numbers_vocab(max=6) 18 | 19 | def get_max_seq_len(self) -> int: 20 | return 5 21 | 22 | def supports_causal_masking(self) -> bool: 23 | return False 24 | 25 | 26 | def make_polynomial(degree=2) -> rasp.SOp: 27 | """ 28 | Computes the result of a polynomial. The first element of the sequence is treated as the base of the polynomial, 29 | while the following ones are the weights. 30 | Example: input [3,2,3,1,4] is treated as the polynomial 2*3^3 + 3*3^2 + 1*3 + 4 = 88 31 | Note that the degree parameter should correspond to the length of the input sequence minus two. 32 | """ 33 | 34 | aggregator = rasp.tokens - rasp.tokens 35 | first_element_selector = rasp.Select(rasp.indices, aggregator, rasp.Comparison.EQ).named("first_element_selector") 36 | base = rasp.Aggregate(first_element_selector, rasp.tokens) 37 | 38 | # Function to create selectors and weights 39 | def create_elem(i): 40 | selector = rasp.Select(rasp.indices, rasp.Map(lambda x: i, rasp.tokens), rasp.Comparison.EQ).named( 41 | f"selector_{i}") 42 | weight = rasp.Aggregate(selector, rasp.tokens, default=None).named(f"weight_{i}") 43 | elem = rasp.SequenceMap(lambda x, y: y * (x ** (degree + 1 - i)), base, weight) 44 | return rasp.SequenceMap(lambda x, y: x + y, aggregator, elem) 45 | 46 | # Applying the function for each term 47 | for i in range(1, degree + 2): 48 | aggregator = create_elem(i) 49 | 50 | return aggregator 51 | -------------------------------------------------------------------------------- /circuits_benchmark/commands/train/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import traceback 3 | 4 | import numpy as np 5 | import torch as t 6 | 7 | from circuits_benchmark.commands.train.compression import linear_compression, \ 8 | non_linear_compression 9 | from circuits_benchmark.commands.train.compression.linear_compression import train_linear_compression 10 | from circuits_benchmark.commands.train.compression.non_linear_compression import train_non_linear_compression 11 | from circuits_benchmark.commands.train.iit import iit_train 12 | from circuits_benchmark.utils.get_cases import get_cases 13 | 14 | 15 | def setup_args_parser(subparsers): 16 | run_parser = subparsers.add_parser("train") 17 | run_subparsers = run_parser.add_subparsers(dest="type") 18 | run_subparsers.required = True 19 | 20 | # Setup arguments for each algorithm 21 | linear_compression.setup_args_parser(run_subparsers) 22 | non_linear_compression.setup_args_parser(run_subparsers) 23 | iit_train.setup_args_parser(run_subparsers) 24 | 25 | 26 | def run(args): 27 | training_type = args.type 28 | 29 | cases = get_cases(args) 30 | assert len(cases) > 0, "No cases found" 31 | 32 | for case in cases: 33 | print(f"\nRunning training {training_type} on {case}") 34 | 35 | # Set numpy, torch and ptyhon seed 36 | seed = args.seed 37 | assert seed is not None, "Seed is always required" 38 | np.random.seed(args.seed) 39 | t.manual_seed(seed) 40 | random.seed(seed) 41 | 42 | try: 43 | if training_type == "linear-compression": 44 | train_linear_compression(case, args) 45 | elif training_type == "non-linear-compression": 46 | train_non_linear_compression(case, args) 47 | elif training_type == "iit": 48 | iit_train.run_iit_train(case, args) 49 | else: 50 | raise ValueError(f"Unknown training: {training_type}") 51 | except Exception as e: 52 | print(f" >>> Failed to run {training_type} training on case {case}:") 53 | traceback.print_exc() 54 | continue 55 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/iit/iit_hl_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from transformer_lens import HookedTransformer 5 | 6 | 7 | class IITHLModel: 8 | """A wrapper class to make tracr models compatible with IITModelPair""" 9 | 10 | def __init__(self, hl_model: HookedTransformer, eval_mode: bool = False): 11 | self.hl_model = hl_model 12 | self.hl_model.to(hl_model.device) 13 | self.eval_mode = eval_mode 14 | 15 | for p in hl_model.parameters(): 16 | p.requires_grad = False 17 | p.to(hl_model.device) 18 | 19 | def __getattr__(self, name: str): 20 | if hasattr(self.hl_model, name): 21 | return getattr(self.hl_model, name) 22 | else: 23 | raise AttributeError( 24 | f"'{type(self).__name__}' object has no attribute '{name}'" 25 | ) 26 | 27 | def create_hl_output(self, y): 28 | if self.hl_model.is_categorical(): 29 | y = y.argmax(dim=-1) 30 | if self.eval_mode: 31 | y = torch.nn.functional.one_hot(y, num_classes=self.hl_model.cfg.d_vocab_out) 32 | return y 33 | 34 | def get_correct_input(self, input): 35 | if isinstance(input, tuple) or isinstance(input, list): 36 | return input[0] 37 | elif isinstance(input, torch.Tensor): 38 | return input 39 | else: 40 | raise ValueError(f"Invalid input type: {type(input)}") 41 | 42 | def forward(self, input): 43 | x = self.get_correct_input(input) 44 | out = self.hl_model(x) 45 | return self.create_hl_output(out) 46 | 47 | def run_with_hooks(self, input, *args, **kwargs): 48 | x = self.get_correct_input(input) 49 | out = self.hl_model.run_with_hooks(x, *args, **kwargs) 50 | return self.create_hl_output(out) 51 | 52 | def run_with_cache(self, input): 53 | x = input[0] 54 | out, cache = self.hl_model.run_with_cache(x) 55 | return self.create_hl_output(out), cache 56 | 57 | def __call__(self, *args: Any, **kwds: Any) -> Any: 58 | return self.forward(*args, **kwds) 59 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/ll_model_loader/ll_model_loader_factory.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from circuits_benchmark.benchmark.benchmark_case import BenchmarkCase 4 | from circuits_benchmark.utils.ll_model_loader.ground_truth_model_loader import GroundTruthModelLoader 5 | from circuits_benchmark.utils.ll_model_loader.interp_bench_model_loader import InterpBenchModelLoader 6 | from circuits_benchmark.utils.ll_model_loader.ll_model_loader import LLModelLoader 7 | from circuits_benchmark.utils.ll_model_loader.natural_model_loader import NaturalModelLoader 8 | from circuits_benchmark.utils.ll_model_loader.siit_model_loader import SIITModelLoader 9 | 10 | 11 | def get_ll_model_loader_from_args(case: BenchmarkCase, args: Namespace) -> LLModelLoader: 12 | return get_ll_model_loader( 13 | case, 14 | args.natural, 15 | args.tracr, 16 | args.interp_bench, 17 | args.siit_weights, 18 | args.load_from_wandb, 19 | args.load_wandb_project, 20 | args.load_wandb_name, 21 | ) 22 | 23 | 24 | def get_ll_model_loader( 25 | case: BenchmarkCase, 26 | natural: bool = False, 27 | tracr: bool = False, 28 | interp_bench: bool = False, 29 | siit_weights: str | None = None, 30 | load_from_wandb: bool = False, 31 | wandb_project: str | None = None, 32 | wandb_name: str | None = None, 33 | ) -> LLModelLoader: 34 | assert ( 35 | not (natural and tracr) 36 | and not (natural and interp_bench) 37 | and not (tracr and interp_bench) 38 | ), "Only one of natural, tracr, interp_bench can be set" 39 | 40 | if natural: 41 | return NaturalModelLoader( 42 | case, 43 | load_from_wandb=load_from_wandb, 44 | wandb_project=wandb_project, 45 | wandb_name=wandb_name 46 | ) 47 | 48 | if tracr: 49 | return GroundTruthModelLoader(case) 50 | 51 | if interp_bench: 52 | return InterpBenchModelLoader(case) 53 | 54 | return SIITModelLoader( 55 | case, 56 | weights=siit_weights, 57 | load_from_wandb=load_from_wandb, 58 | wandb_project=wandb_project, 59 | wandb_name=wandb_name 60 | ) 61 | -------------------------------------------------------------------------------- /circuits_benchmark/benchmark/cases/case_126.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from tracr.rasp import rasp 4 | 5 | from circuits_benchmark.benchmark import vocabs 6 | from circuits_benchmark.benchmark.common_programs import make_length 7 | from circuits_benchmark.benchmark.tracr_benchmark_case import TracrBenchmarkCase 8 | 9 | 10 | class Case126(TracrBenchmarkCase): 11 | def get_program(self) -> rasp.SOp: 12 | return make_set_to_median() 13 | 14 | def get_task_description(self) -> str: 15 | return "Replaces each element with the median of all elements." 16 | 17 | def get_vocab(self) -> Set: 18 | return vocabs.get_int_numbers_vocab() 19 | 20 | def supports_causal_masking(self) -> bool: 21 | return False 22 | 23 | 24 | def make_sort() -> rasp.SOp: 25 | unique_tokens = rasp.SequenceMap(lambda x, y: x + 0.0001 * y, rasp.tokens, rasp.indices) 26 | smaller = rasp.Select(unique_tokens, unique_tokens, rasp.Comparison.LT).named("smaller") 27 | target_pos = rasp.SelectorWidth(smaller).named("target_pos") 28 | sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ) 29 | return rasp.Aggregate(sel_new, rasp.tokens).named("sort") 30 | 31 | 32 | def make_set_to_median() -> rasp.SOp: 33 | sorted_sequence = make_sort() 34 | 35 | length = make_length() 36 | # Assuming a maximum sequence length to pre-calculate possible indices for median 37 | mid_index1 = rasp.Map(lambda x: (x - 1) // 2, length).named("mid_index1") 38 | mid_index2 = rasp.Map(lambda x: x // 2, length).named("mid_index2") 39 | 40 | # Selectors for extracting potential median values 41 | median_selector1 = rasp.Select(rasp.indices, mid_index1, rasp.Comparison.EQ).named("median_selector1") 42 | median_selector2 = rasp.Select(rasp.indices, mid_index2, rasp.Comparison.EQ).named("median_selector2") 43 | 44 | # Extracting potential median values 45 | potential_median1 = rasp.Aggregate(median_selector1, sorted_sequence).named("potential_median1") 46 | potential_median2 = rasp.Aggregate(median_selector2, sorted_sequence).named("potential_median2") 47 | 48 | # Calculating the average of the two potential medians (handles both odd and even length cases) 49 | median = rasp.SequenceMap(lambda x, y: (x + y) / 2, potential_median1, potential_median2).named("median") 50 | 51 | return median 52 | -------------------------------------------------------------------------------- /circuits_benchmark/utils/iit/ll_cfg.py: -------------------------------------------------------------------------------- 1 | compression_ratio_map = { 2 | "5": 1.8, 3 | "18": 1.6, 4 | "21": 1.3, 5 | "25": 1.5, 6 | "34": 1.5, 7 | "35": 3, 8 | "36": 2.5, 9 | "37": 2.2, 10 | "9": 2, 11 | "7": 1.5, 12 | "23": 1.5, 13 | "24": 1.2, 14 | "6": 1.2, 15 | "default": 2, 16 | } 17 | 18 | cases_with_resid_compression = [ 19 | "5", 20 | "18", 21 | "21", 22 | "25", 23 | "26", 24 | "29", 25 | "34", 26 | "35", 27 | "36", 28 | "37", 29 | "9", 30 | "7", 31 | "23", 32 | "22", 33 | "28", 34 | ] 35 | 36 | 37 | def make_ll_cfg_for_case( 38 | hl_model, 39 | case_index: str, 40 | compression_ratio: float | None = None, 41 | same_size: bool = False, 42 | ) -> dict: 43 | compress_resid = case_index in cases_with_resid_compression 44 | if compression_ratio is None: 45 | compression_ratio = compression_ratio_map.get( 46 | case_index, compression_ratio_map["default"] 47 | ) 48 | return make_ll_cfg( 49 | hl_model, 50 | compress_resid=compress_resid or same_size, 51 | compression_ratio=compression_ratio, 52 | same_size=same_size, 53 | ) 54 | 55 | 56 | def make_ll_cfg( 57 | hl_model, compress_resid: bool, compression_ratio: float, same_size: bool 58 | ) -> dict: 59 | ll_cfg = hl_model.cfg.to_dict().copy() 60 | if same_size: 61 | n_heads = ll_cfg["n_heads"] 62 | else: 63 | n_heads = max(4, ll_cfg["n_heads"]) 64 | if compress_resid: 65 | d_model = int(hl_model.cfg.d_model // compression_ratio) 66 | d_model = max(2, d_model) 67 | d_head = max(1, d_model // n_heads) 68 | d_mlp = d_model * 4 69 | else: 70 | d_head = int(max(1, ll_cfg["d_head"] // compression_ratio)) 71 | d_model = n_heads * d_head 72 | d_mlp = d_model * 4 73 | assert d_model > 0 74 | assert d_head > 0 75 | assert d_mlp > 0 76 | cfg_dict = { 77 | "n_layers": max(2, ll_cfg["n_layers"]) if not same_size else ll_cfg["n_layers"], 78 | "n_heads": n_heads, 79 | "d_head": d_head, 80 | "d_model": d_model, 81 | "d_mlp": d_mlp, 82 | "seed": 0, 83 | "act_fn": "gelu", 84 | # "initializer_range": 0.02, 85 | } 86 | ll_cfg.update(cfg_dict) 87 | return ll_cfg 88 | --------------------------------------------------------------------------------