├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── chomsky.svg ├── experiments ├── constants.py ├── curriculum.py ├── example.py ├── range_evaluation.py ├── training.py └── utils.py ├── models ├── ndstack_rnn.py ├── positional_encodings.py ├── rnn.py ├── stack_rnn.py ├── tape_rnn.py └── transformer.py ├── requirements.txt └── tasks ├── cs ├── binary_addition.py ├── binary_multiplication.py ├── bucket_sort.py ├── compute_sqrt.py ├── duplicate_string.py ├── missing_duplicate_string.py └── odds_first.py ├── dcf ├── modular_arithmetic_brackets.py ├── reverse_string.py ├── solve_equation.py └── stack_manipulation.py ├── regular ├── cycle_navigation.py ├── even_pairs.py ├── modular_arithmetic.py └── parity_check.py └── task.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Networks and the Chomsky Hierarchy 2 | 3 |

4 | Overview figure 5 |

6 | 7 | This repository provides an implementation of our ICLR 2023 paper [Neural Networks and the Chomsky Hierarchy](https://arxiv.org/abs/2207.02098). 8 | 9 | > Reliable generalization lies at the heart of safe ML and AI. 10 | However, understanding when and how neural networks generalize remains one of the most important unsolved problems in the field. 11 | In this work, we conduct an extensive empirical study (2200 models, 16 tasks) to investigate whether insights from the theory of computation can predict the limits of neural network generalization in practice. 12 | We demonstrate that grouping tasks according to the Chomsky hierarchy allows us to forecast whether certain architectures will be able to generalize to out-of-distribution inputs. 13 | This includes negative results where even extensive amounts of data and training time never led to any non-trivial generalization, despite models having sufficient capacity to perfectly fit the training data. 14 | Our results show that, for our subset of tasks, RNNs and Transformers fail to generalize on non-regular tasks, LSTMs can solve regular and counter-language tasks, and only networks augmented with structured memory (such as a stack or memory tape) can successfully generalize on context-free and context-sensitive tasks. 15 | 16 | It is based on [JAX](https://jax.readthedocs.io) and [Haiku](https://dm-haiku.readthedocs.io) and contains all code, datasets, and models necessary to reproduce the paper's results. 17 | 18 | 19 | ## Content 20 | 21 | ``` 22 | . 23 | ├── models 24 | | ├── ndstack_rnn.py - Nondeterministic Stack-RNN (DuSell & Chiang, 2021) 25 | | ├── rnn.py - RNN (Elman, 1990) 26 | | ├── stack_rnn.py - Stack-RNN (Joulin & Mikolov, 2015) 27 | | ├── tape_rnn.py - Tape-RNN, loosely based on Baby-NTM (Suzgun et al., 2019) 28 | | └── transformer.py - Transformer (Vaswani et al., 2017) 29 | ├── tasks 30 | | ├── cs - Context-sensitive tasks 31 | | ├── dcf - Determinisitc context-free tasks 32 | | ├── regular - Regular tasks 33 | | └── task.py - Abstract GeneralizationTask 34 | ├── experiments 35 | | ├── constants.py - Training/Evaluation constants 36 | | ├── curriculum.py - Training curricula (over sequence lengths) 37 | | ├── example.py - Example training script (RNN on the Even Pairs task) 38 | | ├── range_evaluation.py - Evaluation loop (over unseen sequence lengths) 39 | | ├── training.py - Training loop 40 | | └── utils.py - Utility functions 41 | ├── README.md 42 | └── requirements.txt - Dependencies 43 | ``` 44 | 45 | `tasks` contains all tasks, organized in their Chomsky hierarchy levels (regular, dcf, cs). 46 | They all inherit the abstract class `GeneralizationTask`, defined in `tasks/task.py`. 47 | 48 | `models` contains all the models we use, written in [jax](https://github.com/google/jax) and [haiku](https://github.com/deepmind/dm-haiku), two open source libraries. 49 | 50 | `training` contains the code for training models and evaluating them on a wide range of lengths. 51 | We also included an example to train and evaluate an RNN on the Even Pairs task. 52 | We use [optax](https://github.com/deepmind/optax) for our optimizers. 53 | 54 | 55 | ## Installation 56 | 57 | Clone the source code into a local directory: 58 | ```bash 59 | git clone https://github.com/google-deepmind/neural_networks_chomsky_hierarchy.git 60 | cd neural_networks_chomsky_hierarchy 61 | ``` 62 | 63 | `pip install -r requirements.txt` will install all required dependencies. 64 | This is best done inside a [conda environment](https://www.anaconda.com/). 65 | To that end, install [Anaconda](https://www.anaconda.com/download#downloads). 66 | Then, create and activate the conda environment: 67 | ```bash 68 | conda create --name nnch 69 | conda activate nnch 70 | ``` 71 | 72 | Install `pip` and use it to install all the dependencies: 73 | ```bash 74 | conda install pip 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | If you have a GPU available (highly recommended for fast training), then you can install JAX with CUDA support. 79 | ```bash 80 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 81 | ``` 82 | Note that the jax version must correspond to the existing CUDA installation you wish to use (CUDA 12 in the example above). 83 | Please see the [JAX documentation](https://github.com/google/jax#installation) for more details. 84 | 85 | 86 | 87 | 88 | 89 | ## Usage 90 | 91 | Before running any code, make sure to activate the conda environment and set the `PYTHONPATH`: 92 | ```bash 93 | conda activate nnch 94 | export PYTHONPATH=$(pwd)/.. 95 | ``` 96 | 97 | We provide an example of a training and evaluation run at: 98 | ```bash 99 | python experiments/example.py 100 | ``` 101 | 102 | 103 | ## Citing This Work 104 | 105 | ```bibtex 106 | @inproceedings{deletang2023neural, 107 | author = {Gr{\'{e}}goire Del{\'{e}}tang and 108 | Anian Ruoss and 109 | Jordi Grau{-}Moya and 110 | Tim Genewein and 111 | Li Kevin Wenliang and 112 | Elliot Catt and 113 | Chris Cundy and 114 | Marcus Hutter and 115 | Shane Legg and 116 | Joel Veness and 117 | Pedro A. Ortega}, 118 | title = {Neural Networks and the Chomsky Hierarchy}, 119 | booktitle = {11th International Conference on Learning Representations}, 120 | year = {2023}, 121 | } 122 | ``` 123 | 124 | 125 | ## License and Disclaimer 126 | 127 | Copyright 2022 DeepMind Technologies Limited 128 | 129 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 130 | you may not use this file except in compliance with the Apache 2.0 license. 131 | You may obtain a copy of the Apache 2.0 license at: 132 | https://www.apache.org/licenses/LICENSE-2.0 133 | 134 | All other materials are licensed under the Creative Commons Attribution 4.0 135 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 136 | https://creativecommons.org/licenses/by/4.0/legalcode 137 | 138 | Unless required by applicable law or agreed to in writing, all software and 139 | materials distributed here under the Apache 2.0 or CC-BY licenses are 140 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 141 | either express or implied. See the licenses for the specific language governing 142 | permissions and limitations under those licenses. 143 | 144 | This is not an official Google product. 145 | -------------------------------------------------------------------------------- /experiments/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Constants for our length generalization experiments.""" 17 | 18 | import functools 19 | 20 | import haiku as hk 21 | 22 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib 23 | from neural_networks_chomsky_hierarchy.models import ndstack_rnn 24 | from neural_networks_chomsky_hierarchy.models import rnn 25 | from neural_networks_chomsky_hierarchy.models import stack_rnn 26 | from neural_networks_chomsky_hierarchy.models import tape_rnn 27 | from neural_networks_chomsky_hierarchy.models import transformer 28 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition 29 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_multiplication 30 | from neural_networks_chomsky_hierarchy.tasks.cs import bucket_sort 31 | from neural_networks_chomsky_hierarchy.tasks.cs import compute_sqrt 32 | from neural_networks_chomsky_hierarchy.tasks.cs import duplicate_string 33 | from neural_networks_chomsky_hierarchy.tasks.cs import missing_duplicate_string 34 | from neural_networks_chomsky_hierarchy.tasks.cs import odds_first 35 | from neural_networks_chomsky_hierarchy.tasks.dcf import modular_arithmetic_brackets 36 | from neural_networks_chomsky_hierarchy.tasks.dcf import reverse_string 37 | from neural_networks_chomsky_hierarchy.tasks.dcf import solve_equation 38 | from neural_networks_chomsky_hierarchy.tasks.dcf import stack_manipulation 39 | from neural_networks_chomsky_hierarchy.tasks.regular import cycle_navigation 40 | from neural_networks_chomsky_hierarchy.tasks.regular import even_pairs 41 | from neural_networks_chomsky_hierarchy.tasks.regular import modular_arithmetic 42 | from neural_networks_chomsky_hierarchy.tasks.regular import parity_check 43 | 44 | 45 | MODEL_BUILDERS = { 46 | 'rnn': 47 | functools.partial(rnn.make_rnn, rnn_core=hk.VanillaRNN), 48 | 'lstm': 49 | functools.partial(rnn.make_rnn, rnn_core=hk.LSTM), 50 | 'stack_rnn': 51 | functools.partial( 52 | rnn.make_rnn, 53 | rnn_core=stack_rnn.StackRNNCore, 54 | inner_core=hk.VanillaRNN), 55 | 'ndstack_rnn': 56 | functools.partial( 57 | rnn.make_rnn, 58 | rnn_core=ndstack_rnn.NDStackRNNCore, 59 | inner_core=hk.VanillaRNN), 60 | 'stack_lstm': 61 | functools.partial( 62 | rnn.make_rnn, rnn_core=stack_rnn.StackRNNCore, inner_core=hk.LSTM), 63 | 'transformer_encoder': 64 | transformer.make_transformer_encoder, 65 | 'transformer': 66 | transformer.make_transformer, 67 | 'tape_rnn': 68 | functools.partial( 69 | rnn.make_rnn, 70 | rnn_core=tape_rnn.TapeInputLengthJumpCore, 71 | inner_core=hk.VanillaRNN), 72 | } 73 | 74 | CURRICULUM_BUILDERS = { 75 | 'fixed': curriculum_lib.FixedCurriculum, 76 | 'regular_increase': curriculum_lib.RegularIncreaseCurriculum, 77 | 'reverse_exponential': curriculum_lib.ReverseExponentialCurriculum, 78 | 'uniform': curriculum_lib.UniformCurriculum, 79 | } 80 | 81 | TASK_BUILDERS = { 82 | 'modular_arithmetic': 83 | modular_arithmetic.ModularArithmetic, 84 | 'parity_check': 85 | parity_check.ParityCheck, 86 | 'even_pairs': 87 | even_pairs.EvenPairs, 88 | 'cycle_navigation': 89 | cycle_navigation.CycleNavigation, 90 | 'modular_arithmetic_brackets': 91 | functools.partial( 92 | modular_arithmetic_brackets.ModularArithmeticBrackets, mult=True), 93 | 'reverse_string': 94 | reverse_string.ReverseString, 95 | 'missing_duplicate_string': 96 | missing_duplicate_string.MissingDuplicateString, 97 | 'duplicate_string': 98 | duplicate_string.DuplicateString, 99 | 'binary_addition': 100 | binary_addition.BinaryAddition, 101 | 'binary_multiplication': 102 | binary_multiplication.BinaryMultiplication, 103 | 'compute_sqrt': 104 | compute_sqrt.ComputeSqrt, 105 | 'odds_first': 106 | odds_first.OddsFirst, 107 | 'solve_equation': 108 | solve_equation.SolveEquation, 109 | 'stack_manipulation': 110 | stack_manipulation.StackManipulation, 111 | 'bucket_sort': 112 | bucket_sort.BucketSort, 113 | } 114 | 115 | TASK_LEVELS = { 116 | 'modular_arithmetic': 'regular', 117 | 'parity_check': 'regular', 118 | 'even_pairs': 'regular', 119 | 'cycle_navigation': 'regular', 120 | 'modular_arithmetic_brackets': 'dcf', 121 | 'reverse_string': 'dcf', 122 | 'stack_manipulation': 'dcf', 123 | 'solve_equation': 'dcf', 124 | 'missing_duplicate_string': 'cs', 125 | 'compute_sqrt': 'cs', 126 | 'duplicate_string': 'cs', 127 | 'binary_addition': 'cs', 128 | 'binary_multiplication': 'cs', 129 | 'odds_first': 'cs', 130 | 'bucket_sort': 'cs', 131 | } 132 | -------------------------------------------------------------------------------- /experiments/curriculum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Curricula over sequence lengths used to evaluate length generalization. 17 | 18 | Allows to sample different sequence lengths along training. For instance, 19 | one might want to start with length=1 and regularly increase the length by 1, 20 | every 50k steps. 21 | """ 22 | 23 | import abc 24 | from collections.abc import Collection 25 | import random 26 | 27 | import numpy as np 28 | 29 | 30 | class Curriculum(abc.ABC): 31 | """Curriculum to sample lengths.""" 32 | 33 | @abc.abstractmethod 34 | def sample_sequence_length(self, step: int) -> int: 35 | """Samples a sequence length from the current distribution.""" 36 | 37 | 38 | class FixedCurriculum(Curriculum): 39 | """A fixed curriculum, always sampling the same sequence length.""" 40 | 41 | def __init__(self, sequence_length: int): 42 | """Initializes. 43 | 44 | Args: 45 | sequence_length: The sequence length to sample. 46 | """ 47 | super().__init__() 48 | self._sequence_length = sequence_length 49 | 50 | def sample_sequence_length(self, step: int) -> int: 51 | """Returns a fixed sequence length.""" 52 | del step 53 | return self._sequence_length 54 | 55 | 56 | class UniformCurriculum(Curriculum): 57 | """A uniform curriculum, sampling different sequence lengths.""" 58 | 59 | def __init__(self, values: Collection[int]): 60 | """Initializes. 61 | 62 | Args: 63 | values: The sequence lengths to sample. 64 | """ 65 | super().__init__() 66 | self._values = tuple(values) 67 | 68 | def sample_sequence_length(self, step: int) -> int: 69 | """Returns a sequence length sampled from a uniform distribution.""" 70 | del step 71 | return random.choice(self._values) 72 | 73 | 74 | class ReverseExponentialCurriculum(Curriculum): 75 | """A reverse exponential curriculum, sampling different sequence lengths.""" 76 | 77 | def __init__(self, values: Collection[int], tau: bool): 78 | """Initializes. 79 | 80 | Args: 81 | values: The sequence lengths to sample. 82 | tau: The exponential rate to use. 83 | """ 84 | super().__init__() 85 | self._values = tuple(values) 86 | self._tau = tau 87 | 88 | def sample_sequence_length(self, step: int) -> int: 89 | """Returns a length sampled from a reverse exponential distribution.""" 90 | del step 91 | probs = self._tau**np.array(self._values) 92 | probs = np.array(probs, dtype=np.float32) 93 | probs = probs / np.sum(probs) 94 | return np.random.choice(self._values, p=probs) 95 | 96 | 97 | class RegularIncreaseCurriculum(Curriculum): 98 | """Curriculum for sequence lengths with a regular increase.""" 99 | 100 | def __init__(self, initial_sequence_length: int, increase_frequency: int, 101 | increase_amount: int, sample_all_length: bool): 102 | """Initializes. 103 | 104 | Args: 105 | initial_sequence_length: The value of the sequence length at the beginning 106 | of the curriculum. 107 | increase_frequency: How often we increase the possible sequence length. 108 | increase_amount: The amount of the increase in length. 109 | sample_all_length: Whether to sample all length lower than the current one 110 | or just return the current one. 111 | """ 112 | super().__init__() 113 | self._initial_sequence_length = initial_sequence_length 114 | self._increase_frequency = increase_frequency 115 | self._increase_amount = increase_amount 116 | self._sample_all_length = sample_all_length 117 | 118 | def sample_sequence_length(self, step: int) -> int: 119 | """Returns a sequence length from the curriculum with the current step.""" 120 | if not self._sample_all_length: 121 | return self._initial_sequence_length + self._increase_amount * ( 122 | step // self._increase_frequency 123 | ) 124 | return ( 125 | self._initial_sequence_length 126 | + self._increase_amount 127 | * np.random.randint(0, step // self._increase_frequency + 1) 128 | ) 129 | -------------------------------------------------------------------------------- /experiments/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Example script to train and evaluate a network.""" 17 | 18 | from absl import app 19 | from absl import flags 20 | import haiku as hk 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | from neural_networks_chomsky_hierarchy.experiments import constants 25 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib 26 | from neural_networks_chomsky_hierarchy.experiments import training 27 | from neural_networks_chomsky_hierarchy.experiments import utils 28 | 29 | _BATCH_SIZE = flags.DEFINE_integer( 30 | 'batch_size', 31 | default=128, 32 | help='Training batch size.', 33 | lower_bound=1, 34 | ) 35 | _SEQUENCE_LENGTH = flags.DEFINE_integer( 36 | 'sequence_length', 37 | default=40, 38 | help='Maximum training sequence length.', 39 | lower_bound=1, 40 | ) 41 | _TASK = flags.DEFINE_string( 42 | 'task', 43 | default='even_pairs', 44 | help='Length generalization task (see `constants.py` for other tasks).', 45 | ) 46 | _ARCHITECTURE = flags.DEFINE_string( 47 | 'architecture', 48 | default='tape_rnn', 49 | help='Model architecture (see `constants.py` for other architectures).', 50 | ) 51 | 52 | _IS_AUTOREGRESSIVE = flags.DEFINE_boolean( 53 | 'is_autoregressive', 54 | default=False, 55 | help='Whether to use autoregressive sampling or not.', 56 | ) 57 | _COMPUTATION_STEPS_MULT = flags.DEFINE_integer( 58 | 'computation_steps_mult', 59 | default=0, 60 | help=( 61 | 'The amount of computation tokens to append to the input tape (defined' 62 | ' as a multiple of the input length)' 63 | ), 64 | lower_bound=0, 65 | ) 66 | # The architecture parameters depend on the architecture, so we cannot define 67 | # them as via flags. See `constants.py` for the required values. 68 | _ARCHITECTURE_PARAMS = { 69 | 'hidden_size': 256, 70 | 'memory_cell_size': 8, 71 | 'memory_size': 40, 72 | } 73 | 74 | 75 | def main(unused_argv) -> None: 76 | # Create the task. 77 | curriculum = curriculum_lib.UniformCurriculum( 78 | values=list(range(1, _SEQUENCE_LENGTH.value + 1)) 79 | ) 80 | task = constants.TASK_BUILDERS[_TASK.value]() 81 | 82 | # Create the model. 83 | single_output = task.output_length(10) == 1 84 | model = constants.MODEL_BUILDERS[_ARCHITECTURE.value]( 85 | output_size=task.output_size, 86 | return_all_outputs=True, 87 | **_ARCHITECTURE_PARAMS, 88 | ) 89 | if _IS_AUTOREGRESSIVE.value: 90 | if 'transformer' not in _ARCHITECTURE.value: 91 | model = utils.make_model_with_targets_as_input( 92 | model, _COMPUTATION_STEPS_MULT.value 93 | ) 94 | model = utils.add_sampling_to_autoregressive_model(model, single_output) 95 | else: 96 | model = utils.make_model_with_empty_targets( 97 | model, task, _COMPUTATION_STEPS_MULT.value, single_output 98 | ) 99 | model = hk.transform(model) 100 | 101 | # Create the loss and accuracy based on the pointwise ones. 102 | def loss_fn(output, target): 103 | loss = jnp.mean(jnp.sum(task.pointwise_loss_fn(output, target), axis=-1)) 104 | return loss, {} 105 | 106 | def accuracy_fn(output, target): 107 | mask = task.accuracy_mask(target) 108 | return jnp.sum(mask * task.accuracy_fn(output, target)) / jnp.sum(mask) 109 | 110 | # Create the final training parameters. 111 | training_params = training.ClassicTrainingParams( 112 | seed=0, 113 | model_init_seed=0, 114 | training_steps=10_000, 115 | log_frequency=100, 116 | length_curriculum=curriculum, 117 | batch_size=_BATCH_SIZE.value, 118 | task=task, 119 | model=model, 120 | loss_fn=loss_fn, 121 | learning_rate=1e-3, 122 | accuracy_fn=accuracy_fn, 123 | compute_full_range_test=True, 124 | max_range_test_length=100, 125 | range_test_total_batch_size=512, 126 | range_test_sub_batch_size=64, 127 | is_autoregressive=_IS_AUTOREGRESSIVE.value, 128 | ) 129 | 130 | training_worker = training.TrainingWorker(training_params, use_tqdm=True) 131 | _, eval_results, _ = training_worker.run() 132 | 133 | # Gather results and print final score. 134 | accuracies = [r['accuracy'] for r in eval_results] 135 | score = np.mean(accuracies[_SEQUENCE_LENGTH.value + 1 :]) 136 | print(f'Network score: {score}') 137 | 138 | 139 | if __name__ == '__main__': 140 | app.run(main) 141 | -------------------------------------------------------------------------------- /experiments/range_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluation of a network on sequences of different lengths.""" 17 | 18 | import dataclasses 19 | import random 20 | from typing import Any, Callable, Mapping 21 | 22 | from absl import logging 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | import tqdm 28 | 29 | 30 | _Batch = Mapping[str, jnp.ndarray] 31 | 32 | 33 | @dataclasses.dataclass 34 | class EvaluationParams: 35 | """The parameters used for range evaluation of networks.""" 36 | model: hk.Transformed 37 | params: hk.Params 38 | 39 | accuracy_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 40 | sample_batch: Callable[[jnp.ndarray, int, int], _Batch] 41 | 42 | max_test_length: int 43 | total_batch_size: int 44 | sub_batch_size: int # We use this to avoid memory overflow. 45 | 46 | is_autoregressive: bool = False 47 | 48 | 49 | def range_evaluation( 50 | eval_params: EvaluationParams, 51 | use_tqdm: bool = False, 52 | ) -> list[Mapping[str, Any]]: 53 | """Evaluates the model on longer, never seen strings and log the results. 54 | 55 | Args: 56 | eval_params: The evaluation parameters, see above. 57 | use_tqdm: Whether to use a progress bar with tqdm. 58 | 59 | Returns: 60 | The list of dicts containing the accuracies. 61 | """ 62 | model = eval_params.model 63 | params = eval_params.params 64 | 65 | random.seed(1) 66 | np.random.seed(1) 67 | rng_seq = hk.PRNGSequence(1) 68 | 69 | if eval_params.is_autoregressive: 70 | apply_fn = jax.jit(model.apply, static_argnames=('sample',)) 71 | else: 72 | apply_fn = jax.jit(model.apply) 73 | 74 | results = [] 75 | lengths = range(1, eval_params.max_test_length + 1) 76 | if use_tqdm: 77 | lengths = tqdm.tqdm(lengths) 78 | for length in lengths: 79 | # We need to clear the cache of jitted functions, to avoid overflow as we 80 | # are jitting len(lengths) ones, which can be a lot. 81 | apply_fn.clear_cache() 82 | sub_accuracies = [] 83 | for _ in range(eval_params.total_batch_size // eval_params.sub_batch_size): 84 | batch = eval_params.sample_batch( 85 | next(rng_seq), eval_params.sub_batch_size, length) 86 | 87 | if eval_params.is_autoregressive: 88 | outputs = apply_fn( 89 | params, 90 | next(rng_seq), 91 | batch['input'], 92 | jnp.empty_like(batch['output']), 93 | sample=True) 94 | else: 95 | outputs = apply_fn(params, next(rng_seq), batch['input']) 96 | 97 | sub_accuracies.append( 98 | float(np.mean(eval_params.accuracy_fn(outputs, batch['output'])))) 99 | log_data = { 100 | 'length': length, 101 | 'accuracy': np.mean(sub_accuracies), 102 | } 103 | logging.info(log_data) 104 | results.append(log_data) 105 | return results 106 | -------------------------------------------------------------------------------- /experiments/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Training loop for length generalization experiments.""" 17 | 18 | import dataclasses 19 | import functools 20 | import random 21 | from typing import Any, Callable, Mapping, Optional 22 | 23 | import chex 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import optax 29 | import tqdm 30 | 31 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib 32 | from neural_networks_chomsky_hierarchy.experiments import range_evaluation 33 | from neural_networks_chomsky_hierarchy.tasks import task as task_lib 34 | 35 | 36 | _LossMetrics = Optional[Mapping[str, jnp.ndarray]] 37 | _LossFn = Callable[[chex.Array, chex.Array], tuple[float, _LossMetrics]] 38 | _AccuracyFn = Callable[[chex.Array, chex.Array], float] 39 | _ModelApplyFn = Callable[..., chex.Array] 40 | _MAX_RNGS_RESERVE = 50000 41 | 42 | 43 | @dataclasses.dataclass 44 | class ClassicTrainingParams: 45 | """Parameters needed to train classical architectures.""" 46 | seed: int # Used to sample during forward pass (e.g. from final logits). 47 | model_init_seed: int # Used to initialize model parameters. 48 | training_steps: int 49 | log_frequency: int 50 | 51 | task: task_lib.GeneralizationTask 52 | length_curriculum: curriculum_lib.Curriculum 53 | batch_size: int 54 | 55 | model: hk.Transformed 56 | loss_fn: Callable[[jnp.ndarray, jnp.ndarray], tuple[float, _LossMetrics]] 57 | learning_rate: float 58 | test_model: Optional[hk.Transformed] = None 59 | max_grad_norm: float = 1. 60 | is_autoregressive: bool = False 61 | 62 | compute_full_range_test: bool = False 63 | range_test_total_batch_size: int = 512 64 | range_test_sub_batch_size: int = 64 65 | max_range_test_length: int = 100 66 | 67 | accuracy_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], 68 | jnp.ndarray]] = None 69 | 70 | 71 | def _apply_loss_and_metrics_fn( 72 | params: hk.Params, 73 | rng_key: chex.PRNGKey, 74 | batch: task_lib.Batch, 75 | model_apply_fn: _ModelApplyFn, 76 | loss_fn: _LossFn, 77 | accuracy_fn: _AccuracyFn, 78 | is_autoregressive: bool = False, 79 | ) -> tuple[float, tuple[_LossMetrics, float]]: 80 | """Computes the model output and applies the loss function. 81 | 82 | Depending on whether a model is autoregressive or not, it will have a 83 | different number of input parameters (i.e., autoregressive models also require 84 | the targets as an input). 85 | 86 | Args: 87 | params: The model parameters. 88 | rng_key: The prng key to use for random number generation. 89 | batch: The data (consists of both inputs and outputs). 90 | model_apply_fn: The model function that converts inputs into outputs. 91 | loss_fn: A function that computes the loss for a batch of logits and labels. 92 | accuracy_fn: A function that computes the accuracy for a batch of logits and 93 | labels. 94 | is_autoregressive: Whether the model is autoregressive or not. 95 | 96 | Returns: 97 | The loss of the model for the batch of data, extra loss metrics and the 98 | accuracy, if accuracy_fn is not None. 99 | """ 100 | if is_autoregressive: 101 | outputs = model_apply_fn( 102 | params, rng_key, batch["input"], batch["output"], sample=False) 103 | else: 104 | outputs = model_apply_fn(params, rng_key, batch["input"]) 105 | 106 | loss, loss_metrics = loss_fn(outputs, batch["output"]) 107 | if accuracy_fn is not None: 108 | accuracy = accuracy_fn(outputs, batch["output"]) 109 | else: 110 | accuracy = None 111 | return loss, (loss_metrics, accuracy) 112 | 113 | 114 | @functools.partial( 115 | jax.jit, 116 | static_argnames=( 117 | "model_apply_fn", 118 | "loss_fn", 119 | "accuracy_fn", 120 | "optimizer", 121 | "is_autoregressive", 122 | ), 123 | ) 124 | def _update_parameters( 125 | params: hk.Params, 126 | rng_key: chex.PRNGKey, 127 | batch: task_lib.Batch, 128 | model_apply_fn: _ModelApplyFn, 129 | loss_fn: _LossFn, 130 | accuracy_fn: _AccuracyFn, 131 | optimizer: optax.GradientTransformation, 132 | opt_state: optax.OptState, 133 | is_autoregressive: bool = False, 134 | ) -> tuple[hk.Params, optax.OptState, tuple[float, _LossMetrics, float]]: 135 | """Applies a single SGD update step to the model parameters. 136 | 137 | Args: 138 | params: The model parameters. 139 | rng_key: The prng key to use for random number generation. 140 | batch: The data (consists of both inputs and outputs). 141 | model_apply_fn: The model function that converts inputs into outputs. 142 | loss_fn: A function that computes the loss for a batch of logits and labels. 143 | accuracy_fn: A function that computes the accuracy for a batch of logits and 144 | labels. 145 | optimizer: The optimizer that computes the updates from the gradients of the 146 | `loss_fn` with respect to the `params` and the previous `opt_state`. 147 | opt_state: The optimizer state, e.g., momentum for each variable when using 148 | Adam. 149 | is_autoregressive: Whether the model is autoregressive or not. 150 | 151 | Returns: 152 | The updated parameters, the new optimizer state, and the loss, loss metrics 153 | and accuracy. 154 | """ 155 | (loss, (metrics, accuracy)), grads = jax.value_and_grad( 156 | _apply_loss_and_metrics_fn, 157 | has_aux=True)(params, rng_key, batch, model_apply_fn, loss_fn, 158 | accuracy_fn, is_autoregressive) 159 | updates, new_opt_state = optimizer.update(grads, opt_state) 160 | new_params = optax.apply_updates(params, updates) 161 | return new_params, new_opt_state, (loss, metrics, accuracy) 162 | 163 | 164 | class TrainingWorker: 165 | """Training worker.""" 166 | 167 | def __init__(self, 168 | training_params: ClassicTrainingParams, 169 | use_tqdm: bool = False): 170 | """Initializes the worker. 171 | 172 | Args: 173 | training_params: The training parameters. 174 | use_tqdm: Whether to add a progress bar to stdout. 175 | """ 176 | self._training_params = training_params 177 | self._use_tqdm = use_tqdm 178 | 179 | def run( 180 | self, 181 | ) -> tuple[ 182 | list[Mapping[str, Any]], Optional[list[Mapping[str, Any]]], chex.ArrayTree 183 | ]: 184 | """Trains the model with the provided config. 185 | 186 | Returns: 187 | Results (various training and validation metrics), module parameters 188 | and router parameters. 189 | """ 190 | training_params = self._training_params 191 | rngs_reserve = min(_MAX_RNGS_RESERVE, training_params.training_steps) 192 | 193 | random.seed(training_params.seed) 194 | np.random.seed(training_params.seed) 195 | rng_seq = hk.PRNGSequence(training_params.seed) 196 | rng_seq.reserve(rngs_reserve) 197 | 198 | results = [] 199 | model = training_params.model 200 | task = training_params.task 201 | length_curriculum = training_params.length_curriculum 202 | 203 | optimizer = optax.chain( 204 | optax.clip_by_global_norm(training_params.max_grad_norm), 205 | optax.adam(training_params.learning_rate)) 206 | 207 | dummy_batch = task.sample_batch( 208 | next(rng_seq), length=10, batch_size=training_params.batch_size) 209 | model_init_rng_key = jax.random.PRNGKey(training_params.model_init_seed) 210 | 211 | if training_params.is_autoregressive: 212 | params = model.init( 213 | model_init_rng_key, 214 | dummy_batch["input"], 215 | dummy_batch["output"], 216 | sample=False) 217 | else: 218 | params = model.init(model_init_rng_key, dummy_batch["input"]) 219 | 220 | opt_state = optimizer.init(params) 221 | self._params, self._step = params, 0 222 | 223 | steps = range(training_params.training_steps + 1) 224 | if self._use_tqdm: 225 | steps = tqdm.tqdm(steps) 226 | for step in steps: 227 | # Randomness handled by either python.random or numpy. 228 | length = length_curriculum.sample_sequence_length(step) 229 | # Randomness handled by either jax, python.random or numpy. 230 | train_batch = task.sample_batch( 231 | next(rng_seq), length=length, batch_size=training_params.batch_size) 232 | params, opt_state, ( 233 | train_loss, train_metrics, train_accuracy) = _update_parameters( 234 | params=params, 235 | rng_key=next(rng_seq), 236 | batch=train_batch, 237 | model_apply_fn=model.apply, 238 | loss_fn=training_params.loss_fn, 239 | accuracy_fn=training_params.accuracy_fn, 240 | optimizer=optimizer, 241 | opt_state=opt_state, 242 | is_autoregressive=training_params.is_autoregressive) 243 | self._params, self._step = params, step 244 | 245 | log_freq = training_params.log_frequency 246 | if (log_freq > 0) and (step % log_freq == 0): 247 | log_data = { 248 | "step": step, 249 | "train_loss": float(train_loss), 250 | } 251 | if training_params.accuracy_fn is not None: 252 | log_data["train_accuracy"] = float(train_accuracy) 253 | for key, value in train_metrics.items(): 254 | log_data[".".join(["train_metrics", key])] = np.array(value) 255 | results.append(log_data) 256 | 257 | # We need to access this private attribute since the default reserve size 258 | # can not be edited yet. 259 | if not rng_seq._subkeys: # pylint: disable=protected-access 260 | rng_seq.reserve(rngs_reserve) 261 | 262 | eval_results = None 263 | if training_params.compute_full_range_test: 264 | eval_params = range_evaluation.EvaluationParams( 265 | model=training_params.test_model or model, 266 | params=params, 267 | accuracy_fn=training_params.accuracy_fn, 268 | sample_batch=task.sample_batch, 269 | max_test_length=training_params.max_range_test_length, 270 | total_batch_size=training_params.range_test_total_batch_size, 271 | sub_batch_size=training_params.range_test_sub_batch_size, 272 | is_autoregressive=training_params.is_autoregressive, 273 | ) 274 | eval_results = range_evaluation.range_evaluation( 275 | eval_params, use_tqdm=False) 276 | 277 | return results, eval_results, params 278 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Provides utility functions for training and evaluation.""" 17 | 18 | import inspect 19 | from typing import Any, Callable 20 | 21 | import chex 22 | import haiku as hk 23 | from jax import nn as jnn 24 | from jax import numpy as jnp 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | 28 | COMPUTATION_EMPTY_TOKEN = 0 29 | OUTPUT_EMPTY_TOKEN = 1 30 | 31 | 32 | def make_model_with_empty_targets( 33 | model: Callable[[chex.Array], chex.Array], 34 | generalization_task: task.GeneralizationTask, 35 | computation_steps_mult: int = 0, 36 | single_output: bool = False, 37 | ) -> Callable[[chex.Array], chex.Array]: 38 | """Returns a wrapped model that pads the inputs to match the output length. 39 | 40 | For a given input tape `input_tape` of vocabulary size `vocab_size`, the 41 | wrapped model will process a tape of the format 42 | [`input_tape`, `empty_tape`], where the empty tape token is `vocab_size + 1`. 43 | The `empty_tape` has the same length as the task output. 44 | 45 | Args: 46 | model: A model function that converts inputs to outputs. 47 | generalization_task: The task that we train on. 48 | computation_steps_mult: The amount of empty cells to append to the input 49 | tape. This variable is a multiplier and the actual number of cells is 50 | `computation_steps_mult * input_length`. 51 | single_output: Whether to return the squeezed tensor of values. 52 | """ 53 | 54 | def new_model(x: chex.Array) -> chex.Array: 55 | batch_size, input_length, input_size = x.shape 56 | output_length = generalization_task.output_length(input_length) 57 | extra_dims_onehot = 1 + int(computation_steps_mult > 0) 58 | final_input_size = input_size + extra_dims_onehot 59 | 60 | # Add trailing zeros to account for new final_input_size. 61 | extra_zeros_x = jnp.zeros( 62 | (batch_size, input_length, final_input_size - input_size) 63 | ) 64 | x = jnp.concatenate([x, extra_zeros_x], axis=-1) 65 | 66 | computation_tape = jnp.full( 67 | (batch_size, computation_steps_mult * input_length), 68 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN) 69 | computation_tape = jnn.one_hot( 70 | computation_tape, num_classes=final_input_size 71 | ) 72 | 73 | output_tokens = jnp.full( 74 | (batch_size, output_length), 75 | fill_value=input_size 76 | + OUTPUT_EMPTY_TOKEN 77 | - int(computation_steps_mult == 0), 78 | ) 79 | output_tokens = jnn.one_hot(output_tokens, num_classes=final_input_size) 80 | final_input = jnp.concatenate([x, computation_tape, output_tokens], axis=1) 81 | 82 | if 'input_length' in inspect.getfullargspec(model).args: 83 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args 84 | else: 85 | output = model(final_input) 86 | output = output[:, -output_length:] 87 | if single_output: 88 | output = jnp.squeeze(output, axis=1) 89 | return output 90 | 91 | return new_model 92 | 93 | 94 | def make_model_with_targets_as_input( 95 | model: Callable[[chex.Array], chex.Array], computation_steps_mult: int = 0 96 | ) -> Callable[[chex.Array, chex.Array], chex.Array]: 97 | """Returns a wrapped model that takes the targets as inputs. 98 | 99 | This function is useful for the autoregressive case where we pass the targets 100 | as inputs to the model. The final input looks like: 101 | [inputs, computation_tokens, output_token, targets] 102 | 103 | Args: 104 | model: A haiku model that takes 'x' as input. 105 | computation_steps_mult: The amount of computation tokens to append to the 106 | input tape. This variable is a multiplier and the actual number of cell is 107 | computation_steps_mult * input_length. 108 | """ 109 | 110 | def new_model(x: chex.Array, y: chex.Array) -> chex.Array: 111 | """Returns an output from the inputs and targets. 112 | 113 | Args: 114 | x: One-hot input vectors, shape (B, T, input_size). 115 | y: One-hot target output vectors, shape (B, T, output_size). 116 | """ 117 | batch_size, input_length, input_size = x.shape 118 | _, output_length, output_size = y.shape 119 | extra_dims_onehot = 1 + int(computation_steps_mult > 0) 120 | final_input_size = max(input_size, output_size) + extra_dims_onehot 121 | 122 | # Add trailing zeros to account for new final_input_size. 123 | extra_zeros_x = jnp.zeros( 124 | (batch_size, input_length, final_input_size - input_size) 125 | ) 126 | x = jnp.concatenate([x, extra_zeros_x], axis=-1) 127 | extra_zeros_y = jnp.zeros( 128 | (batch_size, output_length, final_input_size - output_size) 129 | ) 130 | y = jnp.concatenate([y, extra_zeros_y], axis=-1) 131 | 132 | computation_tape = jnp.full( 133 | (batch_size, computation_steps_mult * input_length), 134 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN, 135 | ) 136 | computation_tape = jnn.one_hot( 137 | computation_tape, num_classes=final_input_size 138 | ) 139 | 140 | output_token = jnp.full( 141 | (batch_size, 1), 142 | fill_value=input_size 143 | + OUTPUT_EMPTY_TOKEN 144 | - int(computation_steps_mult == 0), 145 | ) 146 | output_token = jnn.one_hot(output_token, num_classes=final_input_size) 147 | final_input = jnp.concatenate( 148 | [x, computation_tape, output_token, y], axis=1 149 | ) 150 | 151 | if 'input_length' in inspect.getfullargspec(model).args: 152 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args 153 | else: 154 | output = model(final_input) 155 | 156 | return output[:, -output_length - 1 : -1] 157 | 158 | return new_model 159 | 160 | 161 | def add_sampling_to_autoregressive_model( 162 | model: Callable[[chex.Array, chex.Array], chex.Array], 163 | single_output: bool = False, 164 | ) -> Callable[[chex.Array, chex.Array, bool], chex.Array]: 165 | """Adds a 'sample' argument to the model, to use autoregressive sampling.""" 166 | 167 | def new_model_with_sampling( 168 | x: chex.Array, 169 | y: chex.Array, 170 | sample: bool, 171 | ) -> chex.Array: 172 | """Returns an autoregressive model if `sample == True and output_size > 1`. 173 | 174 | Args: 175 | x: The input sequences of shape (b, t, i), where i is the input size. 176 | y: The target sequences of shape (b, t, o), where o is the output size. 177 | sample: Whether to evaluate the model using autoregressive decoding. 178 | """ 179 | output_length = 1 if len(y.shape) == 2 else y.shape[1] 180 | output_size = y.shape[-1] 181 | 182 | if not sample or output_length == 1: 183 | output = model(x, y) 184 | 185 | else: 186 | 187 | def evaluate_model_autoregressively( 188 | idx: int, 189 | predictions: chex.Array, 190 | ) -> chex.Array: 191 | """Iteratively evaluates the model based on the previous predictions. 192 | 193 | Args: 194 | idx: The index of the target sequence that should be evaluated. 195 | predictions: The logits for the predictions up to but not including 196 | the index `idx`. 197 | 198 | Returns: 199 | The `predictions` array modified only at position `idx` where the 200 | logits for index `idx` have been inserted. 201 | """ 202 | one_hot_predictions = jnn.one_hot( 203 | jnp.argmax(predictions, axis=-1), 204 | num_classes=output_size, 205 | ) 206 | logits = model(x, one_hot_predictions) 207 | return predictions.at[:, idx].set(logits[:, idx]) 208 | 209 | output = hk.fori_loop( 210 | lower=0, 211 | upper=output_length, 212 | body_fun=evaluate_model_autoregressively, 213 | init_val=jnp.empty_like(y), 214 | ) 215 | 216 | if single_output: 217 | output = jnp.squeeze(output, axis=1) 218 | return output 219 | 220 | return new_model_with_sampling 221 | 222 | 223 | def update_tree_with_new_containers( 224 | tree: Any, update_dict: dict[str, Any] 225 | ) -> None: 226 | """Updates a dataclass tree in place, adding new containers. 227 | 228 | This method is useful for the nested library to add fields to a tree, for 229 | which containers have not been created. 230 | For instance, if A is a dataclass with attribute architecture_params, and we 231 | want to add the value architecture_params.rnn_model.size, we need to create 232 | the container 'rnn_model' inside architecture_params. 233 | 234 | Args: 235 | tree: An object with attribute (typically a dataclass). 236 | update_dict: A dict of nested updates. See example above. 237 | """ 238 | for key in update_dict: 239 | subkeys = key.split('.') 240 | if len(subkeys) >= 2: 241 | # Example: architecture.params.size 242 | for i in range(0, len(subkeys) - 2): 243 | getattr(tree, subkeys[i])[subkeys[i + 1]] = {} 244 | -------------------------------------------------------------------------------- /models/ndstack_rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Non-deterministic Stack RNN core. 17 | 18 | Following the paper from DuSell et al (2020): 19 | https://arxiv.org/abs/2010.04674 20 | 21 | The idea is to add a non-deterministic stack extension to a recurrent neural 22 | network to be able to simulate a machine accepting non-deterministic 23 | context-free languages. It can be seen as an extension to the Stack-RNN 24 | developed by Joulin et al (2015). However, it is far more complex and hard to 25 | understand. 26 | The non-deterministic stack is completely differentiable. 27 | 28 | A non-deterministic Pushdown Automata (NDPDA) uses 'multiple stacks at the same 29 | time'. The problem is that the number of possible stacks grows exponentially 30 | with time, which makes a naive practical implementation impossible. However, 31 | Lang et al proved in 1969, based on ideas from Context-Frees parsers like the 32 | CYK, that a NDPDA can be simulated only using O(n³) memory, and not O(2^n). 33 | The main idea is to reuse the content of the different stacks in a dynamic 34 | programming manner. A stack with n values is a stack with n-1 values + an extra 35 | value, so we can build a graph of possible stacks, which would reuse most of 36 | the data. 37 | 38 | Concretely, the graph is made of nodes (t, q, x) where t is a number (the time), 39 | q is a state from a fixed, user-defined set and x is a symbol or value, also 40 | from a finite, user-defined set. Then one path in this graph is exactly one 41 | stack, which can simply be reconstructed by reading the value for each node 42 | in the path. Each state q can be used as a 'branching' mechanism: the more 43 | states, the more branching there can be and therefore the more stacks can be 44 | used. The number of possible stacks is (#states * #symbols)^t. 45 | 46 | To interact with this graph, ie do a push or a pop action, one uses transitions 47 | on these nodes. For push, it is a function of the form (q1, x1) -> (q2, x2), 48 | where q2 is the new state to go in (ie whether to branch to a new stack, or keep 49 | the same) and x2 is the value to push. For pop, it is a function of the form 50 | (q1, x1) -> q2, which again allows the network to choose whether to create a new 51 | stack or not. No value should be passed there however. The functions are 52 | modelled by transition matrices of shape (Q, S, Q, S) where Q=#states and 53 | S=#symbols. 54 | Once the action matrices are passed, the graph is updated. The update is done 55 | via an internal transition matrix called gamma. This matrix is simple for the 56 | push action (one can only push on the top of the stack, ie nodes for which t = 57 | current timestep). It is far more complex for the pop action, as popping a value 58 | from the current stack can completely change the structure of the graph: the 59 | new stack after popping might be equal to a very old stack seen at the beginning 60 | of the episode, and we must change the links accordingly. Roughly, the update 61 | operation for gamma has a time complexity of O(Q⁴ S³ n³). 62 | Finally, one the graph is updated via gamma, we update the probabilities of the 63 | top of stacks, which gives us a tensor called alpha. From alpha we deduce the 64 | average top of the stack to be sent to the agent. 65 | 66 | As there are 3 actions (pop/push/no_op), unrolling this over 67 | long sequences and using big batch sizes consumes too much memory and the 68 | accelerators fail. 69 | 70 | Notations: 71 | Q: number of states of the ND stack (not the number of states of the 72 | RNN). 73 | S: number of symbols which can be pushed on the stack. 74 | T: Sequence length. 75 | B: Batch size. 76 | """ 77 | 78 | from typing import Any, Mapping, NamedTuple, Optional 79 | 80 | import chex 81 | import haiku as hk 82 | import jax 83 | import jax.nn as jnn 84 | import jax.numpy as jnp 85 | 86 | _EPSILON = 0.001 87 | 88 | 89 | class NDStack(NamedTuple): 90 | """The non-deterministic stack. 91 | 92 | Note that alpha and top_stack depend on gamma. 93 | """ 94 | gamma: chex.Array # Shape (B, T, T, Q, S, Q, S) 95 | alpha: chex.Array # Shape (B, T, Q, S) 96 | top_stack: chex.Array # Shape (B, S) 97 | 98 | 99 | def _update_stack(ndstack: NDStack, 100 | push_actions: chex.Array, 101 | pop_actions: chex.Array, 102 | replace_actions: chex.Array, 103 | timestep: int, 104 | read_states: bool = True) -> NDStack: 105 | """Returns an updated NDStack. 106 | 107 | Args: 108 | ndstack: See above. Contains the internals needed to simulate a 109 | non-deterministic stack. 110 | push_actions: A tensor of shape (B, Q, S, Q, S). 111 | pop_actions: A tensor of shape (B, Q, S, Q). 112 | replace_actions: A tensor of shape (B, Q, S, Q, S). 113 | timestep: The current timestep while processing the sequence. 114 | read_states: Whether to read the state of the NPDA as well. 115 | """ 116 | stack_size = ndstack.gamma.shape[2] 117 | mask = jnp.zeros((stack_size, stack_size)) 118 | mask = mask.at[timestep - 1, timestep].set(1) 119 | new_push_gamma_t = jnp.einsum('bqxry,tT->btTqxry', push_actions, 120 | mask)[:, :, timestep] 121 | 122 | index_k = jnp.stack([jnp.arange(start=0, stop=stack_size)] * stack_size) 123 | index_i = jnp.transpose(index_k) 124 | timestep_arr = jnp.full((stack_size, stack_size), timestep) 125 | index_mask = jnp.logical_and(index_k > index_i, index_k < timestep_arr - 1) 126 | index_mask = jnp.einsum('tT,bqxry->btTqxry', index_mask, 127 | jnp.ones(push_actions.shape)) 128 | new_pop_gamma_t = jnp.einsum( 129 | 'bikqxuy,bkuysz,bszr->biqxry', 130 | index_mask * ndstack.gamma, 131 | ndstack.gamma[:, :, timestep - 1], 132 | pop_actions, 133 | ) 134 | 135 | new_replace_gamma_t = jnp.einsum('biqxsz,bszry->biqxry', 136 | ndstack.gamma[:, :, 137 | timestep - 1], replace_actions) 138 | 139 | new_gamma = jax.vmap(jax.vmap(lambda x, y: x.at[timestep].set(y)))( 140 | ndstack.gamma, new_replace_gamma_t + new_pop_gamma_t + new_push_gamma_t) 141 | 142 | alpha_t = jnp.einsum('biqx,biqxry->bry', ndstack.alpha, new_gamma[:, :, 143 | timestep]) 144 | new_alpha = jax.vmap(lambda x, y: x.at[timestep].set(y))(ndstack.alpha, 145 | alpha_t) 146 | 147 | if read_states: 148 | batch_size, states, symbols = alpha_t.shape 149 | obs = jnp.reshape(alpha_t, (batch_size, states * symbols)) 150 | else: 151 | obs = jnp.sum(alpha_t, axis=1) 152 | 153 | obs = obs / (jnp.sum(obs, axis=-1, keepdims=True) + _EPSILON) 154 | return NDStack(new_gamma, new_alpha, top_stack=obs) 155 | 156 | 157 | # First element is the NDStack, second is the current timestep, third is the 158 | # hidden internal state. 159 | _NDStackRnnState = tuple[NDStack, chex.Array, chex.Array] 160 | 161 | 162 | class NDStackRNNCore(hk.RNNCore): 163 | """Core for the non-deterministic stack RNN.""" 164 | 165 | def __init__( 166 | self, 167 | stack_symbols: int, 168 | stack_states: int, 169 | stack_size: int = 30, 170 | inner_core: type[hk.RNNCore] = hk.VanillaRNN, 171 | read_states: bool = False, 172 | name: Optional[str] = None, 173 | **inner_core_kwargs: Mapping[str, Any] 174 | ): 175 | """Initializes. 176 | 177 | Args: 178 | stack_symbols: The number of symbols which can be used in the stack. 179 | stack_states: The number of states of the non-deterministic stack. 180 | Corresponds to the number of branching in the graph, ie roughly n_stacks 181 | = stack_states ^ t. 182 | stack_size: The total size of the stacks. Be careful when increasing this 183 | value since the computation is in O(stack_size ^ 3). 184 | inner_core: The inner RNN core builder. 185 | read_states: Whether to read the states on the NPDA or only the top of the 186 | stack. 187 | name: See base class. 188 | **inner_core_kwargs: The arguments to be passed to the inner RNN core 189 | builder. 190 | """ 191 | super().__init__(name=name) 192 | self._rnn_core = inner_core(**inner_core_kwargs) 193 | self._stack_symbols = stack_symbols 194 | self._stack_states = stack_states 195 | self._stack_size = stack_size 196 | self._read_states = read_states 197 | 198 | def __call__( 199 | self, inputs: chex.Array, prev_state: _NDStackRnnState 200 | ) -> tuple[chex.Array, _NDStackRnnState]: 201 | """Steps the non-deterministic stack RNN core. 202 | 203 | See base class docstring. 204 | 205 | Args: 206 | inputs: An input array of shape (batch_size, input_size). The time 207 | dimension is not included since it is an RNNCore, which is unrolled over 208 | the time dimension. 209 | prev_state: A _NDStackRnnState tuple, consisting of the previous nd-stack, 210 | the previous timestep and the previous state of the inner core. 211 | 212 | Returns: 213 | - output: An output array of shape (batch_size, output_size). 214 | - next_state: Same format as prev_state. 215 | """ 216 | ndstack, timestep, old_core_state = prev_state 217 | 218 | # The network can always read the top of the stack. 219 | batch_size = ndstack.gamma.shape[0] 220 | inputs = jnp.concatenate([inputs, ndstack.top_stack], axis=-1) 221 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state) 222 | 223 | n_push_actions = (self._stack_states * self._stack_symbols)**2 224 | n_pop_actions = self._stack_states**2 * self._stack_symbols 225 | n_replace_actions = (self._stack_states * self._stack_symbols)**2 226 | actions = hk.Linear(n_push_actions + n_pop_actions + n_replace_actions)( 227 | new_core_output) 228 | actions = jnn.softmax(actions, axis=-1) 229 | 230 | push_actions = jnp.reshape( 231 | actions[:, :n_push_actions], 232 | (batch_size, self._stack_states, self._stack_symbols, 233 | self._stack_states, self._stack_symbols)) 234 | 235 | pop_actions = jnp.reshape( 236 | actions[:, n_push_actions:n_push_actions + n_pop_actions], 237 | (batch_size, self._stack_states, self._stack_symbols, 238 | self._stack_states)) 239 | 240 | replace_actions = jnp.reshape( 241 | actions[:, -n_replace_actions:], 242 | (batch_size, self._stack_states, self._stack_symbols, 243 | self._stack_states, self._stack_symbols)) 244 | 245 | new_ndstack = _update_stack( 246 | ndstack, 247 | push_actions, 248 | pop_actions, 249 | replace_actions, (timestep + 1)[0], 250 | read_states=self._read_states) 251 | return new_core_output, (new_ndstack, timestep + 1, new_core_state) 252 | 253 | def initial_state(self, batch_size: Optional[int]) -> _NDStackRnnState: 254 | """Returns the initial state of the core, a hidden state and an empty stack.""" 255 | core_state = self._rnn_core.initial_state(batch_size) 256 | 257 | # Gamma, the transition matrix, is initialized to full zeros: there is no 258 | # connection in the graph at the beginning. 259 | gamma = jnp.zeros( 260 | (batch_size, self._stack_size, self._stack_size, self._stack_states, 261 | self._stack_symbols, self._stack_states, self._stack_symbols)) 262 | 263 | # Alpha is zero everywhere except for the first node, which is (0, q0, 0). 264 | alpha = jnp.zeros( 265 | (batch_size, self._stack_size, self._stack_states, self._stack_symbols)) 266 | alpha = jax.vmap(lambda x: x.at[0, 0, 0].set(1))(alpha) 267 | 268 | if self._read_states: 269 | top_stack = jnp.zeros( 270 | (batch_size, self._stack_states * self._stack_symbols)) 271 | else: 272 | # The top of the stack is 0 as the first node contains the symbol 0. 273 | top_stack = jnp.zeros((batch_size, self._stack_symbols)) 274 | 275 | ndstack = NDStack(gamma, alpha, top_stack) 276 | return ndstack, jnp.zeros((batch_size,), dtype=jnp.int32), core_state 277 | -------------------------------------------------------------------------------- /models/positional_encodings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Positional encodings, used in `transformer.py`.""" 17 | 18 | import enum 19 | import math 20 | from typing import Any 21 | 22 | import chex 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | class PositionalEncodings(enum.Enum): 30 | """Enum for all the positional encodings implemented.""" 31 | NONE = 0 32 | SIN_COS = 1 33 | ALIBI = 2 34 | RELATIVE = 3 35 | ROTARY = 4 36 | 37 | 38 | # General type used throughout the class for pos enc parameters. 39 | PositionalEncodingsParams = Any 40 | 41 | 42 | @chex.dataclass 43 | class SinCosParams: 44 | """Parameters for the classical sin/cos positional encoding.""" 45 | # The maximum wavelength used. 46 | max_time: int = 10_000 47 | 48 | 49 | # We will use this same class for Rotary and Relative. 50 | RotaryParams = SinCosParams 51 | RelativeParams = SinCosParams 52 | 53 | 54 | POS_ENC_TABLE = { 55 | 'NONE': PositionalEncodings.NONE, 56 | 'SIN_COS': PositionalEncodings.SIN_COS, 57 | 'ALIBI': PositionalEncodings.ALIBI, 58 | 'RELATIVE': PositionalEncodings.RELATIVE, 59 | 'ROTARY': PositionalEncodings.ROTARY, 60 | } 61 | 62 | POS_ENC_PARAMS_TABLE = { 63 | 'NONE': SinCosParams, 64 | 'SIN_COS': SinCosParams, 65 | 'ALIBI': SinCosParams, 66 | 'RELATIVE': RelativeParams, 67 | 'ROTARY': RotaryParams, 68 | } 69 | 70 | 71 | def sinusoid_position_encoding( 72 | sequence_length: int, 73 | hidden_size: int, 74 | memory_length: int = 0, 75 | max_timescale: float = 1e4, 76 | min_timescale: float = 2.0, 77 | clamp_length: int = 0, 78 | causal: bool = False, 79 | ): 80 | """Creates sinusoidal encodings. 81 | 82 | The time dimension is larger than sequence_length as we need to cover all 83 | cases of looking in either the future or past. 84 | 85 | Args: 86 | sequence_length: `int` sequence length, L 87 | hidden_size: `int` dimension of the positional encoding vectors, D 88 | memory_length: `int` size of the memory, M 89 | max_timescale: `int` maximum timescale for the frequency 90 | min_timescale: `int` minimum timescale for the frequency 91 | clamp_length: If greater than 0, any positions further apart than 92 | `clamp_length` are clamped to this value 93 | causal: If true then generates a smaller set (L vs 2 * L) of time-encodings 94 | for the use-case of causal attention. 95 | 96 | Returns: 97 | An array of shape [L + M, D] for causal and [2 * L + M, D] otherwise. 98 | """ 99 | freqs = np.arange(0, hidden_size, min_timescale) 100 | inv_freq = max_timescale ** (-freqs / hidden_size) 101 | # Since inputs can look into the past and into the future, depending on the 102 | # permutation mask, we need to have relative encodings for both. The furthest 103 | # back an input can see is the final token, up to sequence_length + 104 | # memory_length - 1. The furthest ahead an input can see is for token 0 where 105 | # it can see up to sequence_length - 1 future tokens. 106 | if causal: 107 | pos_seq = np.arange(sequence_length + memory_length, 0, -1.0) 108 | else: 109 | pos_seq = np.arange(sequence_length + memory_length, -sequence_length, -1.0) 110 | if clamp_length: 111 | pos_seq = np.clip(pos_seq, a_min=-clamp_length, a_max=clamp_length) 112 | sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq) 113 | pos_emb = np.concatenate( 114 | [np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1 115 | ) 116 | return pos_emb 117 | 118 | 119 | def _rel_shift_inner(logits: chex.Array, attention_length: int) -> chex.Array: 120 | """Shifts the relative logits. 121 | 122 | This is a more general than the original Transformer-XL implementation as 123 | inputs may also see the future. (The implementation does not rely on a 124 | causal mask removing the upper-right triangle.) 125 | 126 | Given attention length 3 and inputs: 127 | [[-3, -2, -1, 0, 1, 2], 128 | [-3, -2, -1, 0, 1, 2], 129 | [-3, -2, -1, 0, 1, 2]] 130 | 131 | The shifted output is: 132 | [[0, 1, 2], 133 | [-1, 0, 1], 134 | [-2, -1, 0]] 135 | 136 | Args: 137 | logits: input tensor of shape [T_q, T_v + T_q] 138 | attention_length: T_v `int` length of the attention, should be equal to 139 | memory size + sequence length. 140 | 141 | Returns: 142 | A shifted version of the input of size [T_q, T_v]. In each row, a window of 143 | size T_v elements is kept. The window starts at 144 | the rightmost end, for the first row. It then shifts left by 1 for each 145 | subsequent row. 146 | """ 147 | if logits.ndim != 2: 148 | raise ValueError('`logits` needs to be an array of dimension 2.') 149 | tq, total_len = logits.shape 150 | assert total_len == tq + attention_length 151 | logits = jnp.reshape(logits, [total_len, tq]) 152 | logits = jax.lax.slice(logits, (1, 0), logits.shape) # logits[1:] 153 | logits = jnp.reshape(logits, [tq, total_len - 1]) 154 | # Equiv to logits[:, :attention_length]. 155 | logits = jax.lax.slice(logits, (0, 0), (tq, attention_length)) 156 | return logits 157 | 158 | 159 | def _rel_shift_causal(logits: chex.Array) -> chex.Array: 160 | """Shifts the relative logits, assuming causal attention. 161 | 162 | Given inputs: 163 | [[-4, -3, -2, -1], 164 | [-4, -3, -2, -1]] 165 | 166 | The shifted (and, later, masked) output is: 167 | [[-3, -2, -1, 0], 168 | [-4, -3, -2, -1]] 169 | 170 | Args: 171 | logits: input tensor of shape [T_q, T_v] 172 | 173 | Returns: 174 | A shifted version of the input of size [T_q, T_v]. 175 | """ 176 | t1, t2 = logits.shape 177 | # We prepend zeros on the final timescale dimension. 178 | to_pad = jnp.zeros_like(logits[..., :1]) 179 | x = jnp.concatenate((to_pad, logits), axis=-1) 180 | 181 | # Reshape trick to shift input. 182 | x = jnp.reshape(x, [t2 + 1, t1]) 183 | 184 | # Remove extra time dimension and re-shape. 185 | x = jax.lax.slice(x, [1] + [0] * (x.ndim - 1), x.shape) 186 | 187 | return jnp.reshape(x, [t1, t2]) 188 | 189 | 190 | def relative_shift( 191 | logits: chex.Array, attention_length: int, causal: bool = False 192 | ) -> chex.Array: 193 | if causal: 194 | fn = _rel_shift_causal 195 | else: 196 | fn = lambda t: _rel_shift_inner(t, attention_length) 197 | return jax.vmap(jax.vmap(fn))(logits) 198 | 199 | 200 | def apply_rotary_encoding( 201 | x: jnp.ndarray, position: jnp.ndarray, max_time: int = 10_000 202 | ) -> jnp.ndarray: 203 | """Applies RoPE positional encodings for the input. 204 | 205 | Paper: https://arxiv.org/abs/2104.09864 206 | 207 | Args: 208 | x: The input tensor on which RoPE will be applied. Usually it is either some 209 | queries q or some keys k. 210 | position: The positions to use. Usually it's an arange of the maximum 211 | length. 212 | max_time: Constant used to scale position by in the encodings. 213 | 214 | Returns: 215 | A tensor with the same shape as x. 216 | """ 217 | # Expand dims for positions to support inputs of shapes BTC or BTHC. 218 | freq_seq = jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) 219 | freq_seq = freq_seq / (x.shape[-1] // 2) 220 | inv_freq = max_time**-freq_seq 221 | inv_freq = jnp.repeat(inv_freq, 2, 0) 222 | # Produce position inputs to periodic functions. 223 | t = position[:, :, None, None] * inv_freq[None, None, None, :] 224 | x_rot = jnp.einsum('bthd,dD->bthD', x, _rope_kernel(x.shape[-1], x.dtype)) 225 | return x * jnp.cos(t).astype(x.dtype) + jnp.sin(t).astype(x.dtype) * x_rot 226 | 227 | 228 | def _rope_kernel(n: int, dtype: Any) -> np.ndarray: 229 | """Reorders the embedding dimension of an array, to make rotation easier.""" 230 | # We implement the equivalent of 231 | # even_dims, odd_dims, = x[..., ::2], x[..., 1::2] 232 | # return jnp.stack((-odd_dims, even_dims), axis=-1).reshape(x.shape) 233 | # with a custom kernel for einsum. This allows the computation to execute 234 | # on the MXU instead of producing a slow gather. 235 | assert n % 2 == 0, n 236 | kernel = np.zeros((n, n), dtype) 237 | for i in range(n): 238 | # Swap each neighbouring pair of values. 239 | if i % 2 == 0: 240 | kernel[i, i + 1] = 1 241 | else: 242 | kernel[i, i - 1] = -1 243 | return kernel 244 | 245 | 246 | def compute_attention_with_relative_encodings( 247 | queries: chex.Array, 248 | keys: chex.Array, 249 | max_time: int = 10_000, 250 | causal: bool = False) -> chex.Array: 251 | """Returns attention with relative positional encodings. 252 | 253 | This code strictly follows what is described in the TransformerXL paper. 254 | https://arxiv.org/pdf/1901.02860.pdf 255 | 256 | Args: 257 | queries: The queries used for attention. Shape (b, t, h, d). 258 | keys: The keys used for attention. Shape (b, T, h, d). 259 | max_time: Constant used to scale position by in the sin/cos encodings. 260 | causal: Whether to use causal attention when shifting the relative logits. 261 | 262 | Returns: 263 | The attention logits. Shape (b, h, t, T). 264 | """ 265 | batch_size, k_seq_len, num_heads, num_hiddens = keys.shape 266 | hiddens = num_hiddens * num_heads 267 | 268 | # First compute the content logits. 269 | content_bias = hk.get_parameter( 270 | name='relpos_contentbias', 271 | shape=[num_heads, num_hiddens], 272 | init=hk.initializers.RandomNormal(stddev=0.02)) 273 | content_logits = jnp.einsum('bthd,bThd->bhtT', queries + content_bias, keys) 274 | 275 | positional_encodings = sinusoid_position_encoding( 276 | sequence_length=k_seq_len, 277 | hidden_size=hiddens, 278 | memory_length=0, 279 | max_timescale=max_time, 280 | min_timescale=2, 281 | clamp_length=0, 282 | causal=causal, 283 | ) 284 | positional_encodings = jnp.broadcast_to(positional_encodings, (batch_size,) + 285 | positional_encodings.shape) 286 | relative_keys = hk.Linear(hiddens, with_bias=False)(positional_encodings) 287 | relative_keys = jnp.reshape( 288 | relative_keys, positional_encodings.shape[:-1] + (num_heads, num_hiddens)) 289 | 290 | # Then compute the relative part. 291 | relative_bias = hk.get_parameter( 292 | name='relpos_relativebias', 293 | shape=[num_heads, num_hiddens], 294 | init=hk.initializers.RandomNormal(stddev=0.02)) 295 | relative_logits = jnp.einsum('bthd,bThd->bhtT', queries + relative_bias, 296 | relative_keys) 297 | # We shift the relative logits instead of the positional encoding matrix as 298 | # described in Appendix B of the paper (https://arxiv.org/pdf/1901.02860.pdf). 299 | relative_logits = relative_shift( 300 | relative_logits, attention_length=content_logits.shape[-1], causal=causal 301 | ) 302 | assert content_logits.shape == relative_logits.shape 303 | return content_logits + relative_logits 304 | 305 | 306 | def _get_alibi_slopes(num_heads: int) -> list[float]: 307 | """Returns the slopes for the different attention heads. 308 | 309 | While this does not exactly match the description of the [ALiBi 310 | paper](https://arxiv.org/pdf/2108.12409.pdf), it corresponds to the [official 311 | implementation](https://github.com/ofirpress/attention_with_linear_biases/blob/a06526fbfe557f9148e414b8569dcb97c7b182ba/fairseq/models/transformer.py#L742). 312 | 313 | Args: 314 | num_heads: The number of attention heads to create slopes for. 315 | """ 316 | 317 | def get_slopes_power_of_2(n): 318 | start = (2**(-2**-(math.log2(n) - 3))) 319 | ratio = start 320 | return [start * ratio**i for i in range(n)] 321 | 322 | if math.log2(num_heads).is_integer(): 323 | return get_slopes_power_of_2(num_heads) 324 | else: 325 | closest_power_of_2 = 2**math.floor(math.log2(num_heads)) 326 | return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( 327 | 2 * closest_power_of_2)[0::2][:num_heads - closest_power_of_2]) 328 | 329 | 330 | def compute_alibi_encodings_biases( 331 | attention_shape: tuple[int, ...] 332 | ) -> chex.Array: 333 | """Returns the biases following the ALiBi method. 334 | 335 | This code strictly follows what is described in the ALiBi paper. 336 | https://arxiv.org/pdf/2108.12409.pdf 337 | 338 | Args: 339 | attention_shape: The attention logits shape, without batch size, (h, t, T). 340 | 341 | Returns: 342 | The alibi biases, same shape as the input logits shape. 343 | """ 344 | num_heads, q_seq_len, k_seq_len = attention_shape 345 | 346 | # Since we do not use causal masking, the upper triangle of the matrix has to 347 | # be nonzero. Therefore, we set it equal to the lower triangle, but we also 348 | # add a constant factor of 0.5 to the lower triangle, to (arbitrarily) break 349 | # the symmetry (otherwise, the model cannot distinguish left and right). 350 | alibi = np.zeros((q_seq_len, k_seq_len)) 351 | alibi -= sum(np.tri(*alibi.shape, k=-i) for i in range(1, q_seq_len)) 352 | alibi -= sum(np.tri(*alibi.T.shape, k=-i).T for i in range(1, k_seq_len)) 353 | alibi += 0.5 * np.tri(*alibi.shape, k=-1) 354 | 355 | return alibi * jnp.array(_get_alibi_slopes(num_heads))[:, None, None] 356 | -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builders for RNN/LSTM cores.""" 17 | 18 | from typing import Any, Callable 19 | 20 | import haiku as hk 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | 24 | from neural_networks_chomsky_hierarchy.models import tape_rnn 25 | 26 | 27 | def make_rnn( 28 | output_size: int, 29 | rnn_core: type[hk.RNNCore], 30 | return_all_outputs: bool = False, 31 | input_window: int = 1, 32 | **rnn_kwargs: Any 33 | ) -> Callable[[jnp.ndarray], jnp.ndarray]: 34 | """Returns an RNN model, not haiku transformed. 35 | 36 | Only the last output in the sequence is returned. A linear layer is added to 37 | match the required output_size. 38 | 39 | Args: 40 | output_size: The output size of the model. 41 | rnn_core: The haiku RNN core to use. LSTM by default. 42 | return_all_outputs: Whether to return the whole sequence of outputs of the 43 | RNN, or just the last one. 44 | input_window: The number of tokens that are fed at once to the RNN. 45 | **rnn_kwargs: Kwargs to be passed to the RNN core. 46 | """ 47 | 48 | def rnn_model(x: jnp.ndarray, input_length: int = 1) -> jnp.ndarray: 49 | core = rnn_core(**rnn_kwargs) 50 | if issubclass(rnn_core, tape_rnn.TapeRNNCore): 51 | initial_state = core.initial_state(x.shape[0], input_length) # pytype: disable=wrong-arg-count 52 | else: 53 | initial_state = core.initial_state(x.shape[0]) 54 | 55 | batch_size, seq_length, embed_size = x.shape 56 | if seq_length % input_window != 0: 57 | x = jnp.pad(x, ((0, 0), (0, input_window - seq_length % input_window), 58 | (0, 0))) 59 | new_seq_length = x.shape[1] 60 | x = jnp.reshape( 61 | x, 62 | (batch_size, new_seq_length // input_window, input_window, embed_size)) 63 | 64 | x = hk.Flatten(preserve_dims=2)(x) 65 | 66 | output, _ = hk.dynamic_unroll( 67 | core, x, initial_state, time_major=False, return_all_states=True) 68 | output = jnp.reshape(output, (batch_size, new_seq_length, output.shape[-1])) 69 | 70 | if not return_all_outputs: 71 | output = output[:, -1, :] # (batch, time, alphabet_dim) 72 | output = jnn.relu(output) 73 | return hk.Linear(output_size)(output) 74 | 75 | return rnn_model 76 | -------------------------------------------------------------------------------- /models/stack_rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Stack RNN core. 17 | 18 | Following the paper from Joulin et al (2015): 19 | https://arxiv.org/abs/1503.01007 20 | 21 | The idea is to add a stack extension to a recurrent neural network to be able to 22 | simulate a machine accepting context-free languages. 23 | The stack is completely differentiable. The actions taken are probabilities 24 | only and therefore no RL is required. The stack and state update are just linear 25 | combinations of the last states and these probabilities. 26 | """ 27 | 28 | from typing import Any, Mapping, Optional 29 | 30 | import einshape 31 | import haiku as hk 32 | import jax 33 | import jax.nn as jnn 34 | import jax.numpy as jnp 35 | 36 | 37 | # First element is the stacks, second is the hidden internal state. 38 | _StackRnnState = tuple[jnp.ndarray, jnp.ndarray] 39 | 40 | # Number of actions the stack-RNN can take, namely POP, PUSH and NO_OP. 41 | _NUM_ACTIONS = 3 42 | 43 | 44 | def _update_stack(stack: jnp.ndarray, actions: jnp.ndarray, 45 | push_value: jnp.ndarray) -> jnp.ndarray: 46 | """Updates the stack values. 47 | 48 | We update the stack in two steps. 49 | In the first step, we update the top of the stack, and essentially do: 50 | stack[0] = push_action * push_value 51 | + pop_action * stack[1] 52 | + noop_action * stack[0] 53 | 54 | Then, in the second step, we update the rest of the stack and we move the 55 | elements up and down, depending on the action executed: 56 | * If push_action were 1, then we'd be purely pushing a new element 57 | to the top of the stack, so we'd move all elements down by one. 58 | * Likewise, if pop_action were 1, we'd be purely taking an element 59 | off the top of the stack, so we'd move all elements up by one. 60 | * Finally, if noop_action were 1, we'd leave elements where they were. 61 | The update is therefore essentially: 62 | stack[i] = push_action * stack[i-1] 63 | + pop_action * stack[i+1] 64 | + noop_action * stack[i] 65 | 66 | Args: 67 | stack: The current stack, shape (batch_size, stack_size, stack_cell_size). 68 | actions: The array of probabilities of the actions, shape (batch_size, 3). 69 | push_value: The vector to push on the stack, if the push action probability 70 | is positive, shape (batch_size, stack_cell_size). 71 | 72 | Returns: 73 | The new stack, same shape as the input stack. 74 | """ 75 | batch_size, stack_size, stack_cell_size = stack.shape 76 | 77 | # Tiling the actions to match the top of the stack. 78 | # Shape (batch_size, stack_cell_size, _NUM_ACTIONS) 79 | cell_tiled_stack_actions = einshape.jax_einshape( 80 | 'ba->bsa', actions, s=stack_cell_size) 81 | push_action = cell_tiled_stack_actions[..., 0] 82 | pop_action = cell_tiled_stack_actions[..., 1] 83 | pop_value = stack[..., 1, :] 84 | no_op_action = cell_tiled_stack_actions[..., 2] 85 | no_op_value = stack[..., 0, :] 86 | 87 | # Shape (batch_size, 1, stack_cell_size) 88 | top_new_stack = ( 89 | push_action * push_value + pop_action * pop_value + 90 | no_op_action * no_op_value) 91 | top_new_stack = jnp.expand_dims(top_new_stack, axis=1) 92 | 93 | # Tiling the actions to match all of the stack except the top. 94 | # Shape (batch_size, stack_size, stack_cell_size, _NUM_ACTIONS) 95 | stack_tiled_stack_actions = einshape.jax_einshape( 96 | 'ba->bcsa', actions, s=stack_cell_size, c=stack_size - 1) 97 | push_action = stack_tiled_stack_actions[..., 0] 98 | push_value = stack[..., :-1, :] 99 | pop_action = stack_tiled_stack_actions[..., 1] 100 | pop_extra_zeros = jnp.zeros((batch_size, 1, stack_cell_size)) 101 | pop_value = jnp.concatenate([stack[..., 2:, :], pop_extra_zeros], axis=1) 102 | no_op_action = stack_tiled_stack_actions[..., 2] 103 | no_op_value = stack[..., 1:, :] 104 | 105 | # Shape (batch_size, stack_size-1, stack_cell_size) 106 | rest_new_stack = ( 107 | push_action * push_value + pop_action * pop_value + 108 | no_op_action * no_op_value) 109 | 110 | # Finally concatenate the new top with the new rest of the stack. 111 | return jnp.concatenate([top_new_stack, rest_new_stack], axis=1) 112 | 113 | 114 | class StackRNNCore(hk.RNNCore): 115 | """Core for the stack RNN.""" 116 | 117 | def __init__( 118 | self, 119 | stack_cell_size: int, 120 | stack_size: int = 30, 121 | n_stacks: int = 1, 122 | inner_core: type[hk.RNNCore] = hk.VanillaRNN, 123 | name: Optional[str] = None, 124 | **inner_core_kwargs: Mapping[str, Any] 125 | ): 126 | """Initializes. 127 | 128 | Args: 129 | stack_cell_size: The dimension of the vectors we put in the stack. 130 | stack_size: The total number of vectors we can stack. 131 | n_stacks: Number of stacks to use in the network. 132 | inner_core: The inner RNN core builder. 133 | name: See base class. 134 | **inner_core_kwargs: The arguments to be passed to the inner RNN core 135 | builder. 136 | """ 137 | super().__init__(name=name) 138 | self._rnn_core = inner_core(**inner_core_kwargs) 139 | self._stack_cell_size = stack_cell_size 140 | self._stack_size = stack_size 141 | self._n_stacks = n_stacks 142 | 143 | def __call__( 144 | self, inputs: jnp.ndarray, prev_state: _StackRnnState 145 | ) -> tuple[jnp.ndarray, _StackRnnState]: 146 | """Steps the stack RNN core. 147 | 148 | See base class docstring. 149 | 150 | Args: 151 | inputs: An input array of shape (batch_size, input_size). The time 152 | dimension is not included since it is an RNNCore, which is unrolled over 153 | the time dimension. 154 | prev_state: A _StackRnnState tuple, consisting of the previous stacks and 155 | the previous state of the inner core. Each stack has shape (batch_size, 156 | stack_size, stack_cell_size), such that `stack[n][0]` represents the top 157 | of the stack for the nth batch item, and `stack[n][-1]` the bottom of 158 | the stack. The stacks are just the concatenation of all these tensors. 159 | 160 | Returns: 161 | - output: An output array of shape (batch_size, output_size). 162 | - next_state: Same format as prev_state. 163 | """ 164 | stacks, old_core_state = prev_state 165 | 166 | # The network can always read the top of the stack. 167 | batch_size = stacks.shape[0] 168 | top_stacks = stacks[:, :, 0, :] 169 | top_stacks = jnp.reshape( 170 | top_stacks, (batch_size, self._n_stacks * self._stack_cell_size)) 171 | inputs = jnp.concatenate([inputs, top_stacks], axis=-1) 172 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state) 173 | push_values = hk.Linear(self._n_stacks * self._stack_cell_size)( 174 | new_core_output) 175 | push_values = jnp.reshape( 176 | push_values, (batch_size, self._n_stacks, self._stack_cell_size)) 177 | 178 | # Shape (batch_size, _NUM_ACTIONS) 179 | stack_actions = jnn.softmax( 180 | hk.Linear(self._n_stacks * _NUM_ACTIONS)(new_core_output), axis=-1) 181 | stack_actions = jnp.reshape(stack_actions, 182 | (batch_size, self._n_stacks, _NUM_ACTIONS)) 183 | 184 | new_stacks = jax.vmap( 185 | _update_stack, in_axes=1, out_axes=1)(stacks, stack_actions, 186 | push_values) 187 | return new_core_output, (new_stacks, new_core_state) 188 | 189 | def initial_state(self, batch_size: Optional[int]) -> _StackRnnState: 190 | """Returns the initial state of the core, a hidden state and an empty stack.""" 191 | core_state = self._rnn_core.initial_state(batch_size) 192 | stacks = jnp.zeros( 193 | (batch_size, self._n_stacks, self._stack_size, self._stack_cell_size)) 194 | return stacks, core_state 195 | -------------------------------------------------------------------------------- /models/tape_rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Implements the Tape RNN.""" 17 | 18 | import abc 19 | import functools 20 | from typing import Any, Optional, Sequence 21 | 22 | import chex 23 | import haiku as hk 24 | import jax 25 | from jax import nn as jnn 26 | from jax import numpy as jnp 27 | 28 | # The first element is the memory, the second is the hidden internal state, and 29 | # the third is the input length, necessary for adaptive actions. 30 | _TapeRNNState = tuple[chex.Array, chex.Array, chex.Array] 31 | 32 | 33 | class TapeRNNCore(hk.RNNCore, abc.ABC): 34 | """Core for the tape RNN.""" 35 | 36 | def __init__( 37 | self, 38 | memory_cell_size: int, 39 | memory_size: int = 30, 40 | n_tapes: int = 1, 41 | mlp_layers_size: Sequence[int] = (64, 64), 42 | inner_core: type[hk.RNNCore] = hk.VanillaRNN, 43 | name: Optional[str] = None, 44 | **inner_core_kwargs: Any 45 | ): 46 | """Initializes. 47 | 48 | Args: 49 | memory_cell_size: The dimension of the vectors we put in memory. 50 | memory_size: The size of the tape, fixed value along the episode. 51 | n_tapes: Number of tapes to use. Default is 1. 52 | mlp_layers_size: Sizes for the inner MLP layers. Can be empty, in which 53 | case the MLP is a linear layer. 54 | inner_core: The inner RNN core builder. 55 | name: See base class. 56 | **inner_core_kwargs: The arguments to be passed to the inner RNN core 57 | builder. 58 | """ 59 | super().__init__(name=name) 60 | self._rnn_core = inner_core(**inner_core_kwargs) 61 | self._mlp_layers_size = mlp_layers_size 62 | self._memory_cell_size = memory_cell_size 63 | self._memory_size = memory_size 64 | self._n_tapes = n_tapes 65 | 66 | @abc.abstractmethod 67 | def _tape_operations( 68 | self, eye_memory: chex.Array, input_length: int 69 | ) -> list[chex.Array]: 70 | """Returns a set of updated memory slots. 71 | 72 | An eye matrix is passed and corresponds to the positions of the memory 73 | slots. This method returns a matrix with the new positions associated with 74 | the actions. For instance, for a 'left' action, the new matrix will just be 75 | a roll(eye_memory, shift=-1). This is general enough to allow any 76 | permutation on the indexes. 77 | 78 | Args: 79 | eye_memory: An eye matrix of shape [memory_size, memory_size]. 80 | input_length: The length of the input sequence. Can be useful for some 81 | operations. 82 | """ 83 | 84 | @property 85 | @abc.abstractmethod 86 | def num_actions(self) -> int: 87 | """Returns the number of actions which can be taken on the tape.""" 88 | 89 | def __call__( 90 | self, inputs: chex.Array, prev_state: _TapeRNNState 91 | ) -> tuple[chex.Array, _TapeRNNState]: 92 | """Steps the tape RNN core.""" 93 | memories, old_core_state, input_length = prev_state 94 | 95 | # The network can always read the value of the current cell. 96 | batch_size = memories.shape[0] 97 | current_memories = memories[:, :, 0, :] 98 | current_memories = jnp.reshape( 99 | current_memories, (batch_size, self._n_tapes * self._memory_cell_size)) 100 | inputs = jnp.concatenate([inputs, current_memories], axis=-1) 101 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state) 102 | readout_mlp = hk.nets.MLP( 103 | list(self._mlp_layers_size) + [self._n_tapes * self._memory_cell_size]) 104 | write_values = readout_mlp(new_core_output) 105 | write_values = jnp.reshape( 106 | write_values, (batch_size, self._n_tapes, self._memory_cell_size)) 107 | 108 | # Shape (batch_size, num_actions). 109 | actions = [] 110 | for _ in range(self._n_tapes): 111 | actions.append( 112 | jnn.softmax(hk.Linear(self.num_actions)(new_core_output), axis=-1)) 113 | actions = jnp.stack(actions, axis=1) 114 | 115 | update_memory = functools.partial( 116 | self._update_memory, input_length=input_length[0]) 117 | new_memories = jax.vmap( 118 | update_memory, in_axes=1, out_axes=1)(memories, actions, write_values) 119 | return new_core_output, (new_memories, new_core_state, input_length) 120 | 121 | def initial_state(self, batch_size: Optional[int], 122 | input_length: int) -> _TapeRNNState: # pytype: disable=signature-mismatch 123 | """Returns the initial state of the core.""" 124 | core_state = self._rnn_core.initial_state(batch_size) 125 | memories = jnp.zeros( 126 | (batch_size, self._n_tapes, self._memory_size, self._memory_cell_size)) 127 | return memories, core_state, jnp.array([input_length]) 128 | 129 | def _update_memory(self, memory: chex.Array, actions: chex.Array, 130 | write_values: chex.Array, input_length: int) -> chex.Array: 131 | """Computes the new memory based on the `actions` and `write_values`. 132 | 133 | Args: 134 | memory: The current memory with shape `[batch_size, memory_size, 135 | memory_cell_size]`. 136 | actions: The action probabilities with shape `[batch_size, num_actions]`. 137 | write_values: The values added to the first memory entry with shape 138 | `[batch_size, memory_cell_size]`. 139 | input_length: The length of the input. 140 | 141 | Returns: 142 | The new memory with shape `[batch_size, memory_size]`. 143 | """ 144 | _, memory_size, _ = memory.shape 145 | 146 | memory_with_write = jnp.concatenate( 147 | [jnp.expand_dims(write_values, axis=1), memory[:, 1:]], axis=1) 148 | 149 | eye_memory = jnp.eye(memory_size) 150 | operations = self._tape_operations(eye_memory, input_length) 151 | apply_operation = lambda x: jnp.einsum('mM,bMc->bmc', x, memory_with_write) 152 | memory_operations = jnp.stack(list(map(apply_operation, operations))) 153 | return jnp.einsum('Abmc,bA->bmc', memory_operations, actions) 154 | 155 | 156 | class TapeInputLengthJumpCore(TapeRNNCore): 157 | """A tape-RNN with extra jumps of the length of the input. 158 | 159 | 5 possible actions: 160 | - write and stay 161 | - write and move one cell left 162 | - write and move one cell right 163 | - write and move input_length cells left 164 | - write and move input_length cells right 165 | """ 166 | 167 | @property 168 | def num_actions(self) -> int: 169 | """Returns the number of actions of the tape.""" 170 | return 5 171 | 172 | def _tape_operations( 173 | self, eye_memory: chex.Array, input_length: int 174 | ) -> list[chex.Array]: 175 | write_stay = eye_memory 176 | write_left = jnp.roll(eye_memory, shift=-1, axis=0) 177 | write_right = jnp.roll(eye_memory, shift=1, axis=0) 178 | write_jump_left = jnp.roll(eye_memory, shift=-input_length, axis=0) 179 | write_jump_right = jnp.roll(eye_memory, shift=input_length, axis=0) 180 | return [ 181 | write_stay, write_left, write_right, write_jump_left, write_jump_right 182 | ] 183 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Transformer model.""" 17 | 18 | import dataclasses 19 | from typing import Callable, Optional 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.nn as jnn 25 | import jax.numpy as jnp 26 | 27 | from neural_networks_chomsky_hierarchy.models import positional_encodings as pos_encs_lib 28 | 29 | 30 | @chex.dataclass 31 | class TransformerConfig: 32 | """Hyperparameters used in the Transformer architectures.""" 33 | # The size of the model output (i.e., the output vocabulary size). 34 | output_size: int 35 | # The dimension of the first embedding. 36 | embedding_dim: int = 64 37 | # The number of multi-head attention layers. 38 | num_layers: int = 5 39 | # The number of heads per layer. 40 | num_heads: int = 8 41 | # The number of hidden neurons per head. If None, it is set to be equal to 42 | # `embedding_dim // num_heads`. 43 | num_hiddens_per_head: Optional[int] = None 44 | # The probability that each element is discarded by the dropout modules. 45 | dropout_prob: float = 0.1 46 | # The parameter initialization scale for the embeddings. 47 | emb_init_scale: float = 0.02 48 | # Whether to use the embeddings rather than raw inputs. 49 | use_embeddings: bool = True 50 | # Whether to share embeddings between the Encoder and the Decoder. 51 | share_embeddings: bool = False 52 | # The size of the sliding attention window. See MultiHeadDotProductAttention. 53 | attention_window: Optional[int] = None 54 | # The positional encoding used with default sin/cos (Vaswani et al., 2017). 55 | positional_encodings: pos_encs_lib.PositionalEncodings = dataclasses.field( 56 | default_factory=lambda: pos_encs_lib.PositionalEncodings.SIN_COS 57 | ) 58 | # The maximum size of the context (used by the posiitonal encodings). 59 | max_time: int = 10_000 60 | # The parameters for the positional encodings, default sin/cos. 61 | positional_encodings_params: pos_encs_lib.PositionalEncodingsParams = ( 62 | dataclasses.field(default_factory=pos_encs_lib.SinCosParams) 63 | ) 64 | # How much larger the hidden layer of the feedforward network should be 65 | # compared to the `embedding_dim`. 66 | widening_factor: int = 4 67 | # Add mask to make causal predictions. 68 | causal_masking: bool = False 69 | 70 | def __post_init__(self) -> None: 71 | """Sets `num_hiddens_per_head` if it is `None`.""" 72 | if self.num_hiddens_per_head is None: 73 | self.num_hiddens_per_head = self.embedding_dim // self.num_heads 74 | 75 | 76 | def layer_norm(x: chex.Array) -> chex.Array: 77 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 78 | 79 | 80 | def shift_right(x: chex.Array, output_size: int) -> chex.Array: 81 | """Right-shift the one-hot encoded input by padding on the temporal axis.""" 82 | x = jnp.argmax(x, axis=-1) 83 | 84 | # Add a time dimension for the single-output case (i.e., `ndim == 1`). 85 | if x.ndim == 1: 86 | x = jnp.expand_dims(x, axis=1) 87 | 88 | padded = jnp.pad( 89 | x, ((0, 0), (1, 0)), mode='constant', constant_values=output_size) 90 | 91 | return jnn.one_hot(padded[:, :-1], num_classes=output_size + 1) 92 | 93 | 94 | def compute_sliding_window_mask(sequence_length: int, 95 | attention_window: int) -> chex.Array: 96 | """Returns a k-diagonal mask for a sliding window. 97 | 98 | Args: 99 | sequence_length: The length of the sequence, which will determine the shape 100 | of the output. 101 | attention_window: The size of the sliding window. 102 | 103 | Returns: 104 | A symmetric matrix of shape (sequence_length, sequence_length), 105 | attention_window-diagonal, with ones on the diagonal and on all the 106 | upper/lower diagonals up to attention_window // 2. 107 | 108 | Raises: 109 | ValueError if attention_window is <= 0. 110 | """ 111 | if attention_window <= 0: 112 | raise ValueError( 113 | f'The attention window should be > 0. Got {attention_window}.') 114 | 115 | if attention_window == 1: 116 | return jnp.eye(sequence_length, sequence_length) 117 | 118 | attention_mask = jnp.sum( 119 | jnp.stack([ 120 | jnp.eye(sequence_length, sequence_length, k=k, dtype=jnp.int32) 121 | for k in range(1, attention_window // 2 + 1) 122 | ]), 123 | axis=0) 124 | attention_mask = attention_mask + jnp.transpose(attention_mask) 125 | attention_mask += jnp.eye(sequence_length, sequence_length) 126 | return attention_mask 127 | 128 | 129 | class MultiHeadDotProductAttention(hk.Module): 130 | """Multi-head dot-product attention (Vaswani et al., 2017).""" 131 | 132 | def __init__( 133 | self, 134 | num_heads: int, 135 | num_hiddens_per_head: int, 136 | positional_encodings: pos_encs_lib.PositionalEncodings, 137 | positional_encodings_params: pos_encs_lib.PositionalEncodingsParams, 138 | attention_window: Optional[int] = None, 139 | name: Optional[str] = None, 140 | ) -> None: 141 | """Initializes the attention module. 142 | 143 | Args: 144 | num_heads: Number of heads to use. 145 | num_hiddens_per_head: Number of hidden neurons per head. 146 | positional_encodings: Which positional encodings to use in the attention. 147 | positional_encodings_params: Parameters for the positional encodings. 148 | attention_window: Size of the attention sliding window. None means no 149 | sliding window is used (or equivalently, window=full_attention_length). 150 | We attend only on attention_window tokens around a given query token. We 151 | attend to tokens before AND after the query token. If attention_window 152 | is even, we use the value +1. 153 | name: Name of the module. 154 | """ 155 | super().__init__(name=name) 156 | self._num_heads = num_heads 157 | self._num_hiddens_per_head = num_hiddens_per_head 158 | self._positional_encodings = positional_encodings 159 | self._attention_window = attention_window 160 | self._positional_encodings_params = positional_encodings_params 161 | 162 | def __call__( 163 | self, 164 | inputs_q: chex.Array, 165 | inputs_kv: chex.Array, 166 | mask: Optional[chex.Array] = None, 167 | causal: bool = False, 168 | ) -> chex.Array: 169 | """Returns the output of the multi-head attention.""" 170 | batch_size, sequence_length, embedding_size = inputs_q.shape 171 | 172 | num_hiddens = self._num_hiddens_per_head * self._num_heads 173 | q = hk.Linear(num_hiddens, with_bias=False)(inputs_q) 174 | k = hk.Linear(num_hiddens, with_bias=False)(inputs_kv) 175 | v = hk.Linear(num_hiddens, with_bias=False)(inputs_kv) 176 | # The second (sequence) dimension is undefined since it can differ between 177 | # queries and keys/values when decoding. 178 | new_shape = (batch_size, -1, self._num_heads, self._num_hiddens_per_head) 179 | q = jnp.reshape(q, new_shape) 180 | k = jnp.reshape(k, new_shape) 181 | v = jnp.reshape(v, new_shape) 182 | 183 | # Let b=batch_size, t=seq_len, h=num_heads, and d=num_hiddens_per_head. 184 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.RELATIVE: 185 | # We type hint the params to match the if statement, for pytype. 186 | self._positional_encodings_params: pos_encs_lib.RelativeParams 187 | attention = pos_encs_lib.compute_attention_with_relative_encodings( 188 | q, k, self._positional_encodings_params.max_time, causal=causal 189 | ) 190 | else: 191 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ROTARY: 192 | q = pos_encs_lib.apply_rotary_encoding( 193 | q, position=jnp.arange(q.shape[1])[None, :] 194 | ) 195 | k = pos_encs_lib.apply_rotary_encoding( 196 | k, position=jnp.arange(k.shape[1])[None, :] 197 | ) 198 | attention = jnp.einsum('bthd,bThd->bhtT', q, k) 199 | attention *= 1.0 / jnp.sqrt(self._num_hiddens_per_head) 200 | 201 | # ALiBi encodings are not scaled with the 1 / sqrt(d_k) factor. 202 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ALIBI: 203 | attention += pos_encs_lib.compute_alibi_encodings_biases( 204 | attention.shape[1:] 205 | ) 206 | 207 | if self._attention_window is not None: 208 | # We compute the sliding attention by just applying a mask on the values 209 | # that are outside our window. 210 | attention_mask = compute_sliding_window_mask(sequence_length, 211 | self._attention_window) 212 | attention = jnp.where(attention_mask, attention, 213 | jnp.finfo(jnp.float32).min) 214 | 215 | if mask is not None: 216 | attention = jnp.where(mask, attention, jnp.finfo(jnp.float32).min) 217 | 218 | normalized_attention = jnn.softmax(attention) 219 | 220 | output = jnp.einsum('bhtT,bThd->bthd', normalized_attention, v) 221 | output = jnp.reshape(output, (batch_size, sequence_length, num_hiddens)) 222 | return hk.Linear(embedding_size, with_bias=False)(output) 223 | 224 | 225 | class TransformerEncoder(hk.Module): 226 | """Transformer Encoder (Vaswani et al., 2017).""" 227 | 228 | def __init__( 229 | self, 230 | config: TransformerConfig, 231 | shared_embeddings_fn: Optional[Callable[[chex.Array], chex.Array]] = None, 232 | name: Optional[str] = None, 233 | ) -> None: 234 | """Initializes the transformer encoder. 235 | 236 | Args: 237 | config: The hyperparameters used in Transformer architectures. 238 | shared_embeddings_fn: Embedding function that is shared with the decoder. 239 | name: The name of the module. 240 | """ 241 | super().__init__(name=name) 242 | self._config = config 243 | self._shared_embeddings_fn = shared_embeddings_fn 244 | 245 | def __call__(self, x: jnp.ndarray) -> chex.Array: 246 | """Returns the transformer encoder output, shape [B, T, E].""" 247 | if self._config.use_embeddings: 248 | if self._shared_embeddings_fn is not None: 249 | embeddings = self._shared_embeddings_fn(x) 250 | else: 251 | # Since `x` is one-hot encoded, using hk.Linear is equivalent to 252 | # hk.Embed with hk.EmbedLookupStyle.ONE_HOT. 253 | embs_init = hk.initializers.TruncatedNormal( 254 | stddev=self._config.emb_init_scale) 255 | embeddings = hk.Linear( 256 | self._config.embedding_dim, with_bias=False, w_init=embs_init)( 257 | x) 258 | 259 | embeddings *= jnp.sqrt(self._config.embedding_dim) 260 | 261 | else: 262 | embeddings = x 263 | 264 | batch_size, sequence_length, embedding_size = embeddings.shape 265 | 266 | pos_enc_params = self._config.positional_encodings_params 267 | if ( 268 | self._config.positional_encodings 269 | == pos_encs_lib.PositionalEncodings.SIN_COS 270 | ): 271 | pos_encodings = pos_encs_lib.sinusoid_position_encoding( 272 | sequence_length=sequence_length, 273 | hidden_size=embedding_size, 274 | memory_length=0, 275 | max_timescale=pos_enc_params.max_time, 276 | min_timescale=2, 277 | clamp_length=0, 278 | causal=True, 279 | ) 280 | h = embeddings + pos_encodings 281 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 282 | else: 283 | h = embeddings 284 | 285 | # The causal mask is shared across heads. 286 | if self._config.causal_masking: 287 | causal_mask = jnp.tril( 288 | jnp.ones((batch_size, 1, sequence_length, sequence_length)) 289 | ) 290 | else: 291 | causal_mask = None 292 | 293 | for _ in range(self._config.num_layers): 294 | attention = MultiHeadDotProductAttention( 295 | num_heads=self._config.num_heads, 296 | num_hiddens_per_head=self._config.num_hiddens_per_head, 297 | positional_encodings=self._config.positional_encodings, 298 | positional_encodings_params=pos_enc_params, 299 | attention_window=self._config.attention_window, 300 | )( 301 | inputs_q=h, 302 | inputs_kv=h, 303 | mask=causal_mask, 304 | causal=self._config.causal_masking, 305 | ) 306 | attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, 307 | attention) 308 | attention = layer_norm(h + attention) 309 | 310 | # Position-wise feedforward network. 311 | h = hk.Linear(self._config.embedding_dim * self._config.widening_factor)( 312 | attention) 313 | h = jnn.relu(h) 314 | h = hk.Linear(self._config.embedding_dim)(h) 315 | 316 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 317 | h = layer_norm(h + attention) 318 | return h 319 | 320 | 321 | class TransformerDecoder(hk.Module): 322 | """Transformer Decoder (Vaswani et al., 2017).""" 323 | 324 | def __init__( 325 | self, 326 | config: TransformerConfig, 327 | shared_embeddings_fn: Optional[Callable[[chex.Array], chex.Array]] = None, 328 | name: Optional[str] = None, 329 | ) -> None: 330 | """Initializes the Transformer decoder. 331 | 332 | Args: 333 | config: The hyperparameters used in Transformer architectures. 334 | shared_embeddings_fn: Embedding function that is shared with the encoder. 335 | name: The name of the module. 336 | """ 337 | super().__init__(name=name) 338 | self._config = config 339 | self._shared_embeddings_fn = shared_embeddings_fn 340 | 341 | def __call__(self, encoded: chex.Array, targets: chex.Array) -> chex.Array: 342 | """Returns the transformer decoder output, shape [B, T_O, E]. 343 | 344 | Args: 345 | encoded: The output of the encoder, shape [B, T_I, E]. 346 | targets: The one-hot encoded target values, shape [B, T_O, 2]. 347 | """ 348 | targets = shift_right(targets, self._config.output_size) 349 | 350 | if self._config.use_embeddings: 351 | if self._shared_embeddings_fn is not None: 352 | output_embeddings = self._shared_embeddings_fn(targets) 353 | else: 354 | # Since `x` is one-hot encoded, using hk.Linear is equivalent to 355 | # hk.Embed with hk.EmbedLookupStyle.ONE_HOT. 356 | embs_init = hk.initializers.TruncatedNormal( 357 | stddev=self._config.emb_init_scale) 358 | output_embeddings = hk.Linear( 359 | self._config.embedding_dim, with_bias=False, w_init=embs_init)( 360 | targets) 361 | 362 | output_embeddings *= jnp.sqrt(self._config.embedding_dim) 363 | 364 | else: 365 | output_embeddings = targets 366 | 367 | batch_size, output_sequence_length, embedding_size = output_embeddings.shape 368 | 369 | if ( 370 | self._config.positional_encodings 371 | == pos_encs_lib.PositionalEncodings.SIN_COS 372 | ): 373 | pos_encodings = pos_encs_lib.sinusoid_position_encoding( 374 | sequence_length=output_sequence_length, 375 | hidden_size=embedding_size, 376 | memory_length=0, 377 | max_timescale=self._config.positional_encodings_params.max_time, 378 | min_timescale=2, 379 | clamp_length=0, 380 | causal=True, 381 | ) 382 | h = output_embeddings + pos_encodings 383 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 384 | else: 385 | h = output_embeddings 386 | 387 | # The causal mask is shared across heads. 388 | causal_mask = jnp.tril( 389 | jnp.ones( 390 | (batch_size, 1, output_sequence_length, output_sequence_length))) 391 | 392 | for _ in range(self._config.num_layers): 393 | self_attention = MultiHeadDotProductAttention( 394 | num_heads=self._config.num_heads, 395 | num_hiddens_per_head=self._config.num_hiddens_per_head, 396 | positional_encodings=self._config.positional_encodings, 397 | positional_encodings_params=self._config.positional_encodings_params, 398 | attention_window=self._config.attention_window, 399 | )(inputs_q=h, inputs_kv=h, mask=causal_mask, causal=True) 400 | self_attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, 401 | self_attention) 402 | self_attention = layer_norm(h + self_attention) 403 | 404 | cross_attention = MultiHeadDotProductAttention( 405 | num_heads=self._config.num_heads, 406 | num_hiddens_per_head=self._config.num_hiddens_per_head, 407 | positional_encodings=self._config.positional_encodings, 408 | positional_encodings_params=self._config.positional_encodings_params, 409 | attention_window=self._config.attention_window, 410 | )(inputs_q=self_attention, inputs_kv=encoded, causal=True) 411 | cross_attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, 412 | cross_attention) 413 | cross_attention = layer_norm(self_attention + cross_attention) 414 | 415 | # Position-wise feedforward network. 416 | h = hk.Linear(self._config.embedding_dim * self._config.widening_factor)( 417 | cross_attention) 418 | h = jnn.relu(h) 419 | h = hk.Linear(self._config.embedding_dim)(h) 420 | 421 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 422 | h = layer_norm(h + cross_attention) 423 | 424 | return h 425 | 426 | 427 | class Transformer(hk.Module): 428 | """Transformer (Vaswani et al., 2017).""" 429 | 430 | def __init__(self, config: TransformerConfig, name: Optional[str] = None): 431 | """Initializes the Transformer. 432 | 433 | Args: 434 | config: The hyperparameters used in Transformer architectures. 435 | name: The name of the module. 436 | """ 437 | super().__init__(name=name) 438 | shared_embeddings_fn = None 439 | 440 | if config.share_embeddings: 441 | shared_embeddings_fn = hk.Linear( 442 | config.embedding_dim, 443 | with_bias=False, 444 | w_init=hk.initializers.TruncatedNormal(stddev=config.emb_init_scale), 445 | name='shared_embeddings') 446 | 447 | self._encoder = TransformerEncoder(config, shared_embeddings_fn) 448 | self._decoder = TransformerDecoder(config, shared_embeddings_fn) 449 | 450 | def __call__(self, inputs: chex.Array, targets: chex.Array) -> chex.Array: 451 | return self._decoder(self._encoder(inputs), targets) 452 | 453 | 454 | def make_transformer_encoder( 455 | output_size: int, 456 | embedding_dim: int = 64, 457 | num_layers: int = 5, 458 | num_heads: int = 8, 459 | num_hiddens_per_head: Optional[int] = None, 460 | dropout_prob: float = 0.1, 461 | emb_init_scale: float = 0.02, 462 | use_embeddings: bool = True, 463 | share_embeddings: bool = False, 464 | attention_window: Optional[int] = None, 465 | positional_encodings: Optional[pos_encs_lib.PositionalEncodings] = None, 466 | positional_encodings_params: Optional[ 467 | pos_encs_lib.PositionalEncodingsParams 468 | ] = None, 469 | widening_factor: int = 4, 470 | return_all_outputs: bool = False, 471 | causal_masking: bool = False, 472 | ) -> Callable[[chex.Array], chex.Array]: 473 | """Returns a transformer encoder model.""" 474 | if positional_encodings is None: 475 | positional_encodings = pos_encs_lib.PositionalEncodings.SIN_COS 476 | positional_encodings_params = pos_encs_lib.SinCosParams() 477 | elif positional_encodings_params is None: 478 | raise ValueError('No parameters for positional encodings are passed.') 479 | config = TransformerConfig( 480 | output_size=output_size, 481 | embedding_dim=embedding_dim, 482 | num_layers=num_layers, 483 | num_heads=num_heads, 484 | num_hiddens_per_head=num_hiddens_per_head, 485 | dropout_prob=dropout_prob, 486 | emb_init_scale=emb_init_scale, 487 | use_embeddings=use_embeddings, 488 | share_embeddings=share_embeddings, 489 | attention_window=attention_window, 490 | positional_encodings=positional_encodings, 491 | positional_encodings_params=positional_encodings_params, 492 | widening_factor=widening_factor, 493 | causal_masking=causal_masking, 494 | ) 495 | 496 | def transformer_encoder(inputs: chex.Array) -> chex.Array: 497 | output = TransformerEncoder(config)(inputs) 498 | if not return_all_outputs: 499 | output = output[:, -1, :] 500 | return hk.Linear(output_size)(output) 501 | 502 | return transformer_encoder 503 | 504 | 505 | def make_transformer( 506 | output_size: int, 507 | embedding_dim: int = 64, 508 | num_layers: int = 5, 509 | num_heads: int = 8, 510 | num_hiddens_per_head: Optional[int] = None, 511 | dropout_prob: float = 0.1, 512 | emb_init_scale: float = 0.02, 513 | use_embeddings: bool = True, 514 | share_embeddings: bool = False, 515 | attention_window: Optional[int] = None, 516 | positional_encodings: Optional[pos_encs_lib.PositionalEncodings] = None, 517 | positional_encodings_params: Optional[ 518 | pos_encs_lib.PositionalEncodingsParams 519 | ] = None, 520 | widening_factor: int = 4, 521 | return_all_outputs: bool = False, 522 | ) -> Callable[[chex.Array, chex.Array], chex.Array]: 523 | """Returns a transformer model.""" 524 | if positional_encodings is None: 525 | positional_encodings = pos_encs_lib.PositionalEncodings.SIN_COS 526 | positional_encodings_params = pos_encs_lib.SinCosParams() 527 | elif positional_encodings_params is None: 528 | raise ValueError('No parameters for positional encodings are passed.') 529 | config = TransformerConfig( 530 | output_size=output_size, 531 | embedding_dim=embedding_dim, 532 | num_layers=num_layers, 533 | num_heads=num_heads, 534 | num_hiddens_per_head=num_hiddens_per_head, 535 | dropout_prob=dropout_prob, 536 | emb_init_scale=emb_init_scale, 537 | use_embeddings=use_embeddings, 538 | share_embeddings=share_embeddings, 539 | attention_window=attention_window, 540 | positional_encodings=positional_encodings, 541 | positional_encodings_params=positional_encodings_params, 542 | widening_factor=widening_factor, 543 | ) 544 | 545 | def transformer(inputs: chex.Array, targets: chex.Array) -> chex.Array: 546 | output = Transformer(config)(inputs, targets) 547 | if not return_all_outputs: 548 | output = output[:, -1, :] 549 | return hk.Linear(output_size)(output) 550 | 551 | return transformer 552 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | dm-haiku 3 | dm-tree 4 | git+https://github.com/deepmind/einshape 5 | jax 6 | numpy 7 | optax 8 | tqdm 9 | typing-extensions 10 | -------------------------------------------------------------------------------- /tasks/cs/binary_addition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Add two binary numbers.""" 17 | 18 | import random 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | 28 | 29 | def numbers_to_variable_length_binary( 30 | numbers: Sequence[int], 31 | lengths: Sequence[int], 32 | little_endian: bool = True, 33 | ) -> list[list[int]]: 34 | """Returns the binary notation of a certain length for a sequence of numbers. 35 | 36 | Args: 37 | numbers: The numbers to be converted to binary. 38 | lengths: The lengths of the binary representations (every number uses its 39 | own length). This argument has no effect if the binary representation is 40 | longer than the specified length. 41 | little_endian: Whether to use little- or big-endian notation. 42 | """ 43 | binary_strings = [f'{num:b}'.zfill(len) for num, len in zip(numbers, lengths)] 44 | 45 | if little_endian: 46 | binary_strings = [bin[::-1] for bin in binary_strings] 47 | 48 | return [list(map(int, bin)) for bin in binary_strings] 49 | 50 | 51 | def numbers_to_fixed_length_binary( 52 | numbers: Sequence[int], 53 | length: int, 54 | little_endian: bool = True, 55 | ) -> list[list[int]]: 56 | """Returns the binary notation of a certain length for a sequence of numbers. 57 | 58 | Args: 59 | numbers: The numbers to be converted to binary. 60 | length: The length of the binary representations (all numbers use the same 61 | length). This argument has no effect if the binary representation is 62 | longer than the specified length. 63 | little_endian: Whether to use little- or big-endian notation. 64 | """ 65 | return numbers_to_variable_length_binary( 66 | numbers=numbers, 67 | lengths=[length] * len(numbers), 68 | little_endian=little_endian, 69 | ) 70 | 71 | 72 | def expression_from_numbers( 73 | numbers_n: Sequence[list[int]], 74 | numbers_m: Sequence[list[int]], 75 | ) -> list[list[int]]: 76 | """Returns an expression with a placeholder value to denote the operation.""" 77 | return [n + [2] + m for n, m in zip(numbers_n, numbers_m)] 78 | 79 | 80 | class BinaryAddition(task.GeneralizationTask): 81 | """A task with the goal of summing two numbers in binary (little-endian). 82 | 83 | The input is a string of the form `first_number+second_number` in 84 | (little-endian) binary notation (e.g., `01001+011`). The goal of the agent is 85 | to output the result, also in (little-endian) binary form (i.e., in the 86 | example `18 + 6 = 24 = 00011`). The output is padded with 0s to match the 87 | input length, and the end of the sum is denoted with a termination token 88 | (i.e., the output has values in `{0, 1, 2}`). 89 | 90 | Examples: 91 | 001 + 01101 = 010112000 (4 + 22 = 26) 92 | 1001 + 000001 = 10010120000 (9 + 32 = 41) 93 | """ 94 | 95 | def _sample_expressions_and_results( 96 | self, 97 | batch_size: int, 98 | length: int, 99 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]: 100 | """Samples pairs of numbers and sums them in (little-endian) binary. 101 | 102 | We use Python's bignums, which can represent arbitrary-precision integers to 103 | perform addition of two potentially very large values (roughly of the size 104 | `2 ** (length // 2)`). 105 | 106 | Args: 107 | batch_size: The number of expressions and results to sample. 108 | length: The length of the input expression containing the two numbers and 109 | the separation token. 110 | 111 | Returns: 112 | The expression and the sum of the two numbers. The expression has the 113 | format: `[first_number, 2, second_number]`, where the numbers are in 114 | (little-endian) binary notation. The sum is also in (little-endian) binary 115 | notation, without leading (i.e., ending) zeros. 116 | """ 117 | # If `length <= 2`, we just sample a binary value and return it (without 118 | # leading zeros in little-endian notation). 119 | if length <= 2: 120 | # Since `length <= 2`, we can use `np.random`` without overflow errors. 121 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size)) 122 | expressions = numbers_to_fixed_length_binary(numbers, length) 123 | results = numbers_to_fixed_length_binary(numbers, 0) 124 | return expressions, results 125 | 126 | # We only use `length - 1` tokens for the two values to account for the `+`. 127 | length_n = np.random.randint(1, length - 1, size=(batch_size,)) 128 | length_m = length - 1 - length_n 129 | 130 | integer_n = [random.randint(1, 2**int(len_n) - 1) for len_n in length_n] 131 | integer_m = [random.randint(1, 2**int(len_m) - 1) for len_m in length_m] 132 | 133 | binary_n = numbers_to_variable_length_binary(integer_n, length_n) 134 | binary_m = numbers_to_variable_length_binary(integer_m, length_m) 135 | 136 | expressions = expression_from_numbers(binary_n, binary_m) 137 | 138 | integer_sum = list(map(sum, zip(integer_n, integer_m))) 139 | results = numbers_to_fixed_length_binary(integer_sum, length=0) 140 | 141 | return expressions, results 142 | 143 | def sample_batch( 144 | self, 145 | rng: chex.PRNGKey, 146 | batch_size: int, 147 | length: int, 148 | ) -> task.Batch: 149 | """Returns a batch of binary additions and their results.""" 150 | del rng 151 | 152 | expressions, results = self._sample_expressions_and_results( 153 | batch_size=batch_size, length=length) 154 | # Append the termination token to the result and pad the result with zeros 155 | # to match the output length (accounting for the termination token). 156 | results = [res + [2] + [0] * (length - len(res)) for res in results] 157 | 158 | expressions = jnp.array(expressions, dtype=jnp.int32) 159 | results = jnp.array(results, dtype=jnp.int32) 160 | 161 | return { 162 | 'input': jnn.one_hot(expressions, self.input_size), 163 | 'output': jnn.one_hot(results, self.output_size), 164 | } 165 | 166 | @property 167 | def input_size(self) -> int: 168 | """Returns the input size for the models.""" 169 | return 3 170 | 171 | @property 172 | def output_size(self) -> int: 173 | """Returns the output size for the models.""" 174 | return 3 175 | 176 | def output_length(self, input_length: int) -> int: 177 | return input_length + 1 178 | 179 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 180 | """Computes a mask that ignores everything after the termination token. 181 | 182 | Args: 183 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 184 | 185 | Returns: 186 | The mask of shape `(batch_size, output_length)`. 187 | """ 188 | batch_size, length, _ = target.shape 189 | termination_indices = jnp.argmax( 190 | jnp.argmax(target, axis=-1), 191 | axis=-1, 192 | keepdims=True, 193 | ) 194 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 195 | return indices <= termination_indices 196 | -------------------------------------------------------------------------------- /tasks/cs/binary_multiplication.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Multiply two binary numbers.""" 17 | 18 | import random 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition 28 | 29 | 30 | class BinaryMultiplication(task.GeneralizationTask): 31 | """A task with the goal of multiplying two numbers in binary (little-endian). 32 | 33 | The input is a string of the form `first_number£second_number` in 34 | (little-endian) binary notation (e.g., `01001*011`). The goal of the agent is 35 | to output the result, also in (little-endian) binary form (i.e., in the 36 | example `18 * 6 = 108 = 00110011`). The output is padded with 0s to match the 37 | input length, and the end of the product is denoted with a termination token 38 | (i.e., the output has values in `{0, 1, 2}`). 39 | 40 | Examples: 41 | 001 * 01101 = 000110120 (4 * 22 = 88) 42 | 1001 * 000001 = 00000100120 (9 * 32 = 288) 43 | """ 44 | 45 | def _sample_expressions_and_results( 46 | self, 47 | batch_size: int, 48 | length: int, 49 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]: 50 | """Samples pairs of numbers and multiplies them in (little-endian) binary. 51 | 52 | We use Python's bignums, which can represent arbitrary-precision integers to 53 | perform multiplication of two potentially very large values (roughly of the 54 | size `2 ** (length // 2)`). 55 | 56 | Args: 57 | batch_size: The number of expressions and results to sample. 58 | length: The length of the input expression containing the two numbers and 59 | the separation token. 60 | 61 | Returns: 62 | The expression and the product of the two numbers. The expression has the 63 | format: `[first_number, 2, second_number]`, where the numbers are in 64 | (little-endian) binary notation. The product is also in (little-endian) 65 | binary notation, without leading (i.e., ending) zeros. 66 | """ 67 | # If `length <= 2`, we just sample a binary sequence for the expression and 68 | # arbitrarily set the result to a fixed value (`[]` for `length == 1` and 69 | # `[0]` for `length == 2`) to maintain the invariant that the result has 70 | # length has most `length - 1`. 71 | if length <= 2: 72 | # Since `length <= 2`, we can use `np.random`` without overflow errors. 73 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size)) 74 | expressions = binary_addition.numbers_to_fixed_length_binary( 75 | numbers, length) 76 | return expressions, [[0] * (length - 1)] * batch_size 77 | 78 | # We only use `length - 1` tokens for the two values to account for the `*`. 79 | length_n = np.random.randint(1, length - 1, size=(batch_size,)) 80 | length_m = length - 1 - length_n 81 | 82 | integer_n = [random.randint(1, 2**int(len_n) - 1) for len_n in length_n] 83 | integer_m = [random.randint(1, 2**int(len_m) - 1) for len_m in length_m] 84 | 85 | binary_n = binary_addition.numbers_to_variable_length_binary( 86 | integer_n, length_n) 87 | binary_m = binary_addition.numbers_to_variable_length_binary( 88 | integer_m, length_m) 89 | 90 | expressions = binary_addition.expression_from_numbers(binary_n, binary_m) 91 | 92 | integer_prod = [int_n * int_m for int_n, int_m in zip(integer_n, integer_m)] 93 | results = binary_addition.numbers_to_fixed_length_binary( 94 | integer_prod, length=0) 95 | 96 | return expressions, results 97 | 98 | def sample_batch( 99 | self, 100 | rng: chex.PRNGKey, 101 | batch_size: int, 102 | length: int, 103 | ) -> task.Batch: 104 | """Returns a batch of binary multiplications and their results.""" 105 | del rng 106 | 107 | expressions, results = self._sample_expressions_and_results( 108 | batch_size=batch_size, length=length) 109 | # Append the termination token to the result and pad the result with zeros 110 | # to match the output length (accounting for the termination token). The 111 | # binary representation of the result will have at most length 112 | # `#(first_number) + #(second_number)`, where #() denotes the number of 113 | # digits of the binary notation. Since we use the token `2` to separate the 114 | # two numbers in the expression, the result will have length at most 115 | # `length - 1`, and thus by appending the termination token above it will 116 | # have length at most `length`, as desired. 117 | results = [res + [2] + [0] * (length - 1 - len(res)) for res in results] 118 | 119 | expressions = jnp.array(expressions, dtype=jnp.int32) 120 | results = jnp.array(results, dtype=jnp.int32) 121 | 122 | return { 123 | 'input': jnn.one_hot(expressions, self.input_size), 124 | 'output': jnn.one_hot(results, self.output_size), 125 | } 126 | 127 | @property 128 | def input_size(self) -> int: 129 | """Returns the input size for the models.""" 130 | return 3 131 | 132 | @property 133 | def output_size(self) -> int: 134 | """Returns the output size for the models.""" 135 | return 3 136 | 137 | def output_length(self, input_length: int) -> int: 138 | return input_length 139 | 140 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 141 | """Computes a mask that ignores everything after the termination token. 142 | 143 | Args: 144 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 145 | 146 | Returns: 147 | The mask of shape `(batch_size, output_length)`. 148 | """ 149 | batch_size, length, _ = target.shape 150 | termination_indices = jnp.argmax( 151 | jnp.argmax(target, axis=-1), 152 | axis=-1, 153 | keepdims=True, 154 | ) 155 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 156 | return indices <= termination_indices 157 | -------------------------------------------------------------------------------- /tasks/cs/bucket_sort.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sort tokens from a fixed alphabet (i.e., bucket sort).""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | from jax import nn as jnn 23 | from jax import numpy as jnp 24 | from jax import random as jrandom 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | 28 | 29 | class BucketSort(task.GeneralizationTask): 30 | """A task with the goal of sorting tokens from a fixed alphabet. 31 | 32 | The input string is composed of tokens from a fixed-size alphabet, i.e., 33 | `{0, 1, ..., vocab_size - 1}`, and the goal is to return the sorted string (in 34 | lexicographically increasing order). 35 | 36 | Examples: 37 | 10204112 -> 00111224 (with `vocab_size = 5`) 38 | 1110001 -> 0001111 (with `vocab_size = 2`) 39 | """ 40 | 41 | def __init__(self, vocab_size: int = 5) -> None: 42 | """Initializes the task. 43 | 44 | Args: 45 | vocab_size: The size of the alphabet. We use 5 in the paper. 46 | """ 47 | self._vocab_size = vocab_size 48 | 49 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 50 | def sample_batch( 51 | self, 52 | rng: chex.PRNGKey, 53 | batch_size: int, 54 | length: int, 55 | ) -> task.Batch: 56 | """Returns a batch of strings and tokens sorted by (inc.) occurrence.""" 57 | strings = jrandom.randint( 58 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size) 59 | sorted_strings = jnp.sort(strings, axis=-1) 60 | 61 | return { 62 | 'input': jnn.one_hot(strings, num_classes=self.input_size), 63 | 'output': jnn.one_hot(sorted_strings, num_classes=self.output_size), 64 | } 65 | 66 | @property 67 | def input_size(self) -> int: 68 | """Returns the input size for the models.""" 69 | return self._vocab_size 70 | 71 | @property 72 | def output_size(self) -> int: 73 | """Returns the output size for the models.""" 74 | return self._vocab_size 75 | 76 | def output_length(self, input_length: int) -> int: 77 | """Returns the output length for a given input length.""" 78 | return input_length 79 | -------------------------------------------------------------------------------- /tasks/cs/compute_sqrt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the floor of the square root of a binary number.""" 17 | 18 | import math 19 | import random 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | 25 | from neural_networks_chomsky_hierarchy.tasks import task 26 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition 27 | 28 | 29 | class ComputeSqrt(task.GeneralizationTask): 30 | """A task with the goal of computing the square root of a binary number. 31 | 32 | The input is a number in binary (big-endian), and the output is the floor of 33 | the square root of this number, also in binary. 34 | Note the output length ie the length of the square root in binary is always 35 | ceil(input_length / 2) (because log(sqrt(x)) = 1/2 log(x)). 36 | 37 | Examples: 38 | 100101 = 37 -> square root is 6.08... -> floor(6.08) = 6 -> 101 39 | 111 = 7 -> square root is 2.64 -> floor(2.64) = 2 -> 10 40 | """ 41 | 42 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int, 43 | length: int) -> task.Batch: 44 | """Returns a batch of binary numbers and their square roots, in binary.""" 45 | del rng 46 | numbers = [random.randint(1, 2**length - 1) for _ in range(batch_size)] 47 | binary_numbers = binary_addition.numbers_to_fixed_length_binary( 48 | numbers, length=length, little_endian=False) 49 | 50 | sqrts = list(map(math.isqrt, numbers)) 51 | binary_sqrts = binary_addition.numbers_to_fixed_length_binary( 52 | sqrts, length=self.output_length(length), little_endian=False) 53 | 54 | binary_numbers = jnp.array(binary_numbers, jnp.int32) 55 | binary_sqrts = jnp.array(binary_sqrts, jnp.int32) 56 | 57 | inputs = jnn.one_hot(binary_numbers, self.input_size) 58 | output = jnn.one_hot(binary_sqrts, self.output_size) 59 | return {'input': inputs, 'output': output} 60 | 61 | @property 62 | def input_size(self) -> int: 63 | """Returns the input size for the models.""" 64 | return 2 65 | 66 | @property 67 | def output_size(self) -> int: 68 | """Returns the output size for the models.""" 69 | return 2 70 | 71 | def output_length(self, input_length: int) -> int: 72 | return math.ceil(input_length / 2) 73 | -------------------------------------------------------------------------------- /tasks/cs/duplicate_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Duplicate a string.""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | import jax.random as jrandom 24 | 25 | from neural_networks_chomsky_hierarchy.tasks import task 26 | 27 | 28 | class DuplicateString(task.GeneralizationTask): 29 | """A task with the goal of duplicating a string. 30 | 31 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 32 | output is the same string outputted twice without any separator, ie: 33 | s_1 ... s_n s_1 ... s_n 34 | 35 | Examples: 36 | 101 -> 101 101 37 | 111111 -> 111111 111111 38 | 39 | In the paper, we use only binary strings (ie S = {0, 1}). 40 | Note that the sampling is jittable so this task is fast. 41 | """ 42 | 43 | def __init__(self, vocab_size: int = 2) -> None: 44 | """Initializes the remember_string task. 45 | 46 | Args: 47 | vocab_size: The size of the alphabet. We use 2 in the paper. 48 | """ 49 | self._vocab_size = vocab_size 50 | 51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 52 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 53 | length: int) -> task.Batch: 54 | """Returns a batch of strings and their copies.""" 55 | strings = jrandom.randint( 56 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size) 57 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size) 58 | output = jnp.concatenate([one_hot_strings, one_hot_strings], axis=1) 59 | return {"input": one_hot_strings, "output": output} 60 | 61 | @property 62 | def input_size(self) -> int: 63 | """Returns the input size for the models.""" 64 | return self._vocab_size 65 | 66 | @property 67 | def output_size(self) -> int: 68 | """Returns the output size for the models.""" 69 | return self._vocab_size 70 | 71 | def output_length(self, input_length: int) -> int: 72 | """Returns the output length for a given input length.""" 73 | return 2 * input_length 74 | -------------------------------------------------------------------------------- /tasks/cs/missing_duplicate_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Predict the missing symbol in a duplicated string.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | 28 | 29 | class MissingDuplicateString(task.GeneralizationTask): 30 | """A task with the goal of finding the missing symbol in a duplicated string. 31 | 32 | Given a binary string that is presented twice with exactly one element omitted 33 | (denoted by the placeholder token `2`), predict the value of that element. 34 | Thus, an agent trying to solve this task needs to recognize the underlying 35 | duplicated string to be able to produce the correct output. 36 | If the length is odd, the duplicated strings of length `length // 2` are 37 | padded with the empty token `3`. 38 | 39 | Examples 40 | 01100210 -> 1 (the substring is 0110, so the missing value is 1) 41 | 1011213 -> 0 (the subtring is 101, so the missing value is 0) 42 | """ 43 | 44 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 45 | def sample_batch( 46 | self, 47 | rng: chex.PRNGKey, 48 | batch_size: int, 49 | length: int, 50 | ) -> task.Batch: 51 | """Returns a batch of strings and the expected class.""" 52 | # For `length == 1`, we cannot meaningfully define substrings of length 53 | # `length // 2`, so we arbitrarily set the inputs and outputs to `1`. 54 | if length == 1: 55 | return { 56 | 'input': 57 | jnn.one_hot( 58 | jnp.ones((batch_size, length)), num_classes=self.input_size), 59 | 'output': 60 | jnn.one_hot( 61 | jnp.ones((batch_size,)), num_classes=self.output_size), 62 | } 63 | 64 | strings_rng, indices_rng = jrandom.split(rng) 65 | strings = jrandom.randint( 66 | strings_rng, shape=(batch_size, length // 2), minval=0, maxval=2) 67 | duplicated_strings = jnp.concatenate((strings, strings), axis=-1) 68 | indices = jrandom.randint( 69 | indices_rng, 70 | shape=(batch_size,), 71 | minval=0, 72 | maxval=duplicated_strings.shape[1]) 73 | output = jax.vmap(lambda x, y: x[y])(duplicated_strings, indices) 74 | masked_strings = jax.vmap(lambda x, y: x.at[y].set(2))(duplicated_strings, 75 | indices) 76 | 77 | # If `length` is odd, we pad the strings with the empty token `3` at the end 78 | # to ensure that the final input length is equal to `length` given the two 79 | # substrings of length `length // 2`. 80 | padding = jnp.full((batch_size, length % 2), fill_value=3) 81 | padded_strings = jnp.concatenate((masked_strings, padding), axis=-1) 82 | 83 | return { 84 | 'input': jnn.one_hot(padded_strings, num_classes=self.input_size), 85 | 'output': jnn.one_hot(output, num_classes=self.output_size) 86 | } 87 | 88 | @property 89 | def input_size(self) -> int: 90 | """Returns the input size for the models.""" 91 | return 4 92 | 93 | @property 94 | def output_size(self) -> int: 95 | """Returns the output size for the models.""" 96 | return 2 97 | -------------------------------------------------------------------------------- /tasks/cs/odds_first.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sort a string by the parity of the indices (odd indices first).""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | import jax.random as jrandom 24 | 25 | from neural_networks_chomsky_hierarchy.tasks import task 26 | 27 | 28 | class OddsFirst(task.GeneralizationTask): 29 | """A task with the goal of outputting a string's tokens at odd indices first. 30 | 31 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 32 | output is the same string, but where the values at odd indexes have been put 33 | first: s_1 s_3 s_5 ... s_2 s_4 s_6 ... 34 | 35 | Examples: 36 | 00110101 -> 0100 0111 37 | 110 -> 10 1 38 | 39 | In the paper, we use only binary strings (ie S = {0, 1}). 40 | Note that the sampling is jittable so this task is fast. 41 | """ 42 | 43 | def __init__(self, vocab_size: int = 2) -> None: 44 | """Initializes the odds_first task. 45 | 46 | Args: 47 | vocab_size: The size of the alphabet. We use 2 in the paper. 48 | """ 49 | self._vocab_size = vocab_size 50 | 51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 52 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 53 | length: int) -> task.Batch: 54 | """Returns a batch of strings and their outputs.""" 55 | strings = jrandom.randint( 56 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size) 57 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size) 58 | output = jnp.concatenate( 59 | [one_hot_strings[:, 1::2], one_hot_strings[:, ::2]], axis=1) 60 | return {"input": one_hot_strings, "output": output} 61 | 62 | @property 63 | def input_size(self) -> int: 64 | """Returns the input size for the model.""" 65 | return self._vocab_size 66 | 67 | @property 68 | def output_size(self) -> int: 69 | """Returns the output size for the model.""" 70 | return self._vocab_size 71 | 72 | def output_length(self, input_length: int) -> int: 73 | """Returns the output length for the model.""" 74 | return input_length 75 | -------------------------------------------------------------------------------- /tasks/dcf/modular_arithmetic_brackets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Modular arithmetic with brackets.""" 17 | 18 | import collections 19 | from typing import Sequence 20 | 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import tqdm 25 | import tree 26 | 27 | from neural_networks_chomsky_hierarchy.tasks import task 28 | 29 | 30 | def generate_one_expression_and_result( 31 | modulus: int, length: int, mult: bool = False 32 | ) -> tuple[str, int]: 33 | """Returns a modular arithmetic expression with brackets, and its result. 34 | 35 | The values in the expression are in {0, 1, ..., modulus-1}. The allowed 36 | operations are either {+, -} (mult=False) or {+, -, *} (mult=True). 37 | 38 | Args: 39 | modulus: The modulus to use for the expression. 40 | length: The length of the expression. 41 | mult: Whether to include the multiplication operator in the expressions. 42 | 43 | Raises: 44 | ValueError if length < 1. 45 | """ 46 | 47 | # Generates a terminal (digit). 48 | def gen_terminal(): 49 | terminal = np.random.randint(low=0, high=modulus) 50 | return str(terminal), terminal 51 | 52 | # If length is less than 1, issue an error. 53 | if length < 1: 54 | raise ValueError( 55 | f'Can\'t generate expressions of length < 1. Got {length}.') 56 | 57 | # If length is less than 5, generate a digit d, -d, (d), or (-d). 58 | if length == 1: 59 | return gen_terminal() 60 | elif length == 2: 61 | term_str, term_val = gen_terminal() 62 | return f'-{term_str}', -term_val % modulus 63 | elif length == 3: 64 | term_str, term_val = gen_terminal() 65 | return f'({term_str})', term_val % modulus 66 | elif length == 4: 67 | term_str, term_val = gen_terminal() 68 | return f'(-{term_str})', -term_val % modulus 69 | 70 | # First split the length into a left and right part. 71 | left_length = np.random.randint(low=1, high=length - 3) 72 | right_length = length - (left_length + 3) 73 | left_str, left_val = generate_one_expression_and_result( 74 | modulus, left_length, mult=mult) 75 | right_str, right_val = generate_one_expression_and_result( 76 | modulus, right_length, mult=mult) 77 | 78 | # Now sample an operator and return. 79 | maxop = 3 if mult else 2 80 | op = np.random.randint(low=0, high=maxop) 81 | if op == 0: 82 | return '(' + left_str + '+' + right_str + ')', (left_val + 83 | right_val) % modulus 84 | elif op == 1: 85 | return '(' + left_str + '-' + right_str + ')', (left_val - 86 | right_val) % modulus 87 | else: 88 | return '(' + left_str + '*' + right_str + ')', (left_val * 89 | right_val) % modulus 90 | 91 | 92 | def generate_raw_dataset( 93 | n: int, 94 | lengths: Sequence[int], 95 | modulus: int, 96 | mult: bool = False, 97 | with_tqdm: bool = False, 98 | ) -> dict[int, dict[str, np.ndarray]]: 99 | """Generates a dataset of maths expressions with brackets, and their results. 100 | 101 | Args: 102 | n: The number of datapoints in the dataset. 103 | lengths: The lengths of the sequences to generate. n is evenly distributed 104 | over these lengths. 105 | modulus: Modulus used to compute the expressions. 106 | mult: Whether to include the multiplication operator in the expressions. 107 | with_tqdm: As the computation might be long, whether to add a tqdm progress 108 | bar or not. 109 | 110 | Returns: 111 | A dict which keys are the passed lengths, and the values are dicts with keys 112 | 'equations' and 'solutions', and values are the data numpy arrays. 113 | """ 114 | alphabet_to_int = { 115 | '+': modulus, 116 | '-': modulus + 1, 117 | '*': modulus + 2, 118 | '(': modulus + 3, 119 | ')': modulus + 4, 120 | 'x': modulus + 5, 121 | '=': modulus + 6, 122 | } 123 | for x in range(modulus): 124 | alphabet_to_int[str(x)] = x 125 | 126 | make_default_dict = lambda: {'expressions': [], 'results': []} 127 | sequences = collections.defaultdict(make_default_dict) 128 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths 129 | for length in range_lengths: 130 | for _ in range(n // len(lengths)): 131 | seq, label = generate_one_expression_and_result(modulus, length, mult) 132 | seq = [alphabet_to_int[x] for x in seq] 133 | sequences[length]['expressions'].append(seq) 134 | sequences[length]['results'].append(label) 135 | sequences = tree.traverse( 136 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l, 137 | sequences, 138 | top_down=False, 139 | ) 140 | return dict(sequences) 141 | 142 | 143 | class ModularArithmeticBrackets(task.GeneralizationTask): 144 | """A task with the goal of reducing an arithmetic expression with brackets.""" 145 | 146 | def __init__(self, modulus: int = 5, mult: bool = False) -> None: 147 | """Initializes the modular arithmetic task. 148 | 149 | Args: 150 | modulus: The modulus used for the computation. We use 5 in the paper. 151 | mult: Whether to add multiplication or use only '+' and '-'. 152 | """ 153 | self._modulus = modulus 154 | self._mult = mult 155 | 156 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 157 | length: int) -> task.Batch: 158 | """Returns a batch of inputs/outputs.""" 159 | del rng 160 | batch = generate_raw_dataset( 161 | batch_size, lengths=[length], modulus=self._modulus, 162 | mult=self._mult)[length] 163 | inputs = jnn.one_hot(batch['expressions'], self.input_size) 164 | output = jnn.one_hot(batch['results'], self.output_size) 165 | return {'input': inputs, 'output': output} 166 | 167 | @property 168 | def input_size(self) -> int: 169 | """Returns the input size for the models.""" 170 | return self._modulus + 6 171 | 172 | @property 173 | def output_size(self) -> int: 174 | """Returns the output size for the models.""" 175 | return self._modulus 176 | -------------------------------------------------------------------------------- /tasks/dcf/reverse_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the reverse of an input string.""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from neural_networks_chomsky_hierarchy.tasks import task 24 | from neural_networks_chomsky_hierarchy.tasks.cs import duplicate_string 25 | 26 | 27 | class ReverseString(duplicate_string.DuplicateString): 28 | """A task with the goal of reversing a given string. 29 | 30 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 31 | output is the string, reversed, ie s_n ... s_1. 32 | 33 | Examples: 34 | 011010 -> 010110 35 | 123021 -> 120321 36 | 37 | In the paper, we use only binary strings (ie S = {0, 1}). 38 | Note that the sampling is jittable so this task is fast. 39 | """ 40 | 41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 43 | length: int) -> task.Batch: 44 | """Returns a batch of strings and their reversed version.""" 45 | batch = super().sample_batch(rng, batch_size, length) 46 | batch['output'] = jnp.flip(batch['input'], axis=1) 47 | return batch 48 | 49 | def output_length(self, input_length: int) -> int: 50 | """Returns the output length for a given input length.""" 51 | return input_length 52 | -------------------------------------------------------------------------------- /tasks/dcf/solve_equation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Solve for the value of an unknown variable in an equation.""" 17 | 18 | import collections 19 | from typing import Sequence 20 | 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import tqdm 25 | import tree 26 | 27 | from neural_networks_chomsky_hierarchy.tasks import task 28 | from neural_networks_chomsky_hierarchy.tasks.dcf import modular_arithmetic_brackets as mab 29 | 30 | 31 | def generate_equation_and_solution( 32 | modulus: int, 33 | length: int, 34 | ) -> tuple[str, int]: 35 | """Returns a modular arithmetic equation with brackets, and its solution. 36 | 37 | The values are in {0, 1, ..., modulus-1}, and the unknown 38 | value is x. The allowed operations are either {+, -} (mult=False) or 39 | {+, -, *} (mult=True). 40 | Warning: if mult=True, x might have multiple valid solutions. 41 | 42 | Args: 43 | modulus: The modulus to use for the expression. 44 | length: The length of the expression. 45 | 46 | Raises: 47 | ValueError if the length is < 3. 48 | """ 49 | 50 | # Generate the expression. 51 | expr, val = mab.generate_one_expression_and_result( 52 | modulus, 53 | length - 2, 54 | # We use mult=False by default here, otherwise equations could have 55 | # multiple solutions if the variable 'x' or some expression containing the 56 | # variable is multiplied by 0. 57 | mult=False, 58 | ) 59 | 60 | # Replace random digit with 'x'. 61 | idx = np.random.randint(low=0, high=len(expr)) 62 | digits = [str(n) for n in range(modulus)] 63 | while expr[idx] not in digits: 64 | idx = (idx + 1) % (length - 2) 65 | solution = int(expr[idx]) 66 | equation = f'{expr[:idx]}x{expr[idx + 1:]}={val}' 67 | return equation, solution 68 | 69 | 70 | def generate_raw_dataset( 71 | n: int, 72 | lengths: Sequence[int], 73 | modulus: int, 74 | with_tqdm: bool = False, 75 | ) -> dict[int, dict[str, np.ndarray]]: 76 | """Generates a dataset of equations and their solutions. 77 | 78 | Args: 79 | n: The number of datapoints in the dataset. 80 | lengths: The lengths of the sequences to generate. n is evenly distributed 81 | over these lengths. 82 | modulus: Modulus used to compute the expressions. 83 | with_tqdm: As the computation might be long, whether to add a tqdm progress 84 | bar or not. 85 | 86 | Returns: 87 | A dict which keys are the passed lengths, and the values are dicts with keys 88 | 'equations' and 'solutions', and values are the data numpy arrays. 89 | """ 90 | alphabet_to_int = { 91 | '+': modulus, 92 | '-': modulus + 1, 93 | '(': modulus + 2, 94 | ')': modulus + 3, 95 | 'x': modulus + 4, 96 | '=': modulus + 5, 97 | } 98 | for x in range(modulus): 99 | alphabet_to_int[str(x)] = x 100 | 101 | sequences = collections.defaultdict(lambda: { # pylint: disable=g-long-lambda 102 | 'equations': [], 103 | 'solutions': [] 104 | }) 105 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths 106 | for length in range_lengths: 107 | for _ in range(n // len(lengths)): 108 | seq, label = generate_equation_and_solution(modulus, length) 109 | seq = [alphabet_to_int[x] for x in seq] 110 | sequences[length]['equations'].append(seq) 111 | sequences[length]['solutions'].append(label) 112 | # Convert the list of numbers we have to arrays at the leaves. 113 | sequences = tree.traverse( 114 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l, 115 | sequences, 116 | top_down=False, 117 | ) 118 | return dict(sequences) 119 | 120 | 121 | class SolveEquation(task.GeneralizationTask): 122 | """A task with the goal of solving an modular equation for an unknown. 123 | 124 | Note that the equations do not contain any multiplication as it could lead to 125 | multiple solutions (multiplication by zero). 126 | """ 127 | 128 | def __init__(self, modulus: int = 5) -> None: 129 | """Initializes the modular arithmetic task. 130 | 131 | Args: 132 | modulus: The modulus used for the computation. We use 5 in the paper. 133 | """ 134 | self._modulus = modulus 135 | 136 | def sample_batch( 137 | self, 138 | rng: jnp.ndarray, 139 | batch_size: int, 140 | length: int, 141 | ) -> task.Batch: 142 | """Returns a batch of inputs/outputs.""" 143 | if length < 3: 144 | return { 145 | 'input': 146 | jnn.one_hot( 147 | jnp.zeros((batch_size, length)), num_classes=self.input_size), 148 | 'output': 149 | jnn.one_hot( 150 | jnp.zeros((batch_size,)), num_classes=self.output_size) 151 | } 152 | batch = generate_raw_dataset( 153 | batch_size, 154 | lengths=[length], 155 | modulus=self._modulus, 156 | )[length] 157 | inputs = jnn.one_hot(batch['equations'], self.input_size) 158 | output = jnn.one_hot(batch['solutions'], self.output_size) 159 | return {'input': inputs, 'output': output} 160 | 161 | @property 162 | def input_size(self) -> int: 163 | """Returns the input size for the models.""" 164 | return self._modulus + 6 165 | 166 | @property 167 | def output_size(self) -> int: 168 | """Returns the output size for the models.""" 169 | return self._modulus 170 | -------------------------------------------------------------------------------- /tasks/dcf/stack_manipulation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Manipulate an input stack, using the input actions.""" 17 | 18 | import chex 19 | import jax.nn as jnn 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | from neural_networks_chomsky_hierarchy.tasks import task 24 | 25 | 26 | class StackManipulation(task.GeneralizationTask): 27 | """A task with the goal of following instructions and returning the end stack. 28 | 29 | The input is composed of a stack of 0s and 1s followed by a sequence of 30 | instructions POP/PUSH 0/PUSH 1 (represented by 2s/3s/4s). The input stack is 31 | given bottom-to-top, and the agent needs to execute the instructions given 32 | (left-to-rigth) and output the final stack top-to-bottom (i.e., as if it were 33 | popping the final stack). If a POP action is to be called on an empty stack, 34 | the action is ignored. The output is padded with 0s to match the input length 35 | + 1 (to accommodate for the termination token), and the end of the final stack 36 | is denoted with the termination symbol 2 (i.e., the output has values in {0, 37 | 1, 2}). 38 | 39 | Examples: 40 | 0 1 1 0 PUSH 1 POP POP 41 | initial 0 1 1 0 (the stack is received bottom-to-top) 42 | PUSH 1 0 1 1 0 1 43 | POP 0 1 1 0 44 | POP 0 1 1 45 | -> 1 1 0 2 0 0 0 0 (the stack is returned top-to-bottom) 46 | 47 | 1 1 0 POP POP POP 48 | initial 1 1 0 49 | POP 1 1 50 | POP 1 51 | POP 52 | -> 2 0 0 0 0 0 0 0 (the stack is empty and padded with zeros) 53 | """ 54 | 55 | def _sample_expression_and_result( 56 | self, length: int 57 | ) -> tuple[np.ndarray, list[int]]: 58 | """Returns an expression with stack instructions, and the result stack.""" 59 | if length == 1: 60 | value = np.random.randint(low=0, high=2, size=(1,)) 61 | return value, list(value) 62 | 63 | # Initialize the stack content and the actions (POP/PUSH). 64 | stack_length = np.random.randint(low=1, high=length) 65 | stack = np.random.randint(low=0, high=2, size=(stack_length,)) 66 | actions = np.random.randint(low=2, high=5, size=(length - stack_length,)) 67 | 68 | # Apply the actions on the stack. 69 | current_stack = list(stack) 70 | 71 | for action in actions: 72 | if action == 2: # POP 73 | if current_stack: 74 | current_stack.pop() 75 | elif action in [3, 4]: # PUSH a 0 (case 3) or a 1 (case 4) 76 | current_stack.append(action - 3) 77 | 78 | return np.concatenate([stack, actions]), current_stack[::-1] 79 | 80 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int, 81 | length: int) -> task.Batch: 82 | """Returns a batch of strings and the expected class.""" 83 | expressions, results = [], [] 84 | for _ in range(batch_size): 85 | expression, result = self._sample_expression_and_result(length) 86 | expressions.append(expression) 87 | # Append the termination token to the result. 88 | result += [self.output_size - 1] 89 | # Pad the result with zeros to match the input length (accounting for the 90 | # termination token). 91 | result += [0] * (length + 1 - len(result)) 92 | results.append(result) 93 | expressions = jnp.array(expressions) 94 | results = jnp.array(results) 95 | 96 | inputs = jnn.one_hot(expressions, self.input_size) 97 | output = jnn.one_hot(results, self.output_size) 98 | return {'input': inputs, 'output': output} 99 | 100 | @property 101 | def input_size(self) -> int: 102 | """Returns the input size for the models. 103 | 104 | The value is 5 because we have two possible tokens in the stack (0, 1), plus 105 | three tokens to describe the PUSH 0, PUSH 1, and POP actions. 106 | """ 107 | return 5 108 | 109 | @property 110 | def output_size(self) -> int: 111 | """Returns the output size for the models.""" 112 | return 3 113 | 114 | def output_length(self, input_length: int) -> int: 115 | """Returns the output length of the task.""" 116 | return input_length + 1 117 | 118 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 119 | """Computes mask that ignores everything after the termination tokens. 120 | 121 | Args: 122 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 123 | 124 | Returns: 125 | The mask of shape `(batch_size, output_length)`. 126 | """ 127 | batch_size, length, _ = target.shape 128 | termination_indices = jnp.argmax( 129 | jnp.argmax(target, axis=-1), 130 | axis=-1, 131 | keepdims=True, 132 | ) 133 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 134 | return indices <= termination_indices 135 | -------------------------------------------------------------------------------- /tasks/regular/cycle_navigation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the final state after randomly walking on a circle.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from neural_networks_chomsky_hierarchy.tasks import task 27 | 28 | 29 | class CycleNavigation(task.GeneralizationTask): 30 | """A task with the goal of computing the final state on a circle. 31 | 32 | The input is a string of actions, composed of 0s, 1s or -1s. The actions give 33 | directions to take on a finite length circle (0 is for stay, 1 is for right, 34 | -1 is for left). The goal is to give the final position on the circle after 35 | all the actions have been taken. The agent starts at position 0. 36 | 37 | By default, the length the circle is 5. 38 | 39 | Examples: 40 | 1 -1 0 -1 -1 -> -2 = class 3 41 | 1 1 1 -1 -> 2 = class 2 42 | 43 | Note that the sampling is jittable so it is fast. 44 | """ 45 | 46 | @property 47 | def _cycle_length(self) -> int: 48 | """Returns the cycle length, number of possible states.""" 49 | return 5 50 | 51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 52 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int, 53 | length: int) -> task.Batch: 54 | """Returns a batch of strings and the expected class.""" 55 | actions = jrandom.randint( 56 | rng, shape=(batch_size, length), minval=0, maxval=3) 57 | final_states = jnp.sum(actions - 1, axis=1) % self._cycle_length 58 | final_states = jnn.one_hot(final_states, num_classes=self.output_size) 59 | one_hot_strings = jnn.one_hot(actions, num_classes=self.input_size) 60 | return {"input": one_hot_strings, "output": final_states} 61 | 62 | @property 63 | def input_size(self) -> int: 64 | """Returns the input size for the models.""" 65 | return 3 66 | 67 | @property 68 | def output_size(self) -> int: 69 | """Returns the output size for the models.""" 70 | return self._cycle_length 71 | -------------------------------------------------------------------------------- /tasks/regular/even_pairs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute whether the number of 01's and 10's is even.""" 17 | 18 | import functools 19 | 20 | import jax 21 | from jax import nn as jnn 22 | from jax import numpy as jnp 23 | from jax import random as jrandom 24 | 25 | from neural_networks_chomsky_hierarchy.tasks import task 26 | 27 | 28 | class EvenPairs(task.GeneralizationTask): 29 | """A task with the goal of checking whether the number of 01s and 10s is even. 30 | 31 | The input is a binary string, composed of 0s and 1s. If the result is even, 32 | the class is 0, otherwise it's one. 33 | 34 | Examples: 35 | 001110 -> 1 '10' and 1 '01' -> class 0 36 | 0101001 -> 2 '10' and 3 '01' -> class 1 37 | 38 | Note the sampling is jittable so this task is fast. 39 | """ 40 | 41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 43 | length: int) -> task.Batch: 44 | """Returns a batch of strings and the expected class.""" 45 | strings = jrandom.randint( 46 | rng, 47 | shape=(batch_size, length), 48 | minval=0, 49 | maxval=2, 50 | ) 51 | one_hot_strings = jnn.one_hot(strings, num_classes=2) 52 | unequal_pairs = jnp.logical_xor(strings[:, :-1], strings[:, 1:]) 53 | odd_unequal_pairs = jnp.sum(unequal_pairs, axis=-1) % 2 54 | return { 55 | 'input': one_hot_strings, 56 | 'output': jnn.one_hot(odd_unequal_pairs, num_classes=self.output_size), 57 | } 58 | 59 | @property 60 | def input_size(self) -> int: 61 | """Returns the input size for the models.""" 62 | return 2 63 | 64 | @property 65 | def output_size(self) -> int: 66 | """Returns the output size for the models.""" 67 | return 2 68 | -------------------------------------------------------------------------------- /tasks/regular/modular_arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Modular arithmetic without brackets. 17 | 18 | Note this allows to generate samples using a jittable function, and is therefore 19 | much faster than its 'brackets' counterpart, which requires to simulate the full 20 | CF grammar, non-jittable. 21 | """ 22 | 23 | import functools 24 | from typing import Optional, Sequence 25 | 26 | import jax 27 | import jax.nn as jnn 28 | import jax.numpy as jnp 29 | import jax.random as jrandom 30 | 31 | from neural_networks_chomsky_hierarchy.tasks import task 32 | 33 | # Public as this may be used to encode/decode strings of numbers/symbols. 34 | OP_BY_CHARACTER = {'+': 0, '-': 1, '*': 2, '_': 3} 35 | 36 | 37 | def _replace_subtractions(expression: jnp.ndarray, modulus: int) -> jnp.ndarray: 38 | """Replaces subtractions in an expression by additions with the inverse. 39 | 40 | e.g. the expression [1, -, 3] results in [1, +, -3]. 41 | 42 | Args: 43 | expression: Encoded expression (a 1D array of integers) in which to replace 44 | subtractions. 45 | modulus: The modulus to use for the modular arithmetic. 46 | 47 | Returns: 48 | The expression with all subtractions replaced by additions with the inverse. 49 | """ 50 | if expression.size < 2: 51 | return expression 52 | 53 | mask = (expression == modulus + OP_BY_CHARACTER['-']) 54 | subtract_replaced = jnp.where(mask, modulus + OP_BY_CHARACTER['+'], 55 | expression) 56 | return subtract_replaced.at[2:].multiply(1 - 2 * mask[1:-1]) 57 | 58 | 59 | def _perform_multiplications(expression: jnp.ndarray, 60 | modulus: int) -> jnp.ndarray: 61 | """Performs all multiplications in an expression containing only + and *. 62 | 63 | This is done at fixed length and the result is zero-padded to achieve this. 64 | Since the result of performing multiplications is an expression containing 65 | only + operators, the operators are dropped from the output. For example, the 66 | expression [1, +, 3, *, 4] results in [1, 12, 0]. 67 | 68 | Args: 69 | expression: Encoded expression in which to perform multiplications. 70 | modulus: The modulus to use for the modular arithmetic. 71 | 72 | Returns: 73 | An array with the results of the multiplications (potentially zero-padded). 74 | """ 75 | term_ids = jnp.cumsum(expression == modulus + OP_BY_CHARACTER['+'])[::2] 76 | # Segment_prod can only be jit-compiled with a fixed number of segments. 77 | # Therefore, we have to set to the maximum number of terms possible and 78 | # mask out superfluous segment results with zeros afterwards. 79 | maximum_term_number = expression.shape[0] // 2 + 1 80 | products = jax.ops.segment_prod( 81 | expression[::2], 82 | term_ids, 83 | num_segments=maximum_term_number, 84 | indices_are_sorted=True) 85 | valid_segment_mask = jnp.arange(maximum_term_number) <= term_ids[-1] 86 | return products * valid_segment_mask 87 | 88 | 89 | def _replace_blanks(expression: jnp.ndarray, modulus: int) -> jnp.ndarray: 90 | """Replaces blank symbols in expression with either `+` or `0`. 91 | 92 | Depending on whether the blank symbol is at the position of an operator or a 93 | residual, the blank symbol is replaced with a `+` operator or a `0`. 94 | 95 | Args: 96 | expression: Encoded expression in which to replace blank symbols. 97 | modulus: The modulus to use for the modular arithmetic. 98 | 99 | Returns: 100 | An array with blank symbols replaced by either `+` or `0`. 101 | """ 102 | mask = (expression == OP_BY_CHARACTER['_'] + modulus) 103 | operator_mask = mask.at[::2].set(False) 104 | residual_mask = mask.at[1::2].set(False) 105 | 106 | blanks_replaced = jnp.where(operator_mask, OP_BY_CHARACTER['+'] + modulus, 107 | expression) 108 | blanks_replaced = jnp.where(residual_mask, 0, blanks_replaced) 109 | return blanks_replaced 110 | 111 | 112 | def _evaluate_expression(expression: jnp.ndarray, modulus: int) -> jnp.ndarray: 113 | """Returns the result of evaluating a modular arithmetic expression.""" 114 | expression = _replace_blanks(expression, modulus) 115 | expression = _replace_subtractions(expression, modulus) 116 | additive_terms = _perform_multiplications(expression, modulus) 117 | return jnp.sum(additive_terms) % modulus 118 | 119 | 120 | class ModularArithmetic(task.GeneralizationTask): 121 | """A task with the goal of reducing a simple arithmetic expression. 122 | 123 | The input is a string, composed of numbers (in {0, ..., modulus-1}), and 124 | operators (in {+, -, *}). The output is the reduced value of this expression, 125 | which is also in {0, ..., modulus-1}. 126 | 127 | Examples (modulo 5): 128 | 1 + 2 * 3 = 2 129 | 1 - 1 - 1 = 4 130 | 0 * 1 + 4 * 3 - 2 = 0 131 | 132 | Note that the input strings are always of odd length. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | modulus: int = 5, 138 | operators: Optional[Sequence[str]] = None, 139 | ) -> None: 140 | """Initializes the modular arithmetic task. 141 | 142 | Args: 143 | modulus: The modulus used for the computation. We use 5 in the paper. 144 | operators: Operators to be used in the sequences. By default it's None, 145 | meaning all operators available are used. 146 | """ 147 | self._modulus = modulus 148 | if operators is None: 149 | operators = ('+', '*', '-') 150 | self._operators = (OP_BY_CHARACTER[op] for op in operators) 151 | 152 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 153 | def sample_batch( 154 | self, 155 | rng: jnp.ndarray, 156 | batch_size: int, 157 | length: int, 158 | ) -> task.Batch: 159 | """Returns a batch of modular arithmetic expressions and their labels. 160 | 161 | Args: 162 | rng: The jax random number generator. 163 | batch_size: The size of the batch returned. 164 | length: The length of the sequence. As this length must be odd for the 165 | modular arithmetic dataset, if it's not, we force it to be by 166 | subtracting one to the length passed. 167 | """ 168 | # Subtracting one to the length if it's not odd already. 169 | if length % 2 != 1: 170 | length -= 1 171 | 172 | batch = jnp.empty((batch_size, length), dtype=int) 173 | rng1, rng2 = jax.random.split(rng) 174 | remainders = jax.random.randint(rng1, 175 | (batch_size, length // 2 + 1), 0, 176 | self._modulus) 177 | ops = self._modulus + jnp.array(list(self._operators)) 178 | 179 | operations = jrandom.choice(rng2, ops, (batch_size, length // 2)) 180 | batch = batch.at[:, ::2].set(remainders) 181 | expressions = batch.at[:, 1::2].set(operations) 182 | 183 | evaluate = functools.partial(_evaluate_expression, modulus=self._modulus) 184 | labels = jax.vmap(evaluate)(expressions) 185 | labels = jnn.one_hot(labels, self._modulus) 186 | one_hot_expressions = jnn.one_hot(expressions, 187 | self._modulus + len(OP_BY_CHARACTER)) 188 | return {'input': one_hot_expressions, 'output': labels} 189 | 190 | @property 191 | def input_size(self) -> int: 192 | """Returns the input size for the models.""" 193 | return self._modulus + len(OP_BY_CHARACTER) 194 | 195 | @property 196 | def output_size(self) -> int: 197 | """Returns the output size for the models.""" 198 | return self._modulus 199 | -------------------------------------------------------------------------------- /tasks/regular/parity_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute whether the number of 1s in a string is even.""" 17 | 18 | import functools 19 | 20 | import jax 21 | import jax.nn as jnn 22 | import jax.numpy as jnp 23 | import jax.random as jrandom 24 | 25 | from neural_networks_chomsky_hierarchy.tasks import task 26 | 27 | 28 | class ParityCheck(task.GeneralizationTask): 29 | """A task with the goal of counting the number of '1' in a string, modulo 2. 30 | 31 | The input is a string, composed of 0s and 1s. If the result is even, the class 32 | is 0, otherwise it's 1. 33 | 34 | Examples: 35 | 1010100 -> 3 1s (odd) -> class 1 36 | 01111 -> 4 1s (even) -> class 0 37 | 38 | Note that the sampling is jittable so this task is fast. 39 | """ 40 | 41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int, 43 | length: int) -> task.Batch: 44 | """Returns a batch of strings and the expected class.""" 45 | strings = jrandom.randint( 46 | rng, shape=(batch_size, length), minval=0, maxval=2) 47 | n_b = jnp.sum(strings, axis=1) % 2 48 | n_b = jnn.one_hot(n_b, num_classes=2) 49 | one_hot_strings = jnn.one_hot(strings, num_classes=2) 50 | return {"input": one_hot_strings, "output": n_b} 51 | 52 | @property 53 | def input_size(self) -> int: 54 | """Returns the input size for the models.""" 55 | return 2 56 | 57 | @property 58 | def output_size(self) -> int: 59 | """Returns the output size for the models.""" 60 | return 2 61 | -------------------------------------------------------------------------------- /tasks/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Base class for length generalization tasks.""" 17 | 18 | import abc 19 | from typing import TypedDict 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | 25 | Batch = TypedDict('Batch', {'input': chex.Array, 'output': chex.Array}) 26 | 27 | 28 | class GeneralizationTask(abc.ABC): 29 | """A task for the generalization project. 30 | 31 | Exposes a sample_batch method, and some details about input/output sizes, 32 | losses and accuracies. 33 | """ 34 | 35 | @abc.abstractmethod 36 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int, 37 | length: int) -> Batch: 38 | """Returns a batch of inputs/outputs.""" 39 | 40 | def pointwise_loss_fn(self, output: chex.Array, 41 | target: chex.Array) -> chex.Array: 42 | """Returns the pointwise loss between an output and a target.""" 43 | return -target * jnn.log_softmax(output) 44 | 45 | def accuracy_fn(self, output: chex.Array, target: chex.Array) -> chex.Array: 46 | """Returns the accuracy between an output and a target.""" 47 | return (jnp.argmax(output, 48 | axis=-1) == jnp.argmax(target, 49 | axis=-1)).astype(jnp.float32) 50 | 51 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 52 | """Returns a mask to compute the accuracies, to remove the superfluous ones.""" 53 | # Target is a shape of shape (B, T, C) where C is the number of classes. 54 | # We want a mask per input (B, T), so we take this shape. 55 | return jnp.ones(target.shape[:-1]) 56 | 57 | @property 58 | @abc.abstractmethod 59 | def input_size(self) -> int: 60 | """Returns the size of the input of the models trained on this task.""" 61 | 62 | @property 63 | @abc.abstractmethod 64 | def output_size(self) -> int: 65 | """Returns the size of the output of the models trained on this task.""" 66 | 67 | def output_length(self, input_length: int) -> int: 68 | """Returns the length of the output, given an input length.""" 69 | del input_length 70 | return 1 71 | --------------------------------------------------------------------------------