├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── compression_heatmap.png ├── setup.py └── tracr ├── __init__.py ├── compiler ├── __init__.py ├── assemble.py ├── assemble_test.py ├── basis_inference.py ├── basis_inference_test.py ├── compiling.py ├── craft_graph_to_model.py ├── craft_graph_to_model_test.py ├── craft_model_to_transformer.py ├── expr_to_craft_graph.py ├── expr_to_craft_graph_test.py ├── lib.py ├── lib_test.py ├── nodes.py ├── rasp_to_craft_integration_test.py ├── rasp_to_graph.py ├── rasp_to_graph_test.py ├── rasp_to_transformer_integration_test.py ├── test_cases.py ├── validating.py └── validating_test.py ├── craft ├── __init__.py ├── bases.py ├── bases_test.py ├── chamber │ ├── __init__.py │ ├── categorical_attn.py │ ├── categorical_attn_test.py │ ├── categorical_mlp.py │ ├── categorical_mlp_test.py │ ├── numerical_mlp.py │ ├── numerical_mlp_test.py │ ├── selector_width.py │ └── selector_width_test.py ├── tests_common.py ├── transformers.py ├── transformers_test.py ├── vectorspace_fns.py └── vectorspace_fns_test.py ├── examples ├── Visualize_Tracr_Models.ipynb └── __init__.py ├── rasp ├── __init__.py ├── causal_eval.py ├── causal_eval_test.py ├── rasp.py └── rasp_test.py ├── transformer ├── __init__.py ├── attention.py ├── compressed_model.py ├── compressed_model_test.py ├── encoder.py ├── encoder_test.py ├── model.py └── model_test.py └── utils ├── __init__.py ├── debugging.py ├── errors.py └── errors_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We welcome your contributions to this project. Please read the guidance below 4 | first. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tracr: TRAnsformer Compiler for RASP. 2 | 3 | Tracr is a compiler for converting RASP programs 4 | ([Weiss et al. 2021](https://arxiv.org/abs/2106.06981)) 5 | into transformer weights. Please see our 6 | [tech report](https://arxiv.org/abs/2301.05062) for a detailed description of 7 | the compiler. 8 | 9 | Directory structure: 10 | 11 | * `rasp` contains an implementation of RASP embedded in Python. 12 | * `compiler` contains the compiler itself. 13 | * `transformer` contains the implementation of the transformer. 14 | * `craft` contains the intermediate representation used by the compiler: 15 | essentially a small linear algebra-based library with named dimensions. 16 | 17 | This is not an officially supported Google product. 18 | 19 | 20 | ## Installation 21 | 22 | Just clone and pip install: 23 | 24 | ``` 25 | git clone https://github.com/deepmind/tracr 26 | cd tracr 27 | pip3 install . 28 | ``` 29 | 30 | 31 | ## Usage example: RASP `reverse` program 32 | 33 | Consider the RASP `reverse` program: 34 | 35 | ``` 36 | opp_index = length - indices - 1; 37 | flip = select(indices, opp_index, ==); 38 | reverse = aggregate(flip, tokens); 39 | ``` 40 | 41 | To compile this with Tracr, we would first implement the program using Tracr's 42 | RASP library: 43 | 44 | ```python 45 | from tracr.rasp import rasp 46 | 47 | length = make_length() # `length` is not a primitive in our implementation. 48 | opp_index = length - rasp.indices - 1 49 | flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ) 50 | reverse = rasp.Aggregate(flip, rasp.tokens) 51 | ``` 52 | 53 | Where: 54 | 55 | ```python 56 | def make_length(): 57 | all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE) 58 | return rasp.SelectorWidth(all_true_selector) 59 | ``` 60 | 61 | We can then compile the RASP program to a transformer with: 62 | 63 | ```python 64 | from tracr.compiler import compiling 65 | 66 | bos = "BOS" 67 | model = compiling.compile_rasp_to_model( 68 | reverse, 69 | vocab={1, 2, 3}, 70 | max_seq_len=5, 71 | compiler_bos=bos, 72 | ) 73 | ``` 74 | 75 | This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model. 76 | This model isn't intended to provide _everything_ you might need, but rather serves 77 | as a kind of "documentation-in-code" for the semantics of the generated parameters. 78 | The expectation is that the user can then write or contribute an adapter that converts 79 | parameters from this reference model to another transformer implementation. 80 | 81 | Using this model we can perform a forward pass: 82 | 83 | ```python 84 | >>> out = model.apply([bos, 1, 2, 3]) 85 | >>> out.decoded 86 | ["BOS", 3, 2, 1] 87 | ``` 88 | 89 | Success! We have a transformer that reverses its input tokens. 90 | 91 | Note: compiled models always expect a BOS token in order to support 92 | selectors which don't attend to any of the input tokens. This is necessary to 93 | preserve intuitive RASP semantics; the alternative would have been to treat 94 | all-False selector rows as equivalent to all-True (which is what softmax in an 95 | attention layer would naturally do). For more details, see our paper. 96 | 97 | You can also inspect some of the intermediate activations of the model, using 98 | `out.residuals`, `out.layer_outputs`, and `out.attn_logits`. 99 | 100 | For more examples of RASP programs we can compile, check out 101 | [compiler/lib.py](tracr/compiler/lib.py). 102 | 103 | For an interactive example of compiling a model and visualizing its computation, 104 | check out the notebook at 105 | [examples/Visualize\_Tracr\_Models.ipynb](tracr/examples/Visualize_Tracr_Models.ipynb). 106 | 107 | 108 | ## Developer README 109 | 110 | If you'd like to extend Tracr to fit your purposes, here's some information on 111 | how Tracr works under the hood. 112 | 113 | 114 | ### How Tracr works conceptually 115 | 116 | To compile a program, Tracr does the following. 117 | 118 | 1. **Trace RASP program into a graph representation.** This involves creating 119 | a graph node for each RASP expression and inferring dependencies between 120 | these graph nodes. 121 | 122 | 2. **Infer bases.** Tracr is designed to have each node output to a separate 123 | subspace of the residual stream. To do this, we first infer the set of all 124 | possible token values that each node can take, then using that information, 125 | decide on a subspace for each node, and augment each node in the graph 126 | with the basis vectors for that node's subspace. 127 | 128 | 3. **Convert nodes to Craft components.** Craft is the name of our internal 129 | intermediate representation that does linear algebra on named subspaces. In 130 | this stage, each expression node is converted to a Craft component that 131 | actually performs the linear algebra operations necessary to implement the 132 | expression. This includes converting _sequence operators_ to MLP weights, 133 | and _selectors_ to weights of attention heads. (We compute the appropriate 134 | weights directly using the theory of universal approximation for MLPs - no 135 | gradient descent required!) 136 | 137 | 4. **Convert Craft graph to Craft model.** In this stage, we convert from 138 | a graph representation to a layout that looks more like an actual 139 | transformer. At this stage, we essentially have a working model, but 140 | with the linear algebra done using Craft rather than JAX + Haiku. 141 | 142 | 5. **Convert Craft model to Haiku model.** Finally, we convert our 143 | intermediate representation of the model to a full Haiku model. 144 | 145 | Two details worth expanding on here are subspaces and corresponding bases. 146 | Each node writes to a separate subspace of the residual stream, 147 | where each subspace is simply a unique chunk of the residual stream vector. 148 | For example, the first node might write to the first 5 components of 149 | the residual stream; the second node the next 5; and so on. In terms of what 150 | the embeddings actually associated with each node, Tracr employs two 151 | different kinds of bases: 152 | 153 | * **Categorical representation** - in which each unique token value is 154 | represented as a unique one-hot vector in that node's subspace. This 155 | is the representation used by default. 156 | * **Numerical representation** - in which each unique token value is 157 | mapped to a unique scalar value. This is necessary for some uses 158 | of the `aggregate` operation - essentially, ones which involve taking 159 | a mean - and some other operations are represented more efficiently 160 | with this representation. 161 | 162 | A final detail is BOS tokens. The compiler relies on beginning-of-sequence 163 | tokens to in order to implement a number of operations. This is why token 164 | sequences fed into the final model _must_ start with a BOS token. 165 | 166 | 167 | ### How Tracr works in practice 168 | 169 | The flow of compilation execution begins in 170 | [`compiler/compiling.py`](tracr/compiler/compiling.py), in the 171 | `compile_rasp_to_model` function. This function is fairly short and maps 172 | directly to the stages outlined above, so don't be afraid to read the source! 173 | 174 | 175 | ### Running tests 176 | 177 | We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is 178 | `unittest`-compatible, and is therefore in turn `pytest`-compatible. 179 | 180 | First, install test dependencies: 181 | 182 | ``` 183 | pip3 install absl-py pytest 184 | ``` 185 | 186 | Then, in the checkout directory, simply run `pytest`. This should take about 60 187 | seconds. 188 | 189 | 190 | ## Superposition 191 | 192 | One topic that we've investigated using Tracr is superposition (see e.g. 193 | [Elhage et al 2023](https://transformer-circuits.pub/2022/toy_model/index.html)): 194 | in this work, we learn a compressed embedding of the residual stream in such a 195 | way as to keep the computation faithful to the uncompressed program. 196 | 197 | This is an example showing the dot products between embedding vectors for the 198 | `frac_prevs` example program from [compiler/lib.py](tracr/compiler/lib.py): 199 | 200 | ![Matrix of dot-products](compression_heatmap.png) 201 | 202 | The code for learning these embeddings is not included in this repository, but 203 | you can read more about it in Section 5 of the 204 | [tech report](https://arxiv.org/abs/2301.05062). 205 | 206 | 207 | ## Citing Tracr 208 | 209 | Please use the bibtex for our tech report: 210 | 211 | ``` 212 | @article{lindner2023tracr, 213 | title = {Tracr: Compiled Transformers as a Laboratory for Interpretability}, 214 | author = {Lindner, David and Kramár, János and Rahtz, Matthew and McGrath, Thomas and Mikulik, Vladimir}, 215 | journal={arXiv preprint arXiv:2301.05062}, 216 | year={2023} 217 | } 218 | ``` 219 | -------------------------------------------------------------------------------- /compression_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tracr/9ce2b8c82b6ba10e62e86cf6f390e7536d4fd2cd/compression_heatmap.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install script.""" 16 | 17 | import setuptools 18 | 19 | setuptools.setup( 20 | name="tracr", 21 | version="1.0.0", 22 | url="https://github.com/deepmind/tracr", 23 | author="DeepMind LMI team", 24 | author_email="tracr-devs@google.com", 25 | description="Compiler from RASP to transformer weights", 26 | packages=setuptools.find_packages(), 27 | install_requires=[ 28 | "chex", 29 | "einops", 30 | "dm-haiku", 31 | "jax", 32 | "networkx", 33 | "numpy", 34 | "typing_extensions", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /tracr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/compiler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides the main compiler function as a public import.""" 16 | 17 | from tracr.compiler.compiling import compile_rasp_to_model 18 | 19 | __all__ = ["compile_rasp_to_model"] 20 | -------------------------------------------------------------------------------- /tracr/compiler/assemble_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformer.assemble.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from tracr.compiler import assemble 24 | from tracr.craft import bases 25 | 26 | 27 | class AssembleTest(parameterized.TestCase): 28 | 29 | def test_token_embedding_produces_correct_embedding(self): 30 | # Token embeddings should be one-hot embeddings of the input integers 31 | # into the token subspace of residual_space 32 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 33 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 34 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 35 | residual_space = bases.join_vector_spaces(input_space, indices_space, 36 | output_space) 37 | 38 | @hk.without_apply_rng 39 | @hk.transform 40 | def token_pos_embed(tokens): 41 | embed_modules = assemble._make_embedding_modules( 42 | residual_space=residual_space, 43 | tokens_space=input_space, 44 | indices_space=indices_space, 45 | output_space=output_space) 46 | return embed_modules.token_embed(tokens) 47 | 48 | tokens = jnp.array([0, 0, 1]) 49 | expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0], 50 | [1, 0, 0, 0, 0, 0, 0], 51 | [0, 1, 0, 0, 0, 0, 0]]) 52 | 53 | params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) 54 | embeddings = token_pos_embed.apply(params, tokens) 55 | np.testing.assert_allclose(embeddings, expected_token_embeddings) 56 | 57 | def test_position_embedding_produces_correct_embedding(self): 58 | # Position embeddings should be one-hot embeddings of the input integers 59 | # (representing indices) into the indices subspace of residual_space 60 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 61 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 62 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 63 | residual_space = bases.join_vector_spaces(input_space, indices_space, 64 | output_space) 65 | 66 | @hk.without_apply_rng 67 | @hk.transform 68 | def token_pos_embed(tokens): 69 | embed_modules = assemble._make_embedding_modules( 70 | residual_space=residual_space, 71 | tokens_space=input_space, 72 | indices_space=indices_space, 73 | output_space=output_space) 74 | return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1]) 75 | 76 | tokens = jnp.array([3, 0, 0, 1]) 77 | expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0], 78 | [0, 0, 1, 0, 0, 0, 0], 79 | [0, 0, 0, 1, 0, 0, 0], 80 | [0, 0, 0, 0, 1, 0, 0]]) 81 | 82 | params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) 83 | embeddings = token_pos_embed.apply(params, tokens) 84 | np.testing.assert_allclose(embeddings, expected_pos_embeddings) 85 | 86 | def test_unembedding(self): 87 | # Prepend numbers to preserve basis order [input, index, output] 88 | input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) 89 | indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) 90 | output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) 91 | residual_space = bases.join_vector_spaces(input_space, indices_space, 92 | output_space) 93 | 94 | @hk.without_apply_rng 95 | @hk.transform 96 | def unembed(embeddings): 97 | embed_modules = assemble._make_embedding_modules( 98 | residual_space=residual_space, 99 | tokens_space=input_space, 100 | indices_space=indices_space, 101 | output_space=output_space) 102 | return embed_modules.unembed(embeddings, use_unembed_argmax=True) 103 | 104 | embeddings = jnp.array([ 105 | # pylint: disable=g-no-space-after-comment 106 | #inp| indices| out | < spaces 107 | #0 1 0 1 2 0 1 < values in spaces 108 | [0, 0, 0, 0, 0, 0, 1], 109 | [0, 0, 0, 0, 0, 1, 0], 110 | [0, 0, 0, 0, 0, 0, 1] 111 | ]) 112 | expected_tokens = jnp.array([1, 0, 1]) 113 | 114 | params = unembed.init(jax.random.PRNGKey(0), embeddings) 115 | tokens = unembed.apply(params, embeddings) 116 | np.testing.assert_allclose(tokens, expected_tokens) 117 | 118 | 119 | if __name__ == "__main__": 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /tracr/compiler/basis_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Inferring the vector spaces taken on by certain operations.""" 16 | 17 | import dataclasses 18 | import itertools 19 | from typing import Set 20 | 21 | import networkx as nx 22 | from tracr.compiler import nodes 23 | from tracr.craft import bases 24 | from tracr.rasp import rasp 25 | from tracr.utils import errors 26 | 27 | Node = nodes.Node 28 | 29 | 30 | @dataclasses.dataclass 31 | class InferBasesOutput: 32 | graph: nx.DiGraph 33 | 34 | 35 | def infer_bases( 36 | graph: nx.DiGraph, 37 | sink: Node, 38 | vocab: Set[rasp.Value], 39 | max_seq_len: int, 40 | ) -> None: 41 | """Infers in-place the possible output values and vector bases of the SOps.""" 42 | 43 | def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]: 44 | """Computes value set using already-computed predecessor value sets.""" 45 | if isinstance(sop, rasp.TokensType): 46 | return vocab 47 | elif isinstance(sop, rasp.IndicesType): 48 | return set(range(max_seq_len)) 49 | elif isinstance(sop, rasp.SelectorWidth): 50 | return set(range(0, max_seq_len + 1)) 51 | elif isinstance(sop, rasp.Full): 52 | return {sop.fill} 53 | elif isinstance(sop, rasp.Map): 54 | inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET] 55 | out = set() 56 | for x in inner_value_set: 57 | res = errors.ignoring_arithmetic_errors(sop.f)(x) 58 | if res is not None: 59 | out.add(res) 60 | return out 61 | elif isinstance(sop, rasp.SequenceMap): 62 | f_ignore_error = errors.ignoring_arithmetic_errors(sop.f) 63 | fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET] 64 | snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET] 65 | out = set() 66 | for l, r in itertools.product(fst_value_set, snd_value_set): 67 | res = f_ignore_error(l, r) 68 | if res is not None: 69 | out.add(res) 70 | return out 71 | elif isinstance(sop, rasp.Aggregate): 72 | if rasp.is_categorical(sop): 73 | # Simply pass on the value set of the underlying S-Op. 74 | return graph.nodes[sop.sop.label][nodes.VALUE_SET] 75 | elif rasp.is_numerical(sop): 76 | # TODO(b/255936408): This doesn't work if we average arbitrary values. 77 | # But most examples only average binary variables. 78 | sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET] 79 | if not {int(x) for x in sop_value_set}.issubset({0, 1}): 80 | raise NotImplementedError( 81 | "Attention patterns can currently only " 82 | "average binary variables. Not:", sop_value_set) 83 | 84 | value_set = set() 85 | for value in sop_value_set: 86 | for length in range(1, max_seq_len + 1): 87 | value_set.add(value / length) 88 | return value_set 89 | raise ValueError(f"Unsupported S-Op: {sop}") 90 | 91 | for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]): 92 | expr = graph.nodes[node_id][nodes.EXPR] 93 | 94 | if not isinstance(expr, rasp.SOp): 95 | # Only S-Ops have output vector spaces. 96 | continue 97 | 98 | value_set = compute_value_set(expr) 99 | graph.nodes[node_id][nodes.VALUE_SET] = value_set 100 | 101 | if rasp.is_categorical(expr): 102 | out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set) 103 | elif rasp.is_numerical(expr): 104 | out_space = bases.VectorSpaceWithBasis.from_names([expr.label]) 105 | else: 106 | raise ValueError(f"Unsupported S-Op type: {expr.type}") 107 | graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis 108 | -------------------------------------------------------------------------------- /tracr/compiler/basis_inference_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.basis_inference.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.compiler import basis_inference 20 | from tracr.compiler import nodes 21 | from tracr.compiler import rasp_to_graph 22 | from tracr.rasp import rasp 23 | 24 | 25 | class InferBasesTest(parameterized.TestCase): 26 | 27 | def test_arithmetic_error_logs_warning(self): 28 | program = rasp.numerical(rasp.Map(lambda x: 1 / x, rasp.tokens)) 29 | extracted = rasp_to_graph.extract_rasp_graph(program) 30 | vocab = {0, 1, 2} 31 | with self.assertLogs(level="WARNING"): 32 | basis_inference.infer_bases( 33 | extracted.graph, 34 | extracted.sink, 35 | vocab, 36 | max_seq_len=1, 37 | ) 38 | 39 | @parameterized.parameters(({1, 2, 3}, {2, 3, 4}), ({0, 5}, {1, 6})) 40 | def test_one_edge(self, vocab, expected_value_set): 41 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 42 | extracted = rasp_to_graph.extract_rasp_graph(program) 43 | 44 | basis_inference.infer_bases( 45 | extracted.graph, 46 | extracted.sink, 47 | vocab, 48 | max_seq_len=1, 49 | ) 50 | 51 | self.assertSetEqual( 52 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 53 | expected_value_set, 54 | ) 55 | 56 | def test_primitive_close_to_tip(self): 57 | intermediate = rasp.categorical(rasp.tokens + 1) 58 | intermediate = rasp.categorical(intermediate + intermediate) 59 | program = rasp.categorical(intermediate + rasp.indices) 60 | extracted = rasp_to_graph.extract_rasp_graph(program) 61 | 62 | basis_inference.infer_bases( 63 | extracted.graph, 64 | extracted.sink, 65 | {0, 1}, 66 | max_seq_len=2, 67 | ) 68 | 69 | self.assertSetEqual( 70 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 71 | {2, 3, 4, 5}, 72 | ) 73 | self.assertSetEqual( 74 | extracted.graph.nodes[intermediate.label][nodes.VALUE_SET], 75 | {2, 3, 4}, 76 | ) 77 | 78 | @parameterized.named_parameters( 79 | dict( 80 | testcase_name="categorical_aggregate", 81 | program=rasp.categorical( 82 | rasp.Aggregate( 83 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 84 | rasp.indices, 85 | ) 86 | ), 87 | vocab={0, 1}, 88 | max_seq_len=3, 89 | expected_value_set={0, 1, 2}, 90 | ), 91 | dict( 92 | testcase_name="numerical_aggregate", 93 | program=rasp.numerical( 94 | rasp.Aggregate( 95 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 96 | rasp.tokens, 97 | ) 98 | ), 99 | vocab={0, 1}, 100 | max_seq_len=2, 101 | expected_value_set={0, 1, 1 / 2}, 102 | ), 103 | dict( 104 | testcase_name="selector_width", 105 | program=rasp.SelectorWidth( 106 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ) 107 | ), 108 | vocab={0, 1}, 109 | max_seq_len=3, 110 | expected_value_set={0, 1, 2, 3}, 111 | ), 112 | dict( 113 | testcase_name="annotated_tokens", 114 | program=rasp.categorical(rasp.tokens), 115 | vocab={"a", "b"}, 116 | max_seq_len=2, 117 | expected_value_set={"a", "b"}, 118 | ), 119 | dict( 120 | testcase_name="annotated_indices", 121 | program=rasp.categorical(rasp.indices), 122 | vocab={"a", "b"}, 123 | max_seq_len=2, 124 | expected_value_set={0, 1}, 125 | ), 126 | ) 127 | def test_inferred_value_set_as_expected( 128 | self, program, vocab, max_seq_len, expected_value_set 129 | ): 130 | extracted = rasp_to_graph.extract_rasp_graph(program) 131 | 132 | basis_inference.infer_bases( 133 | extracted.graph, 134 | extracted.sink, 135 | vocab, 136 | max_seq_len=max_seq_len, 137 | ) 138 | 139 | self.assertSetEqual( 140 | extracted.graph.nodes[program.label][nodes.VALUE_SET], 141 | expected_value_set, 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | absltest.main() 147 | -------------------------------------------------------------------------------- /tracr/compiler/compiling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Combines all steps of compiling a RASP program.""" 16 | 17 | from typing import Set 18 | 19 | from tracr.compiler import assemble 20 | from tracr.compiler import basis_inference 21 | from tracr.compiler import craft_graph_to_model 22 | from tracr.compiler import craft_model_to_transformer 23 | from tracr.compiler import expr_to_craft_graph 24 | from tracr.compiler import rasp_to_graph 25 | from tracr.compiler import validating 26 | from tracr.craft import bases 27 | from tracr.rasp import rasp 28 | 29 | 30 | COMPILER_BOS = "compiler_bos" 31 | COMPILER_PAD = "compiler_pad" 32 | 33 | 34 | def compile_rasp_to_model( 35 | program: rasp.SOp, 36 | vocab: Set[rasp.Value], 37 | max_seq_len: int, 38 | causal: bool = False, 39 | compiler_bos: str = COMPILER_BOS, 40 | compiler_pad: str = COMPILER_PAD, 41 | mlp_exactness: int = 100, 42 | ) -> assemble.AssembledTransformerModel: 43 | """Compile a RASP program to transformer weights. 44 | 45 | Note that currently not all RASP features are supported. Most unsupported 46 | features are detected at compile time and will cause a NotImplementedError. 47 | However, a few unsupported features cannot be checked at compile time and 48 | can cause silent errors. 49 | 50 | See `compiler.validating` for details and a function to quickly check if 51 | a program is compilable with Tracr without needing to compile it. 52 | 53 | Args: 54 | program: the RASP program to compile. 55 | vocab: the set of vocab tokens expected by RASP. 56 | max_seq_len: the maximum sequence length for the compiled model. 57 | causal: if True, outputs a model with causal masking. 58 | compiler_bos: the name of the special BOS token that will be added by the 59 | compiler. Must not be present in the vocab. 60 | compiler_pad: the name of the special PAD token that will be added by the 61 | compiler. Must not be present in the vocab. 62 | mlp_exactness: Controls the approximation of the MLP layers. In theory, 63 | larger values yield a better approximation. But too large values can cause 64 | numerical issues due to large parameter norms. Reasonable values are 65 | between 1 and 100. 66 | 67 | Returns: 68 | The compiled model. 69 | 70 | Raises: 71 | NotImplementedError: if the program uses unsopported features that can be 72 | caught at compile time. 73 | """ 74 | 75 | if compiler_bos in vocab: 76 | raise ValueError( 77 | "Compiler BOS token must not be present in the vocab. " 78 | f"Found '{compiler_bos}' in {vocab}" 79 | ) 80 | 81 | if compiler_pad in vocab: 82 | raise ValueError( 83 | "Compiler PAD token must not be present in the vocab. " 84 | f"Found '{compiler_pad}' in {vocab}" 85 | ) 86 | 87 | # Perform static validation to fail fast. This catches most programs that 88 | # tracr is unable to compile. 89 | unsupported_exprs = validating.static_validate(program) 90 | if unsupported_exprs: 91 | error_message = "\n".join( 92 | (f"{expr.expr.name}: {expr.reason}" for expr in unsupported_exprs) 93 | ) 94 | error_message = f"Unsupported RASP expressions:\n{error_message}" 95 | raise NotImplementedError(error_message) 96 | 97 | extracted = rasp_to_graph.extract_rasp_graph(program) 98 | graph, sources, sink = extracted.graph, extracted.sources, extracted.sink 99 | 100 | basis_inference.infer_bases( 101 | graph, 102 | sink, 103 | vocab, 104 | max_seq_len, 105 | ) 106 | 107 | expr_to_craft_graph.add_craft_components_to_rasp_graph( 108 | graph, 109 | bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos), 110 | mlp_exactness=mlp_exactness, 111 | ) 112 | 113 | craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources) 114 | 115 | return craft_model_to_transformer.craft_model_to_transformer( 116 | craft_model=craft_model, 117 | graph=graph, 118 | sink=sink, 119 | max_seq_len=max_seq_len, 120 | causal=causal, 121 | compiler_bos=compiler_bos, 122 | compiler_pad=compiler_pad, 123 | ) 124 | -------------------------------------------------------------------------------- /tracr/compiler/craft_graph_to_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Create a craft model from a computational graph.""" 16 | 17 | import collections 18 | from typing import Dict, List, Sequence 19 | 20 | import networkx as nx 21 | from tracr.compiler import nodes 22 | from tracr.craft import bases 23 | from tracr.craft import transformers 24 | from tracr.rasp import rasp 25 | 26 | Node = nodes.Node 27 | NodeID = nodes.NodeID 28 | 29 | 30 | def compute_computational_depth( 31 | graph: nx.DiGraph, sources_ids: Sequence[int] 32 | ) -> Dict[int, int]: 33 | """Returns the computational depth of each node in the graph. 34 | 35 | Given source nodes, runs DFS to tracr the maximum computational depth from all 36 | source nodes to every node in the graph. 37 | 38 | Non-SOp nodes do not count in the depth calculation. 39 | 40 | Disconnected nodes have depth -1. 41 | 42 | Args: 43 | graph: RASP computational graph where all nodes are annotated with # EXPR 44 | attributes set to rasp primitives 45 | sources_ids: Sequence of integers to measure computational depth against 46 | 47 | Returns: 48 | a dictionary mapping all graph nodes to a computational depth 49 | """ 50 | computational_depth = {} 51 | 52 | def dfs(node_id, depth): 53 | if node_id in computational_depth: 54 | computational_depth[node_id] = max(depth, computational_depth[node_id]) 55 | else: 56 | computational_depth[node_id] = depth 57 | 58 | for successor_id in graph.successors(node_id): 59 | if not isinstance(graph.nodes[successor_id][nodes.EXPR], rasp.SOp): 60 | dfs(successor_id, depth) 61 | else: 62 | dfs(successor_id, depth + 1) 63 | 64 | for source_id in sources_ids: 65 | dfs(source_id, depth=0) 66 | 67 | # ensure any disconnected nodes are given a depth -1 68 | disconnected_nodes = set(graph.nodes) - set(computational_depth.keys()) 69 | for disconnected_node in disconnected_nodes: 70 | computational_depth[disconnected_node] = -1 71 | 72 | return computational_depth 73 | 74 | 75 | def _node_is_attn(node: Node) -> bool: 76 | """Returns True if node is an attention layer.""" 77 | return nodes.MODEL_BLOCK in node and isinstance( 78 | node[nodes.MODEL_BLOCK], 79 | (transformers.AttentionHead, transformers.MultiAttentionHead), 80 | ) 81 | 82 | 83 | def _node_is_mlp(node: Node) -> bool: 84 | """Returns True if node is an MLP layer.""" 85 | return nodes.MODEL_BLOCK in node and isinstance( 86 | node[nodes.MODEL_BLOCK], transformers.MLP 87 | ) 88 | 89 | 90 | def _node_is_residual_block(node: Node) -> bool: 91 | """Returns True if node is a valid residual block (Attn followed by MLP).""" 92 | block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None 93 | if block and isinstance(block, transformers.SeriesWithResiduals): 94 | if len(block.blocks) == 2: 95 | attn, mlp = block.blocks 96 | if isinstance( 97 | attn, (transformers.AttentionHead, transformers.MultiAttentionHead) 98 | ) and isinstance(mlp, transformers.MLP): 99 | return True 100 | return False 101 | 102 | 103 | def _all_attn_nodes(node_list: Sequence[Node]) -> bool: 104 | """Returns True iff all nodes are attention layers (or nodes is empty).""" 105 | for node in node_list: 106 | if not _node_is_attn(node): 107 | return False 108 | return True 109 | 110 | 111 | def _all_mlp_nodes(node_list: Sequence[Node]) -> bool: 112 | """Returns True iff all nodes are MLP layers (or nodes is empty).""" 113 | for node in node_list: 114 | if not _node_is_mlp(node): 115 | return False 116 | return True 117 | 118 | 119 | def _allocate_modules_to_layers( 120 | graph: nx.DiGraph, sources: Sequence[Node] 121 | ) -> Dict[int, int]: 122 | """Allocate all nodes in compute graph to layers. 123 | 124 | First, computes the longest path from the input to each node that is a model 125 | component (not input and output nodes). The longest path to a model component 126 | (its "depth") determines a layer in which we can place it while ensuring that 127 | all necessary previous computations have already happened. 128 | 129 | This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...] 130 | 131 | In the special case where there are only Attention layers at one depth level 132 | and only MLP layers in the next depth layer, they are treated as if there 133 | are at the same depth because attention layers always come before MLP layers 134 | for the same depth. 135 | 136 | Args: 137 | graph: RASP graph with craft blocks. 138 | sources: List of input nodes 139 | 140 | Returns: 141 | A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ... 142 | are in the order attention, mlp, attention, mlp, ... 143 | """ 144 | layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1) 145 | depth_by_node_id: Dict[int, int] = dict() 146 | nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list) 147 | computational_depth = compute_computational_depth( 148 | graph, [src[nodes.ID] for src in sources] 149 | ) 150 | for node_id, node in graph.nodes.items(): 151 | if ( 152 | _node_is_mlp(node) 153 | or _node_is_attn(node) 154 | or _node_is_residual_block(node) 155 | ): 156 | # Node is a model component 157 | depth = computational_depth[node_id] 158 | depth_by_node_id[node_id] = depth 159 | nodes_by_depth[depth].append(node) 160 | 161 | # If at level `depth` there are only attention heads and at level `depths + 1` 162 | # there are only MLPs, we can condense them into one level 163 | # TODO(b/255936816): Think about improving this heuristic. The heuristic is 164 | # not optimal, and only catches very basic opportunities for optimization. It 165 | # is easy to come up with opportunities for optimization that it does not 166 | # catch. 167 | min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys()) 168 | depth = min_depth 169 | while depth < max_depth: 170 | if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes( 171 | nodes_by_depth[depth + 1] 172 | ): 173 | # Condense by decrementing the depth of all nodes starting from depth+1 174 | for update_depth in range(depth + 1, max_depth + 1): 175 | for node in nodes_by_depth[update_depth]: 176 | node_id = node[nodes.ID] 177 | depth_by_node_id[node_id] = update_depth - 1 178 | nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth]) 179 | nodes_by_depth[update_depth] = [] 180 | max_depth -= 1 181 | depth += 1 182 | 183 | # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ... 184 | current_layer = 0 185 | current_depth = 1 186 | for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]): 187 | while depth > current_depth: 188 | current_depth += 1 189 | current_layer += 2 190 | if depth == current_depth: 191 | if _node_is_residual_block(graph.nodes[node_id]): 192 | layer_allocation[node_id] = current_layer 193 | else: 194 | is_mlp = _node_is_mlp(graph.nodes[node_id]) 195 | layer_allocation[node_id] = current_layer + int(is_mlp) 196 | 197 | return layer_allocation 198 | 199 | 200 | def craft_graph_to_model( 201 | graph: nx.DiGraph, sources: Sequence[Node] 202 | ) -> transformers.SeriesWithResiduals: 203 | """Translates a RASP graph with craft blocks into a full craft model. 204 | 205 | 1. Allocate modules to layers, assuming layers in the order 206 | 2. Creates subspaces for all inputs and outputs, and builds residual stream. 207 | 3. Assembles everything into a craft model and returns it. 208 | 209 | Args: 210 | graph: RASP graph with craft blocks. 211 | sources: List of input nodes 212 | 213 | Returns: 214 | A craft model that can be compiled to model weights. 215 | 216 | Raises: 217 | ValueError: On invalid input (if the craft_graph does not have craft blocks 218 | already specified) 219 | """ 220 | layer_allocation = _allocate_modules_to_layers(graph, sources) 221 | blocks_by_layer = collections.defaultdict(list) 222 | model_blocks = [] 223 | 224 | residual_space = bases.VectorSpaceWithBasis([]) 225 | 226 | for node_id, layer_no in layer_allocation.items(): 227 | node = graph.nodes[node_id] 228 | block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None 229 | 230 | if _node_is_residual_block(node): 231 | assert isinstance(block, transformers.SeriesWithResiduals) 232 | assert len(block.blocks) == 2 233 | residual_space = bases.join_vector_spaces( 234 | residual_space, 235 | block.blocks[0].residual_space, 236 | block.blocks[1].residual_space, 237 | ) 238 | blocks_by_layer[layer_no].append(block.blocks[0]) 239 | blocks_by_layer[layer_no + 1].append(block.blocks[1]) 240 | elif block: 241 | residual_space = bases.join_vector_spaces( 242 | residual_space, node[nodes.MODEL_BLOCK].residual_space 243 | ) 244 | blocks_by_layer[layer_no].append(block) 245 | 246 | for layer_no, layer_blocks in sorted( 247 | blocks_by_layer.items(), key=lambda x: x[0] 248 | ): 249 | for block in layer_blocks: 250 | block.residual_space = residual_space 251 | 252 | if layer_blocks: 253 | if layer_no % 2 == 0: # Attention Layer 254 | multi_head_attn = transformers.MultiAttentionHead(layer_blocks) 255 | model_blocks.append(multi_head_attn) 256 | else: # MLP Layer 257 | parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks) 258 | model_blocks.append(parallel_mlp) 259 | 260 | return transformers.SeriesWithResiduals(model_blocks) 261 | -------------------------------------------------------------------------------- /tracr/compiler/craft_graph_to_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.craft_graph_to_model.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import networkx as nx 20 | from tracr.compiler import craft_graph_to_model 21 | from tracr.compiler import nodes 22 | from tracr.compiler import rasp_to_graph 23 | from tracr.craft import bases 24 | from tracr.craft.chamber import categorical_attn 25 | from tracr.craft.chamber import categorical_mlp 26 | from tracr.rasp import rasp 27 | 28 | 29 | class CraftAllocateModulesToLayersTest(parameterized.TestCase): 30 | 31 | def _get_dummy_block(self, block_type): 32 | if block_type == "ATTN": 33 | return categorical_attn.categorical_attn( 34 | query_space=bases.VectorSpaceWithBasis.from_names(["query"]), 35 | key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]), 36 | value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]), 37 | output_space=bases.VectorSpaceWithBasis.from_names(["output"]), 38 | bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]), 39 | one_space=bases.VectorSpaceWithBasis.from_names(["one"]), 40 | attn_fn=lambda x, y: True, 41 | ) 42 | elif block_type == "MLP": 43 | return categorical_mlp.map_categorical_mlp( 44 | input_space=bases.VectorSpaceWithBasis.from_names(["input"]), 45 | output_space=bases.VectorSpaceWithBasis.from_names(["output"]), 46 | operation=lambda x: x, 47 | ) 48 | else: 49 | return None 50 | 51 | def test_compute_computational_depth_returns_expected_result(self): 52 | """Creates a graph and checks the longest path for each node.""" 53 | 54 | # Node IDs: 55 | # 0 -- 1 -- 2 -- 3 ------------ 4 56 | # / / 57 | # 5 -- 6 ---------- 7 -- 8 -- 9 58 | # 59 | # 10 60 | # Expected return values: 61 | # 0 -- 1 -- 2 -- 3 ------------ 5 62 | # / / 63 | # 0 -- 1 ---------- 2 -- 3 -- 4 64 | # 65 | # -1 66 | 67 | graph = nx.DiGraph() 68 | node_ids = list(range(11)) 69 | expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1] 70 | for node_id, res in zip(node_ids, expected_results): 71 | graph.add_node( 72 | node_id, **{ 73 | nodes.ID: node_id, 74 | nodes.EXPR: rasp.ConstantSOp(1), 75 | "expected_result": res 76 | }) 77 | graph.add_edge(0, 1) 78 | graph.add_edge(1, 2) 79 | graph.add_edge(2, 3) 80 | graph.add_edge(3, 4) 81 | graph.add_edge(5, 6) 82 | graph.add_edge(6, 7) 83 | graph.add_edge(7, 8) 84 | graph.add_edge(8, 9) 85 | graph.add_edge(6, 3) 86 | graph.add_edge(9, 4) 87 | sources = [graph.nodes[0], graph.nodes[5]] 88 | 89 | computational_depth = craft_graph_to_model.compute_computational_depth( 90 | graph, [src[nodes.ID] for src in sources] 91 | ) 92 | for node_id, node in graph.nodes.items(): 93 | self.assertEqual(computational_depth[node_id], node["expected_result"]) 94 | 95 | def test_allocate_modules_to_layers_returns_expected_result(self): 96 | """Creates a graph and checks if the correct layer assignment is returned.""" 97 | 98 | # Computation Graph: 99 | # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT 100 | # / / / 101 | # INPUT -- MLP --- MLP ATTN 102 | # \ / 103 | # ATTN 104 | # Node IDs: 105 | # 0 -- 1 -- 2 -- 3 -- 4 -- 5 106 | # / / / 107 | # 6 -- 7 ---- 8 9 108 | # \ / 109 | # 10 110 | # Expected layer allocation: 111 | # -1 -- 0 -- 3 -- 4 -- 7 -- -1 112 | # / / / 113 | # -1 -- 1 --- 3 6 114 | # \ / 115 | # 4 116 | 117 | graph = nx.DiGraph() 118 | node_ids = list(range(11)) 119 | types = [ 120 | "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP", 121 | "ATTN", "ATTN" 122 | ] 123 | expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4] 124 | for node_id, node_type, res in zip(node_ids, types, expected_results): 125 | graph.add_node( 126 | node_id, **{ 127 | nodes.ID: node_id, 128 | nodes.EXPR: rasp.ConstantSOp(1), 129 | nodes.MODEL_BLOCK: self._get_dummy_block(node_type), 130 | "expected_result": res 131 | }) 132 | 133 | graph.add_edge(0, 1) 134 | graph.add_edge(1, 2) 135 | graph.add_edge(2, 3) 136 | graph.add_edge(3, 4) 137 | graph.add_edge(4, 5) 138 | graph.add_edge(6, 7) 139 | graph.add_edge(7, 2) 140 | graph.add_edge(7, 8) 141 | graph.add_edge(8, 3) 142 | graph.add_edge(8, 10) 143 | graph.add_edge(9, 4) 144 | graph.add_edge(10, 9) 145 | 146 | craft_graph = rasp_to_graph.ExtractRaspGraphOutput( 147 | graph=graph, 148 | sink=graph.nodes[10], 149 | sources=[graph.nodes[0], graph.nodes[6]]) 150 | 151 | layer_allocation = craft_graph_to_model._allocate_modules_to_layers( 152 | craft_graph.graph, craft_graph.sources) 153 | for node_id, node in graph.nodes.items(): 154 | self.assertEqual(layer_allocation[node_id], node["expected_result"]) 155 | 156 | def test_allocate_modules_to_layers_returns_expected_result_for_chain(self): 157 | """Tests a chain of alternating attention layers and MLPs.""" 158 | 159 | # Computation Graph: 160 | # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT 161 | # Node IDs: 162 | # 0 -- 1 -- 2 -- 3 -- 4 -- 5 163 | # Expected layer allocation: 164 | # -1 -- 0 -- 1 -- 2 -- 3 -- -1 165 | 166 | graph = nx.DiGraph() 167 | node_ids = list(range(11)) 168 | types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"] 169 | expected_results = [-1, 0, 1, 2, 3, -1] 170 | for node_id, node_type, res in zip(node_ids, types, expected_results): 171 | graph.add_node( 172 | node_id, **{ 173 | nodes.ID: node_id, 174 | nodes.EXPR: rasp.ConstantSOp(1), 175 | nodes.MODEL_BLOCK: self._get_dummy_block(node_type), 176 | "expected_result": res 177 | }) 178 | 179 | graph.add_edge(0, 1) 180 | graph.add_edge(1, 2) 181 | graph.add_edge(2, 3) 182 | graph.add_edge(3, 4) 183 | graph.add_edge(4, 5) 184 | 185 | craft_graph = rasp_to_graph.ExtractRaspGraphOutput( 186 | graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]]) 187 | 188 | layer_allocation = craft_graph_to_model._allocate_modules_to_layers( 189 | craft_graph.graph, craft_graph.sources) 190 | for node_id, node in graph.nodes.items(): 191 | self.assertEqual(layer_allocation[node_id], node["expected_result"]) 192 | 193 | 194 | if __name__ == "__main__": 195 | absltest.main() 196 | -------------------------------------------------------------------------------- /tracr/compiler/craft_model_to_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Convert craft model into transformer with the correct input/output spaces.""" 16 | 17 | import networkx as nx 18 | from tracr.compiler import assemble 19 | from tracr.compiler import nodes 20 | from tracr.craft import bases 21 | from tracr.craft import transformers 22 | from tracr.rasp import rasp 23 | from tracr.transformer import encoder 24 | 25 | 26 | def craft_model_to_transformer( 27 | craft_model: transformers.SeriesWithResiduals, 28 | graph: nx.DiGraph, 29 | sink: nodes.Node, 30 | max_seq_len: int, 31 | compiler_bos: str, 32 | compiler_pad: str, 33 | causal: bool = False, 34 | ) -> assemble.AssembledTransformerModel: 35 | """Turn a craft model into a transformer model.""" 36 | 37 | if rasp.tokens.label not in graph.nodes: 38 | raise ValueError( 39 | f'Failed to find a node with label {rasp.tokens.label}. ' 40 | 'This is probably because your RASP program does not include ' 41 | 'rasp.tokens. A program must include rasp.tokens to be ' 42 | 'compiled.' 43 | ) 44 | 45 | # Add the compiler BOS token. 46 | tokens_value_set = ( 47 | graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union( 48 | {compiler_bos, compiler_pad})) 49 | tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label, 50 | tokens_value_set) 51 | 52 | indices_space = bases.VectorSpaceWithBasis.from_values( 53 | rasp.indices.label, range(max_seq_len)) 54 | 55 | categorical_output = rasp.is_categorical(sink[nodes.EXPR]) 56 | output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS]) 57 | 58 | assembled_model = assemble.assemble_craft_model( 59 | craft_model=craft_model, 60 | tokens_space=tokens_space, 61 | indices_space=indices_space, 62 | output_space=output_space, 63 | categorical_output=categorical_output, 64 | causal=causal, 65 | ) 66 | 67 | assembled_model.input_encoder = encoder.CategoricalEncoder( 68 | basis=tokens_space.basis, 69 | enforce_bos=compiler_bos is not None, 70 | bos_token=compiler_bos, 71 | pad_token=compiler_pad, 72 | max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len, 73 | ) 74 | 75 | if categorical_output: 76 | assembled_model.output_encoder = encoder.CategoricalEncoder( 77 | basis=output_space.basis, 78 | enforce_bos=False, 79 | bos_token=None, 80 | pad_token=None) 81 | else: 82 | assembled_model.output_encoder = encoder.NumericalEncoder() 83 | 84 | return assembled_model 85 | -------------------------------------------------------------------------------- /tracr/compiler/expr_to_craft_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.expr_to_craft_graph.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.compiler import basis_inference 20 | from tracr.compiler import expr_to_craft_graph 21 | from tracr.compiler import lib 22 | from tracr.compiler import nodes 23 | from tracr.compiler import rasp_to_graph 24 | from tracr.craft import bases 25 | from tracr.craft import transformers 26 | from tracr.rasp import rasp 27 | 28 | 29 | class ExprToCraftGraphTest(parameterized.TestCase): 30 | 31 | def _check_block_types_are_correct(self, graph): 32 | for _, node in graph.nodes.items(): 33 | expr = node[nodes.EXPR] 34 | if isinstance(expr, rasp.SOp): 35 | block = node[nodes.MODEL_BLOCK] 36 | if isinstance(expr, (rasp.Map, rasp.SequenceMap)): 37 | self.assertIsInstance(block, transformers.MLP) 38 | elif isinstance(expr, rasp.Aggregate): 39 | self.assertIsInstance(block, transformers.AttentionHead) 40 | 41 | def _get_input_space_from_node(self, node): 42 | block = node[nodes.MODEL_BLOCK] 43 | if isinstance(block, transformers.MLP): 44 | return block.fst.input_space 45 | elif isinstance(block, transformers.AttentionHead): 46 | return bases.join_vector_spaces(block.w_qk.left_space, 47 | block.w_qk.right_space, 48 | block.w_ov.input_space) 49 | else: 50 | return None 51 | 52 | def _check_spaces_are_consistent(self, graph): 53 | """Check that for each edge the output is a subspace of the input.""" 54 | for u, v in graph.edges: 55 | u_node, v_node = graph.nodes[u], graph.nodes[v] 56 | if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance( 57 | v_node[nodes.EXPR], rasp.SOp): 58 | u_out_basis = u_node[nodes.OUTPUT_BASIS] 59 | u_out_space = bases.VectorSpaceWithBasis(u_out_basis) 60 | v_in_space = self._get_input_space_from_node(v_node) 61 | self.assertTrue(u_out_space.issubspace(v_in_space)) 62 | 63 | @parameterized.named_parameters( 64 | dict( 65 | testcase_name="single_map", 66 | program=rasp.Map(lambda x: x + 1, rasp.tokens), 67 | ), 68 | dict( 69 | testcase_name="single_sequence_map", 70 | program=rasp.SequenceMap( 71 | lambda x, y: x + y, rasp.tokens, rasp.indices 72 | ), 73 | ), 74 | dict( 75 | testcase_name="single_select_aggregate", 76 | program=rasp.Aggregate( 77 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 78 | rasp.tokens, 79 | ), 80 | ), 81 | dict(testcase_name="reverse", program=lib.make_reverse(rasp.tokens)), 82 | dict(testcase_name="length", program=lib.make_length()), 83 | dict( 84 | testcase_name="annotated_tokens", 85 | program=rasp.annotate(rasp.tokens, foo="foo"), 86 | ), 87 | dict( 88 | testcase_name="annotated_indices", 89 | program=rasp.annotate(rasp.indices, foo="foo"), 90 | ), 91 | ) 92 | def test_compiling_rasp_programs(self, program): 93 | vocab = {0, 1, 2} 94 | extracted = rasp_to_graph.extract_rasp_graph(program) 95 | basis_inference.infer_bases( 96 | extracted.graph, 97 | extracted.sink, 98 | vocab, 99 | max_seq_len=3, 100 | ) 101 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 102 | self._check_block_types_are_correct(extracted.graph) 103 | self._check_spaces_are_consistent(extracted.graph) 104 | 105 | def test_add_craft_components_raises_value_error_if_called_before_basis_inference( 106 | self): 107 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 108 | extracted = rasp_to_graph.extract_rasp_graph(program) 109 | 110 | with self.assertRaisesRegex( 111 | ValueError, 112 | r"^.*Craft components can only be added after basis inference.*$"): 113 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 114 | 115 | def test_add_craft_components_raises_value_error_if_called_twice(self): 116 | vocab = {0, 1, 2} 117 | program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) 118 | extracted = rasp_to_graph.extract_rasp_graph(program) 119 | 120 | basis_inference.infer_bases( 121 | extracted.graph, 122 | extracted.sink, 123 | vocab, 124 | max_seq_len=1, 125 | ) 126 | 127 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 128 | with self.assertRaisesRegex( 129 | ValueError, r"^.*Input graph cannot have model blocks set already.*$"): 130 | expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) 131 | 132 | 133 | if __name__ == "__main__": 134 | absltest.main() 135 | -------------------------------------------------------------------------------- /tracr/compiler/lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.lib.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.compiler import test_cases 20 | from tracr.rasp import causal_eval 21 | from tracr.rasp import rasp 22 | 23 | 24 | class LibTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters(*test_cases.TEST_CASES) 27 | def test_program_produces_expected_output(self, program, test_input, 28 | expected_output, **kwargs): 29 | del kwargs 30 | self.assertEqual(rasp.evaluate(program, test_input), expected_output) 31 | 32 | @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES) 33 | def test_causal_program_produces_expected_output(self, program, test_input, 34 | expected_output, **kwargs): 35 | del kwargs 36 | self.assertEqual(causal_eval.evaluate(program, test_input), expected_output) 37 | 38 | 39 | if __name__ == "__main__": 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /tracr/compiler/nodes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Documents the data stored in nodes after each compiler pass.""" 16 | 17 | from typing import Any, Dict 18 | 19 | Node = Dict[str, Any] 20 | NodeID = str 21 | 22 | # RASP -> Graph 23 | ID = "ID" # unique ID of the node 24 | EXPR = "EXPR" # the RASPExpr of the node 25 | 26 | # Basis inference 27 | # Note that only S-Op expressions will have these keys set. 28 | VALUE_SET = "VALUE_SET" # possible values taken on by this SOp. 29 | OUTPUT_BASIS = "OUTPUT_BASIS" # the corresponding named basis. 30 | 31 | # RASP Graph -> Craft Graph 32 | MODEL_BLOCK = "MODEL_BLOCK" # craft block representing a RASPExpr 33 | -------------------------------------------------------------------------------- /tracr/compiler/rasp_to_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converting a RaspExpr to a graph.""" 16 | 17 | import dataclasses 18 | import queue 19 | from typing import List 20 | 21 | import networkx as nx 22 | from tracr.compiler import nodes 23 | from tracr.rasp import rasp 24 | 25 | Node = nodes.Node 26 | NodeID = nodes.NodeID 27 | 28 | 29 | @dataclasses.dataclass 30 | class ExtractRaspGraphOutput: 31 | graph: nx.DiGraph 32 | sink: Node # the program's output. 33 | sources: List[Node] # the primitive S-Ops. 34 | 35 | 36 | def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput: 37 | """Converts a RASP program into a graph representation.""" 38 | expr_queue = queue.Queue() 39 | graph = nx.DiGraph() 40 | sources: List[NodeID] = [] 41 | 42 | def ensure_node(expr: rasp.RASPExpr) -> NodeID: 43 | """Finds or creates a graph node corresponding to expr; returns its ID.""" 44 | node_id = expr.label 45 | if node_id not in graph: 46 | graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr}) 47 | 48 | return node_id 49 | 50 | # Breadth-first search over the RASP expression graph. 51 | 52 | def visit_raspexpr(expr: rasp.RASPExpr): 53 | parent_id = ensure_node(expr) 54 | 55 | for child_expr in expr.children: 56 | expr_queue.put(child_expr) 57 | child_id = ensure_node(child_expr) 58 | graph.add_edge(child_id, parent_id) 59 | 60 | if not expr.children: 61 | sources.append(graph.nodes[parent_id]) 62 | 63 | expr_queue.put(tip) 64 | sink = graph.nodes[ensure_node(tip)] 65 | while not expr_queue.empty(): 66 | visit_raspexpr(expr_queue.get()) 67 | 68 | return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources) 69 | -------------------------------------------------------------------------------- /tracr/compiler/rasp_to_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.rasp_to_graph.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.compiler import nodes 20 | from tracr.compiler import rasp_to_graph 21 | from tracr.rasp import rasp 22 | 23 | 24 | class ExtractRaspGraphTest(parameterized.TestCase): 25 | 26 | def test_primitives_have_no_edges(self): 27 | tokens_graph = rasp_to_graph.extract_rasp_graph(rasp.tokens).graph 28 | self.assertEmpty(tokens_graph.edges) 29 | 30 | indices_graph = rasp_to_graph.extract_rasp_graph(rasp.indices).graph 31 | self.assertEmpty(indices_graph.edges) 32 | 33 | full_graph = rasp_to_graph.extract_rasp_graph(rasp.Full(1)).graph 34 | self.assertEmpty(full_graph.edges) 35 | 36 | def test_one_edge(self): 37 | program = rasp.Map(lambda x: x + 1, rasp.tokens) 38 | 39 | graph = rasp_to_graph.extract_rasp_graph(program).graph 40 | 41 | self.assertLen(graph.edges, 1) 42 | (u, v), = graph.edges 43 | self.assertEqual(graph.nodes[u][nodes.EXPR], rasp.tokens) 44 | self.assertEqual(graph.nodes[v][nodes.EXPR], program) 45 | 46 | def test_aggregate(self): 47 | program = rasp.Aggregate( 48 | rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), 49 | rasp.indices, 50 | ) 51 | 52 | extracted = rasp_to_graph.extract_rasp_graph(program) 53 | 54 | # Expected graph: 55 | # 56 | # indices \ -------- 57 | # \ \ 58 | # select -- program 59 | # tokens / 60 | 61 | self.assertLen(extracted.graph.edges, 4) 62 | self.assertEqual(extracted.sink[nodes.EXPR], program) 63 | for source in extracted.sources: 64 | self.assertIn( 65 | source[nodes.EXPR], 66 | [rasp.tokens, rasp.indices], 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /tracr/compiler/rasp_to_transformer_integration_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Integration tests for the full RASP -> transformer compilation.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | import numpy as np 21 | from tracr.compiler import compiling 22 | from tracr.compiler import lib 23 | from tracr.compiler import test_cases 24 | from tracr.craft import tests_common 25 | from tracr.rasp import rasp 26 | 27 | _COMPILER_BOS = "rasp_to_transformer_integration_test_BOS" 28 | _COMPILER_PAD = "rasp_to_transformer_integration_test_PAD" 29 | 30 | # Force float32 precision on TPU, which otherwise defaults to float16. 31 | jax.config.update("jax_default_matmul_precision", "float32") 32 | 33 | 34 | class CompilerIntegrationTest(tests_common.VectorFnTestCase): 35 | 36 | def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq): 37 | for actual, expected in zip(actual_seq, expected_seq): 38 | if expected is not None and actual != expected: 39 | self.fail( 40 | f"{actual_seq} does not match (ignoring Nones) " 41 | f"expected_seq={expected_seq}" 42 | ) 43 | 44 | @parameterized.named_parameters( 45 | dict(testcase_name="map", program=rasp.Map(lambda x: x + 1, rasp.tokens)), 46 | dict( 47 | testcase_name="sequence_map", 48 | program=rasp.SequenceMap( 49 | lambda x, y: x + y, rasp.tokens, rasp.indices 50 | ), 51 | ), 52 | dict( 53 | testcase_name="sequence_map_with_same_input", 54 | program=rasp.SequenceMap( 55 | lambda x, y: x + y, rasp.tokens, rasp.indices 56 | ), 57 | ), 58 | dict( 59 | testcase_name="select_aggregate", 60 | program=rasp.Aggregate( 61 | rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), 62 | rasp.Map(lambda x: 1, rasp.tokens), 63 | ), 64 | ), 65 | ) 66 | def test_rasp_program_and_transformer_produce_same_output(self, program): 67 | vocab = {0, 1, 2} 68 | max_seq_len = 3 69 | assembled_model = compiling.compile_rasp_to_model( 70 | program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS 71 | ) 72 | 73 | test_outputs = {} 74 | rasp_outputs = {} 75 | for val in vocab: 76 | test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1] 77 | rasp_outputs[val] = program([val])[0] 78 | 79 | with self.subTest(val=0): 80 | self.assertEqual(test_outputs[0], rasp_outputs[0]) 81 | with self.subTest(val=1): 82 | self.assertEqual(test_outputs[1], rasp_outputs[1]) 83 | with self.subTest(val=2): 84 | self.assertEqual(test_outputs[2], rasp_outputs[2]) 85 | 86 | @parameterized.named_parameters(*test_cases.TEST_CASES) 87 | def test_compiled_models_produce_expected_output( 88 | self, program, vocab, test_input, expected_output, max_seq_len, **kwargs 89 | ): 90 | del kwargs 91 | assembled_model = compiling.compile_rasp_to_model( 92 | program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS 93 | ) 94 | test_output = assembled_model.apply([_COMPILER_BOS] + test_input) 95 | 96 | if isinstance(expected_output[0], (int, float)): 97 | np.testing.assert_allclose( 98 | test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005 99 | ) 100 | else: 101 | self.assertSequenceEqualWhenExpectedIsNotNone( 102 | test_output.decoded[1:], expected_output 103 | ) 104 | 105 | @parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES) 106 | def test_compiled_causal_models_produce_expected_output( 107 | self, program, vocab, test_input, expected_output, max_seq_len, **kwargs 108 | ): 109 | del kwargs 110 | assembled_model = compiling.compile_rasp_to_model( 111 | program, 112 | vocab, 113 | max_seq_len, 114 | causal=True, 115 | compiler_bos=_COMPILER_BOS, 116 | compiler_pad=_COMPILER_PAD, 117 | ) 118 | test_output = assembled_model.apply([_COMPILER_BOS] + test_input) 119 | 120 | if isinstance(expected_output[0], (int, float)): 121 | np.testing.assert_allclose( 122 | test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005 123 | ) 124 | else: 125 | self.assertSequenceEqualWhenExpectedIsNotNone( 126 | test_output.decoded[1:], expected_output 127 | ) 128 | 129 | @parameterized.named_parameters(*test_cases.UNSUPPORTED_TEST_CASES) 130 | def test_unsupported_programs_raise_exception( 131 | self, program, vocab, max_seq_len 132 | ): 133 | with self.assertRaises(NotImplementedError): 134 | compiling.compile_rasp_to_model( 135 | program, 136 | vocab, 137 | max_seq_len, 138 | causal=True, 139 | compiler_bos=_COMPILER_BOS, 140 | compiler_pad=_COMPILER_PAD, 141 | ) 142 | 143 | @parameterized.named_parameters( 144 | dict( 145 | testcase_name="reverse_1", 146 | program=lib.make_reverse(rasp.tokens), 147 | vocab={"a", "b", "c", "d"}, 148 | test_input=list("abcd"), 149 | expected_output=list("dcba"), 150 | max_seq_len=5, 151 | ), 152 | dict( 153 | testcase_name="reverse_2", 154 | program=lib.make_reverse(rasp.tokens), 155 | vocab={"a", "b", "c", "d"}, 156 | test_input=list("abc"), 157 | expected_output=list("cba"), 158 | max_seq_len=5, 159 | ), 160 | dict( 161 | testcase_name="reverse_3", 162 | program=lib.make_reverse(rasp.tokens), 163 | vocab={"a", "b", "c", "d"}, 164 | test_input=list("ad"), 165 | expected_output=list("da"), 166 | max_seq_len=5, 167 | ), 168 | dict( 169 | testcase_name="reverse_4", 170 | program=lib.make_reverse(rasp.tokens), 171 | vocab={"a", "b", "c", "d"}, 172 | test_input=["c"], 173 | expected_output=["c"], 174 | max_seq_len=5, 175 | ), 176 | dict( 177 | testcase_name="length_categorical_1", 178 | program=rasp.categorical(lib.make_length()), 179 | vocab={"a", "b", "c", "d"}, 180 | test_input=list("abc"), 181 | expected_output=[3, 3, 3], 182 | max_seq_len=5, 183 | ), 184 | dict( 185 | testcase_name="length_categorical_2", 186 | program=rasp.categorical(lib.make_length()), 187 | vocab={"a", "b", "c", "d"}, 188 | test_input=list("ad"), 189 | expected_output=[2, 2], 190 | max_seq_len=5, 191 | ), 192 | dict( 193 | testcase_name="length_categorical_3", 194 | program=rasp.categorical(lib.make_length()), 195 | vocab={"a", "b", "c", "d"}, 196 | test_input=["c"], 197 | expected_output=[1], 198 | max_seq_len=5, 199 | ), 200 | dict( 201 | testcase_name="length_numerical_1", 202 | program=rasp.numerical(lib.make_length()), 203 | vocab={"a", "b", "c", "d"}, 204 | test_input=list("abc"), 205 | expected_output=[3, 3, 3], 206 | max_seq_len=5, 207 | ), 208 | dict( 209 | testcase_name="length_numerical_2", 210 | program=rasp.numerical(lib.make_length()), 211 | vocab={"a", "b", "c", "d"}, 212 | test_input=list("ad"), 213 | expected_output=[2, 2], 214 | max_seq_len=5, 215 | ), 216 | dict( 217 | testcase_name="length_numerical_3", 218 | program=rasp.numerical(lib.make_length()), 219 | vocab={"a", "b", "c", "d"}, 220 | test_input=["c"], 221 | expected_output=[1], 222 | max_seq_len=5, 223 | ), 224 | ) 225 | def test_compiled_models_produce_expected_output_with_padding( 226 | self, program, vocab, test_input, expected_output, max_seq_len, **kwargs 227 | ): 228 | del kwargs 229 | assembled_model = compiling.compile_rasp_to_model( 230 | program, 231 | vocab, 232 | max_seq_len, 233 | compiler_bos=_COMPILER_BOS, 234 | compiler_pad=_COMPILER_PAD, 235 | ) 236 | 237 | pad_len = max_seq_len - len(test_input) 238 | test_input = test_input + [_COMPILER_PAD] * pad_len 239 | test_input = [_COMPILER_BOS] + test_input 240 | test_output = assembled_model.apply(test_input) 241 | output = test_output.decoded 242 | output_len = len(output) 243 | output_stripped = test_output.decoded[1 : output_len - pad_len] 244 | 245 | self.assertEqual(output[0], _COMPILER_BOS) 246 | if isinstance(expected_output[0], (int, float)): 247 | np.testing.assert_allclose( 248 | output_stripped, expected_output, atol=1e-7, rtol=0.005 249 | ) 250 | else: 251 | self.assertEqual(output_stripped, expected_output) 252 | 253 | 254 | if __name__ == "__main__": 255 | absltest.main() 256 | -------------------------------------------------------------------------------- /tracr/compiler/validating.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RASP Evaluator which applies causal masks to selectors.""" 16 | 17 | import dataclasses 18 | import queue 19 | from typing import Sequence, Union 20 | 21 | from tracr.rasp import rasp 22 | 23 | 24 | @dataclasses.dataclass 25 | class TracrUnsupportedExpr: 26 | """An uncompilable expression and the reason it's not compilable.""" 27 | 28 | expr: rasp.RASPExpr 29 | reason: str 30 | 31 | 32 | def _static_validate_expr(expr: rasp.RASPExpr) -> TracrUnsupportedExpr | None: 33 | """Returns TracrUnsupportedExpr if `expr` is not supported by Tracr.""" 34 | if isinstance(expr, rasp.TokensType) and rasp.is_numerical(expr): 35 | return TracrUnsupportedExpr( 36 | expr=expr, reason="tokens should always be categorical." 37 | ) 38 | if isinstance(expr, rasp.IndicesType) and rasp.is_numerical(expr): 39 | return TracrUnsupportedExpr( 40 | expr=expr, reason="tokens should always be categorical." 41 | ) 42 | 43 | if isinstance(expr, rasp.Select): 44 | if not rasp.is_categorical(expr.keys): 45 | return TracrUnsupportedExpr( 46 | expr=expr, 47 | reason="Select keys must be categorical.", 48 | ) 49 | if not rasp.is_categorical(expr.queries): 50 | return TracrUnsupportedExpr( 51 | expr=expr, 52 | reason="Select queries must be categorical.", 53 | ) 54 | 55 | if isinstance(expr, rasp.Aggregate): 56 | if rasp.get_encoding(expr) != rasp.get_encoding(expr.sop): 57 | return TracrUnsupportedExpr( 58 | expr=expr, 59 | reason=( 60 | "An aggregate's output encoding must match its input encoding." 61 | f" Input: {rasp.get_encoding(expr)} " 62 | f" Output: {rasp.get_encoding(expr.sop)} " 63 | ), 64 | ) 65 | 66 | if rasp.is_categorical(expr) and expr.default is not None: 67 | return TracrUnsupportedExpr( 68 | expr=expr, 69 | reason="Categorical aggregate only supports None as default value.", 70 | ) 71 | if rasp.is_numerical(expr) and expr.default != 0: 72 | return TracrUnsupportedExpr( 73 | expr=expr, 74 | reason="Numerical aggregate only supports 0 as default value.", 75 | ) 76 | 77 | if isinstance(expr, rasp.SequenceMap): 78 | if not isinstance(expr, rasp.LinearSequenceMap) and not all( 79 | rasp.is_categorical(x) for x in (expr.fst, expr.snd, expr) 80 | ): 81 | return TracrUnsupportedExpr( 82 | expr=expr, 83 | reason=( 84 | "(Non-linear) SequenceMap only supports categorical" 85 | " inputs/outputs." 86 | ), 87 | ) 88 | 89 | if isinstance(expr, rasp.LinearSequenceMap) and not all( 90 | rasp.is_numerical(x) for x in (expr.fst, expr.snd, expr) 91 | ): 92 | return TracrUnsupportedExpr( 93 | expr=expr, 94 | reason="LinearSequenceMap only supports numerical inputs/outputs.", 95 | ) 96 | 97 | 98 | class DynamicValidationEvaluator(rasp.DefaultRASPEvaluator): 99 | """Evaluates RASP program but raises exceptions to anticipate compiler issues. 100 | 101 | Most features not supported by Tracr are specific input/output types for 102 | some SOp types and can be checked statically. For example, Tracr does not 103 | support Aggregate operations with different input and output encodings 104 | (instead, explicit conversion via a Map is required). 105 | 106 | There are some specific aggregate operations that are not supported and have 107 | to be checked dynamically. For example, Tracr does not support categorical 108 | Aggregate operations that require non-trival aggregation (eg, averaging 109 | tokens instead of moving tokens). 110 | """ 111 | 112 | def __init__(self): 113 | self.unsupported_exprs = [] 114 | super().__init__() 115 | 116 | def evaluate( 117 | self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value] 118 | ) -> Union[Sequence[rasp.Value], rasp.SelectorValue]: 119 | out = super().evaluate(expr, xs) 120 | 121 | if isinstance(expr, rasp.Aggregate): 122 | # We support compiling programs which use Aggregates to move a single 123 | # categorical value to another position, ie when the attention pattern is 124 | # 1 in one place and 0 otherwise. However, if the attention pattern has 125 | # two or more 1s attending to different tokens that have to be aggregated 126 | # the compiled model will silently give incorrect outputs. We don't have 127 | # a way to do this statically so we have to check this at runtime. 128 | 129 | agg_in = expr.sop(xs) 130 | if ( 131 | # The easiest way to satisfy this is to have a selector of width 1 132 | rasp.is_categorical(expr) 133 | and not set(out).issubset(set(agg_in) | {None}) 134 | ): 135 | self.unsupported_exprs.append( 136 | TracrUnsupportedExpr( 137 | expr=expr, 138 | reason=( 139 | "Categorical aggregate does not support Selectors with" 140 | " width > 1 that require aggregation (eg. averaging)." 141 | ), 142 | ) 143 | ) 144 | if rasp.is_numerical(expr) and not set(agg_in).issubset({0, 1}): 145 | self.unsupported_exprs.append( 146 | TracrUnsupportedExpr( 147 | expr=expr, 148 | reason=( 149 | "Numerical aggregate only supports binary inputs 0, 1. But" 150 | f" got {set(agg_in)}." 151 | ), 152 | ) 153 | ) 154 | 155 | return out 156 | 157 | 158 | def static_validate(program: rasp.RASPExpr) -> list[TracrUnsupportedExpr]: 159 | """Performs static checks to see if `program` can be compiled. 160 | 161 | Args: 162 | program: RASP program to validate 163 | 164 | Returns: 165 | list of all unsupported subexpressions detectable statically. 166 | """ 167 | expr_queue = queue.Queue() 168 | unsupported_exprs = [] 169 | visited_exprs = set() 170 | 171 | # Breadth-first search over the RASP expression graph. 172 | def visit_raspexpr(expr: rasp.RASPExpr): 173 | visited_exprs.add(expr.name) 174 | unsupported_expr = _static_validate_expr(expr) 175 | if unsupported_expr: 176 | unsupported_exprs.append(unsupported_expr) 177 | 178 | for child_expr in expr.children: 179 | if child_expr.name not in visited_exprs: 180 | expr_queue.put(child_expr) 181 | 182 | expr_queue.put(program) 183 | while not expr_queue.empty(): 184 | visit_raspexpr(expr_queue.get()) 185 | 186 | return unsupported_exprs 187 | 188 | 189 | def dynamic_validate( 190 | program: rasp.RASPExpr, xs: Sequence[rasp.Value] | None = None 191 | ) -> list[TracrUnsupportedExpr]: 192 | """Checks if `program` can be compiled for input `xs`. 193 | 194 | Args: 195 | program: RASP program to validate 196 | xs: Input sequence to use for dynamic compiler check. If None, only do 197 | static checks. 198 | 199 | Returns: 200 | list of all unsupported expressions according to the dynamic validation 201 | """ 202 | validation_evaluator = DynamicValidationEvaluator() 203 | validation_evaluator.evaluate(expr=program, xs=xs) 204 | return validation_evaluator.unsupported_exprs 205 | 206 | 207 | def validate( 208 | program: rasp.RASPExpr, xs: Sequence[rasp.Value] | None = None 209 | ) -> list[TracrUnsupportedExpr]: 210 | """Checks if `program` can be compiled for input `xs`. 211 | 212 | Args: 213 | program: RASP program to validate 214 | xs: Input sequence to use for dynamic compiler check. If None, only do 215 | static checks. 216 | 217 | Returns: 218 | list of all unsupported expressions 219 | """ 220 | static_unsupported = static_validate(program) 221 | if xs is not None: 222 | dynamic_unsupported = dynamic_validate(program, xs) 223 | return static_unsupported + dynamic_unsupported 224 | return static_unsupported 225 | -------------------------------------------------------------------------------- /tracr/compiler/validating_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for compiler.compilable_evaluator.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.compiler import test_cases 20 | from tracr.compiler import validating 21 | from tracr.rasp import rasp 22 | 23 | 24 | class ValidationEvaluatorTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters(test_cases.TEST_CASES) 27 | def test_supported_programs_pass_validation( 28 | self, 29 | program, 30 | test_input, 31 | **kwargs, 32 | ): 33 | del kwargs 34 | validation_result = validating.validate(program, test_input) 35 | self.assertEmpty(validation_result) 36 | 37 | @parameterized.named_parameters(test_cases.UNSUPPORTED_TEST_CASES) 38 | def test_unsupported_programs_fail_validation( 39 | self, 40 | program, 41 | vocab, 42 | **kwargs, 43 | ): 44 | del kwargs 45 | test_input = sorted(list(vocab)) 46 | validation_result = validating.validate(program, test_input) 47 | self.assertNotEmpty(validation_result) 48 | 49 | @parameterized.named_parameters( 50 | dict( 51 | testcase_name="mean", 52 | program=rasp.Aggregate( 53 | rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE), 54 | rasp.tokens, 55 | ), 56 | test_input=[1, 2, 3, 4], 57 | ), 58 | dict( 59 | testcase_name="prev_mean", 60 | program=rasp.Aggregate( 61 | rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ), 62 | rasp.tokens, 63 | ), 64 | test_input=[1, 2, 3, 4], 65 | ), 66 | ) 67 | def test_dynamic_failure_cases_fail_validation( 68 | self, 69 | program, 70 | test_input, 71 | ): 72 | # Dynamic test cases are not in the general test case suite because they are 73 | # not caught at compile time. 74 | validation_result = validating.validate(program, test_input) 75 | self.assertNotEmpty(validation_result) 76 | 77 | 78 | if __name__ == "__main__": 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /tracr/craft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/craft/bases_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for bases.""" 16 | 17 | from absl.testing import absltest 18 | import numpy as np 19 | from tracr.craft import bases 20 | from tracr.craft import tests_common 21 | 22 | 23 | class VectorInBasisTest(tests_common.VectorFnTestCase): 24 | 25 | def test_shape_mismatch_raises_value_error(self): 26 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 27 | regex = ( 28 | r"^.*Last dimension of magnitudes must be the same as number of " 29 | r"basis directions.*$" 30 | ) 31 | with self.assertRaisesRegex(ValueError, regex): 32 | bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 33 | with self.assertRaisesRegex(ValueError, regex): 34 | bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 35 | 36 | def test_equal(self): 37 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 38 | v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 39 | v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 40 | self.assertEqual(v1, v2) 41 | self.assertEqual(v2, v1) 42 | v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 43 | v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]])) 44 | self.assertEqual(v3, v4) 45 | self.assertEqual(v4, v3) 46 | v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 47 | v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1])) 48 | self.assertNotEqual(v5, v6) 49 | self.assertNotEqual(v6, v5) 50 | v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 51 | v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]])) 52 | self.assertNotEqual(v7, v8) 53 | self.assertNotEqual(v8, v7) 54 | vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"]) 55 | v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4])) 56 | v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4])) 57 | self.assertNotEqual(v9, v10) 58 | self.assertNotEqual(v10, v9) 59 | 60 | def test_dunders(self): 61 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 62 | v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) 63 | three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) 64 | five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5])) 65 | v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10])) 66 | self.assertEqual(5 * v, v_times_5) 67 | self.assertEqual(v * 5, v_times_5) 68 | self.assertEqual(5.0 * v, v_times_5) 69 | self.assertEqual(v * 5.0, v_times_5) 70 | v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1])) 71 | self.assertEqual(v / 2, v_by_2) 72 | self.assertEqual(v / 2.0, v_by_2) 73 | self.assertEqual(1 / 2 * v, v_by_2) 74 | v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) 75 | self.assertEqual(v + three, v_plus_3) 76 | self.assertEqual(three + v, v_plus_3) 77 | v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3])) 78 | self.assertEqual(v - five, v_minus_5) 79 | minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2])) 80 | self.assertEqual(-v, minus_v) 81 | 82 | def test_add_directions(self): 83 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 84 | expected = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5])) 85 | v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2])) 86 | three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3])) 87 | shifted = v.add_directions(three) 88 | self.assertEqual(shifted, expected) 89 | 90 | 91 | class ProjectionTest(tests_common.VectorFnTestCase): 92 | 93 | def test_direct_sum_produces_expected_result(self): 94 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 95 | vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) 96 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"]) 97 | self.assertEqual(bases.direct_sum(vs1, vs2), vs3) 98 | 99 | def test_join_vector_spaces_produces_expected_result(self): 100 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 101 | vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"]) 102 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 103 | self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) 104 | 105 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 106 | vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"]) 107 | vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 108 | self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3) 109 | 110 | def test_compare_vectors_with_differently_ordered_basis_vectors(self): 111 | basis1 = ["a", "b", "c", "d"] 112 | basis1 = [bases.BasisDirection(x) for x in basis1] 113 | basis2 = ["b", "d", "a", "c"] 114 | basis2 = [bases.BasisDirection(x) for x in basis2] 115 | vs1 = bases.VectorSpaceWithBasis(basis1) 116 | vs2 = bases.VectorSpaceWithBasis(basis2) 117 | v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4])) 118 | v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3])) 119 | self.assertEqual(v1, v2) 120 | self.assertEqual(v1 - v2, vs1.null_vector()) 121 | self.assertEqual(v1 - v2, vs2.null_vector()) 122 | self.assertEqual(v1 + v2, 2 * v2) 123 | self.assertIn(v1, vs1) 124 | self.assertIn(v1, vs2) 125 | self.assertIn(v2, vs1) 126 | self.assertIn(v2, vs2) 127 | 128 | def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self): 129 | basis1 = ["a", "b", "c", "d"] 130 | basis1 = [bases.BasisDirection(x) for x in basis1] 131 | basis2 = ["b", "d", "a", "c"] 132 | basis2 = [bases.BasisDirection(x) for x in basis2] 133 | vs1 = bases.VectorSpaceWithBasis(basis1) 134 | vs2 = bases.VectorSpaceWithBasis(basis2) 135 | v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]])) 136 | v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]])) 137 | null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()]) 138 | self.assertEqual(v1, v2) 139 | self.assertEqual(v1 - v2, null_vec) 140 | self.assertEqual(v1 + v2, 2 * v2) 141 | self.assertIn(v1, vs1) 142 | self.assertIn(v1, vs2) 143 | self.assertIn(v2, vs1) 144 | self.assertIn(v2, vs2) 145 | 146 | def test_projection_to_larger_space(self): 147 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 148 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 149 | a1, b1 = vs1.basis_vectors() 150 | a2, b2, _, _ = vs2.basis_vectors() 151 | 152 | self.assertEqual(a1.project(vs2), a2) 153 | self.assertEqual(b1.project(vs2), b2) 154 | 155 | def test_projection_to_smaller_space(self): 156 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 157 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 158 | a1, b1, c1, d1 = vs1.basis_vectors() 159 | a2, b2 = vs2.basis_vectors() 160 | 161 | self.assertEqual(a1.project(vs2), a2) 162 | self.assertEqual(b1.project(vs2), b2) 163 | self.assertEqual(c1.project(vs2), vs2.null_vector()) 164 | self.assertEqual(d1.project(vs2), vs2.null_vector()) 165 | 166 | 167 | if __name__ == "__main__": 168 | absltest.main() 169 | -------------------------------------------------------------------------------- /tracr/craft/chamber/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/craft/chamber/categorical_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Attention head for categorical inputs.""" 16 | 17 | from typing import Optional 18 | 19 | from tracr.craft import bases 20 | from tracr.craft import transformers 21 | from tracr.craft import vectorspace_fns 22 | from typing_extensions import Protocol 23 | 24 | 25 | class QueryKeyToAttnLogit(Protocol): 26 | 27 | def __call__(self, query: bases.BasisDirection, 28 | key: bases.BasisDirection) -> bool: 29 | pass 30 | 31 | 32 | def categorical_attn( 33 | query_space: bases.VectorSpaceWithBasis, 34 | key_space: bases.VectorSpaceWithBasis, 35 | value_space: bases.VectorSpaceWithBasis, 36 | output_space: bases.VectorSpaceWithBasis, 37 | bos_space: bases.VectorSpaceWithBasis, 38 | one_space: bases.VectorSpaceWithBasis, 39 | attn_fn: QueryKeyToAttnLogit, 40 | default_output: Optional[bases.VectorInBasis] = None, 41 | causal: bool = False, 42 | always_attend_to_bos: bool = False, 43 | use_bos_for_default_output: bool = True, 44 | softmax_coldness: float = 100., 45 | ) -> transformers.AttentionHead: 46 | """Returns an attention head for categorical inputs. 47 | 48 | Assumes the existence of a beginning of sequence token and attends to it 49 | always with strength 0.5*softmax_coldness. This allows to implement an 50 | arbitrary default value for rows in the attention pattern that are all-zero. 51 | 52 | Attends to the BOS token if all other key-query pairs have zero attention. 53 | Hence, the first value in the value sequence will be the default output for 54 | such cases. 55 | 56 | Args: 57 | query_space: Vector space containing (categorical) query input. 58 | key_space: Vector space containing (categorical) key input. 59 | value_space: Vector space containing (numerical) value input. 60 | output_space: Vector space which will contain (numerical) output. 61 | bos_space: 1-d space used to identify the beginning of sequence token. 62 | one_space: 1-d space which contains 1 at every position. 63 | attn_fn: A selector function f(query, key) operating on the query/key basis 64 | directions that defines the attention pattern. 65 | default_output: Output to return if attention pattern is all zero. 66 | causal: If True, use masked attention. 67 | always_attend_to_bos: If True, always attend to the BOS token. If False, 68 | only attend to BOS when attending to nothing else. 69 | use_bos_for_default_output: If True, assume BOS is not in the value space 70 | and output a default value when attending to BOS. If False, assume BOS is 71 | in the value space, and map it to the output space like any other token. 72 | softmax_coldness: The inverse temperature of the softmax. Default value is 73 | high which makes the attention close to a hard maximum. 74 | """ 75 | bases.ensure_dims(bos_space, num_dims=1, name="bos_space") 76 | bases.ensure_dims(one_space, num_dims=1, name="one_space") 77 | bos_direction = bos_space.basis[0] 78 | one_direction = one_space.basis[0] 79 | 80 | # Add bos direction to query, key, and value spaces in case it is missing 81 | query_space = bases.join_vector_spaces(query_space, bos_space, one_space) 82 | key_space = bases.join_vector_spaces(key_space, bos_space) 83 | value_space = bases.join_vector_spaces(value_space, bos_space) 84 | 85 | if always_attend_to_bos: 86 | value_basis = value_space.basis 87 | else: 88 | value_basis = [v for v in value_space.basis if v != bos_direction] 89 | assert len(value_basis) == output_space.num_dims 90 | value_to_output = dict(zip(value_basis, output_space.basis)) 91 | 92 | if default_output is None: 93 | default_output = output_space.null_vector() 94 | assert default_output in output_space 95 | 96 | def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float: 97 | 98 | # We want to enforce the following property on our attention patterns: 99 | # - if nothing else is attended to, attend to the BOS token. 100 | # - otherwise, don't attend to the BOS token. 101 | # 102 | # We assume that the BOS position always only contains the vector bos + one, 103 | # and that any other position has bos coefficient 0. 104 | # 105 | # We do this as follows: 106 | # Let Q and K be subspaces of V containing the query and key vectors, 107 | # both disjoint with the BOS space {bos} or the one space {one}. 108 | # Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> ℝ, 109 | # s.t. W_QK(q, k) = 0 whenever either q or k are bos or one. 110 | # 111 | # Then define W_new: V x V -> ℝ st: 112 | # W_new(one, bos) = 0.5, otherwise 0. 113 | # 114 | # Now set W_QK' = W_QK + W_new. 115 | # 116 | # To evaluate the attention to the BOS position: 117 | # W_QK'(q, bos + one) 118 | # = W_QK'(q, bos) + W_QK'(q, one) 119 | # = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one) 120 | # = 0 + 0 + W_new(q, bos) + W_new(q, one) 121 | # = W_new(q, bos) + W_new(q, one) 122 | # = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q' 123 | # = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one) 124 | # = 0 + 0.5 + 0 + 0 125 | # = 0.5 126 | # 127 | # To evaluate the attention to a non-BOS position: 128 | # W_QK'(0 * bos + q, 0 * bos + k) # s.t. q ∈ Q+{one}, k ∈ K+{one} 129 | # = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k) 130 | # = W_QK'(q, 0*bos + k) 131 | # = 0*W_QK'(q, bos) + W_QK'(q, k) 132 | # = W_QK'(q, k) 133 | # = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos. 134 | # = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one. 135 | # 136 | # Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax 137 | # coldness will give us the desired property. QED 138 | # 139 | # The following implements this idea. 140 | # By replacing 0.5 with 1, we can instead enforce a different property: that 141 | # the BOS token is always attended to in addition to whatever else. 142 | 143 | if key == bos_direction and query == one_direction: 144 | c = 1. if always_attend_to_bos else 0.5 145 | return c * softmax_coldness 146 | elif {key, query}.intersection({one_direction, bos_direction}): 147 | return 0 148 | 149 | return softmax_coldness * attn_fn(query, key) 150 | 151 | w_qk = vectorspace_fns.ScalarBilinear.from_action( 152 | query_space, 153 | key_space, 154 | qk_fun, 155 | ) 156 | 157 | def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis: 158 | if use_bos_for_default_output and input_dir == bos_direction: 159 | return default_output 160 | return output_space.vector_from_basis_direction(value_to_output[input_dir]) 161 | 162 | w_ov = vectorspace_fns.Linear.from_action( 163 | value_space, 164 | output_space, 165 | ov_fun, 166 | ) 167 | 168 | return transformers.AttentionHead(w_qk, w_ov, causal=causal) 169 | -------------------------------------------------------------------------------- /tracr/craft/chamber/categorical_attn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for chamber.categorical_attn.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracr.craft import bases 21 | from tracr.craft import tests_common 22 | from tracr.craft.chamber import categorical_attn 23 | 24 | 25 | class CategoricalAttnTest(tests_common.VectorFnTestCase): 26 | 27 | @parameterized.parameters([ 28 | dict(causal=False, input_seq=[1, 2, 3, 4, 5], result_seq=[3, 3, 3, 3, 3]), 29 | dict( 30 | causal=True, 31 | input_seq=[1, 2, 3, 4, 5], 32 | result_seq=[1, 1.5, 2, 2.5, 3]), 33 | dict(causal=False, input_seq=[10], result_seq=[10]), 34 | dict(causal=True, input_seq=[10], result_seq=[10]), 35 | dict(causal=False, input_seq=[-1, 0, 1], result_seq=[0, 0, 0]), 36 | dict(causal=True, input_seq=[-1, 0, 1], result_seq=[-1, -0.5, 0]), 37 | ]) 38 | def test_categorical_attn_can_implement_select_all(self, causal, input_seq, 39 | result_seq): 40 | vocab = range(-20, 20) 41 | input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) 42 | 43 | output_dir = bases.BasisDirection("output") 44 | output_space = bases.VectorSpaceWithBasis([output_dir]) 45 | output_vec = output_space.vector_from_basis_direction(output_dir) 46 | 47 | bos_dir = bases.BasisDirection("bos_dimension") 48 | bos_space = bases.VectorSpaceWithBasis([bos_dir]) 49 | 50 | one_dir = bases.BasisDirection("one") 51 | one_space = bases.VectorSpaceWithBasis([one_dir]) 52 | 53 | value_dir = bases.BasisDirection("value") 54 | value_space = bases.VectorSpaceWithBasis([value_dir]) 55 | 56 | input_space = bases.join_vector_spaces(input_space, bos_space, one_space) 57 | value_space = bases.join_vector_spaces(value_space, bos_space) 58 | residual_space = bases.join_vector_spaces(input_space, value_space, 59 | output_space) 60 | one_vec = residual_space.vector_from_basis_direction(one_dir) 61 | bos_vec = residual_space.vector_from_basis_direction(bos_dir) 62 | value_vec = residual_space.vector_from_basis_direction(value_dir) 63 | 64 | attn = categorical_attn.categorical_attn( 65 | key_space=input_space, 66 | query_space=input_space, 67 | value_space=value_space, 68 | output_space=output_space, 69 | bos_space=bos_space, 70 | one_space=one_space, 71 | attn_fn=lambda x, y: True, 72 | causal=causal) 73 | 74 | test_inputs = [bos_vec + one_vec] 75 | for x in input_seq: 76 | test_inputs.append( 77 | residual_space.vector_from_basis_direction( 78 | bases.BasisDirection("input", x)) + x * value_vec) 79 | test_inputs = bases.VectorInBasis.stack(test_inputs) 80 | 81 | # Expect the average of all (previous) tokens 82 | expected_results = [x * output_vec for x in result_seq] 83 | expected_results = bases.VectorInBasis.stack(expected_results) 84 | 85 | test_outputs = attn.apply(test_inputs).project(output_space) 86 | 87 | self.assertVectorAllClose( 88 | tests_common.strip_bos_token(test_outputs), expected_results) 89 | 90 | @parameterized.parameters([ 91 | dict(causal=False, input_seq=[1, 2, 3, 4, 5], default=0), 92 | dict(causal=True, input_seq=[1, 2, 3, 4, 5], default=1), 93 | dict(causal=False, input_seq=[10], default=2), 94 | dict(causal=True, input_seq=[10], default=-3), 95 | dict(causal=False, input_seq=[-1, 0, 1], default=-2), 96 | dict(causal=True, input_seq=[-1, 0, 1], default=-1), 97 | ]) 98 | def test_categorical_attn_can_implement_select_none(self, causal, input_seq, 99 | default): 100 | vocab = range(-20, 20) 101 | input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) 102 | 103 | output_dir = bases.BasisDirection("output") 104 | output_space = bases.VectorSpaceWithBasis([output_dir]) 105 | default_vec = default * output_space.vector_from_basis_direction(output_dir) 106 | 107 | bos_dir = bases.BasisDirection("bos_dimension") 108 | bos_space = bases.VectorSpaceWithBasis([bos_dir]) 109 | 110 | one_dir = bases.BasisDirection("one") 111 | one_space = bases.VectorSpaceWithBasis([one_dir]) 112 | 113 | value_dir = bases.BasisDirection("value") 114 | value_space = bases.VectorSpaceWithBasis([value_dir]) 115 | 116 | input_space = bases.join_vector_spaces(input_space, bos_space, one_space) 117 | value_space = bases.join_vector_spaces(value_space, bos_space) 118 | residual_space = bases.join_vector_spaces(input_space, value_space, 119 | output_space) 120 | value_vec = residual_space.vector_from_basis_direction(value_dir) 121 | bos_vec = residual_space.vector_from_basis_direction(bos_dir) 122 | one_vec = residual_space.vector_from_basis_direction(one_dir) 123 | 124 | attn = categorical_attn.categorical_attn( 125 | key_space=input_space, 126 | query_space=input_space, 127 | value_space=value_space, 128 | output_space=output_space, 129 | bos_space=bos_space, 130 | one_space=one_space, 131 | attn_fn=lambda x, y: False, 132 | default_output=default_vec, 133 | causal=causal, 134 | always_attend_to_bos=False, 135 | use_bos_for_default_output=True) 136 | 137 | def make_input(x): 138 | return (one_vec + x * value_vec + 139 | residual_space.vector_from_basis_direction( 140 | bases.BasisDirection("input", x))) 141 | 142 | test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] + 143 | [make_input(x) for x in input_seq]) 144 | 145 | # Expect the default value 146 | expected_results = [default_vec for x in input_seq] 147 | expected_results = bases.VectorInBasis.stack(expected_results) 148 | 149 | test_outputs = attn.apply(test_inputs).project(output_space) 150 | 151 | self.assertVectorAllClose( 152 | tests_common.strip_bos_token(test_outputs), expected_results) 153 | 154 | @parameterized.parameters([ 155 | dict(num_counts=5, input_seq=[1, 4, 3, 2], n=1, result=[4, 3, 2, 1]), 156 | dict(num_counts=10, input_seq=[5, 8, 9, 2], n=3, result=[2, 5, 8, 9]) 157 | ]) 158 | def test_categorical_attn_can_implement_shift_by_n(self, num_counts, 159 | input_seq, n, result): 160 | query_prefix = "prefix1" 161 | key_prefix = "prefix2" 162 | agg_input_prefix = "prefix3" 163 | output_prefix = "prefix4" 164 | 165 | bos_direction = bases.BasisDirection("bos") 166 | one_direction = bases.BasisDirection("one") 167 | query_space = bases.VectorSpaceWithBasis.from_values( 168 | query_prefix, range(num_counts)) 169 | key_space = bases.VectorSpaceWithBasis.from_values(key_prefix, 170 | range(num_counts)) 171 | bos_space = bases.VectorSpaceWithBasis([bos_direction]) 172 | one_space = bases.VectorSpaceWithBasis([one_direction]) 173 | key_space = bases.join_vector_spaces(key_space, bos_space) 174 | 175 | agg_input_space = bases.VectorSpaceWithBasis.from_values( 176 | agg_input_prefix, range(num_counts)) 177 | agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space) 178 | output_space = bases.VectorSpaceWithBasis.from_values( 179 | output_prefix, range(num_counts)) 180 | 181 | attn = categorical_attn.categorical_attn( 182 | query_space=query_space, 183 | key_space=key_space, 184 | value_space=agg_input_space, 185 | output_space=output_space, 186 | bos_space=bos_space, 187 | one_space=one_space, 188 | attn_fn=lambda q, k: q.value == k.value, 189 | default_output=None, 190 | always_attend_to_bos=False, 191 | use_bos_for_default_output=True, 192 | causal=False) 193 | 194 | residual_space = bases.join_vector_spaces(key_space, query_space, 195 | agg_input_space, output_space, 196 | one_space) 197 | 198 | seq_len = len(input_seq) 199 | query_seq = np.arange(n, seq_len + n) % seq_len 200 | key_seq = np.arange(seq_len) 201 | 202 | bos_vec = residual_space.vector_from_basis_direction(bos_direction) 203 | one_vec = residual_space.vector_from_basis_direction(one_direction) 204 | 205 | test_inputs = [bos_vec + one_vec] 206 | expected_results = [] 207 | for i in range(seq_len): 208 | test_inputs.append( 209 | residual_space.vector_from_basis_direction( 210 | bases.BasisDirection(query_prefix, query_seq[i])) + 211 | residual_space.vector_from_basis_direction( 212 | bases.BasisDirection(key_prefix, key_seq[i])) + 213 | residual_space.vector_from_basis_direction( 214 | bases.BasisDirection(agg_input_prefix, input_seq[i]))) 215 | expected_results.append( 216 | residual_space.vector_from_basis_direction( 217 | bases.BasisDirection(output_prefix, result[i]))) 218 | 219 | test_inputs = bases.VectorInBasis.stack(test_inputs) 220 | expected_results = bases.VectorInBasis.stack(expected_results) 221 | 222 | test_outputs = attn.apply(test_inputs) 223 | 224 | self.assertVectorAllClose( 225 | tests_common.strip_bos_token(test_outputs), expected_results) 226 | 227 | 228 | if __name__ == "__main__": 229 | absltest.main() 230 | -------------------------------------------------------------------------------- /tracr/craft/chamber/categorical_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """MLP to compute basic linear functions of one-hot encoded integers.""" 16 | 17 | from typing import Callable 18 | 19 | import numpy as np 20 | 21 | from tracr.craft import bases 22 | from tracr.craft import transformers 23 | from tracr.craft import vectorspace_fns 24 | 25 | _ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"]) 26 | 27 | 28 | def map_categorical_mlp( 29 | input_space: bases.VectorSpaceWithBasis, 30 | output_space: bases.VectorSpaceWithBasis, 31 | operation: Callable[[bases.BasisDirection], bases.BasisDirection], 32 | ) -> transformers.MLP: 33 | """Returns an MLP that encodes any categorical function of a single variable f(x). 34 | 35 | The hidden layer is the identity and output combines this with a lookup table 36 | output_k = sum(f(i)*input_i for all i in input space) 37 | 38 | Args: 39 | input_space: space containing the input x. 40 | output_space: space containing possible outputs. 41 | operation: A function operating on basis directions. 42 | """ 43 | 44 | def operation_fn(direction): 45 | if direction in input_space: 46 | output_direction = operation(direction) 47 | if output_direction in output_space: 48 | return output_space.vector_from_basis_direction(output_direction) 49 | return output_space.null_vector() 50 | 51 | first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, 52 | operation_fn) 53 | 54 | second_layer = vectorspace_fns.project(output_space, output_space) 55 | 56 | return transformers.MLP(first_layer, second_layer) 57 | 58 | 59 | def map_categorical_to_numerical_mlp( 60 | input_space: bases.VectorSpaceWithBasis, 61 | output_space: bases.VectorSpaceWithBasis, 62 | operation: Callable[[bases.Value], float], 63 | ) -> transformers.MLP: 64 | """Returns an MLP to compute f(x) from a categorical to a numerical variable. 65 | 66 | The hidden layer is the identity and output combines this with a lookup table 67 | output = sum(f(i)*input_i for all i in input space) 68 | 69 | Args: 70 | input_space: Vector space containing the input x. 71 | output_space: Vector space to write the numerical output to. 72 | operation: A function operating on basis directions. 73 | """ 74 | bases.ensure_dims(output_space, num_dims=1, name="output_space") 75 | out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) 76 | 77 | def operation_fn(direction): 78 | if direction in input_space: 79 | return operation(direction.value) * out_vec 80 | return output_space.null_vector() 81 | 82 | first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, 83 | operation_fn) 84 | 85 | second_layer = vectorspace_fns.project(output_space, output_space) 86 | 87 | return transformers.MLP(first_layer, second_layer) 88 | 89 | 90 | def sequence_map_categorical_mlp( 91 | input1_space: bases.VectorSpaceWithBasis, 92 | input2_space: bases.VectorSpaceWithBasis, 93 | output_space: bases.VectorSpaceWithBasis, 94 | operation: Callable[[bases.BasisDirection, bases.BasisDirection], 95 | bases.BasisDirection], 96 | one_space: bases.VectorSpaceWithBasis = _ONE_SPACE, 97 | hidden_name: bases.Name = "__hidden__", 98 | ) -> transformers.MLP: 99 | """Returns an MLP that encodes a categorical function of two variables f(x, y). 100 | 101 | The hidden layer of the MLP computes the logical and of all input directions 102 | hidden_i_j = ReLU(x_i+x_j-1) 103 | 104 | And the output combines this with a lookup table 105 | output_k = sum(f(i, j)*hidden_i_j for all i,j in input space) 106 | 107 | Args: 108 | input1_space: Vector space containing the input x. 109 | input2_space: Vector space containing the input y. 110 | output_space: Vector space to write outputs to. 111 | operation: A function operating on basis directions. 112 | one_space: a reserved 1-d space that always contains a 1. 113 | hidden_name: Name for hidden dimensions. 114 | """ 115 | bases.ensure_dims(one_space, num_dims=1, name="one_space") 116 | 117 | if not set(input1_space.basis).isdisjoint(input2_space.basis): 118 | raise ValueError("Input spaces to a SequenceMap must be disjoint. " 119 | "If input spaces are the same, use Map instead!") 120 | 121 | input_space = bases.direct_sum(input1_space, input2_space, one_space) 122 | 123 | def to_hidden(x, y): 124 | return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value)) 125 | 126 | def from_hidden(h): 127 | x_name, x_value, y_name, y_value = h.value 128 | x_dir = bases.BasisDirection(x_name, x_value) 129 | y_dir = bases.BasisDirection(y_name, y_value) 130 | return x_dir, y_dir 131 | 132 | hidden_dir = [] 133 | for dir1 in input1_space.basis: 134 | for dir2 in input2_space.basis: 135 | hidden_dir.append(to_hidden(dir1, dir2)) 136 | hidden_space = bases.VectorSpaceWithBasis(hidden_dir) 137 | 138 | def logical_and(direction): 139 | if direction in one_space: 140 | out = bases.VectorInBasis(hidden_space.basis, 141 | -np.ones(hidden_space.num_dims)) 142 | elif direction in input1_space: 143 | dir1 = direction 144 | out = hidden_space.null_vector() 145 | for dir2 in input2_space.basis: 146 | vector = bases.VectorInBasis( 147 | [to_hidden(dir1, dir2)], np.array([1]), _basis_is_sorted=True 148 | ) 149 | out = out.add_directions(vector) 150 | else: 151 | dir2 = direction 152 | out = hidden_space.null_vector() 153 | for dir1 in input1_space.basis: 154 | vector = bases.VectorInBasis( 155 | [to_hidden(dir1, dir2)], np.array([1]), _basis_is_sorted=True 156 | ) 157 | out = out.add_directions(vector) 158 | return out 159 | 160 | first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, 161 | logical_and) 162 | 163 | def operation_fn(direction): 164 | dir1, dir2 = from_hidden(direction) 165 | output_direction = operation(dir1, dir2) 166 | if output_direction in output_space: 167 | return output_space.vector_from_basis_direction(output_direction) 168 | else: 169 | return output_space.null_vector() 170 | 171 | second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, 172 | operation_fn) 173 | 174 | return transformers.MLP(first_layer, second_layer) 175 | -------------------------------------------------------------------------------- /tracr/craft/chamber/categorical_mlp_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for chamber.categorical_mlp.""" 16 | 17 | import math 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from tracr.craft import bases 22 | from tracr.craft import tests_common 23 | from tracr.craft.chamber import categorical_mlp 24 | 25 | 26 | class CategoricalInputMlpTest(tests_common.VectorFnTestCase): 27 | 28 | @parameterized.parameters([ 29 | dict(num_counts=4, x=1, y=2, fun=lambda x, y: x + y, result=3), 30 | dict(num_counts=4, x=1, y=0, fun=lambda x, y: x + y + 1, result=2), 31 | dict(num_counts=5, x=2, y=1, fun=math.pow, result=2), 32 | dict(num_counts=5, x=2, y=2, fun=math.pow, result=4), 33 | ]) 34 | def test_seq_map_categorical_mlp_produces_expected_outcome( 35 | self, num_counts, x, y, fun, result): 36 | input1_name = "in1" 37 | input2_name = "in2" 38 | output_name = "out" 39 | one_name = "one_dimension" 40 | 41 | in1_space = bases.VectorSpaceWithBasis.from_values(input1_name, 42 | range(num_counts + 1)) 43 | in2_space = bases.VectorSpaceWithBasis.from_values(input2_name, 44 | range(num_counts + 1)) 45 | out_space = bases.VectorSpaceWithBasis.from_values(output_name, 46 | range(num_counts + 1)) 47 | 48 | def operation(in1, in2): 49 | out_val = fun(int(in1.value), int(in2.value)) 50 | return bases.BasisDirection(output_name, out_val) 51 | 52 | mlp = categorical_mlp.sequence_map_categorical_mlp( 53 | input1_space=in1_space, 54 | input2_space=in2_space, 55 | output_space=out_space, 56 | operation=operation, 57 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 58 | 59 | test_inputs = ( 60 | mlp.residual_space.vector_from_basis_direction( 61 | bases.BasisDirection(one_name)) + 62 | mlp.residual_space.vector_from_basis_direction( 63 | bases.BasisDirection(input1_name, x)) + 64 | mlp.residual_space.vector_from_basis_direction( 65 | bases.BasisDirection(input2_name, y))) 66 | 67 | expected_results = mlp.residual_space.vector_from_basis_direction( 68 | bases.BasisDirection(output_name, result)) 69 | 70 | test_outputs = mlp.apply(test_inputs) 71 | 72 | self.assertVectorAllClose(test_outputs, expected_results) 73 | 74 | def test_seq_map_categorical_mlp_raises_error_with_overlapping_inputs(self): 75 | input_name = "in" 76 | output_name = "out" 77 | one_name = "one_dimension" 78 | 79 | in1_space = bases.VectorSpaceWithBasis.from_values(input_name, range(5)) 80 | in2_space = bases.VectorSpaceWithBasis.from_values(input_name, range(3, 10)) 81 | out_space = bases.VectorSpaceWithBasis.from_values(output_name, range(5)) 82 | 83 | with self.assertRaisesRegex( 84 | ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): 85 | categorical_mlp.sequence_map_categorical_mlp( 86 | input1_space=in1_space, 87 | input2_space=in1_space, 88 | output_space=out_space, 89 | operation=lambda x, y: bases.BasisDirection(output_name, 0), 90 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 91 | 92 | with self.assertRaisesRegex( 93 | ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"): 94 | categorical_mlp.sequence_map_categorical_mlp( 95 | input1_space=in1_space, 96 | input2_space=in2_space, 97 | output_space=out_space, 98 | operation=lambda x, y: bases.BasisDirection(output_name, 0), 99 | one_space=bases.VectorSpaceWithBasis.from_names([one_name])) 100 | 101 | @parameterized.parameters([ 102 | dict(num_counts=5, x=2, fun=lambda x: x, result=2), 103 | dict(num_counts=5, x=2, fun=lambda x: math.pow(x, int(2)), result=4), 104 | dict(num_counts=5, x=-2, fun=lambda x: math.pow(x, int(2)), result=4), 105 | dict(num_counts=5, x=-1, fun=lambda x: math.pow(x, int(3)), result=-1), 106 | ]) 107 | def test_map_categorical_mlp_produces_expected_outcome_computing_powers( 108 | self, num_counts, x, fun, result): 109 | input_name = "in" 110 | output_name = "out" 111 | 112 | in_space = bases.VectorSpaceWithBasis.from_values( 113 | input_name, range(-num_counts, num_counts + 1)) 114 | out_space = bases.VectorSpaceWithBasis.from_values( 115 | output_name, range(-num_counts, num_counts + 1)) 116 | 117 | def operation(direction): 118 | out_val = fun(int(direction.value)) 119 | return bases.BasisDirection(output_name, out_val) 120 | 121 | mlp = categorical_mlp.map_categorical_mlp( 122 | input_space=in_space, output_space=out_space, operation=operation) 123 | 124 | test_inputs = mlp.residual_space.vector_from_basis_direction( 125 | bases.BasisDirection(input_name, x)) 126 | 127 | expected_results = mlp.residual_space.vector_from_basis_direction( 128 | bases.BasisDirection(output_name, result)) 129 | 130 | test_outputs = mlp.apply(test_inputs) 131 | 132 | self.assertVectorAllClose(test_outputs, expected_results) 133 | 134 | @parameterized.parameters([ 135 | dict(x=2, fun=lambda x: x, result=2), 136 | dict(x=2, fun=lambda x: math.pow(x, int(2)), result=4), 137 | dict(x=1, fun=lambda x: 1 / (x + 1), result=0.5), 138 | dict(x=3, fun=lambda x: 1 / (x + 1), result=0.25), 139 | ]) 140 | def test_map_categorical_to_numerical_mlp_produces_expected_outcome( 141 | self, x, fun, result): 142 | 143 | in_space = bases.VectorSpaceWithBasis.from_values("in", range(6)) 144 | out_space = bases.VectorSpaceWithBasis.from_names(["out"]) 145 | 146 | mlp = categorical_mlp.map_categorical_to_numerical_mlp( 147 | input_space=in_space, 148 | output_space=out_space, 149 | operation=fun, 150 | ) 151 | 152 | test_inputs = mlp.residual_space.vector_from_basis_direction( 153 | bases.BasisDirection("in", x)) 154 | 155 | expected_results = result * mlp.residual_space.vector_from_basis_direction( 156 | bases.BasisDirection("out")) 157 | 158 | test_outputs = mlp.apply(test_inputs) 159 | 160 | self.assertVectorAllClose(test_outputs, expected_results) 161 | 162 | 163 | if __name__ == "__main__": 164 | absltest.main() 165 | -------------------------------------------------------------------------------- /tracr/craft/chamber/numerical_mlp_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for chamber.numerical_mlp.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracr.craft import bases 21 | from tracr.craft import tests_common 22 | from tracr.craft.chamber import numerical_mlp 23 | from tracr.utils import errors 24 | 25 | 26 | class NumericalMlpTest(tests_common.VectorFnTestCase): 27 | 28 | @parameterized.parameters([ 29 | dict( 30 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 31 | x=2, 32 | function=lambda x: x, 33 | result=2), 34 | dict( 35 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 36 | x=2, 37 | function=lambda x: x**2, 38 | result=4), 39 | dict( 40 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 41 | x=2, 42 | function=lambda x: x**3, 43 | result=8), 44 | dict( 45 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 46 | x=-2, 47 | function=lambda x: x, 48 | result=-2), 49 | dict( 50 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 51 | x=-2, 52 | function=lambda x: x**2, 53 | result=4), 54 | dict( 55 | in_value_set={-2, -2, -1, 0, 1, 2, 3}, 56 | x=-2, 57 | function=lambda x: x**3, 58 | result=-8), 59 | ]) 60 | def test_map_numerical_mlp_produces_expected_outcome(self, in_value_set, x, 61 | function, result): 62 | 63 | input_dir = bases.BasisDirection("input") 64 | output_dir = bases.BasisDirection("output") 65 | one_dir = bases.BasisDirection("one") 66 | input_space = bases.VectorSpaceWithBasis([input_dir]) 67 | output_space = bases.VectorSpaceWithBasis([output_dir]) 68 | one_space = bases.VectorSpaceWithBasis([one_dir]) 69 | 70 | mlp = numerical_mlp.map_numerical_mlp( 71 | f=function, 72 | input_space=input_space, 73 | output_space=output_space, 74 | one_space=one_space, 75 | input_value_set=in_value_set, 76 | ) 77 | 78 | test_inputs = bases.VectorInBasis( 79 | basis_directions=[input_dir, output_dir, one_dir], 80 | magnitudes=np.array([x, 0, 1])) 81 | 82 | expected_results = bases.VectorInBasis( 83 | basis_directions=[input_dir, output_dir, one_dir], 84 | magnitudes=np.array([0, result, 0])) 85 | 86 | test_outputs = mlp.apply(test_inputs) 87 | 88 | self.assertVectorAllClose(test_outputs, expected_results) 89 | 90 | @parameterized.parameters([ 91 | dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1), 92 | dict( 93 | in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5), 94 | dict( 95 | in_value_set={0, 1, 2, 3}, 96 | x=3, 97 | function=lambda x: 1 / x, 98 | result=1 / 3), 99 | ]) 100 | def test_map_numerical_mlp_logs_warning_and_produces_expected_outcome( 101 | self, in_value_set, x, function, result): 102 | 103 | input_dir = bases.BasisDirection("input") 104 | output_dir = bases.BasisDirection("output") 105 | one_dir = bases.BasisDirection("one") 106 | input_space = bases.VectorSpaceWithBasis([input_dir]) 107 | output_space = bases.VectorSpaceWithBasis([output_dir]) 108 | one_space = bases.VectorSpaceWithBasis([one_dir]) 109 | 110 | with self.assertLogs(level="WARNING"): 111 | mlp = numerical_mlp.map_numerical_mlp( 112 | f=function, 113 | input_space=input_space, 114 | output_space=output_space, 115 | one_space=one_space, 116 | input_value_set=in_value_set, 117 | ) 118 | 119 | test_inputs = bases.VectorInBasis( 120 | basis_directions=[input_dir, output_dir, one_dir], 121 | magnitudes=np.array([x, 0, 1])) 122 | 123 | expected_results = bases.VectorInBasis( 124 | basis_directions=[input_dir, output_dir, one_dir], 125 | magnitudes=np.array([0, result, 0])) 126 | 127 | test_outputs = mlp.apply(test_inputs) 128 | 129 | self.assertVectorAllClose(test_outputs, expected_results) 130 | 131 | @parameterized.parameters([ 132 | dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1), 133 | dict( 134 | in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5), 135 | dict( 136 | in_value_set={0, 1, 2, 3}, 137 | x=3, 138 | function=lambda x: 1 / x, 139 | result=1 / 3), 140 | ]) 141 | def test_map_numerical_to_categorical_mlp_logs_warning_and_produces_expected_outcome( 142 | self, in_value_set, x, function, result): 143 | 144 | f_ign = errors.ignoring_arithmetic_errors(function) 145 | out_value_set = {f_ign(x) for x in in_value_set if f_ign(x) is not None} 146 | 147 | in_space = bases.VectorSpaceWithBasis.from_names(["input"]) 148 | out_space = bases.VectorSpaceWithBasis.from_values("output", out_value_set) 149 | one_space = bases.VectorSpaceWithBasis.from_names(["one"]) 150 | 151 | residual_space = bases.join_vector_spaces(in_space, one_space, out_space) 152 | in_vec = residual_space.vector_from_basis_direction(in_space.basis[0]) 153 | one_vec = residual_space.vector_from_basis_direction(one_space.basis[0]) 154 | 155 | with self.assertLogs(level="WARNING"): 156 | mlp = numerical_mlp.map_numerical_to_categorical_mlp( 157 | f=function, 158 | input_space=in_space, 159 | output_space=out_space, 160 | input_value_set=in_value_set, 161 | one_space=one_space, 162 | ) 163 | 164 | test_inputs = x * in_vec + one_vec 165 | expected_results = out_space.vector_from_basis_direction( 166 | bases.BasisDirection("output", result)) 167 | test_outputs = mlp.apply(test_inputs).project(out_space) 168 | self.assertVectorAllClose(test_outputs, expected_results) 169 | 170 | @parameterized.parameters([ 171 | dict(x_factor=1, y_factor=2, x=1, y=1, result=3), 172 | dict(x_factor=1, y_factor=2, x=1, y=-1, result=-1), 173 | dict(x_factor=1, y_factor=-1, x=1, y=1, result=0), 174 | dict(x_factor=1, y_factor=1, x=3, y=5, result=8), 175 | dict(x_factor=-2, y_factor=-0.5, x=4, y=1, result=-8.5), 176 | ]) 177 | def test_linear_sequence_map_produces_expected_result(self, x_factor, 178 | y_factor, x, y, result): 179 | 180 | input1_dir = bases.BasisDirection("input1") 181 | input2_dir = bases.BasisDirection("input2") 182 | output_dir = bases.BasisDirection("output") 183 | 184 | mlp = numerical_mlp.linear_sequence_map_numerical_mlp( 185 | input1_basis_direction=input1_dir, 186 | input2_basis_direction=input2_dir, 187 | output_basis_direction=output_dir, 188 | input1_factor=x_factor, 189 | input2_factor=y_factor) 190 | 191 | test_inputs = bases.VectorInBasis( 192 | basis_directions=[input1_dir, input2_dir, output_dir], 193 | magnitudes=np.array([x, y, 0])) 194 | 195 | expected_results = bases.VectorInBasis( 196 | basis_directions=[input1_dir, input2_dir, output_dir], 197 | magnitudes=np.array([0, 0, result])) 198 | 199 | test_outputs = mlp.apply(test_inputs) 200 | 201 | self.assertVectorAllClose(test_outputs, expected_results) 202 | 203 | @parameterized.parameters([ 204 | dict(x_factor=1, y_factor=2, x=1, result=3), 205 | dict(x_factor=1, y_factor=-1, x=1, result=0), 206 | ]) 207 | def test_linear_sequence_map_produces_expected_result_with_same_inputs( 208 | self, x_factor, y_factor, x, result): 209 | 210 | input_dir = bases.BasisDirection("input") 211 | output_dir = bases.BasisDirection("output") 212 | 213 | mlp = numerical_mlp.linear_sequence_map_numerical_mlp( 214 | input1_basis_direction=input_dir, 215 | input2_basis_direction=input_dir, 216 | output_basis_direction=output_dir, 217 | input1_factor=x_factor, 218 | input2_factor=y_factor) 219 | 220 | test_inputs = bases.VectorInBasis( 221 | basis_directions=[input_dir, output_dir], magnitudes=np.array([x, 0])) 222 | 223 | expected_results = bases.VectorInBasis( 224 | basis_directions=[input_dir, output_dir], 225 | magnitudes=np.array([0, result])) 226 | 227 | test_outputs = mlp.apply(test_inputs) 228 | 229 | self.assertVectorAllClose(test_outputs, expected_results) 230 | 231 | 232 | if __name__ == "__main__": 233 | absltest.main() 234 | -------------------------------------------------------------------------------- /tracr/craft/chamber/selector_width.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SelectorWidth component consisting of an attention head and an MLP.""" 16 | 17 | from typing import Iterable 18 | from tracr.craft import bases 19 | from tracr.craft import transformers 20 | from tracr.craft import vectorspace_fns 21 | from tracr.craft.chamber import categorical_attn 22 | from tracr.craft.chamber import numerical_mlp 23 | 24 | 25 | def selector_width( 26 | query_space: bases.VectorSpaceWithBasis, 27 | key_space: bases.VectorSpaceWithBasis, 28 | output_space: bases.VectorSpaceWithBasis, 29 | bos_space: bases.VectorSpaceWithBasis, 30 | one_space: bases.VectorSpaceWithBasis, 31 | attn_fn: categorical_attn.QueryKeyToAttnLogit, 32 | out_value_set: Iterable[float], 33 | categorical_output: bool, 34 | causal: bool = False, 35 | softmax_coldness: float = 100., 36 | mlp_large_number: float = 100., 37 | label: str = "", 38 | ) -> transformers.SeriesWithResiduals: 39 | """Returns a craft block implementing RASP's SelectorWidth primitive. 40 | 41 | The block consists of one attention head and one MLP. 42 | 43 | The attention head implements the attention pattern (attn_fn or key=bos) and 44 | aggregates the bos dimension over this pattern. The output of this will be 45 | 1/(d+1) in every position, where d is the "width" of the attention pattern, 46 | i.e. the number of 1s in a row. 47 | 48 | The MLP then computes d from the previous output in all positions except for 49 | the first BOS position. In the BOS position the MLP removes the output of the 50 | attention head, to ensure it only contains the encoding of the BOS token 51 | which is expected by all other model components. 52 | 53 | Args: 54 | query_space: Vector space containing (categorical) query input. 55 | key_space: Vector space containing (categorical) key input. 56 | output_space: Vector space which will contain (numerical or categorical) 57 | output. 58 | bos_space: 1-d space used to identify the beginning of sequence token. 59 | one_space: Auxiliary 1-d vector space that must contain 1 in the input. 60 | attn_fn: A selector function f(query, key) operating on the query/key basis 61 | directions that defines the attention pattern to compute the width of. 62 | out_value_set: Set of possible output values of this SelectorWidth. 63 | categorical_output: If True, encode the output as a categorical variable. 64 | causal: If True, use masked attention. 65 | softmax_coldness: The inverse temperature of the softmax. Default value is 66 | high which makes the attention close to a hard maximum. 67 | mlp_large_number: A larger number makes the MLP more accurate. 68 | label: A name for this block, used to label auxiliary dimensions. 69 | """ 70 | assert output_space.num_dims == 1 or categorical_output 71 | 72 | attn_out_dir = bases.BasisDirection(f"{label}_selector_width_attn_output") 73 | attn_out_space = bases.VectorSpaceWithBasis([attn_out_dir]) 74 | attn_out_vec = attn_out_space.vector_from_basis_direction(attn_out_dir) 75 | 76 | attn = categorical_attn.categorical_attn( 77 | query_space=query_space, 78 | key_space=key_space, 79 | value_space=bos_space, 80 | output_space=attn_out_space, 81 | bos_space=bos_space, 82 | one_space=one_space, 83 | attn_fn=attn_fn, 84 | default_output=attn_out_space.null_vector(), 85 | causal=causal, 86 | always_attend_to_bos=True, 87 | use_bos_for_default_output=False, 88 | softmax_coldness=softmax_coldness) 89 | 90 | fun = lambda x: round((1 / x) - 1) 91 | in_value_set = {1 / (out_v + 1) for out_v in out_value_set} 92 | if categorical_output: 93 | mlp = numerical_mlp.map_numerical_to_categorical_mlp( 94 | f=fun, 95 | input_space=attn_out_space, 96 | output_space=output_space, 97 | input_value_set=in_value_set, 98 | one_space=one_space, 99 | hidden_name=f"_hidden_{label}_", 100 | large_number=mlp_large_number) 101 | else: 102 | mlp = numerical_mlp.map_numerical_mlp( 103 | f=fun, 104 | input_space=attn_out_space, 105 | output_space=output_space, 106 | input_value_set=in_value_set, 107 | one_space=one_space, 108 | hidden_name=f"_hidden_{label}_", 109 | large_number=mlp_large_number) 110 | 111 | # This implementation of selector width writes at each position including 112 | # the BOS. To ensure that the BOS token position does not contain 113 | # additional values, we add an mlp to subtract the output of both layers. 114 | clean_bos_out_space = bases.join_vector_spaces(attn_out_space, output_space) 115 | vec_to_subtract_from_bos = attn_out_vec.project(clean_bos_out_space) 116 | 117 | if categorical_output: 118 | # Add the one-hot encoding of the zero value to the vector 119 | # which will get scrubbed from the BOS position. 120 | zero_dir = [d for d in output_space.basis if d.value == 0][0] 121 | zero_vec = clean_bos_out_space.vector_from_basis_direction(zero_dir) 122 | vec_to_subtract_from_bos += zero_vec 123 | 124 | # Construct an MLP that subtracts vec_to_subtract_from_bos * bos 125 | # from the residual stream which is vec_to_subtract_from_bos in the 126 | # bos position and 0 else. vec_to_subtract_from_bos contains what the 127 | # attention head writes to the bos position. 128 | 129 | hidden_dir = bases.BasisDirection("_hidden_clean_bos_") 130 | hidden_space = bases.VectorSpaceWithBasis([hidden_dir]) 131 | hidden_vec = hidden_space.vector_from_basis_direction(hidden_dir) 132 | 133 | # It's okay to use the local variables because they are only used within 134 | # the same loop iteration to create the MLP. 135 | # pylint: disable=cell-var-from-loop 136 | first_layer = vectorspace_fns.Linear.from_action(bos_space, hidden_space, 137 | lambda x: hidden_vec) 138 | second_layer = vectorspace_fns.Linear.from_action( 139 | hidden_space, clean_bos_out_space, lambda x: -vec_to_subtract_from_bos) 140 | # pylint: enable=cell-var-from-loop 141 | clean_bos_mlp = transformers.MLP(first_layer, second_layer) 142 | 143 | mlp = transformers.MLP.combine_in_parallel([mlp, clean_bos_mlp]) 144 | return transformers.SeriesWithResiduals([attn, mlp]) 145 | -------------------------------------------------------------------------------- /tracr/craft/chamber/selector_width_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for selector_width.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.craft import bases 20 | from tracr.craft import tests_common 21 | from tracr.craft.chamber import selector_width 22 | 23 | 24 | class SelectorWidthTest(tests_common.VectorFnTestCase): 25 | 26 | @parameterized.product( 27 | causal=[False, True], 28 | categorical_output=[False, True], 29 | input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]], 30 | ) 31 | def test_selector_width_of_select_all_is_length( 32 | self, causal, categorical_output, input_seq 33 | ): 34 | vocab = range(-20, 20) 35 | input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) 36 | 37 | if categorical_output: 38 | output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) 39 | else: 40 | output_space = bases.VectorSpaceWithBasis( 41 | [bases.BasisDirection("output")] 42 | ) 43 | 44 | bos_dir = bases.BasisDirection("bos_dimension") 45 | bos_space = bases.VectorSpaceWithBasis([bos_dir]) 46 | 47 | one_dir = bases.BasisDirection("one_dimension") 48 | one_space = bases.VectorSpaceWithBasis([one_dir]) 49 | 50 | input_space = bases.join_vector_spaces(input_space, bos_space, one_space) 51 | residual_space = bases.join_vector_spaces(input_space, output_space) 52 | bos_vec = residual_space.vector_from_basis_direction(bos_dir) 53 | one_vec = residual_space.vector_from_basis_direction(one_dir) 54 | 55 | block = selector_width.selector_width( 56 | query_space=input_space, 57 | key_space=input_space, 58 | output_space=output_space, 59 | bos_space=bos_space, 60 | one_space=one_space, 61 | attn_fn=lambda x, y: True, 62 | out_value_set=set(range(len(input_seq) + 1)), 63 | categorical_output=categorical_output, 64 | causal=causal, 65 | label="select_all", 66 | ) 67 | 68 | test_inputs = [bos_vec + one_vec] 69 | for x in input_seq: 70 | test_inputs.append( 71 | residual_space.vector_from_basis_direction( 72 | bases.BasisDirection("input", x) 73 | ) 74 | + one_vec 75 | ) 76 | test_inputs = bases.VectorInBasis.stack(test_inputs) 77 | 78 | # Expect length of the input sequence 79 | if causal: 80 | expected_results = list(range(1, len(input_seq) + 1)) 81 | else: 82 | expected_results = [len(input_seq) for _ in input_seq] 83 | 84 | if categorical_output: 85 | expected_results = [ 86 | output_space.vector_from_basis_direction( 87 | bases.BasisDirection("output", x) 88 | ) 89 | for x in expected_results 90 | ] 91 | else: 92 | output_vec = output_space.vector_from_basis_direction( 93 | bases.BasisDirection("output") 94 | ) 95 | expected_results = [x * output_vec for x in expected_results] 96 | 97 | expected_results = bases.VectorInBasis.stack(expected_results) 98 | 99 | test_outputs = block.apply(test_inputs).project(output_space) 100 | self.assertVectorAllClose( 101 | tests_common.strip_bos_token(test_outputs), expected_results 102 | ) 103 | 104 | @parameterized.product( 105 | causal=[False, True], 106 | categorical_output=[False, True], 107 | input_seq=[[1] * 20, [2] * 50], 108 | ) 109 | def test_selector_width_works_for_long_sequences( 110 | self, causal, categorical_output, input_seq 111 | ): 112 | vocab = range(-20, 20) 113 | input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) 114 | 115 | if categorical_output: 116 | output_space = bases.VectorSpaceWithBasis.from_values( 117 | "output", range(100) 118 | ) 119 | else: 120 | output_space = bases.VectorSpaceWithBasis( 121 | [bases.BasisDirection("output")] 122 | ) 123 | 124 | bos_dir = bases.BasisDirection("bos_dimension") 125 | bos_space = bases.VectorSpaceWithBasis([bos_dir]) 126 | 127 | one_dir = bases.BasisDirection("one_dimension") 128 | one_space = bases.VectorSpaceWithBasis([one_dir]) 129 | 130 | input_space = bases.join_vector_spaces(input_space, bos_space, one_space) 131 | 132 | try: 133 | selector_width.selector_width( 134 | query_space=input_space, 135 | key_space=input_space, 136 | output_space=output_space, 137 | bos_space=bos_space, 138 | one_space=one_space, 139 | attn_fn=lambda x, y: True, 140 | out_value_set=set(range(len(input_seq) + 1)), 141 | categorical_output=categorical_output, 142 | causal=causal, 143 | label="select_all", 144 | ) 145 | except AssertionError as e: 146 | if "output value mismatch" in str(e): 147 | # assertion raised if there are inconsistency in the expected output 148 | # values (likely due to floating point issues for long sequence length) 149 | self.fail(str(e)) 150 | else: 151 | raise e 152 | 153 | @parameterized.product( 154 | causal=[False, True], 155 | categorical_output=[False, True], 156 | input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]], 157 | ) 158 | def test_selector_width_of_select_none_is_zero( 159 | self, causal, categorical_output, input_seq 160 | ): 161 | vocab = range(-20, 20) 162 | input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) 163 | 164 | if categorical_output: 165 | output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) 166 | else: 167 | output_space = bases.VectorSpaceWithBasis( 168 | [bases.BasisDirection("output")] 169 | ) 170 | 171 | bos_dir = bases.BasisDirection("bos_dimension") 172 | bos_space = bases.VectorSpaceWithBasis([bos_dir]) 173 | 174 | one_dir = bases.BasisDirection("one_dimension") 175 | one_space = bases.VectorSpaceWithBasis([one_dir]) 176 | 177 | input_space = bases.join_vector_spaces(input_space, bos_space, one_space) 178 | residual_space = bases.join_vector_spaces(input_space, output_space) 179 | bos_vec = residual_space.vector_from_basis_direction(bos_dir) 180 | one_vec = residual_space.vector_from_basis_direction(one_dir) 181 | 182 | block = selector_width.selector_width( 183 | query_space=input_space, 184 | key_space=input_space, 185 | output_space=output_space, 186 | bos_space=bos_space, 187 | one_space=one_space, 188 | attn_fn=lambda x, y: False, 189 | out_value_set=set(range(len(input_seq) + 1)), 190 | categorical_output=categorical_output, 191 | causal=causal, 192 | label="select_all", 193 | ) 194 | 195 | test_inputs = [bos_vec + one_vec] 196 | for x in input_seq: 197 | test_inputs.append( 198 | residual_space.vector_from_basis_direction( 199 | bases.BasisDirection("input", x) 200 | ) 201 | + one_vec 202 | ) 203 | test_inputs = bases.VectorInBasis.stack(test_inputs) 204 | 205 | # Expect zero output 206 | if categorical_output: 207 | expected_results = [ 208 | output_space.vector_from_basis_direction( 209 | bases.BasisDirection("output", 0) 210 | ) 211 | for _ in input_seq 212 | ] 213 | else: 214 | expected_results = [output_space.null_vector() for _ in input_seq] 215 | expected_results = bases.VectorInBasis.stack(expected_results) 216 | 217 | test_outputs = block.apply(test_inputs).project(output_space) 218 | self.assertVectorAllClose( 219 | tests_common.strip_bos_token(test_outputs), expected_results 220 | ) 221 | 222 | 223 | if __name__ == "__main__": 224 | absltest.main() 225 | -------------------------------------------------------------------------------- /tracr/craft/tests_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper functions for tests.""" 16 | 17 | from absl.testing import parameterized 18 | import numpy as np 19 | from tracr.craft import bases 20 | 21 | 22 | def strip_bos_token(vector: bases.VectorInBasis) -> bases.VectorInBasis: 23 | """Removes BOS token of a vector.""" 24 | return bases.VectorInBasis(vector.basis_directions, vector.magnitudes[1:]) 25 | 26 | 27 | class VectorFnTestCase(parameterized.TestCase): 28 | """Asserts for vectors.""" 29 | 30 | def assertVectorAllClose(self, v1: bases.VectorInBasis, 31 | v2: bases.VectorInBasis): 32 | self.assertEqual(v1.basis_directions, v2.basis_directions) 33 | np.testing.assert_allclose(v1.magnitudes, v2.magnitudes, atol=1e-7) 34 | -------------------------------------------------------------------------------- /tracr/craft/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Pieces for making transformers.""" 16 | 17 | import abc 18 | import dataclasses 19 | from typing import Iterable, List, Optional, Sequence, Union 20 | 21 | import numpy as np 22 | 23 | from tracr.craft import bases 24 | from tracr.craft import vectorspace_fns 25 | 26 | project = vectorspace_fns.project 27 | 28 | 29 | def _np_softmax(x, axis=-1): 30 | x_max = np.max(x, axis=axis, keepdims=True) 31 | return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) 32 | 33 | 34 | def _np_relu(x): 35 | return np.where(x > 0, x, 0) 36 | 37 | 38 | def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: 39 | return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) 40 | 41 | 42 | class Block(abc.ABC): 43 | """Transformer block, acting on a sequence of vector space elements. 44 | 45 | Attributes: 46 | residual_space: Vector space that contains all subspaces the Block interacts 47 | with. This can be either the full residual space of a model or a subspace. 48 | """ 49 | residual_space: bases.VectorSpaceWithBasis 50 | 51 | @abc.abstractmethod 52 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 53 | """Applies self to an input.""" 54 | 55 | 56 | @dataclasses.dataclass 57 | class AttentionHead(Block): 58 | """A transformer attention head.""" 59 | w_qk: vectorspace_fns.ScalarBilinear 60 | w_ov: vectorspace_fns.Linear 61 | residual_space: Optional[bases.VectorSpaceWithBasis] = None 62 | causal: bool = False 63 | 64 | def __post_init__(self): 65 | """Infer residual stream and typecheck subspaces.""" 66 | if self.residual_space is None: 67 | self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, 68 | self.w_qk.right_space, 69 | self.w_ov.input_space, 70 | self.w_ov.output_space) 71 | 72 | assert self.w_qk.left_space.issubspace(self.residual_space) 73 | assert self.w_qk.right_space.issubspace(self.residual_space) 74 | assert self.w_ov.input_space.issubspace(self.residual_space) 75 | assert self.w_ov.output_space.issubspace(self.residual_space) 76 | 77 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 78 | assert self.residual_space is not None 79 | assert x in self.residual_space 80 | # seq_len x query_space 81 | queries = x.project(self.w_qk.left_space) 82 | # seq_len x key_space 83 | keys = x.project(self.w_qk.right_space) 84 | 85 | attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T 86 | 87 | if self.causal: 88 | # The 1 gives us the matrix above the diagonal. 89 | mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) 90 | attn_matrix = attn_matrix + mask 91 | 92 | attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to 93 | values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model 94 | 95 | magnitudes = attn_weights @ values # seq_len_from, d_model 96 | return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) 97 | 98 | def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 99 | """Wov but acting on the residual space.""" 100 | x = project(self.residual_space, self.w_ov.input_space)(x) 101 | out = self.w_ov(x) 102 | return project(self.w_ov.output_space, self.residual_space)(out) 103 | 104 | @property 105 | def num_heads(self) -> int: 106 | return 1 107 | 108 | def as_multi(self) -> "MultiAttentionHead": 109 | return MultiAttentionHead([self]) 110 | 111 | 112 | @dataclasses.dataclass 113 | class MultiAttentionHead(Block): 114 | """Applies attention heads in parallel.""" 115 | sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]] 116 | 117 | def __post_init__(self): 118 | spaces = [block.residual_space for block in self.sub_blocks] 119 | self.residual_space, *others = spaces 120 | assert all(s == self.residual_space for s in others) 121 | 122 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 123 | # each element is seq_len x embedding 124 | outs = [block.apply(x) for block in self.sub_blocks] 125 | return bases.VectorInBasis.sum(outs) # seq_len x embedding 126 | 127 | @property 128 | def num_heads(self) -> int: 129 | return sum(sub_block.num_heads for sub_block in self.sub_blocks) 130 | 131 | def heads(self) -> Iterable[AttentionHead]: 132 | for sub_block in self.sub_blocks: 133 | if isinstance(sub_block, AttentionHead): 134 | yield sub_block 135 | elif isinstance(sub_block, MultiAttentionHead): 136 | yield from sub_block.heads() 137 | else: 138 | raise NotImplementedError() 139 | 140 | def as_multi(self) -> "MultiAttentionHead": 141 | return self 142 | 143 | 144 | @dataclasses.dataclass 145 | class MLP(Block): 146 | """A transformer MLP block.""" 147 | fst: vectorspace_fns.Linear 148 | snd: vectorspace_fns.Linear 149 | residual_space: Optional[bases.VectorSpaceWithBasis] = None 150 | 151 | def __post_init__(self): 152 | """Typecheck subspaces.""" 153 | if self.residual_space is None: 154 | self.residual_space = bases.join_vector_spaces(self.fst.input_space, 155 | self.snd.output_space) 156 | 157 | assert self.fst.output_space == self.snd.input_space 158 | assert self.fst.input_space.issubspace(self.residual_space) 159 | assert self.snd.output_space.issubspace(self.residual_space) 160 | 161 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 162 | assert x in self.residual_space 163 | 164 | x = project(self.residual_space, self.fst.input_space)(x) 165 | hidden = self.fst(x) 166 | hidden = relu(hidden) 167 | out = self.snd(hidden) 168 | return project(self.snd.output_space, self.residual_space)(out) 169 | 170 | @classmethod 171 | def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": 172 | fst = vectorspace_fns.Linear.combine_in_parallel( 173 | [block.fst for block in mlps]) 174 | snd = vectorspace_fns.Linear.combine_in_parallel( 175 | [block.snd for block in mlps]) 176 | return cls(fst=fst, snd=snd, residual_space=None) 177 | 178 | 179 | # Block that fits into a half-layer, without residual connections. 180 | HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] 181 | 182 | 183 | @dataclasses.dataclass 184 | class SeriesWithResiduals(Block): 185 | """A series of blocks with residual connections.""" 186 | blocks: List[HalfLayerBlock] 187 | 188 | def __post_init__(self): 189 | spaces = [block.residual_space for block in self.blocks] 190 | self.residual_space = bases.join_vector_spaces(*spaces) 191 | 192 | def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: 193 | x = x.project(self.residual_space) 194 | for block in self.blocks: 195 | x_in = x.project(block.residual_space) 196 | x_out = block.apply(x_in).project(self.residual_space) 197 | x = x + x_out 198 | return x 199 | -------------------------------------------------------------------------------- /tracr/craft/transformers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformers.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracr.craft import bases 21 | from tracr.craft import tests_common 22 | from tracr.craft import transformers 23 | from tracr.craft import vectorspace_fns as vs_fns 24 | 25 | # This makes it easier to use comments to annotate dimensions in arrays 26 | # pylint: disable=g-no-space-after-comment 27 | 28 | 29 | class AttentionHeadTest(tests_common.VectorFnTestCase): 30 | 31 | @parameterized.parameters([ 32 | dict(with_residual_stream=False), 33 | dict(with_residual_stream=True), 34 | ]) 35 | def test_attention_head(self, with_residual_stream): 36 | i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) 37 | o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) 38 | q = bases.VectorSpaceWithBasis.from_values("q", [1, 2]) 39 | k = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) 40 | rs = bases.direct_sum(i, o, q, k) 41 | 42 | seq = bases.VectorInBasis( 43 | rs.basis, 44 | np.array([ 45 | #i1 i2 o1 o2 q1 q2 p1 p2 46 | [1, 0, 0, 0, 1, 0, 1, 0], 47 | [0, 1, 0, 0, 0, 1, 0, 1], 48 | ])) 49 | 50 | head = transformers.AttentionHead( 51 | w_qk=vs_fns.ScalarBilinear(q, k, 52 | np.eye(2) * 100), 53 | w_ov=vs_fns.Linear(i, o, np.eye(2)), 54 | residual_space=rs if with_residual_stream else None, 55 | causal=False, 56 | ) 57 | 58 | self.assertVectorAllClose( 59 | head.apply(seq), 60 | bases.VectorInBasis( 61 | rs.basis, 62 | np.array([ 63 | #i1 i2 o1 o2 q1 q2 p1 p2 64 | [0, 0, 1, 0, 0, 0, 0, 0], 65 | [0, 0, 0, 1, 0, 0, 0, 0], 66 | ])), 67 | ) 68 | 69 | 70 | class MLPTest(tests_common.VectorFnTestCase): 71 | 72 | @parameterized.parameters([ 73 | dict(with_residual_stream=False, same_in_out=False), 74 | dict(with_residual_stream=False, same_in_out=True), 75 | dict(with_residual_stream=True, same_in_out=False), 76 | dict(with_residual_stream=True, same_in_out=True), 77 | ]) 78 | def test_mlp(self, with_residual_stream, same_in_out): 79 | i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) 80 | if same_in_out: 81 | o, rs = i, i 82 | expected_result = np.array([ 83 | #o1 o2 84 | [1, 0], 85 | [0, 1], 86 | ]) 87 | else: 88 | o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) 89 | rs = bases.direct_sum(i, o) 90 | expected_result = np.array([ 91 | #i1 i2 o1 o2 92 | [0, 0, 1, 0], 93 | [0, 0, 0, 1], 94 | ]) 95 | h = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) 96 | 97 | seq = bases.VectorInBasis( 98 | i.basis, 99 | np.array([ 100 | #i1 i2 101 | [1, -1], 102 | [-1, 1], 103 | ])).project(rs) 104 | 105 | mlp = transformers.MLP( 106 | fst=vs_fns.Linear(i, h, np.eye(2)), 107 | snd=vs_fns.Linear(h, o, np.eye(2)), 108 | residual_space=rs if with_residual_stream else None, 109 | ) 110 | 111 | self.assertEqual( 112 | mlp.apply(seq), 113 | bases.VectorInBasis(rs.basis, expected_result), 114 | ) 115 | 116 | def test_combining_mlps(self): 117 | in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2]) 118 | in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4]) 119 | out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2]) 120 | residual_space = bases.join_vector_spaces(in12, in34, out12) 121 | 122 | h1 = bases.VectorSpaceWithBasis.from_values("h", [1]) 123 | h2 = bases.VectorSpaceWithBasis.from_values("h", [2]) 124 | 125 | # MLP1 maps in2 -> h1 -> out1 126 | mlp1 = transformers.MLP( 127 | fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])), 128 | snd=vs_fns.Linear(h1, out12, np.array([[1, 0]]))) 129 | 130 | # MLP2 maps in3 -> h2 -> out2 131 | mlp2 = transformers.MLP( 132 | fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])), 133 | snd=vs_fns.Linear(h2, out12, np.array([[0, 1]]))) 134 | 135 | mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2]) 136 | 137 | seq = bases.VectorInBasis( 138 | bases.direct_sum(in12, in34).basis, 139 | np.array([ 140 | #i1 i2 i3 i4 141 | [1, 2, 0, 0], 142 | [0, 2, 3, 4], 143 | ])).project(residual_space) 144 | 145 | expected_result = bases.VectorInBasis( 146 | out12.basis, 147 | np.array([ 148 | #o1 o2 149 | [2, 0], 150 | [2, 3], 151 | ])) 152 | 153 | self.assertEqual( 154 | mlp.apply(seq).project(out12), 155 | expected_result, 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /tracr/craft/vectorspace_fns.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions on vector spaces.""" 16 | 17 | import abc 18 | import dataclasses 19 | from typing import Callable, Sequence 20 | 21 | import numpy as np 22 | from tracr.craft import bases 23 | 24 | VectorSpaceWithBasis = bases.VectorSpaceWithBasis 25 | VectorInBasis = bases.VectorInBasis 26 | BasisDirection = bases.BasisDirection 27 | 28 | 29 | class VectorFunction(abc.ABC): 30 | """A function that acts on vectors.""" 31 | 32 | input_space: VectorSpaceWithBasis 33 | output_space: VectorSpaceWithBasis 34 | 35 | @abc.abstractmethod 36 | def __call__(self, x: VectorInBasis) -> VectorInBasis: 37 | """Evaluates the function.""" 38 | 39 | 40 | class Linear(VectorFunction): 41 | """A linear function.""" 42 | 43 | def __init__( 44 | self, 45 | input_space: VectorSpaceWithBasis, 46 | output_space: VectorSpaceWithBasis, 47 | matrix: np.ndarray, 48 | ): 49 | """Initialises. 50 | 51 | Args: 52 | input_space: The input vector space. 53 | output_space: The output vector space. 54 | matrix: a [input, output] matrix acting in a (sorted) basis. 55 | """ 56 | self.input_space = input_space 57 | self.output_space = output_space 58 | self.matrix = matrix 59 | 60 | def __post_init__(self) -> None: 61 | output_size, input_size = self.matrix.shape 62 | assert input_size == self.input_space.num_dims 63 | assert output_size == self.output_space.num_dims 64 | 65 | def __call__(self, x: VectorInBasis) -> VectorInBasis: 66 | if x not in self.input_space: 67 | raise TypeError(f"x={x} not in self.input_space={self.input_space}.") 68 | return self.output_space.make_vector(x.magnitudes @ self.matrix) 69 | 70 | @classmethod 71 | def from_action( 72 | cls, 73 | input_space: VectorSpaceWithBasis, 74 | output_space: VectorSpaceWithBasis, 75 | action: Callable[[BasisDirection], VectorInBasis], 76 | ) -> "Linear": 77 | """from_action(i, o)(action) creates a Linear.""" 78 | 79 | matrix = np.zeros((input_space.num_dims, output_space.num_dims)) 80 | for i, direction in enumerate(input_space.basis): 81 | out_vector = action(direction) 82 | if out_vector not in output_space: 83 | raise TypeError( 84 | f"image of {direction} from input_space={input_space} " 85 | f"is not in output_space={output_space}" 86 | ) 87 | matrix[i, :] = out_vector.magnitudes 88 | 89 | return Linear(input_space, output_space, matrix) 90 | 91 | @classmethod 92 | def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": 93 | """Combines multiple parallel linear functions into a single one.""" 94 | joint_input_space = bases.join_vector_spaces( 95 | *[fn.input_space for fn in fns] 96 | ) 97 | joint_output_space = bases.join_vector_spaces( 98 | *[fn.output_space for fn in fns] 99 | ) 100 | 101 | # Cache properties for the parents to avoid recomputing for each child. 102 | # Since the index_by_direction cached_property of the children is needed 103 | # within the action, it would be computed for every single child. This is 104 | # redundant as they share the same basis. By accessing the properties here, 105 | # we ensure they are only computed once and passed on to the children. 106 | _ = joint_input_space.index_by_direction 107 | _ = joint_output_space.index_by_direction 108 | 109 | def action(x: bases.BasisDirection) -> bases.VectorInBasis: 110 | out = joint_output_space.null_vector() 111 | for fn in fns: 112 | if x in fn.input_space: 113 | x_vec = fn.input_space.vector_from_basis_direction(x) 114 | applied = fn(x_vec) 115 | out = out.add_directions(applied) 116 | return out 117 | 118 | return cls.from_action(joint_input_space, joint_output_space, action) 119 | 120 | 121 | def project( 122 | from_space: VectorSpaceWithBasis, 123 | to_space: VectorSpaceWithBasis, 124 | ) -> Linear: 125 | """Creates a projection.""" 126 | 127 | def action(direction: bases.BasisDirection) -> VectorInBasis: 128 | if direction in to_space: 129 | return to_space.vector_from_basis_direction(direction) 130 | else: 131 | return to_space.null_vector() 132 | 133 | return Linear.from_action(from_space, to_space, action=action) 134 | 135 | 136 | @dataclasses.dataclass 137 | class ScalarBilinear: 138 | """A scalar-valued bilinear operator.""" 139 | 140 | left_space: VectorSpaceWithBasis 141 | right_space: VectorSpaceWithBasis 142 | matrix: np.ndarray 143 | 144 | def __post_init__(self): 145 | """Ensure matrix acts in sorted bases and typecheck sizes.""" 146 | left_size, right_size = self.matrix.shape 147 | assert left_size == self.left_space.num_dims 148 | assert right_size == self.right_space.num_dims 149 | 150 | def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: 151 | """Describes the action of the operator on vectors.""" 152 | if x not in self.left_space: 153 | raise TypeError(f"x={x} not in self.left_space={self.left_space}.") 154 | if y not in self.right_space: 155 | raise TypeError(f"y={y} not in self.right_space={self.right_space}.") 156 | return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() 157 | 158 | @classmethod 159 | def from_action( 160 | cls, 161 | left_space: VectorSpaceWithBasis, 162 | right_space: VectorSpaceWithBasis, 163 | action: Callable[[BasisDirection, BasisDirection], float], 164 | ) -> "ScalarBilinear": 165 | """from_action(l, r)(action) creates a ScalarBilinear.""" 166 | 167 | matrix = np.zeros((left_space.num_dims, right_space.num_dims)) 168 | for i, left_direction in enumerate(left_space.basis): 169 | for j, right_direction in enumerate(right_space.basis): 170 | matrix[i, j] = action(left_direction, right_direction) 171 | 172 | return ScalarBilinear(left_space, right_space, matrix) 173 | -------------------------------------------------------------------------------- /tracr/craft/vectorspace_fns_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for vectorspace_fns.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from tracr.craft import bases 21 | from tracr.craft import tests_common 22 | from tracr.craft import vectorspace_fns as vs_fns 23 | 24 | 25 | class LinearTest(tests_common.VectorFnTestCase): 26 | 27 | def test_identity_from_matrix(self): 28 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 29 | f = vs_fns.Linear(vs, vs, np.eye(3)) 30 | for v in vs.basis_vectors(): 31 | self.assertEqual(f(v), v) 32 | 33 | def test_identity_from_action(self): 34 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) 35 | f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction) 36 | for v in vs.basis_vectors(): 37 | self.assertEqual(f(v), v) 38 | 39 | def test_nonidentiy(self): 40 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 41 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 42 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 43 | 44 | f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]])) 45 | 46 | self.assertEqual( 47 | f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7]))) 48 | self.assertEqual( 49 | f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1]))) 50 | 51 | def test_different_vector_spaces(self): 52 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 53 | vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) 54 | a, b = vs1.basis_vectors() 55 | c, d = vs2.basis_vectors() 56 | 57 | f = vs_fns.Linear(vs1, vs2, np.eye(2)) 58 | 59 | self.assertEqual(f(a), c) 60 | self.assertEqual(f(b), d) 61 | 62 | def test_combining_linear_functions_with_different_input(self): 63 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 64 | vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) 65 | vs = bases.direct_sum(vs1, vs2) 66 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 67 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 68 | c = vs.vector_from_basis_direction(bases.BasisDirection("c")) 69 | d = vs.vector_from_basis_direction(bases.BasisDirection("d")) 70 | 71 | f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]])) 72 | f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]])) 73 | f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) 74 | 75 | self.assertEqual( 76 | f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0]))) 77 | self.assertEqual( 78 | f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0]))) 79 | self.assertEqual( 80 | f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0]))) 81 | self.assertEqual( 82 | f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0]))) 83 | 84 | def test_combining_linear_functions_with_same_input(self): 85 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 86 | a = vs.vector_from_basis_direction(bases.BasisDirection("a")) 87 | b = vs.vector_from_basis_direction(bases.BasisDirection("b")) 88 | 89 | f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]])) 90 | f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]])) 91 | f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) 92 | 93 | self.assertEqual( 94 | f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1]))) 95 | self.assertEqual( 96 | f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0]))) 97 | self.assertEqual(f3(a), f1(a) + f2(a)) 98 | self.assertEqual(f3(b), f1(b) + f2(b)) 99 | 100 | 101 | class ProjectionTest(tests_common.VectorFnTestCase): 102 | 103 | def test_projection_to_larger_space(self): 104 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 105 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 106 | a1, b1 = vs1.basis_vectors() 107 | a2, b2, _, _ = vs2.basis_vectors() 108 | 109 | f = vs_fns.project(vs1, vs2) 110 | 111 | self.assertEqual(f(a1), a2) 112 | self.assertEqual(f(b1), b2) 113 | 114 | def test_projection_to_smaller_space(self): 115 | vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) 116 | vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 117 | a1, b1, c1, d1 = vs1.basis_vectors() 118 | a2, b2 = vs2.basis_vectors() 119 | 120 | f = vs_fns.project(vs1, vs2) 121 | 122 | self.assertEqual(f(a1), a2) 123 | self.assertEqual(f(b1), b2) 124 | self.assertEqual(f(c1), vs2.null_vector()) 125 | self.assertEqual(f(d1), vs2.null_vector()) 126 | 127 | 128 | class ScalarBilinearTest(parameterized.TestCase): 129 | 130 | def test_identity_matrix(self): 131 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 132 | a, b = vs.basis_vectors() 133 | 134 | f = vs_fns.ScalarBilinear(vs, vs, np.eye(2)) 135 | 136 | self.assertEqual(f(a, a), 1) 137 | self.assertEqual(f(a, b), 0) 138 | self.assertEqual(f(b, a), 0) 139 | self.assertEqual(f(b, b), 1) 140 | 141 | def test_identity_from_action(self): 142 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 143 | a, b = vs.basis_vectors() 144 | 145 | f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y)) 146 | 147 | self.assertEqual(f(a, a), 1) 148 | self.assertEqual(f(a, b), 0) 149 | self.assertEqual(f(b, a), 0) 150 | self.assertEqual(f(b, b), 1) 151 | 152 | def test_non_identity(self): 153 | vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) 154 | a, b = vs.basis_vectors() 155 | 156 | f = vs_fns.ScalarBilinear.from_action(vs, vs, 157 | lambda x, y: int(x.name == "a")) 158 | 159 | self.assertEqual(f(a, a), 1) 160 | self.assertEqual(f(a, b), 1) 161 | self.assertEqual(f(b, a), 0) 162 | self.assertEqual(f(b, b), 0) 163 | 164 | 165 | if __name__ == "__main__": 166 | absltest.main() 167 | -------------------------------------------------------------------------------- /tracr/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/rasp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/rasp/causal_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RASP Evaluator which applies causal masks to selectors.""" 16 | 17 | from typing import Sequence, Union 18 | 19 | import numpy as np 20 | from tracr.rasp import rasp 21 | 22 | 23 | class CausalEvaluator(rasp.DefaultRASPEvaluator): 24 | """Evaluates RASP with causal masking.""" 25 | 26 | def evaluate( 27 | self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value] 28 | ) -> Union[Sequence[rasp.Value], rasp.SelectorValue]: 29 | out = super().evaluate(expr, xs) 30 | 31 | if not isinstance(expr, rasp.Selector): 32 | return out 33 | 34 | out = np.array(out) 35 | causal_mask = np.tril(np.full(out.shape, 1)) 36 | return np.logical_and(causal_mask, out).tolist() 37 | 38 | 39 | evaluate = CausalEvaluator().evaluate 40 | -------------------------------------------------------------------------------- /tracr/rasp/causal_eval_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for causal_eval.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | 20 | from tracr.rasp import causal_eval 21 | from tracr.rasp import rasp 22 | 23 | 24 | class CausalEvalTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | dict( 28 | testcase_name="constant_selector_3x3_1", 29 | program=rasp.ConstantSelector([ 30 | [True, True, True], 31 | [True, True, True], 32 | [True, True, True], 33 | ]), 34 | input_sequence=[True, True, True], 35 | expected_output=[ 36 | [True, False, False], 37 | [True, True, False], 38 | [True, True, True], 39 | ]), 40 | dict( 41 | testcase_name="constant_selector_3x3_2", 42 | program=rasp.ConstantSelector([ 43 | [True, True, True], 44 | [False, True, True], 45 | [True, False, True], 46 | ]), 47 | input_sequence=[True, True, True], 48 | expected_output=[ 49 | [True, False, False], 50 | [False, True, False], 51 | [True, False, True], 52 | ])) 53 | def test_evaluations(self, program, input_sequence, expected_output): 54 | self.assertListEqual( 55 | causal_eval.evaluate(program, input_sequence), 56 | expected_output, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /tracr/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/transformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Instrumented attention layer (forked from the Haiku library implementation). 16 | """ 17 | 18 | from typing import Optional 19 | import warnings 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | @chex.dataclass 29 | class AttentionOutput: 30 | out: jax.Array # [..., T', D'] 31 | logits: jax.Array # [..., H, T', T] 32 | 33 | 34 | class MultiHeadAttention(hk.Module): 35 | """Multi-headed attention (MHA) module. 36 | 37 | This module is intended for attending over sequences of vectors. 38 | 39 | Rough sketch: 40 | - Compute keys (K), queries (Q), and values (V) as projections of inputs. 41 | - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). 42 | - Output is another projection of WV^T. 43 | 44 | For more detail, see the original Transformer paper: 45 | "Attention is all you need" https://arxiv.org/abs/1706.03762. 46 | 47 | Glossary of shapes: 48 | - T: Sequence length. 49 | - D: Vector (embedding) size. 50 | - H: Number of attention heads. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | num_heads: int, 56 | key_size: int, 57 | # TODO(b/240019186): Remove `w_init_scale`. 58 | w_init_scale: Optional[float] = None, 59 | *, 60 | w_init: Optional[hk.initializers.Initializer] = None, 61 | value_size: Optional[int] = None, 62 | model_size: Optional[int] = None, 63 | name: Optional[str] = None, 64 | ): 65 | """Initialises the module. 66 | 67 | Args: 68 | num_heads: Number of independent attention heads (H). 69 | key_size: The size of keys (K) and queries used for attention. 70 | w_init_scale: DEPRECATED. Please use w_init instead. 71 | w_init: Initialiser for weights in the linear map. 72 | value_size: Optional size of the value projection (V). If None, defaults 73 | to the key size (K). 74 | model_size: Optional size of the output embedding (D'). If None, defaults 75 | to the key size multiplied by the number of heads (K * H). 76 | name: Optional name for this module. 77 | """ 78 | super().__init__(name=name) 79 | self.num_heads = num_heads 80 | self.key_size = key_size 81 | self.value_size = value_size or key_size 82 | self.model_size = model_size or key_size * num_heads 83 | 84 | # Backwards-compatibility for w_init_scale. 85 | if w_init_scale is not None: 86 | warnings.warn( 87 | "w_init_scale is deprecated; please pass an explicit weight " 88 | "initialiser instead.", DeprecationWarning) 89 | if w_init and w_init_scale: 90 | raise ValueError("Please provide only `w_init`, not `w_init_scale`.") 91 | if w_init is None and w_init_scale is None: 92 | raise ValueError("Please provide a weight initializer: `w_init`.") 93 | if w_init is None: 94 | w_init = hk.initializers.VarianceScaling(w_init_scale) 95 | self.w_init = w_init 96 | 97 | def __call__( 98 | self, 99 | query: jnp.ndarray, 100 | key: jnp.ndarray, 101 | value: jnp.ndarray, 102 | mask: Optional[jnp.ndarray] = None, 103 | ) -> AttentionOutput: 104 | """Computes (optionally masked) MHA with queries, keys & values. 105 | 106 | This module broadcasts over zero or more 'batch-like' leading dimensions. 107 | 108 | Args: 109 | query: Embeddings sequence used to compute queries; shape [..., T', D_q]. 110 | key: Embeddings sequence used to compute keys; shape [..., T, D_k]. 111 | value: Embeddings sequence used to compute values; shape [..., T, D_v]. 112 | mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. 113 | 114 | Returns: 115 | A new sequence of embeddings, consisting of a projection of the 116 | attention-weighted value projections; shape [..., T', D']. 117 | """ 118 | 119 | # In shape hints below, we suppress the leading dims [...] for brevity. 120 | # Hence e.g. [A, B] should be read in every case as [..., A, B]. 121 | *leading_dims, sequence_length, _ = query.shape 122 | projection = self._linear_projection 123 | 124 | # Compute key/query/values (overload K/Q/V to denote the respective sizes). 125 | query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] 126 | key_heads = projection(key, self.key_size, "key") # [T, H, K] 127 | value_heads = projection(value, self.value_size, "value") # [T, H, V] 128 | 129 | # Compute attention weights. 130 | attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) 131 | attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) 132 | if mask is not None: 133 | if mask.ndim != attn_logits.ndim: 134 | raise ValueError( 135 | f"Mask dimensionality {mask.ndim} must match logits dimensionality " 136 | f"{attn_logits.ndim}.") 137 | attn_logits = jnp.where(mask, attn_logits, -1e30) 138 | attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] 139 | 140 | # Weight the values by the attention and flatten the head vectors. 141 | attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) 142 | attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] 143 | 144 | # Apply another projection to get the final embeddings. 145 | final_projection = hk.Linear(self.model_size, w_init=self.w_init) 146 | return AttentionOutput( 147 | out=final_projection(attn), 148 | logits=attn_logits, 149 | ) 150 | 151 | @hk.transparent 152 | def _linear_projection( 153 | self, 154 | x: jnp.ndarray, 155 | head_size: int, 156 | name: Optional[str] = None, 157 | ) -> jnp.ndarray: 158 | y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x) 159 | *leading_dims, _ = x.shape 160 | return y.reshape((*leading_dims, self.num_heads, head_size)) 161 | -------------------------------------------------------------------------------- /tracr/transformer/compressed_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Modified transformer to learn a linear compression of the residual stream. 16 | 17 | CompressedTransformer adds three arguments compared to Transformer: 18 | - embedding_size: the size of the compressed residual stream. 19 | - unembed_at_every_layer: whether to apply the unembedding before applying 20 | attention and MLP layers 21 | - return_activations: whether to return all model activations rather than just 22 | the outputs 23 | """ 24 | 25 | import collections 26 | import dataclasses 27 | from typing import Optional 28 | 29 | import haiku as hk 30 | import jax 31 | import numpy as np 32 | 33 | from tracr.transformer import attention 34 | from tracr.transformer import model 35 | 36 | 37 | @dataclasses.dataclass 38 | class CompressedTransformer(hk.Module): 39 | """A transformer stack with linearly compressed residual stream.""" 40 | 41 | config: model.TransformerConfig 42 | name: Optional[str] = None 43 | 44 | def __call__( 45 | self, 46 | embeddings: jax.Array, # [B, T, D] 47 | mask: jax.Array, # [B, T] 48 | *, 49 | use_dropout: bool = True, 50 | embedding_size: Optional[int] = None, 51 | unembed_at_every_layer: bool = False, 52 | ) -> model.TransformerOutput: # [B, T, D] 53 | """Transforms input embedding sequences to output embedding sequences. 54 | 55 | Args: 56 | embeddings: Input embeddings to pass through the model. 57 | mask: Boolean mask to restrict the inputs the model uses. 58 | use_dropout: Turns dropout on/off. 59 | embedding_size: Dimension to compress the residual stream to. 60 | unembed_at_every_layer: Whether to unembed the residual stream when 61 | reading the input for every layer (keeping the layer input sizes) or to 62 | only unembed before the model output (compressing the layer inputs). 63 | 64 | Returns: 65 | The outputs of the forward pass through the transformer. 66 | """ 67 | 68 | def layer_norm(x: jax.Array) -> jax.Array: 69 | """Applies a unique LayerNorm to x with default settings.""" 70 | if self.config.layer_norm: 71 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 72 | return x 73 | 74 | initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) 75 | dropout_rate = self.config.dropout_rate if use_dropout else 0. 76 | _, seq_len, model_size = embeddings.shape 77 | 78 | # To compress the model, we multiply with a matrix W when reading from 79 | # the residual stream, and with W^T when writing to the residual stream. 80 | if embedding_size is not None: 81 | # [to_size, from_size] 82 | w_emb = hk.get_parameter( 83 | "w_emb", (embedding_size, model_size), 84 | init=hk.initializers.RandomNormal()) 85 | 86 | write_to_residual = lambda x: x @ w_emb.T 87 | read_from_residual = lambda x: x @ w_emb 88 | 89 | if not unembed_at_every_layer: 90 | model_size = embedding_size 91 | else: 92 | write_to_residual = lambda x: x 93 | read_from_residual = lambda x: x 94 | 95 | # Compute causal mask for autoregressive sequence modelling. 96 | mask = mask[:, None, None, :] # [B, H=1, T'=1, T] 97 | mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] 98 | 99 | if self.config.causal: 100 | causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] 101 | causal_mask = np.tril(causal_mask) 102 | mask = mask * causal_mask # [B, H=1, T, T] 103 | 104 | # Set up activation collection. 105 | collected = collections.defaultdict(list) 106 | 107 | def collect(**kwargs): 108 | for k, v in kwargs.items(): 109 | collected[k].append(v) 110 | 111 | residual = write_to_residual(embeddings) 112 | 113 | for layer in range(self.config.num_layers): 114 | with hk.experimental.name_scope(f"layer_{layer}"): 115 | # First the attention block. 116 | attn_block = attention.MultiHeadAttention( 117 | num_heads=self.config.num_heads, 118 | key_size=self.config.key_size, 119 | model_size=model_size, 120 | w_init=initializer, 121 | name="attn") 122 | 123 | attn_in = residual 124 | if unembed_at_every_layer: 125 | attn_in = read_from_residual(attn_in) 126 | attn_in = layer_norm(attn_in) 127 | attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) 128 | attn_out, attn_logits = attn_out.out, attn_out.logits 129 | if dropout_rate > 0: 130 | attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) 131 | 132 | if unembed_at_every_layer: 133 | collect(layer_outputs=attn_out, attn_logits=attn_logits) 134 | else: 135 | collect( 136 | layer_outputs=read_from_residual(attn_out), 137 | attn_logits=attn_logits, 138 | ) 139 | 140 | if unembed_at_every_layer: 141 | attn_out = write_to_residual(attn_out) 142 | residual = residual + attn_out 143 | 144 | collect(residuals=residual) 145 | 146 | # Then the dense block. 147 | with hk.experimental.name_scope("mlp"): 148 | dense_block = hk.Sequential([ 149 | hk.Linear( 150 | self.config.mlp_hidden_size, 151 | w_init=initializer, 152 | name="linear_1"), 153 | self.config.activation_function, 154 | hk.Linear(model_size, w_init=initializer, name="linear_2"), 155 | ]) 156 | 157 | dense_in = residual 158 | if unembed_at_every_layer: 159 | dense_in = read_from_residual(dense_in) 160 | dense_in = layer_norm(dense_in) 161 | dense_out = dense_block(dense_in) 162 | if dropout_rate > 0: 163 | dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) 164 | 165 | if unembed_at_every_layer: 166 | collect(layer_outputs=dense_out) 167 | else: 168 | collect(layer_outputs=read_from_residual(dense_out)) 169 | 170 | if unembed_at_every_layer: 171 | dense_out = write_to_residual(dense_out) 172 | residual = residual + dense_out 173 | 174 | collect(residuals=residual) 175 | 176 | output = read_from_residual(residual) 177 | output = layer_norm(output) 178 | 179 | return model.TransformerOutput( 180 | layer_outputs=collected["layer_outputs"], 181 | residuals=collected["residuals"], 182 | attn_logits=collected["attn_logits"], 183 | output=output, 184 | input_embeddings=embeddings, 185 | ) 186 | -------------------------------------------------------------------------------- /tracr/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic encoder for inputs with a fixed vocabulary.""" 16 | 17 | import abc 18 | from typing import Any, List, Optional, Sequence 19 | 20 | from tracr.craft import bases 21 | 22 | 23 | class Encoder(abc.ABC): 24 | """Encodes a list of tokens into a list of inputs for a transformer model. 25 | 26 | The abstract class does not make assumptions on the input and output types, 27 | and we have different encoders for different input types. 28 | """ 29 | 30 | @abc.abstractmethod 31 | def encode(self, inputs: List[Any]) -> List[Any]: 32 | return list() 33 | 34 | @abc.abstractmethod 35 | def decode(self, encodings: List[Any]) -> List[Any]: 36 | return list() 37 | 38 | @property 39 | def pad_token(self) -> Optional[str]: 40 | return None 41 | 42 | @property 43 | def bos_token(self) -> Optional[str]: 44 | return None 45 | 46 | @property 47 | def pad_encoding(self) -> Optional[int]: 48 | return None 49 | 50 | @property 51 | def bos_encoding(self) -> Optional[int]: 52 | return None 53 | 54 | 55 | class NumericalEncoder(Encoder): 56 | """Encodes numerical variables (simply using the identity mapping).""" 57 | 58 | def encode(self, inputs: List[float]) -> List[float]: 59 | return inputs 60 | 61 | def decode(self, encodings: List[float]) -> List[float]: 62 | return encodings 63 | 64 | 65 | class CategoricalEncoder(Encoder): 66 | """Encodes categorical variables with a fixed vocabulary.""" 67 | 68 | def __init__( 69 | self, 70 | basis: Sequence[bases.BasisDirection], 71 | enforce_bos: bool = False, 72 | bos_token: Optional[str] = None, 73 | pad_token: Optional[str] = None, 74 | max_seq_len: Optional[int] = None, 75 | ): 76 | """Initialises. If enforce_bos is set, ensures inputs start with it.""" 77 | if enforce_bos and not bos_token: 78 | raise ValueError("BOS token must be specified if enforcing BOS.") 79 | 80 | self.encoding_map = {} 81 | for i, direction in enumerate(basis): 82 | val = direction.value 83 | self.encoding_map[val] = i 84 | 85 | if bos_token and bos_token not in self.encoding_map: 86 | raise ValueError("BOS token missing in encoding.") 87 | 88 | if pad_token and pad_token not in self.encoding_map: 89 | raise ValueError("PAD token missing in encoding.") 90 | 91 | self.enforce_bos = enforce_bos 92 | self._bos_token = bos_token 93 | self._pad_token = pad_token 94 | self._max_seq_len = max_seq_len 95 | 96 | def encode(self, inputs: List[bases.Value]) -> List[int]: 97 | if self.enforce_bos and inputs[0] != self.bos_token: 98 | raise ValueError("First input token must be BOS token. " 99 | f"Should be '{self.bos_token}', but was '{inputs[0]}'.") 100 | if missing := set(inputs) - set(self.encoding_map.keys()): 101 | raise ValueError(f"Inputs {missing} not found in encoding ", 102 | self.encoding_map.keys()) 103 | if self._max_seq_len is not None and len(inputs) > self._max_seq_len: 104 | raise ValueError(f"inputs={inputs} are longer than the maximum " 105 | f"sequence length {self._max_seq_len}") 106 | 107 | return [self.encoding_map[x] for x in inputs] 108 | 109 | def decode(self, encodings: List[int]) -> List[bases.Value]: 110 | """Recover the tokens that corresponds to `ids`. Inverse of __call__.""" 111 | decoding_map = {val: key for key, val in self.encoding_map.items()} 112 | if missing := set(encodings) - set(decoding_map.keys()): 113 | raise ValueError(f"Inputs {missing} not found in decoding map ", 114 | decoding_map.keys()) 115 | return [decoding_map[x] for x in encodings] 116 | 117 | @property 118 | def vocab_size(self) -> int: 119 | return len(self.encoding_map) 120 | 121 | @property 122 | def bos_token(self) -> Optional[str]: 123 | return self._bos_token 124 | 125 | @property 126 | def pad_token(self) -> Optional[str]: 127 | return self._pad_token 128 | 129 | @property 130 | def bos_encoding(self) -> Optional[int]: 131 | return None if self.bos_token is None else self.encoding_map[self.bos_token] 132 | 133 | @property 134 | def pad_encoding(self) -> Optional[int]: 135 | return None if self.pad_token is None else self.encoding_map[self.pad_token] 136 | -------------------------------------------------------------------------------- /tracr/transformer/encoder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for transformer.encoder.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.craft import bases 20 | from tracr.transformer import encoder 21 | 22 | _BOS_TOKEN = "bos_encoder_test" 23 | _PAD_TOKEN = "pad_encoder_test" 24 | 25 | 26 | class CategoricalEncoderTest(parameterized.TestCase): 27 | 28 | def test_encode_raises_value_error_if_input_doesnt_start_with_bos(self): 29 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) 30 | basic_encoder = encoder.CategoricalEncoder( 31 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 32 | with self.assertRaisesRegex(ValueError, 33 | r"^.*First input token must be BOS token.*$"): 34 | basic_encoder.encode([1, 1, 1]) 35 | 36 | def test_encode_raises_value_error_if_input_not_in_vocab(self): 37 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN}) 38 | basic_encoder = encoder.CategoricalEncoder( 39 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 40 | with self.assertRaisesRegex(ValueError, 41 | r"^.*Inputs .* not found in encoding.*$"): 42 | basic_encoder.encode([_BOS_TOKEN, 1, 2, 3, 4]) 43 | 44 | def test_decode_raises_value_error_if_id_outside_of_vocab_size(self): 45 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, _BOS_TOKEN}) 46 | basic_encoder = encoder.CategoricalEncoder( 47 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 48 | with self.assertRaisesRegex(ValueError, 49 | r"^.*Inputs .* not found in decoding map.*$"): 50 | basic_encoder.decode([0, 1, 2, 3]) 51 | 52 | def test_encoder_raises_value_error_if_bos_not_in_basis(self): 53 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) 54 | with self.assertRaisesRegex(ValueError, 55 | r"^.*BOS token missing in encoding.*$"): 56 | unused_basic_encoder = encoder.CategoricalEncoder( 57 | vs.basis, bos_token=_BOS_TOKEN) 58 | 59 | def test_encoder_raises_value_error_if_pad_not_in_basis(self): 60 | vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3}) 61 | with self.assertRaisesRegex(ValueError, 62 | r"^.*PAD token missing in encoding.*$"): 63 | unused_basic_encoder = encoder.CategoricalEncoder( 64 | vs.basis, pad_token=_PAD_TOKEN) 65 | 66 | def test_encoder_encodes_bos_and_pad_tokens_as_expected(self): 67 | vs = bases.VectorSpaceWithBasis.from_values( 68 | "input", {1, 2, 3, _BOS_TOKEN, _PAD_TOKEN}) 69 | basic_encoder = encoder.CategoricalEncoder( 70 | vs.basis, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 71 | self.assertEqual( 72 | basic_encoder.encode([_BOS_TOKEN, _PAD_TOKEN]), 73 | [basic_encoder.bos_encoding, basic_encoder.pad_encoding]) 74 | 75 | @parameterized.parameters([ 76 | dict( 77 | vocab={1, 2, 3, _BOS_TOKEN}, # lexicographic order 78 | inputs=[_BOS_TOKEN, 3, 2, 1], 79 | expected=[3, 2, 1, 0]), 80 | dict( 81 | vocab={"a", "b", _BOS_TOKEN, "c"}, # lexicographic order 82 | inputs=[_BOS_TOKEN, "b", "b", "c"], 83 | expected=[2, 1, 1, 3]), 84 | ]) 85 | def test_tokens_are_encoded_in_lexicographic_order(self, vocab, inputs, 86 | expected): 87 | # Expect encodings to be assigned to ids according to a lexicographic 88 | # ordering of the vocab 89 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 90 | basic_encoder = encoder.CategoricalEncoder( 91 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN) 92 | encodings = basic_encoder.encode(inputs) 93 | self.assertEqual(encodings, expected) 94 | 95 | @parameterized.parameters([ 96 | dict(vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, expected=5), 97 | dict(vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b"}, expected=4), 98 | ]) 99 | def test_vocab_size_has_expected_value(self, vocab, expected): 100 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 101 | basic_encoder = encoder.CategoricalEncoder( 102 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 103 | self.assertEqual(basic_encoder.vocab_size, expected) 104 | 105 | @parameterized.parameters([ 106 | dict( 107 | vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, inputs=[_BOS_TOKEN, 3, 2, 108 | 1]), 109 | dict( 110 | vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b", "c"}, 111 | inputs=[_BOS_TOKEN, "b", "b", "c"]), 112 | ]) 113 | def test_decode_inverts_encode(self, vocab, inputs): 114 | vs = bases.VectorSpaceWithBasis.from_values("input", vocab) 115 | basic_encoder = encoder.CategoricalEncoder( 116 | vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN) 117 | encodings = basic_encoder.encode(inputs) 118 | recovered = basic_encoder.decode(encodings) 119 | self.assertEqual(recovered, inputs) 120 | 121 | 122 | if __name__ == "__main__": 123 | absltest.main() 124 | -------------------------------------------------------------------------------- /tracr/transformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Didactic example of an autoregressive Transformer-based language model. 16 | 17 | Glossary of shapes: 18 | - B: Batch size. 19 | - T: Sequence length. 20 | - D: Model embedding size. 21 | - H: Number of attention heads. 22 | - V: Vocabulary size. 23 | 24 | Forked from: haiku.examples.transformer.model 25 | """ 26 | 27 | import collections 28 | import dataclasses 29 | from typing import Callable, List, Optional 30 | 31 | import chex 32 | import haiku as hk 33 | import jax 34 | import jax.numpy as jnp 35 | import numpy as np 36 | from tracr.transformer import attention 37 | 38 | # hk.Modules are not always callable: github.com/deepmind/dm-haiku/issues/52 39 | # Ideally, we'd want a type: 40 | # CallableHaikuModule = Intersection[Callable[..., jax.Array], hk.Module] 41 | # But Intersection does not exist (yet): github.com/python/typing/issues/213 42 | CallableHaikuModule = Callable[..., jax.Array] 43 | 44 | 45 | @chex.dataclass 46 | class TransformerOutput: 47 | layer_outputs: List[jax.Array] # [B, T, D] 48 | residuals: List[jax.Array] # [B, T, D] 49 | attn_logits: List[jax.Array] # [B, H, T, T] 50 | output: jax.Array # [B, T, D] 51 | input_embeddings: jax.Array # [B, T, D] 52 | 53 | 54 | @dataclasses.dataclass 55 | class TransformerConfig: 56 | num_heads: int 57 | num_layers: int 58 | key_size: int 59 | mlp_hidden_size: int 60 | dropout_rate: float 61 | activation_function: Callable[[jax.Array], jax.Array] = jax.nn.gelu 62 | layer_norm: bool = True 63 | causal: bool = False 64 | 65 | 66 | @dataclasses.dataclass 67 | class Transformer(hk.Module): 68 | """A transformer stack.""" 69 | 70 | config: TransformerConfig 71 | name: Optional[str] = None 72 | 73 | def __call__( 74 | self, 75 | embeddings: jax.Array, # [B, T, D] 76 | mask: jax.Array, # [B, T] 77 | *, 78 | use_dropout: bool = True, 79 | ) -> TransformerOutput: 80 | """Transforms input embedding sequences to output embedding sequences.""" 81 | 82 | def layer_norm(x: jax.Array) -> jax.Array: 83 | """Applies a unique LayerNorm to x with default settings.""" 84 | if self.config.layer_norm: 85 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 86 | return x 87 | 88 | initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers) 89 | dropout_rate = self.config.dropout_rate if use_dropout else 0. 90 | _, seq_len, model_size = embeddings.shape 91 | 92 | # Compute causal mask for autoregressive sequence modelling. 93 | mask = mask[:, None, None, :] # [B, H=1, T'=1, T] 94 | mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T] 95 | 96 | if self.config.causal: 97 | causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T] 98 | causal_mask = np.tril(causal_mask) 99 | mask = mask * causal_mask # [B, H=1, T, T] 100 | 101 | # Set up activation collection. 102 | collected = collections.defaultdict(list) 103 | 104 | def collect(**kwargs): 105 | for k, v in kwargs.items(): 106 | collected[k].append(v) 107 | 108 | residual = embeddings 109 | for layer in range(self.config.num_layers): 110 | with hk.experimental.name_scope(f"layer_{layer}"): 111 | # First the attention block. 112 | attn_block = attention.MultiHeadAttention( 113 | num_heads=self.config.num_heads, 114 | key_size=self.config.key_size, 115 | model_size=model_size, 116 | w_init=initializer, 117 | name="attn") 118 | attn_in = layer_norm(residual) 119 | attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask) 120 | attn_out, attn_logits = attn_out.out, attn_out.logits 121 | if dropout_rate > 0: 122 | attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out) 123 | residual = residual + attn_out 124 | 125 | collect( 126 | residuals=residual, layer_outputs=attn_out, attn_logits=attn_logits) 127 | 128 | # Then the dense block. 129 | with hk.experimental.name_scope("mlp"): 130 | dense_block = hk.Sequential([ 131 | hk.Linear( 132 | self.config.mlp_hidden_size, 133 | w_init=initializer, 134 | name="linear_1"), 135 | self.config.activation_function, 136 | hk.Linear(model_size, w_init=initializer, name="linear_2"), 137 | ]) 138 | dense_in = layer_norm(residual) 139 | dense_out = dense_block(dense_in) 140 | if dropout_rate > 0: 141 | dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out) 142 | residual = residual + dense_out 143 | 144 | collect(residuals=residual, layer_outputs=dense_out) 145 | 146 | return TransformerOutput( 147 | residuals=collected["residuals"], 148 | layer_outputs=collected["layer_outputs"], 149 | attn_logits=collected["attn_logits"], 150 | output=layer_norm(residual), 151 | input_embeddings=embeddings, 152 | ) 153 | 154 | 155 | @chex.dataclass 156 | class CompiledTransformerModelOutput: 157 | transformer_output: TransformerOutput 158 | unembedded_output: jax.Array # [B, T] 159 | 160 | 161 | @dataclasses.dataclass 162 | class CompiledTransformerModel(hk.Module): 163 | """A transformer model with one-hot embeddings.""" 164 | transformer: Transformer 165 | token_embed: CallableHaikuModule 166 | position_embed: CallableHaikuModule 167 | unembed: CallableHaikuModule 168 | use_unembed_argmax: bool 169 | pad_token: Optional[int] = None 170 | 171 | def embed(self, tokens: jax.Array) -> jax.Array: 172 | token_embeddings = self.token_embed(tokens) 173 | positional_embeddings = self.position_embed(jnp.indices(tokens.shape)[-1]) 174 | return token_embeddings + positional_embeddings # [B, T, D] 175 | 176 | def __call__( 177 | self, 178 | tokens: jax.Array, 179 | use_dropout: bool = True, 180 | ) -> CompiledTransformerModelOutput: 181 | """Embed tokens, pass through model, and unembed output.""" 182 | if self.pad_token is None: 183 | input_mask = jnp.ones_like(tokens) 184 | else: 185 | input_mask = (tokens != self.pad_token) 186 | input_embeddings = self.embed(tokens) 187 | 188 | transformer_output = self.transformer( 189 | input_embeddings, 190 | input_mask, 191 | use_dropout=use_dropout, 192 | ) 193 | return CompiledTransformerModelOutput( 194 | transformer_output=transformer_output, 195 | unembedded_output=self.unembed( 196 | transformer_output.output, 197 | use_unembed_argmax=self.use_unembed_argmax, 198 | ), 199 | ) 200 | -------------------------------------------------------------------------------- /tracr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /tracr/utils/debugging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Useful helpers for model debugging.""" 16 | 17 | 18 | def print_arrays(arrays, labels=None, colwidth=12): 19 | """Pretty-prints a list of [1, T, D] arrays.""" 20 | if labels is not None: 21 | print(" |".join(labels)) 22 | widths = [len(l) for l in labels] 23 | else: 24 | widths = [colwidth] * len(arrays[0].shape[-1]) 25 | for layer in arrays: 26 | print("=" * (colwidth + 1) * layer.shape[1]) 27 | for row in layer[0]: 28 | print(" |".join([f"{x:<{width}.2f}" for x, width in zip(row, widths)])) 29 | -------------------------------------------------------------------------------- /tracr/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helpers for handling errors in user-provided functions.""" 16 | 17 | import functools 18 | import logging 19 | from typing import Any, Callable 20 | 21 | 22 | def ignoring_arithmetic_errors(fun: Callable[..., Any]) -> Callable[..., Any]: 23 | """Makes fun return None instead of raising ArithmeticError.""" 24 | 25 | @functools.wraps(fun) 26 | def fun_wrapped(*args): 27 | try: 28 | return fun(*args) 29 | except ArithmeticError: 30 | logging.warning( 31 | "Encountered arithmetic error in function: for value %s. " 32 | "Assuming this input will never occur.", str(args)) 33 | return None 34 | 35 | return fun_wrapped 36 | -------------------------------------------------------------------------------- /tracr/utils/errors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for rasp.helper.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from tracr.utils import errors 20 | 21 | 22 | class FunIgnoreArithmeticErrorsTest(parameterized.TestCase): 23 | 24 | def test_ignoring_arithmetic_errors(self): 25 | fun = lambda x: 1 / x 26 | fun_ignore = errors.ignoring_arithmetic_errors(fun) 27 | 28 | with self.assertLogs(level="WARNING"): 29 | res = fun_ignore(0) 30 | self.assertIs(res, None) 31 | 32 | self.assertEqual(fun_ignore(1), 1) 33 | self.assertEqual(fun_ignore(2), 0.5) 34 | self.assertEqual(fun_ignore(-2), -0.5) 35 | 36 | def test_ignoring_arithmetic_errors_two_arguments(self): 37 | fun = lambda x, y: 1 / x + 1 / y 38 | fun_ignore = errors.ignoring_arithmetic_errors(fun) 39 | 40 | with self.assertLogs(level="WARNING"): 41 | res = fun_ignore(0, 1) 42 | self.assertIs(res, None) 43 | 44 | with self.assertLogs(level="WARNING"): 45 | res = fun_ignore(0, 0) 46 | self.assertIs(res, None) 47 | 48 | with self.assertLogs(level="WARNING"): 49 | res = fun_ignore(1, 0) 50 | self.assertIs(res, None) 51 | 52 | self.assertEqual(fun_ignore(1, 1), 2) 53 | self.assertEqual(fun_ignore(1, 2), 1.5) 54 | self.assertEqual(fun_ignore(-2, 2), 0) 55 | 56 | 57 | if __name__ == "__main__": 58 | absltest.main() 59 | --------------------------------------------------------------------------------