├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── experiments ├── constants.py ├── curriculum.py ├── example.py ├── range_evaluation.py ├── training.py └── utils.py ├── models ├── positional_encodings.py ├── transformer.py └── transformer_utils.py ├── overview.png ├── requirements.txt └── tasks ├── cs ├── binary_addition.py ├── binary_multiplication.py ├── bucket_sort.py ├── compute_sqrt.py ├── duplicate_string.py ├── missing_duplicate_string.py └── odds_first.py ├── dcf ├── modular_arithmetic_brackets.py ├── reverse_string.py ├── solve_equation.py └── stack_manipulation.py ├── regular ├── cycle_navigation.py ├── even_pairs.py ├── modular_arithmetic.py └── parity_check.py └── task.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Randomized Positional Encodings Boost Length Generalization of Transformers 2 | 3 |

4 | Overview figure 5 |

6 | 7 | This repository provides an implementation of our ACL 2023 paper [Randomized Positional Encodings Boost Length Generalization of Transformers](https://arxiv.org/abs/2305.16843). 8 | 9 | >Transformers have impressive generalization capabilities on tasks with a fixed context length. 10 | However, they fail to generalize to sequences of arbitrary length, even for seemingly simple tasks such as duplicating a string. 11 | Moreover, simply training on longer sequences is inefficient due to the quadratic computation complexity of the global attention mechanism. 12 | In this work, we demonstrate that this failure mode is linked to positional encodings being out-of-distribution for longer sequences (even for relative encodings) and introduce a novel family of positional encodings that can overcome this problem. 13 | Concretely, our randomized positional encoding scheme simulates the positions of longer sequences and randomly selects an ordered subset to fit the sequence's length. 14 | Our large-scale empirical evaluation of 6000 models across 15 algorithmic reasoning tasks shows that our method allows Transformers to generalize to sequences of unseen length (increasing test accuracy by 12.0% on average). 15 | 16 | It is based on [JAX](https://jax.readthedocs.io) and [Haiku](https://dm-haiku.readthedocs.io) and contains all the code, datasets, and models necessary to reproduce the paper's results. 17 | 18 | 19 | ## Content 20 | 21 | ``` 22 | . 23 | ├── models 24 | │   ├── positional_encodings.py 25 | │   ├── transformer.py - Transformer (Vaswani et al., 2017) 26 | │   └── transformer_utils.py 27 | ├── tasks 28 | │   ├── cs - Context-sensitive tasks 29 | │   ├── dcf - Deterministic context-free tasks 30 | │   ├── regular - Regular tasks 31 | │   └── task.py - Abstract `GeneralizationTask` 32 | ├── experiments 33 | | ├── constants.py - Training/Evaluation constants 34 | | ├── curriculum.py - Training curricula (over sequence lengths) 35 | | ├── example.py - Example traning script 36 | | ├── range_evaluation.py - Evaluation loop (test sequences lengths) 37 | | ├── training.py - Training loop 38 | | └── utils.py - Utility functions 39 | ├── README.md 40 | └── requirements.txt - Dependencies 41 | ``` 42 | 43 | 44 | ## Installation 45 | 46 | Clone the source code into a local directory: 47 | ```bash 48 | git clone https://github.com/google-deepmind/randomized_positional_encodings.git 49 | cd randomized_positional_encodings 50 | ``` 51 | 52 | `pip install -r requirements.txt` will install all required dependencies. 53 | This is best done inside a [conda environment](https://www.anaconda.com/). 54 | To that end, install [Anaconda](https://www.anaconda.com/download#downloads). 55 | Then, create and activate the conda environment: 56 | ```bash 57 | conda create --name randomized_positional_encodings 58 | conda activate randomized_positional_encodings 59 | ``` 60 | 61 | Install `pip` and use it to install all the dependencies: 62 | ```bash 63 | conda install pip 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | If you have a GPU available (highly recommended for fast training), then you can install JAX with CUDA support. 68 | ```bash 69 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 70 | ``` 71 | Note that the jax version must correspond to the existing CUDA installation you wish to use (CUDA 12 in the example above). 72 | Please see the [JAX documentation](https://github.com/google/jax#installation) for more details. 73 | 74 | 75 | ## Usage 76 | 77 | Before running any code, make sure to activate the conda environment and set the `PYTHONPATH`: 78 | ```bash 79 | conda activate randomized_positional_encodings 80 | export PYTHONPATH=$(pwd)/.. 81 | ``` 82 | 83 | We provide an example of a training and evaluation run at: 84 | ```bash 85 | python experiments/example.py 86 | ``` 87 | 88 | 89 | ## Citing this work 90 | 91 | ```bibtex 92 | @inproceedings{ruoss2023randomized, 93 | author = {Anian Ruoss and 94 | Gr{\'{e}}goire Del{\'{e}}tang and 95 | Tim Genewein and 96 | Jordi Grau{-}Moya and 97 | R{\'{o}}bert Csord{\'{a}}s and 98 | Mehdi Bennani and 99 | Shane Legg and 100 | Joel Veness}, 101 | title = {Randomized Positional Encodings Boost Length Generalization of Transformers}, 102 | booktitle = {61st Annual Meeting of the Association for Computational Linguistics} 103 | year = {2023}, 104 | } 105 | ``` 106 | 107 | 108 | ## License and disclaimer 109 | 110 | Copyright 2023 DeepMind Technologies Limited 111 | 112 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 113 | you may not use this file except in compliance with the Apache 2.0 license. 114 | You may obtain a copy of the Apache 2.0 license at: 115 | https://www.apache.org/licenses/LICENSE-2.0 116 | 117 | All other materials are licensed under the Creative Commons Attribution 4.0 118 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 119 | https://creativecommons.org/licenses/by/4.0/legalcode 120 | 121 | Unless required by applicable law or agreed to in writing, all software and 122 | materials distributed here under the Apache 2.0 or CC-BY licenses are 123 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 124 | either express or implied. See the licenses for the specific language governing 125 | permissions and limitations under those licenses. 126 | 127 | This is not an official Google product. 128 | -------------------------------------------------------------------------------- /experiments/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Constants for our length generalization experiments.""" 17 | 18 | import functools 19 | 20 | from randomized_positional_encodings.experiments import curriculum as curriculum_lib 21 | from randomized_positional_encodings.models import transformer 22 | from randomized_positional_encodings.tasks.cs import binary_addition 23 | from randomized_positional_encodings.tasks.cs import binary_multiplication 24 | from randomized_positional_encodings.tasks.cs import bucket_sort 25 | from randomized_positional_encodings.tasks.cs import compute_sqrt 26 | from randomized_positional_encodings.tasks.cs import duplicate_string 27 | from randomized_positional_encodings.tasks.cs import missing_duplicate_string 28 | from randomized_positional_encodings.tasks.cs import odds_first 29 | from randomized_positional_encodings.tasks.dcf import modular_arithmetic_brackets 30 | from randomized_positional_encodings.tasks.dcf import reverse_string 31 | from randomized_positional_encodings.tasks.dcf import solve_equation 32 | from randomized_positional_encodings.tasks.dcf import stack_manipulation 33 | from randomized_positional_encodings.tasks.regular import cycle_navigation 34 | from randomized_positional_encodings.tasks.regular import even_pairs 35 | from randomized_positional_encodings.tasks.regular import modular_arithmetic 36 | from randomized_positional_encodings.tasks.regular import parity_check 37 | 38 | 39 | MODEL_BUILDERS = { 40 | 'transformer_encoder': functools.partial( 41 | transformer.make_transformer, 42 | transformer_module=transformer.TransformerEncoder, # pytype: disable=module-attr 43 | ), 44 | } 45 | 46 | CURRICULUM_BUILDERS = { 47 | 'fixed': curriculum_lib.FixedCurriculum, 48 | 'regular_increase': curriculum_lib.RegularIncreaseCurriculum, 49 | 'reverse_exponential': curriculum_lib.ReverseExponentialCurriculum, 50 | 'uniform': curriculum_lib.UniformCurriculum, 51 | } 52 | 53 | TASK_BUILDERS = { 54 | 'even_pairs': even_pairs.EvenPairs, 55 | 'modular_arithmetic': functools.partial( 56 | modular_arithmetic.ModularArithmetic, modulus=5 57 | ), 58 | 'parity_check': parity_check.ParityCheck, 59 | 'cycle_navigation': cycle_navigation.CycleNavigation, 60 | 'stack_manipulation': stack_manipulation.StackManipulation, 61 | 'reverse_string': functools.partial( 62 | reverse_string.ReverseString, vocab_size=2 63 | ), 64 | 'modular_arithmetic_brackets': functools.partial( 65 | modular_arithmetic_brackets.ModularArithmeticBrackets, 66 | modulus=5, 67 | mult=True, 68 | ), 69 | 'solve_equation': functools.partial( 70 | solve_equation.SolveEquation, modulus=5 71 | ), 72 | 'duplicate_string': functools.partial( 73 | duplicate_string.DuplicateString, vocab_size=2 74 | ), 75 | 'missing_duplicate_string': missing_duplicate_string.MissingDuplicateString, 76 | 'odds_first': functools.partial(odds_first.OddsFirst, vocab_size=2), 77 | 'binary_addition': binary_addition.BinaryAddition, 78 | 'binary_multiplication': binary_multiplication.BinaryMultiplication, 79 | 'compute_sqrt': compute_sqrt.ComputeSqrt, 80 | 'bucket_sort': functools.partial(bucket_sort.BucketSort, vocab_size=5), 81 | } 82 | 83 | TASK_LEVELS = { 84 | 'even_pairs': 'regular', 85 | 'modular_arithmetic': 'regular', 86 | 'parity_check': 'regular', 87 | 'cycle_navigation': 'regular', 88 | 'stack_manipulation': 'dcf', 89 | 'reverse_string': 'dcf', 90 | 'modular_arithmetic_brackets': 'dcf', 91 | 'solve_equation': 'dcf', 92 | 'duplicate_string': 'cs', 93 | 'missing_duplicate_string': 'cs', 94 | 'odds_first': 'cs', 95 | 'binary_addition': 'cs', 96 | 'binary_multiplication': 'cs', 97 | 'compute_sqrt': 'cs', 98 | 'bucket_sort': 'cs', 99 | } 100 | -------------------------------------------------------------------------------- /experiments/curriculum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Curricula over sequence lengths used to evaluate length generalization. 17 | 18 | Allows to sample different sequence lengths along training. For instance, 19 | one might want to start with length=1 and regularly increase the length by 1, 20 | every 50k steps. 21 | """ 22 | 23 | import abc 24 | from collections.abc import Collection 25 | import random 26 | 27 | import numpy as np 28 | 29 | 30 | class Curriculum(abc.ABC): 31 | """Curriculum to sample lengths.""" 32 | 33 | @abc.abstractmethod 34 | def sample_sequence_length(self, step: int) -> int: 35 | """Samples a sequence length from the current distribution.""" 36 | 37 | 38 | class FixedCurriculum(Curriculum): 39 | """A fixed curriculum, always sampling the same sequence length.""" 40 | 41 | def __init__(self, sequence_length: int): 42 | """Initializes. 43 | 44 | Args: 45 | sequence_length: The sequence length to sample. 46 | """ 47 | super().__init__() 48 | self._sequence_length = sequence_length 49 | 50 | def sample_sequence_length(self, step: int) -> int: 51 | """Returns a fixed sequence length.""" 52 | del step 53 | return self._sequence_length 54 | 55 | 56 | class UniformCurriculum(Curriculum): 57 | """A uniform curriculum, sampling different sequence lengths.""" 58 | 59 | def __init__(self, values: Collection[int]): 60 | """Initializes. 61 | 62 | Args: 63 | values: The sequence lengths to sample. 64 | """ 65 | super().__init__() 66 | self._values = tuple(values) 67 | 68 | def sample_sequence_length(self, step: int) -> int: 69 | """Returns a sequence length sampled from a uniform distribution.""" 70 | del step 71 | return random.choice(self._values) 72 | 73 | 74 | class ReverseExponentialCurriculum(Curriculum): 75 | """A reverse exponential curriculum, sampling different sequence lengths.""" 76 | 77 | def __init__(self, values: Collection[int], tau: bool): 78 | """Initializes. 79 | 80 | Args: 81 | values: The sequence lengths to sample. 82 | tau: The exponential rate to use. 83 | """ 84 | super().__init__() 85 | self._values = tuple(values) 86 | self._tau = tau 87 | 88 | def sample_sequence_length(self, step: int) -> int: 89 | """Returns a length sampled from a reverse exponential distribution.""" 90 | del step 91 | probs = self._tau ** np.array(self._values) 92 | probs = np.array(probs, dtype=np.float32) 93 | probs = probs / np.sum(probs) 94 | return np.random.choice(self._values, p=probs) 95 | 96 | 97 | class RegularIncreaseCurriculum(Curriculum): 98 | """Curriculum for sequence lengths with a regular increase.""" 99 | 100 | def __init__( 101 | self, 102 | initial_sequence_length: int, 103 | increase_frequency: int, 104 | increase_amount: int, 105 | sample_all_length: bool, 106 | ): 107 | """Initializes. 108 | 109 | Args: 110 | initial_sequence_length: The value of the sequence length at the beginning 111 | of the curriculum. 112 | increase_frequency: How often we increase the possible sequence length. 113 | increase_amount: The amount of the increase in length. 114 | sample_all_length: Whether to sample all length lower than the current one 115 | or just return the current one. 116 | """ 117 | super().__init__() 118 | self._initial_sequence_length = initial_sequence_length 119 | self._increase_frequency = increase_frequency 120 | self._increase_amount = increase_amount 121 | self._sample_all_length = sample_all_length 122 | 123 | def sample_sequence_length(self, step: int) -> int: 124 | """Returns a sequence length from the curriculum with the current step.""" 125 | if not self._sample_all_length: 126 | return self._initial_sequence_length + self._increase_amount * ( 127 | step // self._increase_frequency 128 | ) 129 | return ( 130 | self._initial_sequence_length 131 | + self._increase_amount 132 | * np.random.randint(0, step // self._increase_frequency + 1) 133 | ) 134 | -------------------------------------------------------------------------------- /experiments/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Example script to train and evaluate a network.""" 17 | 18 | from absl import app 19 | from absl import flags 20 | import haiku as hk 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | from randomized_positional_encodings.experiments import constants 25 | from randomized_positional_encodings.experiments import curriculum as curriculum_lib 26 | from randomized_positional_encodings.experiments import training 27 | from randomized_positional_encodings.experiments import utils 28 | 29 | 30 | _BATCH_SIZE = flags.DEFINE_integer( 31 | 'batch_size', 32 | default=128, 33 | help='Training batch size.', 34 | lower_bound=1, 35 | ) 36 | _SEQUENCE_LENGTH = flags.DEFINE_integer( 37 | 'sequence_length', 38 | default=40, 39 | help='Maximum training sequence length.', 40 | lower_bound=1, 41 | ) 42 | _TASK = flags.DEFINE_string( 43 | 'task', 44 | default='missing_duplicate_string', 45 | help='Length generalization task (see `constants.py` for other tasks).', 46 | ) 47 | _ARCHITECTURE = flags.DEFINE_string( 48 | 'architecture', 49 | default='transformer_encoder', 50 | help='Model architecture (see `constants.py` for other architectures).', 51 | ) 52 | _IS_AUTOREGRESSIVE = flags.DEFINE_boolean( 53 | 'is_autoregressive', 54 | default=False, 55 | help='Whether to use autoregressive sampling or not.', 56 | ) 57 | _COMPUTATION_STEPS_MULT = flags.DEFINE_integer( 58 | 'computation_steps_mult', 59 | default=0, 60 | help=( 61 | 'The amount of computation tokens to append to the input tape (defined' 62 | ' as a multiple of the input length)' 63 | ), 64 | lower_bound=0, 65 | ) 66 | # The architecture parameters depend on the architecture, so we cannot define 67 | # them as via flags. See `constants.py` for the required values. 68 | _ARCHITECTURE_PARAMS = { 69 | 'num_layers': 5, 70 | 'embedding_dim': 64, 71 | 'dropout_prob': 0.1, 72 | 'positional_encodings': 'NOISY_RELATIVE', 73 | 'positional_encodings_params': {'noise_max_length': 2048}, 74 | } 75 | 76 | 77 | def main(_) -> None: 78 | # Create the task. 79 | curriculum = curriculum_lib.UniformCurriculum( 80 | values=list(range(1, _SEQUENCE_LENGTH.value + 1)) 81 | ) 82 | task = constants.TASK_BUILDERS[_TASK.value]() 83 | 84 | # Create the model. 85 | single_output = task.output_length(10) == 1 86 | model = constants.MODEL_BUILDERS[_ARCHITECTURE.value]( 87 | output_size=task.output_size, 88 | return_all_outputs=True, 89 | **_ARCHITECTURE_PARAMS, 90 | ) 91 | if _IS_AUTOREGRESSIVE.value: 92 | if 'transformer' not in _ARCHITECTURE.value: 93 | model = utils.make_model_with_targets_as_input( 94 | model, _COMPUTATION_STEPS_MULT.value 95 | ) 96 | model = utils.add_sampling_to_autoregressive_model(model, single_output) 97 | else: 98 | model = utils.make_model_with_empty_targets( 99 | model, task, _COMPUTATION_STEPS_MULT.value, single_output 100 | ) 101 | model = hk.transform(model) 102 | 103 | # Create the loss and accuracy based on the pointwise ones. 104 | def loss_fn(output, target): 105 | loss = jnp.mean(jnp.sum(task.pointwise_loss_fn(output, target), axis=-1)) 106 | return loss, {} 107 | 108 | def accuracy_fn(output, target): 109 | mask = task.accuracy_mask(target) 110 | return jnp.sum(mask * task.accuracy_fn(output, target)) / jnp.sum(mask) 111 | 112 | # Create the final training parameters. 113 | training_params = training.ClassicTrainingParams( 114 | seed=0, 115 | model_init_seed=0, 116 | training_steps=10_000, 117 | log_frequency=100, 118 | length_curriculum=curriculum, 119 | batch_size=_BATCH_SIZE.value, 120 | task=task, 121 | model=model, 122 | loss_fn=loss_fn, 123 | learning_rate=1e-3, 124 | l2_weight=0.0, 125 | accuracy_fn=accuracy_fn, 126 | compute_full_range_test=True, 127 | max_range_test_length=100, 128 | range_test_total_batch_size=512, 129 | range_test_sub_batch_size=64, 130 | is_autoregressive=_IS_AUTOREGRESSIVE.value, 131 | ) 132 | 133 | training_worker = training.TrainingWorker(training_params, use_tqdm=True) 134 | _, eval_results, _ = training_worker.run() 135 | 136 | # Gather results and print final score. 137 | accuracies = [r['accuracy'] for r in eval_results] 138 | score = np.mean(accuracies[_SEQUENCE_LENGTH.value + 1 :]) 139 | print(f'Score: {score}') 140 | 141 | 142 | if __name__ == '__main__': 143 | app.run(main) 144 | -------------------------------------------------------------------------------- /experiments/range_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluation of a network on sequences of different lengths.""" 17 | 18 | import dataclasses 19 | import random 20 | from typing import Any, Callable, Mapping 21 | 22 | from absl import logging 23 | import chex 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import tqdm 29 | 30 | _Batch = Mapping[str, jnp.ndarray] 31 | 32 | 33 | @dataclasses.dataclass 34 | class EvaluationParams: 35 | """The parameters used for range evaluation of networks.""" 36 | 37 | model: hk.Transformed 38 | params: chex.ArrayTree 39 | 40 | accuracy_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 41 | sample_batch: Callable[[chex.Array, int, int], _Batch] 42 | 43 | max_test_length: int 44 | total_batch_size: int 45 | sub_batch_size: int 46 | 47 | is_autoregressive: bool = False 48 | 49 | 50 | def range_evaluation( 51 | eval_params: EvaluationParams, 52 | use_tqdm: bool = False, 53 | ) -> list[Mapping[str, Any]]: 54 | """Evaluates the model on longer, never seen strings and log the results. 55 | 56 | Args: 57 | eval_params: The evaluation parameters, see above. 58 | use_tqdm: Whether to use a progress bar with tqdm. 59 | 60 | Returns: 61 | The list of dicts containing the accuracies. 62 | """ 63 | model = eval_params.model 64 | params = eval_params.params 65 | 66 | random.seed(1) 67 | np.random.seed(1) 68 | rng_seq = hk.PRNGSequence(1) 69 | 70 | if eval_params.is_autoregressive: 71 | apply_fn = jax.jit(model.apply, static_argnames=('sample',)) 72 | else: 73 | apply_fn = jax.jit(model.apply) 74 | 75 | results = [] 76 | lengths = range(1, eval_params.max_test_length + 1) 77 | 78 | if use_tqdm: 79 | lengths = tqdm.tqdm(lengths) 80 | 81 | for length in lengths: 82 | # We need to clear the cache of jitted functions, to avoid overflow as we 83 | # are jitting len(lengths) ones, which can be a lot. 84 | apply_fn.clear_cache() 85 | sub_accuracies = [] 86 | 87 | for _ in range(eval_params.total_batch_size // eval_params.sub_batch_size): 88 | batch = eval_params.sample_batch( 89 | next(rng_seq), eval_params.sub_batch_size, length 90 | ) 91 | 92 | if eval_params.is_autoregressive: 93 | outputs = apply_fn( 94 | params, 95 | next(rng_seq), 96 | batch['input'], 97 | jnp.empty_like(batch['output']), 98 | sample=True, 99 | ) 100 | else: 101 | outputs = apply_fn(params, next(rng_seq), batch['input']) 102 | 103 | sub_accuracies.append( 104 | float(np.mean(eval_params.accuracy_fn(outputs, batch['output']))) 105 | ) 106 | log_data = { 107 | 'length': length, 108 | 'accuracy': np.mean(sub_accuracies), 109 | } 110 | logging.info(log_data) 111 | results.append(log_data) 112 | return results 113 | -------------------------------------------------------------------------------- /experiments/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Training loop for length generalization experiments.""" 17 | 18 | import dataclasses 19 | import functools 20 | import random 21 | from typing import Any, Callable, Mapping, Optional 22 | 23 | from absl import logging 24 | import chex 25 | import haiku as hk 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | import optax 30 | import tqdm 31 | 32 | from randomized_positional_encodings.experiments import curriculum as curriculum_lib 33 | from randomized_positional_encodings.experiments import range_evaluation 34 | from randomized_positional_encodings.tasks import task as task_lib 35 | 36 | 37 | _Batch = Mapping[str, jnp.ndarray] 38 | _LossMetrics = Optional[Mapping[str, jnp.ndarray]] 39 | _LossFn = Callable[[chex.Array, chex.Array], tuple[float, _LossMetrics]] 40 | _AccuracyFn = Callable[[chex.Array, chex.Array], float] 41 | _ModelApplyFn = Callable[..., chex.Array] 42 | _MAX_RNGS_RESERVE = 50000 43 | 44 | 45 | @dataclasses.dataclass 46 | class ClassicTrainingParams: 47 | """Parameters needed to train classical architectures.""" 48 | 49 | seed: int # Used to sample during forward pass (e.g. from final logits). 50 | model_init_seed: int # Used to initialize model parameters. 51 | training_steps: int 52 | log_frequency: int 53 | 54 | task: task_lib.GeneralizationTask 55 | length_curriculum: curriculum_lib.Curriculum 56 | batch_size: int 57 | 58 | model: hk.Transformed 59 | loss_fn: Callable[[jnp.ndarray, jnp.ndarray], tuple[float, _LossMetrics]] 60 | learning_rate: float 61 | l2_weight: float 62 | test_model: Optional[hk.Transformed] = None 63 | max_grad_norm: float = 1.0 64 | is_autoregressive: bool = False 65 | 66 | compute_full_range_test: bool = False 67 | range_test_total_batch_size: int = 512 68 | range_test_sub_batch_size: int = 64 69 | max_range_test_length: int = 100 70 | 71 | accuracy_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = ( 72 | None 73 | ) 74 | 75 | 76 | def _apply_loss_and_metrics_fn( 77 | params: hk.Params, 78 | rng_key: chex.PRNGKey, 79 | batch: _Batch, 80 | model_apply_fn: _ModelApplyFn, 81 | loss_fn: _LossFn, 82 | accuracy_fn: _AccuracyFn, 83 | is_autoregressive: bool = False, 84 | ) -> tuple[float, tuple[_LossMetrics, float]]: 85 | """Computes the model output and applies the loss function. 86 | 87 | Depending on whether a model is autoregressive or not, it will have a 88 | different number of input parameters (i.e., autoregressive models also require 89 | the targets as an input). 90 | 91 | Args: 92 | params: The model parameters. 93 | rng_key: The prng key to use for random number generation. 94 | batch: The data (consists of both inputs and outputs). 95 | model_apply_fn: The model function that converts inputs into outputs. 96 | loss_fn: A function that computes the loss for a batch of logits and labels. 97 | accuracy_fn: A function that computes the accuracy for a batch of logits and 98 | labels. 99 | is_autoregressive: Whether the model is autoregressive or not. 100 | 101 | Returns: 102 | The loss of the model for the batch of data, extra loss metrics and the 103 | accuracy, if accuracy_fn is not None. 104 | """ 105 | if is_autoregressive: 106 | outputs = model_apply_fn( 107 | params, rng_key, batch["input"], batch["output"], sample=False 108 | ) 109 | else: 110 | outputs = model_apply_fn(params, rng_key, batch["input"]) 111 | 112 | loss, loss_metrics = loss_fn(outputs, batch["output"]) 113 | if accuracy_fn is not None: 114 | accuracy = accuracy_fn(outputs, batch["output"]) 115 | else: 116 | accuracy = None 117 | return loss, (loss_metrics, accuracy) 118 | 119 | 120 | @functools.partial( 121 | jax.jit, 122 | static_argnames=( 123 | "model_apply_fn", 124 | "loss_fn", 125 | "accuracy_fn", 126 | "optimizer", 127 | "is_autoregressive", 128 | ), 129 | ) 130 | def _update_parameters( 131 | params: hk.Params, 132 | rng_key: chex.PRNGKey, 133 | batch: _Batch, 134 | model_apply_fn: _ModelApplyFn, 135 | loss_fn: _LossFn, 136 | accuracy_fn: _AccuracyFn, 137 | optimizer: optax.GradientTransformation, 138 | opt_state: optax.OptState, 139 | is_autoregressive: bool = False, 140 | ) -> tuple[hk.Params, optax.OptState, tuple[float, _LossMetrics, float]]: 141 | """Applies a single SGD update step to the model parameters. 142 | 143 | Args: 144 | params: The model parameters. 145 | rng_key: The prng key to use for random number generation. 146 | batch: The data (consists of both inputs and outputs). 147 | model_apply_fn: The model function that converts inputs into outputs. 148 | loss_fn: A function that computes the loss for a batch of logits and labels. 149 | accuracy_fn: A function that computes the accuracy for a batch of logits and 150 | labels. 151 | optimizer: The optimizer that computes the updates from the gradients of the 152 | `loss_fn` with respect to the `params` and the previous `opt_state`. 153 | opt_state: The optimizer state, e.g., momentum for each variable when using 154 | Adam. 155 | is_autoregressive: Whether the model is autoregressive or not. 156 | 157 | Returns: 158 | The updated parameters, the new optimizer state, and the loss, loss metrics 159 | and accuracy. 160 | """ 161 | (loss, (metrics, accuracy)), grads = jax.value_and_grad( 162 | _apply_loss_and_metrics_fn, has_aux=True 163 | )( 164 | params, 165 | rng_key, 166 | batch, 167 | model_apply_fn, 168 | loss_fn, 169 | accuracy_fn, 170 | is_autoregressive, 171 | ) 172 | updates, new_opt_state = optimizer.update(grads, opt_state, params) 173 | new_params = optax.apply_updates(params, updates) 174 | return new_params, new_opt_state, (loss, metrics, accuracy) 175 | 176 | 177 | class TrainingWorker: 178 | """Training worker.""" 179 | 180 | def __init__( 181 | self, training_params: ClassicTrainingParams, use_tqdm: bool = False 182 | ): 183 | """Initializes the worker. 184 | 185 | Args: 186 | training_params: The training parameters. 187 | use_tqdm: Whether to add a progress bar to stdout. 188 | """ 189 | self._training_params = training_params 190 | self._use_tqdm = use_tqdm 191 | self._params = None 192 | self._step = 0 193 | 194 | def step_for_evaluator(self) -> int: 195 | return self._step 196 | 197 | def run( 198 | self, 199 | ) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], chex.ArrayTree]: 200 | """Trains the model with the provided config. 201 | 202 | Returns: 203 | Results (various training and validation metrics), module parameters 204 | and router parameters. 205 | """ 206 | logging.info("Starting training!") 207 | training_params = self._training_params 208 | rngs_reserve = min(_MAX_RNGS_RESERVE, training_params.training_steps) 209 | 210 | random.seed(training_params.seed) 211 | np.random.seed(training_params.seed) 212 | rng_seq = hk.PRNGSequence(training_params.seed) 213 | rng_seq.reserve(rngs_reserve) 214 | 215 | step = 0 216 | 217 | results = [] 218 | model = training_params.model 219 | task = training_params.task 220 | length_curriculum = training_params.length_curriculum 221 | 222 | if training_params.l2_weight is None or training_params.l2_weight == 0: 223 | optimizer = optax.adam(training_params.learning_rate) 224 | else: 225 | optimizer = optax.adamw( 226 | training_params.learning_rate, weight_decay=training_params.l2_weight 227 | ) 228 | 229 | optimizer = optax.chain( 230 | optax.clip_by_global_norm(training_params.max_grad_norm), optimizer 231 | ) 232 | 233 | dummy_batch = task.sample_batch( 234 | next(rng_seq), length=10, batch_size=training_params.batch_size 235 | ) 236 | model_init_rng_key = jax.random.PRNGKey(training_params.model_init_seed) 237 | 238 | if training_params.is_autoregressive: 239 | params = model.init( 240 | model_init_rng_key, 241 | dummy_batch["input"], 242 | dummy_batch["output"], 243 | sample=False, 244 | ) 245 | else: 246 | params = model.init(model_init_rng_key, dummy_batch["input"]) 247 | 248 | opt_state = optimizer.init(params) 249 | self._params, self._step = params, 0 250 | 251 | steps = range(training_params.training_steps + 1) 252 | 253 | if self._use_tqdm: 254 | steps = tqdm.tqdm(steps) 255 | 256 | for step in steps: 257 | # Randomness handled by either python.random or numpy. 258 | length = length_curriculum.sample_sequence_length(step) 259 | # Randomness handled by either jax, python.random or numpy. 260 | train_batch = task.sample_batch( 261 | next(rng_seq), length=length, batch_size=training_params.batch_size 262 | ) 263 | params, opt_state, (train_loss, train_metrics, train_accuracy) = ( 264 | _update_parameters( 265 | params=params, 266 | rng_key=next(rng_seq), 267 | batch=train_batch, 268 | model_apply_fn=model.apply, 269 | loss_fn=training_params.loss_fn, 270 | accuracy_fn=training_params.accuracy_fn, 271 | optimizer=optimizer, 272 | opt_state=opt_state, 273 | is_autoregressive=training_params.is_autoregressive, 274 | ) 275 | ) 276 | self._params, self._step = params, step 277 | 278 | log_freq = training_params.log_frequency 279 | if (log_freq > 0) and (step % log_freq == 0): 280 | log_data = { 281 | "step": step, 282 | "train_loss": float(train_loss), 283 | } 284 | if training_params.accuracy_fn is not None: 285 | log_data["train_accuracy"] = float(train_accuracy) 286 | for key, value in train_metrics.items(): 287 | log_data[".".join(["train_metrics", key])] = np.array(value) 288 | logging.info(log_data) 289 | results.append(log_data) 290 | 291 | # We need to access this private attribute since the default reserve size 292 | # can not be edited yet. 293 | if not rng_seq._subkeys: # pylint: disable=protected-access 294 | rng_seq.reserve(rngs_reserve) 295 | 296 | eval_results = list() 297 | 298 | if training_params.compute_full_range_test: 299 | eval_params = range_evaluation.EvaluationParams( 300 | model=training_params.test_model or model, 301 | params=params, 302 | accuracy_fn=training_params.accuracy_fn, 303 | sample_batch=task.sample_batch, 304 | max_test_length=training_params.max_range_test_length, 305 | total_batch_size=training_params.range_test_total_batch_size, 306 | sub_batch_size=training_params.range_test_sub_batch_size, 307 | is_autoregressive=training_params.is_autoregressive, 308 | ) 309 | eval_results = range_evaluation.range_evaluation( 310 | eval_params, 311 | use_tqdm=True, 312 | ) 313 | 314 | return results, eval_results, params 315 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Provides utility functions for training and evaluation.""" 17 | 18 | import inspect 19 | from typing import Any, Callable 20 | 21 | import chex 22 | import haiku as hk 23 | from jax import nn as jnn 24 | from jax import numpy as jnp 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | COMPUTATION_EMPTY_TOKEN = 0 29 | OUTPUT_EMPTY_TOKEN = 1 30 | 31 | 32 | def make_model_with_empty_targets( 33 | model: Callable[[chex.Array], chex.Array], 34 | generalization_task: task.GeneralizationTask, 35 | computation_steps_mult: int = 0, 36 | single_output: bool = False, 37 | ) -> Callable[[chex.Array], chex.Array]: 38 | """Returns a wrapped model that pads the inputs to match the output length. 39 | 40 | For a given input tape `input_tape` of vocabulary size `vocab_size`, the 41 | wrapped model will process a tape of the format 42 | [`input_tape`, `empty_tape`], where the empty tape token is `vocab_size + 1`. 43 | The `empty_tape` has the same length as the task output. 44 | 45 | Args: 46 | model: A model function that converts inputs to outputs. 47 | generalization_task: The task that we train on. 48 | computation_steps_mult: The amount of empty cells to append to the input 49 | tape. This variable is a multiplier and the actual number of cells is 50 | `computation_steps_mult * input_length`. 51 | single_output: Whether to return the squeezed tensor of values. 52 | """ 53 | 54 | def new_model(x: chex.Array) -> chex.Array: 55 | batch_size, input_length, input_size = x.shape 56 | output_length = generalization_task.output_length(input_length) 57 | extra_dims_onehot = 1 + int(computation_steps_mult > 0) 58 | final_input_size = input_size + extra_dims_onehot 59 | 60 | # Add trailing zeros to account for new final_input_size. 61 | extra_zeros_x = jnp.zeros( 62 | (batch_size, input_length, final_input_size - input_size) 63 | ) 64 | x = jnp.concatenate([x, extra_zeros_x], axis=-1) 65 | 66 | computation_tape = jnp.full( 67 | (batch_size, computation_steps_mult * input_length), 68 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN, 69 | ) 70 | computation_tape = jnn.one_hot( 71 | computation_tape, num_classes=final_input_size 72 | ) 73 | 74 | output_tokens = jnp.full( 75 | (batch_size, output_length), 76 | fill_value=input_size 77 | + OUTPUT_EMPTY_TOKEN 78 | - int(computation_steps_mult == 0), 79 | ) 80 | output_tokens = jnn.one_hot(output_tokens, num_classes=final_input_size) 81 | final_input = jnp.concatenate([x, computation_tape, output_tokens], axis=1) 82 | 83 | if 'input_length' in inspect.getfullargspec(model).args: 84 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args 85 | else: 86 | output = model(final_input) 87 | output = output[:, -output_length:] 88 | if single_output: 89 | output = jnp.squeeze(output, axis=1) 90 | return output 91 | 92 | return new_model 93 | 94 | 95 | def make_model_with_targets_as_input( 96 | model: Callable[[chex.Array], chex.Array], computation_steps_mult: int = 0 97 | ) -> Callable[[chex.Array, chex.Array], chex.Array]: 98 | """Returns a wrapped model that takes the targets as inputs. 99 | 100 | This function is useful for the autoregressive case where we pass the targets 101 | as inputs to the model. The final input looks like: 102 | [inputs, computation_tokens, output_token, targets] 103 | 104 | Args: 105 | model: A haiku model that takes 'x' as input. 106 | computation_steps_mult: The amount of computation tokens to append to the 107 | input tape. This variable is a multiplier and the actual number of cell is 108 | computation_steps_mult * input_length. 109 | """ 110 | 111 | def new_model(x: chex.Array, y: chex.Array) -> chex.Array: 112 | """Returns an output from the inputs and targets. 113 | 114 | Args: 115 | x: One-hot input vectors, shape (B, T, input_size). 116 | y: One-hot target output vectors, shape (B, T, output_size). 117 | """ 118 | batch_size, input_length, input_size = x.shape 119 | _, output_length, output_size = y.shape 120 | extra_dims_onehot = 1 + int(computation_steps_mult > 0) 121 | final_input_size = max(input_size, output_size) + extra_dims_onehot 122 | 123 | # Add trailing zeros to account for new final_input_size. 124 | extra_zeros_x = jnp.zeros( 125 | (batch_size, input_length, final_input_size - input_size) 126 | ) 127 | x = jnp.concatenate([x, extra_zeros_x], axis=-1) 128 | extra_zeros_y = jnp.zeros( 129 | (batch_size, output_length, final_input_size - output_size) 130 | ) 131 | y = jnp.concatenate([y, extra_zeros_y], axis=-1) 132 | 133 | computation_tape = jnp.full( 134 | (batch_size, computation_steps_mult * input_length), 135 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN, 136 | ) 137 | computation_tape = jnn.one_hot( 138 | computation_tape, num_classes=final_input_size 139 | ) 140 | 141 | output_token = jnp.full( 142 | (batch_size, 1), 143 | fill_value=input_size 144 | + OUTPUT_EMPTY_TOKEN 145 | - int(computation_steps_mult == 0), 146 | ) 147 | output_token = jnn.one_hot(output_token, num_classes=final_input_size) 148 | final_input = jnp.concatenate( 149 | [x, computation_tape, output_token, y], axis=1 150 | ) 151 | 152 | if 'input_length' in inspect.getfullargspec(model).args: 153 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args 154 | else: 155 | output = model(final_input) 156 | 157 | return output[:, -output_length - 1 : -1] 158 | 159 | return new_model 160 | 161 | 162 | def add_sampling_to_autoregressive_model( 163 | model: Callable[[chex.Array, chex.Array], chex.Array], 164 | single_output: bool = False, 165 | ) -> Callable[[chex.Array, chex.Array, bool], chex.Array]: 166 | """Adds a 'sample' argument to the model, to use autoregressive sampling.""" 167 | 168 | def new_model_with_sampling( 169 | x: chex.Array, 170 | y: chex.Array, 171 | sample: bool, 172 | ) -> chex.Array: 173 | """Returns an autoregressive model if `sample == True and output_size > 1`. 174 | 175 | Args: 176 | x: The input sequences of shape (b, t, i), where i is the input size. 177 | y: The target sequences of shape (b, t, o), where o is the output size. 178 | sample: Whether to evaluate the model using autoregressive decoding. 179 | """ 180 | output_length = 1 if len(y.shape) == 2 else y.shape[1] 181 | output_size = y.shape[-1] 182 | 183 | if not sample or output_length == 1: 184 | output = model(x, y) 185 | 186 | else: 187 | 188 | def evaluate_model_autoregressively( 189 | idx: int, 190 | predictions: chex.Array, 191 | ) -> chex.Array: 192 | """Iteratively evaluates the model based on the previous predictions. 193 | 194 | Args: 195 | idx: The index of the target sequence that should be evaluated. 196 | predictions: The logits for the predictions up to but not including 197 | the index `idx`. 198 | 199 | Returns: 200 | The `predictions` array modified only at position `idx` where the 201 | logits for index `idx` have been inserted. 202 | """ 203 | one_hot_predictions = jnn.one_hot( 204 | jnp.argmax(predictions, axis=-1), 205 | num_classes=output_size, 206 | ) 207 | logits = model(x, one_hot_predictions) 208 | return predictions.at[:, idx].set(logits[:, idx]) 209 | 210 | output = hk.fori_loop( 211 | lower=0, 212 | upper=output_length, 213 | body_fun=evaluate_model_autoregressively, 214 | init_val=jnp.empty_like(y), 215 | ) 216 | 217 | if single_output: 218 | output = jnp.squeeze(output, axis=1) 219 | return output 220 | 221 | return new_model_with_sampling 222 | 223 | 224 | def update_tree_with_new_containers( 225 | tree: Any, update_dict: dict[str, Any] 226 | ) -> None: 227 | """Updates a dataclass tree in place, adding new containers. 228 | 229 | This method is useful for the nested library to add fields to a tree, for 230 | which containers have not been created. 231 | For instance, if A is a dataclass with attribute architecture_params, and we 232 | want to add the value architecture_params.rnn_model.size, we need to create 233 | the container 'rnn_model' inside architecture_params. 234 | 235 | Args: 236 | tree: An object with attribute (typically a dataclass). 237 | update_dict: A dict of nested updates. See example above. 238 | """ 239 | for key in update_dict: 240 | subkeys = key.split('.') 241 | if len(subkeys) >= 2: 242 | # Example: architecture.params.size 243 | for i in range(0, len(subkeys) - 2): 244 | getattr(tree, subkeys[i])[subkeys[i + 1]] = {} 245 | -------------------------------------------------------------------------------- /models/positional_encodings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Positional encodings, used in `transformer.py`.""" 17 | 18 | import enum 19 | import functools 20 | import math 21 | from typing import Any, Optional, Union 22 | 23 | import chex 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import jax.random as jrandom 28 | import numpy as np 29 | 30 | 31 | class PositionalEncodings(enum.Enum): 32 | """Enum for all the positional encodings implemented.""" 33 | 34 | NONE = 0 35 | SIN_COS = 1 36 | ALIBI = 2 37 | RELATIVE = 3 38 | ROTARY = 4 39 | LEARNT = 5 40 | NOISY_SIN_COS = 6 41 | NOISY_RELATIVE = 7 42 | NOISY_LEARNT = 8 43 | NOISY_ROTARY = 9 44 | NOISY_ALIBI = 10 45 | 46 | 47 | @chex.dataclass 48 | class SinCosParams: 49 | """Parameters for the classical sin/cos positional encoding.""" 50 | 51 | # The maximum wavelength used. 52 | max_time: int = 10_000 53 | 54 | 55 | # We will use this same class for Rotary and Relative. 56 | RotaryParams = SinCosParams 57 | RelativeParams = SinCosParams 58 | 59 | 60 | @chex.dataclass 61 | class LearntParams: 62 | """Parameters for the classical sin/cos positional encoding.""" 63 | 64 | # The size of the embedding matrix to use. 65 | max_sequence_length: int 66 | 67 | 68 | @chex.dataclass 69 | class NoisySinCosParams: 70 | """Parameters for the noisy sin/cos positional encoding.""" 71 | 72 | # The maximum length to sample. 73 | noise_max_length: int 74 | # The maximum wavelength used. 75 | max_time: int = 10_000 76 | 77 | 78 | @chex.dataclass 79 | class NoisyRelativeParams: 80 | """Parameters for the noisy relative positional encoding.""" 81 | 82 | # The maximum length to sample. 83 | noise_max_length: int 84 | # Either randomize the right side and keep the same encodings for the left 85 | # part, keeping the symmetry, or randomize each side independently. 86 | randomize_both_sides: bool = False 87 | # The maximum wavelength used. 88 | max_time: int = 10_000 89 | 90 | 91 | @chex.dataclass 92 | class NoisyLearntParams: 93 | """Parameters for the noisy relative positional encoding.""" 94 | 95 | # The maximum length to sample. 96 | noise_max_length: int 97 | 98 | 99 | @chex.dataclass 100 | class NoisyAlibiParams: 101 | """Parameters for the noisy alibi positional encoding.""" 102 | 103 | # The maximum length to sample. 104 | noise_max_length: int 105 | # Either randomize the right side and keep the same encodings for the left 106 | # part, maintaining symmetry, or randomize each side independently. 107 | randomize_both_sides: bool = False 108 | 109 | 110 | @chex.dataclass 111 | class NoisyRotaryParams: 112 | """Parameters for the noisy rotary positional encoding.""" 113 | 114 | # The maximum length to sample. 115 | noise_max_length: int 116 | 117 | 118 | PositionalEncodingsParams = Union[ 119 | SinCosParams, 120 | RelativeParams, 121 | RotaryParams, 122 | LearntParams, 123 | NoisySinCosParams, 124 | NoisyAlibiParams, 125 | NoisyRelativeParams, 126 | NoisyRotaryParams, 127 | NoisyLearntParams, 128 | ] 129 | 130 | 131 | POS_ENC_TABLE = { 132 | 'NONE': PositionalEncodings.NONE, 133 | 'SIN_COS': PositionalEncodings.SIN_COS, 134 | 'ALIBI': PositionalEncodings.ALIBI, 135 | 'RELATIVE': PositionalEncodings.RELATIVE, 136 | 'ROTARY': PositionalEncodings.ROTARY, 137 | 'LEARNT': PositionalEncodings.LEARNT, 138 | 'NOISY_SIN_COS': PositionalEncodings.NOISY_SIN_COS, 139 | 'NOISY_ALIBI': PositionalEncodings.NOISY_ALIBI, 140 | 'NOISY_RELATIVE': PositionalEncodings.NOISY_RELATIVE, 141 | 'NOISY_ROTARY': PositionalEncodings.NOISY_ROTARY, 142 | 'NOISY_LEARNT': PositionalEncodings.NOISY_LEARNT, 143 | } 144 | 145 | POS_ENC_PARAMS_TABLE = { 146 | 'NONE': SinCosParams, 147 | 'SIN_COS': SinCosParams, 148 | 'ALIBI': SinCosParams, 149 | 'RELATIVE': RelativeParams, 150 | 'ROTARY': RotaryParams, 151 | 'LEARNT': LearntParams, 152 | 'NOISY_SIN_COS': NoisySinCosParams, 153 | 'NOISY_ALIBI': NoisyAlibiParams, 154 | 'NOISY_RELATIVE': NoisyRelativeParams, 155 | 'NOISY_ROTARY': NoisyRotaryParams, 156 | 'NOISY_LEARNT': NoisyLearntParams, 157 | } 158 | 159 | 160 | def sinusoid_position_encoding( 161 | sequence_length: int, 162 | hidden_size: int, 163 | max_timescale: float = 1e4, 164 | add_negative_side: bool = False, 165 | ) -> np.ndarray: 166 | """Creates sinusoidal encodings from the original transformer paper. 167 | 168 | The returned values are, for all i < D/2: 169 | array[pos, i] = sin(pos / (max_timescale^(2*i / D))) 170 | array[pos, D/2 + i] = cos(pos / (max_timescale^(2*i / D))) 171 | 172 | Args: 173 | sequence_length: Sequence length. 174 | hidden_size: Dimension of the positional encoding vectors, D. Should be 175 | even. 176 | max_timescale: Maximum timescale for the frequency. 177 | add_negative_side: Whether to also include the positional encodings for 178 | negative positions. 179 | 180 | Returns: 181 | An array of shape [L, D] if add_negative_side is False, else [2 * L, D]. 182 | """ 183 | if hidden_size % 2 != 0: 184 | raise ValueError( 185 | 'The feature dimension should be even for sin/cos positional encodings.' 186 | ) 187 | freqs = np.arange(0, hidden_size, 2) 188 | inv_freq = max_timescale ** (-freqs / hidden_size) 189 | pos_seq = np.arange( 190 | start=-sequence_length if add_negative_side else 0, stop=sequence_length 191 | ) 192 | sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq) 193 | return np.concatenate([np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1) 194 | 195 | 196 | def noisy_fixed_positional_encodings( 197 | fixed_positional_encodings: chex.Array, 198 | sequence_length: int, 199 | rng: Optional[chex.PRNGKey] = None, 200 | ) -> chex.Array: 201 | """Generates noisy positional encodings from fixed positional encodings. 202 | 203 | Randomly samples and orders sequence_length positional encodings from a wider 204 | range [0, noise_max_length) rather than just [0, sequence_length). 205 | The user provides the full_encodings, which should span the entire range 206 | [0, noise_max_length). 207 | 208 | Args: 209 | fixed_positional_encodings: A tensor of shape (noise_max_length, 210 | embedding_size). This is from what the encodings will be sampled. 211 | sequence_length: The length of the output sequence. 212 | rng: Optional rng to use rather than hk.next_rng_key(). 213 | 214 | Returns: 215 | A tensor of size [sequence_length, embedding_size]. 216 | """ 217 | noise_max_length, _ = fixed_positional_encodings.shape 218 | indexes = jrandom.choice( 219 | rng if rng is not None else hk.next_rng_key(), 220 | jnp.arange(noise_max_length), 221 | shape=(sequence_length,), 222 | replace=False, 223 | ) 224 | indexes = jnp.sort(indexes) 225 | encodings = fixed_positional_encodings[indexes] 226 | return encodings 227 | 228 | 229 | def _rel_shift_inner(logits: jax.Array, attention_length: int) -> jax.Array: 230 | """Shifts the relative logits. 231 | 232 | This is a more general than the original Transformer-XL implementation as 233 | inputs may also see the future. (The implementation does not rely on a 234 | causal mask removing the upper-right triangle.) 235 | 236 | Given attention length 3 and inputs: 237 | [[-3, -2, -1, 0, 1, 2], 238 | [-3, -2, -1, 0, 1, 2], 239 | [-3, -2, -1, 0, 1, 2]] 240 | 241 | The shifted output is: 242 | [[0, 1, 2], 243 | [-1, 0, 1], 244 | [-2, -1, 0]] 245 | 246 | Args: 247 | logits: input tensor of shape [T_q, T_v + T_q] 248 | attention_length: T_v `int` length of the attention, should be equal to 249 | memory size + sequence length. 250 | 251 | Returns: 252 | A shifted version of the input of size [T_q, T_v]. In each row, a window of 253 | size T_v elements is kept. The window starts at 254 | subsequent row. 255 | """ 256 | if logits.ndim != 2: 257 | raise ValueError('`logits` needs to be an array of dimension 2.') 258 | tq, total_len = logits.shape 259 | assert total_len == tq + attention_length 260 | logits = jnp.reshape(logits, [total_len, tq]) 261 | logits = jnp.reshape(logits, [total_len, tq]) 262 | logits = jax.lax.slice(logits, (1, 0), logits.shape) # logits[1:] 263 | logits = jnp.reshape(logits, [tq, total_len - 1]) 264 | # Equiv to logits[:, :attention_length]. 265 | logits = jax.lax.slice(logits, (0, 0), (tq, attention_length)) 266 | return logits 267 | 268 | 269 | def _rel_shift_causal(logits: jax.Array) -> jax.Array: 270 | """Shifts the relative logits, assuming causal attention. 271 | 272 | Given inputs: 273 | [[-4, -3, -2, -1], 274 | [-4, -3, -2, -1]] 275 | 276 | The shifted (and, later, masked) output is: 277 | [[-3, -2, -1, 0], 278 | [-4, -3, -2, -1]] 279 | 280 | Args: 281 | logits: input tensor of shape [T_q, T_v] 282 | 283 | Returns: 284 | A shifted version of the input of size [T_q, T_v]. 285 | """ 286 | t1, t2 = logits.shape 287 | # We prepend zeros on the final timescale dimension. 288 | to_pad = jnp.zeros_like(logits[..., :1]) 289 | x = jnp.concatenate((to_pad, logits), axis=-1) 290 | 291 | # Reshape trick to shift input. 292 | x = jnp.reshape(x, [t2 + 1, t1]) 293 | 294 | # Remove extra time dimension and re-shape. 295 | x = jax.lax.slice(x, [1] + [0] * (x.ndim - 1), x.shape) 296 | 297 | return jnp.reshape(x, [t1, t2]) 298 | 299 | 300 | def relative_shift( 301 | logits: jax.Array, attention_length: int, causal: bool = False 302 | ) -> jax.Array: 303 | if causal: 304 | fn = _rel_shift_causal 305 | else: 306 | fn = lambda t: _rel_shift_inner(t, attention_length) 307 | return jax.vmap(jax.vmap(fn))(logits) 308 | 309 | 310 | def apply_rotary_encoding( 311 | x: jnp.ndarray, 312 | position: jnp.ndarray, 313 | max_time: int = 10_000, 314 | noisy: bool = False, 315 | rng: Optional[chex.PRNGKey] = None, 316 | ) -> jnp.ndarray: 317 | """Applies RoPE positional encodings for the input. 318 | 319 | Paper: https://arxiv.org/abs/2104.09864 320 | 321 | Args: 322 | x: The input tensor on which RoPE will be applied. Usually it is either some 323 | queries q or some keys k. 324 | position: The positions to use. Usually it's an arange of the maximum 325 | length. 326 | max_time: Constant used to scale position by in the encodings. 327 | noisy: Whether to use the noisy version. 328 | rng: The rng key to use if the noisy version is used. 329 | 330 | Returns: 331 | A tensor with the same shape as x. 332 | """ 333 | # Expand dims for positions to support inputs of shapes BTC or BTHC. 334 | freq_seq = jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) 335 | freq_seq = freq_seq / (x.shape[-1] // 2) 336 | inv_freq = max_time**-freq_seq 337 | inv_freq = jnp.repeat(inv_freq, 2, 0) 338 | # Produce position inputs to periodic functions. 339 | t = position[:, :, None, None] * inv_freq[None, None, None, :] 340 | if noisy: 341 | t = noisy_fixed_positional_encodings(t[0, :, 0], x.shape[1], rng=rng) 342 | t = t[None, :, None, :] 343 | x_rot = jnp.einsum('bthd,dD->bthD', x, _rope_kernel(x.shape[-1], x.dtype)) 344 | return x * jnp.cos(t).astype(x.dtype) + jnp.sin(t).astype(x.dtype) * x_rot 345 | 346 | 347 | def _rope_kernel(n: int, dtype: Any) -> np.ndarray: 348 | """Reorders the embedding dimension of an array, to make rotation easier.""" 349 | # We implement the equivalent of 350 | # even_dims, odd_dims, = x[..., ::2], x[..., 1::2] 351 | # return jnp.stack((-odd_dims, even_dims), axis=-1).reshape(x.shape) 352 | # with a custom kernel for einsum. This allows the computation to execute 353 | # on the MXU instead of producing a slow gather. 354 | assert n % 2 == 0, n 355 | kernel = np.zeros((n, n), dtype) 356 | for i in range(n): 357 | # Swap each neighbouring pair of values. 358 | if i % 2 == 0: 359 | kernel[i, i + 1] = 1 360 | else: 361 | kernel[i, i - 1] = -1 362 | return kernel 363 | 364 | 365 | def compute_attention_with_relative_encodings( 366 | queries: chex.Array, 367 | keys: chex.Array, 368 | max_time: int = 10_000, 369 | causal: bool = False, 370 | ) -> chex.Array: 371 | """Returns attention with relative positional encodings. 372 | 373 | This code strictly follows what is described in the TransformerXL paper. 374 | https://arxiv.org/pdf/1901.02860.pdf 375 | 376 | Args: 377 | queries: The queries used for attention. Shape (b, t, h, d). 378 | keys: The keys used for attention. Shape (b, T, h, d). 379 | max_time: Constant used to scale position by in the sin/cos encodings. 380 | causal: Whether to use causal attention when shifting the relative logits. 381 | 382 | Returns: 383 | The attention logits. Shape (b, h, t, T). 384 | """ 385 | batch_size, k_seq_len, num_heads, num_hiddens = keys.shape 386 | hiddens = num_hiddens * num_heads 387 | 388 | # First compute the content logits. 389 | content_bias = hk.get_parameter( 390 | name='relpos_contentbias', 391 | shape=[num_heads, num_hiddens], 392 | init=hk.initializers.RandomNormal(stddev=0.02), 393 | ) 394 | content_logits = jnp.einsum('bthd,bThd->bhtT', queries + content_bias, keys) 395 | 396 | positional_encodings = sinusoid_position_encoding( 397 | sequence_length=k_seq_len, 398 | hidden_size=hiddens, 399 | max_timescale=max_time, 400 | add_negative_side=not causal, 401 | ) 402 | positional_encodings = jnp.broadcast_to( 403 | positional_encodings, (batch_size,) + positional_encodings.shape 404 | ) 405 | relative_keys = hk.Linear(hiddens, with_bias=False)(positional_encodings) 406 | relative_keys = jnp.reshape( 407 | relative_keys, positional_encodings.shape[:-1] + (num_heads, num_hiddens) 408 | ) 409 | 410 | # Then compute the relative part. 411 | relative_bias = hk.get_parameter( 412 | name='relpos_relativebias', 413 | shape=[num_heads, num_hiddens], 414 | init=hk.initializers.RandomNormal(stddev=0.02), 415 | ) 416 | relative_logits = jnp.einsum( 417 | 'bthd,bThd->bhtT', queries + relative_bias, relative_keys 418 | ) 419 | # We shift the relative logits instead of the positional encoding matrix as 420 | # described in Appendix B of the paper (https://arxiv.org/pdf/1901.02860.pdf). 421 | relative_logits = relative_shift( 422 | relative_logits, attention_length=content_logits.shape[-1], causal=causal 423 | ) 424 | assert content_logits.shape == relative_logits.shape 425 | return content_logits + relative_logits 426 | 427 | 428 | def compute_attention_with_noisy_relative_encodings( 429 | queries: chex.Array, 430 | keys: chex.Array, 431 | noise_max_length: int, 432 | randomize_both_sides: bool = False, 433 | max_time: int = 10_000, 434 | causal: bool = False, 435 | ) -> chex.Array: 436 | """Returns attention with *noisy* relative positional encodings. 437 | 438 | This code follows what is described in the TransformerXL paper. 439 | https://arxiv.org/pdf/1901.02860.pdf 440 | However, in this version, the base positional encodings R (which are then 441 | shifted), are randomly sampled and ordered from a wider range than the 442 | sequence length. 443 | 444 | Args: 445 | queries: The queries used for attention. Shape (b, t, h, d). 446 | keys: The keys used for attention. Shape (b, T, h, d). 447 | noise_max_length: The maximum length used to sample the encodings. 448 | randomize_both_sides: Whether to sample the encodings on the left and on the 449 | right of the current token, or just sample from the left and take the 450 | inverted ones for the right part. 451 | max_time: Constant used to scale position by in the sin/cos encodings. 452 | causal: Whether to use causal attention when shifting the relative logits. 453 | 454 | Returns: 455 | The attention logits. Shape (b, h, t, T). 456 | """ 457 | batch_size, k_seq_len, num_heads, num_hiddens = keys.shape 458 | hiddens = num_hiddens * num_heads 459 | 460 | # First compute the content logits. 461 | content_bias = hk.get_parameter( 462 | name='relpos_contentbias', 463 | shape=[num_heads, num_hiddens], 464 | init=hk.initializers.RandomNormal(stddev=0.02), 465 | ) 466 | content_logits = jnp.einsum('bthd,bThd->bhtT', queries + content_bias, keys) 467 | 468 | # Select random indexes. 469 | # The indexes are in the range 470 | # [-noise_max_length + 1, noise_max_length - 1] 471 | right_indexes = jrandom.choice( 472 | hk.next_rng_key(), 473 | jnp.arange(1, noise_max_length), 474 | shape=(k_seq_len - 1,), 475 | replace=False, 476 | ) 477 | right_indexes = jnp.sort(right_indexes) 478 | if randomize_both_sides: 479 | left_indexes = jrandom.choice( 480 | hk.next_rng_key(), 481 | jnp.arange(start=-noise_max_length + 1, stop=0), 482 | shape=(k_seq_len,), 483 | replace=False, 484 | ) 485 | left_indexes = jnp.sort(left_indexes) 486 | else: 487 | left_indexes = -right_indexes[::-1] 488 | # The leftmost index is required by position_embedding.relative_shift. 489 | left_indexes = jnp.concatenate([jnp.zeros((1,)), left_indexes]) 490 | zero_index = jnp.zeros((1,)) 491 | indexes = jnp.concatenate([left_indexes, zero_index, right_indexes]) 492 | # We shift the indexes to the range [0, 2*noise_max_length-1], since this 493 | # will be the range of the sin/cos. In this array, the value at index 494 | # noise_max_length is the sin/cos encoding at position 0, which is exactly 495 | # what we want: when doing relative attention, the token should have a fixed 496 | # encoding of position 0 for its own position. 497 | indexes += noise_max_length 498 | indexes = jnp.array(indexes, dtype=jnp.int32) 499 | 500 | positional_encodings = sinusoid_position_encoding( 501 | sequence_length=noise_max_length, 502 | hidden_size=hiddens, 503 | max_timescale=max_time, 504 | ) 505 | positional_encodings = jnp.array(positional_encodings, dtype=jnp.float32) 506 | positional_encodings = positional_encodings[indexes] 507 | positional_encodings = jnp.broadcast_to( 508 | positional_encodings, (batch_size,) + positional_encodings.shape 509 | ) 510 | relative_keys = hk.Linear(hiddens, with_bias=False)(positional_encodings) 511 | relative_keys = jnp.reshape( 512 | relative_keys, positional_encodings.shape[:-1] + (num_heads, num_hiddens) 513 | ) 514 | 515 | # Then compute the relative part. 516 | relative_bias = hk.get_parameter( 517 | name='relpos_relativebias', 518 | shape=[num_heads, num_hiddens], 519 | init=hk.initializers.RandomNormal(stddev=0.02), 520 | ) 521 | relative_logits = jnp.einsum( 522 | 'bthd,bThd->bhtT', queries + relative_bias, relative_keys 523 | ) 524 | # We shift the relative logits instead of the positional encoding matrix as 525 | # described in Appendix B of the paper (https://arxiv.org/pdf/1901.02860.pdf). 526 | relative_logits = relative_shift( 527 | relative_logits, attention_length=content_logits.shape[-1], causal=causal 528 | ) 529 | assert content_logits.shape == relative_logits.shape 530 | return content_logits + relative_logits 531 | 532 | 533 | def _get_alibi_slopes(num_heads: int) -> list[float]: 534 | """Returns the slopes for the different attention heads. 535 | 536 | While this does not exactly match the description of the [ALiBi 537 | paper](https://arxiv.org/pdf/2108.12409.pdf), it corresponds to the [official 538 | implementation](https://github.com/ofirpress/attention_with_linear_biases/blob/a06526fbfe557f9148e414b8569dcb97c7b182ba/fairseq/models/transformer.py#L742). 539 | 540 | Args: 541 | num_heads: The number of attention heads to create slopes for. 542 | """ 543 | 544 | def get_slopes_power_of_2(n): 545 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 546 | ratio = start 547 | return [start * ratio**i for i in range(n)] 548 | 549 | if math.log2(num_heads).is_integer(): 550 | return get_slopes_power_of_2(num_heads) 551 | else: 552 | closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) 553 | return ( 554 | get_slopes_power_of_2(closest_power_of_2) 555 | + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ 556 | : num_heads - closest_power_of_2 557 | ] 558 | ) 559 | 560 | 561 | def compute_alibi_encodings_biases( 562 | attention_shape: tuple[int, ...] 563 | ) -> chex.Array: 564 | """Returns the biases following the ALiBi method. 565 | 566 | This code strictly follows what is described in the ALiBi paper. 567 | https://arxiv.org/pdf/2108.12409.pdf 568 | 569 | Args: 570 | attention_shape: The attention logits shape, without batch size, (h, t, T). 571 | 572 | Returns: 573 | The alibi biases, same shape as the input logits shape. 574 | """ 575 | num_heads, q_seq_len, k_seq_len = attention_shape 576 | 577 | # Since we do not use causal masking, the upper triangle of the matrix has to 578 | # be nonzero. Therefore, we set it equal to the lower triangle, but we also 579 | # add a constant factor of 0.5 to the lower triangle, to (arbitrarily) break 580 | # the symmetry (otherwise, the model cannot distinguish left and right). 581 | alibi = np.zeros((q_seq_len, k_seq_len)) 582 | alibi -= sum(np.tri(*alibi.shape, k=-i) for i in range(1, q_seq_len)) 583 | alibi -= sum(np.tri(*alibi.T.shape, k=-i).T for i in range(1, k_seq_len)) 584 | alibi += 0.5 * np.tri(*alibi.shape, k=-1) 585 | 586 | return alibi * jnp.array(_get_alibi_slopes(num_heads))[:, None, None] 587 | 588 | 589 | def compute_noisy_alibi_encodings_biases( 590 | attention_shape: tuple[int, ...], 591 | noise_max_length: int, 592 | randomize_both_sides: bool = False, 593 | ) -> chex.Array: 594 | """Returns the biases following the ALiBi method. 595 | 596 | This code strictly follows what is described in the [ALiBi 597 | paper](https://arxiv.org/pdf/2108.12409.pdf). 598 | However, in this version, the biases are randomly sampled and ordered from a 599 | wider range than the sequence length. 600 | 601 | Args: 602 | attention_shape: The attention logits shape, without batch size, (h, t, T). 603 | noise_max_length: The maximum length used to sample the encodings. 604 | randomize_both_sides: Whether to sample the encodings on the left and on the 605 | right of the current token or just sample from the left and take the 606 | inverted ones for the right part. 607 | 608 | Returns: 609 | The alibi biases, same shape as the input logits shape. 610 | """ 611 | num_heads, q_seq_len, k_seq_len = attention_shape 612 | 613 | sample_positions = functools.partial( 614 | jrandom.choice, 615 | a=jnp.arange(1, noise_max_length), 616 | replace=False, 617 | ) 618 | 619 | if randomize_both_sides: 620 | right_positions = sample_positions( 621 | hk.next_rng_key(), shape=(k_seq_len - 1,) 622 | ) 623 | left_positions = sample_positions(hk.next_rng_key(), shape=(q_seq_len - 1,)) 624 | right_positions = -jnp.sort(right_positions) 625 | left_positions = jnp.sort(-left_positions) 626 | 627 | else: 628 | symmetric_positions = sample_positions( 629 | hk.next_rng_key(), shape=(max(q_seq_len, k_seq_len) - 1,) 630 | ) 631 | symmetric_positions = -jnp.sort(symmetric_positions) 632 | right_positions = symmetric_positions[: k_seq_len - 1] 633 | left_positions = jnp.flip(symmetric_positions)[: q_seq_len - 1] 634 | 635 | # Since we do not use causal masking, the upper triangle of the matrix has to 636 | # be nonzero. Therefore, we set it equal to the lower triangle if 637 | # `randomize_both_side` is `False` and to randomly sampled positions 638 | # otherwise, but we also add a constant factor of 0.5 to the lower triangle, 639 | # to (arbitrarily) break the symmetry (otherwise, the model cannot distinguish 640 | # left and right). 641 | left_positions += 0.5 642 | 643 | # We add a dummy value to make the dimensions work for 644 | # position_embedding.relative_shift. The value will be ignored. 645 | left_positions = jnp.concatenate((jnp.empty((1,)), left_positions)) 646 | 647 | positions = jnp.concatenate( 648 | (left_positions, jnp.zeros((1,)), right_positions) 649 | ) 650 | # position_embedding.relative_shift requires a four-dimensional tensor. 651 | positions = jnp.tile(positions, (1, 1, q_seq_len, 1)) 652 | 653 | alibi = relative_shift( 654 | positions, 655 | attention_length=k_seq_len, 656 | causal=False, 657 | ) 658 | alibi = jnp.squeeze(alibi, axis=(0, 1)) 659 | 660 | return alibi * jnp.array(_get_alibi_slopes(num_heads))[:, None, None] 661 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Transformer model.""" 17 | 18 | import dataclasses 19 | from typing import Any, Callable, Optional, Union 20 | 21 | from absl import logging 22 | import chex 23 | import haiku as hk 24 | import jax 25 | import jax.nn as jnn 26 | import jax.numpy as jnp 27 | 28 | from randomized_positional_encodings.models import positional_encodings as pos_encs_lib 29 | from randomized_positional_encodings.models import transformer_utils 30 | 31 | 32 | @chex.dataclass 33 | class TransformerConfig: 34 | """Hyperparameters used in the Transformer architectures.""" 35 | 36 | # The dimension of the first embedding. 37 | embedding_dim: int = 64 38 | # The number of multi-head attention layers. 39 | num_layers: int = 5 40 | # The number of heads per layer. 41 | num_heads: int = 8 42 | # The number of hidden neurons per head. If None, it is set to be equal to 43 | # `embedding_dim // num_heads`. 44 | num_hiddens_per_head: Optional[int] = None 45 | # The probability that each element is discarded by the dropout modules. 46 | # None means dropout is not used at all. 47 | dropout_prob: Optional[float] = 0.1 48 | # The parameter initialization scale for the embeddings. 49 | emb_init_scale: float = 0.02 50 | # Whether to use the embeddings rather than raw inputs. 51 | use_embeddings: bool = True 52 | # Whether to use lookup-embeddings, in which case the inputs must be ints. 53 | use_lookup_embeddings: bool = False 54 | # Input vocabulary size, not needed if use_lookup_embeddings is False. 55 | input_vocab_size: Optional[int] = None 56 | # Whether to share embeddings between the Encoder and the Decoder. 57 | share_embeddings: bool = False 58 | # The size of the sliding attention window. See MultiHeadDotProductAttention. 59 | attention_window: Optional[int] = None 60 | # The positional encoding used with default sin/cos (Vaswani et al., 2017). 61 | positional_encodings: pos_encs_lib.PositionalEncodings = dataclasses.field( 62 | default_factory=lambda: pos_encs_lib.PositionalEncodings.SIN_COS 63 | ) 64 | # The parameters for the positional encodings, default sin/cos. 65 | positional_encodings_params: pos_encs_lib.PositionalEncodingsParams = ( 66 | dataclasses.field(default_factory=pos_encs_lib.SinCosParams) 67 | ) 68 | # How much larger the hidden layer of the feedforward network should be 69 | # compared to the `embedding_dim`. 70 | widening_factor: int = 4 71 | # Which activation function to use. 72 | activation_fn: Callable[[jax.Array], jax.Array] = jnn.relu 73 | # Add mask to make causal predictions. All the decoders use causal masking by 74 | # default, this option is only used in the encoder. This is quite unusual but 75 | # can still be useful in some rare cases. 76 | encoder_causal_masking: bool = False 77 | # Which token to use for the beginning of the string. None means an array 78 | # full of zeros will be used. 79 | bos_token: Optional[int] = None 80 | # Used by the chunked transformer. 81 | chunk_context_length: Optional[int] = None 82 | 83 | def __post_init__(self) -> None: 84 | """Runs after the config has been created.""" 85 | if self.num_hiddens_per_head is None: 86 | self.num_hiddens_per_head = self.embedding_dim // self.num_heads 87 | 88 | if self.positional_encodings is None: 89 | self.positional_encodings = pos_encs_lib.PositionalEncodings.SIN_COS 90 | self.positional_encodings_params = pos_encs_lib.SinCosParams() 91 | elif self.positional_encodings_params is None: 92 | raise ValueError('No parameters for positional encodings are passed.') 93 | elif not isinstance( 94 | self.positional_encodings, pos_encs_lib.PositionalEncodings 95 | ) or not isinstance( 96 | self.positional_encodings_params, pos_encs_lib.PositionalEncodingsParams 97 | ): 98 | raise ValueError( 99 | "The positional encodings passed are not of the right type. You're" 100 | ' probably passing strings rather than actual objects.' 101 | ) 102 | 103 | 104 | class MultiHeadDotProductAttention(hk.Module): 105 | """Multi-head dot-product attention (Vaswani et al., 2017).""" 106 | 107 | def __init__( 108 | self, 109 | num_heads: int, 110 | num_hiddens_per_head: int, 111 | positional_encodings: Optional[pos_encs_lib.PositionalEncodings] = None, 112 | positional_encodings_params: Optional[ 113 | pos_encs_lib.PositionalEncodingsParams 114 | ] = None, 115 | attention_window: Optional[int] = None, 116 | name: Optional[str] = None, 117 | ) -> None: 118 | """Initializes the attention module. 119 | 120 | Args: 121 | num_heads: Number of heads to use. 122 | num_hiddens_per_head: Number of hidden neurons per head. 123 | positional_encodings: Which positional encodings to use in the attention. 124 | None means no positional encodings are applied to keys or queries. 125 | positional_encodings_params: Parameters for the positional encodings. 126 | attention_window: Size of the attention sliding window. None means no 127 | sliding window is used (or equivalently, window=full_attention_length). 128 | We attend only on attention_window tokens around a given query token. We 129 | attend to tokens before AND after the query token. If attention_window 130 | is even, we use the value +1. 131 | name: Name of the module. 132 | """ 133 | super().__init__(name=name) 134 | self._num_heads = num_heads 135 | self._num_hiddens_per_head = num_hiddens_per_head 136 | self._positional_encodings = positional_encodings 137 | self._attention_window = attention_window 138 | self._positional_encodings_params = ( 139 | positional_encodings_params # pytype: disable=annotation-type-mismatch 140 | ) 141 | 142 | def __call__( 143 | self, 144 | inputs_q: chex.Array, 145 | inputs_kv: chex.Array, 146 | mask: Optional[chex.Array] = None, 147 | causal: bool = False, 148 | ) -> chex.Array: 149 | """Returns the output of the multi-head attention.""" 150 | batch_size, sequence_length, embedding_size = inputs_q.shape 151 | 152 | num_hiddens = self._num_hiddens_per_head * self._num_heads 153 | q = hk.Linear(num_hiddens, with_bias=False)(inputs_q) 154 | k = hk.Linear(num_hiddens, with_bias=False)(inputs_kv) 155 | v = hk.Linear(num_hiddens, with_bias=False)(inputs_kv) 156 | # The second (sequence) dimension is undefined since it can differ between 157 | # queries and keys/values when decoding. Also checking that the inputs have 158 | # the same batch size as the reshape below does not guarantee a failure if 159 | # they are different. 160 | chex.assert_equal_shape_prefix([inputs_q, inputs_kv], prefix_len=1) 161 | new_shape = (batch_size, -1, self._num_heads, self._num_hiddens_per_head) 162 | q = jnp.reshape(q, new_shape) 163 | k = jnp.reshape(k, new_shape) 164 | v = jnp.reshape(v, new_shape) 165 | 166 | # Let b=batch_size, t=seq_len, h=num_heads, and d=num_hiddens_per_head. 167 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.RELATIVE: 168 | # We type hint the params to match the if statement, for pytype. 169 | self._positional_encodings_params: pos_encs_lib.RelativeParams 170 | attention = pos_encs_lib.compute_attention_with_relative_encodings( 171 | q, k, self._positional_encodings_params.max_time, causal=causal 172 | ) 173 | elif ( 174 | self._positional_encodings 175 | == pos_encs_lib.PositionalEncodings.NOISY_RELATIVE 176 | ): 177 | if causal: 178 | raise NotImplementedError( 179 | 'Noisy positional encodings not implemented for causal attention.' 180 | ) 181 | # We type hint the params to match the if statement, for pytype. 182 | self._positional_encodings_params: pos_encs_lib.NoisyRelativeParams 183 | attention = pos_encs_lib.compute_attention_with_noisy_relative_encodings( 184 | q, 185 | k, 186 | max_time=self._positional_encodings_params.max_time, 187 | noise_max_length=self._positional_encodings_params.noise_max_length, 188 | randomize_both_sides=self._positional_encodings_params.randomize_both_sides, 189 | causal=causal, 190 | ) 191 | else: 192 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ROTARY: 193 | q = pos_encs_lib.apply_rotary_encoding( 194 | q, position=jnp.arange(q.shape[1])[None, :] 195 | ) 196 | k = pos_encs_lib.apply_rotary_encoding( 197 | k, position=jnp.arange(k.shape[1])[None, :] 198 | ) 199 | elif ( 200 | self._positional_encodings 201 | == pos_encs_lib.PositionalEncodings.NOISY_ROTARY 202 | ): 203 | # We type hint the params to match the if statement, for pytype. 204 | self._positional_encodings_params: pos_encs_lib.NoisyRotaryParams 205 | noise_max_length = self._positional_encodings_params.noise_max_length 206 | # WARNING: This only works with self-attention, ie q.shape==k.shape. 207 | rng = hk.next_rng_key() 208 | q = pos_encs_lib.apply_rotary_encoding( 209 | q, 210 | position=jnp.arange(noise_max_length)[None, :], 211 | noisy=True, 212 | rng=rng, 213 | ) 214 | k = pos_encs_lib.apply_rotary_encoding( 215 | k, 216 | position=jnp.arange(noise_max_length)[None, :], 217 | noisy=True, 218 | rng=rng, 219 | ) 220 | attention = jnp.einsum('bthd,bThd->bhtT', q, k) 221 | attention *= 1.0 / jnp.sqrt(self._num_hiddens_per_head) 222 | 223 | # ALiBi encodings are not scaled with the 1 / sqrt(d_k) factor. 224 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ALIBI: 225 | attention += pos_encs_lib.compute_alibi_encodings_biases( 226 | attention.shape[1:] 227 | ) 228 | if ( 229 | self._positional_encodings 230 | == pos_encs_lib.PositionalEncodings.NOISY_ALIBI 231 | ): 232 | # We type hint the params to match the if statement, for pytype. 233 | self._positional_encodings_params: pos_encs_lib.NoisyAlibiParams 234 | attention += pos_encs_lib.compute_noisy_alibi_encodings_biases( 235 | attention.shape[1:], 236 | noise_max_length=self._positional_encodings_params.noise_max_length, 237 | randomize_both_sides=self._positional_encodings_params.randomize_both_sides, 238 | ) 239 | 240 | if self._attention_window is not None: 241 | # We compute the sliding attention by just applying a mask on the values 242 | # that are outside our window. 243 | attention_mask = transformer_utils.compute_sliding_window_mask( 244 | sequence_length, self._attention_window 245 | ) 246 | attention = jnp.where( 247 | attention_mask, attention, jnp.finfo(jnp.float32).min 248 | ) 249 | 250 | if mask is not None: 251 | attention = jnp.where(mask, attention, jnp.finfo(jnp.float32).min) 252 | 253 | normalized_attention = jnn.softmax(attention) 254 | 255 | output = jnp.einsum('bhtT,bThd->bthd', normalized_attention, v) 256 | output = jnp.reshape(output, (batch_size, sequence_length, num_hiddens)) 257 | return hk.Linear(embedding_size, with_bias=False)(output) 258 | 259 | 260 | class TransformerInit(hk.Module): 261 | """Helper class to avoid repeating the same __init__.""" 262 | 263 | def __init__(self, config: TransformerConfig): 264 | """Initializes the module.""" 265 | super().__init__() 266 | self._config = config 267 | if self._config.use_lookup_embeddings and self._config.bos_token is None: 268 | raise ValueError("Can't use lookup embeddings with a zero bos_token.") 269 | 270 | 271 | class TransformerEmbedder(TransformerInit): 272 | """A module to embed sequences and add positional encodings if needed.""" 273 | 274 | def embed_sequences(self, sequences: chex.Array) -> chex.Array: 275 | """Returns embedded sequences, following a linear operation or hk.Embed.""" 276 | embs_init = hk.initializers.TruncatedNormal( 277 | stddev=self._config.emb_init_scale 278 | ) 279 | if self._config.use_lookup_embeddings: 280 | embeddings_layer = hk.Embed( 281 | vocab_size=self._config.input_vocab_size, 282 | embed_dim=self._config.embedding_dim, 283 | lookup_style=hk.EmbedLookupStyle.ARRAY_INDEX, 284 | w_init=embs_init, 285 | ) 286 | integer_sequences = jnp.argmax(sequences, axis=-1) 287 | embeddings = embeddings_layer(integer_sequences) 288 | else: 289 | embeddings_layer = hk.Linear( 290 | self._config.embedding_dim, 291 | with_bias=False, 292 | w_init=embs_init, 293 | ) 294 | embeddings = embeddings_layer(sequences) 295 | 296 | embeddings *= jnp.sqrt(self._config.embedding_dim) 297 | return embeddings 298 | 299 | def add_positional_encodings(self, embeddings: chex.Array) -> chex.Array: 300 | """Returns new embeddings, which have been added positional encodings. 301 | 302 | The shape of the returned array is (B, T, E), where E is the dimension of 303 | the embeddings (if any are used, otherwise E = F). 304 | 305 | Args: 306 | embeddings: A batch of embeddings, of shape (B, T, F). 307 | """ 308 | chex.assert_rank(embeddings, 3) 309 | 310 | _, sequence_length, embedding_size = embeddings.shape 311 | 312 | pos_enc_params = self._config.positional_encodings_params 313 | if ( 314 | self._config.positional_encodings 315 | == pos_encs_lib.PositionalEncodings.SIN_COS 316 | ): 317 | pos_enc_params: pos_encs_lib.SinCosParams 318 | pos_encodings = pos_encs_lib.sinusoid_position_encoding( 319 | sequence_length=sequence_length, 320 | hidden_size=embedding_size, 321 | max_timescale=pos_enc_params.max_time, 322 | ) 323 | h = embeddings + pos_encodings 324 | if self._config.dropout_prob is not None: 325 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 326 | elif ( 327 | self._config.positional_encodings 328 | == pos_encs_lib.PositionalEncodings.NOISY_SIN_COS 329 | ): 330 | pos_enc_params: pos_encs_lib.NoisySinCosParams 331 | if pos_enc_params.noise_max_length > pos_enc_params.max_time: 332 | logging.warning( 333 | ( 334 | 'noise_max_length=%i is larger than max_time=%i, some ' 335 | 'positional encodings will be equal.' 336 | ), 337 | pos_enc_params.noise_max_length, 338 | pos_enc_params.max_time, 339 | ) 340 | pos_encodings = pos_encs_lib.sinusoid_position_encoding( 341 | sequence_length=pos_enc_params.noise_max_length, 342 | hidden_size=embedding_size, 343 | max_timescale=pos_enc_params.max_time, 344 | ) 345 | pos_encodings = jnp.array(pos_encodings) 346 | pos_encodings = pos_encs_lib.noisy_fixed_positional_encodings( 347 | pos_encodings, sequence_length 348 | ) 349 | h = embeddings + pos_encodings 350 | if self._config.dropout_prob is not None: 351 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 352 | elif ( 353 | self._config.positional_encodings 354 | == pos_encs_lib.PositionalEncodings.LEARNT 355 | ): 356 | pos_enc_params: pos_encs_lib.LearntParams 357 | pos_encodings = jnp.arange(sequence_length) 358 | pos_encodings = hk.Embed( 359 | vocab_size=pos_enc_params.max_sequence_length, 360 | embed_dim=embedding_size, 361 | )(pos_encodings) 362 | h = embeddings + pos_encodings 363 | if self._config.dropout_prob is not None: 364 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 365 | elif ( 366 | self._config.positional_encodings 367 | == pos_encs_lib.PositionalEncodings.NOISY_LEARNT 368 | ): 369 | pos_enc_params: pos_encs_lib.NoisyLearntParams 370 | pos_encodings = jnp.arange(pos_enc_params.noise_max_length) 371 | pos_encodings = hk.Embed( 372 | vocab_size=pos_enc_params.noise_max_length, embed_dim=embedding_size 373 | )(pos_encodings) 374 | pos_encodings = pos_encs_lib.noisy_fixed_positional_encodings( 375 | pos_encodings, sequence_length 376 | ) 377 | h = embeddings + pos_encodings 378 | if self._config.dropout_prob is not None: 379 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 380 | else: 381 | h = embeddings 382 | return h 383 | 384 | 385 | class TransformerEncoder(TransformerInit): 386 | """Transformer Encoder (Vaswani et al., 2017).""" 387 | 388 | def __call__(self, inputs: jnp.ndarray) -> chex.Array: 389 | """Returns the transformer encoder output, shape [B, T, E].""" 390 | batch_size, sequence_length = inputs.shape[:2] 391 | # Embeds the inputs, adds positional encodings. 392 | embedder = TransformerEmbedder(self._config) 393 | embeddings = embedder.embed_sequences(inputs) 394 | h = embedder.add_positional_encodings(embeddings) 395 | 396 | # The causal mask is shared across heads. 397 | if self._config.encoder_causal_masking: 398 | causal_mask = jnp.tril( 399 | jnp.ones((batch_size, 1, sequence_length, sequence_length)) 400 | ) 401 | else: 402 | causal_mask = None 403 | 404 | for _ in range(self._config.num_layers): 405 | attention = MultiHeadDotProductAttention( 406 | num_heads=self._config.num_heads, 407 | num_hiddens_per_head=self._config.num_hiddens_per_head, 408 | positional_encodings=self._config.positional_encodings, 409 | positional_encodings_params=self._config.positional_encodings_params, 410 | attention_window=self._config.attention_window, 411 | )( 412 | inputs_q=h, 413 | inputs_kv=h, 414 | mask=causal_mask, 415 | causal=self._config.encoder_causal_masking, 416 | ) 417 | if self._config.dropout_prob is not None: 418 | attention = hk.dropout( 419 | hk.next_rng_key(), self._config.dropout_prob, attention 420 | ) 421 | attention = transformer_utils.layer_norm(h + attention) 422 | 423 | # Position-wise feedforward network. 424 | h = hk.Linear(self._config.embedding_dim * self._config.widening_factor)( 425 | attention 426 | ) 427 | h = self._config.activation_fn(h) 428 | h = hk.Linear(self._config.embedding_dim)(h) 429 | 430 | if self._config.dropout_prob is not None: 431 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h) 432 | h = transformer_utils.layer_norm(h + attention) 433 | return h 434 | 435 | 436 | class ChunkedTransformerEncoder(TransformerInit): 437 | """A Transformer encoder that can handle large histories via chunks. 438 | 439 | We chunk the inputs, moving from a shape (B, T, F) to a shape (B, T/C, C, F), 440 | where C is the length of the chunk. Note that T must be a multiple of C for it 441 | to work. The chunks are then passed independently to the encoder, and all the 442 | outputs are then concatenated together, to return a shape (B, T, E), where E 443 | is the embedding_dim of the TransformerEncoder, see class above. 444 | """ 445 | 446 | def __call__(self, inputs: chex.Array) -> jnp.ndarray: 447 | """Calls the chunked transformer encoder.""" 448 | batch_size, history_len = inputs.shape[:2] 449 | inputs = transformer_utils.chunk_sequences( 450 | inputs, chunk_length=self._config.chunk_context_length 451 | ) 452 | outputs = TransformerEncoder(self._config)(inputs=inputs) 453 | return jnp.reshape(outputs, (batch_size, history_len, outputs.shape[-1])) 454 | 455 | 456 | CallableTransformer = Union[ 457 | ChunkedTransformerEncoder, 458 | TransformerEncoder, 459 | ] 460 | 461 | 462 | def make_transformer( 463 | output_size: int, 464 | transformer_module: type[CallableTransformer], 465 | return_all_outputs: bool = False, 466 | **transformer_kwargs, 467 | ) -> Any: 468 | """Returns a transformer predict function.""" 469 | 470 | if 'positional_encodings' in transformer_kwargs: 471 | if isinstance(transformer_kwargs['positional_encodings'], str): 472 | transformer_kwargs['positional_encodings_params'] = ( 473 | pos_encs_lib.POS_ENC_PARAMS_TABLE[ 474 | transformer_kwargs['positional_encodings'] 475 | ](**transformer_kwargs['positional_encodings_params']) 476 | ) 477 | transformer_kwargs['positional_encodings'] = pos_encs_lib.POS_ENC_TABLE[ 478 | transformer_kwargs['positional_encodings'] 479 | ] 480 | 481 | config = TransformerConfig(**transformer_kwargs) 482 | 483 | def transformer(*args, **kwargs) -> chex.Array: 484 | output = transformer_module(config=config)(*args, **kwargs) 485 | if not return_all_outputs: 486 | output = output[:, -1, :] 487 | return hk.Linear(output_size)(output) 488 | 489 | return transformer 490 | -------------------------------------------------------------------------------- /models/transformer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for the transformer architectures.""" 17 | 18 | import chex 19 | import haiku as hk 20 | import jax.numpy as jnp 21 | 22 | 23 | def layer_norm(x: chex.Array) -> chex.Array: 24 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x) 25 | 26 | 27 | def chunk_sequences(sequences: chex.Array, chunk_length: int) -> chex.Array: 28 | """Chunks an array of sequences, on the second (time) dimension. 29 | 30 | Args: 31 | sequences: An array of sequences, of shape (B, T, F). 32 | chunk_length: The length of each chunk. 33 | 34 | Returns: 35 | An array of shape (B, T // chunk_length, chunk_length, F) 36 | Raises: 37 | ValueError if T is not a multiple of chunk_length. 38 | """ 39 | chex.assert_rank(sequences, 3) 40 | batch_size, history_len, num_features = sequences.shape 41 | if history_len < chunk_length: 42 | context_length = history_len 43 | elif history_len % chunk_length == 0: 44 | context_length = chunk_length 45 | else: 46 | raise ValueError( 47 | 'The history length should a multiple of the context length. Got' 48 | f' history_length={history_len} and' 49 | f' context_length={chunk_length}' 50 | ) 51 | 52 | history_batch_size = history_len // context_length 53 | return jnp.reshape( 54 | sequences, 55 | (batch_size * history_batch_size, context_length, num_features), 56 | ) 57 | 58 | 59 | def compute_sliding_window_mask( 60 | sequence_length: int, attention_window: int 61 | ) -> chex.Array: 62 | """Returns a k-diagonal mask for a sliding window. 63 | 64 | Args: 65 | sequence_length: The length of the sequence, which will determine the shape 66 | of the output. 67 | attention_window: The size of the sliding window. 68 | 69 | Returns: 70 | A symmetric matrix of shape (sequence_length, sequence_length), 71 | attention_window-diagonal, with ones on the diagonal and on all the 72 | upper/lower diagonals up to attention_window // 2. 73 | 74 | Raises: 75 | ValueError if attention_window is <= 0. 76 | """ 77 | if attention_window <= 0: 78 | raise ValueError( 79 | f'The attention window should be > 0. Got {attention_window}.' 80 | ) 81 | 82 | if attention_window == 1: 83 | return jnp.eye(sequence_length, sequence_length) 84 | 85 | attention_mask = jnp.sum( 86 | jnp.stack( 87 | [ 88 | jnp.eye(sequence_length, sequence_length, k=k, dtype=jnp.int32) 89 | for k in range(1, attention_window // 2 + 1) 90 | ] 91 | ), 92 | axis=0, 93 | ) 94 | attention_mask = attention_mask + jnp.transpose(attention_mask) 95 | attention_mask += jnp.eye(sequence_length, sequence_length) 96 | return attention_mask 97 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/randomized_positional_encodings/f3ef05a7dc2c4f0f71d4efaae3512ccd158b00d7/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | chex 3 | dm-haiku 4 | dm-tree 5 | git+https://github.com/deepmind/einshape 6 | jax 7 | numpy 8 | optax 9 | tqdm 10 | typing_extensions 11 | -------------------------------------------------------------------------------- /tasks/cs/binary_addition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Add two binary numbers.""" 17 | 18 | import random 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | def numbers_to_variable_length_binary( 30 | numbers: Sequence[int], 31 | lengths: Sequence[int], 32 | little_endian: bool = True, 33 | ) -> list[list[int]]: 34 | """Returns the binary notation of a certain length for a sequence of numbers. 35 | 36 | Args: 37 | numbers: The numbers to be converted to binary. 38 | lengths: The lengths of the binary representations (every number uses its 39 | own length). This argument has no effect if the binary representation is 40 | longer than the specified length. 41 | little_endian: Whether to use little- or big-endian notation. 42 | """ 43 | binary_strings = [f'{num:b}'.zfill(len) for num, len in zip(numbers, lengths)] 44 | 45 | if little_endian: 46 | binary_strings = [bin[::-1] for bin in binary_strings] 47 | 48 | return [list(map(int, bin)) for bin in binary_strings] 49 | 50 | 51 | def numbers_to_fixed_length_binary( 52 | numbers: Sequence[int], 53 | length: int, 54 | little_endian: bool = True, 55 | ) -> list[list[int]]: 56 | """Returns the binary notation of a certain length for a sequence of numbers. 57 | 58 | Args: 59 | numbers: The numbers to be converted to binary. 60 | length: The length of the binary representations (all numbers use the same 61 | length). This argument has no effect if the binary representation is 62 | longer than the specified length. 63 | little_endian: Whether to use little- or big-endian notation. 64 | """ 65 | return numbers_to_variable_length_binary( 66 | numbers=numbers, 67 | lengths=[length] * len(numbers), 68 | little_endian=little_endian, 69 | ) 70 | 71 | 72 | def expression_from_numbers( 73 | numbers_n: Sequence[list[int]], 74 | numbers_m: Sequence[list[int]], 75 | ) -> list[list[int]]: 76 | """Returns an expression with a placeholder value to denote the operation.""" 77 | return [n + [2] + m for n, m in zip(numbers_n, numbers_m)] 78 | 79 | 80 | class BinaryAddition(task.GeneralizationTask): 81 | """A task with the goal of summing two numbers in binary (little-endian). 82 | 83 | The input is a string of the form `first_number+second_number` in 84 | (little-endian) binary notation (e.g., `01001+011`). The goal of the agent is 85 | to output the result, also in (little-endian) binary form (i.e., in the 86 | example `18 + 6 = 24 = 00011`). The output is padded with 0s to match the 87 | input length, and the end of the sum is denoted with a termination token 88 | (i.e., the output has values in `{0, 1, 2}`). 89 | 90 | Examples: 91 | 001 + 01101 = 010112000 (4 + 22 = 26) 92 | 1001 + 000001 = 10010120000 (9 + 32 = 41) 93 | """ 94 | 95 | def _sample_expressions_and_results( 96 | self, 97 | batch_size: int, 98 | length: int, 99 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]: 100 | """Samples pairs of numbers and sums them in (little-endian) binary. 101 | 102 | We use Python's bignums, which can represent arbitrary-precision integers to 103 | perform addition of two potentially very large values (roughly of the size 104 | `2 ** (length // 2)`). 105 | 106 | Args: 107 | batch_size: The number of expressions and results to sample. 108 | length: The length of the input expression containing the two numbers and 109 | the separation token. 110 | 111 | Returns: 112 | The expression and the sum of the two numbers. The expression has the 113 | format: `[first_number, 2, second_number]`, where the numbers are in 114 | (little-endian) binary notation. The sum is also in (little-endian) binary 115 | notation, without leading (i.e., ending) zeros. 116 | """ 117 | # If `length <= 2`, we just sample a binary value and return it (without 118 | # leading zeros in little-endian notation). 119 | if length <= 2: 120 | # Since `length <= 2`, we can use `np.random`` without overflow errors. 121 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size)) 122 | expressions = numbers_to_fixed_length_binary(numbers, length) 123 | results = numbers_to_fixed_length_binary(numbers, 0) 124 | return expressions, results 125 | 126 | # We only use `length - 1` tokens for the two values to account for the `+`. 127 | length_n = np.random.randint(1, length - 1, size=(batch_size,)) 128 | length_m = length - 1 - length_n 129 | 130 | integer_n = [random.randint(1, 2 ** int(len_n) - 1) for len_n in length_n] 131 | integer_m = [random.randint(1, 2 ** int(len_m) - 1) for len_m in length_m] 132 | 133 | binary_n = numbers_to_variable_length_binary(integer_n, length_n) 134 | binary_m = numbers_to_variable_length_binary(integer_m, length_m) 135 | 136 | expressions = expression_from_numbers(binary_n, binary_m) 137 | 138 | integer_sum = list(map(sum, zip(integer_n, integer_m))) 139 | results = numbers_to_fixed_length_binary(integer_sum, length=0) 140 | 141 | return expressions, results 142 | 143 | def sample_batch( 144 | self, 145 | rng: chex.PRNGKey, 146 | batch_size: int, 147 | length: int, 148 | ) -> task.Batch: 149 | """Returns a batch of binary additions and their results.""" 150 | del rng 151 | 152 | expressions, results = self._sample_expressions_and_results( 153 | batch_size=batch_size, length=length 154 | ) 155 | # Append the termination token to the result and pad the result with zeros 156 | # to match the output length (accounting for the termination token). 157 | results = [res + [2] + [0] * (length - len(res)) for res in results] 158 | 159 | expressions = jnp.array(expressions, dtype=jnp.int32) 160 | results = jnp.array(results, dtype=jnp.int32) 161 | 162 | return { 163 | 'input': jnn.one_hot(expressions, self.input_size), 164 | 'output': jnn.one_hot(results, self.output_size), 165 | } 166 | 167 | @property 168 | def input_size(self) -> int: 169 | """Returns the input size for the models.""" 170 | return 3 171 | 172 | @property 173 | def output_size(self) -> int: 174 | """Returns the output size for the models.""" 175 | return 3 176 | 177 | def output_length(self, input_length: int) -> int: 178 | return input_length + 1 179 | 180 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 181 | """Computes a mask that ignores everything after the termination token. 182 | 183 | Args: 184 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 185 | 186 | Returns: 187 | The mask of shape `(batch_size, output_length)`. 188 | """ 189 | batch_size, length, _ = target.shape 190 | termination_indices = jnp.argmax( 191 | jnp.argmax(target, axis=-1), 192 | axis=-1, 193 | keepdims=True, 194 | ) 195 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 196 | return indices <= termination_indices 197 | -------------------------------------------------------------------------------- /tasks/cs/binary_multiplication.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Multiply two binary numbers.""" 17 | 18 | import random 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | from randomized_positional_encodings.tasks import task 27 | from randomized_positional_encodings.tasks.cs import binary_addition 28 | 29 | 30 | class BinaryMultiplication(task.GeneralizationTask): 31 | """A task with the goal of multiplying two numbers in binary (little-endian). 32 | 33 | The input is a string of the form `first_number£second_number` in 34 | (little-endian) binary notation (e.g., `01001*011`). The goal of the agent is 35 | to output the result, also in (little-endian) binary form (i.e., in the 36 | example `18 * 6 = 108 = 00110011`). The output is padded with 0s to match the 37 | input length, and the end of the product is denoted with a termination token 38 | (i.e., the output has values in `{0, 1, 2}`). 39 | 40 | Examples: 41 | 001 * 01101 = 000110120 (4 * 22 = 88) 42 | 1001 * 000001 = 00000100120 (9 * 32 = 288) 43 | """ 44 | 45 | def _sample_expressions_and_results( 46 | self, 47 | batch_size: int, 48 | length: int, 49 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]: 50 | """Samples pairs of numbers and multiplies them in (little-endian) binary. 51 | 52 | We use Python's bignums, which can represent arbitrary-precision integers to 53 | perform multiplication of two potentially very large values (roughly of the 54 | size `2 ** (length // 2)`). 55 | 56 | Args: 57 | batch_size: The number of expressions and results to sample. 58 | length: The length of the input expression containing the two numbers and 59 | the separation token. 60 | 61 | Returns: 62 | The expression and the product of the two numbers. The expression has the 63 | format: `[first_number, 2, second_number]`, where the numbers are in 64 | (little-endian) binary notation. The product is also in (little-endian) 65 | binary notation, without leading (i.e., ending) zeros. 66 | """ 67 | # If `length <= 2`, we just sample a binary sequence for the expression and 68 | # arbitrarily set the result to a fixed value (`[]` for `length == 1` and 69 | # `[0]` for `length == 2`) to maintain the invariant that the result has 70 | # length has most `length - 1`. 71 | if length <= 2: 72 | # Since `length <= 2`, we can use `np.random`` without overflow errors. 73 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size)) 74 | expressions = binary_addition.numbers_to_fixed_length_binary( 75 | numbers, length 76 | ) 77 | return expressions, [[0] * (length - 1)] * batch_size 78 | 79 | # We only use `length - 1` tokens for the two values to account for the `*`. 80 | length_n = np.random.randint(1, length - 1, size=(batch_size,)) 81 | length_m = length - 1 - length_n 82 | 83 | integer_n = [random.randint(1, 2 ** int(len_n) - 1) for len_n in length_n] 84 | integer_m = [random.randint(1, 2 ** int(len_m) - 1) for len_m in length_m] 85 | 86 | binary_n = binary_addition.numbers_to_variable_length_binary( 87 | integer_n, length_n 88 | ) 89 | binary_m = binary_addition.numbers_to_variable_length_binary( 90 | integer_m, length_m 91 | ) 92 | 93 | expressions = binary_addition.expression_from_numbers(binary_n, binary_m) 94 | 95 | integer_prod = [int_n * int_m for int_n, int_m in zip(integer_n, integer_m)] 96 | results = binary_addition.numbers_to_fixed_length_binary( 97 | integer_prod, length=0 98 | ) 99 | 100 | return expressions, results 101 | 102 | def sample_batch( 103 | self, 104 | rng: chex.PRNGKey, 105 | batch_size: int, 106 | length: int, 107 | ) -> task.Batch: 108 | """Returns a batch of binary multiplications and their results.""" 109 | del rng 110 | 111 | expressions, results = self._sample_expressions_and_results( 112 | batch_size=batch_size, length=length 113 | ) 114 | # Append the termination token to the result and pad the result with zeros 115 | # to match the output length (accounting for the termination token). The 116 | # binary representation of the result will have at most length 117 | # `#(first_number) + #(second_number)`, where #() denotes the number of 118 | # digits of the binary notation. Since we use the token `2` to separate the 119 | # two numbers in the expression, the result will have length at most 120 | # `length - 1`, and thus by appending the termination token above it will 121 | # have length at most `length`, as desired. 122 | results = [res + [2] + [0] * (length - 1 - len(res)) for res in results] 123 | 124 | expressions = jnp.array(expressions, dtype=jnp.int32) 125 | results = jnp.array(results, dtype=jnp.int32) 126 | 127 | return { 128 | 'input': jnn.one_hot(expressions, self.input_size), 129 | 'output': jnn.one_hot(results, self.output_size), 130 | } 131 | 132 | @property 133 | def input_size(self) -> int: 134 | """Returns the input size for the models.""" 135 | return 3 136 | 137 | @property 138 | def output_size(self) -> int: 139 | """Returns the output size for the models.""" 140 | return 3 141 | 142 | def output_length(self, input_length: int) -> int: 143 | return input_length 144 | 145 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 146 | """Computes a mask that ignores everything after the termination token. 147 | 148 | Args: 149 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 150 | 151 | Returns: 152 | The mask of shape `(batch_size, output_length)`. 153 | """ 154 | batch_size, length, _ = target.shape 155 | termination_indices = jnp.argmax( 156 | jnp.argmax(target, axis=-1), 157 | axis=-1, 158 | keepdims=True, 159 | ) 160 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 161 | return indices <= termination_indices 162 | -------------------------------------------------------------------------------- /tasks/cs/bucket_sort.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sort tokens from a fixed alphabet (i.e., bucket sort).""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | from jax import nn as jnn 23 | from jax import numpy as jnp 24 | from jax import random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class BucketSort(task.GeneralizationTask): 30 | """A task with the goal of sorting tokens from a fixed alphabet. 31 | 32 | The input string is composed of tokens from a fixed-size alphabet, i.e., 33 | `{0, 1, ..., vocab_size - 1}`, and the goal is to return the sorted string (in 34 | lexicographically increasing order). 35 | 36 | Examples: 37 | 10204112 -> 00111224 (with `vocab_size = 5`) 38 | 1110001 -> 0001111 (with `vocab_size = 2`) 39 | """ 40 | 41 | def __init__(self, *args, vocab_size: int = 5, **kwargs) -> None: 42 | """Initializes the task. 43 | 44 | Args: 45 | *args: The args for the base task class. 46 | vocab_size: The size of the alphabet. 47 | **kwargs: The kwargs for the base task class. 48 | """ 49 | super().__init__(*args, **kwargs) 50 | self._vocab_size = vocab_size 51 | 52 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 53 | def sample_batch( 54 | self, 55 | rng: chex.PRNGKey, 56 | batch_size: int, 57 | length: int, 58 | ) -> task.Batch: 59 | """Returns a batch of strings and tokens sorted by (inc.) occurrence.""" 60 | strings = jrandom.randint( 61 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size 62 | ) 63 | sorted_strings = jnp.sort(strings, axis=-1) 64 | 65 | return { 66 | 'input': jnn.one_hot(strings, num_classes=self.input_size), 67 | 'output': jnn.one_hot(sorted_strings, num_classes=self.output_size), 68 | } 69 | 70 | @property 71 | def input_size(self) -> int: 72 | """Returns the input size for the models.""" 73 | return self._vocab_size 74 | 75 | @property 76 | def output_size(self) -> int: 77 | """Returns the output size for the models.""" 78 | return self._vocab_size 79 | 80 | def output_length(self, input_length: int) -> int: 81 | """Returns the output length for a given input length.""" 82 | return input_length 83 | -------------------------------------------------------------------------------- /tasks/cs/compute_sqrt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the floor of the square root of a binary number.""" 17 | 18 | import math 19 | import random 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | 25 | from randomized_positional_encodings.tasks import task 26 | from randomized_positional_encodings.tasks.cs import binary_addition 27 | 28 | 29 | class ComputeSqrt(task.GeneralizationTask): 30 | """A task with the goal of computing the square root of a binary number. 31 | 32 | The input is a number in binary (big-endian), and the output is the floor of 33 | the square root of this number, also in binary. 34 | Note the output length ie the length of the square root in binary is always 35 | ceil(input_length / 2) (because log(sqrt(x)) = 1/2 log(x)). 36 | 37 | Examples: 38 | 100101 = 37 -> square root is 6.08... -> floor(6.08) = 6 -> 101 39 | 111 = 7 -> square root is 2.64 -> floor(2.64) = 2 -> 10 40 | """ 41 | 42 | def sample_batch( 43 | self, rng: chex.PRNGKey, batch_size: int, length: int 44 | ) -> task.Batch: 45 | """Returns a batch of binary numbers and their square roots, in binary.""" 46 | del rng 47 | numbers = [random.randint(1, 2**length - 1) for _ in range(batch_size)] 48 | binary_numbers = binary_addition.numbers_to_fixed_length_binary( 49 | numbers, length=length, little_endian=False 50 | ) 51 | 52 | sqrts = list(map(math.isqrt, numbers)) 53 | binary_sqrts = binary_addition.numbers_to_fixed_length_binary( 54 | sqrts, length=self.output_length(length), little_endian=False 55 | ) 56 | 57 | binary_numbers = jnp.array(binary_numbers, jnp.int32) 58 | binary_sqrts = jnp.array(binary_sqrts, jnp.int32) 59 | 60 | inputs = jnn.one_hot(binary_numbers, self.input_size) 61 | output = jnn.one_hot(binary_sqrts, self.output_size) 62 | return {'input': inputs, 'output': output} 63 | 64 | @property 65 | def input_size(self) -> int: 66 | """Returns the input size for the models.""" 67 | return 2 68 | 69 | @property 70 | def output_size(self) -> int: 71 | """Returns the output size for the models.""" 72 | return 2 73 | 74 | def output_length(self, input_length: int) -> int: 75 | return math.ceil(input_length / 2) 76 | -------------------------------------------------------------------------------- /tasks/cs/duplicate_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Duplicate a string.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class DuplicateString(task.GeneralizationTask): 30 | """A task with the goal of duplicating a string. 31 | 32 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 33 | output is the same string outputted twice without any separator, ie: 34 | s_1 ... s_n s_1 ... s_n 35 | 36 | Examples: 37 | 101 -> 101 101 38 | 111111 -> 111111 111111 39 | 40 | In the paper, we use only binary strings (ie S = {0, 1}). 41 | Note that the sampling is jittable so this task is fast. 42 | """ 43 | 44 | def __init__(self, vocab_size: int, *args, duplication: int = 2, **kwargs): 45 | """Initializes the remember_string task. 46 | 47 | Args: 48 | vocab_size: The size of the alphabet. 49 | *args: Args for the base task class. 50 | duplication: Number of times the string should be duplicated. 51 | **kwargs: Kwargs for the base task class. 52 | """ 53 | super().__init__(*args, **kwargs) 54 | 55 | self._vocab_size = vocab_size 56 | self._duplication = duplication 57 | 58 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 59 | def sample_batch( 60 | self, rng: chex.PRNGKey, batch_size: int, length: int 61 | ) -> task.Batch: 62 | """Returns a batch of strings and their copies.""" 63 | strings = jrandom.randint( 64 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size 65 | ) 66 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size) 67 | output = jnp.concatenate([one_hot_strings] * self._duplication, axis=1) 68 | return {"input": one_hot_strings, "output": output} 69 | 70 | @property 71 | def input_size(self) -> int: 72 | """Returns the input size for the models.""" 73 | return self._vocab_size 74 | 75 | @property 76 | def output_size(self) -> int: 77 | """Returns the output size for the models.""" 78 | return self._vocab_size 79 | 80 | def output_length(self, input_length: int) -> int: 81 | """Returns the output length for a given input length.""" 82 | return self._duplication * input_length 83 | -------------------------------------------------------------------------------- /tasks/cs/missing_duplicate_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Predict the missing symbol in a duplicated string.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class MissingDuplicateString(task.GeneralizationTask): 30 | """A task with the goal of finding the missing symbol in a duplicated string. 31 | 32 | Given a binary string that is presented twice with exactly one element omitted 33 | (denoted by the placeholder token `2`), predict the value of that element. 34 | Thus, an agent trying to solve this task needs to recognize the underlying 35 | duplicated string to be able to produce the correct output. 36 | If the length is odd, the duplicated strings of length `length // 2` are 37 | padded with the empty token `3`. 38 | 39 | Examples 40 | 01100210 -> 1 (the substring is 0110, so the missing value is 1) 41 | 1011213 -> 0 (the subtring is 101, so the missing value is 0) 42 | """ 43 | 44 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 45 | def sample_batch( 46 | self, 47 | rng: chex.PRNGKey, 48 | batch_size: int, 49 | length: int, 50 | ) -> task.Batch: 51 | """Returns a batch of strings and the expected class.""" 52 | # For `length == 1`, we cannot meaningfully define substrings of length 53 | # `length // 2`, so we arbitrarily set the inputs and outputs to `1`. 54 | if length == 1: 55 | return { 56 | 'input': jnn.one_hot( 57 | jnp.ones((batch_size, length)), num_classes=self.input_size 58 | ), 59 | 'output': jnn.one_hot( 60 | jnp.ones((batch_size,)), num_classes=self.output_size 61 | ), 62 | } 63 | 64 | strings_rng, indices_rng = jrandom.split(rng) 65 | strings = jrandom.randint( 66 | strings_rng, shape=(batch_size, length // 2), minval=0, maxval=2 67 | ) 68 | duplicated_strings = jnp.concatenate((strings, strings), axis=-1) 69 | indices = jrandom.randint( 70 | indices_rng, 71 | shape=(batch_size,), 72 | minval=0, 73 | maxval=duplicated_strings.shape[1], 74 | ) 75 | output = jax.vmap(lambda x, y: x[y])(duplicated_strings, indices) 76 | masked_strings = jax.vmap(lambda x, y: x.at[y].set(2))( 77 | duplicated_strings, indices 78 | ) 79 | 80 | # If `length` is odd, we pad the strings with the empty token `3` at the end 81 | # to ensure that the final input length is equal to `length` given the two 82 | # substrings of length `length // 2`. 83 | padding = jnp.full((batch_size, length % 2), fill_value=3) 84 | padded_strings = jnp.concatenate((masked_strings, padding), axis=-1) 85 | 86 | return { 87 | 'input': jnn.one_hot(padded_strings, num_classes=self.input_size), 88 | 'output': jnn.one_hot(output, num_classes=self.output_size), 89 | } 90 | 91 | @property 92 | def input_size(self) -> int: 93 | """Returns the input size for the models.""" 94 | return 4 95 | 96 | @property 97 | def output_size(self) -> int: 98 | """Returns the output size for the models.""" 99 | return 2 100 | -------------------------------------------------------------------------------- /tasks/cs/odds_first.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sort a string by the parity of the indices (odd indices first).""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class OddsFirst(task.GeneralizationTask): 30 | """A task with the goal of outputting a string's tokens at odd indices first. 31 | 32 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 33 | output is the same string, but where the values at odd indexes have been put 34 | first: s_1 s_3 s_5 ... s_2 s_4 s_6 ... 35 | 36 | Examples: 37 | 00110101 -> 0100 0111 38 | 110 -> 10 1 39 | 40 | In the paper, we use only binary strings (ie S = {0, 1}). 41 | Note that the sampling is jittable so this task is fast. 42 | """ 43 | 44 | def __init__(self, vocab_size: int, *args, **kwargs): 45 | """Initializes the odds_first task. 46 | 47 | Args: 48 | vocab_size: The size of the alphabet. 49 | *args: Args for the base task class. 50 | **kwargs: Kwargs for the base task class. 51 | """ 52 | super().__init__(*args, **kwargs) 53 | 54 | self._vocab_size = vocab_size 55 | 56 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 57 | def sample_batch( 58 | self, rng: chex.PRNGKey, batch_size: int, length: int 59 | ) -> task.Batch: 60 | """Returns a batch of strings and their outputs.""" 61 | strings = jrandom.randint( 62 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size 63 | ) 64 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size) 65 | output = jnp.concatenate( 66 | [one_hot_strings[:, 1::2], one_hot_strings[:, ::2]], axis=1 67 | ) 68 | return {"input": one_hot_strings, "output": output} 69 | 70 | @property 71 | def input_size(self) -> int: 72 | """Returns the input size for the model.""" 73 | return self._vocab_size 74 | 75 | @property 76 | def output_size(self) -> int: 77 | """Returns the output size for the model.""" 78 | return self._vocab_size 79 | 80 | def output_length(self, input_length: int) -> int: 81 | """Returns the output length for the model.""" 82 | return input_length 83 | -------------------------------------------------------------------------------- /tasks/dcf/modular_arithmetic_brackets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Modular arithmetic with brackets.""" 17 | 18 | import collections 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import numpy as np 24 | import tqdm 25 | import tree 26 | 27 | from randomized_positional_encodings.tasks import task 28 | 29 | 30 | def generate_one_expression_and_result( 31 | modulus: int, length: int, mult: bool = False 32 | ) -> tuple[str, int]: 33 | """Returns a modular arithmetic expression with brackets, and its result. 34 | 35 | The values in the expression are in {0, 1, ..., modulus-1}. The allowed 36 | operations are either {+, -} (mult=False) or {+, -, *} (mult=True). 37 | 38 | Args: 39 | modulus: The modulus to use for the expression. 40 | length: The length of the expression. 41 | mult: Whether to include the multiplication operator in the expressions. 42 | 43 | Raises: 44 | ValueError if length < 1. 45 | """ 46 | 47 | # Generates a terminal (digit). 48 | def gen_terminal(): 49 | terminal = np.random.randint(low=0, high=modulus) 50 | return str(terminal), terminal 51 | 52 | # If length is less than 1, issue an error. 53 | if length < 1: 54 | raise ValueError(f"Can't generate expressions of length < 1. Got {length}.") 55 | 56 | # If length is less than 5, generate a digit d, -d, (d), or (-d). 57 | if length == 1: 58 | return gen_terminal() 59 | elif length == 2: 60 | term_str, term_val = gen_terminal() 61 | return '-' + term_str, -term_val % modulus 62 | elif length == 3: 63 | term_str, term_val = gen_terminal() 64 | return '(' + term_str + ')', term_val 65 | elif length == 4: 66 | term_str, term_val = gen_terminal() 67 | return '(-' + term_str + ')', -term_val % modulus 68 | 69 | # If length is >= 5, sample an operator with brackets. 70 | 71 | # First split the length into a left and right part. 72 | left_length = np.random.randint(low=1, high=length - 3) 73 | right_length = length - (left_length + 3) 74 | left_str, left_val = generate_one_expression_and_result( 75 | modulus, left_length, mult=mult 76 | ) 77 | right_str, right_val = generate_one_expression_and_result( 78 | modulus, right_length, mult=mult 79 | ) 80 | 81 | # Now sample an operator and return. 82 | maxop = 3 if mult else 2 83 | op = np.random.randint(low=0, high=maxop) 84 | if op == 0: 85 | return ( 86 | '(' + left_str + '+' + right_str + ')', 87 | (left_val + right_val) % modulus, 88 | ) 89 | elif op == 1: 90 | return ( 91 | '(' + left_str + '-' + right_str + ')', 92 | (left_val - right_val) % modulus, 93 | ) 94 | else: 95 | return ( 96 | '(' + left_str + '*' + right_str + ')', 97 | (left_val * right_val) % modulus, 98 | ) 99 | 100 | 101 | def generate_raw_dataset( 102 | n: int, 103 | lengths: Sequence[int], 104 | modulus: int, 105 | mult: bool = False, 106 | with_tqdm: bool = False, 107 | ) -> dict[int, dict[str, np.ndarray]]: 108 | """Generates a dataset of maths expressions with brackets, and their results. 109 | 110 | Args: 111 | n: The number of datapoints in the dataset. 112 | lengths: The lengths of the sequences to generate. n is evenly distributed 113 | over these lengths. 114 | modulus: Modulus used to compute the expressions. 115 | mult: Whether to include the multiplication operator in the expressions. 116 | with_tqdm: As the computation might be long, whether to add a tqdm progress 117 | bar or not. 118 | 119 | Returns: 120 | A dict which keys are the passed lengths, and the values are dicts with keys 121 | 'equations' and 'solutions', and values are the data numpy arrays. 122 | """ 123 | alphabet_to_int = { 124 | '+': modulus, 125 | '-': modulus + 1, 126 | '*': modulus + 2, 127 | '(': modulus + 3, 128 | ')': modulus + 4, 129 | 'x': modulus + 5, 130 | '=': modulus + 6, 131 | } 132 | for x in range(modulus): 133 | alphabet_to_int[str(x)] = x 134 | 135 | sequences = collections.defaultdict( 136 | lambda: { # pylint: disable=g-long-lambda 137 | 'expressions': [], 138 | 'results': [], 139 | } 140 | ) 141 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths 142 | for length in range_lengths: 143 | for _ in range(n // len(lengths)): 144 | seq, label = generate_one_expression_and_result(modulus, length, mult) 145 | seq = [alphabet_to_int[x] for x in seq] 146 | sequences[length]['expressions'].append(seq) 147 | sequences[length]['results'].append(label) 148 | sequences = tree.traverse( 149 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l, 150 | sequences, 151 | top_down=False, 152 | ) 153 | return dict(sequences) 154 | 155 | 156 | class ModularArithmeticBrackets(task.GeneralizationTask): 157 | """A task with the goal of reducing an arithmetic expression with brackets.""" 158 | 159 | def __init__(self, modulus: int, *args, mult: bool = False, **kwargs): 160 | super().__init__(*args, **kwargs) 161 | self._modulus = modulus 162 | self._mult = mult 163 | 164 | def sample_batch( 165 | self, rng: chex.PRNGKey, batch_size: int, length: int 166 | ) -> task.Batch: 167 | """Returns a batch of inputs/outputs.""" 168 | np.random.seed(rng[0]) 169 | batch = generate_raw_dataset( 170 | batch_size, lengths=[length], modulus=self._modulus, mult=self._mult 171 | )[length] 172 | inputs = jnn.one_hot(batch['expressions'], self.input_size) 173 | output = jnn.one_hot(batch['results'], self.output_size) 174 | return {'input': inputs, 'output': output} 175 | 176 | @property 177 | def input_size(self) -> int: 178 | """Returns the input size for the models.""" 179 | return self._modulus + 6 180 | 181 | @property 182 | def output_size(self) -> int: 183 | """Returns the output size for the models.""" 184 | return self._modulus 185 | -------------------------------------------------------------------------------- /tasks/dcf/reverse_string.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the reverse of an input string.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | from randomized_positional_encodings.tasks import task 25 | from randomized_positional_encodings.tasks.cs import duplicate_string 26 | 27 | 28 | class ReverseString(duplicate_string.DuplicateString): 29 | """A task with the goal of reversing a given string. 30 | 31 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The 32 | output is the string, reversed, ie s_n ... s_1. 33 | 34 | Examples: 35 | 011010 -> 010110 36 | 123021 -> 120321 37 | 38 | In the paper, we use only binary strings (ie S = {0, 1}). 39 | Note that the sampling is jittable so this task is fast. 40 | """ 41 | 42 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 43 | def sample_batch( 44 | self, rng: chex.PRNGKey, batch_size: int, length: int 45 | ) -> task.Batch: 46 | """Returns a batch of strings and their reversed version.""" 47 | batch = super().sample_batch(rng, batch_size, length) 48 | batch['output'] = jnp.flip(batch['input'], axis=1) 49 | return batch 50 | 51 | def output_length(self, input_length: int) -> int: 52 | """Returns the output length for a given input length.""" 53 | return input_length 54 | -------------------------------------------------------------------------------- /tasks/dcf/solve_equation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Solve for the value of an unknown variable in an equation.""" 17 | 18 | import collections 19 | from typing import Sequence 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import numpy as np 25 | import tqdm 26 | import tree 27 | 28 | from randomized_positional_encodings.tasks import task 29 | from randomized_positional_encodings.tasks.dcf import modular_arithmetic_brackets as mab 30 | 31 | 32 | def generate_equation_and_solution( 33 | modulus: int, length: int, mult: bool = False 34 | ) -> tuple[str, int]: 35 | """Returns a modular arithmetic equation with brackets, and its solution. 36 | 37 | The values are in {0, 1, ..., modulus-1}, and the unknown 38 | value is x. The allowed operations are either {+, -} (mult=False) or 39 | {+, -, *} (mult=True). 40 | Warning: if mult=True, x might have multiple valid solutions. 41 | 42 | Args: 43 | modulus: The modulus to use for the expression. 44 | length: The length of the expression. 45 | mult: Whether to include the multiplication operator in the expressions. 46 | 47 | Raises: 48 | ValueError if the length is < 3. 49 | """ 50 | 51 | # Generate the expression. 52 | expr, val = mab.generate_one_expression_and_result( 53 | modulus, length - 2, mult=mult 54 | ) 55 | 56 | # Replace random digit with 'x'. 57 | idx = np.random.randint(low=0, high=len(expr)) 58 | digits = [str(n) for n in range(modulus)] 59 | while expr[idx] not in digits: 60 | idx = (idx + 1) % (length - 2) 61 | solution = int(expr[idx]) 62 | equation = expr[:idx] + 'x' + expr[idx + 1 :] + '=' + str(val) 63 | 64 | return equation, solution 65 | 66 | 67 | def generate_raw_dataset( 68 | n: int, 69 | lengths: Sequence[int], 70 | modulus: int, 71 | mult: bool = False, 72 | with_tqdm: bool = False, 73 | ) -> dict[int, dict[str, np.ndarray]]: 74 | """Generates a dataset of equations and their solutions. 75 | 76 | Args: 77 | n: The number of datapoints in the dataset. 78 | lengths: The lengths of the sequences to generate. n is evenly distributed 79 | over these lengths. 80 | modulus: Modulus used to compute the expressions. 81 | mult: Whether to include the multiplication operator in the expressions. 82 | with_tqdm: As the computation might be long, whether to add a tqdm progress 83 | bar or not. 84 | 85 | Returns: 86 | A dict which keys are the passed lengths, and the values are dicts with keys 87 | 'equations' and 'solutions', and values are the data numpy arrays. 88 | """ 89 | alphabet_to_int = { 90 | '+': modulus, 91 | '-': modulus + 1, 92 | '(': modulus + 2, 93 | ')': modulus + 3, 94 | 'x': modulus + 4, 95 | '=': modulus + 5, 96 | } 97 | for x in range(modulus): 98 | alphabet_to_int[str(x)] = x 99 | 100 | sequences = collections.defaultdict( 101 | lambda: { # pylint: disable=g-long-lambda 102 | 'equations': [], 103 | 'solutions': [], 104 | } 105 | ) 106 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths 107 | for length in range_lengths: 108 | for _ in range(n // len(lengths)): 109 | seq, label = generate_equation_and_solution(modulus, length, mult=mult) 110 | seq = [alphabet_to_int[x] for x in seq] 111 | sequences[length]['equations'].append(seq) 112 | sequences[length]['solutions'].append(label) 113 | # Convert the list of numbers we have to arrays at the leaves. 114 | sequences = tree.traverse( 115 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l, 116 | sequences, 117 | top_down=False, 118 | ) 119 | return dict(sequences) 120 | 121 | 122 | class SolveEquation(task.GeneralizationTask): 123 | """A task with the goal of solving an modular equation for an unknown.""" 124 | 125 | def __init__(self, modulus: int, *args, **kwargs): 126 | super().__init__(*args, **kwargs) 127 | self._modulus = modulus 128 | 129 | def sample_batch( 130 | self, rng: chex.PRNGKey, batch_size: int, length: int 131 | ) -> task.Batch: 132 | """Returns a batch of inputs/outputs.""" 133 | np.random.seed(rng[0]) 134 | if length < 3: 135 | return { 136 | 'input': jnn.one_hot( 137 | jnp.zeros((batch_size, length)), num_classes=self.input_size 138 | ), 139 | 'output': jnn.one_hot( 140 | jnp.zeros((batch_size,)), num_classes=self.output_size 141 | ), 142 | } 143 | batch = generate_raw_dataset( 144 | batch_size, lengths=[length], modulus=self._modulus 145 | )[length] 146 | inputs = jnn.one_hot(batch['equations'], self.input_size) 147 | output = jnn.one_hot(batch['solutions'], self.output_size) 148 | return {'input': inputs, 'output': output} 149 | 150 | @property 151 | def input_size(self) -> int: 152 | """Returns the input size for the models.""" 153 | return self._modulus + 6 154 | 155 | @property 156 | def output_size(self) -> int: 157 | """Returns the output size for the models.""" 158 | return self._modulus 159 | -------------------------------------------------------------------------------- /tasks/dcf/stack_manipulation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Manipulate an input stack, using the input actions.""" 17 | 18 | import chex 19 | import jax.nn as jnn 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | from randomized_positional_encodings.tasks import task 24 | 25 | 26 | class StackManipulation(task.GeneralizationTask): 27 | """A task with the goal of following instructions and returning the end stack. 28 | 29 | The input is composed of a stack of 0s and 1s followed by a sequence of 30 | instructions POP/PUSH 0/PUSH 1 (represented by 2s/3s/4s). The input stack is 31 | given bottom-to-top, and the agent needs to execute the instructions given 32 | (left-to-rigth) and output the final stack top-to-bottom (i.e., as if it were 33 | popping the final stack). If a POP action is to be called on an empty stack, 34 | the action is ignored. The output is padded with 0s to match the input length 35 | + 1 (to accommodate for the termination token), and the end of the final stack 36 | is denoted with the termination symbol 2 (i.e., the output has values in {0, 37 | 1, 2}). 38 | 39 | Examples: 40 | 0 1 1 0 PUSH 1 POP POP 41 | initial 0 1 1 0 (the stack is received bottom-to-top) 42 | PUSH 1 0 1 1 0 1 43 | POP 0 1 1 0 44 | POP 0 1 1 45 | -> 1 1 0 2 0 0 0 0 (the stack is returned top-to-bottom) 46 | 47 | 1 1 0 POP POP POP 48 | initial 1 1 0 49 | POP 1 1 50 | POP 1 51 | POP 52 | -> 2 0 0 0 0 0 0 0 (the stack is empty and padded with zeros) 53 | """ 54 | 55 | def _sample_expression_and_result( 56 | self, length: int 57 | ) -> tuple[np.ndarray, list[int]]: 58 | """Returns an expression with stack instructions, and the result stack.""" 59 | if length == 1: 60 | value = np.random.randint(low=0, high=2, size=(1,)) 61 | return value, list(value) 62 | 63 | # Initialize the stack content and the actions (POP/PUSH). 64 | stack_length = np.random.randint(low=1, high=length) 65 | stack = np.random.randint(low=0, high=2, size=(stack_length,)) 66 | actions = np.random.randint(low=2, high=5, size=(length - stack_length,)) 67 | 68 | # Apply the actions on the stack. 69 | current_stack = list(stack) 70 | 71 | for action in actions: 72 | if action == 2: # POP 73 | if current_stack: 74 | current_stack.pop() 75 | elif action in [3, 4]: # PUSH a 0 (case 3) or a 1 (case 4) 76 | current_stack.append(action - 3) 77 | 78 | return np.concatenate([stack, actions]), current_stack[::-1] 79 | 80 | def sample_batch( 81 | self, rng: chex.PRNGKey, batch_size: int, length: int 82 | ) -> task.Batch: 83 | """Returns a batch of strings and the expected class.""" 84 | np.random.seed(rng[0]) 85 | expressions, results = [], [] 86 | for _ in range(batch_size): 87 | expression, result = self._sample_expression_and_result(length) 88 | expressions.append(expression) 89 | # Append the termination token to the result. 90 | result += [self.output_size - 1] 91 | # Pad the result with zeros to match the input length (accounting for the 92 | # termination token). 93 | result += [0] * (length + 1 - len(result)) 94 | results.append(result) 95 | expressions = jnp.array(expressions) 96 | results = jnp.array(results) 97 | 98 | inputs = jnn.one_hot(expressions, self.input_size) 99 | output = jnn.one_hot(results, self.output_size) 100 | return {'input': inputs, 'output': output} 101 | 102 | @property 103 | def input_size(self) -> int: 104 | """Returns the input size for the models. 105 | 106 | The value is 5 because we have two possible tokens in the stack (0, 1), plus 107 | three tokens to describe the PUSH 0, PUSH 1, and POP actions. 108 | """ 109 | return 5 110 | 111 | @property 112 | def output_size(self) -> int: 113 | """Returns the output size for the models.""" 114 | return 3 115 | 116 | def output_length(self, input_length: int) -> int: 117 | """Returns the output length of the task.""" 118 | return input_length + 1 119 | 120 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 121 | """Computes mask that ignores everything after the termination tokens. 122 | 123 | Args: 124 | target: Target tokens of shape `(batch_size, output_length, output_size)`. 125 | 126 | Returns: 127 | The mask of shape `(batch_size, output_length)`. 128 | """ 129 | batch_size, length, _ = target.shape 130 | termination_indices = jnp.argmax( 131 | jnp.argmax(target, axis=-1), 132 | axis=-1, 133 | keepdims=True, 134 | ) 135 | indices = jnp.tile(jnp.arange(length), (batch_size, 1)) 136 | return indices <= termination_indices 137 | -------------------------------------------------------------------------------- /tasks/regular/cycle_navigation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute the final state after randomly walking on a circle.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class CycleNavigation(task.GeneralizationTask): 30 | """A task with the goal of computing the final state on a circle. 31 | 32 | The input is a string of actions, composed of 0s, 1s or -1s. The actions give 33 | directions to take on a finite length circle (0 is for stay, 1 is for right, 34 | -1 is for left). The goal is to give the final position on the circle after 35 | all the actions have been taken. The agent starts at position 0. 36 | 37 | By default, the length the circle is 5. 38 | 39 | Examples: 40 | 1 -1 0 -1 -1 -> -2 = class 3 41 | 1 1 1 -1 -> 2 = class 2 42 | 43 | Note that the sampling is jittable so it is fast. 44 | """ 45 | 46 | @property 47 | def _cycle_length(self) -> int: 48 | """Returns the cycle length, number of possible states.""" 49 | return 5 50 | 51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 52 | def sample_batch( 53 | self, rng: chex.PRNGKey, batch_size: int, length: int 54 | ) -> task.Batch: 55 | """Returns a batch of strings and the expected class.""" 56 | actions = jrandom.randint( 57 | rng, shape=(batch_size, length), minval=0, maxval=3 58 | ) 59 | final_states = jnp.sum(actions - 1, axis=1) % self._cycle_length 60 | final_states = jnn.one_hot(final_states, num_classes=self.output_size) 61 | one_hot_strings = jnn.one_hot(actions, num_classes=self.input_size) 62 | return {"input": one_hot_strings, "output": final_states} 63 | 64 | @property 65 | def input_size(self) -> int: 66 | """Returns the input size for the models.""" 67 | return 3 68 | 69 | @property 70 | def output_size(self) -> int: 71 | """Returns the output size for the models.""" 72 | return self._cycle_length 73 | -------------------------------------------------------------------------------- /tasks/regular/even_pairs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute whether the number of 01's and 10's is even.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | from jax import nn as jnn 23 | from jax import numpy as jnp 24 | from jax import random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class EvenPairs(task.GeneralizationTask): 30 | """A task with the goal of checking whether the number of 01s and 10s is even. 31 | 32 | The input is a binary string, composed of 0s and 1s. If the result is even, 33 | the class is 0, otherwise it's one. 34 | 35 | Examples: 36 | 001110 -> 1 '10' and 1 '01' -> class 0 37 | 0101001 -> 2 '10' and 3 '01' -> class 1 38 | 39 | Note the sampling is jittable so this task is fast. 40 | """ 41 | 42 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 43 | def sample_batch( 44 | self, rng: chex.PRNGKey, batch_size: int, length: int 45 | ) -> task.Batch: 46 | """Returns a batch of strings and the expected class.""" 47 | strings = jrandom.randint( 48 | rng, 49 | shape=(batch_size, length), 50 | minval=0, 51 | maxval=2, 52 | ) 53 | one_hot_strings = jnn.one_hot(strings, num_classes=2) 54 | unequal_pairs = jnp.logical_xor(strings[:, :-1], strings[:, 1:]) 55 | odd_unequal_pairs = jnp.sum(unequal_pairs, axis=-1) % 2 56 | return { 57 | 'input': one_hot_strings, 58 | 'output': jnn.one_hot(odd_unequal_pairs, num_classes=self.output_size), 59 | } 60 | 61 | @property 62 | def input_size(self) -> int: 63 | """Returns the input size for the models.""" 64 | return 2 65 | 66 | @property 67 | def output_size(self) -> int: 68 | """Returns the output size for the models.""" 69 | return 2 70 | -------------------------------------------------------------------------------- /tasks/regular/modular_arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Modular arithmetic without brackets. 17 | 18 | Note this allows to generate samples using a jittable function, and is therefore 19 | much faster than its 'brackets' counterpart, which requires to simulate the full 20 | CF grammar, non-jittable. 21 | """ 22 | 23 | import functools 24 | from typing import Optional, Sequence 25 | 26 | import chex 27 | import jax 28 | import jax.nn as jnn 29 | import jax.numpy as jnp 30 | import jax.random as jrandom 31 | 32 | from randomized_positional_encodings.tasks import task 33 | 34 | OP_BY_CHARACTER = {'+': 0, '-': 1, '*': 2, '_': 3} 35 | 36 | 37 | def _replace_subtractions(expression: chex.Array, modulus: int) -> chex.Array: 38 | """Replaces subtractions in an expression by additions with the inverse. 39 | 40 | e.g. the expression [1, -, 3] results in [1, +, -3]. 41 | 42 | Args: 43 | expression: Encoded expression (a 1D array of integers) in which to replace 44 | subtractions. 45 | modulus: The modulus to use for the modular arithmetic. 46 | 47 | Returns: 48 | The expression with all subtractions replaced by additions with the inverse. 49 | """ 50 | if expression.size < 2: 51 | return expression 52 | 53 | mask = expression == modulus + OP_BY_CHARACTER['-'] 54 | subtract_replaced = jnp.where( 55 | mask, modulus + OP_BY_CHARACTER['+'], expression 56 | ) 57 | return subtract_replaced.at[2:].multiply(1 - 2 * mask[1:-1]) 58 | 59 | 60 | def _perform_multiplications( 61 | expression: chex.Array, modulus: int 62 | ) -> chex.Array: 63 | """Performs all multiplications in an expression containing only + and *. 64 | 65 | This is done at fixed length and the result is zero-padded to achieve this. 66 | Since the result of performing multiplications is an expression containing 67 | only + operators, the operators are dropped from the output. For example, the 68 | expression [1, +, 3, *, 4] results in [1, 12, 0]. 69 | 70 | Args: 71 | expression: Encoded expression in which to perform multiplications. 72 | modulus: The modulus to use for the modular arithmetic. 73 | 74 | Returns: 75 | An array with the results of the multiplications (potentially zero-padded). 76 | """ 77 | term_ids = jnp.cumsum(expression == modulus + OP_BY_CHARACTER['+'])[::2] 78 | # Segment_prod can only be jit-compiled with a fixed number of segments. 79 | # Therefore, we have to set to the maximum number of terms possible and 80 | # mask out superfluous segment results with zeros afterwards. 81 | maximum_term_number = expression.shape[0] // 2 + 1 82 | products = jax.ops.segment_prod( 83 | expression[::2], 84 | term_ids, 85 | num_segments=maximum_term_number, 86 | indices_are_sorted=True, 87 | ) 88 | valid_segment_mask = jnp.arange(maximum_term_number) <= term_ids[-1] 89 | return products * valid_segment_mask 90 | 91 | 92 | def _replace_blanks(expression: chex.Array, modulus: int) -> chex.Array: 93 | """Replaces blank symbols in expression with either `+` or `0`. 94 | 95 | Depending on whether the blank symbol is at the position of an operator or a 96 | residual, the blank symbol is replaced with a `+` operator or a `0`. 97 | 98 | Args: 99 | expression: Encoded expression in which to replace blank symbols. 100 | modulus: The modulus to use for the modular arithmetic. 101 | 102 | Returns: 103 | An array with blank symbols replaced by either `+` or `0`. 104 | """ 105 | mask = expression == OP_BY_CHARACTER['_'] + modulus 106 | operator_mask = mask.at[::2].set(False) # pytype: disable=attribute-error # numpy-scalars 107 | residual_mask = mask.at[1::2].set(False) # pytype: disable=attribute-error # numpy-scalars 108 | 109 | blanks_replaced = jnp.where( 110 | operator_mask, OP_BY_CHARACTER['+'] + modulus, expression 111 | ) 112 | blanks_replaced = jnp.where(residual_mask, 0, blanks_replaced) 113 | return blanks_replaced 114 | 115 | 116 | def _evaluate_expression(expression: chex.Array, modulus: int) -> chex.Array: 117 | """Returns the result of evaluating a modular arithmetic expression.""" 118 | expression = _replace_blanks(expression, modulus) 119 | expression = _replace_subtractions(expression, modulus) 120 | additive_terms = _perform_multiplications(expression, modulus) 121 | return jnp.sum(additive_terms) % modulus 122 | 123 | 124 | class ModularArithmetic(task.GeneralizationTask): 125 | """A task with the goal of reducing a simple arithmetic expression. 126 | 127 | The input is a string, composed of numbers (in {0, ..., modulus-1}), and 128 | operators (in {+, -, *}). The output is the reduced value of this expression, 129 | which is also in {0, ..., modulus-1}. 130 | 131 | Examples (modulo 5): 132 | 1 + 2 * 3 = 2 133 | 1 - 1 - 1 = 4 134 | 0 * 1 + 4 * 3 - 2 = 0 135 | 136 | Note that the input strings are always of odd length. 137 | """ 138 | 139 | def __init__( 140 | self, 141 | modulus: int, 142 | *args, 143 | operators: Optional[Sequence[str]] = None, 144 | **kwargs 145 | ): 146 | """Initializes the modular arithmetic task. 147 | 148 | Args: 149 | modulus: The modulus used for the computation. 150 | *args: Args for the base task class. 151 | operators: Operators to be used in the sequences. By default it's None, 152 | meaning all operators available are used. 153 | **kwargs: Kwargs for the base task class. 154 | """ 155 | super().__init__(*args, **kwargs) 156 | 157 | self._modulus = modulus 158 | if operators is None: 159 | operators = ('+', '*', '-') 160 | self._operators = [OP_BY_CHARACTER[op] for op in operators] 161 | 162 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 163 | def sample_batch( 164 | self, 165 | rng: chex.PRNGKey, 166 | batch_size: int, 167 | length: int, 168 | ) -> task.Batch: 169 | """Returns a batch of modular arithmetic expressions and their labels. 170 | 171 | Args: 172 | rng: The jax random number generator. 173 | batch_size: The size of the batch returned. 174 | length: The length of the sequence. As this length must be odd for the 175 | modular arithmetic dataset, if it's not, we force it to be by 176 | subtracting one to the length passed. 177 | """ 178 | # Subtracting one to the length if it's not odd already. 179 | if length % 2 != 1: 180 | length -= 1 181 | 182 | batch = jnp.empty((batch_size, length), dtype=int) 183 | rng1, rng2 = jax.random.split(rng) 184 | remainders = jax.random.randint( 185 | rng1, (batch_size, length // 2 + 1), 0, self._modulus 186 | ) 187 | ops = self._modulus + jnp.array(self._operators) 188 | 189 | operations = jrandom.choice(rng2, ops, (batch_size, length // 2)) 190 | batch = batch.at[:, ::2].set(remainders) 191 | expressions = batch.at[:, 1::2].set(operations) 192 | 193 | evaluate = functools.partial(_evaluate_expression, modulus=self._modulus) 194 | labels = jax.vmap(evaluate)(expressions) 195 | labels = jnn.one_hot(labels, self._modulus) 196 | one_hot_expressions = jnn.one_hot( 197 | expressions, self._modulus + len(OP_BY_CHARACTER) 198 | ) 199 | return {'input': one_hot_expressions, 'output': labels} 200 | 201 | @property 202 | def input_size(self) -> int: 203 | """Returns the input size for the models.""" 204 | return self._modulus + len(OP_BY_CHARACTER) 205 | 206 | @property 207 | def output_size(self) -> int: 208 | """Returns the output size for the models.""" 209 | return self._modulus 210 | -------------------------------------------------------------------------------- /tasks/regular/parity_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute whether the number of 1s in a string is even.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | import jax.random as jrandom 25 | 26 | from randomized_positional_encodings.tasks import task 27 | 28 | 29 | class ParityCheck(task.GeneralizationTask): 30 | """A task with the goal of counting the number of '1' in a string, modulo 2. 31 | 32 | The input is a string, composed of 0s and 1s. If the result is even, the class 33 | is 0, otherwise it's 1. 34 | 35 | Examples: 36 | 1010100 -> 3 1s (odd) -> class 1 37 | 01111 -> 4 1s (even) -> class 0 38 | 39 | Note that the sampling is jittable so this task is fast. 40 | """ 41 | 42 | @functools.partial(jax.jit, static_argnums=(0, 2, 3)) 43 | def sample_batch( 44 | self, rng: chex.PRNGKey, batch_size: int, length: int 45 | ) -> task.Batch: 46 | """Returns a batch of strings and the expected class.""" 47 | strings = jrandom.randint( 48 | rng, shape=(batch_size, length), minval=0, maxval=2 49 | ) 50 | n_b = jnp.sum(strings, axis=1) % 2 51 | n_b = jnn.one_hot(n_b, num_classes=2) 52 | one_hot_strings = jnn.one_hot(strings, num_classes=2) 53 | return {"input": one_hot_strings, "output": n_b} 54 | 55 | @property 56 | def input_size(self) -> int: 57 | """Returns the input size for the models.""" 58 | return 2 59 | 60 | @property 61 | def output_size(self) -> int: 62 | """Returns the output size for the models.""" 63 | return 2 64 | -------------------------------------------------------------------------------- /tasks/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Base class for length generalization tasks.""" 17 | 18 | import abc 19 | from typing import TypedDict 20 | 21 | import chex 22 | import jax.nn as jnn 23 | import jax.numpy as jnp 24 | 25 | Batch = TypedDict('Batch', {'input': chex.Array, 'output': chex.Array}) 26 | 27 | 28 | class GeneralizationTask(abc.ABC): 29 | """A task for the generalization project. 30 | 31 | Exposes a sample_batch method, and some details about input/output sizes, 32 | losses and accuracies. 33 | """ 34 | 35 | @abc.abstractmethod 36 | def sample_batch( 37 | self, rng: chex.PRNGKey, batch_size: int, length: int 38 | ) -> Batch: 39 | """Returns a batch of inputs/outputs.""" 40 | 41 | def pointwise_loss_fn( 42 | self, output: chex.Array, target: chex.Array 43 | ) -> chex.Array: 44 | """Returns the pointwise loss between an output and a target.""" 45 | return -target * jnn.log_softmax(output) 46 | 47 | def accuracy_fn(self, output: chex.Array, target: chex.Array) -> chex.Array: 48 | """Returns the accuracy between an output and a target.""" 49 | return (jnp.argmax(output, axis=-1) == jnp.argmax(target, axis=-1)).astype( 50 | jnp.float32 51 | ) 52 | 53 | def accuracy_mask(self, target: chex.Array) -> chex.Array: 54 | """Returns a mask to compute the accuracies, to remove the superfluous ones.""" 55 | # Target is a shape of shape (B, T, C) where C is the number of classes. 56 | # We want a mask per input (B, T), so we take this shape. 57 | return jnp.ones(target.shape[:-1]) 58 | 59 | @property 60 | @abc.abstractmethod 61 | def input_size(self) -> int: 62 | """Returns the size of the input of the models trained on this task.""" 63 | 64 | @property 65 | @abc.abstractmethod 66 | def output_size(self) -> int: 67 | """Returns the size of the output of the models trained on this task.""" 68 | 69 | def output_length(self, input_length: int) -> int: 70 | """Returns the length of the output, given an input length.""" 71 | del input_length 72 | return 1 73 | --------------------------------------------------------------------------------