├── LICENSE ├── README.md ├── assets ├── circuit.png └── teaser.png ├── baselines_examples └── acdc_ioi.py ├── data.zip ├── requirements.txt ├── run_scripts ├── custom_example.sh ├── fsdp_configs │ ├── eval_config.yaml │ └── prune_config.yaml ├── gp_sweep.sh ├── gt_sweep.sh ├── ioi_sweep.sh ├── launch_fllama_eval.sh ├── launch_fllama_fs_prune.sh ├── launch_fllama_instr_prune.sh ├── tracr_reverse.sh ├── tracr_xproportion.sh ├── wrapper_launch_fllama_eval.sh └── wrapper_launch_fllama_prune.sh ├── src ├── eval │ ├── boolean_expressions.py │ ├── gp.py │ ├── gt.py │ └── ioi.py ├── modeling │ ├── __pycache__ │ │ ├── l0.cpython-310.pyc │ │ ├── l0_fllama.cpython-310.pyc │ │ ├── modeling_erazr.cpython-310.pyc │ │ ├── modeling_fllama.cpython-310.pyc │ │ └── modeling_fpt2.cpython-310.pyc │ ├── draw_fpt2.py │ ├── l0.py │ ├── l0_fllama.py │ ├── modeling_erazr.py │ ├── modeling_fllama.py │ ├── modeling_fpt2.py │ ├── vis_fllama.py │ └── vis_fpt2.py └── prune │ ├── erazr_reverse.py │ ├── erazr_xproportion.py │ ├── fllama_boolean_expressions_fs.py │ ├── fllama_boolean_expressions_ip.py │ ├── fpt2_custom.py │ ├── fpt2_gp.py │ ├── fpt2_gt.py │ └── fpt2_ioi.py └── tracrx ├── __init__.py ├── __pycache__ └── __init__.cpython-310.pyc ├── compiler ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── assemble.cpython-310.pyc │ ├── assemble.cpython-311.pyc │ ├── basis_inference.cpython-310.pyc │ ├── compiling.cpython-310.pyc │ ├── compiling.cpython-311.pyc │ ├── craft_graph_to_model.cpython-310.pyc │ ├── craft_model_to_transformer.cpython-310.pyc │ ├── expr_to_craft_graph.cpython-310.pyc │ ├── nodes.cpython-310.pyc │ ├── rasp_to_graph.cpython-310.pyc │ └── validating.cpython-310.pyc ├── assemble.py ├── assemble_test.py ├── basis_inference.py ├── basis_inference_test.py ├── compiling.py ├── craft_graph_to_model.py ├── craft_graph_to_model_test.py ├── craft_model_to_transformer.py ├── expr_to_craft_graph.py ├── expr_to_craft_graph_test.py ├── lib.py ├── lib_test.py ├── nodes.py ├── rasp_to_craft_integration_test.py ├── rasp_to_graph.py ├── rasp_to_graph_test.py ├── rasp_to_transformer_integration_test.py ├── test_cases.py ├── validating.py └── validating_test.py ├── craft ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── bases.cpython-310.pyc │ ├── transformers.cpython-310.pyc │ └── vectorspace_fns.cpython-310.pyc ├── bases.py ├── bases_test.py ├── chamber │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── categorical_attn.cpython-310.pyc │ │ ├── categorical_mlp.cpython-310.pyc │ │ ├── numerical_mlp.cpython-310.pyc │ │ └── selector_width.cpython-310.pyc │ ├── categorical_attn.py │ ├── categorical_attn_test.py │ ├── categorical_mlp.py │ ├── categorical_mlp_test.py │ ├── numerical_mlp.py │ ├── numerical_mlp_test.py │ ├── selector_width.py │ └── selector_width_test.py ├── tests_common.py ├── transformers.py ├── transformers_test.py ├── vectorspace_fns.py └── vectorspace_fns_test.py ├── examples ├── Visualize_Tracr_Models.ipynb └── __init__.py ├── rasp ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── rasp.cpython-310.pyc │ └── rasp.cpython-311.pyc ├── causal_eval.py ├── causal_eval_test.py ├── rasp.py └── rasp_test.py ├── transformer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── attention.cpython-310.pyc │ ├── encoder.cpython-310.pyc │ └── model.cpython-310.pyc ├── attention.py ├── compressed_model.py ├── compressed_model_test.py ├── encoder.py ├── encoder_test.py ├── model.py └── model_test.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc └── errors.cpython-310.pyc ├── debugging.py ├── errors.py └── errors_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Adithya Bhaskar, Alexander Wettig, Dan Friedman, and Danqi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/circuit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/assets/circuit.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/assets/teaser.png -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/data.zip -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.28.0 3 | aiohttp==3.9.3 4 | aiosignal==1.3.1 5 | alembic==1.13.1 6 | aniso8601==9.0.1 7 | annotated-types==0.6.0 8 | antlr4-python3-runtime==4.9.3 9 | anyio==4.3.0 10 | apache-libcloud==3.8.0 11 | appdirs==1.4.4 12 | argcomplete==3.2.3 13 | arrow==1.3.0 14 | asttokens==2.4.1 15 | async-timeout==4.0.3 16 | attrs==23.2.0 17 | azure-core==1.30.1 18 | azure-identity==1.15.0 19 | azure-storage-blob==12.19.1 20 | azure-storage-file-datalake==12.14.0 21 | backcall==0.2.0 22 | backoff==2.2.1 23 | bcrypt==4.1.2 24 | beautifulsoup4==4.12.3 25 | bleach==6.1.0 26 | blinker==1.7.0 27 | boto3==1.34.74 28 | botocore==1.34.74 29 | Brotli==1.1.0 30 | cachetools==5.3.3 31 | catalogue==2.0.10 32 | certifi==2022.12.7 33 | cffi==1.16.0 34 | charset-normalizer==2.1.1 35 | chex==0.1.86 36 | circuitbreaker==1.4.0 37 | click==8.1.7 38 | cloudpickle==3.0.0 39 | cmake==3.25.0 40 | coloredlogs==15.0.1 41 | colorspacious==1.1.2 42 | composer==0.16.3 43 | contourpy==1.2.0 44 | coolname==2.2.0 45 | cramjam==2.8.3 46 | cryptography==42.0.5 47 | cycler==0.12.1 48 | datasets==2.14.7 49 | decorator==5.1.1 50 | deepspeed==0.14.0 51 | defusedxml==0.7.1 52 | dill==0.3.7 53 | distro==1.9.0 54 | dm-haiku==0.0.12 55 | docker==7.0.0 56 | docker-pycreds==0.4.0 57 | docopt==0.6.2 58 | einops==0.5.0 59 | entrypoints==0.4 60 | etils==1.7.0 61 | evaluate==0.4.1 62 | exceptiongroup==1.2.0 63 | executing==2.0.1 64 | fastjsonschema==2.19.1 65 | filelock==3.9.0 66 | fire==0.6.0 67 | Flask==3.0.2 68 | flatbuffers==24.3.25 69 | flax==0.8.2 70 | fonttools==4.50.0 71 | frozenlist==1.4.1 72 | fsspec==2023.6.0 73 | gitdb==4.0.11 74 | GitPython==3.1.43 75 | google-api-core==2.18.0 76 | google-auth==2.29.0 77 | google-cloud-core==2.4.1 78 | google-cloud-storage==2.10.0 79 | google-crc32c==1.5.0 80 | google-resumable-media==2.7.0 81 | googleapis-common-protos==1.63.0 82 | gql==3.5.0 83 | graphene==3.3 84 | graphql-core==3.2.3 85 | graphql-relay==3.2.0 86 | graphviz==0.20.3 87 | greenlet==3.0.3 88 | gunicorn==21.2.0 89 | h11==0.14.0 90 | hjson==3.1.0 91 | httpcore==1.0.5 92 | httpx==0.27.0 93 | huggingface-hub==0.22.2 94 | humanfriendly==10.0 95 | idna==3.4 96 | importlib-metadata==6.11.0 97 | importlib_resources==6.4.0 98 | ipython==8.12.3 99 | isodate==0.6.1 100 | itsdangerous==2.1.2 101 | jax==0.4.25 102 | jaxlib==0.4.25 103 | jedi==0.19.1 104 | Jinja2==3.1.2 105 | jmespath==1.0.1 106 | jmp==0.0.4 107 | joblib==1.3.2 108 | jsonschema==4.22.0 109 | jsonschema-specifications==2023.12.1 110 | jupyter_client==8.6.2 111 | jupyter_core==5.7.2 112 | jupyterlab_pygments==0.3.0 113 | kiwisolver==1.4.5 114 | lightning-utilities==0.11.2 115 | lit==15.0.7 116 | llm-foundry==0.3.0 117 | Mako==1.3.2 118 | Markdown==3.6 119 | markdown-it-py==3.0.0 120 | MarkupSafe==2.1.3 121 | matplotlib==3.8.3 122 | matplotlib-inline==0.1.7 123 | mdurl==0.1.2 124 | mistune==3.0.2 125 | ml-dtypes==0.4.0 126 | mlflow==2.11.3 127 | mosaicml==0.16.4 128 | mosaicml-cli==0.5.34 129 | mosaicml-streaming==0.6.1 130 | mpmath==1.3.0 131 | msal==1.28.0 132 | msal-extensions==1.1.0 133 | msgpack==1.0.8 134 | multidict==6.0.5 135 | multiprocess==0.70.15 136 | nbclient==0.10.0 137 | nbconvert==7.16.4 138 | nbformat==5.10.4 139 | nest-asyncio==1.6.0 140 | networkx==3.2.1 141 | ninja==1.11.1.1 142 | numpy==1.26.3 143 | nvidia-cublas-cu12==12.1.3.1 144 | nvidia-cuda-cupti-cu12==12.1.105 145 | nvidia-cuda-nvrtc-cu12==12.1.105 146 | nvidia-cuda-runtime-cu12==12.1.105 147 | nvidia-cudnn-cu12==8.9.2.26 148 | nvidia-cufft-cu12==11.0.2.54 149 | nvidia-curand-cu12==10.3.2.106 150 | nvidia-cusolver-cu12==11.4.5.107 151 | nvidia-cusparse-cu12==12.1.0.106 152 | nvidia-nccl-cu12==2.18.1 153 | nvidia-nvjitlink-cu12==12.4.99 154 | nvidia-nvtx-cu12==12.1.105 155 | oci==2.125.0 156 | omegaconf==2.3.0 157 | onnx==1.14.0 158 | onnxruntime==1.15.1 159 | openai==1.21.1 160 | opt-einsum==3.3.0 161 | optax==0.2.2 162 | orbax-checkpoint==0.5.7 163 | packaging==22.0 164 | pandas==2.2.1 165 | pandocfilters==1.5.1 166 | paramiko==3.4.0 167 | parso==0.8.4 168 | pathtools==0.1.2 169 | pexpect==4.9.0 170 | pickleshare==0.7.5 171 | pillow==10.2.0 172 | pipreqs==0.5.0 173 | platformdirs==4.2.2 174 | portalocker==2.8.2 175 | prompt-toolkit==3.0.36 176 | proto-plus==1.23.0 177 | protobuf==4.25.3 178 | psutil==5.9.8 179 | ptyprocess==0.7.0 180 | pure-eval==0.2.2 181 | py-cpuinfo==9.0.0 182 | pyarrow==15.0.2 183 | pyarrow-hotfix==0.6 184 | pyasn1==0.6.0 185 | pyasn1_modules==0.4.0 186 | pycparser==2.22 187 | pydantic==2.6.4 188 | pydantic_core==2.16.3 189 | Pygments==2.17.2 190 | PyJWT==2.8.0 191 | PyNaCl==1.5.0 192 | pynvml==11.5.0 193 | pyOpenSSL==24.1.0 194 | pyparsing==3.1.2 195 | python-dateutil==2.9.0.post0 196 | python-snappy==0.7.1 197 | pytorch-ranger==0.1.1 198 | pytz==2024.1 199 | PyYAML==6.0.1 200 | pyzmq==26.0.3 201 | querystring-parser==1.2.4 202 | questionary==2.0.1 203 | referencing==0.35.1 204 | regex==2023.12.25 205 | requests==2.28.1 206 | responses==0.18.0 207 | rich==13.7.1 208 | rpds-py==0.18.1 209 | rsa==4.9 210 | ruamel.yaml==0.18.6 211 | ruamel.yaml.clib==0.2.8 212 | s3transfer==0.10.1 213 | safetensors==0.4.2 214 | scikit-learn==1.4.1.post1 215 | scipy==1.12.0 216 | seaborn==0.13.2 217 | sentencepiece==0.1.97 218 | sentry-sdk==1.44.0 219 | setproctitle==1.3.3 220 | shellingham==1.5.4 221 | six==1.16.0 222 | slack_sdk==3.27.1 223 | smmap==5.0.1 224 | sniffio==1.3.1 225 | soupsieve==2.5 226 | SQLAlchemy==2.0.29 227 | sqlparse==0.4.4 228 | stack-data==0.6.3 229 | sympy==1.12 230 | tabulate==0.9.0 231 | tenacity==8.2.3 232 | tensorstore==0.1.56 233 | termcolor==2.4.0 234 | threadpoolctl==3.4.0 235 | tinycss2==1.3.0 236 | tokenizers==0.15.2 237 | toolz==0.12.1 238 | torch==2.1.0 239 | torch-optimizer==0.3.0 240 | torchaudio==2.0.2+cu118 241 | torchmetrics==1.0.3 242 | torchvision==0.16.0 243 | tornado==6.4.1 244 | tqdm==4.66.2 245 | tracr @ file:///scratch/gpfs/ab4197/p-printer/tracr/tracr 246 | traitlets==5.14.3 247 | transformers @ git+https://github.com/huggingface/transformers@3b8e2932ce743008f63585aae1e1b8b30dc8b3ac 248 | triton==2.1.0 249 | typer==0.12.0 250 | typer-cli==0.12.0 251 | typer-slim==0.12.0 252 | types-python-dateutil==2.9.0.20240316 253 | typing_extensions==4.8.0 254 | tzdata==2024.1 255 | urllib3==1.26.13 256 | validators==0.24.0 257 | wandb==0.15.12 258 | wcwidth==0.2.13 259 | webencodings==0.5.1 260 | websockets==11.0.3 261 | Werkzeug==3.0.1 262 | xxhash==3.4.1 263 | yarg==0.1.9 264 | yarl==1.9.4 265 | zipp==3.18.1 266 | zstd==1.5.5.1 267 | -------------------------------------------------------------------------------- /run_scripts/custom_example.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITY=0.95 2 | NODE_SPARSITY=0.72 3 | ELR=0.8 4 | LLR=0.8 5 | RELR=0.8 6 | RLLR=0.8 7 | TOTAL=3000 8 | WARMUP=2500 9 | 10 | EXTRA="--disable_node_loss" 11 | TAG="wo_node_loss" 12 | 13 | # Uncomment this if you want to run with node loss 14 | # EXTRA="" 15 | # TAG="w_node_loss" 16 | 17 | train_split="train" # "train_400", "train_100k" 18 | N_TRAIN=10 # Set to a large value so all of the (200 / 400 / 100000) examples are used 19 | N_VAL=10 # The val split size 20 | 21 | # You can wrap the following in an sbatch script if you use SLURM 22 | # Activate your environment etc 23 | 24 | # If you want to always keep embedding nodes, remove the --with_embedding_nodes flag 25 | # That flag, when set, also models masks over the embedding nodes 26 | 27 | WANDB_MODE=disabled python src/prune/fpt2_custom.py \ 28 | --report_to wandb \ 29 | --do_train \ 30 | --do_eval \ 31 | --dataset_path ./data/datasets/example_custom.jsonl \ 32 | --train_split $train_split \ 33 | --initialize_from gpt2 \ 34 | --max_seq_length 64 \ 35 | --per_device_train_batch_size 32 \ 36 | --per_device_eval_batch_size 16 \ 37 | --gradient_accumulation_steps 1 \ 38 | --eval_accumulation_steps 16 \ 39 | --edge_learning_rate $ELR \ 40 | --layer_learning_rate $LLR \ 41 | --reg_edge_learning_rate $RELR \ 42 | --reg_layer_learning_rate $RLLR \ 43 | --max_steps $TOTAL \ 44 | --warmup_steps 200 \ 45 | --evaluation_strategy steps \ 46 | --eval_steps 64 \ 47 | --save_steps 64 \ 48 | --logging_steps 8 \ 49 | --save_total_limit 1 \ 50 | --start_edge_sparsity 0.00 \ 51 | --target_edge_sparsity $EDGE_SPARSITY \ 52 | --start_layer_sparsity 0.00 \ 53 | --target_layer_sparsity $NODE_SPARSITY \ 54 | --num_sparsity_warmup_steps $WARMUP \ 55 | --max_train_samples $N_TRAIN \ 56 | --max_eval_samples $N_VAL \ 57 | --output_dir ./data/runs/custom-example/ \ 58 | --remove_unused_columns false \ 59 | --dataloader_num_workers 0 \ 60 | --warmup_type linear \ 61 | --with_embedding_nodes \ 62 | $EXTRA -------------------------------------------------------------------------------- /run_scripts/fsdp_configs/eval_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: NO_PREFETCH 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: false 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: FULL_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_transformer_layer_cls_to_wrap: FLlamaDecoderLayer 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | mixed_precision: bf16 18 | main_training_function: main 19 | num_machines: 1 20 | num_processes: 4 21 | rdzv_backend: c10d 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /run_scripts/fsdp_configs/prune_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: NO_PREFETCH 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: true 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: FULL_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_transformer_layer_cls_to_wrap: FLlamaDecoderLayer 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | mixed_precision: bf16 18 | main_training_function: main 19 | num_machines: 4 20 | num_processes: 8 21 | rdzv_backend: c10d 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false -------------------------------------------------------------------------------- /run_scripts/gp_sweep.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITIES=(0.93 0.94 0.945 0.95 0.955 0.96 0.965 0.97 0.975 0.98 0.985 0.99 0.995 1.0 1.05 1.1 1.15 1.2 1.4) 2 | 3 | for i in "${!EDGE_SPARSITIES[@]}"; do 4 | 5 | EDGE_SPARSITY=${EDGE_SPARSITIES[i]} 6 | NODE_SPARSITY=0.69 7 | ELR=0.8 8 | LLR=0.8 9 | RELR=0.8 10 | RLLR=0.8 11 | TOTAL=3000 12 | WARMUP=2500 13 | 14 | EXTRA="--disable_node_loss" 15 | TAG="wo_node_loss" 16 | 17 | # Uncomment this if you want to run with node loss 18 | # EXTRA="" 19 | # TAG="w_node_loss" 20 | 21 | train_split="train" # "train_3k" 22 | N_TRAIN=1000000 # Set to a large value so all of the (150 / 3000) examples are used 23 | N_VAL=150 # The val split size 24 | 25 | # You can wrap the following in an sbatch script if you use SLURM 26 | # Activate your environment etc 27 | 28 | # If you want to always keep embedding nodes, remove the --with_embedding_nodes flag 29 | # That flag, when set, also models masks over the embedding nodes 30 | 31 | WANDB_MODE=disabled python src/prune/fpt2_gp.py \ 32 | --report_to wandb \ 33 | --do_train \ 34 | --do_eval \ 35 | --dataset_path ./data/datasets/gp/ \ 36 | --train_split $train_split \ 37 | --initialize_from gpt2 \ 38 | --max_seq_length 64 \ 39 | --per_device_train_batch_size 32 \ 40 | --per_device_eval_batch_size 16 \ 41 | --gradient_accumulation_steps 1 \ 42 | --eval_accumulation_steps 16 \ 43 | --edge_learning_rate $ELR \ 44 | --layer_learning_rate $LLR \ 45 | --reg_edge_learning_rate $RELR \ 46 | --reg_layer_learning_rate $RLLR \ 47 | --max_steps $TOTAL \ 48 | --warmup_steps 200 \ 49 | --evaluation_strategy steps \ 50 | --eval_steps 64 \ 51 | --save_steps 64 \ 52 | --logging_steps 8 \ 53 | --save_total_limit 1 \ 54 | --start_edge_sparsity 0.00 \ 55 | --target_edge_sparsity $EDGE_SPARSITY \ 56 | --start_layer_sparsity 0.00 \ 57 | --target_layer_sparsity $NODE_SPARSITY \ 58 | --num_sparsity_warmup_steps $WARMUP \ 59 | --max_train_samples $N_TRAIN \ 60 | --max_eval_samples $N_VAL \ 61 | --output_dir ./data/runs/gp-${TAG}-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ \ 62 | --remove_unused_columns false \ 63 | --dataloader_num_workers 0 \ 64 | --warmup_type linear \ 65 | --with_embedding_nodes \ 66 | $EXTRA 67 | 68 | done -------------------------------------------------------------------------------- /run_scripts/gt_sweep.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITIES=(0.95 0.955 0.96 0.965 0.97 0.975 0.98 0.985 0.99 0.995 1.0 1.01 1.02 1.05 1.1) 2 | 3 | for i in "${!EDGE_SPARSITIES[@]}"; do 4 | 5 | EDGE_SPARSITY=${EDGE_SPARSITIES[i]} 6 | NODE_SPARSITY=0.68 7 | ELR=0.8 8 | LLR=0.8 9 | RELR=0.8 10 | RLLR=0.8 11 | TOTAL=3000 12 | WARMUP=2500 13 | 14 | EXTRA="--disable_node_loss" 15 | TAG="wo_node_loss" 16 | 17 | # Uncomment this if you want to run with node loss 18 | # EXTRA="" 19 | # TAG="w_node_loss" 20 | 21 | train_split="train" # "train_80k" 22 | N_TRAIN=1000000 # Set to a large value so all of the (150 / 80000) examples are used 23 | N_VAL=150 # The val split size 24 | 25 | # You can wrap the following in an sbatch script if you use SLURM 26 | # Activate your environment etc 27 | 28 | # If you want to always keep embedding nodes, remove the --with_embedding_nodes flag 29 | # That flag, when set, also models masks over the embedding nodes 30 | 31 | WANDB_MODE=disabled python src/prune/fpt2_gt.py \ 32 | --report_to wandb \ 33 | --do_train \ 34 | --do_eval \ 35 | --dataset_path ./data/datasets/gt/ \ 36 | --train_split $train_split \ 37 | --initialize_from gpt2 \ 38 | --max_seq_length 64 \ 39 | --per_device_train_batch_size 32 \ 40 | --per_device_eval_batch_size 16 \ 41 | --gradient_accumulation_steps 1 \ 42 | --eval_accumulation_steps 16 \ 43 | --edge_learning_rate $ELR \ 44 | --layer_learning_rate $LLR \ 45 | --reg_edge_learning_rate $RELR \ 46 | --reg_layer_learning_rate $RLLR \ 47 | --max_steps $TOTAL \ 48 | --warmup_steps 200 \ 49 | --evaluation_strategy steps \ 50 | --eval_steps 64 \ 51 | --save_steps 64 \ 52 | --logging_steps 8 \ 53 | --save_total_limit 1 \ 54 | --start_edge_sparsity 0.00 \ 55 | --target_edge_sparsity $EDGE_SPARSITY \ 56 | --start_layer_sparsity 0.00 \ 57 | --target_layer_sparsity $NODE_SPARSITY \ 58 | --num_sparsity_warmup_steps $WARMUP \ 59 | --max_train_samples $N_TRAIN \ 60 | --max_eval_samples $N_VAL \ 61 | --output_dir ./data/runs/gt-${TAG}-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ \ 62 | --remove_unused_columns false \ 63 | --dataloader_num_workers 0 \ 64 | --warmup_type linear \ 65 | --with_embedding_nodes \ 66 | $EXTRA 67 | 68 | done -------------------------------------------------------------------------------- /run_scripts/ioi_sweep.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITIES=(0.94 0.945 0.95 0.955 0.96 0.965 0.97 0.975 0.98 0.985 0.99 0.995 1.0 1.01 1.02 1.05 1.1) 2 | 3 | for i in "${!EDGE_SPARSITIES[@]}"; do 4 | 5 | EDGE_SPARSITY=${EDGE_SPARSITIES[i]} 6 | NODE_SPARSITY=0.72 7 | ELR=0.8 8 | LLR=0.8 9 | RELR=0.8 10 | RLLR=0.8 11 | TOTAL=3000 12 | WARMUP=2500 13 | 14 | EXTRA="--disable_node_loss" 15 | TAG="wo_node_loss" 16 | 17 | # Uncomment this if you want to run with node loss 18 | # EXTRA="" 19 | # TAG="w_node_loss" 20 | 21 | train_split="train" # "train_400", "train_100k" 22 | N_TRAIN=1000000 # Set to a large value so all of the (200 / 400 / 100000) examples are used 23 | N_VAL=200 # The val split size 24 | 25 | # You can wrap the following in an sbatch script if you use SLURM 26 | # Activate your environment etc 27 | 28 | # If you want to always keep embedding nodes, remove the --with_embedding_nodes flag 29 | # That flag, when set, also models masks over the embedding nodes 30 | 31 | WANDB_MODE=disabled python src/prune/fpt2_ioi.py \ 32 | --report_to wandb \ 33 | --do_train \ 34 | --do_eval \ 35 | --dataset_path ./data/datasets/ioi/ \ 36 | --train_split $train_split \ 37 | --initialize_from gpt2 \ 38 | --max_seq_length 64 \ 39 | --per_device_train_batch_size 32 \ 40 | --per_device_eval_batch_size 16 \ 41 | --gradient_accumulation_steps 1 \ 42 | --eval_accumulation_steps 16 \ 43 | --edge_learning_rate $ELR \ 44 | --layer_learning_rate $LLR \ 45 | --reg_edge_learning_rate $RELR \ 46 | --reg_layer_learning_rate $RLLR \ 47 | --max_steps $TOTAL \ 48 | --warmup_steps 200 \ 49 | --evaluation_strategy steps \ 50 | --eval_steps 64 \ 51 | --save_steps 64 \ 52 | --logging_steps 8 \ 53 | --save_total_limit 1 \ 54 | --start_edge_sparsity 0.00 \ 55 | --target_edge_sparsity $EDGE_SPARSITY \ 56 | --start_layer_sparsity 0.00 \ 57 | --target_layer_sparsity $NODE_SPARSITY \ 58 | --num_sparsity_warmup_steps $WARMUP \ 59 | --max_train_samples $N_TRAIN \ 60 | --max_eval_samples $N_VAL \ 61 | --output_dir ./data/runs/ioi-${TAG}-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ \ 62 | --remove_unused_columns false \ 63 | --dataloader_num_workers 0 \ 64 | --warmup_type linear \ 65 | --with_embedding_nodes \ 66 | $EXTRA 67 | 68 | done -------------------------------------------------------------------------------- /run_scripts/launch_fllama_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=eval-fllama 3 | #SBATCH --nodes=1 4 | #SBATCH --output=./joblog/%x-%A_%a.out ## Stdout 5 | #SBATCH --error=./joblog/%x-%A_%a.err ## Stderr 6 | #SBATCH --gres=gpu:4 7 | #SBATCH --mem=300G 8 | #SBATCH --time=1:30:00 9 | 10 | num_nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST | wc -l) 11 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 12 | 13 | export MASTER_ADDR=$master_addr 14 | if [ -z "$SLURM_GPUS_PER_NODE" ]; then 15 | export SLURM_GPUS_PER_NODE=4 16 | fi 17 | echo $SLURM_GPUS_PER_NODE 18 | 19 | export WORLD_SIZE=$(( $num_nodes * $SLURM_GPUS_PER_NODE )) 20 | export MASTER_PORT=$(( 10000 + RANDOM % 10000 )) 21 | export NUM_NODES=$num_nodes 22 | 23 | echo "MASTER_ADDR="$MASTER_ADDR 24 | echo "MASTER_PORT="$MASTER_PORT 25 | echo "WORLD_SIZE="$WORLD_SIZE 26 | echo "num_nodes="$num_nodes 27 | 28 | MODEL="/path/to/pruned/model" 29 | REFERENCE="meta-llama/CodeLlama-13b-Instruct-hf" 30 | BATCH_SIZE=4 31 | MODE="instruction" # "fewshot" 32 | 33 | # If you want to evaluate the intersection, or a custom set of edges 34 | ## Step 1: Obtain the intersection 35 | # python src/modeling/vis_fllama.py -m1 /path/to/model1 -m2 /path/to/model2 -o /path/to/output.json 36 | ## Step 2: Run evaluation with the additional flag -e /path/to/output.json 37 | # For this step $MODEL should point to $REFERENCE or either of model1/model2 38 | 39 | srun bash run_scripts/wrapper_launch_fllama_eval.sh \ 40 | src/eval/boolean_expressions.py -m $MODEL -r $REFERENCE -b $BATCH_SIZE -M $MODE -bf16 -------------------------------------------------------------------------------- /run_scripts/launch_fllama_fs_prune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH --job-name=fs_prune-fllama 3 | #SBATCH --nodes=4 4 | #SBATCH --output=./joblog/%x-%A_%a.out ## Stdout 5 | #SBATCH --error=./joblog/%x-%A_%a.err ## Stderr 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --mem=700G 8 | #SBATCH --time=35:00:00 9 | #SBATCH --cpus-per-task=16 10 | 11 | LOG_DIR=joblog 12 | 13 | num_nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST | wc -l) 14 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 15 | 16 | export MASTER_ADDR=$master_addr 17 | if [ -z "$SLURM_GPUS_PER_NODE" ]; then 18 | export SLURM_GPUS_PER_NODE=8 19 | fi 20 | echo $SLURM_GPUS_PER_NODE 21 | 22 | export WORLD_SIZE=$(( $num_nodes * $SLURM_GPUS_PER_NODE )) 23 | export MASTER_PORT=$(( 10000 + RANDOM % 10000 )) 24 | export NUM_NODES=$num_nodes 25 | 26 | echo "MASTER_ADDR="$MASTER_ADDR 27 | echo "MASTER_PORT="$MASTER_PORT 28 | echo "WORLD_SIZE="$WORLD_SIZE 29 | echo "num_nodes="$num_nodes 30 | 31 | ELR=0.8 32 | LLR=$ELR 33 | RELR=0.4 34 | RLLR=$RELR 35 | EDGE_SPARSITY=1.2 36 | NODE_SPARSITY=0.7 37 | TOTAL=6000 38 | WARMUP=5500 39 | SEED=42 40 | OUTPUT_DIR=./data/runs/fllama-fs-s${SEED}-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ 41 | 42 | mkdir -p $OUTPUT_DIR 43 | 44 | # Add --with_embedding_nodes if you want to allow the model to prune the embedding nodes 45 | # It should work with the same hyperparameters, but give you a slightly sparser circuit 46 | 47 | # Remove --disable_node_loss if you want to prune with node loss 48 | 49 | srun bash run_scripts/wrapper_launch_fllama_prune.sh \ 50 | src/prune/fllama_boolean_expressions_fs.py \ 51 | --report_to wandb \ 52 | --do_train \ 53 | --dataset_path ./data/datasets/boolean_expressions/ \ 54 | --initialize_from meta-llama/CodeLlama-13b-Instruct-hf \ 55 | --ref_initialize_from meta-llama/CodeLlama-13b-Instruct-hf \ 56 | --max_seq_length 72 \ 57 | --per_device_train_batch_size 1 \ 58 | --edge_learning_rate $ELR \ 59 | --node_learning_rate $LLR \ 60 | --reg_edge_learning_rate $RELR \ 61 | --reg_node_learning_rate $RLLR \ 62 | --max_steps $TOTAL \ 63 | --warmup_steps 200 \ 64 | --save_steps 512 \ 65 | --logging_steps 8 \ 66 | --save_total_limit 1 \ 67 | --start_edge_sparsity 0.00 \ 68 | --target_edge_sparsity $EDGE_SPARSITY \ 69 | --start_node_sparsity 0.00 \ 70 | --target_node_sparsity $NODE_SPARSITY \ 71 | --num_sparsity_warmup_steps $WARMUP \ 72 | --max_train_samples 8000 \ 73 | --output_dir $OUTPUT_DIR \ 74 | --remove_unused_columns false \ 75 | --dataloader_num_workers 0 \ 76 | --warmup_type linear \ 77 | --bf16 \ 78 | --gradient_checkpointing \ 79 | --seed $SEED \ 80 | --disable_node_loss -------------------------------------------------------------------------------- /run_scripts/launch_fllama_instr_prune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH --job-name=instr_prune-fllama 3 | #SBATCH --nodes=4 4 | #SBATCH --output=./joblog/%x-%A_%a.out ## Stdout 5 | #SBATCH --error=./joblog/%x-%A_%a.err ## Stderr 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --mem=700G 8 | #SBATCH --time=35:00:00 9 | #SBATCH --cpus-per-task=16 10 | 11 | LOG_DIR=joblog 12 | 13 | num_nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST | wc -l) 14 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 15 | 16 | export MASTER_ADDR=$master_addr 17 | if [ -z "$SLURM_GPUS_PER_NODE" ]; then 18 | export SLURM_GPUS_PER_NODE=8 19 | fi 20 | echo $SLURM_GPUS_PER_NODE 21 | 22 | export WORLD_SIZE=$(( $num_nodes * $SLURM_GPUS_PER_NODE )) 23 | export MASTER_PORT=$(( 10000 + RANDOM % 10000 )) 24 | export NUM_NODES=$num_nodes 25 | 26 | echo "MASTER_ADDR="$MASTER_ADDR 27 | echo "MASTER_PORT="$MASTER_PORT 28 | echo "WORLD_SIZE="$WORLD_SIZE 29 | echo "num_nodes="$num_nodes 30 | 31 | ELR=0.8 32 | LLR=$ELR 33 | RELR=0.4 34 | RLLR=$RELR 35 | EDGE_SPARSITY=1.2 36 | NODE_SPARSITY=0.7 37 | TOTAL=6000 38 | WARMUP=5500 39 | SEED=42 40 | OUTPUT_DIR=./data/runs/fllama-instr-s${SEED}-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ 41 | 42 | mkdir -p $OUTPUT_DIR 43 | 44 | # Add --with_embedding_nodes if you want to allow the model to prune the embedding nodes 45 | # It should work with the same hyperparameters, but give you a slightly sparser circuit 46 | 47 | # Remove --disable_node_loss if you want to prune with node loss 48 | 49 | srun bash run_scripts/wrapper_launch_fllama_prune.sh \ 50 | src/prune/fllama_boolean_expressions_instr.py \ 51 | --report_to wandb \ 52 | --do_train \ 53 | --dataset_path ./data/datasets/merged/boolean_expressions_inhouse_big/ \ 54 | --initialize_from meta-llama/CodeLlama-13b-Instruct-hf \ 55 | --ref_initialize_from meta-llama/CodeLlama-13b-Instruct-hf \ 56 | --max_seq_length 64 \ 57 | --per_device_train_batch_size 1 \ 58 | --edge_learning_rate $ELR \ 59 | --node_learning_rate $LLR \ 60 | --reg_edge_learning_rate $RELR \ 61 | --reg_node_learning_rate $RLLR \ 62 | --max_steps $TOTAL \ 63 | --warmup_steps 200 \ 64 | --save_steps 512 \ 65 | --logging_steps 8 \ 66 | --save_total_limit 1 \ 67 | --start_edge_sparsity 0.00 \ 68 | --target_edge_sparsity $EDGE_SPARSITY \ 69 | --start_node_sparsity 0.00 \ 70 | --target_node_sparsity $NODE_SPARSITY \ 71 | --num_sparsity_warmup_steps $WARMUP \ 72 | --max_train_samples 8000 \ 73 | --output_dir ${OUTPUT_DIR} \ 74 | --remove_unused_columns false \ 75 | --dataloader_num_workers 0 \ 76 | --warmup_type linear \ 77 | --bf16 \ 78 | --gradient_checkpointing \ 79 | --seed $SEED \ 80 | --disable_node_loss -------------------------------------------------------------------------------- /run_scripts/tracr_reverse.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITY=1.02 2 | NODE_SPARSITY=0.1 3 | ELR=0.03 4 | LLR=0.03 5 | RELR=0.001 6 | RLLR=0.001 7 | TOTAL=6000 8 | WARMUP=5900 9 | 10 | WANDB_MODE=disabled python src/prune/erazr_reverse.py \ 11 | --report_to wandb \ 12 | --do_train \ 13 | --do_eval \ 14 | --dataset_path ./data/datasets/reverse-t3-s3 \ 15 | --initialize_from ./data/tracr_models/reverse.tracr.pkl \ 16 | --seq_length 4 \ 17 | --per_device_train_batch_size 16 \ 18 | --per_device_eval_batch_size 16 \ 19 | --edge_learning_rate $ELR \ 20 | --node_learning_rate $LLR \ 21 | --reg_edge_learning_rate $RELR \ 22 | --reg_node_learning_rate $RLLR \ 23 | --max_steps $TOTAL \ 24 | --warmup_steps 1500 \ 25 | --evaluation_strategy steps \ 26 | --eval_steps 64 \ 27 | --save_steps 64 \ 28 | --logging_steps 4 \ 29 | --save_total_limit 1 \ 30 | --start_edge_sparsity 0.00 \ 31 | --target_edge_sparsity $EDGE_SPARSITY \ 32 | --start_node_sparsity 0.00 \ 33 | --target_node_sparsity $NODE_SPARSITY \ 34 | --num_sparsity_warmup_steps $WARMUP \ 35 | --max_train_samples 100000 \ 36 | --max_eval_samples 100000 \ 37 | --output_dir ./data/runs/erazr-reverse-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ \ 38 | --remove_unused_columns false \ 39 | --dataloader_num_workers 0 \ 40 | --label_names labels \ 41 | --warmup_type linear \ 42 | --zero_ablation \ 43 | --disable_node_loss \ 44 | --overwrite_output_dir -------------------------------------------------------------------------------- /run_scripts/tracr_xproportion.sh: -------------------------------------------------------------------------------- 1 | EDGE_SPARSITY=0.92 2 | NODE_SPARSITY=0.4 3 | ELR=1 4 | LLR=1 5 | RELR=0.0001 6 | RLLR=0.0001 7 | TOTAL=720 8 | WARMUP=640 9 | 10 | WANDB_MODE=disabled python src/prune/erazr_xproportion.py \ 11 | --report_to wandb \ 12 | --do_train \ 13 | --do_eval \ 14 | --dataset_path ./data/datasets/xproportion-t4-s4 \ 15 | --initialize_from ./data/tracr_models/xproportion.tracr.pkl \ 16 | --seq_length 5 \ 17 | --per_device_train_batch_size 16 \ 18 | --per_device_eval_batch_size 16 \ 19 | --edge_learning_rate $ELR \ 20 | --node_learning_rate $LLR \ 21 | --reg_edge_learning_rate $RELR \ 22 | --reg_node_learning_rate $RLLR \ 23 | --max_steps $TOTAL \ 24 | --warmup_steps 96 \ 25 | --evaluation_strategy steps \ 26 | --eval_steps 8 \ 27 | --save_steps 8 \ 28 | --logging_steps 4 \ 29 | --save_total_limit 1 \ 30 | --start_edge_sparsity 0.00 \ 31 | --target_edge_sparsity $EDGE_SPARSITY \ 32 | --start_node_sparsity 0.00 \ 33 | --target_node_sparsity $NODE_SPARSITY \ 34 | --num_sparsity_warmup_steps $WARMUP \ 35 | --max_train_samples 100000 \ 36 | --max_eval_samples 100000 \ 37 | --output_dir ./data/runs/erazr-xproportion-elr${ELR}-llr${LLR}-relr${RELR}-rllr${RLLR}-es${EDGE_SPARSITY}-ns${NODE_SPARSITY}-t${TOTAL}/ \ 38 | --remove_unused_columns false \ 39 | --dataloader_num_workers 0 \ 40 | --label_names labels \ 41 | --warmup_type linear \ 42 | --zero_ablation \ 43 | --overwrite_output_dir -------------------------------------------------------------------------------- /run_scripts/wrapper_launch_fllama_eval.sh: -------------------------------------------------------------------------------- 1 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 2 | export MASTER_ADDR=$master_addr 3 | 4 | echo accelerate launch --config_file run_scripts/fsdp_configs/eval_config.yaml \ 5 | --main_process_ip ${MASTER_ADDR} \ 6 | --main_process_port ${MASTER_PORT} \ 7 | --machine_rank ${SLURM_NODEID} \ 8 | --num_machines ${NUM_NODES} \ 9 | --num_processes ${WORLD_SIZE} $@ 10 | 11 | accelerate launch --config_file run_scripts/fsdp_configs/eval_config.yaml \ 12 | --main_process_ip ${MASTER_ADDR} \ 13 | --main_process_port ${MASTER_PORT} \ 14 | --machine_rank ${SLURM_NODEID} \ 15 | --num_machines ${NUM_NODES} \ 16 | --num_processes ${WORLD_SIZE} $@ -------------------------------------------------------------------------------- /run_scripts/wrapper_launch_fllama_prune.sh: -------------------------------------------------------------------------------- 1 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 2 | export MASTER_ADDR=$master_addr 3 | 4 | echo WANDB_MODE=disabled accelerate launch --config_file run_scripts/fsdp_configs/prune_config.yaml \ 5 | --main_process_ip ${MASTER_ADDR} \ 6 | --main_process_port ${MASTER_PORT} \ 7 | --machine_rank ${SLURM_NODEID} \ 8 | --num_machines ${NUM_NODES} \ 9 | --num_processes ${WORLD_SIZE} $@ 10 | 11 | WANDB_MODE=disabled accelerate launch --config_file run_scripts/fsdp_configs/prune_config.yaml \ 12 | --main_process_ip ${MASTER_ADDR} \ 13 | --main_process_port ${MASTER_PORT} \ 14 | --machine_rank ${SLURM_NODEID} \ 15 | --num_machines ${NUM_NODES} \ 16 | --num_processes ${WORLD_SIZE} $@ -------------------------------------------------------------------------------- /src/modeling/__pycache__/l0.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/src/modeling/__pycache__/l0.cpython-310.pyc -------------------------------------------------------------------------------- /src/modeling/__pycache__/l0_fllama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/src/modeling/__pycache__/l0_fllama.cpython-310.pyc -------------------------------------------------------------------------------- /src/modeling/__pycache__/modeling_erazr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/src/modeling/__pycache__/modeling_erazr.cpython-310.pyc -------------------------------------------------------------------------------- /src/modeling/__pycache__/modeling_fllama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/src/modeling/__pycache__/modeling_fllama.cpython-310.pyc -------------------------------------------------------------------------------- /src/modeling/__pycache__/modeling_fpt2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/src/modeling/__pycache__/modeling_fpt2.cpython-310.pyc -------------------------------------------------------------------------------- /src/modeling/draw_fpt2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | import graphviz 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument("--in_path", "-i", type=str, required=True) 10 | parser.add_argument("--out_path", "-o", type=str, default=None) 11 | parser.add_argument("--constants", "-c", default=[], nargs="+") 12 | parser.add_argument("--no-sanitize", "-ns", action="store_true") 13 | 14 | args = parser.parse_args() 15 | 16 | if args.out_path is None: 17 | assert "edges.json" in args.in_path, "Please provide the out_path" 18 | args.out_path = args.in_path.replace("edges.json", "edges.pdf") 19 | args.constants = [c for c in args.constants if c.lower() != "[none]"] 20 | return args 21 | 22 | def get_constant_color(color="aquamarine1"): 23 | return lambda x: color 24 | 25 | def get_circuit_colors_v1( 26 | embeds_color="azure", # azure2 27 | mlp_color="cadetblue2", # cornflowerblue 28 | q_color="plum1", # orchid3 29 | k_color="lightpink", # chocolate1 30 | v_color="khaki1", # gold1 31 | o_color="darkslategray3", # aquamarine1 32 | resid_post_color="azure" # azure2 33 | ): 34 | def decide_color(node_name): 35 | if "embed" in node_name: 36 | return embeds_color 37 | elif node_name == "resid_post": 38 | return resid_post_color 39 | elif node_name.startswith("m"): 40 | return mlp_color 41 | elif node_name.endswith(".q"): 42 | return q_color 43 | elif node_name.endswith(".k"): 44 | return k_color 45 | elif node_name.endswith(".v"): 46 | return v_color 47 | else: 48 | return o_color 49 | return decide_color 50 | 51 | def get_circuit_colors( 52 | tok_embeds_color="grey", # azure2 53 | pos_embeds_color="lightsteelblue", # azure2 54 | mlp_color="cadetblue2", # cornflowerblue 55 | q_color="plum1", # orchid3 56 | k_color="lightpink", # chocolate1 57 | v_color="khaki1", # gold1 58 | o_color="darkslategray3", # aquamarine1 59 | resid_post_color="azure" # azure2 60 | ): 61 | def decide_color(node_name): 62 | node_name = node_name.lower() 63 | if "embed" in node_name: 64 | if "pos" in node_name: 65 | return pos_embeds_color 66 | else: 67 | return tok_embeds_color 68 | elif node_name == "output": 69 | return resid_post_color 70 | elif node_name.startswith("m"): 71 | return mlp_color 72 | elif node_name.endswith(".q"): 73 | return q_color 74 | elif node_name.endswith(".k"): 75 | return k_color 76 | elif node_name.endswith(".v"): 77 | return v_color 78 | else: 79 | return o_color 80 | return decide_color 81 | 82 | def sanitize_edges(edges): 83 | # First, add all q,k,v -> o edges 84 | new_edges_ = set() 85 | for edge in edges: 86 | if edge[0][0] == "a" and edge[0][-1] not in ["q", "k", "v"]: 87 | new_edges_.add(edge[0]) 88 | for to in new_edges_: 89 | for suffix in [".q", ".k", ".v"]: 90 | from_ = to + suffix 91 | edges.append((from_, to)) 92 | while True: 93 | orig_len = len(edges) 94 | # Find all nodes that are destinations but not sources 95 | froms = set() 96 | tos = set() 97 | for edge in edges: 98 | froms.add(edge[0]) 99 | if edge[1] != "resid_post": 100 | tos.add(edge[1]) 101 | banned_tos = tos.difference(froms) 102 | edges = [e for e in edges if e[1] not in banned_tos] 103 | 104 | # Find qkv nodes that have no incoming edges, and remove the q -> o edge for them 105 | qkv_nodes = set() 106 | for edge in edges: 107 | if edge[1].endswith(".q"): 108 | qkv_nodes.add(edge[1]) 109 | elif edge[1].endswith(".k"): 110 | qkv_nodes.add(edge[1]) 111 | elif edge[1].endswith(".v"): 112 | qkv_nodes.add(edge[1]) 113 | 114 | edges = [ 115 | e for e in edges if not ( 116 | (e[0].endswith(".q") and e[0] not in qkv_nodes) or 117 | (e[0].endswith(".k") and e[0] not in qkv_nodes) or 118 | (e[0].endswith(".v") and e[0] not in qkv_nodes) 119 | ) 120 | ] 121 | if orig_len == len(edges): 122 | break 123 | 124 | return edges 125 | 126 | def rename(name): 127 | if type(name) != str: 128 | return [rename(n) for n in name] 129 | if "embeds" in name: 130 | return "Embeddings" 131 | if name == "resid_post": 132 | return "Output" 133 | if name.startswith("m"): 134 | l = int(name[1:]) 135 | return f"MLP {l}" 136 | if name.endswith(".q"): 137 | parts = name.split(".") 138 | l = int(parts[0][1:]) 139 | h = int(parts[1][1:]) 140 | return f"Head {l}.{h}.Q" 141 | if name.endswith(".k"): 142 | parts = name.split(".") 143 | l = int(parts[0][1:]) 144 | h = int(parts[1][1:]) 145 | return f"Head {l}.{h}.K" 146 | if name.endswith(".v"): 147 | parts = name.split(".") 148 | l = int(parts[0][1:]) 149 | h = int(parts[1][1:]) 150 | return f"Head {l}.{h}.V" 151 | parts = name.split(".") 152 | assert len(parts) == 2, f"Invalid node name {name}" 153 | l = int(parts[0][1:]) 154 | h = int(parts[1][1:]) 155 | return f"Head {l}.{h}.O" 156 | 157 | def main(): 158 | args = parse_args() 159 | 160 | edges = json.load(open(args.in_path)) 161 | if not args.no_sanitize: 162 | edges = sanitize_edges(edges) 163 | 164 | edges = [rename(e) for e in edges] 165 | 166 | out_path = args.out_path 167 | out_path_temp = args.out_path + ".pdf" 168 | 169 | coloring_fn = get_circuit_colors() 170 | constant_edge_color = 'gray66' 171 | 172 | constant_in = args.constants 173 | 174 | nodes = set(constant_in + [x for y in edges for x in y]) 175 | 176 | kwargs = { 177 | "graph_attr": { 178 | "nodesep": "0.02", 179 | "ranksep": "0.02", 180 | "ratio":"1:6", 181 | }, 182 | "node_attr": { 183 | "shape": "box", 184 | "style": "rounded,filled", 185 | }, 186 | } 187 | 188 | g = graphviz.Digraph(**kwargs) 189 | for node in nodes: 190 | g.node(node, color='black', fillcolor=coloring_fn(node)) 191 | 192 | for edge in edges: 193 | g.edge(edge[0], edge[1], color=coloring_fn(edge[0])) 194 | 195 | for node in nodes: 196 | for cin in constant_in: 197 | if node not in constant_in: 198 | g.edge(cin, node, color=constant_edge_color) 199 | 200 | g.render(out_path) 201 | os.rename(out_path_temp, out_path) 202 | 203 | 204 | if __name__ == '__main__': 205 | main() -------------------------------------------------------------------------------- /src/modeling/l0.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | LIMIT_LEFT = -0.1 7 | LIMIT_RIGHT = 1.1 8 | EPS = 1e-6 9 | TEMPERATURE = 2 / 3 10 | FACTOR = 0.8 11 | 12 | def cdf_stretched_concrete(x, log_alpha): 13 | x_01 = (x - LIMIT_LEFT) / (LIMIT_RIGHT - LIMIT_LEFT) 14 | intermediate = math.log(x_01) - math.log(1 - x_01) 15 | prob_unclamped = torch.sigmoid(TEMPERATURE * intermediate - log_alpha) 16 | prob_clamped = torch.clamp(prob_unclamped, EPS, 1 - EPS) 17 | return prob_clamped 18 | 19 | def sample_z_from_u(u, log_alpha): 20 | s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + log_alpha) / TEMPERATURE) 21 | return (LIMIT_RIGHT - LIMIT_LEFT) * s + LIMIT_LEFT 22 | 23 | def deterministic_z_from_log_alpha(log_alpha, apply_one=False): 24 | size = np.prod(log_alpha.shape) 25 | 26 | # Since the distribution is stretched to [-eps, 1+eps], the prob of a variable <= 0 equals its prob to 0 27 | expected_num_nonzeros = torch.sum(1 - cdf_stretched_concrete(0, log_alpha)) 28 | expected_num_zeros = size - expected_num_nonzeros 29 | num_zeros = int(torch.round(expected_num_zeros).item()) 30 | 31 | soft_mask = torch.sigmoid(log_alpha / TEMPERATURE * FACTOR).reshape(-1) 32 | 33 | if num_zeros > 0: 34 | if soft_mask.ndim == 0: 35 | soft_mask = torch.tensor(0).to(log_alpha.device) 36 | else: 37 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) 38 | soft_mask[indices] = 0 39 | if apply_one: 40 | soft_mask[soft_mask > 0] = 1 41 | return soft_mask.reshape(log_alpha.shape) 42 | 43 | def sample_z_from_log_alpha(log_alpha): 44 | u = torch.autograd.Variable(torch.FloatTensor(log_alpha.shape).uniform_(EPS, 1 - EPS)).to(log_alpha.device) 45 | z = sample_z_from_u(u, log_alpha) 46 | z = F.hardtanh(z, 0, 1) 47 | 48 | return z 49 | 50 | if __name__ == '__main__': 51 | pass -------------------------------------------------------------------------------- /src/modeling/l0_fllama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | LIMIT_LEFT = -0.1 7 | LIMIT_RIGHT = 1.1 8 | EPS = 1e-8 # 1e-6 9 | TEMPERATURE = 2 / 3 10 | FACTOR = 0.8 11 | 12 | def cdf_stretched_concrete(x, log_alpha): 13 | x_01 = (x - LIMIT_LEFT) / (LIMIT_RIGHT - LIMIT_LEFT) 14 | intermediate = math.log(x_01) - math.log(1 - x_01) 15 | 16 | precursor = TEMPERATURE * intermediate - log_alpha 17 | 18 | prob_unclamped = torch.sigmoid(precursor) 19 | prob_clamped = torch.clamp(prob_unclamped, EPS, 1 - EPS) 20 | return prob_clamped 21 | 22 | def sample_z_from_u(u, log_alpha): 23 | s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + log_alpha) / TEMPERATURE) 24 | return (LIMIT_RIGHT - LIMIT_LEFT) * s + LIMIT_LEFT 25 | 26 | def deterministic_z_from_log_alpha(log_alpha, apply_one=False): 27 | size = np.prod(log_alpha.shape) 28 | 29 | # Since the distribution is stretched to [-eps, 1+eps], the prob of a variable <= 0 equals its prob to 0 30 | csc = cdf_stretched_concrete(0, log_alpha) 31 | expected_num_nonzeros = torch.sum(1 - csc) 32 | expected_num_zeros = size - expected_num_nonzeros 33 | num_zeros = torch.round(expected_num_zeros).item() 34 | 35 | num_zeros = int(num_zeros) 36 | 37 | soft_mask = torch.sigmoid(log_alpha / TEMPERATURE * FACTOR).reshape(-1) 38 | 39 | if num_zeros > 0: 40 | if soft_mask.ndim == 0: 41 | soft_mask = torch.tensor(0).to(log_alpha.device) 42 | else: 43 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) 44 | soft_mask[indices] = 0 45 | if apply_one: 46 | soft_mask[soft_mask > 0] = 1 47 | return soft_mask.reshape(log_alpha.shape) 48 | 49 | def sample_z_from_log_alpha_old(log_alpha): 50 | u = torch.autograd.Variable(torch.FloatTensor(log_alpha.shape).uniform_(EPS, 1 - EPS)).to(log_alpha.device) 51 | z = sample_z_from_u(u, log_alpha) 52 | z = F.hardtanh(z, 0, 1) 53 | 54 | return z 55 | 56 | def sample_z_from_log_alpha(log_alpha): 57 | u = torch.autograd.Variable(torch.empty(log_alpha.shape, dtype=log_alpha.dtype).uniform_(EPS, 1 - EPS)).to(log_alpha.device) 58 | z = sample_z_from_u(u, log_alpha) 59 | z = F.hardtanh(z, 0, 1) 60 | 61 | return z 62 | 63 | if __name__ == '__main__': 64 | pass -------------------------------------------------------------------------------- /src/modeling/vis_fllama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import sys 6 | sys.path.append( 7 | os.path.join( 8 | os.getcwd(), 9 | "src/modeling/" 10 | ) 11 | ) 12 | from modeling_fllama import FLlamaForCausalLM 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("-m1", "--model1", type=str, required=True) 18 | parser.add_argument("-m2", "--model2", type=str, default="") 19 | parser.add_argument("-w", "--with_embedding_nodes", action="store_true") 20 | parser.add_argument("-o", "--output", type=str, default="") 21 | 22 | args = parser.parse_args() 23 | 24 | if args.output == "": 25 | if args.model2 == "": 26 | args.output = os.path.join(args.model1, "edges.json") 27 | else: 28 | raise ValueError("Output file must be specified when comparing two models") 29 | 30 | return args 31 | 32 | def main(): 33 | args = parse_args() 34 | 35 | model = FLlamaForCausalLM.from_pretrained(args.model1, with_embedding_nodes=args.with_embedding_nodes) 36 | 37 | edge_sparsity = model.get_edge_sparsity() 38 | node_sparsity = model.get_node_sparsity() 39 | 40 | # Binary search for the threshold 41 | l = 0 42 | r = 1 43 | while r-l > 1e-5: 44 | threshold = (l+r)/2 45 | model.set_edge_threshold_for_deterministic(threshold) 46 | sparsity = model.get_edge_sparsity() 47 | if sparsity > edge_sparsity: 48 | r = threshold 49 | else: 50 | l = threshold 51 | model.set_edge_threshold_for_deterministic(threshold) 52 | 53 | edge_threshold = threshold 54 | print("Edge threshold (1):", format(edge_threshold, '.60g')) 55 | 56 | # Binary search for the threshold 57 | l = 0 58 | r = 1 59 | while r-l > 1e-5: 60 | threshold = (l+r)/2 61 | model.set_node_threshold_for_deterministic(threshold) 62 | sparsity = model.get_node_sparsity() 63 | if sparsity > node_sparsity: 64 | r = threshold 65 | else: 66 | l = threshold 67 | model.set_node_threshold_for_deterministic(threshold) 68 | node_threshold = threshold 69 | print("Node threshold (1):", format(node_threshold ,'.60g')) 70 | 71 | overall_edge_sparsity = model.get_effective_edge_sparsity() 72 | print("Overall edge sparsity (1):", format(overall_edge_sparsity)) 73 | 74 | edges = model.get_edges() 75 | 76 | if args.model2 != "": 77 | model2 = FLlamaForCausalLM.from_pretrained(args.model2, with_embedding_nodes=args.with_embedding_nodes) 78 | edge_sparsity = model2.get_edge_sparsity() 79 | node_sparsity = model2.get_node_sparsity() 80 | 81 | # Binary search for the threshold 82 | l = 0 83 | r = 1 84 | while r-l > 1e-5: 85 | threshold = (l+r)/2 86 | model2.set_edge_threshold_for_deterministic(threshold) 87 | sparsity = model2.get_edge_sparsity() 88 | if sparsity > edge_sparsity: 89 | r = threshold 90 | else: 91 | l = threshold 92 | model2.set_edge_threshold_for_deterministic(threshold) 93 | 94 | edge_threshold = threshold 95 | print("Edge threshold (2):", format(edge_threshold, '.60g')) 96 | 97 | # Binary search for the threshold 98 | l = 0 99 | r = 1 100 | while r-l > 1e-5: 101 | threshold = (l+r)/2 102 | model2.set_node_threshold_for_deterministic(threshold) 103 | sparsity = model2.get_node_sparsity() 104 | if sparsity > node_sparsity: 105 | r = threshold 106 | else: 107 | l = threshold 108 | model2.set_node_threshold_for_deterministic(threshold) 109 | node_threshold = threshold 110 | print("Node threshold (2):", format(node_threshold ,'.60g')) 111 | 112 | overall_edge_sparsity = model.get_effective_edge_sparsity() 113 | print("Overall edge sparsity (2):", format(overall_edge_sparsity)) 114 | 115 | edges2 = model2.get_edges() 116 | 117 | # Intersection 118 | edges = [e[0]+"#"+e[1] for e in edges] 119 | edges2 = [e[0]+"#"+e[1] for e in edges2] 120 | edges = list(set(edges).intersection(edges2)) 121 | edges = [e.split("#") for e in edges] 122 | 123 | print("Saving", len(edges), "edges...") 124 | json.dump(edges, open(args.output, "w+"), indent=4) 125 | 126 | if __name__ == "__main__": 127 | main() -------------------------------------------------------------------------------- /src/modeling/vis_fpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import sys 6 | sys.path.append( 7 | os.path.join( 8 | os.getcwd(), 9 | "src/modeling/" 10 | ) 11 | ) 12 | from modeling_fpt2 import FPT2LMHeadModel 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("--in_path", "-i", type=str, required=True) 18 | parser.add_argument("--out_path", "-o", type=str, default=None) 19 | parser.add_argument("--with_embedding_nodes", "-w", action="store_true") 20 | parser.add_argument("--edge_sparsity", "-e", type=float, default=None) 21 | parser.add_argument("--node_sparsity", "-n", type=float, default=None) 22 | 23 | args = parser.parse_args() 24 | 25 | if args.out_path is None: 26 | args.out_path = os.path.join(args.in_path, "edges.json") 27 | print(f"Output path not specified. Saving to {args.out_path}.") 28 | 29 | return args 30 | 31 | def main(): 32 | args = parse_args() 33 | 34 | model = FPT2LMHeadModel.from_pretrained(args.in_path, with_embedding_nodes=args.with_embedding_nodes) 35 | if args.edge_sparsity is None: 36 | args.edge_sparsity = model.get_edge_sparsity() 37 | if args.node_sparsity is None: 38 | args.node_sparsity = model.get_node_sparsity() 39 | 40 | l = 0 41 | r = 1 42 | while r-l > 1e-5: 43 | threshold = (l+r)/2 44 | model.set_edge_threshold_for_deterministic(threshold) 45 | sparsity = model.get_edge_sparsity() 46 | if sparsity > args.edge_sparsity: 47 | r = threshold 48 | else: 49 | l = threshold 50 | 51 | l = 0 52 | r = 1 53 | while r-l > 1e-5: 54 | threshold = (l+r)/2 55 | model.set_node_threshold_for_deterministic(threshold) 56 | sparsity = model.get_node_sparsity() 57 | if sparsity > args.node_sparsity: 58 | r = threshold 59 | else: 60 | l = threshold 61 | 62 | overall_edge_sparsity = model.get_effective_edge_sparsity() 63 | print("Overall edge sparsity:", overall_edge_sparsity.item()) 64 | 65 | edges = model.get_edges() 66 | 67 | json.dump(edges, open(args.out_path, "w+"), indent=4) 68 | 69 | if __name__ == '__main__': 70 | main() -------------------------------------------------------------------------------- /tracrx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides the main compiler function as a public import.""" 16 | 17 | from tracrx.compiler.compiling import compile_rasp_to_model 18 | 19 | __all__ = ["compile_rasp_to_model"] 20 | -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/assemble.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/assemble.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/assemble.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/assemble.cpython-311.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/basis_inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/basis_inference.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/compiling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/compiling.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/compiling.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/compiling.cpython-311.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/craft_graph_to_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/craft_graph_to_model.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/craft_model_to_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/craft_model_to_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/expr_to_craft_graph.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/expr_to_craft_graph.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/nodes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/nodes.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/rasp_to_graph.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/rasp_to_graph.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/__pycache__/validating.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/compiler/__pycache__/validating.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/compiler/assemble_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformer.assemble.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from tracrx.compiler import assemble 24 | from tracrx.craft import bases 25 | 26 | 27 | class AssembleTest(parameterized.TestCase): 28 | 29 | def test_token_embedding_produces_correct_embedding(self): 30 | # Token embeddings should be one-hot embeddings of the input integers 31 | # into the token subspace of residual_space 32 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 33 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 34 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 35 | residual_space = bases.join_vector_spaces(input_space, indices_space, 36 | output_space) 37 | 38 | @hk.without_apply_rng 39 | @hk.transform 40 | def token_pos_embed(tokens): 41 | embed_modules = assemble._make_embedding_modules( 42 | residual_space=residual_space, 43 | tokens_space=input_space, 44 | indices_space=indices_space, 45 | output_space=output_space) 46 | return embed_modules.token_embed(tokens) 47 | 48 | tokens = jnp.array([0, 0, 1]) 49 | expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0], 50 | [1, 0, 0, 0, 0, 0, 0], 51 | [0, 1, 0, 0, 0, 0, 0]]) 52 | 53 | params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) 54 | embeddings = token_pos_embed.apply(params, tokens) 55 | np.testing.assert_allclose(embeddings, expected_token_embeddings) 56 | 57 | def test_position_embedding_produces_correct_embedding(self): 58 | # Position embeddings should be one-hot embeddings of the input integers 59 | # (representing indices) into the indices subspace of residual_space 60 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 61 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 62 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 63 | residual_space = bases.join_vector_spaces(input_space, indices_space, 64 | output_space) 65 | 66 | @hk.without_apply_rng 67 | @hk.transform 68 | def token_pos_embed(tokens): 69 | embed_modules = assemble._make_embedding_modules( 70 | residual_space=residual_space, 71 | tokens_space=input_space, 72 | indices_space=indices_space, 73 | output_space=output_space) 74 | return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1]) 75 | 76 | tokens = jnp.array([3, 0, 0, 1]) 77 | expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0], 78 | [0, 0, 1, 0, 0, 0, 0], 79 | [0, 0, 0, 1, 0, 0, 0], 80 | [0, 0, 0, 0, 1, 0, 0]]) 81 | 82 | params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) 83 | embeddings = token_pos_embed.apply(params, tokens) 84 | np.testing.assert_allclose(embeddings, expected_pos_embeddings) 85 | 86 | def test_unembedding(self): 87 | # Prepend numbers to preserve basis order [input, index, output] 88 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 89 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 90 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 91 | residual_space = bases.join_vector_spaces(input_space, indices_space, 92 | output_space) 93 | 94 | @hk.without_apply_rng 95 | @hk.transform 96 | def unembed(embeddings): 97 | embed_modules = assemble._make_embedding_modules( 98 | residual_space=residual_space, 99 | tokens_space=input_space, 100 | indices_space=indices_space, 101 | output_space=output_space) 102 | return embed_modules.unembed(embeddings, use_unembed_argmax=True) 103 | 104 | embeddings = jnp.array([ 105 | # pylint: disable=g-no-space-after-comment 106 | #inp| indices| out | < spaces 107 | #0 1 0 1 2 0 1 < values in spaces 108 | [0, 0, 0, 0, 0, 0, 1], 109 | [0, 0, 0, 0, 0, 1, 0], 110 | [0, 0, 0, 0, 0, 0, 1] 111 | ]) 112 | expected_tokens = jnp.array([1, 0, 1]) 113 | 114 | params = unembed.init(jax.random.PRNGKey(0), embeddings) 115 | tokens = unembed.apply(params, embeddings) 116 | np.testing.assert_allclose(tokens, expected_tokens) 117 | 118 | 119 | if __name__ == "__main__": 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /tracrx/compiler/basis_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Inferring the vector spaces taken on by certain operations.""" 16 | 17 | import dataclasses 18 | import itertools 19 | from typing import Set 20 | 21 | import networkx as nx 22 | from tracrx.compiler import nodes 23 | from tracrx.craft import bases 24 | from tracrx.rasp import rasp 25 | from tracrx.utils import errors 26 | 27 | Node = nodes.Node 28 | 29 | 30 | @dataclasses.dataclass 31 | class InferBasesOutput: 32 | graph: nx.DiGraph 33 | 34 | 35 | def infer_bases( 36 | graph: nx.DiGraph, 37 | sink: Node, 38 | vocab: Set[rasp.Value], 39 | max_seq_len: int, 40 | ) -> None: 41 | """Infers in-place the possible output values and vector bases of the SOps.""" 42 | 43 | def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: 44 | """Computes value set using already-computed predecessor value sets.""" 45 | if isinstance(sop, rasp.TokensType): 46 | return vocab 47 | elif isinstance(sop, rasp.IndicesType): 48 | return set(range(max_seq_len)) 49 | elif isinstance(sop, rasp.SelectorWidth): 50 | return set(range(0, max_seq_len + 1)) 51 | elif isinstance(sop, rasp.Full): 52 | return {sop.fill} 53 | elif isinstance(sop, rasp.Map): 54 | inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET] 55 | out = set() 56 | for x in inner_value_set: 57 | res = errors.ignoring_arithmetic_errors(sop.f)(x) 58 | if res is not None: 59 | out.add(res) 60 | return out 61 | elif isinstance(sop, rasp.SequenceMap): 62 | f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) 63 | fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET] 64 | snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET] 65 | out = set() 66 | for l, r in itertools.product(fst_value_set, snd_value_set): 67 | res = f_ignore_error(l, r) 68 | if res is not None: 69 | out.add(res) 70 | return out 71 | elif isinstance(sop, rasp.Aggregate): 72 | if rasp.is_categorical(sop): 73 | # Simply pass on the value set of the underlying S-Op. 74 | return graph.nodes[sop.sop.label][nodes.VALUE_SET] 75 | elif rasp.is_numerical(sop): 76 | # TODO(b/255936408): This doesn't work if we average arbitrary values. 77 | # But most examples only average binary variables. 78 | sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET] 79 | if not {int(x) for x in sop_value_set}.issubset({0, 1}): 80 | raise NotImplementedError( 81 | "Attention patterns can currently only " 82 | "average binary variables. Not:", sop_value_set) 83 | 84 | value_set = set() 85 | for value in sop_value_set: 86 | for length in range(1, max_seq_len + 1): 87 | value_set.add(value / length) 88 | return value_set 89 | raise ValueError(f"Unsupported S-Op: {sop}") 90 | 91 | for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]): 92 | expr = graph.nodes[node_id][nodes.EXPR] 93 | 94 | if not isinstance(expr, rasp.SOp): 95 | # Only S-Ops have output vector spaces. 96 | continue 97 | 98 | value_set = compute_value_set(expr) 99 | graph.nodes[node_id][nodes.VALUE_SET] = value_set 100 | 101 | if rasp.is_categorical(expr): 102 | out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set) 103 | elif rasp.is_numerical(expr): 104 | out_space = bases.VectorSpaceWithBasis.from_names([expr.label]) 105 | else: 106 | raise ValueError(f"Unsupported S-Op type: {expr.type}") 107 | graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis 108 | -------------------------------------------------------------------------------- /tracrx/compiler/basis_inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.basis_inference.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.compiler import basis_inference 20 | from tracrx.compiler import nodes 21 | from tracrx.compiler import rasp_to_graph 22 | from tracrx.rasp import rasp 23 | 24 | 25 | class InferBasesTest(parameterized.TestCase): 26 | 27 | def test_arithmetic_error_logs_warning(self): 28 | program = rasp.numerical(rasp.Map(lambda x: 1 / x, rasp.tokens)) 29 | extracted = rasp_to_graph.extract_rasp_graph(program) 30 | vocab = {0, 1, 2} 31 | with self.assertLogs(level="WARNING"): 32 | basis_inference.infer_bases( 33 | extracted.graph, 34 | extracted.sink, 35 | vocab, 36 | max_seq_len=1, 37 | ) 38 | 39 | @parameterized.parameters(({1, 2, 3}, {2, 3, 4}), ({0, 5}, {1, 6})) 40 | def test_one_edge(self, vocab, expected_value_set): 41 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 42 | extracted = rasp_to_graph.extract_rasp_graph(program) 43 | 44 | basis_inference.infer_bases( 45 | extracted.graph, 46 | extracted.sink, 47 | vocab, 48 | max_seq_len=1, 49 | ) 50 | 51 | self.assertSetEqual( 52 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 53 | expected_value_set, 54 | ) 55 | 56 | def test_primitive_close_to_tip(self): 57 | intermediate = rasp.categorical(rasp.tokens + 1) 58 | intermediate = rasp.categorical(intermediate + intermediate) 59 | program = rasp.categorical(intermediate + rasp.indices) 60 | extracted = rasp_to_graph.extract_rasp_graph(program) 61 | 62 | basis_inference.infer_bases( 63 | extracted.graph, 64 | extracted.sink, 65 | {0, 1}, 66 | max_seq_len=2, 67 | ) 68 | 69 | self.assertSetEqual( 70 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 71 | {2, 3, 4, 5}, 72 | ) 73 | self.assertSetEqual( 74 | extracted.graph.nodes[intermediate.label][nodes.VALUE_SET], 75 | {2, 3, 4}, 76 | ) 77 | 78 | @parameterized.named_parameters( 79 | dict( 80 | testcase_name="categorical_aggregate", 81 | program=rasp.categorical( 82 | rasp.Aggregate( 83 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 84 | rasp.indices, 85 | ) 86 | ), 87 | vocab={0, 1}, 88 | max_seq_len=3, 89 | expected_value_set={0, 1, 2}, 90 | ), 91 | dict( 92 | testcase_name="numerical_aggregate", 93 | program=rasp.numerical( 94 | rasp.Aggregate( 95 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 96 | rasp.tokens, 97 | ) 98 | ), 99 | vocab={0, 1}, 100 | max_seq_len=2, 101 | expected_value_set={0, 1, 1 / 2}, 102 | ), 103 | dict( 104 | testcase_name="selector_width", 105 | program=rasp.SelectorWidth( 106 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ) 107 | ), 108 | vocab={0, 1}, 109 | max_seq_len=3, 110 | expected_value_set={0, 1, 2, 3}, 111 | ), 112 | dict( 113 | testcase_name="annotated_tokens", 114 | program=rasp.categorical(rasp.tokens), 115 | vocab={"a", "b"}, 116 | max_seq_len=2, 117 | expected_value_set={"a", "b"}, 118 | ), 119 | dict( 120 | testcase_name="annotated_indices", 121 | program=rasp.categorical(rasp.indices), 122 | vocab={"a", "b"}, 123 | max_seq_len=2, 124 | expected_value_set={0, 1}, 125 | ), 126 | ) 127 | def test_inferred_value_set_as_expected( 128 | self, program, vocab, max_seq_len, expected_value_set 129 | ): 130 | extracted = rasp_to_graph.extract_rasp_graph(program) 131 | 132 | basis_inference.infer_bases( 133 | extracted.graph, 134 | extracted.sink, 135 | vocab, 136 | max_seq_len=max_seq_len, 137 | ) 138 | 139 | self.assertSetEqual( 140 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 141 | expected_value_set, 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | absltest.main() 147 | -------------------------------------------------------------------------------- /tracrx/compiler/compiling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Combines all steps of compiling a RASP program.""" 16 | 17 | from typing import Set 18 | 19 | from tracrx.compiler import assemble 20 | from tracrx.compiler import basis_inference 21 | from tracrx.compiler import craft_graph_to_model 22 | from tracrx.compiler import craft_model_to_transformer 23 | from tracrx.compiler import expr_to_craft_graph 24 | from tracrx.compiler import rasp_to_graph 25 | from tracrx.compiler import validating 26 | from tracrx.craft import bases 27 | from tracrx.rasp import rasp 28 | 29 | 30 | COMPILER_BOS = "compiler_bos" 31 | COMPILER_PAD = "compiler_pad" 32 | 33 | 34 | def compile_rasp_to_model( 35 | program: rasp.SOp, 36 | vocab: Set[rasp.Value], 37 | max_seq_len: int, 38 | causal: bool = False, 39 | compiler_bos: str = COMPILER_BOS, 40 | compiler_pad: str = COMPILER_PAD, 41 | mlp_exactness: int = 100, 42 | ) -> assemble.AssembledTransformerModel: 43 | """Compile a RASP program to transformer weights. 44 | 45 | Note that currently not all RASP features are supported. Most unsupported 46 | features are detected at compile time and will cause a NotImplementedError. 47 | However, a few unsupported features cannot be checked at compile time and 48 | can cause silent errors. 49 | 50 | See `compiler.validating` for details and a function to quickly check if 51 | a program is compilable with Tracr without needing to compile it. 52 | 53 | Args: 54 | program: the RASP program to compile. 55 | vocab: the set of vocab tokens expected by RASP. 56 | max_seq_len: the maximum sequence length for the compiled model. 57 | causal: if True, outputs a model with causal masking. 58 | compiler_bos: the name of the special BOS token that will be added by the 59 | compiler. Must not be present in the vocab. 60 | compiler_pad: the name of the special PAD token that will be added by the 61 | compiler. Must not be present in the vocab. 62 | mlp_exactness: Controls the approximation of the MLP layers. In theory, 63 | larger values yield a better approximation. But too large values can cause 64 | numerical issues due to large parameter norms. Reasonable values are 65 | between 1 and 100. 66 | 67 | Returns: 68 | The compiled model. 69 | 70 | Raises: 71 | NotImplementedError: if the program uses unsopported features that can be 72 | caught at compile time. 73 | """ 74 | 75 | if compiler_bos in vocab: 76 | raise ValueError( 77 | "Compiler BOS token must not be present in the vocab. " 78 | f"Found '{compiler_bos}' in {vocab}" 79 | ) 80 | 81 | if compiler_pad in vocab: 82 | raise ValueError( 83 | "Compiler PAD token must not be present in the vocab. " 84 | f"Found '{compiler_pad}' in {vocab}" 85 | ) 86 | 87 | # Perform static validation to fail fast. This catches most programs that 88 | # tracr is unable to compile. 89 | unsupported_exprs = validating.static_validate(program) 90 | if unsupported_exprs: 91 | error_message = "\n".join( 92 | (f"{expr.expr.name}: {expr.reason}" for expr in unsupported_exprs) 93 | ) 94 | error_message = f"Unsupported RASP expressions:\n{error_message}" 95 | raise NotImplementedError(error_message) 96 | 97 | extracted = rasp_to_graph.extract_rasp_graph(program) 98 | graph, sources, sink = extracted.graph, extracted.sources, extracted.sink 99 | 100 | basis_inference.infer_bases( 101 | graph, 102 | sink, 103 | vocab, 104 | max_seq_len, 105 | ) 106 | 107 | expr_to_craft_graph.add_craft_components_to_rasp_graph( 108 | graph, 109 | bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos), 110 | mlp_exactness=mlp_exactness, 111 | ) 112 | 113 | craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources) 114 | 115 | return craft_model_to_transformer.craft_model_to_transformer( 116 | craft_model=craft_model, 117 | graph=graph, 118 | sink=sink, 119 | max_seq_len=max_seq_len, 120 | causal=causal, 121 | compiler_bos=compiler_bos, 122 | compiler_pad=compiler_pad, 123 | ) 124 | -------------------------------------------------------------------------------- /tracrx/compiler/craft_graph_to_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.craft_graph_to_model.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import networkx as nx 20 | from tracrx.compiler import craft_graph_to_model 21 | from tracrx.compiler import nodes 22 | from tracrx.compiler import rasp_to_graph 23 | from tracrx.craft import bases 24 | from tracrx.craft.chamber import categorical_attn 25 | from tracrx.craft.chamber import categorical_mlp 26 | from tracrx.rasp import rasp 27 | 28 | 29 | class CraftAllocateModulesToLayersTest(parameterized.TestCase): 30 | 31 | def _get_dummy_block(self, block_type): 32 | if block_type == "ATTN": 33 | return categorical_attn.categorical_attn( 34 | query_space=bases.VectorSpaceWithBasis.from_names(["query"]), 35 | key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]), 36 | value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]), 37 | output_space=bases.VectorSpaceWithBasis.from_names(["output"]), 38 | bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]), 39 | one_space=bases.VectorSpaceWithBasis.from_names(["one"]), 40 | attn_fn=lambda x, y: True, 41 | ) 42 | elif block_type == "MLP": 43 | return categorical_mlp.map_categorical_mlp( 44 | input_space=bases.VectorSpaceWithBasis.from_names(["input"]), 45 | output_space=bases.VectorSpaceWithBasis.from_names(["output"]), 46 | operation=lambda x: x, 47 | ) 48 | else: 49 | return None 50 | 51 | def test_compute_computational_depth_returns_expected_result(self): 52 | """Creates a graph and checks the longest path for each node.""" 53 | 54 | # Node IDs: 55 | # 0 -- 1 -- 2 -- 3 ------------ 4 56 | # / / 57 | # 5 -- 6 ---------- 7 -- 8 -- 9 58 | # 59 | # 10 60 | # Expected return values: 61 | # 0 -- 1 -- 2 -- 3 ------------ 5 62 | # / / 63 | # 0 -- 1 ---------- 2 -- 3 -- 4 64 | # 65 | # -1 66 | 67 | graph = nx.DiGraph() 68 | node_ids = list(range(11)) 69 | expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1] 70 | for node_id, res in zip(node_ids, expected_results): 71 | graph.add_node( 72 | node_id, **{ 73 | nodes.ID: node_id, 74 | nodes.EXPR: rasp.ConstantSOp(1), 75 | "expected_result": res 76 | }) 77 | graph.add_edge(0, 1) 78 | graph.add_edge(1, 2) 79 | graph.add_edge(2, 3) 80 | graph.add_edge(3, 4) 81 | graph.add_edge(5, 6) 82 | graph.add_edge(6, 7) 83 | graph.add_edge(7, 8) 84 | graph.add_edge(8, 9) 85 | graph.add_edge(6, 3) 86 | graph.add_edge(9, 4) 87 | sources = [graph.nodes[0], graph.nodes[5]] 88 | 89 | computational_depth = craft_graph_to_model.compute_computational_depth( 90 | graph, [src[nodes.ID] for src in sources] 91 | ) 92 | for node_id, node in graph.nodes.items(): 93 | self.assertEqual(computational_depth[node_id], node["expected_result"]) 94 | 95 | def test_allocate_modules_to_layers_returns_expected_result(self): 96 | """Creates a graph and checks if the correct layer assignment is returned.""" 97 | 98 | # Computation Graph: 99 | # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT 100 | # / / / 101 | # INPUT -- MLP --- MLP ATTN 102 | # \ / 103 | # ATTN 104 | # Node IDs: 105 | # 0 -- 1 -- 2 -- 3 -- 4 -- 5 106 | # / / / 107 | # 6 -- 7 ---- 8 9 108 | # \ / 109 | # 10 110 | # Expected layer allocation: 111 | # -1 -- 0 -- 3 -- 4 -- 7 -- -1 112 | # / / / 113 | # -1 -- 1 --- 3 6 114 | # \ / 115 | # 4 116 | 117 | graph = nx.DiGraph() 118 | node_ids = list(range(11)) 119 | types = [ 120 | "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP", 121 | "ATTN", "ATTN" 122 | ] 123 | expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4] 124 | for node_id, node_type, res in zip(node_ids, types, expected_results): 125 | graph.add_node( 126 | node_id, **{ 127 | nodes.ID: node_id, 128 | nodes.EXPR: rasp.ConstantSOp(1), 129 | nodes.MODEL_BLOCK: self._get_dummy_block(node_type), 130 | "expected_result": res 131 | }) 132 | 133 | graph.add_edge(0, 1) 134 | graph.add_edge(1, 2) 135 | graph.add_edge(2, 3) 136 | graph.add_edge(3, 4) 137 | graph.add_edge(4, 5) 138 | graph.add_edge(6, 7) 139 | graph.add_edge(7, 2) 140 | graph.add_edge(7, 8) 141 | graph.add_edge(8, 3) 142 | graph.add_edge(8, 10) 143 | graph.add_edge(9, 4) 144 | graph.add_edge(10, 9) 145 | 146 | craft_graph = rasp_to_graph.ExtractRaspGraphOutput( 147 | graph=graph, 148 | sink=graph.nodes[10], 149 | sources=[graph.nodes[0], graph.nodes[6]]) 150 | 151 | layer_allocation = craft_graph_to_model._allocate_modules_to_layers( 152 | craft_graph.graph, craft_graph.sources) 153 | for node_id, node in graph.nodes.items(): 154 | self.assertEqual(layer_allocation[node_id], node["expected_result"]) 155 | 156 | def test_allocate_modules_to_layers_returns_expected_result_for_chain(self): 157 | """Tests a chain of alternating attention layers and MLPs.""" 158 | 159 | # Computation Graph: 160 | # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT 161 | # Node IDs: 162 | # 0 -- 1 -- 2 -- 3 -- 4 -- 5 163 | # Expected layer allocation: 164 | # -1 -- 0 -- 1 -- 2 -- 3 -- -1 165 | 166 | graph = nx.DiGraph() 167 | node_ids = list(range(11)) 168 | types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"] 169 | expected_results = [-1, 0, 1, 2, 3, -1] 170 | for node_id, node_type, res in zip(node_ids, types, expected_results): 171 | graph.add_node( 172 | node_id, **{ 173 | nodes.ID: node_id, 174 | nodes.EXPR: rasp.ConstantSOp(1), 175 | nodes.MODEL_BLOCK: self._get_dummy_block(node_type), 176 | "expected_result": res 177 | }) 178 | 179 | graph.add_edge(0, 1) 180 | graph.add_edge(1, 2) 181 | graph.add_edge(2, 3) 182 | graph.add_edge(3, 4) 183 | graph.add_edge(4, 5) 184 | 185 | craft_graph = rasp_to_graph.ExtractRaspGraphOutput( 186 | graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]]) 187 | 188 | layer_allocation = craft_graph_to_model._allocate_modules_to_layers( 189 | craft_graph.graph, craft_graph.sources) 190 | for node_id, node in graph.nodes.items(): 191 | self.assertEqual(layer_allocation[node_id], node["expected_result"]) 192 | 193 | 194 | if __name__ == "__main__": 195 | absltest.main() 196 | -------------------------------------------------------------------------------- /tracrx/compiler/craft_model_to_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Convert craft model into transformer with the correct input/output spaces.""" 16 | 17 | import networkx as nx 18 | from tracrx.compiler import assemble 19 | from tracrx.compiler import nodes 20 | from tracrx.craft import bases 21 | from tracrx.craft import transformers 22 | from tracrx.rasp import rasp 23 | from tracrx.transformer import encoder 24 | 25 | 26 | def craft_model_to_transformer( 27 | craft_model: transformers.SeriesWithResiduals, 28 | graph: nx.DiGraph, 29 | sink: nodes.Node, 30 | max_seq_len: int, 31 | compiler_bos: str, 32 | compiler_pad: str, 33 | causal: bool = False, 34 | ) -> assemble.AssembledTransformerModel: 35 | """Turn a craft model into a transformer model.""" 36 | 37 | if rasp.tokens.label not in graph.nodes: 38 | raise ValueError( 39 | f'Failed to find a node with label {rasp.tokens.label}. ' 40 | 'This is probably because your RASP program does not include ' 41 | 'rasp.tokens. A program must include rasp.tokens to be ' 42 | 'compiled.' 43 | ) 44 | 45 | # Add the compiler BOS token. 46 | tokens_value_set = ( 47 | graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union( 48 | {compiler_bos, compiler_pad})) 49 | tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label, 50 | tokens_value_set) 51 | 52 | indices_space = bases.VectorSpaceWithBasis.from_values( 53 | rasp.indices.label, range(max_seq_len)) 54 | 55 | categorical_output = rasp.is_categorical(sink[nodes.EXPR]) 56 | output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS]) 57 | 58 | assembled_model = assemble.assemble_craft_model( 59 | craft_model=craft_model, 60 | tokens_space=tokens_space, 61 | indices_space=indices_space, 62 | output_space=output_space, 63 | categorical_output=categorical_output, 64 | causal=causal, 65 | ) 66 | 67 | assembled_model.input_encoder = encoder.CategoricalEncoder( 68 | basis=tokens_space.basis, 69 | enforce_bos=compiler_bos is not None, 70 | bos_token=compiler_bos, 71 | pad_token=compiler_pad, 72 | max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len, 73 | ) 74 | 75 | if categorical_output: 76 | assembled_model.output_encoder = encoder.CategoricalEncoder( 77 | basis=output_space.basis, 78 | enforce_bos=False, 79 | bos_token=None, 80 | pad_token=None) 81 | else: 82 | assembled_model.output_encoder = encoder.NumericalEncoder() 83 | 84 | return assembled_model 85 | -------------------------------------------------------------------------------- /tracrx/compiler/expr_to_craft_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.expr_to_craft_graph.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.compiler import basis_inference 20 | from tracrx.compiler import expr_to_craft_graph 21 | from tracrx.compiler import lib 22 | from tracrx.compiler import nodes 23 | from tracrx.compiler import rasp_to_graph 24 | from tracrx.craft import bases 25 | from tracrx.craft import transformers 26 | from tracrx.rasp import rasp 27 | 28 | 29 | class ExprToCraftGraphTest(parameterized.TestCase): 30 | 31 | def _check_block_types_are_correct(self, graph): 32 | for _, node in graph.nodes.items(): 33 | expr = node[nodes.EXPR] 34 | if isinstance(expr, rasp.SOp): 35 | block = node[nodes.MODEL_BLOCK] 36 | if isinstance(expr, (rasp.Map, rasp.SequenceMap)): 37 | self.assertIsInstance(block, transformers.MLP) 38 | elif isinstance(expr, rasp.Aggregate): 39 | self.assertIsInstance(block, transformers.AttentionHead) 40 | 41 | def _get_input_space_from_node(self, node): 42 | block = node[nodes.MODEL_BLOCK] 43 | if isinstance(block, transformers.MLP): 44 | return block.fst.input_space 45 | elif isinstance(block, transformers.AttentionHead): 46 | return bases.join_vector_spaces(block.w_qk.left_space, 47 | block.w_qk.right_space, 48 | block.w_ov.input_space) 49 | else: 50 | return None 51 | 52 | def _check_spaces_are_consistent(self, graph): 53 | """Check that for each edge the output is a subspace of the input.""" 54 | for u, v in graph.edges: 55 | u_node, v_node = graph.nodes[u], graph.nodes[v] 56 | if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance( 57 | v_node[nodes.EXPR], rasp.SOp): 58 | u_out_basis = u_node[nodes.OUTPUT_BASIS] 59 | u_out_space = bases.VectorSpaceWithBasis(u_out_basis) 60 | v_in_space = self._get_input_space_from_node(v_node) 61 | self.assertTrue(u_out_space.issubspace(v_in_space)) 62 | 63 | @parameterized.named_parameters( 64 | dict( 65 | testcase_name="single_map", 66 | program=rasp.Map(lambda x: x + 1, rasp.tokens), 67 | ), 68 | dict( 69 | testcase_name="single_sequence_map", 70 | program=rasp.SequenceMap( 71 | lambda x, y: x + y, rasp.tokens, rasp.indices 72 | ), 73 | ), 74 | dict( 75 | testcase_name="single_select_aggregate", 76 | program=rasp.Aggregate( 77 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 78 | rasp.tokens, 79 | ), 80 | ), 81 | dict(testcase_name="reverse", program=lib.make_reverse(rasp.tokens)), 82 | dict(testcase_name="length", program=lib.make_length()), 83 | dict( 84 | testcase_name="annotated_tokens", 85 | program=rasp.annotate(rasp.tokens, foo="foo"), 86 | ), 87 | dict( 88 | testcase_name="annotated_indices", 89 | program=rasp.annotate(rasp.indices, foo="foo"), 90 | ), 91 | ) 92 | def test_compiling_rasp_programs(self, program): 93 | vocab = {0, 1, 2} 94 | extracted = rasp_to_graph.extract_rasp_graph(program) 95 | basis_inference.infer_bases( 96 | extracted.graph, 97 | extracted.sink, 98 | vocab, 99 | max_seq_len=3, 100 | ) 101 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 102 | self._check_block_types_are_correct(extracted.graph) 103 | self._check_spaces_are_consistent(extracted.graph) 104 | 105 | def test_add_craft_components_raises_value_error_if_called_before_basis_inference( 106 | self): 107 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 108 | extracted = rasp_to_graph.extract_rasp_graph(program) 109 | 110 | with self.assertRaisesRegex( 111 | ValueError, 112 | r"^.*Craft components can only be added after basis inference.*$"): 113 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 114 | 115 | def test_add_craft_components_raises_value_error_if_called_twice(self): 116 | vocab = {0, 1, 2} 117 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 118 | extracted = rasp_to_graph.extract_rasp_graph(program) 119 | 120 | basis_inference.infer_bases( 121 | extracted.graph, 122 | extracted.sink, 123 | vocab, 124 | max_seq_len=1, 125 | ) 126 | 127 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 128 | with self.assertRaisesRegex( 129 | ValueError, r"^.*Input graph cannot have model blocks set already.*$"): 130 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 131 | 132 | 133 | if __name__ == "__main__": 134 | absltest.main() 135 | -------------------------------------------------------------------------------- /tracrx/compiler/lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.lib.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.compiler import test_cases 20 | from tracrx.rasp import causal_eval 21 | from tracrx.rasp import rasp 22 | 23 | 24 | class LibTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters(*test_cases.TEST_CASES) 27 | def test_program_produces_expected_output(self, program, test_input, 28 | expected_output, **kwargs): 29 | del kwargs 30 | self.assertEqual(rasp.evaluate(program, test_input), expected_output) 31 | 32 | @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES) 33 | def test_causal_program_produces_expected_output(self, program, test_input, 34 | expected_output, **kwargs): 35 | del kwargs 36 | self.assertEqual(causal_eval.evaluate(program, test_input), expected_output) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /tracrx/compiler/nodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Documents the data stored in nodes after each compiler pass.""" 16 | 17 | from typing import Any, Dict 18 | 19 | Node = Dict[str, Any] 20 | NodeID = str 21 | 22 | # RASP -> Graph 23 | ID = "ID" # unique ID of the node 24 | EXPR = "EXPR" # the RASPExpr of the node 25 | 26 | # Basis inference 27 | # Note that only S-Op expressions will have these keys set. 28 | VALUE_SET = "VALUE_SET" # possible values taken on by this SOp. 29 | OUTPUT_BASIS = "OUTPUT_BASIS" # the corresponding named basis. 30 | 31 | # RASP Graph -> Craft Graph 32 | MODEL_BLOCK = "MODEL_BLOCK" # craft block representing a RASPExpr 33 | -------------------------------------------------------------------------------- /tracrx/compiler/rasp_to_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converting a RaspExpr to a graph.""" 16 | 17 | import dataclasses 18 | import queue 19 | from typing import List 20 | 21 | import networkx as nx 22 | from tracrx.compiler import nodes 23 | from tracrx.rasp import rasp 24 | 25 | Node = nodes.Node 26 | NodeID = nodes.NodeID 27 | 28 | 29 | @dataclasses.dataclass 30 | class ExtractRaspGraphOutput: 31 | graph: nx.DiGraph 32 | sink: Node # the program's output. 33 | sources: List[Node] # the primitive S-Ops. 34 | 35 | 36 | def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput: 37 | """Converts a RASP program into a graph representation.""" 38 | expr_queue = queue.Queue() 39 | graph = nx.DiGraph() 40 | sources: List[NodeID] = [] 41 | 42 | def ensure_node(expr: rasp.RASPExpr) -> NodeID: 43 | """Finds or creates a graph node corresponding to expr; returns its ID.""" 44 | node_id = expr.label 45 | if node_id not in graph: 46 | graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr}) 47 | 48 | return node_id 49 | 50 | # Breadth-first search over the RASP expression graph. 51 | 52 | def visit_raspexpr(expr: rasp.RASPExpr): 53 | parent_id = ensure_node(expr) 54 | 55 | for child_expr in expr.children: 56 | expr_queue.put(child_expr) 57 | child_id = ensure_node(child_expr) 58 | graph.add_edge(child_id, parent_id) 59 | 60 | if not expr.children: 61 | sources.append(graph.nodes[parent_id]) 62 | 63 | expr_queue.put(tip) 64 | sink = graph.nodes[ensure_node(tip)] 65 | while not expr_queue.empty(): 66 | visit_raspexpr(expr_queue.get()) 67 | 68 | return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources) 69 | -------------------------------------------------------------------------------- /tracrx/compiler/rasp_to_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.rasp_to_graph.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.compiler import nodes 20 | from tracrx.compiler import rasp_to_graph 21 | from tracrx.rasp import rasp 22 | 23 | 24 | class ExtractRaspGraphTest(parameterized.TestCase): 25 | 26 | def test_primitives_have_no_edges(self): 27 | tokens_graph = rasp_to_graph.extract_rasp_graph(rasp.tokens).graph 28 | self.assertEmpty(tokens_graph.edges) 29 | 30 | indices_graph = rasp_to_graph.extract_rasp_graph(rasp.indices).graph 31 | self.assertEmpty(indices_graph.edges) 32 | 33 | full_graph = rasp_to_graph.extract_rasp_graph(rasp.Full(1)).graph 34 | self.assertEmpty(full_graph.edges) 35 | 36 | def test_one_edge(self): 37 | program = rasp.Map(lambda x: x + 1, rasp.tokens) 38 | 39 | graph = rasp_to_graph.extract_rasp_graph(program).graph 40 | 41 | self.assertLen(graph.edges, 1) 42 | (u, v), = graph.edges 43 | self.assertEqual(graph.nodes[u][nodes.EXPR], rasp.tokens) 44 | self.assertEqual(graph.nodes[v][nodes.EXPR], program) 45 | 46 | def test_aggregate(self): 47 | program = rasp.Aggregate( 48 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 49 | rasp.indices, 50 | ) 51 | 52 | extracted = rasp_to_graph.extract_rasp_graph(program) 53 | 54 | # Expected graph: 55 | # 56 | # indices \ -------- 57 | # \ \ 58 | # select -- program 59 | # tokens / 60 | 61 | self.assertLen(extracted.graph.edges, 4) 62 | self.assertEqual(extracted.sink[nodes.EXPR], program) 63 | for source in extracted.sources: 64 | self.assertIn( 65 | source[nodes.EXPR], 66 | [rasp.tokens, rasp.indices], 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /tracrx/compiler/validating_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.compilable_evaluator.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.compiler import test_cases 20 | from tracrx.compiler import validating 21 | from tracrx.rasp import rasp 22 | 23 | 24 | class ValidationEvaluatorTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters(test_cases.TEST_CASES) 27 | def test_supported_programs_pass_validation( 28 | self, 29 | program, 30 | test_input, 31 | **kwargs, 32 | ): 33 | del kwargs 34 | validation_result = validating.validate(program, test_input) 35 | self.assertEmpty(validation_result) 36 | 37 | @parameterized.named_parameters(test_cases.UNSUPPORTED_TEST_CASES) 38 | def test_unsupported_programs_fail_validation( 39 | self, 40 | program, 41 | vocab, 42 | **kwargs, 43 | ): 44 | del kwargs 45 | test_input = sorted(list(vocab)) 46 | validation_result = validating.validate(program, test_input) 47 | self.assertNotEmpty(validation_result) 48 | 49 | @parameterized.named_parameters( 50 | dict( 51 | testcase_name="mean", 52 | program=rasp.Aggregate( 53 | rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE), 54 | rasp.tokens, 55 | ), 56 | test_input=[1, 2, 3, 4], 57 | ), 58 | dict( 59 | testcase_name="prev_mean", 60 | program=rasp.Aggregate( 61 | rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ), 62 | rasp.tokens, 63 | ), 64 | test_input=[1, 2, 3, 4], 65 | ), 66 | ) 67 | def test_dynamic_failure_cases_fail_validation( 68 | self, 69 | program, 70 | test_input, 71 | ): 72 | # Dynamic test cases are not in the general test case suite because they are 73 | # not caught at compile time. 74 | validation_result = validating.validate(program, test_input) 75 | self.assertNotEmpty(validation_result) 76 | 77 | 78 | if __name__ == "__main__": 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /tracrx/craft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/craft/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/__pycache__/bases.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/__pycache__/bases.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/__pycache__/transformers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/__pycache__/transformers.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/__pycache__/vectorspace_fns.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/__pycache__/vectorspace_fns.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/bases_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for bases.""" 16 | 17 | from absl.testing import absltest 18 | import numpy as np 19 | from tracrx.craft import bases 20 | from tracrx.craft import tests_common 21 | 22 | 23 | class VectorInBasisTest(tests_common.VectorFnTestCase): 24 | 25 | def test_shape_mismatch_raises_value_error(self): 26 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 27 | regex = ( 28 | r"^.*Last dimension of magnitudes must be the same as number of " 29 | r"basis directions.*$" 30 | ) 31 | with self.assertRaisesRegex(ValueError, regex): 32 | bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 33 | with self.assertRaisesRegex(ValueError, regex): 34 | bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 35 | 36 | def test_equal(self): 37 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 38 | v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 39 | v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 40 | self.assertEqual(v1, v2) 41 | self.assertEqual(v2, v1) 42 | v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 43 | v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 44 | self.assertEqual(v3, v4) 45 | self.assertEqual(v4, v3) 46 | v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 47 | v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1])) 48 | self.assertNotEqual(v5, v6) 49 | self.assertNotEqual(v6, v5) 50 | v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 51 | v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]])) 52 | self.assertNotEqual(v7, v8) 53 | self.assertNotEqual(v8, v7) 54 | vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"]) 55 | v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 56 | v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4])) 57 | self.assertNotEqual(v9, v10) 58 | self.assertNotEqual(v10, v9) 59 | 60 | def test_dunders(self): 61 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 62 | v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) 63 | three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) 64 | five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5])) 65 | v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10])) 66 | self.assertEqual(5 * v, v_times_5) 67 | self.assertEqual(v * 5, v_times_5) 68 | self.assertEqual(5.0 * v, v_times_5) 69 | self.assertEqual(v * 5.0, v_times_5) 70 | v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1])) 71 | self.assertEqual(v / 2, v_by_2) 72 | self.assertEqual(v / 2.0, v_by_2) 73 | self.assertEqual(1 / 2 * v, v_by_2) 74 | v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) 75 | self.assertEqual(v + three, v_plus_3) 76 | self.assertEqual(three + v, v_plus_3) 77 | v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3])) 78 | self.assertEqual(v - five, v_minus_5) 79 | minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2])) 80 | self.assertEqual(-v, minus_v) 81 | 82 | def test_add_directions(self): 83 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 84 | expected = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) 85 | v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) 86 | three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) 87 | shifted = v.add_directions(three) 88 | self.assertEqual(shifted, expected) 89 | 90 | 91 | class ProjectionTest(tests_common.VectorFnTestCase): 92 | 93 | def test_direct_sum_produces_expected_result(self): 94 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 95 | vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) 96 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"]) 97 | self.assertEqual(bases.direct_sum(vs1, vs2), vs3) 98 | 99 | def test_join_vector_spaces_produces_expected_result(self): 100 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 101 | vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) 102 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 103 | self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) 104 | 105 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 106 | vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"]) 107 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 108 | self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) 109 | 110 | def test_compare_vectors_with_differently_ordered_basis_vectors(self): 111 | basis1 = ["a", "b", "c", "d"] 112 | basis1 = [bases.BasisDirection(x) for x in basis1] 113 | basis2 = ["b", "d", "a", "c"] 114 | basis2 = [bases.BasisDirection(x) for x in basis2] 115 | vs1 = bases.VectorSpaceWithBasis(basis1) 116 | vs2 = bases.VectorSpaceWithBasis(basis2) 117 | v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4])) 118 | v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3])) 119 | self.assertEqual(v1, v2) 120 | self.assertEqual(v1 - v2, vs1.null_vector()) 121 | self.assertEqual(v1 - v2, vs2.null_vector()) 122 | self.assertEqual(v1 + v2, 2 * v2) 123 | self.assertIn(v1, vs1) 124 | self.assertIn(v1, vs2) 125 | self.assertIn(v2, vs1) 126 | self.assertIn(v2, vs2) 127 | 128 | def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self): 129 | basis1 = ["a", "b", "c", "d"] 130 | basis1 = [bases.BasisDirection(x) for x in basis1] 131 | basis2 = ["b", "d", "a", "c"] 132 | basis2 = [bases.BasisDirection(x) for x in basis2] 133 | vs1 = bases.VectorSpaceWithBasis(basis1) 134 | vs2 = bases.VectorSpaceWithBasis(basis2) 135 | v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]])) 136 | v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]])) 137 | null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()]) 138 | self.assertEqual(v1, v2) 139 | self.assertEqual(v1 - v2, null_vec) 140 | self.assertEqual(v1 + v2, 2 * v2) 141 | self.assertIn(v1, vs1) 142 | self.assertIn(v1, vs2) 143 | self.assertIn(v2, vs1) 144 | self.assertIn(v2, vs2) 145 | 146 | def test_projection_to_larger_space(self): 147 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 148 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 149 | a1, b1 = vs1.basis_vectors() 150 | a2, b2, _, _ = vs2.basis_vectors() 151 | 152 | self.assertEqual(a1.project(vs2), a2) 153 | self.assertEqual(b1.project(vs2), b2) 154 | 155 | def test_projection_to_smaller_space(self): 156 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 157 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 158 | a1, b1, c1, d1 = vs1.basis_vectors() 159 | a2, b2 = vs2.basis_vectors() 160 | 161 | self.assertEqual(a1.project(vs2), a2) 162 | self.assertEqual(b1.project(vs2), b2) 163 | self.assertEqual(c1.project(vs2), vs2.null_vector()) 164 | self.assertEqual(d1.project(vs2), vs2.null_vector()) 165 | 166 | 167 | if __name__ == "__main__": 168 | absltest.main() 169 | -------------------------------------------------------------------------------- /tracrx/craft/chamber/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/craft/chamber/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/chamber/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/chamber/__pycache__/categorical_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/chamber/__pycache__/categorical_attn.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/chamber/__pycache__/categorical_mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/chamber/__pycache__/categorical_mlp.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/chamber/__pycache__/numerical_mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/chamber/__pycache__/numerical_mlp.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/chamber/__pycache__/selector_width.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/craft/chamber/__pycache__/selector_width.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/craft/chamber/categorical_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention head for categorical inputs.""" 16 | 17 | from typing import Optional 18 | 19 | from tracrx.craft import bases 20 | from tracrx.craft import transformers 21 | from tracrx.craft import vectorspace_fns 22 | from typing_extensions import Protocol 23 | 24 | 25 | class QueryKeyToAttnLogit(Protocol): 26 | 27 | def __call__(self, query: bases.BasisDirection, 28 | key: bases.BasisDirection) -> bool: 29 | pass 30 | 31 | 32 | def categorical_attn( 33 | query_space: bases.VectorSpaceWithBasis, 34 | key_space: bases.VectorSpaceWithBasis, 35 | value_space: bases.VectorSpaceWithBasis, 36 | output_space: bases.VectorSpaceWithBasis, 37 | bos_space: bases.VectorSpaceWithBasis, 38 | one_space: bases.VectorSpaceWithBasis, 39 | attn_fn: QueryKeyToAttnLogit, 40 | default_output: Optional[bases.VectorInBasis] = None, 41 | causal: bool = False, 42 | always_attend_to_bos: bool = False, 43 | use_bos_for_default_output: bool = True, 44 | softmax_coldness: float = 100., 45 | ) -> transformers.AttentionHead: 46 | """Returns an attention head for categorical inputs. 47 | 48 | Assumes the existence of a beginning of sequence token and attends to it 49 | always with strength 0.5*softmax_coldness. This allows to implement an 50 | arbitrary default value for rows in the attention pattern that are all-zero. 51 | 52 | Attends to the BOS token if all other key-query pairs have zero attention. 53 | Hence, the first value in the value sequence will be the default output for 54 | such cases. 55 | 56 | Args: 57 | query_space: Vector space containing (categorical) query input. 58 | key_space: Vector space containing (categorical) key input. 59 | value_space: Vector space containing (numerical) value input. 60 | output_space: Vector space which will contain (numerical) output. 61 | bos_space: 1-d space used to identify the beginning of sequence token. 62 | one_space: 1-d space which contains 1 at every position. 63 | attn_fn: A selector function f(query, key) operating on the query/key basis 64 | directions that defines the attention pattern. 65 | default_output: Output to return if attention pattern is all zero. 66 | causal: If True, use masked attention. 67 | always_attend_to_bos: If True, always attend to the BOS token. If False, 68 | only attend to BOS when attending to nothing else. 69 | use_bos_for_default_output: If True, assume BOS is not in the value space 70 | and output a default value when attending to BOS. If False, assume BOS is 71 | in the value space, and map it to the output space like any other token. 72 | softmax_coldness: The inverse temperature of the softmax. Default value is 73 | high which makes the attention close to a hard maximum. 74 | """ 75 | bases.ensure_dims(bos_space, num_dims=1, name="bos_space") 76 | bases.ensure_dims(one_space, num_dims=1, name="one_space") 77 | bos_direction = bos_space.basis[0] 78 | one_direction = one_space.basis[0] 79 | 80 | # Add bos direction to query, key, and value spaces in case it is missing 81 | query_space = bases.join_vector_spaces(query_space, bos_space, one_space) 82 | key_space = bases.join_vector_spaces(key_space, bos_space) 83 | value_space = bases.join_vector_spaces(value_space, bos_space) 84 | 85 | if always_attend_to_bos: 86 | value_basis = value_space.basis 87 | else: 88 | value_basis = [v for v in value_space.basis if v != bos_direction] 89 | assert len(value_basis) == output_space.num_dims 90 | value_to_output = dict(zip(value_basis, output_space.basis)) 91 | 92 | if default_output is None: 93 | default_output = output_space.null_vector() 94 | assert default_output in output_space 95 | 96 | def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float: 97 | 98 | # We want to enforce the following property on our attention patterns: 99 | # - if nothing else is attended to, attend to the BOS token. 100 | # - otherwise, don't attend to the BOS token. 101 | # 102 | # We assume that the BOS position always only contains the vector bos + one, 103 | # and that any other position has bos coefficient 0. 104 | # 105 | # We do this as follows: 106 | # Let Q and K be subspaces of V containing the query and key vectors, 107 | # both disjoint with the BOS space {bos} or the one space {one}. 108 | # Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> ℝ, 109 | # s.t. W_QK(q, k) = 0 whenever either q or k are bos or one. 110 | # 111 | # Then define W_new: V x V -> ℝ st: 112 | # W_new(one, bos) = 0.5, otherwise 0. 113 | # 114 | # Now set W_QK' = W_QK + W_new. 115 | # 116 | # To evaluate the attention to the BOS position: 117 | # W_QK'(q, bos + one) 118 | # = W_QK'(q, bos) + W_QK'(q, one) 119 | # = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one) 120 | # = 0 + 0 + W_new(q, bos) + W_new(q, one) 121 | # = W_new(q, bos) + W_new(q, one) 122 | # = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q' 123 | # = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one) 124 | # = 0 + 0.5 + 0 + 0 125 | # = 0.5 126 | # 127 | # To evaluate the attention to a non-BOS position: 128 | # W_QK'(0 * bos + q, 0 * bos + k) # s.t. q ∈ Q+{one}, k ∈ K+{one} 129 | # = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k) 130 | # = W_QK'(q, 0*bos + k) 131 | # = 0*W_QK'(q, bos) + W_QK'(q, k) 132 | # = W_QK'(q, k) 133 | # = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos. 134 | # = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one. 135 | # 136 | # Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax 137 | # coldness will give us the desired property. QED 138 | # 139 | # The following implements this idea. 140 | # By replacing 0.5 with 1, we can instead enforce a different property: that 141 | # the BOS token is always attended to in addition to whatever else. 142 | 143 | if key == bos_direction and query == one_direction: 144 | c = 1. if always_attend_to_bos else 0.5 145 | return c * softmax_coldness 146 | elif {key, query}.intersection({one_direction, bos_direction}): 147 | return 0 148 | 149 | return softmax_coldness * attn_fn(query, key) 150 | 151 | w_qk = vectorspace_fns.ScalarBilinear.from_action( 152 | query_space, 153 | key_space, 154 | qk_fun, 155 | ) 156 | 157 | def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis: 158 | if use_bos_for_default_output and input_dir == bos_direction: 159 | return default_output 160 | return output_space.vector_from_basis_direction(value_to_output[input_dir]) 161 | 162 | w_ov = vectorspace_fns.Linear.from_action( 163 | value_space, 164 | output_space, 165 | ov_fun, 166 | ) 167 | 168 | return transformers.AttentionHead(w_qk, w_ov, causal=causal) 169 | -------------------------------------------------------------------------------- /tracrx/craft/chamber/categorical_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MLP to compute basic linear functions of one-hot encoded integers.""" 16 | 17 | from typing import Callable 18 | 19 | import numpy as np 20 | 21 | from tracrx.craft import bases 22 | from tracrx.craft import transformers 23 | from tracrx.craft import vectorspace_fns 24 | 25 | _ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"]) 26 | 27 | 28 | def map_categorical_mlp( 29 | input_space: bases.VectorSpaceWithBasis, 30 | output_space: bases.VectorSpaceWithBasis, 31 | operation: Callable[[bases.BasisDirection], bases.BasisDirection], 32 | ) -> transformers.MLP: 33 | """Returns an MLP that encodes any categorical function of a single variable f(x). 34 | 35 | The hidden layer is the identity and output combines this with a lookup table 36 | output_k = sum(f(i)*input_i for all i in input space) 37 | 38 | Args: 39 | input_space: space containing the input x. 40 | output_space: space containing possible outputs. 41 | operation: A function operating on basis directions. 42 | """ 43 | 44 | def operation_fn(direction): 45 | if direction in input_space: 46 | output_direction = operation(direction) 47 | if output_direction in output_space: 48 | return output_space.vector_from_basis_direction(output_direction) 49 | return output_space.null_vector() 50 | 51 | first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, 52 | operation_fn) 53 | 54 | second_layer = vectorspace_fns.project(output_space, output_space) 55 | 56 | return transformers.MLP(first_layer, second_layer) 57 | 58 | 59 | def map_categorical_to_numerical_mlp( 60 | input_space: bases.VectorSpaceWithBasis, 61 | output_space: bases.VectorSpaceWithBasis, 62 | operation: Callable[[bases.Value], float], 63 | ) -> transformers.MLP: 64 | """Returns an MLP to compute f(x) from a categorical to a numerical variable. 65 | 66 | The hidden layer is the identity and output combines this with a lookup table 67 | output = sum(f(i)*input_i for all i in input space) 68 | 69 | Args: 70 | input_space: Vector space containing the input x. 71 | output_space: Vector space to write the numerical output to. 72 | operation: A function operating on basis directions. 73 | """ 74 | bases.ensure_dims(output_space, num_dims=1, name="output_space") 75 | out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) 76 | 77 | def operation_fn(direction): 78 | if direction in input_space: 79 | return operation(direction.value) * out_vec 80 | return output_space.null_vector() 81 | 82 | first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, 83 | operation_fn) 84 | 85 | second_layer = vectorspace_fns.project(output_space, output_space) 86 | 87 | return transformers.MLP(first_layer, second_layer) 88 | 89 | 90 | def sequence_map_categorical_mlp( 91 | input1_space: bases.VectorSpaceWithBasis, 92 | input2_space: bases.VectorSpaceWithBasis, 93 | output_space: bases.VectorSpaceWithBasis, 94 | operation: Callable[[bases.BasisDirection, bases.BasisDirection], 95 | bases.BasisDirection], 96 | one_space: bases.VectorSpaceWithBasis = _ONE_SPACE, 97 | hidden_name: bases.Name = "__hidden__", 98 | ) -> transformers.MLP: 99 | """Returns an MLP that encodes a categorical function of two variables f(x, y). 100 | 101 | The hidden layer of the MLP computes the logical and of all input directions 102 | hidden_i_j = ReLU(x_i+x_j-1) 103 | 104 | And the output combines this with a lookup table 105 | output_k = sum(f(i, j)*hidden_i_j for all i,j in input space) 106 | 107 | Args: 108 | input1_space: Vector space containing the input x. 109 | input2_space: Vector space containing the input y. 110 | output_space: Vector space to write outputs to. 111 | operation: A function operating on basis directions. 112 | one_space: a reserved 1-d space that always contains a 1. 113 | hidden_name: Name for hidden dimensions. 114 | """ 115 | bases.ensure_dims(one_space, num_dims=1, name="one_space") 116 | 117 | if not set(input1_space.basis).isdisjoint(input2_space.basis): 118 | raise ValueError("Input spaces to a SequenceMap must be disjoint. " 119 | "If input spaces are the same, use Map instead!") 120 | 121 | input_space = bases.direct_sum(input1_space, input2_space, one_space) 122 | 123 | def to_hidden(x, y): 124 | return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value)) 125 | 126 | def from_hidden(h): 127 | x_name, x_value, y_name, y_value = h.value 128 | x_dir = bases.BasisDirection(x_name, x_value) 129 | y_dir = bases.BasisDirection(y_name, y_value) 130 | return x_dir, y_dir 131 | 132 | hidden_dir = [] 133 | for dir1 in input1_space.basis: 134 | for dir2 in input2_space.basis: 135 | hidden_dir.append(to_hidden(dir1, dir2)) 136 | hidden_space = bases.VectorSpaceWithBasis(hidden_dir) 137 | 138 | def logical_and(direction): 139 | if direction in one_space: 140 | out = bases.VectorInBasis(hidden_space.basis, 141 | -np.ones(hidden_space.num_dims)) 142 | elif direction in input1_space: 143 | dir1 = direction 144 | out = hidden_space.null_vector() 145 | for dir2 in input2_space.basis: 146 | vector = bases.VectorInBasis( 147 | [to_hidden(dir1, dir2)], np.array([1]), _basis_is_sorted=True 148 | ) 149 | out = out.add_directions(vector) 150 | else: 151 | dir2 = direction 152 | out = hidden_space.null_vector() 153 | for dir1 in input1_space.basis: 154 | vector = bases.VectorInBasis( 155 | [to_hidden(dir1, dir2)], np.array([1]), _basis_is_sorted=True 156 | ) 157 | out = out.add_directions(vector) 158 | return out 159 | 160 | first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, 161 | logical_and) 162 | 163 | def operation_fn(direction): 164 | dir1, dir2 = from_hidden(direction) 165 | output_direction = operation(dir1, dir2) 166 | if output_direction in output_space: 167 | return output_space.vector_from_basis_direction(output_direction) 168 | else: 169 | return output_space.null_vector() 170 | 171 | second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, 172 | operation_fn) 173 | 174 | return transformers.MLP(first_layer, second_layer) 175 | -------------------------------------------------------------------------------- /tracrx/craft/chamber/categorical_mlp_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for chamber.categorical_mlp.""" 16 | 17 | import math 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from tracrx.craft import bases 22 | from tracrx.craft import tests_common 23 | from tracrx.craft.chamber import categorical_mlp 24 | 25 | 26 | class CategoricalInputMlpTest(tests_common.VectorFnTestCase): 27 | 28 | @parameterized.parameters([ 29 | dict(num_counts=4, x=1, y=2, fun=lambda x, y: x + y, result=3), 30 | dict(num_counts=4, x=1, y=0, fun=lambda x, y: x + y + 1, result=2), 31 | dict(num_counts=5, x=2, y=1, fun=math.pow, result=2), 32 | dict(num_counts=5, x=2, y=2, fun=math.pow, result=4), 33 | ]) 34 | def test_seq_map_categorical_mlp_produces_expected_outcome( 35 | self, num_counts, x, y, fun, result): 36 | input1_name = "in1" 37 | input2_name = "in2" 38 | output_name = "out" 39 | one_name = "one_dimension" 40 | 41 | in1_space = bases.VectorSpaceWithBasis.from_values(input1_name, 42 | range(num_counts + 1)) 43 | in2_space = bases.VectorSpaceWithBasis.from_values(input2_name, 44 | range(num_counts + 1)) 45 | out_space = bases.VectorSpaceWithBasis.from_values(output_name, 46 | range(num_counts + 1)) 47 | 48 | def operation(in1, in2): 49 | out_val = fun(int(in1.value), int(in2.value)) 50 | return bases.BasisDirection(output_name, out_val) 51 | 52 | mlp = categorical_mlp.sequence_map_categorical_mlp( 53 | input1_space=in1_space, 54 | input2_space=in2_space, 55 | output_space=out_space, 56 | operation=operation, 57 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 58 | 59 | test_inputs = ( 60 | mlp.residual_space.vector_from_basis_direction( 61 | bases.BasisDirection(one_name)) + 62 | mlp.residual_space.vector_from_basis_direction( 63 | bases.BasisDirection(input1_name, x)) + 64 | mlp.residual_space.vector_from_basis_direction( 65 | bases.BasisDirection(input2_name, y))) 66 | 67 | expected_results = mlp.residual_space.vector_from_basis_direction( 68 | bases.BasisDirection(output_name, result)) 69 | 70 | test_outputs = mlp.apply(test_inputs) 71 | 72 | self.assertVectorAllClose(test_outputs, expected_results) 73 | 74 | def test_seq_map_categorical_mlp_raises_error_with_overlapping_inputs(self): 75 | input_name = "in" 76 | output_name = "out" 77 | one_name = "one_dimension" 78 | 79 | in1_space = bases.VectorSpaceWithBasis.from_values(input_name, range(5)) 80 | in2_space = bases.VectorSpaceWithBasis.from_values(input_name, range(3, 10)) 81 | out_space = bases.VectorSpaceWithBasis.from_values(output_name, range(5)) 82 | 83 | with self.assertRaisesRegex( 84 | ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): 85 | categorical_mlp.sequence_map_categorical_mlp( 86 | input1_space=in1_space, 87 | input2_space=in1_space, 88 | output_space=out_space, 89 | operation=lambda x, y: bases.BasisDirection(output_name, 0), 90 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 91 | 92 | with self.assertRaisesRegex( 93 | ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): 94 | categorical_mlp.sequence_map_categorical_mlp( 95 | input1_space=in1_space, 96 | input2_space=in2_space, 97 | output_space=out_space, 98 | operation=lambda x, y: bases.BasisDirection(output_name, 0), 99 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 100 | 101 | @parameterized.parameters([ 102 | dict(num_counts=5, x=2, fun=lambda x: x, result=2), 103 | dict(num_counts=5, x=2, fun=lambda x: math.pow(x, int(2)), result=4), 104 | dict(num_counts=5, x=-2, fun=lambda x: math.pow(x, int(2)), result=4), 105 | dict(num_counts=5, x=-1, fun=lambda x: math.pow(x, int(3)), result=-1), 106 | ]) 107 | def test_map_categorical_mlp_produces_expected_outcome_computing_powers( 108 | self, num_counts, x, fun, result): 109 | input_name = "in" 110 | output_name = "out" 111 | 112 | in_space = bases.VectorSpaceWithBasis.from_values( 113 | input_name, range(-num_counts, num_counts + 1)) 114 | out_space = bases.VectorSpaceWithBasis.from_values( 115 | output_name, range(-num_counts, num_counts + 1)) 116 | 117 | def operation(direction): 118 | out_val = fun(int(direction.value)) 119 | return bases.BasisDirection(output_name, out_val) 120 | 121 | mlp = categorical_mlp.map_categorical_mlp( 122 | input_space=in_space, output_space=out_space, operation=operation) 123 | 124 | test_inputs = mlp.residual_space.vector_from_basis_direction( 125 | bases.BasisDirection(input_name, x)) 126 | 127 | expected_results = mlp.residual_space.vector_from_basis_direction( 128 | bases.BasisDirection(output_name, result)) 129 | 130 | test_outputs = mlp.apply(test_inputs) 131 | 132 | self.assertVectorAllClose(test_outputs, expected_results) 133 | 134 | @parameterized.parameters([ 135 | dict(x=2, fun=lambda x: x, result=2), 136 | dict(x=2, fun=lambda x: math.pow(x, int(2)), result=4), 137 | dict(x=1, fun=lambda x: 1 / (x + 1), result=0.5), 138 | dict(x=3, fun=lambda x: 1 / (x + 1), result=0.25), 139 | ]) 140 | def test_map_categorical_to_numerical_mlp_produces_expected_outcome( 141 | self, x, fun, result): 142 | 143 | in_space = bases.VectorSpaceWithBasis.from_values("in", range(6)) 144 | out_space = bases.VectorSpaceWithBasis.from_names(["out"]) 145 | 146 | mlp = categorical_mlp.map_categorical_to_numerical_mlp( 147 | input_space=in_space, 148 | output_space=out_space, 149 | operation=fun, 150 | ) 151 | 152 | test_inputs = mlp.residual_space.vector_from_basis_direction( 153 | bases.BasisDirection("in", x)) 154 | 155 | expected_results = result * mlp.residual_space.vector_from_basis_direction( 156 | bases.BasisDirection("out")) 157 | 158 | test_outputs = mlp.apply(test_inputs) 159 | 160 | self.assertVectorAllClose(test_outputs, expected_results) 161 | 162 | 163 | if __name__ == "__main__": 164 | absltest.main() 165 | -------------------------------------------------------------------------------- /tracrx/craft/chamber/selector_width.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SelectorWidth component consisting of an attention head and an MLP.""" 16 | 17 | from typing import Iterable 18 | from tracrx.craft import bases 19 | from tracrx.craft import transformers 20 | from tracrx.craft import vectorspace_fns 21 | from tracrx.craft.chamber import categorical_attn 22 | from tracrx.craft.chamber import numerical_mlp 23 | 24 | 25 | def selector_width( 26 | query_space: bases.VectorSpaceWithBasis, 27 | key_space: bases.VectorSpaceWithBasis, 28 | output_space: bases.VectorSpaceWithBasis, 29 | bos_space: bases.VectorSpaceWithBasis, 30 | one_space: bases.VectorSpaceWithBasis, 31 | attn_fn: categorical_attn.QueryKeyToAttnLogit, 32 | out_value_set: Iterable[float], 33 | categorical_output: bool, 34 | causal: bool = False, 35 | softmax_coldness: float = 100., 36 | mlp_large_number: float = 100., 37 | label: str = "", 38 | ) -> transformers.SeriesWithResiduals: 39 | """Returns a craft block implementing RASP's SelectorWidth primitive. 40 | 41 | The block consists of one attention head and one MLP. 42 | 43 | The attention head implements the attention pattern (attn_fn or key=bos) and 44 | aggregates the bos dimension over this pattern. The output of this will be 45 | 1/(d+1) in every position, where d is the "width" of the attention pattern, 46 | i.e. the number of 1s in a row. 47 | 48 | The MLP then computes d from the previous output in all positions except for 49 | the first BOS position. In the BOS position the MLP removes the output of the 50 | attention head, to ensure it only contains the encoding of the BOS token 51 | which is expected by all other model components. 52 | 53 | Args: 54 | query_space: Vector space containing (categorical) query input. 55 | key_space: Vector space containing (categorical) key input. 56 | output_space: Vector space which will contain (numerical or categorical) 57 | output. 58 | bos_space: 1-d space used to identify the beginning of sequence token. 59 | one_space: Auxiliary 1-d vector space that must contain 1 in the input. 60 | attn_fn: A selector function f(query, key) operating on the query/key basis 61 | directions that defines the attention pattern to compute the width of. 62 | out_value_set: Set of possible output values of this SelectorWidth. 63 | categorical_output: If True, encode the output as a categorical variable. 64 | causal: If True, use masked attention. 65 | softmax_coldness: The inverse temperature of the softmax. Default value is 66 | high which makes the attention close to a hard maximum. 67 | mlp_large_number: A larger number makes the MLP more accurate. 68 | label: A name for this block, used to label auxiliary dimensions. 69 | """ 70 | assert output_space.num_dims == 1 or categorical_output 71 | 72 | attn_out_dir = bases.BasisDirection(f"{label}_selector_width_attn_output") 73 | attn_out_space = bases.VectorSpaceWithBasis([attn_out_dir]) 74 | attn_out_vec = attn_out_space.vector_from_basis_direction(attn_out_dir) 75 | 76 | attn = categorical_attn.categorical_attn( 77 | query_space=query_space, 78 | key_space=key_space, 79 | value_space=bos_space, 80 | output_space=attn_out_space, 81 | bos_space=bos_space, 82 | one_space=one_space, 83 | attn_fn=attn_fn, 84 | default_output=attn_out_space.null_vector(), 85 | causal=causal, 86 | always_attend_to_bos=True, 87 | use_bos_for_default_output=False, 88 | softmax_coldness=softmax_coldness) 89 | 90 | fun = lambda x: round((1 / x) - 1) 91 | in_value_set = {1 / (out_v + 1) for out_v in out_value_set} 92 | if categorical_output: 93 | mlp = numerical_mlp.map_numerical_to_categorical_mlp( 94 | f=fun, 95 | input_space=attn_out_space, 96 | output_space=output_space, 97 | input_value_set=in_value_set, 98 | one_space=one_space, 99 | hidden_name=f"_hidden_{label}_", 100 | large_number=mlp_large_number) 101 | else: 102 | mlp = numerical_mlp.map_numerical_mlp( 103 | f=fun, 104 | input_space=attn_out_space, 105 | output_space=output_space, 106 | input_value_set=in_value_set, 107 | one_space=one_space, 108 | hidden_name=f"_hidden_{label}_", 109 | large_number=mlp_large_number) 110 | 111 | # This implementation of selector width writes at each position including 112 | # the BOS. To ensure that the BOS token position does not contain 113 | # additional values, we add an mlp to subtract the output of both layers. 114 | clean_bos_out_space = bases.join_vector_spaces(attn_out_space, output_space) 115 | vec_to_subtract_from_bos = attn_out_vec.project(clean_bos_out_space) 116 | 117 | if categorical_output: 118 | # Add the one-hot encoding of the zero value to the vector 119 | # which will get scrubbed from the BOS position. 120 | zero_dir = [d for d in output_space.basis if d.value == 0][0] 121 | zero_vec = clean_bos_out_space.vector_from_basis_direction(zero_dir) 122 | vec_to_subtract_from_bos += zero_vec 123 | 124 | # Construct an MLP that subtracts vec_to_subtract_from_bos * bos 125 | # from the residual stream which is vec_to_subtract_from_bos in the 126 | # bos position and 0 else. vec_to_subtract_from_bos contains what the 127 | # attention head writes to the bos position. 128 | 129 | hidden_dir = bases.BasisDirection("_hidden_clean_bos_") 130 | hidden_space = bases.VectorSpaceWithBasis([hidden_dir]) 131 | hidden_vec = hidden_space.vector_from_basis_direction(hidden_dir) 132 | 133 | # It's okay to use the local variables because they are only used within 134 | # the same loop iteration to create the MLP. 135 | # pylint: disable=cell-var-from-loop 136 | first_layer = vectorspace_fns.Linear.from_action(bos_space, hidden_space, 137 | lambda x: hidden_vec) 138 | second_layer = vectorspace_fns.Linear.from_action( 139 | hidden_space, clean_bos_out_space, lambda x: -vec_to_subtract_from_bos) 140 | # pylint: enable=cell-var-from-loop 141 | clean_bos_mlp = transformers.MLP(first_layer, second_layer) 142 | 143 | mlp = transformers.MLP.combine_in_parallel([mlp, clean_bos_mlp]) 144 | return transformers.SeriesWithResiduals([attn, mlp]) 145 | -------------------------------------------------------------------------------- /tracrx/craft/tests_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper functions for tests.""" 16 | 17 | from absl.testing import parameterized 18 | import numpy as np 19 | from tracrx.craft import bases 20 | 21 | 22 | def strip_bos_token(vector: bases.VectorInBasis) -> bases.VectorInBasis: 23 | """Removes BOS token of a vector.""" 24 | return bases.VectorInBasis(vector.basis_directions, vector.magnitudes[1:]) 25 | 26 | 27 | class VectorFnTestCase(parameterized.TestCase): 28 | """Asserts for vectors.""" 29 | 30 | def assertVectorAllClose(self, v1: bases.VectorInBasis, 31 | v2: bases.VectorInBasis): 32 | self.assertEqual(v1.basis_directions, v2.basis_directions) 33 | np.testing.assert_allclose(v1.magnitudes, v2.magnitudes, atol=1e-7) 34 | -------------------------------------------------------------------------------- /tracrx/craft/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Pieces for making transformers.""" 16 | 17 | import abc 18 | import dataclasses 19 | from typing import Iterable, List, Optional, Sequence, Union 20 | 21 | import numpy as np 22 | 23 | from tracrx.craft import bases 24 | from tracrx.craft import vectorspace_fns 25 | 26 | project = vectorspace_fns.project 27 | 28 | 29 | def _np_softmax(x, axis=-1): 30 | x_max = np.max(x, axis=axis, keepdims=True) 31 | return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) 32 | 33 | 34 | def _np_relu(x): 35 | return np.where(x > 0, x, 0) 36 | 37 | 38 | def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: 39 | return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) 40 | 41 | 42 | class Block(abc.ABC): 43 | """Transformer block, acting on a sequence of vector space elements. 44 | 45 | Attributes: 46 | residual_space: Vector space that contains all subspaces the Block interacts 47 | with. This can be either the full residual space of a model or a subspace. 48 | """ 49 | residual_space: bases.VectorSpaceWithBasis 50 | 51 | @abc.abstractmethod 52 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 53 | """Applies self to an input.""" 54 | 55 | 56 | @dataclasses.dataclass 57 | class AttentionHead(Block): 58 | """A transformer attention head.""" 59 | w_qk: vectorspace_fns.ScalarBilinear 60 | w_ov: vectorspace_fns.Linear 61 | residual_space: Optional[bases.VectorSpaceWithBasis] = None 62 | causal: bool = False 63 | 64 | def __post_init__(self): 65 | """Infer residual stream and typecheck subspaces.""" 66 | if self.residual_space is None: 67 | self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, 68 | self.w_qk.right_space, 69 | self.w_ov.input_space, 70 | self.w_ov.output_space) 71 | 72 | assert self.w_qk.left_space.issubspace(self.residual_space) 73 | assert self.w_qk.right_space.issubspace(self.residual_space) 74 | assert self.w_ov.input_space.issubspace(self.residual_space) 75 | assert self.w_ov.output_space.issubspace(self.residual_space) 76 | 77 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 78 | assert self.residual_space is not None 79 | assert x in self.residual_space 80 | # seq_len x query_space 81 | queries = x.project(self.w_qk.left_space) 82 | # seq_len x key_space 83 | keys = x.project(self.w_qk.right_space) 84 | 85 | attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T 86 | 87 | if self.causal: 88 | # The 1 gives us the matrix above the diagonal. 89 | mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) 90 | attn_matrix = attn_matrix + mask 91 | 92 | attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to 93 | values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model 94 | 95 | magnitudes = attn_weights @ values # seq_len_from, d_model 96 | return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) 97 | 98 | def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 99 | """Wov but acting on the residual space.""" 100 | x = project(self.residual_space, self.w_ov.input_space)(x) 101 | out = self.w_ov(x) 102 | return project(self.w_ov.output_space, self.residual_space)(out) 103 | 104 | @property 105 | def num_heads(self) -> int: 106 | return 1 107 | 108 | def as_multi(self) -> "MultiAttentionHead": 109 | return MultiAttentionHead([self]) 110 | 111 | 112 | @dataclasses.dataclass 113 | class MultiAttentionHead(Block): 114 | """Applies attention heads in parallel.""" 115 | sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]] 116 | 117 | def __post_init__(self): 118 | spaces = [block.residual_space for block in self.sub_blocks] 119 | self.residual_space, *others = spaces 120 | assert all(s == self.residual_space for s in others) 121 | 122 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 123 | # each element is seq_len x embedding 124 | outs = [block.apply(x) for block in self.sub_blocks] 125 | return bases.VectorInBasis.sum(outs) # seq_len x embedding 126 | 127 | @property 128 | def num_heads(self) -> int: 129 | return sum(sub_block.num_heads for sub_block in self.sub_blocks) 130 | 131 | def heads(self) -> Iterable[AttentionHead]: 132 | for sub_block in self.sub_blocks: 133 | if isinstance(sub_block, AttentionHead): 134 | yield sub_block 135 | elif isinstance(sub_block, MultiAttentionHead): 136 | yield from sub_block.heads() 137 | else: 138 | raise NotImplementedError() 139 | 140 | def as_multi(self) -> "MultiAttentionHead": 141 | return self 142 | 143 | 144 | @dataclasses.dataclass 145 | class MLP(Block): 146 | """A transformer MLP block.""" 147 | fst: vectorspace_fns.Linear 148 | snd: vectorspace_fns.Linear 149 | residual_space: Optional[bases.VectorSpaceWithBasis] = None 150 | 151 | def __post_init__(self): 152 | """Typecheck subspaces.""" 153 | if self.residual_space is None: 154 | self.residual_space = bases.join_vector_spaces(self.fst.input_space, 155 | self.snd.output_space) 156 | 157 | assert self.fst.output_space == self.snd.input_space 158 | assert self.fst.input_space.issubspace(self.residual_space) 159 | assert self.snd.output_space.issubspace(self.residual_space) 160 | 161 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 162 | assert x in self.residual_space 163 | 164 | x = project(self.residual_space, self.fst.input_space)(x) 165 | hidden = self.fst(x) 166 | hidden = relu(hidden) 167 | out = self.snd(hidden) 168 | return project(self.snd.output_space, self.residual_space)(out) 169 | 170 | @classmethod 171 | def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": 172 | fst = vectorspace_fns.Linear.combine_in_parallel( 173 | [block.fst for block in mlps]) 174 | snd = vectorspace_fns.Linear.combine_in_parallel( 175 | [block.snd for block in mlps]) 176 | return cls(fst=fst, snd=snd, residual_space=None) 177 | 178 | 179 | # Block that fits into a half-layer, without residual connections. 180 | HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] 181 | 182 | 183 | @dataclasses.dataclass 184 | class SeriesWithResiduals(Block): 185 | """A series of blocks with residual connections.""" 186 | blocks: List[HalfLayerBlock] 187 | 188 | def __post_init__(self): 189 | spaces = [block.residual_space for block in self.blocks] 190 | self.residual_space = bases.join_vector_spaces(*spaces) 191 | 192 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 193 | x = x.project(self.residual_space) 194 | for block in self.blocks: 195 | x_in = x.project(block.residual_space) 196 | x_out = block.apply(x_in).project(self.residual_space) 197 | x = x + x_out 198 | return x 199 | -------------------------------------------------------------------------------- /tracrx/craft/transformers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformers.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracrx.craft import bases 21 | from tracrx.craft import tests_common 22 | from tracrx.craft import transformers 23 | from tracrx.craft import vectorspace_fns as vs_fns 24 | 25 | # This makes it easier to use comments to annotate dimensions in arrays 26 | # pylint: disable=g-no-space-after-comment 27 | 28 | 29 | class AttentionHeadTest(tests_common.VectorFnTestCase): 30 | 31 | @parameterized.parameters([ 32 | dict(with_residual_stream=False), 33 | dict(with_residual_stream=True), 34 | ]) 35 | def test_attention_head(self, with_residual_stream): 36 | i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) 37 | o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) 38 | q = bases.VectorSpaceWithBasis.from_values("q", [1, 2]) 39 | k = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) 40 | rs = bases.direct_sum(i, o, q, k) 41 | 42 | seq = bases.VectorInBasis( 43 | rs.basis, 44 | np.array([ 45 | #i1 i2 o1 o2 q1 q2 p1 p2 46 | [1, 0, 0, 0, 1, 0, 1, 0], 47 | [0, 1, 0, 0, 0, 1, 0, 1], 48 | ])) 49 | 50 | head = transformers.AttentionHead( 51 | w_qk=vs_fns.ScalarBilinear(q, k, 52 | np.eye(2) * 100), 53 | w_ov=vs_fns.Linear(i, o, np.eye(2)), 54 | residual_space=rs if with_residual_stream else None, 55 | causal=False, 56 | ) 57 | 58 | self.assertVectorAllClose( 59 | head.apply(seq), 60 | bases.VectorInBasis( 61 | rs.basis, 62 | np.array([ 63 | #i1 i2 o1 o2 q1 q2 p1 p2 64 | [0, 0, 1, 0, 0, 0, 0, 0], 65 | [0, 0, 0, 1, 0, 0, 0, 0], 66 | ])), 67 | ) 68 | 69 | 70 | class MLPTest(tests_common.VectorFnTestCase): 71 | 72 | @parameterized.parameters([ 73 | dict(with_residual_stream=False, same_in_out=False), 74 | dict(with_residual_stream=False, same_in_out=True), 75 | dict(with_residual_stream=True, same_in_out=False), 76 | dict(with_residual_stream=True, same_in_out=True), 77 | ]) 78 | def test_mlp(self, with_residual_stream, same_in_out): 79 | i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) 80 | if same_in_out: 81 | o, rs = i, i 82 | expected_result = np.array([ 83 | #o1 o2 84 | [1, 0], 85 | [0, 1], 86 | ]) 87 | else: 88 | o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) 89 | rs = bases.direct_sum(i, o) 90 | expected_result = np.array([ 91 | #i1 i2 o1 o2 92 | [0, 0, 1, 0], 93 | [0, 0, 0, 1], 94 | ]) 95 | h = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) 96 | 97 | seq = bases.VectorInBasis( 98 | i.basis, 99 | np.array([ 100 | #i1 i2 101 | [1, -1], 102 | [-1, 1], 103 | ])).project(rs) 104 | 105 | mlp = transformers.MLP( 106 | fst=vs_fns.Linear(i, h, np.eye(2)), 107 | snd=vs_fns.Linear(h, o, np.eye(2)), 108 | residual_space=rs if with_residual_stream else None, 109 | ) 110 | 111 | self.assertEqual( 112 | mlp.apply(seq), 113 | bases.VectorInBasis(rs.basis, expected_result), 114 | ) 115 | 116 | def test_combining_mlps(self): 117 | in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2]) 118 | in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4]) 119 | out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2]) 120 | residual_space = bases.join_vector_spaces(in12, in34, out12) 121 | 122 | h1 = bases.VectorSpaceWithBasis.from_values("h", [1]) 123 | h2 = bases.VectorSpaceWithBasis.from_values("h", [2]) 124 | 125 | # MLP1 maps in2 -> h1 -> out1 126 | mlp1 = transformers.MLP( 127 | fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])), 128 | snd=vs_fns.Linear(h1, out12, np.array([[1, 0]]))) 129 | 130 | # MLP2 maps in3 -> h2 -> out2 131 | mlp2 = transformers.MLP( 132 | fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])), 133 | snd=vs_fns.Linear(h2, out12, np.array([[0, 1]]))) 134 | 135 | mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2]) 136 | 137 | seq = bases.VectorInBasis( 138 | bases.direct_sum(in12, in34).basis, 139 | np.array([ 140 | #i1 i2 i3 i4 141 | [1, 2, 0, 0], 142 | [0, 2, 3, 4], 143 | ])).project(residual_space) 144 | 145 | expected_result = bases.VectorInBasis( 146 | out12.basis, 147 | np.array([ 148 | #o1 o2 149 | [2, 0], 150 | [2, 3], 151 | ])) 152 | 153 | self.assertEqual( 154 | mlp.apply(seq).project(out12), 155 | expected_result, 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /tracrx/craft/vectorspace_fns.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions on vector spaces.""" 16 | 17 | import abc 18 | import dataclasses 19 | from typing import Callable, Sequence 20 | 21 | import numpy as np 22 | from tracrx.craft import bases 23 | 24 | VectorSpaceWithBasis = bases.VectorSpaceWithBasis 25 | VectorInBasis = bases.VectorInBasis 26 | BasisDirection = bases.BasisDirection 27 | 28 | 29 | class VectorFunction(abc.ABC): 30 | """A function that acts on vectors.""" 31 | 32 | input_space: VectorSpaceWithBasis 33 | output_space: VectorSpaceWithBasis 34 | 35 | @abc.abstractmethod 36 | def __call__(self, x: VectorInBasis) -> VectorInBasis: 37 | """Evaluates the function.""" 38 | 39 | 40 | class Linear(VectorFunction): 41 | """A linear function.""" 42 | 43 | def __init__( 44 | self, 45 | input_space: VectorSpaceWithBasis, 46 | output_space: VectorSpaceWithBasis, 47 | matrix: np.ndarray, 48 | ): 49 | """Initialises. 50 | 51 | Args: 52 | input_space: The input vector space. 53 | output_space: The output vector space. 54 | matrix: a [input, output] matrix acting in a (sorted) basis. 55 | """ 56 | self.input_space = input_space 57 | self.output_space = output_space 58 | self.matrix = matrix 59 | 60 | def __post_init__(self) -> None: 61 | output_size, input_size = self.matrix.shape 62 | assert input_size == self.input_space.num_dims 63 | assert output_size == self.output_space.num_dims 64 | 65 | def __call__(self, x: VectorInBasis) -> VectorInBasis: 66 | if x not in self.input_space: 67 | raise TypeError(f"x={x} not in self.input_space={self.input_space}.") 68 | return self.output_space.make_vector(x.magnitudes @ self.matrix) 69 | 70 | @classmethod 71 | def from_action( 72 | cls, 73 | input_space: VectorSpaceWithBasis, 74 | output_space: VectorSpaceWithBasis, 75 | action: Callable[[BasisDirection], VectorInBasis], 76 | ) -> "Linear": 77 | """from_action(i, o)(action) creates a Linear.""" 78 | 79 | matrix = np.zeros((input_space.num_dims, output_space.num_dims)) 80 | for i, direction in enumerate(input_space.basis): 81 | out_vector = action(direction) 82 | if out_vector not in output_space: 83 | raise TypeError( 84 | f"image of {direction} from input_space={input_space} " 85 | f"is not in output_space={output_space}" 86 | ) 87 | matrix[i, :] = out_vector.magnitudes 88 | 89 | return Linear(input_space, output_space, matrix) 90 | 91 | @classmethod 92 | def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": 93 | """Combines multiple parallel linear functions into a single one.""" 94 | joint_input_space = bases.join_vector_spaces( 95 | *[fn.input_space for fn in fns] 96 | ) 97 | joint_output_space = bases.join_vector_spaces( 98 | *[fn.output_space for fn in fns] 99 | ) 100 | 101 | # Cache properties for the parents to avoid recomputing for each child. 102 | # Since the index_by_direction cached_property of the children is needed 103 | # within the action, it would be computed for every single child. This is 104 | # redundant as they share the same basis. By accessing the properties here, 105 | # we ensure they are only computed once and passed on to the children. 106 | _ = joint_input_space.index_by_direction 107 | _ = joint_output_space.index_by_direction 108 | 109 | def action(x: bases.BasisDirection) -> bases.VectorInBasis: 110 | out = joint_output_space.null_vector() 111 | for fn in fns: 112 | if x in fn.input_space: 113 | x_vec = fn.input_space.vector_from_basis_direction(x) 114 | applied = fn(x_vec) 115 | out = out.add_directions(applied) 116 | return out 117 | 118 | return cls.from_action(joint_input_space, joint_output_space, action) 119 | 120 | 121 | def project( 122 | from_space: VectorSpaceWithBasis, 123 | to_space: VectorSpaceWithBasis, 124 | ) -> Linear: 125 | """Creates a projection.""" 126 | 127 | def action(direction: bases.BasisDirection) -> VectorInBasis: 128 | if direction in to_space: 129 | return to_space.vector_from_basis_direction(direction) 130 | else: 131 | return to_space.null_vector() 132 | 133 | return Linear.from_action(from_space, to_space, action=action) 134 | 135 | 136 | @dataclasses.dataclass 137 | class ScalarBilinear: 138 | """A scalar-valued bilinear operator.""" 139 | 140 | left_space: VectorSpaceWithBasis 141 | right_space: VectorSpaceWithBasis 142 | matrix: np.ndarray 143 | 144 | def __post_init__(self): 145 | """Ensure matrix acts in sorted bases and typecheck sizes.""" 146 | left_size, right_size = self.matrix.shape 147 | assert left_size == self.left_space.num_dims 148 | assert right_size == self.right_space.num_dims 149 | 150 | def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: 151 | """Describes the action of the operator on vectors.""" 152 | if x not in self.left_space: 153 | raise TypeError(f"x={x} not in self.left_space={self.left_space}.") 154 | if y not in self.right_space: 155 | raise TypeError(f"y={y} not in self.right_space={self.right_space}.") 156 | return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() 157 | 158 | @classmethod 159 | def from_action( 160 | cls, 161 | left_space: VectorSpaceWithBasis, 162 | right_space: VectorSpaceWithBasis, 163 | action: Callable[[BasisDirection, BasisDirection], float], 164 | ) -> "ScalarBilinear": 165 | """from_action(l, r)(action) creates a ScalarBilinear.""" 166 | 167 | matrix = np.zeros((left_space.num_dims, right_space.num_dims)) 168 | for i, left_direction in enumerate(left_space.basis): 169 | for j, right_direction in enumerate(right_space.basis): 170 | matrix[i, j] = action(left_direction, right_direction) 171 | 172 | return ScalarBilinear(left_space, right_space, matrix) 173 | -------------------------------------------------------------------------------- /tracrx/craft/vectorspace_fns_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for vectorspace_fns.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracrx.craft import bases 21 | from tracrx.craft import tests_common 22 | from tracrx.craft import vectorspace_fns as vs_fns 23 | 24 | 25 | class LinearTest(tests_common.VectorFnTestCase): 26 | 27 | def test_identity_from_matrix(self): 28 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 29 | f = vs_fns.Linear(vs, vs, np.eye(3)) 30 | for v in vs.basis_vectors(): 31 | self.assertEqual(f(v), v) 32 | 33 | def test_identity_from_action(self): 34 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 35 | f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction) 36 | for v in vs.basis_vectors(): 37 | self.assertEqual(f(v), v) 38 | 39 | def test_nonidentiy(self): 40 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 41 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 42 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 43 | 44 | f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]])) 45 | 46 | self.assertEqual( 47 | f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7]))) 48 | self.assertEqual( 49 | f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1]))) 50 | 51 | def test_different_vector_spaces(self): 52 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 53 | vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) 54 | a, b = vs1.basis_vectors() 55 | c, d = vs2.basis_vectors() 56 | 57 | f = vs_fns.Linear(vs1, vs2, np.eye(2)) 58 | 59 | self.assertEqual(f(a), c) 60 | self.assertEqual(f(b), d) 61 | 62 | def test_combining_linear_functions_with_different_input(self): 63 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 64 | vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) 65 | vs = bases.direct_sum(vs1, vs2) 66 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 67 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 68 | c = vs.vector_from_basis_direction(bases.BasisDirection("c")) 69 | d = vs.vector_from_basis_direction(bases.BasisDirection("d")) 70 | 71 | f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]])) 72 | f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]])) 73 | f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) 74 | 75 | self.assertEqual( 76 | f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0]))) 77 | self.assertEqual( 78 | f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0]))) 79 | self.assertEqual( 80 | f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0]))) 81 | self.assertEqual( 82 | f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0]))) 83 | 84 | def test_combining_linear_functions_with_same_input(self): 85 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 86 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 87 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 88 | 89 | f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]])) 90 | f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]])) 91 | f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) 92 | 93 | self.assertEqual( 94 | f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1]))) 95 | self.assertEqual( 96 | f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0]))) 97 | self.assertEqual(f3(a), f1(a) + f2(a)) 98 | self.assertEqual(f3(b), f1(b) + f2(b)) 99 | 100 | 101 | class ProjectionTest(tests_common.VectorFnTestCase): 102 | 103 | def test_projection_to_larger_space(self): 104 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 105 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 106 | a1, b1 = vs1.basis_vectors() 107 | a2, b2, _, _ = vs2.basis_vectors() 108 | 109 | f = vs_fns.project(vs1, vs2) 110 | 111 | self.assertEqual(f(a1), a2) 112 | self.assertEqual(f(b1), b2) 113 | 114 | def test_projection_to_smaller_space(self): 115 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 116 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 117 | a1, b1, c1, d1 = vs1.basis_vectors() 118 | a2, b2 = vs2.basis_vectors() 119 | 120 | f = vs_fns.project(vs1, vs2) 121 | 122 | self.assertEqual(f(a1), a2) 123 | self.assertEqual(f(b1), b2) 124 | self.assertEqual(f(c1), vs2.null_vector()) 125 | self.assertEqual(f(d1), vs2.null_vector()) 126 | 127 | 128 | class ScalarBilinearTest(parameterized.TestCase): 129 | 130 | def test_identity_matrix(self): 131 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 132 | a, b = vs.basis_vectors() 133 | 134 | f = vs_fns.ScalarBilinear(vs, vs, np.eye(2)) 135 | 136 | self.assertEqual(f(a, a), 1) 137 | self.assertEqual(f(a, b), 0) 138 | self.assertEqual(f(b, a), 0) 139 | self.assertEqual(f(b, b), 1) 140 | 141 | def test_identity_from_action(self): 142 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 143 | a, b = vs.basis_vectors() 144 | 145 | f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y)) 146 | 147 | self.assertEqual(f(a, a), 1) 148 | self.assertEqual(f(a, b), 0) 149 | self.assertEqual(f(b, a), 0) 150 | self.assertEqual(f(b, b), 1) 151 | 152 | def test_non_identity(self): 153 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 154 | a, b = vs.basis_vectors() 155 | 156 | f = vs_fns.ScalarBilinear.from_action(vs, vs, 157 | lambda x, y: int(x.name == "a")) 158 | 159 | self.assertEqual(f(a, a), 1) 160 | self.assertEqual(f(a, b), 1) 161 | self.assertEqual(f(b, a), 0) 162 | self.assertEqual(f(b, b), 0) 163 | 164 | 165 | if __name__ == "__main__": 166 | absltest.main() 167 | -------------------------------------------------------------------------------- /tracrx/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/rasp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/rasp/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/rasp/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/rasp/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/rasp/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /tracrx/rasp/__pycache__/rasp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/rasp/__pycache__/rasp.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/rasp/__pycache__/rasp.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/rasp/__pycache__/rasp.cpython-311.pyc -------------------------------------------------------------------------------- /tracrx/rasp/causal_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RASP Evaluator which applies causal masks to selectors.""" 16 | 17 | from typing import Sequence, Union 18 | 19 | import numpy as np 20 | from tracrx.rasp import rasp 21 | 22 | 23 | class CausalEvaluator(rasp.DefaultRASPEvaluator): 24 | """Evaluates RASP with causal masking.""" 25 | 26 | def evaluate( 27 | self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value] 28 | ) -> Union[Sequence[rasp.Value], rasp.SelectorValue]: 29 | out = super().evaluate(expr, xs) 30 | 31 | if not isinstance(expr, rasp.Selector): 32 | return out 33 | 34 | out = np.array(out) 35 | causal_mask = np.tril(np.full(out.shape, 1)) 36 | return np.logical_and(causal_mask, out).tolist() 37 | 38 | 39 | evaluate = CausalEvaluator().evaluate 40 | -------------------------------------------------------------------------------- /tracrx/rasp/causal_eval_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for causal_eval.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | from tracrx.rasp import causal_eval 21 | from tracrx.rasp import rasp 22 | 23 | 24 | class CausalEvalTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | dict( 28 | testcase_name="constant_selector_3x3_1", 29 | program=rasp.ConstantSelector([ 30 | [True, True, True], 31 | [True, True, True], 32 | [True, True, True], 33 | ]), 34 | input_sequence=[True, True, True], 35 | expected_output=[ 36 | [True, False, False], 37 | [True, True, False], 38 | [True, True, True], 39 | ]), 40 | dict( 41 | testcase_name="constant_selector_3x3_2", 42 | program=rasp.ConstantSelector([ 43 | [True, True, True], 44 | [False, True, True], 45 | [True, False, True], 46 | ]), 47 | input_sequence=[True, True, True], 48 | expected_output=[ 49 | [True, False, False], 50 | [False, True, False], 51 | [True, False, True], 52 | ])) 53 | def test_evaluations(self, program, input_sequence, expected_output): 54 | self.assertListEqual( 55 | causal_eval.evaluate(program, input_sequence), 56 | expected_output, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /tracrx/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/transformer/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/transformer/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/transformer/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/transformer/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/transformer/__pycache__/encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/transformer/__pycache__/encoder.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/transformer/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/transformer/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/transformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Instrumented attention layer (forked from the Haiku library implementation). 16 | """ 17 | 18 | from typing import Optional 19 | import warnings 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | @chex.dataclass 29 | class AttentionOutput: 30 | out: jax.Array # [..., T', D'] 31 | logits: jax.Array # [..., H, T', T] 32 | 33 | 34 | class MultiHeadAttention(hk.Module): 35 | """Multi-headed attention (MHA) module. 36 | 37 | This module is intended for attending over sequences of vectors. 38 | 39 | Rough sketch: 40 | - Compute keys (K), queries (Q), and values (V) as projections of inputs. 41 | - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). 42 | - Output is another projection of WV^T. 43 | 44 | For more detail, see the original Transformer paper: 45 | "Attention is all you need" https://arxiv.org/abs/1706.03762. 46 | 47 | Glossary of shapes: 48 | - T: Sequence length. 49 | - D: Vector (embedding) size. 50 | - H: Number of attention heads. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | num_heads: int, 56 | key_size: int, 57 | # TODO(b/240019186): Remove `w_init_scale`. 58 | w_init_scale: Optional[float] = None, 59 | *, 60 | w_init: Optional[hk.initializers.Initializer] = None, 61 | value_size: Optional[int] = None, 62 | model_size: Optional[int] = None, 63 | name: Optional[str] = None, 64 | ): 65 | """Initialises the module. 66 | 67 | Args: 68 | num_heads: Number of independent attention heads (H). 69 | key_size: The size of keys (K) and queries used for attention. 70 | w_init_scale: DEPRECATED. Please use w_init instead. 71 | w_init: Initialiser for weights in the linear map. 72 | value_size: Optional size of the value projection (V). If None, defaults 73 | to the key size (K). 74 | model_size: Optional size of the output embedding (D'). If None, defaults 75 | to the key size multiplied by the number of heads (K * H). 76 | name: Optional name for this module. 77 | """ 78 | super().__init__(name=name) 79 | self.num_heads = num_heads 80 | self.key_size = key_size 81 | self.value_size = value_size or key_size 82 | self.model_size = model_size or key_size * num_heads 83 | 84 | # Backwards-compatibility for w_init_scale. 85 | if w_init_scale is not None: 86 | warnings.warn( 87 | "w_init_scale is deprecated; please pass an explicit weight " 88 | "initialiser instead.", DeprecationWarning) 89 | if w_init and w_init_scale: 90 | raise ValueError("Please provide only `w_init`, not `w_init_scale`.") 91 | if w_init is None and w_init_scale is None: 92 | raise ValueError("Please provide a weight initializer: `w_init`.") 93 | if w_init is None: 94 | w_init = hk.initializers.VarianceScaling(w_init_scale) 95 | self.w_init = w_init 96 | 97 | def __call__( 98 | self, 99 | query: jnp.ndarray, 100 | key: jnp.ndarray, 101 | value: jnp.ndarray, 102 | mask: Optional[jnp.ndarray] = None, 103 | ) -> AttentionOutput: 104 | """Computes (optionally masked) MHA with queries, keys & values. 105 | 106 | This module broadcasts over zero or more 'batch-like' leading dimensions. 107 | 108 | Args: 109 | query: Embeddings sequence used to compute queries; shape [..., T', D_q]. 110 | key: Embeddings sequence used to compute keys; shape [..., T, D_k]. 111 | value: Embeddings sequence used to compute values; shape [..., T, D_v]. 112 | mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. 113 | 114 | Returns: 115 | A new sequence of embeddings, consisting of a projection of the 116 | attention-weighted value projections; shape [..., T', D']. 117 | """ 118 | 119 | # In shape hints below, we suppress the leading dims [...] for brevity. 120 | # Hence e.g. [A, B] should be read in every case as [..., A, B]. 121 | *leading_dims, sequence_length, _ = query.shape 122 | projection = self._linear_projection 123 | 124 | # Compute key/query/values (overload K/Q/V to denote the respective sizes). 125 | query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] 126 | key_heads = projection(key, self.key_size, "key") # [T, H, K] 127 | value_heads = projection(value, self.value_size, "value") # [T, H, V] 128 | 129 | # Compute attention weights. 130 | attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) 131 | attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) 132 | if mask is not None: 133 | if mask.ndim != attn_logits.ndim: 134 | raise ValueError( 135 | f"Mask dimensionality {mask.ndim} must match logits dimensionality " 136 | f"{attn_logits.ndim}.") 137 | attn_logits = jnp.where(mask, attn_logits, -1e30) 138 | attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] 139 | 140 | # Weight the values by the attention and flatten the head vectors. 141 | attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) 142 | attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] 143 | 144 | # Apply another projection to get the final embeddings. 145 | final_projection = hk.Linear(self.model_size, w_init=self.w_init) 146 | out = final_projection(attn) 147 | 148 | return AttentionOutput( 149 | out=out, 150 | logits=attn_logits, 151 | ) 152 | 153 | @hk.transparent 154 | def _linear_projection( 155 | self, 156 | x: jnp.ndarray, 157 | head_size: int, 158 | name: Optional[str] = None, 159 | ) -> jnp.ndarray: 160 | y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x) 161 | *leading_dims, _ = x.shape 162 | return y.reshape((*leading_dims, self.num_heads, head_size)) 163 | -------------------------------------------------------------------------------- /tracrx/transformer/compressed_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Modified transformer to learn a linear compression of the residual stream. 16 | 17 | CompressedTransformer adds three arguments compared to Transformer: 18 | - embedding_size: the size of the compressed residual stream. 19 | - unembed_at_every_layer: whether to apply the unembedding before applying 20 | attention and MLP layers 21 | - return_activations: whether to return all model activations rather than just 22 | the outputs 23 | """ 24 | 25 | import collections 26 | import dataclasses 27 | from typing import Optional 28 | 29 | import haiku as hk 30 | import jax 31 | import numpy as np 32 | 33 | from tracrx.transformer import attention 34 | from tracrx.transformer import model 35 | 36 | 37 | @dataclasses.dataclass 38 | class CompressedTransformer(hk.Module): 39 | """A transformer stack with linearly compressed residual stream.""" 40 | 41 | config: model.TransformerConfig 42 | name: Optional[str] = None 43 | 44 | def __call__( 45 | self, 46 | embeddings: jax.Array, # [B, T, D] 47 | mask: jax.Array, # [B, T] 48 | *, 49 | use_dropout: bool = True, 50 | embedding_size: Optional[int] = None, 51 | unembed_at_every_layer: bool = False, 52 | ) -> model.TransformerOutput: # [B, T, D] 53 | """Transforms input embedding sequences to output embedding sequences. 54 | 55 | Args: 56 | embeddings: Input embeddings to pass through the model. 57 | mask: Boolean mask to restrict the inputs the model uses. 58 | use_dropout: Turns dropout on/off. 59 | embedding_size: Dimension to compress the residual stream to. 60 | unembed_at_every_layer: Whether to unembed the residual stream when 61 | reading the input for every layer (keeping the layer input sizes) or to 62 | only unembed before the model output (compressing the layer inputs). 63 | 64 | Returns: 65 | The outputs of the forward pass through the transformer. 66 | """ 67 | 68 | def layer_norm(x: jax.Array) -> jax.Array: 69 | """Applies a unique LayerNorm to x with default settings.""" 70 | if self.config.layer_norm: 71 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 72 | return x 73 | 74 | initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) 75 | dropout_rate = self.config.dropout_rate if use_dropout else 0. 76 | _, seq_len, model_size = embeddings.shape 77 | 78 | # To compress the model, we multiply with a matrix W when reading from 79 | # the residual stream, and with W^T when writing to the residual stream. 80 | if embedding_size is not None: 81 | # [to_size, from_size] 82 | w_emb = hk.get_parameter( 83 | "w_emb", (embedding_size, model_size), 84 | init=hk.initializers.RandomNormal()) 85 | 86 | write_to_residual = lambda x: x @ w_emb.T 87 | read_from_residual = lambda x: x @ w_emb 88 | 89 | if not unembed_at_every_layer: 90 | model_size = embedding_size 91 | else: 92 | write_to_residual = lambda x: x 93 | read_from_residual = lambda x: x 94 | 95 | # Compute causal mask for autoregressive sequence modelling. 96 | mask = mask[:, None, None, :] # [B, H=1, T'=1, T] 97 | mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] 98 | 99 | if self.config.causal: 100 | causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] 101 | causal_mask = np.tril(causal_mask) 102 | mask = mask * causal_mask # [B, H=1, T, T] 103 | 104 | # Set up activation collection. 105 | collected = collections.defaultdict(list) 106 | 107 | def collect(**kwargs): 108 | for k, v in kwargs.items(): 109 | collected[k].append(v) 110 | 111 | residual = write_to_residual(embeddings) 112 | 113 | for layer in range(self.config.num_layers): 114 | with hk.experimental.name_scope(f"layer_{layer}"): 115 | # First the attention block. 116 | attn_block = attention.MultiHeadAttention( 117 | num_heads=self.config.num_heads, 118 | key_size=self.config.key_size, 119 | model_size=model_size, 120 | w_init=initializer, 121 | name="attn") 122 | 123 | attn_in = residual 124 | if unembed_at_every_layer: 125 | attn_in = read_from_residual(attn_in) 126 | attn_in = layer_norm(attn_in) 127 | attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) 128 | attn_out, attn_logits = attn_out.out, attn_out.logits 129 | if dropout_rate > 0: 130 | attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) 131 | 132 | if unembed_at_every_layer: 133 | collect(layer_outputs=attn_out, attn_logits=attn_logits) 134 | else: 135 | collect( 136 | layer_outputs=read_from_residual(attn_out), 137 | attn_logits=attn_logits, 138 | ) 139 | 140 | if unembed_at_every_layer: 141 | attn_out = write_to_residual(attn_out) 142 | residual = residual + attn_out 143 | 144 | collect(residuals=residual) 145 | 146 | # Then the dense block. 147 | with hk.experimental.name_scope("mlp"): 148 | dense_block = hk.Sequential([ 149 | hk.Linear( 150 | self.config.mlp_hidden_size, 151 | w_init=initializer, 152 | name="linear_1"), 153 | self.config.activation_function, 154 | hk.Linear(model_size, w_init=initializer, name="linear_2"), 155 | ]) 156 | 157 | dense_in = residual 158 | if unembed_at_every_layer: 159 | dense_in = read_from_residual(dense_in) 160 | dense_in = layer_norm(dense_in) 161 | dense_out = dense_block(dense_in) 162 | if dropout_rate > 0: 163 | dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) 164 | 165 | if unembed_at_every_layer: 166 | collect(layer_outputs=dense_out) 167 | else: 168 | collect(layer_outputs=read_from_residual(dense_out)) 169 | 170 | if unembed_at_every_layer: 171 | dense_out = write_to_residual(dense_out) 172 | residual = residual + dense_out 173 | 174 | collect(residuals=residual) 175 | 176 | output = read_from_residual(residual) 177 | output = layer_norm(output) 178 | 179 | return model.TransformerOutput( 180 | layer_outputs=collected["layer_outputs"], 181 | residuals=collected["residuals"], 182 | attn_logits=collected["attn_logits"], 183 | output=output, 184 | input_embeddings=embeddings, 185 | ) 186 | -------------------------------------------------------------------------------- /tracrx/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic encoder for inputs with a fixed vocabulary.""" 16 | 17 | import abc 18 | from typing import Any, List, Optional, Sequence 19 | 20 | from tracrx.craft import bases 21 | 22 | 23 | class Encoder(abc.ABC): 24 | """Encodes a list of tokens into a list of inputs for a transformer model. 25 | 26 | The abstract class does not make assumptions on the input and output types, 27 | and we have different encoders for different input types. 28 | """ 29 | 30 | @abc.abstractmethod 31 | def encode(self, inputs: List[Any]) -> List[Any]: 32 | return list() 33 | 34 | @abc.abstractmethod 35 | def decode(self, encodings: List[Any]) -> List[Any]: 36 | return list() 37 | 38 | @property 39 | def pad_token(self) -> Optional[str]: 40 | return None 41 | 42 | @property 43 | def bos_token(self) -> Optional[str]: 44 | return None 45 | 46 | @property 47 | def pad_encoding(self) -> Optional[int]: 48 | return None 49 | 50 | @property 51 | def bos_encoding(self) -> Optional[int]: 52 | return None 53 | 54 | 55 | class NumericalEncoder(Encoder): 56 | """Encodes numerical variables (simply using the identity mapping).""" 57 | 58 | def encode(self, inputs: List[float]) -> List[float]: 59 | return inputs 60 | 61 | def decode(self, encodings: List[float]) -> List[float]: 62 | return encodings 63 | 64 | 65 | class CategoricalEncoder(Encoder): 66 | """Encodes categorical variables with a fixed vocabulary.""" 67 | 68 | def __init__( 69 | self, 70 | basis: Sequence[bases.BasisDirection], 71 | enforce_bos: bool = False, 72 | bos_token: Optional[str] = None, 73 | pad_token: Optional[str] = None, 74 | max_seq_len: Optional[int] = None, 75 | ): 76 | """Initialises. If enforce_bos is set, ensures inputs start with it.""" 77 | if enforce_bos and not bos_token: 78 | raise ValueError("BOS token must be specified if enforcing BOS.") 79 | 80 | self.encoding_map = {} 81 | for i, direction in enumerate(basis): 82 | val = direction.value 83 | self.encoding_map[val] = i 84 | 85 | if bos_token and bos_token not in self.encoding_map: 86 | raise ValueError("BOS token missing in encoding.") 87 | 88 | if pad_token and pad_token not in self.encoding_map: 89 | raise ValueError("PAD token missing in encoding.") 90 | 91 | self.enforce_bos = enforce_bos 92 | self._bos_token = bos_token 93 | self._pad_token = pad_token 94 | self._max_seq_len = max_seq_len 95 | 96 | def encode(self, inputs: List[bases.Value]) -> List[int]: 97 | if self.enforce_bos and inputs[0] != self.bos_token: 98 | raise ValueError("First input token must be BOS token. " 99 | f"Should be '{self.bos_token}', but was '{inputs[0]}'.") 100 | if missing := set(inputs) - set(self.encoding_map.keys()): 101 | raise ValueError(f"Inputs {missing} not found in encoding ", 102 | self.encoding_map.keys()) 103 | if self._max_seq_len is not None and len(inputs) > self._max_seq_len: 104 | raise ValueError(f"inputs={inputs} are longer than the maximum " 105 | f"sequence length {self._max_seq_len}") 106 | 107 | return [self.encoding_map[x] for x in inputs] 108 | 109 | def decode(self, encodings: List[int]) -> List[bases.Value]: 110 | """Recover the tokens that corresponds to `ids`. Inverse of __call__.""" 111 | decoding_map = {val: key for key, val in self.encoding_map.items()} 112 | if missing := set(encodings) - set(decoding_map.keys()): 113 | raise ValueError(f"Inputs {missing} not found in decoding map ", 114 | decoding_map.keys()) 115 | return [decoding_map[x] for x in encodings] 116 | 117 | @property 118 | def vocab_size(self) -> int: 119 | return len(self.encoding_map) 120 | 121 | @property 122 | def bos_token(self) -> Optional[str]: 123 | return self._bos_token 124 | 125 | @property 126 | def pad_token(self) -> Optional[str]: 127 | return self._pad_token 128 | 129 | @property 130 | def bos_encoding(self) -> Optional[int]: 131 | return None if self.bos_token is None else self.encoding_map[self.bos_token] 132 | 133 | @property 134 | def pad_encoding(self) -> Optional[int]: 135 | return None if self.pad_token is None else self.encoding_map[self.pad_token] 136 | -------------------------------------------------------------------------------- /tracrx/transformer/encoder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformer.encoder.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.craft import bases 20 | from tracrx.transformer import encoder 21 | 22 | _BOS_TOKEN = "bos_encoder_test" 23 | _PAD_TOKEN = "pad_encoder_test" 24 | 25 | 26 | class CategoricalEncoderTest(parameterized.TestCase): 27 | 28 | def test_encode_raises_value_error_if_input_doesnt_start_with_bos(self): 29 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) 30 | basic_encoder = encoder.CategoricalEncoder( 31 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 32 | with self.assertRaisesRegex(ValueError, 33 | r"^.*First input token must be BOS token.*$"): 34 | basic_encoder.encode([1, 1, 1]) 35 | 36 | def test_encode_raises_value_error_if_input_not_in_vocab(self): 37 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) 38 | basic_encoder = encoder.CategoricalEncoder( 39 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 40 | with self.assertRaisesRegex(ValueError, 41 | r"^.*Inputs .* not found in encoding.*$"): 42 | basic_encoder.encode([_BOS_TOKEN, 1, 2, 3, 4]) 43 | 44 | def test_decode_raises_value_error_if_id_outside_of_vocab_size(self): 45 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, _BOS_TOKEN}) 46 | basic_encoder = encoder.CategoricalEncoder( 47 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 48 | with self.assertRaisesRegex(ValueError, 49 | r"^.*Inputs .* not found in decoding map.*$"): 50 | basic_encoder.decode([0, 1, 2, 3]) 51 | 52 | def test_encoder_raises_value_error_if_bos_not_in_basis(self): 53 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) 54 | with self.assertRaisesRegex(ValueError, 55 | r"^.*BOS token missing in encoding.*$"): 56 | unused_basic_encoder = encoder.CategoricalEncoder( 57 | vs.basis, bos_token=_BOS_TOKEN) 58 | 59 | def test_encoder_raises_value_error_if_pad_not_in_basis(self): 60 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) 61 | with self.assertRaisesRegex(ValueError, 62 | r"^.*PAD token missing in encoding.*$"): 63 | unused_basic_encoder = encoder.CategoricalEncoder( 64 | vs.basis, pad_token=_PAD_TOKEN) 65 | 66 | def test_encoder_encodes_bos_and_pad_tokens_as_expected(self): 67 | vs = bases.VectorSpaceWithBasis.from_values( 68 | "input", {1, 2, 3, _BOS_TOKEN, _PAD_TOKEN}) 69 | basic_encoder = encoder.CategoricalEncoder( 70 | vs.basis, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 71 | self.assertEqual( 72 | basic_encoder.encode([_BOS_TOKEN, _PAD_TOKEN]), 73 | [basic_encoder.bos_encoding, basic_encoder.pad_encoding]) 74 | 75 | @parameterized.parameters([ 76 | dict( 77 | vocab={1, 2, 3, _BOS_TOKEN}, # lexicographic order 78 | inputs=[_BOS_TOKEN, 3, 2, 1], 79 | expected=[3, 2, 1, 0]), 80 | dict( 81 | vocab={"a", "b", _BOS_TOKEN, "c"}, # lexicographic order 82 | inputs=[_BOS_TOKEN, "b", "b", "c"], 83 | expected=[2, 1, 1, 3]), 84 | ]) 85 | def test_tokens_are_encoded_in_lexicographic_order(self, vocab, inputs, 86 | expected): 87 | # Expect encodings to be assigned to ids according to a lexicographic 88 | # ordering of the vocab 89 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 90 | basic_encoder = encoder.CategoricalEncoder( 91 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 92 | encodings = basic_encoder.encode(inputs) 93 | self.assertEqual(encodings, expected) 94 | 95 | @parameterized.parameters([ 96 | dict(vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, expected=5), 97 | dict(vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b"}, expected=4), 98 | ]) 99 | def test_vocab_size_has_expected_value(self, vocab, expected): 100 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 101 | basic_encoder = encoder.CategoricalEncoder( 102 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 103 | self.assertEqual(basic_encoder.vocab_size, expected) 104 | 105 | @parameterized.parameters([ 106 | dict( 107 | vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, inputs=[_BOS_TOKEN, 3, 2, 108 | 1]), 109 | dict( 110 | vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b", "c"}, 111 | inputs=[_BOS_TOKEN, "b", "b", "c"]), 112 | ]) 113 | def test_decode_inverts_encode(self, vocab, inputs): 114 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 115 | basic_encoder = encoder.CategoricalEncoder( 116 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 117 | encodings = basic_encoder.encode(inputs) 118 | recovered = basic_encoder.decode(encodings) 119 | self.assertEqual(recovered, inputs) 120 | 121 | 122 | if __name__ == "__main__": 123 | absltest.main() 124 | -------------------------------------------------------------------------------- /tracrx/transformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Didactic example of an autoregressive Transformer-based language model. 16 | 17 | Glossary of shapes: 18 | - B: Batch size. 19 | - T: Sequence length. 20 | - D: Model embedding size. 21 | - H: Number of attention heads. 22 | - V: Vocabulary size. 23 | 24 | Forked from: haiku.examples.transformer.model 25 | """ 26 | 27 | import collections 28 | import dataclasses 29 | from typing import Callable, List, Optional 30 | 31 | import chex 32 | import haiku as hk 33 | import jax 34 | import jax.numpy as jnp 35 | import numpy as np 36 | from tracrx.transformer import attention 37 | 38 | # hk.Modules are not always callable: github.com/deepmind/dm-haiku/issues/52 39 | # Ideally, we'd want a type: 40 | # CallableHaikuModule = Intersection[Callable[..., jax.Array], hk.Module] 41 | # But Intersection does not exist (yet): github.com/python/typing/issues/213 42 | CallableHaikuModule = Callable[..., jax.Array] 43 | 44 | 45 | @chex.dataclass 46 | class TransformerOutput: 47 | layer_outputs: List[jax.Array] # [B, T, D] 48 | residuals: List[jax.Array] # [B, T, D] 49 | attn_logits: List[jax.Array] # [B, H, T, T] 50 | output: jax.Array # [B, T, D] 51 | input_embeddings: jax.Array # [B, T, D] 52 | 53 | 54 | @dataclasses.dataclass 55 | class TransformerConfig: 56 | num_heads: int 57 | num_layers: int 58 | key_size: int 59 | mlp_hidden_size: int 60 | dropout_rate: float 61 | activation_function: Callable[[jax.Array], jax.Array] = jax.nn.gelu 62 | layer_norm: bool = True 63 | causal: bool = False 64 | 65 | 66 | @dataclasses.dataclass 67 | class Transformer(hk.Module): 68 | """A transformer stack.""" 69 | 70 | config: TransformerConfig 71 | name: Optional[str] = None 72 | 73 | def __call__( 74 | self, 75 | embeddings: jax.Array, # [B, T, D] 76 | mask: jax.Array, # [B, T] 77 | *, 78 | use_dropout: bool = True, 79 | ) -> TransformerOutput: 80 | """Transforms input embedding sequences to output embedding sequences.""" 81 | 82 | def layer_norm(x: jax.Array) -> jax.Array: 83 | """Applies a unique LayerNorm to x with default settings.""" 84 | if self.config.layer_norm: 85 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 86 | return x 87 | 88 | initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) 89 | dropout_rate = self.config.dropout_rate if use_dropout else 0. 90 | _, seq_len, model_size = embeddings.shape 91 | 92 | # Compute causal mask for autoregressive sequence modelling. 93 | mask = mask[:, None, None, :] # [B, H=1, T'=1, T] 94 | mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] 95 | 96 | if self.config.causal: 97 | causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] 98 | causal_mask = np.tril(causal_mask) 99 | mask = mask * causal_mask # [B, H=1, T, T] 100 | 101 | # Set up activation collection. 102 | collected = collections.defaultdict(list) 103 | 104 | def collect(**kwargs): 105 | for k, v in kwargs.items(): 106 | collected[k].append(v) 107 | 108 | residual = embeddings 109 | for layer in range(self.config.num_layers): 110 | with hk.experimental.name_scope(f"layer_{layer}"): 111 | # First the attention block. 112 | attn_block = attention.MultiHeadAttention( 113 | num_heads=self.config.num_heads, 114 | key_size=self.config.key_size, 115 | model_size=model_size, 116 | w_init=initializer, 117 | name="attn") 118 | attn_in = layer_norm(residual) 119 | attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) 120 | attn_out, attn_logits = attn_out.out, attn_out.logits 121 | if dropout_rate > 0: 122 | attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) 123 | residual = residual + attn_out 124 | 125 | collect( 126 | residuals=residual, layer_outputs=attn_out, attn_logits=attn_logits) 127 | 128 | # Then the dense block. 129 | with hk.experimental.name_scope("mlp"): 130 | dense_block = hk.Sequential([ 131 | hk.Linear( 132 | self.config.mlp_hidden_size, 133 | w_init=initializer, 134 | name="linear_1"), 135 | self.config.activation_function, 136 | hk.Linear(model_size, w_init=initializer, name="linear_2"), 137 | ]) 138 | dense_in = layer_norm(residual) 139 | dense_out = dense_block(dense_in) 140 | if dropout_rate > 0: 141 | dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) 142 | residual = residual + dense_out 143 | 144 | collect(residuals=residual, layer_outputs=dense_out) 145 | 146 | return TransformerOutput( 147 | residuals=collected["residuals"], 148 | layer_outputs=collected["layer_outputs"], 149 | attn_logits=collected["attn_logits"], 150 | output=layer_norm(residual), 151 | input_embeddings=embeddings, 152 | ) 153 | 154 | 155 | @chex.dataclass 156 | class CompiledTransformerModelOutput: 157 | transformer_output: TransformerOutput 158 | unembedded_output: jax.Array # [B, T] 159 | unembedding_mtx: jax.Array # [D, V] 160 | 161 | 162 | @dataclasses.dataclass 163 | class CompiledTransformerModel(hk.Module): 164 | """A transformer model with one-hot embeddings.""" 165 | transformer: Transformer 166 | token_embed: CallableHaikuModule 167 | position_embed: CallableHaikuModule 168 | unembed: CallableHaikuModule 169 | unembed_mtx: jax.Array 170 | use_unembed_argmax: bool 171 | pad_token: Optional[int] = None 172 | 173 | def embed(self, tokens: jax.Array) -> jax.Array: 174 | token_embeddings = self.token_embed(tokens) 175 | positional_embeddings = self.position_embed(jnp.indices(tokens.shape)[-1]) 176 | return token_embeddings + positional_embeddings # [B, T, D] 177 | 178 | def __call__( 179 | self, 180 | tokens: jax.Array, 181 | use_dropout: bool = True, 182 | ) -> CompiledTransformerModelOutput: 183 | """Embed tokens, pass through model, and unembed output.""" 184 | if self.pad_token is None: 185 | input_mask = jnp.ones_like(tokens) 186 | else: 187 | input_mask = (tokens != self.pad_token) 188 | input_embeddings = self.embed(tokens) 189 | 190 | transformer_output = self.transformer( 191 | input_embeddings, 192 | input_mask, 193 | use_dropout=use_dropout, 194 | ) 195 | return CompiledTransformerModelOutput( 196 | transformer_output=transformer_output, 197 | unembedded_output=self.unembed( 198 | transformer_output.output, 199 | use_unembed_argmax=self.use_unembed_argmax, 200 | ), 201 | unembedding_mtx=self.unembed_mtx, 202 | ) 203 | -------------------------------------------------------------------------------- /tracrx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracrx/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/utils/__pycache__/errors.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/Edge-Pruning/54e79c9fc1f1e9b0cb8eac94163f186ac2eb1ec3/tracrx/utils/__pycache__/errors.cpython-310.pyc -------------------------------------------------------------------------------- /tracrx/utils/debugging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Useful helpers for model debugging.""" 16 | 17 | 18 | def print_arrays(arrays, labels=None, colwidth=12): 19 | """Pretty-prints a list of [1, T, D] arrays.""" 20 | if labels is not None: 21 | print(" |".join(labels)) 22 | widths = [len(l) for l in labels] 23 | else: 24 | widths = [colwidth] * len(arrays[0].shape[-1]) 25 | for layer in arrays: 26 | print("=" * (colwidth + 1) * layer.shape[1]) 27 | for row in layer[0]: 28 | print(" |".join([f"{x:<{width}.2f}" for x, width in zip(row, widths)])) 29 | -------------------------------------------------------------------------------- /tracrx/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helpers for handling errors in user-provided functions.""" 16 | 17 | import functools 18 | import logging 19 | from typing import Any, Callable 20 | 21 | 22 | def ignoring_arithmetic_errors(fun: Callable[..., Any]) -> Callable[..., Any]: 23 | """Makes fun return None instead of raising ArithmeticError.""" 24 | 25 | @functools.wraps(fun) 26 | def fun_wrapped(*args): 27 | try: 28 | return fun(*args) 29 | except ArithmeticError: 30 | logging.warning( 31 | "Encountered arithmetic error in function: for value %s. " 32 | "Assuming this input will never occur.", str(args)) 33 | return None 34 | 35 | return fun_wrapped 36 | -------------------------------------------------------------------------------- /tracrx/utils/errors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for rasp.helper.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracrx.utils import errors 20 | 21 | 22 | class FunIgnoreArithmeticErrorsTest(parameterized.TestCase): 23 | 24 | def test_ignoring_arithmetic_errors(self): 25 | fun = lambda x: 1 / x 26 | fun_ignore = errors.ignoring_arithmetic_errors(fun) 27 | 28 | with self.assertLogs(level="WARNING"): 29 | res = fun_ignore(0) 30 | self.assertIs(res, None) 31 | 32 | self.assertEqual(fun_ignore(1), 1) 33 | self.assertEqual(fun_ignore(2), 0.5) 34 | self.assertEqual(fun_ignore(-2), -0.5) 35 | 36 | def test_ignoring_arithmetic_errors_two_arguments(self): 37 | fun = lambda x, y: 1 / x + 1 / y 38 | fun_ignore = errors.ignoring_arithmetic_errors(fun) 39 | 40 | with self.assertLogs(level="WARNING"): 41 | res = fun_ignore(0, 1) 42 | self.assertIs(res, None) 43 | 44 | with self.assertLogs(level="WARNING"): 45 | res = fun_ignore(0, 0) 46 | self.assertIs(res, None) 47 | 48 | with self.assertLogs(level="WARNING"): 49 | res = fun_ignore(1, 0) 50 | self.assertIs(res, None) 51 | 52 | self.assertEqual(fun_ignore(1, 1), 2) 53 | self.assertEqual(fun_ignore(1, 2), 1.5) 54 | self.assertEqual(fun_ignore(-2, 2), 0) 55 | 56 | 57 | if __name__ == "__main__": 58 | absltest.main() 59 | --------------------------------------------------------------------------------