├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── chomsky.svg
├── experiments
├── constants.py
├── curriculum.py
├── example.py
├── range_evaluation.py
├── training.py
└── utils.py
├── models
├── ndstack_rnn.py
├── positional_encodings.py
├── rnn.py
├── stack_rnn.py
├── tape_rnn.py
└── transformer.py
├── requirements.txt
└── tasks
├── cs
├── binary_addition.py
├── binary_multiplication.py
├── bucket_sort.py
├── compute_sqrt.py
├── duplicate_string.py
├── missing_duplicate_string.py
└── odds_first.py
├── dcf
├── modular_arithmetic_brackets.py
├── reverse_string.py
├── solve_equation.py
└── stack_manipulation.py
├── regular
├── cycle_navigation.py
├── even_pairs.py
├── modular_arithmetic.py
└── parity_check.py
└── task.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Networks and the Chomsky Hierarchy
2 |
3 |
4 |
5 |
6 |
7 | This repository provides an implementation of our ICLR 2023 paper [Neural Networks and the Chomsky Hierarchy](https://arxiv.org/abs/2207.02098).
8 |
9 | > Reliable generalization lies at the heart of safe ML and AI.
10 | However, understanding when and how neural networks generalize remains one of the most important unsolved problems in the field.
11 | In this work, we conduct an extensive empirical study (2200 models, 16 tasks) to investigate whether insights from the theory of computation can predict the limits of neural network generalization in practice.
12 | We demonstrate that grouping tasks according to the Chomsky hierarchy allows us to forecast whether certain architectures will be able to generalize to out-of-distribution inputs.
13 | This includes negative results where even extensive amounts of data and training time never led to any non-trivial generalization, despite models having sufficient capacity to perfectly fit the training data.
14 | Our results show that, for our subset of tasks, RNNs and Transformers fail to generalize on non-regular tasks, LSTMs can solve regular and counter-language tasks, and only networks augmented with structured memory (such as a stack or memory tape) can successfully generalize on context-free and context-sensitive tasks.
15 |
16 | It is based on [JAX](https://jax.readthedocs.io) and [Haiku](https://dm-haiku.readthedocs.io) and contains all code, datasets, and models necessary to reproduce the paper's results.
17 |
18 |
19 | ## Content
20 |
21 | ```
22 | .
23 | ├── models
24 | | ├── ndstack_rnn.py - Nondeterministic Stack-RNN (DuSell & Chiang, 2021)
25 | | ├── rnn.py - RNN (Elman, 1990)
26 | | ├── stack_rnn.py - Stack-RNN (Joulin & Mikolov, 2015)
27 | | ├── tape_rnn.py - Tape-RNN, loosely based on Baby-NTM (Suzgun et al., 2019)
28 | | └── transformer.py - Transformer (Vaswani et al., 2017)
29 | ├── tasks
30 | | ├── cs - Context-sensitive tasks
31 | | ├── dcf - Determinisitc context-free tasks
32 | | ├── regular - Regular tasks
33 | | └── task.py - Abstract GeneralizationTask
34 | ├── experiments
35 | | ├── constants.py - Training/Evaluation constants
36 | | ├── curriculum.py - Training curricula (over sequence lengths)
37 | | ├── example.py - Example training script (RNN on the Even Pairs task)
38 | | ├── range_evaluation.py - Evaluation loop (over unseen sequence lengths)
39 | | ├── training.py - Training loop
40 | | └── utils.py - Utility functions
41 | ├── README.md
42 | └── requirements.txt - Dependencies
43 | ```
44 |
45 | `tasks` contains all tasks, organized in their Chomsky hierarchy levels (regular, dcf, cs).
46 | They all inherit the abstract class `GeneralizationTask`, defined in `tasks/task.py`.
47 |
48 | `models` contains all the models we use, written in [jax](https://github.com/google/jax) and [haiku](https://github.com/deepmind/dm-haiku), two open source libraries.
49 |
50 | `training` contains the code for training models and evaluating them on a wide range of lengths.
51 | We also included an example to train and evaluate an RNN on the Even Pairs task.
52 | We use [optax](https://github.com/deepmind/optax) for our optimizers.
53 |
54 |
55 | ## Installation
56 |
57 | Clone the source code into a local directory:
58 | ```bash
59 | git clone https://github.com/google-deepmind/neural_networks_chomsky_hierarchy.git
60 | cd neural_networks_chomsky_hierarchy
61 | ```
62 |
63 | `pip install -r requirements.txt` will install all required dependencies.
64 | This is best done inside a [conda environment](https://www.anaconda.com/).
65 | To that end, install [Anaconda](https://www.anaconda.com/download#downloads).
66 | Then, create and activate the conda environment:
67 | ```bash
68 | conda create --name nnch
69 | conda activate nnch
70 | ```
71 |
72 | Install `pip` and use it to install all the dependencies:
73 | ```bash
74 | conda install pip
75 | pip install -r requirements.txt
76 | ```
77 |
78 | If you have a GPU available (highly recommended for fast training), then you can install JAX with CUDA support.
79 | ```bash
80 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
81 | ```
82 | Note that the jax version must correspond to the existing CUDA installation you wish to use (CUDA 12 in the example above).
83 | Please see the [JAX documentation](https://github.com/google/jax#installation) for more details.
84 |
85 |
86 |
87 |
88 |
89 | ## Usage
90 |
91 | Before running any code, make sure to activate the conda environment and set the `PYTHONPATH`:
92 | ```bash
93 | conda activate nnch
94 | export PYTHONPATH=$(pwd)/..
95 | ```
96 |
97 | We provide an example of a training and evaluation run at:
98 | ```bash
99 | python experiments/example.py
100 | ```
101 |
102 |
103 | ## Citing This Work
104 |
105 | ```bibtex
106 | @inproceedings{deletang2023neural,
107 | author = {Gr{\'{e}}goire Del{\'{e}}tang and
108 | Anian Ruoss and
109 | Jordi Grau{-}Moya and
110 | Tim Genewein and
111 | Li Kevin Wenliang and
112 | Elliot Catt and
113 | Chris Cundy and
114 | Marcus Hutter and
115 | Shane Legg and
116 | Joel Veness and
117 | Pedro A. Ortega},
118 | title = {Neural Networks and the Chomsky Hierarchy},
119 | booktitle = {11th International Conference on Learning Representations},
120 | year = {2023},
121 | }
122 | ```
123 |
124 |
125 | ## License and Disclaimer
126 |
127 | Copyright 2022 DeepMind Technologies Limited
128 |
129 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
130 | you may not use this file except in compliance with the Apache 2.0 license.
131 | You may obtain a copy of the Apache 2.0 license at:
132 | https://www.apache.org/licenses/LICENSE-2.0
133 |
134 | All other materials are licensed under the Creative Commons Attribution 4.0
135 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
136 | https://creativecommons.org/licenses/by/4.0/legalcode
137 |
138 | Unless required by applicable law or agreed to in writing, all software and
139 | materials distributed here under the Apache 2.0 or CC-BY licenses are
140 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
141 | either express or implied. See the licenses for the specific language governing
142 | permissions and limitations under those licenses.
143 |
144 | This is not an official Google product.
145 |
--------------------------------------------------------------------------------
/experiments/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Constants for our length generalization experiments."""
17 |
18 | import functools
19 |
20 | import haiku as hk
21 |
22 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib
23 | from neural_networks_chomsky_hierarchy.models import ndstack_rnn
24 | from neural_networks_chomsky_hierarchy.models import rnn
25 | from neural_networks_chomsky_hierarchy.models import stack_rnn
26 | from neural_networks_chomsky_hierarchy.models import tape_rnn
27 | from neural_networks_chomsky_hierarchy.models import transformer
28 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition
29 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_multiplication
30 | from neural_networks_chomsky_hierarchy.tasks.cs import bucket_sort
31 | from neural_networks_chomsky_hierarchy.tasks.cs import compute_sqrt
32 | from neural_networks_chomsky_hierarchy.tasks.cs import duplicate_string
33 | from neural_networks_chomsky_hierarchy.tasks.cs import missing_duplicate_string
34 | from neural_networks_chomsky_hierarchy.tasks.cs import odds_first
35 | from neural_networks_chomsky_hierarchy.tasks.dcf import modular_arithmetic_brackets
36 | from neural_networks_chomsky_hierarchy.tasks.dcf import reverse_string
37 | from neural_networks_chomsky_hierarchy.tasks.dcf import solve_equation
38 | from neural_networks_chomsky_hierarchy.tasks.dcf import stack_manipulation
39 | from neural_networks_chomsky_hierarchy.tasks.regular import cycle_navigation
40 | from neural_networks_chomsky_hierarchy.tasks.regular import even_pairs
41 | from neural_networks_chomsky_hierarchy.tasks.regular import modular_arithmetic
42 | from neural_networks_chomsky_hierarchy.tasks.regular import parity_check
43 |
44 |
45 | MODEL_BUILDERS = {
46 | 'rnn':
47 | functools.partial(rnn.make_rnn, rnn_core=hk.VanillaRNN),
48 | 'lstm':
49 | functools.partial(rnn.make_rnn, rnn_core=hk.LSTM),
50 | 'stack_rnn':
51 | functools.partial(
52 | rnn.make_rnn,
53 | rnn_core=stack_rnn.StackRNNCore,
54 | inner_core=hk.VanillaRNN),
55 | 'ndstack_rnn':
56 | functools.partial(
57 | rnn.make_rnn,
58 | rnn_core=ndstack_rnn.NDStackRNNCore,
59 | inner_core=hk.VanillaRNN),
60 | 'stack_lstm':
61 | functools.partial(
62 | rnn.make_rnn, rnn_core=stack_rnn.StackRNNCore, inner_core=hk.LSTM),
63 | 'transformer_encoder':
64 | transformer.make_transformer_encoder,
65 | 'transformer':
66 | transformer.make_transformer,
67 | 'tape_rnn':
68 | functools.partial(
69 | rnn.make_rnn,
70 | rnn_core=tape_rnn.TapeInputLengthJumpCore,
71 | inner_core=hk.VanillaRNN),
72 | }
73 |
74 | CURRICULUM_BUILDERS = {
75 | 'fixed': curriculum_lib.FixedCurriculum,
76 | 'regular_increase': curriculum_lib.RegularIncreaseCurriculum,
77 | 'reverse_exponential': curriculum_lib.ReverseExponentialCurriculum,
78 | 'uniform': curriculum_lib.UniformCurriculum,
79 | }
80 |
81 | TASK_BUILDERS = {
82 | 'modular_arithmetic':
83 | modular_arithmetic.ModularArithmetic,
84 | 'parity_check':
85 | parity_check.ParityCheck,
86 | 'even_pairs':
87 | even_pairs.EvenPairs,
88 | 'cycle_navigation':
89 | cycle_navigation.CycleNavigation,
90 | 'modular_arithmetic_brackets':
91 | functools.partial(
92 | modular_arithmetic_brackets.ModularArithmeticBrackets, mult=True),
93 | 'reverse_string':
94 | reverse_string.ReverseString,
95 | 'missing_duplicate_string':
96 | missing_duplicate_string.MissingDuplicateString,
97 | 'duplicate_string':
98 | duplicate_string.DuplicateString,
99 | 'binary_addition':
100 | binary_addition.BinaryAddition,
101 | 'binary_multiplication':
102 | binary_multiplication.BinaryMultiplication,
103 | 'compute_sqrt':
104 | compute_sqrt.ComputeSqrt,
105 | 'odds_first':
106 | odds_first.OddsFirst,
107 | 'solve_equation':
108 | solve_equation.SolveEquation,
109 | 'stack_manipulation':
110 | stack_manipulation.StackManipulation,
111 | 'bucket_sort':
112 | bucket_sort.BucketSort,
113 | }
114 |
115 | TASK_LEVELS = {
116 | 'modular_arithmetic': 'regular',
117 | 'parity_check': 'regular',
118 | 'even_pairs': 'regular',
119 | 'cycle_navigation': 'regular',
120 | 'modular_arithmetic_brackets': 'dcf',
121 | 'reverse_string': 'dcf',
122 | 'stack_manipulation': 'dcf',
123 | 'solve_equation': 'dcf',
124 | 'missing_duplicate_string': 'cs',
125 | 'compute_sqrt': 'cs',
126 | 'duplicate_string': 'cs',
127 | 'binary_addition': 'cs',
128 | 'binary_multiplication': 'cs',
129 | 'odds_first': 'cs',
130 | 'bucket_sort': 'cs',
131 | }
132 |
--------------------------------------------------------------------------------
/experiments/curriculum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Curricula over sequence lengths used to evaluate length generalization.
17 |
18 | Allows to sample different sequence lengths along training. For instance,
19 | one might want to start with length=1 and regularly increase the length by 1,
20 | every 50k steps.
21 | """
22 |
23 | import abc
24 | from collections.abc import Collection
25 | import random
26 |
27 | import numpy as np
28 |
29 |
30 | class Curriculum(abc.ABC):
31 | """Curriculum to sample lengths."""
32 |
33 | @abc.abstractmethod
34 | def sample_sequence_length(self, step: int) -> int:
35 | """Samples a sequence length from the current distribution."""
36 |
37 |
38 | class FixedCurriculum(Curriculum):
39 | """A fixed curriculum, always sampling the same sequence length."""
40 |
41 | def __init__(self, sequence_length: int):
42 | """Initializes.
43 |
44 | Args:
45 | sequence_length: The sequence length to sample.
46 | """
47 | super().__init__()
48 | self._sequence_length = sequence_length
49 |
50 | def sample_sequence_length(self, step: int) -> int:
51 | """Returns a fixed sequence length."""
52 | del step
53 | return self._sequence_length
54 |
55 |
56 | class UniformCurriculum(Curriculum):
57 | """A uniform curriculum, sampling different sequence lengths."""
58 |
59 | def __init__(self, values: Collection[int]):
60 | """Initializes.
61 |
62 | Args:
63 | values: The sequence lengths to sample.
64 | """
65 | super().__init__()
66 | self._values = tuple(values)
67 |
68 | def sample_sequence_length(self, step: int) -> int:
69 | """Returns a sequence length sampled from a uniform distribution."""
70 | del step
71 | return random.choice(self._values)
72 |
73 |
74 | class ReverseExponentialCurriculum(Curriculum):
75 | """A reverse exponential curriculum, sampling different sequence lengths."""
76 |
77 | def __init__(self, values: Collection[int], tau: bool):
78 | """Initializes.
79 |
80 | Args:
81 | values: The sequence lengths to sample.
82 | tau: The exponential rate to use.
83 | """
84 | super().__init__()
85 | self._values = tuple(values)
86 | self._tau = tau
87 |
88 | def sample_sequence_length(self, step: int) -> int:
89 | """Returns a length sampled from a reverse exponential distribution."""
90 | del step
91 | probs = self._tau**np.array(self._values)
92 | probs = np.array(probs, dtype=np.float32)
93 | probs = probs / np.sum(probs)
94 | return np.random.choice(self._values, p=probs)
95 |
96 |
97 | class RegularIncreaseCurriculum(Curriculum):
98 | """Curriculum for sequence lengths with a regular increase."""
99 |
100 | def __init__(self, initial_sequence_length: int, increase_frequency: int,
101 | increase_amount: int, sample_all_length: bool):
102 | """Initializes.
103 |
104 | Args:
105 | initial_sequence_length: The value of the sequence length at the beginning
106 | of the curriculum.
107 | increase_frequency: How often we increase the possible sequence length.
108 | increase_amount: The amount of the increase in length.
109 | sample_all_length: Whether to sample all length lower than the current one
110 | or just return the current one.
111 | """
112 | super().__init__()
113 | self._initial_sequence_length = initial_sequence_length
114 | self._increase_frequency = increase_frequency
115 | self._increase_amount = increase_amount
116 | self._sample_all_length = sample_all_length
117 |
118 | def sample_sequence_length(self, step: int) -> int:
119 | """Returns a sequence length from the curriculum with the current step."""
120 | if not self._sample_all_length:
121 | return self._initial_sequence_length + self._increase_amount * (
122 | step // self._increase_frequency
123 | )
124 | return (
125 | self._initial_sequence_length
126 | + self._increase_amount
127 | * np.random.randint(0, step // self._increase_frequency + 1)
128 | )
129 |
--------------------------------------------------------------------------------
/experiments/example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Example script to train and evaluate a network."""
17 |
18 | from absl import app
19 | from absl import flags
20 | import haiku as hk
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 | from neural_networks_chomsky_hierarchy.experiments import constants
25 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib
26 | from neural_networks_chomsky_hierarchy.experiments import training
27 | from neural_networks_chomsky_hierarchy.experiments import utils
28 |
29 | _BATCH_SIZE = flags.DEFINE_integer(
30 | 'batch_size',
31 | default=128,
32 | help='Training batch size.',
33 | lower_bound=1,
34 | )
35 | _SEQUENCE_LENGTH = flags.DEFINE_integer(
36 | 'sequence_length',
37 | default=40,
38 | help='Maximum training sequence length.',
39 | lower_bound=1,
40 | )
41 | _TASK = flags.DEFINE_string(
42 | 'task',
43 | default='even_pairs',
44 | help='Length generalization task (see `constants.py` for other tasks).',
45 | )
46 | _ARCHITECTURE = flags.DEFINE_string(
47 | 'architecture',
48 | default='tape_rnn',
49 | help='Model architecture (see `constants.py` for other architectures).',
50 | )
51 |
52 | _IS_AUTOREGRESSIVE = flags.DEFINE_boolean(
53 | 'is_autoregressive',
54 | default=False,
55 | help='Whether to use autoregressive sampling or not.',
56 | )
57 | _COMPUTATION_STEPS_MULT = flags.DEFINE_integer(
58 | 'computation_steps_mult',
59 | default=0,
60 | help=(
61 | 'The amount of computation tokens to append to the input tape (defined'
62 | ' as a multiple of the input length)'
63 | ),
64 | lower_bound=0,
65 | )
66 | # The architecture parameters depend on the architecture, so we cannot define
67 | # them as via flags. See `constants.py` for the required values.
68 | _ARCHITECTURE_PARAMS = {
69 | 'hidden_size': 256,
70 | 'memory_cell_size': 8,
71 | 'memory_size': 40,
72 | }
73 |
74 |
75 | def main(unused_argv) -> None:
76 | # Create the task.
77 | curriculum = curriculum_lib.UniformCurriculum(
78 | values=list(range(1, _SEQUENCE_LENGTH.value + 1))
79 | )
80 | task = constants.TASK_BUILDERS[_TASK.value]()
81 |
82 | # Create the model.
83 | single_output = task.output_length(10) == 1
84 | model = constants.MODEL_BUILDERS[_ARCHITECTURE.value](
85 | output_size=task.output_size,
86 | return_all_outputs=True,
87 | **_ARCHITECTURE_PARAMS,
88 | )
89 | if _IS_AUTOREGRESSIVE.value:
90 | if 'transformer' not in _ARCHITECTURE.value:
91 | model = utils.make_model_with_targets_as_input(
92 | model, _COMPUTATION_STEPS_MULT.value
93 | )
94 | model = utils.add_sampling_to_autoregressive_model(model, single_output)
95 | else:
96 | model = utils.make_model_with_empty_targets(
97 | model, task, _COMPUTATION_STEPS_MULT.value, single_output
98 | )
99 | model = hk.transform(model)
100 |
101 | # Create the loss and accuracy based on the pointwise ones.
102 | def loss_fn(output, target):
103 | loss = jnp.mean(jnp.sum(task.pointwise_loss_fn(output, target), axis=-1))
104 | return loss, {}
105 |
106 | def accuracy_fn(output, target):
107 | mask = task.accuracy_mask(target)
108 | return jnp.sum(mask * task.accuracy_fn(output, target)) / jnp.sum(mask)
109 |
110 | # Create the final training parameters.
111 | training_params = training.ClassicTrainingParams(
112 | seed=0,
113 | model_init_seed=0,
114 | training_steps=10_000,
115 | log_frequency=100,
116 | length_curriculum=curriculum,
117 | batch_size=_BATCH_SIZE.value,
118 | task=task,
119 | model=model,
120 | loss_fn=loss_fn,
121 | learning_rate=1e-3,
122 | accuracy_fn=accuracy_fn,
123 | compute_full_range_test=True,
124 | max_range_test_length=100,
125 | range_test_total_batch_size=512,
126 | range_test_sub_batch_size=64,
127 | is_autoregressive=_IS_AUTOREGRESSIVE.value,
128 | )
129 |
130 | training_worker = training.TrainingWorker(training_params, use_tqdm=True)
131 | _, eval_results, _ = training_worker.run()
132 |
133 | # Gather results and print final score.
134 | accuracies = [r['accuracy'] for r in eval_results]
135 | score = np.mean(accuracies[_SEQUENCE_LENGTH.value + 1 :])
136 | print(f'Network score: {score}')
137 |
138 |
139 | if __name__ == '__main__':
140 | app.run(main)
141 |
--------------------------------------------------------------------------------
/experiments/range_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluation of a network on sequences of different lengths."""
17 |
18 | import dataclasses
19 | import random
20 | from typing import Any, Callable, Mapping
21 |
22 | from absl import logging
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 | import tqdm
28 |
29 |
30 | _Batch = Mapping[str, jnp.ndarray]
31 |
32 |
33 | @dataclasses.dataclass
34 | class EvaluationParams:
35 | """The parameters used for range evaluation of networks."""
36 | model: hk.Transformed
37 | params: hk.Params
38 |
39 | accuracy_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
40 | sample_batch: Callable[[jnp.ndarray, int, int], _Batch]
41 |
42 | max_test_length: int
43 | total_batch_size: int
44 | sub_batch_size: int # We use this to avoid memory overflow.
45 |
46 | is_autoregressive: bool = False
47 |
48 |
49 | def range_evaluation(
50 | eval_params: EvaluationParams,
51 | use_tqdm: bool = False,
52 | ) -> list[Mapping[str, Any]]:
53 | """Evaluates the model on longer, never seen strings and log the results.
54 |
55 | Args:
56 | eval_params: The evaluation parameters, see above.
57 | use_tqdm: Whether to use a progress bar with tqdm.
58 |
59 | Returns:
60 | The list of dicts containing the accuracies.
61 | """
62 | model = eval_params.model
63 | params = eval_params.params
64 |
65 | random.seed(1)
66 | np.random.seed(1)
67 | rng_seq = hk.PRNGSequence(1)
68 |
69 | if eval_params.is_autoregressive:
70 | apply_fn = jax.jit(model.apply, static_argnames=('sample',))
71 | else:
72 | apply_fn = jax.jit(model.apply)
73 |
74 | results = []
75 | lengths = range(1, eval_params.max_test_length + 1)
76 | if use_tqdm:
77 | lengths = tqdm.tqdm(lengths)
78 | for length in lengths:
79 | # We need to clear the cache of jitted functions, to avoid overflow as we
80 | # are jitting len(lengths) ones, which can be a lot.
81 | apply_fn.clear_cache()
82 | sub_accuracies = []
83 | for _ in range(eval_params.total_batch_size // eval_params.sub_batch_size):
84 | batch = eval_params.sample_batch(
85 | next(rng_seq), eval_params.sub_batch_size, length)
86 |
87 | if eval_params.is_autoregressive:
88 | outputs = apply_fn(
89 | params,
90 | next(rng_seq),
91 | batch['input'],
92 | jnp.empty_like(batch['output']),
93 | sample=True)
94 | else:
95 | outputs = apply_fn(params, next(rng_seq), batch['input'])
96 |
97 | sub_accuracies.append(
98 | float(np.mean(eval_params.accuracy_fn(outputs, batch['output']))))
99 | log_data = {
100 | 'length': length,
101 | 'accuracy': np.mean(sub_accuracies),
102 | }
103 | logging.info(log_data)
104 | results.append(log_data)
105 | return results
106 |
--------------------------------------------------------------------------------
/experiments/training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Training loop for length generalization experiments."""
17 |
18 | import dataclasses
19 | import functools
20 | import random
21 | from typing import Any, Callable, Mapping, Optional
22 |
23 | import chex
24 | import haiku as hk
25 | import jax
26 | import jax.numpy as jnp
27 | import numpy as np
28 | import optax
29 | import tqdm
30 |
31 | from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib
32 | from neural_networks_chomsky_hierarchy.experiments import range_evaluation
33 | from neural_networks_chomsky_hierarchy.tasks import task as task_lib
34 |
35 |
36 | _LossMetrics = Optional[Mapping[str, jnp.ndarray]]
37 | _LossFn = Callable[[chex.Array, chex.Array], tuple[float, _LossMetrics]]
38 | _AccuracyFn = Callable[[chex.Array, chex.Array], float]
39 | _ModelApplyFn = Callable[..., chex.Array]
40 | _MAX_RNGS_RESERVE = 50000
41 |
42 |
43 | @dataclasses.dataclass
44 | class ClassicTrainingParams:
45 | """Parameters needed to train classical architectures."""
46 | seed: int # Used to sample during forward pass (e.g. from final logits).
47 | model_init_seed: int # Used to initialize model parameters.
48 | training_steps: int
49 | log_frequency: int
50 |
51 | task: task_lib.GeneralizationTask
52 | length_curriculum: curriculum_lib.Curriculum
53 | batch_size: int
54 |
55 | model: hk.Transformed
56 | loss_fn: Callable[[jnp.ndarray, jnp.ndarray], tuple[float, _LossMetrics]]
57 | learning_rate: float
58 | test_model: Optional[hk.Transformed] = None
59 | max_grad_norm: float = 1.
60 | is_autoregressive: bool = False
61 |
62 | compute_full_range_test: bool = False
63 | range_test_total_batch_size: int = 512
64 | range_test_sub_batch_size: int = 64
65 | max_range_test_length: int = 100
66 |
67 | accuracy_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray],
68 | jnp.ndarray]] = None
69 |
70 |
71 | def _apply_loss_and_metrics_fn(
72 | params: hk.Params,
73 | rng_key: chex.PRNGKey,
74 | batch: task_lib.Batch,
75 | model_apply_fn: _ModelApplyFn,
76 | loss_fn: _LossFn,
77 | accuracy_fn: _AccuracyFn,
78 | is_autoregressive: bool = False,
79 | ) -> tuple[float, tuple[_LossMetrics, float]]:
80 | """Computes the model output and applies the loss function.
81 |
82 | Depending on whether a model is autoregressive or not, it will have a
83 | different number of input parameters (i.e., autoregressive models also require
84 | the targets as an input).
85 |
86 | Args:
87 | params: The model parameters.
88 | rng_key: The prng key to use for random number generation.
89 | batch: The data (consists of both inputs and outputs).
90 | model_apply_fn: The model function that converts inputs into outputs.
91 | loss_fn: A function that computes the loss for a batch of logits and labels.
92 | accuracy_fn: A function that computes the accuracy for a batch of logits and
93 | labels.
94 | is_autoregressive: Whether the model is autoregressive or not.
95 |
96 | Returns:
97 | The loss of the model for the batch of data, extra loss metrics and the
98 | accuracy, if accuracy_fn is not None.
99 | """
100 | if is_autoregressive:
101 | outputs = model_apply_fn(
102 | params, rng_key, batch["input"], batch["output"], sample=False)
103 | else:
104 | outputs = model_apply_fn(params, rng_key, batch["input"])
105 |
106 | loss, loss_metrics = loss_fn(outputs, batch["output"])
107 | if accuracy_fn is not None:
108 | accuracy = accuracy_fn(outputs, batch["output"])
109 | else:
110 | accuracy = None
111 | return loss, (loss_metrics, accuracy)
112 |
113 |
114 | @functools.partial(
115 | jax.jit,
116 | static_argnames=(
117 | "model_apply_fn",
118 | "loss_fn",
119 | "accuracy_fn",
120 | "optimizer",
121 | "is_autoregressive",
122 | ),
123 | )
124 | def _update_parameters(
125 | params: hk.Params,
126 | rng_key: chex.PRNGKey,
127 | batch: task_lib.Batch,
128 | model_apply_fn: _ModelApplyFn,
129 | loss_fn: _LossFn,
130 | accuracy_fn: _AccuracyFn,
131 | optimizer: optax.GradientTransformation,
132 | opt_state: optax.OptState,
133 | is_autoregressive: bool = False,
134 | ) -> tuple[hk.Params, optax.OptState, tuple[float, _LossMetrics, float]]:
135 | """Applies a single SGD update step to the model parameters.
136 |
137 | Args:
138 | params: The model parameters.
139 | rng_key: The prng key to use for random number generation.
140 | batch: The data (consists of both inputs and outputs).
141 | model_apply_fn: The model function that converts inputs into outputs.
142 | loss_fn: A function that computes the loss for a batch of logits and labels.
143 | accuracy_fn: A function that computes the accuracy for a batch of logits and
144 | labels.
145 | optimizer: The optimizer that computes the updates from the gradients of the
146 | `loss_fn` with respect to the `params` and the previous `opt_state`.
147 | opt_state: The optimizer state, e.g., momentum for each variable when using
148 | Adam.
149 | is_autoregressive: Whether the model is autoregressive or not.
150 |
151 | Returns:
152 | The updated parameters, the new optimizer state, and the loss, loss metrics
153 | and accuracy.
154 | """
155 | (loss, (metrics, accuracy)), grads = jax.value_and_grad(
156 | _apply_loss_and_metrics_fn,
157 | has_aux=True)(params, rng_key, batch, model_apply_fn, loss_fn,
158 | accuracy_fn, is_autoregressive)
159 | updates, new_opt_state = optimizer.update(grads, opt_state)
160 | new_params = optax.apply_updates(params, updates)
161 | return new_params, new_opt_state, (loss, metrics, accuracy)
162 |
163 |
164 | class TrainingWorker:
165 | """Training worker."""
166 |
167 | def __init__(self,
168 | training_params: ClassicTrainingParams,
169 | use_tqdm: bool = False):
170 | """Initializes the worker.
171 |
172 | Args:
173 | training_params: The training parameters.
174 | use_tqdm: Whether to add a progress bar to stdout.
175 | """
176 | self._training_params = training_params
177 | self._use_tqdm = use_tqdm
178 |
179 | def run(
180 | self,
181 | ) -> tuple[
182 | list[Mapping[str, Any]], Optional[list[Mapping[str, Any]]], chex.ArrayTree
183 | ]:
184 | """Trains the model with the provided config.
185 |
186 | Returns:
187 | Results (various training and validation metrics), module parameters
188 | and router parameters.
189 | """
190 | training_params = self._training_params
191 | rngs_reserve = min(_MAX_RNGS_RESERVE, training_params.training_steps)
192 |
193 | random.seed(training_params.seed)
194 | np.random.seed(training_params.seed)
195 | rng_seq = hk.PRNGSequence(training_params.seed)
196 | rng_seq.reserve(rngs_reserve)
197 |
198 | results = []
199 | model = training_params.model
200 | task = training_params.task
201 | length_curriculum = training_params.length_curriculum
202 |
203 | optimizer = optax.chain(
204 | optax.clip_by_global_norm(training_params.max_grad_norm),
205 | optax.adam(training_params.learning_rate))
206 |
207 | dummy_batch = task.sample_batch(
208 | next(rng_seq), length=10, batch_size=training_params.batch_size)
209 | model_init_rng_key = jax.random.PRNGKey(training_params.model_init_seed)
210 |
211 | if training_params.is_autoregressive:
212 | params = model.init(
213 | model_init_rng_key,
214 | dummy_batch["input"],
215 | dummy_batch["output"],
216 | sample=False)
217 | else:
218 | params = model.init(model_init_rng_key, dummy_batch["input"])
219 |
220 | opt_state = optimizer.init(params)
221 | self._params, self._step = params, 0
222 |
223 | steps = range(training_params.training_steps + 1)
224 | if self._use_tqdm:
225 | steps = tqdm.tqdm(steps)
226 | for step in steps:
227 | # Randomness handled by either python.random or numpy.
228 | length = length_curriculum.sample_sequence_length(step)
229 | # Randomness handled by either jax, python.random or numpy.
230 | train_batch = task.sample_batch(
231 | next(rng_seq), length=length, batch_size=training_params.batch_size)
232 | params, opt_state, (
233 | train_loss, train_metrics, train_accuracy) = _update_parameters(
234 | params=params,
235 | rng_key=next(rng_seq),
236 | batch=train_batch,
237 | model_apply_fn=model.apply,
238 | loss_fn=training_params.loss_fn,
239 | accuracy_fn=training_params.accuracy_fn,
240 | optimizer=optimizer,
241 | opt_state=opt_state,
242 | is_autoregressive=training_params.is_autoregressive)
243 | self._params, self._step = params, step
244 |
245 | log_freq = training_params.log_frequency
246 | if (log_freq > 0) and (step % log_freq == 0):
247 | log_data = {
248 | "step": step,
249 | "train_loss": float(train_loss),
250 | }
251 | if training_params.accuracy_fn is not None:
252 | log_data["train_accuracy"] = float(train_accuracy)
253 | for key, value in train_metrics.items():
254 | log_data[".".join(["train_metrics", key])] = np.array(value)
255 | results.append(log_data)
256 |
257 | # We need to access this private attribute since the default reserve size
258 | # can not be edited yet.
259 | if not rng_seq._subkeys: # pylint: disable=protected-access
260 | rng_seq.reserve(rngs_reserve)
261 |
262 | eval_results = None
263 | if training_params.compute_full_range_test:
264 | eval_params = range_evaluation.EvaluationParams(
265 | model=training_params.test_model or model,
266 | params=params,
267 | accuracy_fn=training_params.accuracy_fn,
268 | sample_batch=task.sample_batch,
269 | max_test_length=training_params.max_range_test_length,
270 | total_batch_size=training_params.range_test_total_batch_size,
271 | sub_batch_size=training_params.range_test_sub_batch_size,
272 | is_autoregressive=training_params.is_autoregressive,
273 | )
274 | eval_results = range_evaluation.range_evaluation(
275 | eval_params, use_tqdm=False)
276 |
277 | return results, eval_results, params
278 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Provides utility functions for training and evaluation."""
17 |
18 | import inspect
19 | from typing import Any, Callable
20 |
21 | import chex
22 | import haiku as hk
23 | from jax import nn as jnn
24 | from jax import numpy as jnp
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 |
28 | COMPUTATION_EMPTY_TOKEN = 0
29 | OUTPUT_EMPTY_TOKEN = 1
30 |
31 |
32 | def make_model_with_empty_targets(
33 | model: Callable[[chex.Array], chex.Array],
34 | generalization_task: task.GeneralizationTask,
35 | computation_steps_mult: int = 0,
36 | single_output: bool = False,
37 | ) -> Callable[[chex.Array], chex.Array]:
38 | """Returns a wrapped model that pads the inputs to match the output length.
39 |
40 | For a given input tape `input_tape` of vocabulary size `vocab_size`, the
41 | wrapped model will process a tape of the format
42 | [`input_tape`, `empty_tape`], where the empty tape token is `vocab_size + 1`.
43 | The `empty_tape` has the same length as the task output.
44 |
45 | Args:
46 | model: A model function that converts inputs to outputs.
47 | generalization_task: The task that we train on.
48 | computation_steps_mult: The amount of empty cells to append to the input
49 | tape. This variable is a multiplier and the actual number of cells is
50 | `computation_steps_mult * input_length`.
51 | single_output: Whether to return the squeezed tensor of values.
52 | """
53 |
54 | def new_model(x: chex.Array) -> chex.Array:
55 | batch_size, input_length, input_size = x.shape
56 | output_length = generalization_task.output_length(input_length)
57 | extra_dims_onehot = 1 + int(computation_steps_mult > 0)
58 | final_input_size = input_size + extra_dims_onehot
59 |
60 | # Add trailing zeros to account for new final_input_size.
61 | extra_zeros_x = jnp.zeros(
62 | (batch_size, input_length, final_input_size - input_size)
63 | )
64 | x = jnp.concatenate([x, extra_zeros_x], axis=-1)
65 |
66 | computation_tape = jnp.full(
67 | (batch_size, computation_steps_mult * input_length),
68 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN)
69 | computation_tape = jnn.one_hot(
70 | computation_tape, num_classes=final_input_size
71 | )
72 |
73 | output_tokens = jnp.full(
74 | (batch_size, output_length),
75 | fill_value=input_size
76 | + OUTPUT_EMPTY_TOKEN
77 | - int(computation_steps_mult == 0),
78 | )
79 | output_tokens = jnn.one_hot(output_tokens, num_classes=final_input_size)
80 | final_input = jnp.concatenate([x, computation_tape, output_tokens], axis=1)
81 |
82 | if 'input_length' in inspect.getfullargspec(model).args:
83 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args
84 | else:
85 | output = model(final_input)
86 | output = output[:, -output_length:]
87 | if single_output:
88 | output = jnp.squeeze(output, axis=1)
89 | return output
90 |
91 | return new_model
92 |
93 |
94 | def make_model_with_targets_as_input(
95 | model: Callable[[chex.Array], chex.Array], computation_steps_mult: int = 0
96 | ) -> Callable[[chex.Array, chex.Array], chex.Array]:
97 | """Returns a wrapped model that takes the targets as inputs.
98 |
99 | This function is useful for the autoregressive case where we pass the targets
100 | as inputs to the model. The final input looks like:
101 | [inputs, computation_tokens, output_token, targets]
102 |
103 | Args:
104 | model: A haiku model that takes 'x' as input.
105 | computation_steps_mult: The amount of computation tokens to append to the
106 | input tape. This variable is a multiplier and the actual number of cell is
107 | computation_steps_mult * input_length.
108 | """
109 |
110 | def new_model(x: chex.Array, y: chex.Array) -> chex.Array:
111 | """Returns an output from the inputs and targets.
112 |
113 | Args:
114 | x: One-hot input vectors, shape (B, T, input_size).
115 | y: One-hot target output vectors, shape (B, T, output_size).
116 | """
117 | batch_size, input_length, input_size = x.shape
118 | _, output_length, output_size = y.shape
119 | extra_dims_onehot = 1 + int(computation_steps_mult > 0)
120 | final_input_size = max(input_size, output_size) + extra_dims_onehot
121 |
122 | # Add trailing zeros to account for new final_input_size.
123 | extra_zeros_x = jnp.zeros(
124 | (batch_size, input_length, final_input_size - input_size)
125 | )
126 | x = jnp.concatenate([x, extra_zeros_x], axis=-1)
127 | extra_zeros_y = jnp.zeros(
128 | (batch_size, output_length, final_input_size - output_size)
129 | )
130 | y = jnp.concatenate([y, extra_zeros_y], axis=-1)
131 |
132 | computation_tape = jnp.full(
133 | (batch_size, computation_steps_mult * input_length),
134 | fill_value=input_size + COMPUTATION_EMPTY_TOKEN,
135 | )
136 | computation_tape = jnn.one_hot(
137 | computation_tape, num_classes=final_input_size
138 | )
139 |
140 | output_token = jnp.full(
141 | (batch_size, 1),
142 | fill_value=input_size
143 | + OUTPUT_EMPTY_TOKEN
144 | - int(computation_steps_mult == 0),
145 | )
146 | output_token = jnn.one_hot(output_token, num_classes=final_input_size)
147 | final_input = jnp.concatenate(
148 | [x, computation_tape, output_token, y], axis=1
149 | )
150 |
151 | if 'input_length' in inspect.getfullargspec(model).args:
152 | output = model(final_input, input_length=input_length) # pytype: disable=wrong-keyword-args
153 | else:
154 | output = model(final_input)
155 |
156 | return output[:, -output_length - 1 : -1]
157 |
158 | return new_model
159 |
160 |
161 | def add_sampling_to_autoregressive_model(
162 | model: Callable[[chex.Array, chex.Array], chex.Array],
163 | single_output: bool = False,
164 | ) -> Callable[[chex.Array, chex.Array, bool], chex.Array]:
165 | """Adds a 'sample' argument to the model, to use autoregressive sampling."""
166 |
167 | def new_model_with_sampling(
168 | x: chex.Array,
169 | y: chex.Array,
170 | sample: bool,
171 | ) -> chex.Array:
172 | """Returns an autoregressive model if `sample == True and output_size > 1`.
173 |
174 | Args:
175 | x: The input sequences of shape (b, t, i), where i is the input size.
176 | y: The target sequences of shape (b, t, o), where o is the output size.
177 | sample: Whether to evaluate the model using autoregressive decoding.
178 | """
179 | output_length = 1 if len(y.shape) == 2 else y.shape[1]
180 | output_size = y.shape[-1]
181 |
182 | if not sample or output_length == 1:
183 | output = model(x, y)
184 |
185 | else:
186 |
187 | def evaluate_model_autoregressively(
188 | idx: int,
189 | predictions: chex.Array,
190 | ) -> chex.Array:
191 | """Iteratively evaluates the model based on the previous predictions.
192 |
193 | Args:
194 | idx: The index of the target sequence that should be evaluated.
195 | predictions: The logits for the predictions up to but not including
196 | the index `idx`.
197 |
198 | Returns:
199 | The `predictions` array modified only at position `idx` where the
200 | logits for index `idx` have been inserted.
201 | """
202 | one_hot_predictions = jnn.one_hot(
203 | jnp.argmax(predictions, axis=-1),
204 | num_classes=output_size,
205 | )
206 | logits = model(x, one_hot_predictions)
207 | return predictions.at[:, idx].set(logits[:, idx])
208 |
209 | output = hk.fori_loop(
210 | lower=0,
211 | upper=output_length,
212 | body_fun=evaluate_model_autoregressively,
213 | init_val=jnp.empty_like(y),
214 | )
215 |
216 | if single_output:
217 | output = jnp.squeeze(output, axis=1)
218 | return output
219 |
220 | return new_model_with_sampling
221 |
222 |
223 | def update_tree_with_new_containers(
224 | tree: Any, update_dict: dict[str, Any]
225 | ) -> None:
226 | """Updates a dataclass tree in place, adding new containers.
227 |
228 | This method is useful for the nested library to add fields to a tree, for
229 | which containers have not been created.
230 | For instance, if A is a dataclass with attribute architecture_params, and we
231 | want to add the value architecture_params.rnn_model.size, we need to create
232 | the container 'rnn_model' inside architecture_params.
233 |
234 | Args:
235 | tree: An object with attribute (typically a dataclass).
236 | update_dict: A dict of nested updates. See example above.
237 | """
238 | for key in update_dict:
239 | subkeys = key.split('.')
240 | if len(subkeys) >= 2:
241 | # Example: architecture.params.size
242 | for i in range(0, len(subkeys) - 2):
243 | getattr(tree, subkeys[i])[subkeys[i + 1]] = {}
244 |
--------------------------------------------------------------------------------
/models/ndstack_rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Non-deterministic Stack RNN core.
17 |
18 | Following the paper from DuSell et al (2020):
19 | https://arxiv.org/abs/2010.04674
20 |
21 | The idea is to add a non-deterministic stack extension to a recurrent neural
22 | network to be able to simulate a machine accepting non-deterministic
23 | context-free languages. It can be seen as an extension to the Stack-RNN
24 | developed by Joulin et al (2015). However, it is far more complex and hard to
25 | understand.
26 | The non-deterministic stack is completely differentiable.
27 |
28 | A non-deterministic Pushdown Automata (NDPDA) uses 'multiple stacks at the same
29 | time'. The problem is that the number of possible stacks grows exponentially
30 | with time, which makes a naive practical implementation impossible. However,
31 | Lang et al proved in 1969, based on ideas from Context-Frees parsers like the
32 | CYK, that a NDPDA can be simulated only using O(n³) memory, and not O(2^n).
33 | The main idea is to reuse the content of the different stacks in a dynamic
34 | programming manner. A stack with n values is a stack with n-1 values + an extra
35 | value, so we can build a graph of possible stacks, which would reuse most of
36 | the data.
37 |
38 | Concretely, the graph is made of nodes (t, q, x) where t is a number (the time),
39 | q is a state from a fixed, user-defined set and x is a symbol or value, also
40 | from a finite, user-defined set. Then one path in this graph is exactly one
41 | stack, which can simply be reconstructed by reading the value for each node
42 | in the path. Each state q can be used as a 'branching' mechanism: the more
43 | states, the more branching there can be and therefore the more stacks can be
44 | used. The number of possible stacks is (#states * #symbols)^t.
45 |
46 | To interact with this graph, ie do a push or a pop action, one uses transitions
47 | on these nodes. For push, it is a function of the form (q1, x1) -> (q2, x2),
48 | where q2 is the new state to go in (ie whether to branch to a new stack, or keep
49 | the same) and x2 is the value to push. For pop, it is a function of the form
50 | (q1, x1) -> q2, which again allows the network to choose whether to create a new
51 | stack or not. No value should be passed there however. The functions are
52 | modelled by transition matrices of shape (Q, S, Q, S) where Q=#states and
53 | S=#symbols.
54 | Once the action matrices are passed, the graph is updated. The update is done
55 | via an internal transition matrix called gamma. This matrix is simple for the
56 | push action (one can only push on the top of the stack, ie nodes for which t =
57 | current timestep). It is far more complex for the pop action, as popping a value
58 | from the current stack can completely change the structure of the graph: the
59 | new stack after popping might be equal to a very old stack seen at the beginning
60 | of the episode, and we must change the links accordingly. Roughly, the update
61 | operation for gamma has a time complexity of O(Q⁴ S³ n³).
62 | Finally, one the graph is updated via gamma, we update the probabilities of the
63 | top of stacks, which gives us a tensor called alpha. From alpha we deduce the
64 | average top of the stack to be sent to the agent.
65 |
66 | As there are 3 actions (pop/push/no_op), unrolling this over
67 | long sequences and using big batch sizes consumes too much memory and the
68 | accelerators fail.
69 |
70 | Notations:
71 | Q: number of states of the ND stack (not the number of states of the
72 | RNN).
73 | S: number of symbols which can be pushed on the stack.
74 | T: Sequence length.
75 | B: Batch size.
76 | """
77 |
78 | from typing import Any, Mapping, NamedTuple, Optional
79 |
80 | import chex
81 | import haiku as hk
82 | import jax
83 | import jax.nn as jnn
84 | import jax.numpy as jnp
85 |
86 | _EPSILON = 0.001
87 |
88 |
89 | class NDStack(NamedTuple):
90 | """The non-deterministic stack.
91 |
92 | Note that alpha and top_stack depend on gamma.
93 | """
94 | gamma: chex.Array # Shape (B, T, T, Q, S, Q, S)
95 | alpha: chex.Array # Shape (B, T, Q, S)
96 | top_stack: chex.Array # Shape (B, S)
97 |
98 |
99 | def _update_stack(ndstack: NDStack,
100 | push_actions: chex.Array,
101 | pop_actions: chex.Array,
102 | replace_actions: chex.Array,
103 | timestep: int,
104 | read_states: bool = True) -> NDStack:
105 | """Returns an updated NDStack.
106 |
107 | Args:
108 | ndstack: See above. Contains the internals needed to simulate a
109 | non-deterministic stack.
110 | push_actions: A tensor of shape (B, Q, S, Q, S).
111 | pop_actions: A tensor of shape (B, Q, S, Q).
112 | replace_actions: A tensor of shape (B, Q, S, Q, S).
113 | timestep: The current timestep while processing the sequence.
114 | read_states: Whether to read the state of the NPDA as well.
115 | """
116 | stack_size = ndstack.gamma.shape[2]
117 | mask = jnp.zeros((stack_size, stack_size))
118 | mask = mask.at[timestep - 1, timestep].set(1)
119 | new_push_gamma_t = jnp.einsum('bqxry,tT->btTqxry', push_actions,
120 | mask)[:, :, timestep]
121 |
122 | index_k = jnp.stack([jnp.arange(start=0, stop=stack_size)] * stack_size)
123 | index_i = jnp.transpose(index_k)
124 | timestep_arr = jnp.full((stack_size, stack_size), timestep)
125 | index_mask = jnp.logical_and(index_k > index_i, index_k < timestep_arr - 1)
126 | index_mask = jnp.einsum('tT,bqxry->btTqxry', index_mask,
127 | jnp.ones(push_actions.shape))
128 | new_pop_gamma_t = jnp.einsum(
129 | 'bikqxuy,bkuysz,bszr->biqxry',
130 | index_mask * ndstack.gamma,
131 | ndstack.gamma[:, :, timestep - 1],
132 | pop_actions,
133 | )
134 |
135 | new_replace_gamma_t = jnp.einsum('biqxsz,bszry->biqxry',
136 | ndstack.gamma[:, :,
137 | timestep - 1], replace_actions)
138 |
139 | new_gamma = jax.vmap(jax.vmap(lambda x, y: x.at[timestep].set(y)))(
140 | ndstack.gamma, new_replace_gamma_t + new_pop_gamma_t + new_push_gamma_t)
141 |
142 | alpha_t = jnp.einsum('biqx,biqxry->bry', ndstack.alpha, new_gamma[:, :,
143 | timestep])
144 | new_alpha = jax.vmap(lambda x, y: x.at[timestep].set(y))(ndstack.alpha,
145 | alpha_t)
146 |
147 | if read_states:
148 | batch_size, states, symbols = alpha_t.shape
149 | obs = jnp.reshape(alpha_t, (batch_size, states * symbols))
150 | else:
151 | obs = jnp.sum(alpha_t, axis=1)
152 |
153 | obs = obs / (jnp.sum(obs, axis=-1, keepdims=True) + _EPSILON)
154 | return NDStack(new_gamma, new_alpha, top_stack=obs)
155 |
156 |
157 | # First element is the NDStack, second is the current timestep, third is the
158 | # hidden internal state.
159 | _NDStackRnnState = tuple[NDStack, chex.Array, chex.Array]
160 |
161 |
162 | class NDStackRNNCore(hk.RNNCore):
163 | """Core for the non-deterministic stack RNN."""
164 |
165 | def __init__(
166 | self,
167 | stack_symbols: int,
168 | stack_states: int,
169 | stack_size: int = 30,
170 | inner_core: type[hk.RNNCore] = hk.VanillaRNN,
171 | read_states: bool = False,
172 | name: Optional[str] = None,
173 | **inner_core_kwargs: Mapping[str, Any]
174 | ):
175 | """Initializes.
176 |
177 | Args:
178 | stack_symbols: The number of symbols which can be used in the stack.
179 | stack_states: The number of states of the non-deterministic stack.
180 | Corresponds to the number of branching in the graph, ie roughly n_stacks
181 | = stack_states ^ t.
182 | stack_size: The total size of the stacks. Be careful when increasing this
183 | value since the computation is in O(stack_size ^ 3).
184 | inner_core: The inner RNN core builder.
185 | read_states: Whether to read the states on the NPDA or only the top of the
186 | stack.
187 | name: See base class.
188 | **inner_core_kwargs: The arguments to be passed to the inner RNN core
189 | builder.
190 | """
191 | super().__init__(name=name)
192 | self._rnn_core = inner_core(**inner_core_kwargs)
193 | self._stack_symbols = stack_symbols
194 | self._stack_states = stack_states
195 | self._stack_size = stack_size
196 | self._read_states = read_states
197 |
198 | def __call__(
199 | self, inputs: chex.Array, prev_state: _NDStackRnnState
200 | ) -> tuple[chex.Array, _NDStackRnnState]:
201 | """Steps the non-deterministic stack RNN core.
202 |
203 | See base class docstring.
204 |
205 | Args:
206 | inputs: An input array of shape (batch_size, input_size). The time
207 | dimension is not included since it is an RNNCore, which is unrolled over
208 | the time dimension.
209 | prev_state: A _NDStackRnnState tuple, consisting of the previous nd-stack,
210 | the previous timestep and the previous state of the inner core.
211 |
212 | Returns:
213 | - output: An output array of shape (batch_size, output_size).
214 | - next_state: Same format as prev_state.
215 | """
216 | ndstack, timestep, old_core_state = prev_state
217 |
218 | # The network can always read the top of the stack.
219 | batch_size = ndstack.gamma.shape[0]
220 | inputs = jnp.concatenate([inputs, ndstack.top_stack], axis=-1)
221 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state)
222 |
223 | n_push_actions = (self._stack_states * self._stack_symbols)**2
224 | n_pop_actions = self._stack_states**2 * self._stack_symbols
225 | n_replace_actions = (self._stack_states * self._stack_symbols)**2
226 | actions = hk.Linear(n_push_actions + n_pop_actions + n_replace_actions)(
227 | new_core_output)
228 | actions = jnn.softmax(actions, axis=-1)
229 |
230 | push_actions = jnp.reshape(
231 | actions[:, :n_push_actions],
232 | (batch_size, self._stack_states, self._stack_symbols,
233 | self._stack_states, self._stack_symbols))
234 |
235 | pop_actions = jnp.reshape(
236 | actions[:, n_push_actions:n_push_actions + n_pop_actions],
237 | (batch_size, self._stack_states, self._stack_symbols,
238 | self._stack_states))
239 |
240 | replace_actions = jnp.reshape(
241 | actions[:, -n_replace_actions:],
242 | (batch_size, self._stack_states, self._stack_symbols,
243 | self._stack_states, self._stack_symbols))
244 |
245 | new_ndstack = _update_stack(
246 | ndstack,
247 | push_actions,
248 | pop_actions,
249 | replace_actions, (timestep + 1)[0],
250 | read_states=self._read_states)
251 | return new_core_output, (new_ndstack, timestep + 1, new_core_state)
252 |
253 | def initial_state(self, batch_size: Optional[int]) -> _NDStackRnnState:
254 | """Returns the initial state of the core, a hidden state and an empty stack."""
255 | core_state = self._rnn_core.initial_state(batch_size)
256 |
257 | # Gamma, the transition matrix, is initialized to full zeros: there is no
258 | # connection in the graph at the beginning.
259 | gamma = jnp.zeros(
260 | (batch_size, self._stack_size, self._stack_size, self._stack_states,
261 | self._stack_symbols, self._stack_states, self._stack_symbols))
262 |
263 | # Alpha is zero everywhere except for the first node, which is (0, q0, 0).
264 | alpha = jnp.zeros(
265 | (batch_size, self._stack_size, self._stack_states, self._stack_symbols))
266 | alpha = jax.vmap(lambda x: x.at[0, 0, 0].set(1))(alpha)
267 |
268 | if self._read_states:
269 | top_stack = jnp.zeros(
270 | (batch_size, self._stack_states * self._stack_symbols))
271 | else:
272 | # The top of the stack is 0 as the first node contains the symbol 0.
273 | top_stack = jnp.zeros((batch_size, self._stack_symbols))
274 |
275 | ndstack = NDStack(gamma, alpha, top_stack)
276 | return ndstack, jnp.zeros((batch_size,), dtype=jnp.int32), core_state
277 |
--------------------------------------------------------------------------------
/models/positional_encodings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Positional encodings, used in `transformer.py`."""
17 |
18 | import enum
19 | import math
20 | from typing import Any
21 |
22 | import chex
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 |
28 |
29 | class PositionalEncodings(enum.Enum):
30 | """Enum for all the positional encodings implemented."""
31 | NONE = 0
32 | SIN_COS = 1
33 | ALIBI = 2
34 | RELATIVE = 3
35 | ROTARY = 4
36 |
37 |
38 | # General type used throughout the class for pos enc parameters.
39 | PositionalEncodingsParams = Any
40 |
41 |
42 | @chex.dataclass
43 | class SinCosParams:
44 | """Parameters for the classical sin/cos positional encoding."""
45 | # The maximum wavelength used.
46 | max_time: int = 10_000
47 |
48 |
49 | # We will use this same class for Rotary and Relative.
50 | RotaryParams = SinCosParams
51 | RelativeParams = SinCosParams
52 |
53 |
54 | POS_ENC_TABLE = {
55 | 'NONE': PositionalEncodings.NONE,
56 | 'SIN_COS': PositionalEncodings.SIN_COS,
57 | 'ALIBI': PositionalEncodings.ALIBI,
58 | 'RELATIVE': PositionalEncodings.RELATIVE,
59 | 'ROTARY': PositionalEncodings.ROTARY,
60 | }
61 |
62 | POS_ENC_PARAMS_TABLE = {
63 | 'NONE': SinCosParams,
64 | 'SIN_COS': SinCosParams,
65 | 'ALIBI': SinCosParams,
66 | 'RELATIVE': RelativeParams,
67 | 'ROTARY': RotaryParams,
68 | }
69 |
70 |
71 | def sinusoid_position_encoding(
72 | sequence_length: int,
73 | hidden_size: int,
74 | memory_length: int = 0,
75 | max_timescale: float = 1e4,
76 | min_timescale: float = 2.0,
77 | clamp_length: int = 0,
78 | causal: bool = False,
79 | ):
80 | """Creates sinusoidal encodings.
81 |
82 | The time dimension is larger than sequence_length as we need to cover all
83 | cases of looking in either the future or past.
84 |
85 | Args:
86 | sequence_length: `int` sequence length, L
87 | hidden_size: `int` dimension of the positional encoding vectors, D
88 | memory_length: `int` size of the memory, M
89 | max_timescale: `int` maximum timescale for the frequency
90 | min_timescale: `int` minimum timescale for the frequency
91 | clamp_length: If greater than 0, any positions further apart than
92 | `clamp_length` are clamped to this value
93 | causal: If true then generates a smaller set (L vs 2 * L) of time-encodings
94 | for the use-case of causal attention.
95 |
96 | Returns:
97 | An array of shape [L + M, D] for causal and [2 * L + M, D] otherwise.
98 | """
99 | freqs = np.arange(0, hidden_size, min_timescale)
100 | inv_freq = max_timescale ** (-freqs / hidden_size)
101 | # Since inputs can look into the past and into the future, depending on the
102 | # permutation mask, we need to have relative encodings for both. The furthest
103 | # back an input can see is the final token, up to sequence_length +
104 | # memory_length - 1. The furthest ahead an input can see is for token 0 where
105 | # it can see up to sequence_length - 1 future tokens.
106 | if causal:
107 | pos_seq = np.arange(sequence_length + memory_length, 0, -1.0)
108 | else:
109 | pos_seq = np.arange(sequence_length + memory_length, -sequence_length, -1.0)
110 | if clamp_length:
111 | pos_seq = np.clip(pos_seq, a_min=-clamp_length, a_max=clamp_length)
112 | sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq)
113 | pos_emb = np.concatenate(
114 | [np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1
115 | )
116 | return pos_emb
117 |
118 |
119 | def _rel_shift_inner(logits: chex.Array, attention_length: int) -> chex.Array:
120 | """Shifts the relative logits.
121 |
122 | This is a more general than the original Transformer-XL implementation as
123 | inputs may also see the future. (The implementation does not rely on a
124 | causal mask removing the upper-right triangle.)
125 |
126 | Given attention length 3 and inputs:
127 | [[-3, -2, -1, 0, 1, 2],
128 | [-3, -2, -1, 0, 1, 2],
129 | [-3, -2, -1, 0, 1, 2]]
130 |
131 | The shifted output is:
132 | [[0, 1, 2],
133 | [-1, 0, 1],
134 | [-2, -1, 0]]
135 |
136 | Args:
137 | logits: input tensor of shape [T_q, T_v + T_q]
138 | attention_length: T_v `int` length of the attention, should be equal to
139 | memory size + sequence length.
140 |
141 | Returns:
142 | A shifted version of the input of size [T_q, T_v]. In each row, a window of
143 | size T_v elements is kept. The window starts at
144 | the rightmost end, for the first row. It then shifts left by 1 for each
145 | subsequent row.
146 | """
147 | if logits.ndim != 2:
148 | raise ValueError('`logits` needs to be an array of dimension 2.')
149 | tq, total_len = logits.shape
150 | assert total_len == tq + attention_length
151 | logits = jnp.reshape(logits, [total_len, tq])
152 | logits = jax.lax.slice(logits, (1, 0), logits.shape) # logits[1:]
153 | logits = jnp.reshape(logits, [tq, total_len - 1])
154 | # Equiv to logits[:, :attention_length].
155 | logits = jax.lax.slice(logits, (0, 0), (tq, attention_length))
156 | return logits
157 |
158 |
159 | def _rel_shift_causal(logits: chex.Array) -> chex.Array:
160 | """Shifts the relative logits, assuming causal attention.
161 |
162 | Given inputs:
163 | [[-4, -3, -2, -1],
164 | [-4, -3, -2, -1]]
165 |
166 | The shifted (and, later, masked) output is:
167 | [[-3, -2, -1, 0],
168 | [-4, -3, -2, -1]]
169 |
170 | Args:
171 | logits: input tensor of shape [T_q, T_v]
172 |
173 | Returns:
174 | A shifted version of the input of size [T_q, T_v].
175 | """
176 | t1, t2 = logits.shape
177 | # We prepend zeros on the final timescale dimension.
178 | to_pad = jnp.zeros_like(logits[..., :1])
179 | x = jnp.concatenate((to_pad, logits), axis=-1)
180 |
181 | # Reshape trick to shift input.
182 | x = jnp.reshape(x, [t2 + 1, t1])
183 |
184 | # Remove extra time dimension and re-shape.
185 | x = jax.lax.slice(x, [1] + [0] * (x.ndim - 1), x.shape)
186 |
187 | return jnp.reshape(x, [t1, t2])
188 |
189 |
190 | def relative_shift(
191 | logits: chex.Array, attention_length: int, causal: bool = False
192 | ) -> chex.Array:
193 | if causal:
194 | fn = _rel_shift_causal
195 | else:
196 | fn = lambda t: _rel_shift_inner(t, attention_length)
197 | return jax.vmap(jax.vmap(fn))(logits)
198 |
199 |
200 | def apply_rotary_encoding(
201 | x: jnp.ndarray, position: jnp.ndarray, max_time: int = 10_000
202 | ) -> jnp.ndarray:
203 | """Applies RoPE positional encodings for the input.
204 |
205 | Paper: https://arxiv.org/abs/2104.09864
206 |
207 | Args:
208 | x: The input tensor on which RoPE will be applied. Usually it is either some
209 | queries q or some keys k.
210 | position: The positions to use. Usually it's an arange of the maximum
211 | length.
212 | max_time: Constant used to scale position by in the encodings.
213 |
214 | Returns:
215 | A tensor with the same shape as x.
216 | """
217 | # Expand dims for positions to support inputs of shapes BTC or BTHC.
218 | freq_seq = jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
219 | freq_seq = freq_seq / (x.shape[-1] // 2)
220 | inv_freq = max_time**-freq_seq
221 | inv_freq = jnp.repeat(inv_freq, 2, 0)
222 | # Produce position inputs to periodic functions.
223 | t = position[:, :, None, None] * inv_freq[None, None, None, :]
224 | x_rot = jnp.einsum('bthd,dD->bthD', x, _rope_kernel(x.shape[-1], x.dtype))
225 | return x * jnp.cos(t).astype(x.dtype) + jnp.sin(t).astype(x.dtype) * x_rot
226 |
227 |
228 | def _rope_kernel(n: int, dtype: Any) -> np.ndarray:
229 | """Reorders the embedding dimension of an array, to make rotation easier."""
230 | # We implement the equivalent of
231 | # even_dims, odd_dims, = x[..., ::2], x[..., 1::2]
232 | # return jnp.stack((-odd_dims, even_dims), axis=-1).reshape(x.shape)
233 | # with a custom kernel for einsum. This allows the computation to execute
234 | # on the MXU instead of producing a slow gather.
235 | assert n % 2 == 0, n
236 | kernel = np.zeros((n, n), dtype)
237 | for i in range(n):
238 | # Swap each neighbouring pair of values.
239 | if i % 2 == 0:
240 | kernel[i, i + 1] = 1
241 | else:
242 | kernel[i, i - 1] = -1
243 | return kernel
244 |
245 |
246 | def compute_attention_with_relative_encodings(
247 | queries: chex.Array,
248 | keys: chex.Array,
249 | max_time: int = 10_000,
250 | causal: bool = False) -> chex.Array:
251 | """Returns attention with relative positional encodings.
252 |
253 | This code strictly follows what is described in the TransformerXL paper.
254 | https://arxiv.org/pdf/1901.02860.pdf
255 |
256 | Args:
257 | queries: The queries used for attention. Shape (b, t, h, d).
258 | keys: The keys used for attention. Shape (b, T, h, d).
259 | max_time: Constant used to scale position by in the sin/cos encodings.
260 | causal: Whether to use causal attention when shifting the relative logits.
261 |
262 | Returns:
263 | The attention logits. Shape (b, h, t, T).
264 | """
265 | batch_size, k_seq_len, num_heads, num_hiddens = keys.shape
266 | hiddens = num_hiddens * num_heads
267 |
268 | # First compute the content logits.
269 | content_bias = hk.get_parameter(
270 | name='relpos_contentbias',
271 | shape=[num_heads, num_hiddens],
272 | init=hk.initializers.RandomNormal(stddev=0.02))
273 | content_logits = jnp.einsum('bthd,bThd->bhtT', queries + content_bias, keys)
274 |
275 | positional_encodings = sinusoid_position_encoding(
276 | sequence_length=k_seq_len,
277 | hidden_size=hiddens,
278 | memory_length=0,
279 | max_timescale=max_time,
280 | min_timescale=2,
281 | clamp_length=0,
282 | causal=causal,
283 | )
284 | positional_encodings = jnp.broadcast_to(positional_encodings, (batch_size,) +
285 | positional_encodings.shape)
286 | relative_keys = hk.Linear(hiddens, with_bias=False)(positional_encodings)
287 | relative_keys = jnp.reshape(
288 | relative_keys, positional_encodings.shape[:-1] + (num_heads, num_hiddens))
289 |
290 | # Then compute the relative part.
291 | relative_bias = hk.get_parameter(
292 | name='relpos_relativebias',
293 | shape=[num_heads, num_hiddens],
294 | init=hk.initializers.RandomNormal(stddev=0.02))
295 | relative_logits = jnp.einsum('bthd,bThd->bhtT', queries + relative_bias,
296 | relative_keys)
297 | # We shift the relative logits instead of the positional encoding matrix as
298 | # described in Appendix B of the paper (https://arxiv.org/pdf/1901.02860.pdf).
299 | relative_logits = relative_shift(
300 | relative_logits, attention_length=content_logits.shape[-1], causal=causal
301 | )
302 | assert content_logits.shape == relative_logits.shape
303 | return content_logits + relative_logits
304 |
305 |
306 | def _get_alibi_slopes(num_heads: int) -> list[float]:
307 | """Returns the slopes for the different attention heads.
308 |
309 | While this does not exactly match the description of the [ALiBi
310 | paper](https://arxiv.org/pdf/2108.12409.pdf), it corresponds to the [official
311 | implementation](https://github.com/ofirpress/attention_with_linear_biases/blob/a06526fbfe557f9148e414b8569dcb97c7b182ba/fairseq/models/transformer.py#L742).
312 |
313 | Args:
314 | num_heads: The number of attention heads to create slopes for.
315 | """
316 |
317 | def get_slopes_power_of_2(n):
318 | start = (2**(-2**-(math.log2(n) - 3)))
319 | ratio = start
320 | return [start * ratio**i for i in range(n)]
321 |
322 | if math.log2(num_heads).is_integer():
323 | return get_slopes_power_of_2(num_heads)
324 | else:
325 | closest_power_of_2 = 2**math.floor(math.log2(num_heads))
326 | return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
327 | 2 * closest_power_of_2)[0::2][:num_heads - closest_power_of_2])
328 |
329 |
330 | def compute_alibi_encodings_biases(
331 | attention_shape: tuple[int, ...]
332 | ) -> chex.Array:
333 | """Returns the biases following the ALiBi method.
334 |
335 | This code strictly follows what is described in the ALiBi paper.
336 | https://arxiv.org/pdf/2108.12409.pdf
337 |
338 | Args:
339 | attention_shape: The attention logits shape, without batch size, (h, t, T).
340 |
341 | Returns:
342 | The alibi biases, same shape as the input logits shape.
343 | """
344 | num_heads, q_seq_len, k_seq_len = attention_shape
345 |
346 | # Since we do not use causal masking, the upper triangle of the matrix has to
347 | # be nonzero. Therefore, we set it equal to the lower triangle, but we also
348 | # add a constant factor of 0.5 to the lower triangle, to (arbitrarily) break
349 | # the symmetry (otherwise, the model cannot distinguish left and right).
350 | alibi = np.zeros((q_seq_len, k_seq_len))
351 | alibi -= sum(np.tri(*alibi.shape, k=-i) for i in range(1, q_seq_len))
352 | alibi -= sum(np.tri(*alibi.T.shape, k=-i).T for i in range(1, k_seq_len))
353 | alibi += 0.5 * np.tri(*alibi.shape, k=-1)
354 |
355 | return alibi * jnp.array(_get_alibi_slopes(num_heads))[:, None, None]
356 |
--------------------------------------------------------------------------------
/models/rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Builders for RNN/LSTM cores."""
17 |
18 | from typing import Any, Callable
19 |
20 | import haiku as hk
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 |
24 | from neural_networks_chomsky_hierarchy.models import tape_rnn
25 |
26 |
27 | def make_rnn(
28 | output_size: int,
29 | rnn_core: type[hk.RNNCore],
30 | return_all_outputs: bool = False,
31 | input_window: int = 1,
32 | **rnn_kwargs: Any
33 | ) -> Callable[[jnp.ndarray], jnp.ndarray]:
34 | """Returns an RNN model, not haiku transformed.
35 |
36 | Only the last output in the sequence is returned. A linear layer is added to
37 | match the required output_size.
38 |
39 | Args:
40 | output_size: The output size of the model.
41 | rnn_core: The haiku RNN core to use. LSTM by default.
42 | return_all_outputs: Whether to return the whole sequence of outputs of the
43 | RNN, or just the last one.
44 | input_window: The number of tokens that are fed at once to the RNN.
45 | **rnn_kwargs: Kwargs to be passed to the RNN core.
46 | """
47 |
48 | def rnn_model(x: jnp.ndarray, input_length: int = 1) -> jnp.ndarray:
49 | core = rnn_core(**rnn_kwargs)
50 | if issubclass(rnn_core, tape_rnn.TapeRNNCore):
51 | initial_state = core.initial_state(x.shape[0], input_length) # pytype: disable=wrong-arg-count
52 | else:
53 | initial_state = core.initial_state(x.shape[0])
54 |
55 | batch_size, seq_length, embed_size = x.shape
56 | if seq_length % input_window != 0:
57 | x = jnp.pad(x, ((0, 0), (0, input_window - seq_length % input_window),
58 | (0, 0)))
59 | new_seq_length = x.shape[1]
60 | x = jnp.reshape(
61 | x,
62 | (batch_size, new_seq_length // input_window, input_window, embed_size))
63 |
64 | x = hk.Flatten(preserve_dims=2)(x)
65 |
66 | output, _ = hk.dynamic_unroll(
67 | core, x, initial_state, time_major=False, return_all_states=True)
68 | output = jnp.reshape(output, (batch_size, new_seq_length, output.shape[-1]))
69 |
70 | if not return_all_outputs:
71 | output = output[:, -1, :] # (batch, time, alphabet_dim)
72 | output = jnn.relu(output)
73 | return hk.Linear(output_size)(output)
74 |
75 | return rnn_model
76 |
--------------------------------------------------------------------------------
/models/stack_rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Stack RNN core.
17 |
18 | Following the paper from Joulin et al (2015):
19 | https://arxiv.org/abs/1503.01007
20 |
21 | The idea is to add a stack extension to a recurrent neural network to be able to
22 | simulate a machine accepting context-free languages.
23 | The stack is completely differentiable. The actions taken are probabilities
24 | only and therefore no RL is required. The stack and state update are just linear
25 | combinations of the last states and these probabilities.
26 | """
27 |
28 | from typing import Any, Mapping, Optional
29 |
30 | import einshape
31 | import haiku as hk
32 | import jax
33 | import jax.nn as jnn
34 | import jax.numpy as jnp
35 |
36 |
37 | # First element is the stacks, second is the hidden internal state.
38 | _StackRnnState = tuple[jnp.ndarray, jnp.ndarray]
39 |
40 | # Number of actions the stack-RNN can take, namely POP, PUSH and NO_OP.
41 | _NUM_ACTIONS = 3
42 |
43 |
44 | def _update_stack(stack: jnp.ndarray, actions: jnp.ndarray,
45 | push_value: jnp.ndarray) -> jnp.ndarray:
46 | """Updates the stack values.
47 |
48 | We update the stack in two steps.
49 | In the first step, we update the top of the stack, and essentially do:
50 | stack[0] = push_action * push_value
51 | + pop_action * stack[1]
52 | + noop_action * stack[0]
53 |
54 | Then, in the second step, we update the rest of the stack and we move the
55 | elements up and down, depending on the action executed:
56 | * If push_action were 1, then we'd be purely pushing a new element
57 | to the top of the stack, so we'd move all elements down by one.
58 | * Likewise, if pop_action were 1, we'd be purely taking an element
59 | off the top of the stack, so we'd move all elements up by one.
60 | * Finally, if noop_action were 1, we'd leave elements where they were.
61 | The update is therefore essentially:
62 | stack[i] = push_action * stack[i-1]
63 | + pop_action * stack[i+1]
64 | + noop_action * stack[i]
65 |
66 | Args:
67 | stack: The current stack, shape (batch_size, stack_size, stack_cell_size).
68 | actions: The array of probabilities of the actions, shape (batch_size, 3).
69 | push_value: The vector to push on the stack, if the push action probability
70 | is positive, shape (batch_size, stack_cell_size).
71 |
72 | Returns:
73 | The new stack, same shape as the input stack.
74 | """
75 | batch_size, stack_size, stack_cell_size = stack.shape
76 |
77 | # Tiling the actions to match the top of the stack.
78 | # Shape (batch_size, stack_cell_size, _NUM_ACTIONS)
79 | cell_tiled_stack_actions = einshape.jax_einshape(
80 | 'ba->bsa', actions, s=stack_cell_size)
81 | push_action = cell_tiled_stack_actions[..., 0]
82 | pop_action = cell_tiled_stack_actions[..., 1]
83 | pop_value = stack[..., 1, :]
84 | no_op_action = cell_tiled_stack_actions[..., 2]
85 | no_op_value = stack[..., 0, :]
86 |
87 | # Shape (batch_size, 1, stack_cell_size)
88 | top_new_stack = (
89 | push_action * push_value + pop_action * pop_value +
90 | no_op_action * no_op_value)
91 | top_new_stack = jnp.expand_dims(top_new_stack, axis=1)
92 |
93 | # Tiling the actions to match all of the stack except the top.
94 | # Shape (batch_size, stack_size, stack_cell_size, _NUM_ACTIONS)
95 | stack_tiled_stack_actions = einshape.jax_einshape(
96 | 'ba->bcsa', actions, s=stack_cell_size, c=stack_size - 1)
97 | push_action = stack_tiled_stack_actions[..., 0]
98 | push_value = stack[..., :-1, :]
99 | pop_action = stack_tiled_stack_actions[..., 1]
100 | pop_extra_zeros = jnp.zeros((batch_size, 1, stack_cell_size))
101 | pop_value = jnp.concatenate([stack[..., 2:, :], pop_extra_zeros], axis=1)
102 | no_op_action = stack_tiled_stack_actions[..., 2]
103 | no_op_value = stack[..., 1:, :]
104 |
105 | # Shape (batch_size, stack_size-1, stack_cell_size)
106 | rest_new_stack = (
107 | push_action * push_value + pop_action * pop_value +
108 | no_op_action * no_op_value)
109 |
110 | # Finally concatenate the new top with the new rest of the stack.
111 | return jnp.concatenate([top_new_stack, rest_new_stack], axis=1)
112 |
113 |
114 | class StackRNNCore(hk.RNNCore):
115 | """Core for the stack RNN."""
116 |
117 | def __init__(
118 | self,
119 | stack_cell_size: int,
120 | stack_size: int = 30,
121 | n_stacks: int = 1,
122 | inner_core: type[hk.RNNCore] = hk.VanillaRNN,
123 | name: Optional[str] = None,
124 | **inner_core_kwargs: Mapping[str, Any]
125 | ):
126 | """Initializes.
127 |
128 | Args:
129 | stack_cell_size: The dimension of the vectors we put in the stack.
130 | stack_size: The total number of vectors we can stack.
131 | n_stacks: Number of stacks to use in the network.
132 | inner_core: The inner RNN core builder.
133 | name: See base class.
134 | **inner_core_kwargs: The arguments to be passed to the inner RNN core
135 | builder.
136 | """
137 | super().__init__(name=name)
138 | self._rnn_core = inner_core(**inner_core_kwargs)
139 | self._stack_cell_size = stack_cell_size
140 | self._stack_size = stack_size
141 | self._n_stacks = n_stacks
142 |
143 | def __call__(
144 | self, inputs: jnp.ndarray, prev_state: _StackRnnState
145 | ) -> tuple[jnp.ndarray, _StackRnnState]:
146 | """Steps the stack RNN core.
147 |
148 | See base class docstring.
149 |
150 | Args:
151 | inputs: An input array of shape (batch_size, input_size). The time
152 | dimension is not included since it is an RNNCore, which is unrolled over
153 | the time dimension.
154 | prev_state: A _StackRnnState tuple, consisting of the previous stacks and
155 | the previous state of the inner core. Each stack has shape (batch_size,
156 | stack_size, stack_cell_size), such that `stack[n][0]` represents the top
157 | of the stack for the nth batch item, and `stack[n][-1]` the bottom of
158 | the stack. The stacks are just the concatenation of all these tensors.
159 |
160 | Returns:
161 | - output: An output array of shape (batch_size, output_size).
162 | - next_state: Same format as prev_state.
163 | """
164 | stacks, old_core_state = prev_state
165 |
166 | # The network can always read the top of the stack.
167 | batch_size = stacks.shape[0]
168 | top_stacks = stacks[:, :, 0, :]
169 | top_stacks = jnp.reshape(
170 | top_stacks, (batch_size, self._n_stacks * self._stack_cell_size))
171 | inputs = jnp.concatenate([inputs, top_stacks], axis=-1)
172 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state)
173 | push_values = hk.Linear(self._n_stacks * self._stack_cell_size)(
174 | new_core_output)
175 | push_values = jnp.reshape(
176 | push_values, (batch_size, self._n_stacks, self._stack_cell_size))
177 |
178 | # Shape (batch_size, _NUM_ACTIONS)
179 | stack_actions = jnn.softmax(
180 | hk.Linear(self._n_stacks * _NUM_ACTIONS)(new_core_output), axis=-1)
181 | stack_actions = jnp.reshape(stack_actions,
182 | (batch_size, self._n_stacks, _NUM_ACTIONS))
183 |
184 | new_stacks = jax.vmap(
185 | _update_stack, in_axes=1, out_axes=1)(stacks, stack_actions,
186 | push_values)
187 | return new_core_output, (new_stacks, new_core_state)
188 |
189 | def initial_state(self, batch_size: Optional[int]) -> _StackRnnState:
190 | """Returns the initial state of the core, a hidden state and an empty stack."""
191 | core_state = self._rnn_core.initial_state(batch_size)
192 | stacks = jnp.zeros(
193 | (batch_size, self._n_stacks, self._stack_size, self._stack_cell_size))
194 | return stacks, core_state
195 |
--------------------------------------------------------------------------------
/models/tape_rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Implements the Tape RNN."""
17 |
18 | import abc
19 | import functools
20 | from typing import Any, Optional, Sequence
21 |
22 | import chex
23 | import haiku as hk
24 | import jax
25 | from jax import nn as jnn
26 | from jax import numpy as jnp
27 |
28 | # The first element is the memory, the second is the hidden internal state, and
29 | # the third is the input length, necessary for adaptive actions.
30 | _TapeRNNState = tuple[chex.Array, chex.Array, chex.Array]
31 |
32 |
33 | class TapeRNNCore(hk.RNNCore, abc.ABC):
34 | """Core for the tape RNN."""
35 |
36 | def __init__(
37 | self,
38 | memory_cell_size: int,
39 | memory_size: int = 30,
40 | n_tapes: int = 1,
41 | mlp_layers_size: Sequence[int] = (64, 64),
42 | inner_core: type[hk.RNNCore] = hk.VanillaRNN,
43 | name: Optional[str] = None,
44 | **inner_core_kwargs: Any
45 | ):
46 | """Initializes.
47 |
48 | Args:
49 | memory_cell_size: The dimension of the vectors we put in memory.
50 | memory_size: The size of the tape, fixed value along the episode.
51 | n_tapes: Number of tapes to use. Default is 1.
52 | mlp_layers_size: Sizes for the inner MLP layers. Can be empty, in which
53 | case the MLP is a linear layer.
54 | inner_core: The inner RNN core builder.
55 | name: See base class.
56 | **inner_core_kwargs: The arguments to be passed to the inner RNN core
57 | builder.
58 | """
59 | super().__init__(name=name)
60 | self._rnn_core = inner_core(**inner_core_kwargs)
61 | self._mlp_layers_size = mlp_layers_size
62 | self._memory_cell_size = memory_cell_size
63 | self._memory_size = memory_size
64 | self._n_tapes = n_tapes
65 |
66 | @abc.abstractmethod
67 | def _tape_operations(
68 | self, eye_memory: chex.Array, input_length: int
69 | ) -> list[chex.Array]:
70 | """Returns a set of updated memory slots.
71 |
72 | An eye matrix is passed and corresponds to the positions of the memory
73 | slots. This method returns a matrix with the new positions associated with
74 | the actions. For instance, for a 'left' action, the new matrix will just be
75 | a roll(eye_memory, shift=-1). This is general enough to allow any
76 | permutation on the indexes.
77 |
78 | Args:
79 | eye_memory: An eye matrix of shape [memory_size, memory_size].
80 | input_length: The length of the input sequence. Can be useful for some
81 | operations.
82 | """
83 |
84 | @property
85 | @abc.abstractmethod
86 | def num_actions(self) -> int:
87 | """Returns the number of actions which can be taken on the tape."""
88 |
89 | def __call__(
90 | self, inputs: chex.Array, prev_state: _TapeRNNState
91 | ) -> tuple[chex.Array, _TapeRNNState]:
92 | """Steps the tape RNN core."""
93 | memories, old_core_state, input_length = prev_state
94 |
95 | # The network can always read the value of the current cell.
96 | batch_size = memories.shape[0]
97 | current_memories = memories[:, :, 0, :]
98 | current_memories = jnp.reshape(
99 | current_memories, (batch_size, self._n_tapes * self._memory_cell_size))
100 | inputs = jnp.concatenate([inputs, current_memories], axis=-1)
101 | new_core_output, new_core_state = self._rnn_core(inputs, old_core_state)
102 | readout_mlp = hk.nets.MLP(
103 | list(self._mlp_layers_size) + [self._n_tapes * self._memory_cell_size])
104 | write_values = readout_mlp(new_core_output)
105 | write_values = jnp.reshape(
106 | write_values, (batch_size, self._n_tapes, self._memory_cell_size))
107 |
108 | # Shape (batch_size, num_actions).
109 | actions = []
110 | for _ in range(self._n_tapes):
111 | actions.append(
112 | jnn.softmax(hk.Linear(self.num_actions)(new_core_output), axis=-1))
113 | actions = jnp.stack(actions, axis=1)
114 |
115 | update_memory = functools.partial(
116 | self._update_memory, input_length=input_length[0])
117 | new_memories = jax.vmap(
118 | update_memory, in_axes=1, out_axes=1)(memories, actions, write_values)
119 | return new_core_output, (new_memories, new_core_state, input_length)
120 |
121 | def initial_state(self, batch_size: Optional[int],
122 | input_length: int) -> _TapeRNNState: # pytype: disable=signature-mismatch
123 | """Returns the initial state of the core."""
124 | core_state = self._rnn_core.initial_state(batch_size)
125 | memories = jnp.zeros(
126 | (batch_size, self._n_tapes, self._memory_size, self._memory_cell_size))
127 | return memories, core_state, jnp.array([input_length])
128 |
129 | def _update_memory(self, memory: chex.Array, actions: chex.Array,
130 | write_values: chex.Array, input_length: int) -> chex.Array:
131 | """Computes the new memory based on the `actions` and `write_values`.
132 |
133 | Args:
134 | memory: The current memory with shape `[batch_size, memory_size,
135 | memory_cell_size]`.
136 | actions: The action probabilities with shape `[batch_size, num_actions]`.
137 | write_values: The values added to the first memory entry with shape
138 | `[batch_size, memory_cell_size]`.
139 | input_length: The length of the input.
140 |
141 | Returns:
142 | The new memory with shape `[batch_size, memory_size]`.
143 | """
144 | _, memory_size, _ = memory.shape
145 |
146 | memory_with_write = jnp.concatenate(
147 | [jnp.expand_dims(write_values, axis=1), memory[:, 1:]], axis=1)
148 |
149 | eye_memory = jnp.eye(memory_size)
150 | operations = self._tape_operations(eye_memory, input_length)
151 | apply_operation = lambda x: jnp.einsum('mM,bMc->bmc', x, memory_with_write)
152 | memory_operations = jnp.stack(list(map(apply_operation, operations)))
153 | return jnp.einsum('Abmc,bA->bmc', memory_operations, actions)
154 |
155 |
156 | class TapeInputLengthJumpCore(TapeRNNCore):
157 | """A tape-RNN with extra jumps of the length of the input.
158 |
159 | 5 possible actions:
160 | - write and stay
161 | - write and move one cell left
162 | - write and move one cell right
163 | - write and move input_length cells left
164 | - write and move input_length cells right
165 | """
166 |
167 | @property
168 | def num_actions(self) -> int:
169 | """Returns the number of actions of the tape."""
170 | return 5
171 |
172 | def _tape_operations(
173 | self, eye_memory: chex.Array, input_length: int
174 | ) -> list[chex.Array]:
175 | write_stay = eye_memory
176 | write_left = jnp.roll(eye_memory, shift=-1, axis=0)
177 | write_right = jnp.roll(eye_memory, shift=1, axis=0)
178 | write_jump_left = jnp.roll(eye_memory, shift=-input_length, axis=0)
179 | write_jump_right = jnp.roll(eye_memory, shift=input_length, axis=0)
180 | return [
181 | write_stay, write_left, write_right, write_jump_left, write_jump_right
182 | ]
183 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Transformer model."""
17 |
18 | import dataclasses
19 | from typing import Callable, Optional
20 |
21 | import chex
22 | import haiku as hk
23 | import jax
24 | import jax.nn as jnn
25 | import jax.numpy as jnp
26 |
27 | from neural_networks_chomsky_hierarchy.models import positional_encodings as pos_encs_lib
28 |
29 |
30 | @chex.dataclass
31 | class TransformerConfig:
32 | """Hyperparameters used in the Transformer architectures."""
33 | # The size of the model output (i.e., the output vocabulary size).
34 | output_size: int
35 | # The dimension of the first embedding.
36 | embedding_dim: int = 64
37 | # The number of multi-head attention layers.
38 | num_layers: int = 5
39 | # The number of heads per layer.
40 | num_heads: int = 8
41 | # The number of hidden neurons per head. If None, it is set to be equal to
42 | # `embedding_dim // num_heads`.
43 | num_hiddens_per_head: Optional[int] = None
44 | # The probability that each element is discarded by the dropout modules.
45 | dropout_prob: float = 0.1
46 | # The parameter initialization scale for the embeddings.
47 | emb_init_scale: float = 0.02
48 | # Whether to use the embeddings rather than raw inputs.
49 | use_embeddings: bool = True
50 | # Whether to share embeddings between the Encoder and the Decoder.
51 | share_embeddings: bool = False
52 | # The size of the sliding attention window. See MultiHeadDotProductAttention.
53 | attention_window: Optional[int] = None
54 | # The positional encoding used with default sin/cos (Vaswani et al., 2017).
55 | positional_encodings: pos_encs_lib.PositionalEncodings = dataclasses.field(
56 | default_factory=lambda: pos_encs_lib.PositionalEncodings.SIN_COS
57 | )
58 | # The maximum size of the context (used by the posiitonal encodings).
59 | max_time: int = 10_000
60 | # The parameters for the positional encodings, default sin/cos.
61 | positional_encodings_params: pos_encs_lib.PositionalEncodingsParams = (
62 | dataclasses.field(default_factory=pos_encs_lib.SinCosParams)
63 | )
64 | # How much larger the hidden layer of the feedforward network should be
65 | # compared to the `embedding_dim`.
66 | widening_factor: int = 4
67 | # Add mask to make causal predictions.
68 | causal_masking: bool = False
69 |
70 | def __post_init__(self) -> None:
71 | """Sets `num_hiddens_per_head` if it is `None`."""
72 | if self.num_hiddens_per_head is None:
73 | self.num_hiddens_per_head = self.embedding_dim // self.num_heads
74 |
75 |
76 | def layer_norm(x: chex.Array) -> chex.Array:
77 | return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
78 |
79 |
80 | def shift_right(x: chex.Array, output_size: int) -> chex.Array:
81 | """Right-shift the one-hot encoded input by padding on the temporal axis."""
82 | x = jnp.argmax(x, axis=-1)
83 |
84 | # Add a time dimension for the single-output case (i.e., `ndim == 1`).
85 | if x.ndim == 1:
86 | x = jnp.expand_dims(x, axis=1)
87 |
88 | padded = jnp.pad(
89 | x, ((0, 0), (1, 0)), mode='constant', constant_values=output_size)
90 |
91 | return jnn.one_hot(padded[:, :-1], num_classes=output_size + 1)
92 |
93 |
94 | def compute_sliding_window_mask(sequence_length: int,
95 | attention_window: int) -> chex.Array:
96 | """Returns a k-diagonal mask for a sliding window.
97 |
98 | Args:
99 | sequence_length: The length of the sequence, which will determine the shape
100 | of the output.
101 | attention_window: The size of the sliding window.
102 |
103 | Returns:
104 | A symmetric matrix of shape (sequence_length, sequence_length),
105 | attention_window-diagonal, with ones on the diagonal and on all the
106 | upper/lower diagonals up to attention_window // 2.
107 |
108 | Raises:
109 | ValueError if attention_window is <= 0.
110 | """
111 | if attention_window <= 0:
112 | raise ValueError(
113 | f'The attention window should be > 0. Got {attention_window}.')
114 |
115 | if attention_window == 1:
116 | return jnp.eye(sequence_length, sequence_length)
117 |
118 | attention_mask = jnp.sum(
119 | jnp.stack([
120 | jnp.eye(sequence_length, sequence_length, k=k, dtype=jnp.int32)
121 | for k in range(1, attention_window // 2 + 1)
122 | ]),
123 | axis=0)
124 | attention_mask = attention_mask + jnp.transpose(attention_mask)
125 | attention_mask += jnp.eye(sequence_length, sequence_length)
126 | return attention_mask
127 |
128 |
129 | class MultiHeadDotProductAttention(hk.Module):
130 | """Multi-head dot-product attention (Vaswani et al., 2017)."""
131 |
132 | def __init__(
133 | self,
134 | num_heads: int,
135 | num_hiddens_per_head: int,
136 | positional_encodings: pos_encs_lib.PositionalEncodings,
137 | positional_encodings_params: pos_encs_lib.PositionalEncodingsParams,
138 | attention_window: Optional[int] = None,
139 | name: Optional[str] = None,
140 | ) -> None:
141 | """Initializes the attention module.
142 |
143 | Args:
144 | num_heads: Number of heads to use.
145 | num_hiddens_per_head: Number of hidden neurons per head.
146 | positional_encodings: Which positional encodings to use in the attention.
147 | positional_encodings_params: Parameters for the positional encodings.
148 | attention_window: Size of the attention sliding window. None means no
149 | sliding window is used (or equivalently, window=full_attention_length).
150 | We attend only on attention_window tokens around a given query token. We
151 | attend to tokens before AND after the query token. If attention_window
152 | is even, we use the value +1.
153 | name: Name of the module.
154 | """
155 | super().__init__(name=name)
156 | self._num_heads = num_heads
157 | self._num_hiddens_per_head = num_hiddens_per_head
158 | self._positional_encodings = positional_encodings
159 | self._attention_window = attention_window
160 | self._positional_encodings_params = positional_encodings_params
161 |
162 | def __call__(
163 | self,
164 | inputs_q: chex.Array,
165 | inputs_kv: chex.Array,
166 | mask: Optional[chex.Array] = None,
167 | causal: bool = False,
168 | ) -> chex.Array:
169 | """Returns the output of the multi-head attention."""
170 | batch_size, sequence_length, embedding_size = inputs_q.shape
171 |
172 | num_hiddens = self._num_hiddens_per_head * self._num_heads
173 | q = hk.Linear(num_hiddens, with_bias=False)(inputs_q)
174 | k = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
175 | v = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
176 | # The second (sequence) dimension is undefined since it can differ between
177 | # queries and keys/values when decoding.
178 | new_shape = (batch_size, -1, self._num_heads, self._num_hiddens_per_head)
179 | q = jnp.reshape(q, new_shape)
180 | k = jnp.reshape(k, new_shape)
181 | v = jnp.reshape(v, new_shape)
182 |
183 | # Let b=batch_size, t=seq_len, h=num_heads, and d=num_hiddens_per_head.
184 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.RELATIVE:
185 | # We type hint the params to match the if statement, for pytype.
186 | self._positional_encodings_params: pos_encs_lib.RelativeParams
187 | attention = pos_encs_lib.compute_attention_with_relative_encodings(
188 | q, k, self._positional_encodings_params.max_time, causal=causal
189 | )
190 | else:
191 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ROTARY:
192 | q = pos_encs_lib.apply_rotary_encoding(
193 | q, position=jnp.arange(q.shape[1])[None, :]
194 | )
195 | k = pos_encs_lib.apply_rotary_encoding(
196 | k, position=jnp.arange(k.shape[1])[None, :]
197 | )
198 | attention = jnp.einsum('bthd,bThd->bhtT', q, k)
199 | attention *= 1.0 / jnp.sqrt(self._num_hiddens_per_head)
200 |
201 | # ALiBi encodings are not scaled with the 1 / sqrt(d_k) factor.
202 | if self._positional_encodings == pos_encs_lib.PositionalEncodings.ALIBI:
203 | attention += pos_encs_lib.compute_alibi_encodings_biases(
204 | attention.shape[1:]
205 | )
206 |
207 | if self._attention_window is not None:
208 | # We compute the sliding attention by just applying a mask on the values
209 | # that are outside our window.
210 | attention_mask = compute_sliding_window_mask(sequence_length,
211 | self._attention_window)
212 | attention = jnp.where(attention_mask, attention,
213 | jnp.finfo(jnp.float32).min)
214 |
215 | if mask is not None:
216 | attention = jnp.where(mask, attention, jnp.finfo(jnp.float32).min)
217 |
218 | normalized_attention = jnn.softmax(attention)
219 |
220 | output = jnp.einsum('bhtT,bThd->bthd', normalized_attention, v)
221 | output = jnp.reshape(output, (batch_size, sequence_length, num_hiddens))
222 | return hk.Linear(embedding_size, with_bias=False)(output)
223 |
224 |
225 | class TransformerEncoder(hk.Module):
226 | """Transformer Encoder (Vaswani et al., 2017)."""
227 |
228 | def __init__(
229 | self,
230 | config: TransformerConfig,
231 | shared_embeddings_fn: Optional[Callable[[chex.Array], chex.Array]] = None,
232 | name: Optional[str] = None,
233 | ) -> None:
234 | """Initializes the transformer encoder.
235 |
236 | Args:
237 | config: The hyperparameters used in Transformer architectures.
238 | shared_embeddings_fn: Embedding function that is shared with the decoder.
239 | name: The name of the module.
240 | """
241 | super().__init__(name=name)
242 | self._config = config
243 | self._shared_embeddings_fn = shared_embeddings_fn
244 |
245 | def __call__(self, x: jnp.ndarray) -> chex.Array:
246 | """Returns the transformer encoder output, shape [B, T, E]."""
247 | if self._config.use_embeddings:
248 | if self._shared_embeddings_fn is not None:
249 | embeddings = self._shared_embeddings_fn(x)
250 | else:
251 | # Since `x` is one-hot encoded, using hk.Linear is equivalent to
252 | # hk.Embed with hk.EmbedLookupStyle.ONE_HOT.
253 | embs_init = hk.initializers.TruncatedNormal(
254 | stddev=self._config.emb_init_scale)
255 | embeddings = hk.Linear(
256 | self._config.embedding_dim, with_bias=False, w_init=embs_init)(
257 | x)
258 |
259 | embeddings *= jnp.sqrt(self._config.embedding_dim)
260 |
261 | else:
262 | embeddings = x
263 |
264 | batch_size, sequence_length, embedding_size = embeddings.shape
265 |
266 | pos_enc_params = self._config.positional_encodings_params
267 | if (
268 | self._config.positional_encodings
269 | == pos_encs_lib.PositionalEncodings.SIN_COS
270 | ):
271 | pos_encodings = pos_encs_lib.sinusoid_position_encoding(
272 | sequence_length=sequence_length,
273 | hidden_size=embedding_size,
274 | memory_length=0,
275 | max_timescale=pos_enc_params.max_time,
276 | min_timescale=2,
277 | clamp_length=0,
278 | causal=True,
279 | )
280 | h = embeddings + pos_encodings
281 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h)
282 | else:
283 | h = embeddings
284 |
285 | # The causal mask is shared across heads.
286 | if self._config.causal_masking:
287 | causal_mask = jnp.tril(
288 | jnp.ones((batch_size, 1, sequence_length, sequence_length))
289 | )
290 | else:
291 | causal_mask = None
292 |
293 | for _ in range(self._config.num_layers):
294 | attention = MultiHeadDotProductAttention(
295 | num_heads=self._config.num_heads,
296 | num_hiddens_per_head=self._config.num_hiddens_per_head,
297 | positional_encodings=self._config.positional_encodings,
298 | positional_encodings_params=pos_enc_params,
299 | attention_window=self._config.attention_window,
300 | )(
301 | inputs_q=h,
302 | inputs_kv=h,
303 | mask=causal_mask,
304 | causal=self._config.causal_masking,
305 | )
306 | attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob,
307 | attention)
308 | attention = layer_norm(h + attention)
309 |
310 | # Position-wise feedforward network.
311 | h = hk.Linear(self._config.embedding_dim * self._config.widening_factor)(
312 | attention)
313 | h = jnn.relu(h)
314 | h = hk.Linear(self._config.embedding_dim)(h)
315 |
316 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h)
317 | h = layer_norm(h + attention)
318 | return h
319 |
320 |
321 | class TransformerDecoder(hk.Module):
322 | """Transformer Decoder (Vaswani et al., 2017)."""
323 |
324 | def __init__(
325 | self,
326 | config: TransformerConfig,
327 | shared_embeddings_fn: Optional[Callable[[chex.Array], chex.Array]] = None,
328 | name: Optional[str] = None,
329 | ) -> None:
330 | """Initializes the Transformer decoder.
331 |
332 | Args:
333 | config: The hyperparameters used in Transformer architectures.
334 | shared_embeddings_fn: Embedding function that is shared with the encoder.
335 | name: The name of the module.
336 | """
337 | super().__init__(name=name)
338 | self._config = config
339 | self._shared_embeddings_fn = shared_embeddings_fn
340 |
341 | def __call__(self, encoded: chex.Array, targets: chex.Array) -> chex.Array:
342 | """Returns the transformer decoder output, shape [B, T_O, E].
343 |
344 | Args:
345 | encoded: The output of the encoder, shape [B, T_I, E].
346 | targets: The one-hot encoded target values, shape [B, T_O, 2].
347 | """
348 | targets = shift_right(targets, self._config.output_size)
349 |
350 | if self._config.use_embeddings:
351 | if self._shared_embeddings_fn is not None:
352 | output_embeddings = self._shared_embeddings_fn(targets)
353 | else:
354 | # Since `x` is one-hot encoded, using hk.Linear is equivalent to
355 | # hk.Embed with hk.EmbedLookupStyle.ONE_HOT.
356 | embs_init = hk.initializers.TruncatedNormal(
357 | stddev=self._config.emb_init_scale)
358 | output_embeddings = hk.Linear(
359 | self._config.embedding_dim, with_bias=False, w_init=embs_init)(
360 | targets)
361 |
362 | output_embeddings *= jnp.sqrt(self._config.embedding_dim)
363 |
364 | else:
365 | output_embeddings = targets
366 |
367 | batch_size, output_sequence_length, embedding_size = output_embeddings.shape
368 |
369 | if (
370 | self._config.positional_encodings
371 | == pos_encs_lib.PositionalEncodings.SIN_COS
372 | ):
373 | pos_encodings = pos_encs_lib.sinusoid_position_encoding(
374 | sequence_length=output_sequence_length,
375 | hidden_size=embedding_size,
376 | memory_length=0,
377 | max_timescale=self._config.positional_encodings_params.max_time,
378 | min_timescale=2,
379 | clamp_length=0,
380 | causal=True,
381 | )
382 | h = output_embeddings + pos_encodings
383 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h)
384 | else:
385 | h = output_embeddings
386 |
387 | # The causal mask is shared across heads.
388 | causal_mask = jnp.tril(
389 | jnp.ones(
390 | (batch_size, 1, output_sequence_length, output_sequence_length)))
391 |
392 | for _ in range(self._config.num_layers):
393 | self_attention = MultiHeadDotProductAttention(
394 | num_heads=self._config.num_heads,
395 | num_hiddens_per_head=self._config.num_hiddens_per_head,
396 | positional_encodings=self._config.positional_encodings,
397 | positional_encodings_params=self._config.positional_encodings_params,
398 | attention_window=self._config.attention_window,
399 | )(inputs_q=h, inputs_kv=h, mask=causal_mask, causal=True)
400 | self_attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob,
401 | self_attention)
402 | self_attention = layer_norm(h + self_attention)
403 |
404 | cross_attention = MultiHeadDotProductAttention(
405 | num_heads=self._config.num_heads,
406 | num_hiddens_per_head=self._config.num_hiddens_per_head,
407 | positional_encodings=self._config.positional_encodings,
408 | positional_encodings_params=self._config.positional_encodings_params,
409 | attention_window=self._config.attention_window,
410 | )(inputs_q=self_attention, inputs_kv=encoded, causal=True)
411 | cross_attention = hk.dropout(hk.next_rng_key(), self._config.dropout_prob,
412 | cross_attention)
413 | cross_attention = layer_norm(self_attention + cross_attention)
414 |
415 | # Position-wise feedforward network.
416 | h = hk.Linear(self._config.embedding_dim * self._config.widening_factor)(
417 | cross_attention)
418 | h = jnn.relu(h)
419 | h = hk.Linear(self._config.embedding_dim)(h)
420 |
421 | h = hk.dropout(hk.next_rng_key(), self._config.dropout_prob, h)
422 | h = layer_norm(h + cross_attention)
423 |
424 | return h
425 |
426 |
427 | class Transformer(hk.Module):
428 | """Transformer (Vaswani et al., 2017)."""
429 |
430 | def __init__(self, config: TransformerConfig, name: Optional[str] = None):
431 | """Initializes the Transformer.
432 |
433 | Args:
434 | config: The hyperparameters used in Transformer architectures.
435 | name: The name of the module.
436 | """
437 | super().__init__(name=name)
438 | shared_embeddings_fn = None
439 |
440 | if config.share_embeddings:
441 | shared_embeddings_fn = hk.Linear(
442 | config.embedding_dim,
443 | with_bias=False,
444 | w_init=hk.initializers.TruncatedNormal(stddev=config.emb_init_scale),
445 | name='shared_embeddings')
446 |
447 | self._encoder = TransformerEncoder(config, shared_embeddings_fn)
448 | self._decoder = TransformerDecoder(config, shared_embeddings_fn)
449 |
450 | def __call__(self, inputs: chex.Array, targets: chex.Array) -> chex.Array:
451 | return self._decoder(self._encoder(inputs), targets)
452 |
453 |
454 | def make_transformer_encoder(
455 | output_size: int,
456 | embedding_dim: int = 64,
457 | num_layers: int = 5,
458 | num_heads: int = 8,
459 | num_hiddens_per_head: Optional[int] = None,
460 | dropout_prob: float = 0.1,
461 | emb_init_scale: float = 0.02,
462 | use_embeddings: bool = True,
463 | share_embeddings: bool = False,
464 | attention_window: Optional[int] = None,
465 | positional_encodings: Optional[pos_encs_lib.PositionalEncodings] = None,
466 | positional_encodings_params: Optional[
467 | pos_encs_lib.PositionalEncodingsParams
468 | ] = None,
469 | widening_factor: int = 4,
470 | return_all_outputs: bool = False,
471 | causal_masking: bool = False,
472 | ) -> Callable[[chex.Array], chex.Array]:
473 | """Returns a transformer encoder model."""
474 | if positional_encodings is None:
475 | positional_encodings = pos_encs_lib.PositionalEncodings.SIN_COS
476 | positional_encodings_params = pos_encs_lib.SinCosParams()
477 | elif positional_encodings_params is None:
478 | raise ValueError('No parameters for positional encodings are passed.')
479 | config = TransformerConfig(
480 | output_size=output_size,
481 | embedding_dim=embedding_dim,
482 | num_layers=num_layers,
483 | num_heads=num_heads,
484 | num_hiddens_per_head=num_hiddens_per_head,
485 | dropout_prob=dropout_prob,
486 | emb_init_scale=emb_init_scale,
487 | use_embeddings=use_embeddings,
488 | share_embeddings=share_embeddings,
489 | attention_window=attention_window,
490 | positional_encodings=positional_encodings,
491 | positional_encodings_params=positional_encodings_params,
492 | widening_factor=widening_factor,
493 | causal_masking=causal_masking,
494 | )
495 |
496 | def transformer_encoder(inputs: chex.Array) -> chex.Array:
497 | output = TransformerEncoder(config)(inputs)
498 | if not return_all_outputs:
499 | output = output[:, -1, :]
500 | return hk.Linear(output_size)(output)
501 |
502 | return transformer_encoder
503 |
504 |
505 | def make_transformer(
506 | output_size: int,
507 | embedding_dim: int = 64,
508 | num_layers: int = 5,
509 | num_heads: int = 8,
510 | num_hiddens_per_head: Optional[int] = None,
511 | dropout_prob: float = 0.1,
512 | emb_init_scale: float = 0.02,
513 | use_embeddings: bool = True,
514 | share_embeddings: bool = False,
515 | attention_window: Optional[int] = None,
516 | positional_encodings: Optional[pos_encs_lib.PositionalEncodings] = None,
517 | positional_encodings_params: Optional[
518 | pos_encs_lib.PositionalEncodingsParams
519 | ] = None,
520 | widening_factor: int = 4,
521 | return_all_outputs: bool = False,
522 | ) -> Callable[[chex.Array, chex.Array], chex.Array]:
523 | """Returns a transformer model."""
524 | if positional_encodings is None:
525 | positional_encodings = pos_encs_lib.PositionalEncodings.SIN_COS
526 | positional_encodings_params = pos_encs_lib.SinCosParams()
527 | elif positional_encodings_params is None:
528 | raise ValueError('No parameters for positional encodings are passed.')
529 | config = TransformerConfig(
530 | output_size=output_size,
531 | embedding_dim=embedding_dim,
532 | num_layers=num_layers,
533 | num_heads=num_heads,
534 | num_hiddens_per_head=num_hiddens_per_head,
535 | dropout_prob=dropout_prob,
536 | emb_init_scale=emb_init_scale,
537 | use_embeddings=use_embeddings,
538 | share_embeddings=share_embeddings,
539 | attention_window=attention_window,
540 | positional_encodings=positional_encodings,
541 | positional_encodings_params=positional_encodings_params,
542 | widening_factor=widening_factor,
543 | )
544 |
545 | def transformer(inputs: chex.Array, targets: chex.Array) -> chex.Array:
546 | output = Transformer(config)(inputs, targets)
547 | if not return_all_outputs:
548 | output = output[:, -1, :]
549 | return hk.Linear(output_size)(output)
550 |
551 | return transformer
552 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | dm-haiku
3 | dm-tree
4 | git+https://github.com/deepmind/einshape
5 | jax
6 | numpy
7 | optax
8 | tqdm
9 | typing-extensions
10 |
--------------------------------------------------------------------------------
/tasks/cs/binary_addition.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Add two binary numbers."""
17 |
18 | import random
19 | from typing import Sequence
20 |
21 | import chex
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 |
28 |
29 | def numbers_to_variable_length_binary(
30 | numbers: Sequence[int],
31 | lengths: Sequence[int],
32 | little_endian: bool = True,
33 | ) -> list[list[int]]:
34 | """Returns the binary notation of a certain length for a sequence of numbers.
35 |
36 | Args:
37 | numbers: The numbers to be converted to binary.
38 | lengths: The lengths of the binary representations (every number uses its
39 | own length). This argument has no effect if the binary representation is
40 | longer than the specified length.
41 | little_endian: Whether to use little- or big-endian notation.
42 | """
43 | binary_strings = [f'{num:b}'.zfill(len) for num, len in zip(numbers, lengths)]
44 |
45 | if little_endian:
46 | binary_strings = [bin[::-1] for bin in binary_strings]
47 |
48 | return [list(map(int, bin)) for bin in binary_strings]
49 |
50 |
51 | def numbers_to_fixed_length_binary(
52 | numbers: Sequence[int],
53 | length: int,
54 | little_endian: bool = True,
55 | ) -> list[list[int]]:
56 | """Returns the binary notation of a certain length for a sequence of numbers.
57 |
58 | Args:
59 | numbers: The numbers to be converted to binary.
60 | length: The length of the binary representations (all numbers use the same
61 | length). This argument has no effect if the binary representation is
62 | longer than the specified length.
63 | little_endian: Whether to use little- or big-endian notation.
64 | """
65 | return numbers_to_variable_length_binary(
66 | numbers=numbers,
67 | lengths=[length] * len(numbers),
68 | little_endian=little_endian,
69 | )
70 |
71 |
72 | def expression_from_numbers(
73 | numbers_n: Sequence[list[int]],
74 | numbers_m: Sequence[list[int]],
75 | ) -> list[list[int]]:
76 | """Returns an expression with a placeholder value to denote the operation."""
77 | return [n + [2] + m for n, m in zip(numbers_n, numbers_m)]
78 |
79 |
80 | class BinaryAddition(task.GeneralizationTask):
81 | """A task with the goal of summing two numbers in binary (little-endian).
82 |
83 | The input is a string of the form `first_number+second_number` in
84 | (little-endian) binary notation (e.g., `01001+011`). The goal of the agent is
85 | to output the result, also in (little-endian) binary form (i.e., in the
86 | example `18 + 6 = 24 = 00011`). The output is padded with 0s to match the
87 | input length, and the end of the sum is denoted with a termination token
88 | (i.e., the output has values in `{0, 1, 2}`).
89 |
90 | Examples:
91 | 001 + 01101 = 010112000 (4 + 22 = 26)
92 | 1001 + 000001 = 10010120000 (9 + 32 = 41)
93 | """
94 |
95 | def _sample_expressions_and_results(
96 | self,
97 | batch_size: int,
98 | length: int,
99 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]:
100 | """Samples pairs of numbers and sums them in (little-endian) binary.
101 |
102 | We use Python's bignums, which can represent arbitrary-precision integers to
103 | perform addition of two potentially very large values (roughly of the size
104 | `2 ** (length // 2)`).
105 |
106 | Args:
107 | batch_size: The number of expressions and results to sample.
108 | length: The length of the input expression containing the two numbers and
109 | the separation token.
110 |
111 | Returns:
112 | The expression and the sum of the two numbers. The expression has the
113 | format: `[first_number, 2, second_number]`, where the numbers are in
114 | (little-endian) binary notation. The sum is also in (little-endian) binary
115 | notation, without leading (i.e., ending) zeros.
116 | """
117 | # If `length <= 2`, we just sample a binary value and return it (without
118 | # leading zeros in little-endian notation).
119 | if length <= 2:
120 | # Since `length <= 2`, we can use `np.random`` without overflow errors.
121 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size))
122 | expressions = numbers_to_fixed_length_binary(numbers, length)
123 | results = numbers_to_fixed_length_binary(numbers, 0)
124 | return expressions, results
125 |
126 | # We only use `length - 1` tokens for the two values to account for the `+`.
127 | length_n = np.random.randint(1, length - 1, size=(batch_size,))
128 | length_m = length - 1 - length_n
129 |
130 | integer_n = [random.randint(1, 2**int(len_n) - 1) for len_n in length_n]
131 | integer_m = [random.randint(1, 2**int(len_m) - 1) for len_m in length_m]
132 |
133 | binary_n = numbers_to_variable_length_binary(integer_n, length_n)
134 | binary_m = numbers_to_variable_length_binary(integer_m, length_m)
135 |
136 | expressions = expression_from_numbers(binary_n, binary_m)
137 |
138 | integer_sum = list(map(sum, zip(integer_n, integer_m)))
139 | results = numbers_to_fixed_length_binary(integer_sum, length=0)
140 |
141 | return expressions, results
142 |
143 | def sample_batch(
144 | self,
145 | rng: chex.PRNGKey,
146 | batch_size: int,
147 | length: int,
148 | ) -> task.Batch:
149 | """Returns a batch of binary additions and their results."""
150 | del rng
151 |
152 | expressions, results = self._sample_expressions_and_results(
153 | batch_size=batch_size, length=length)
154 | # Append the termination token to the result and pad the result with zeros
155 | # to match the output length (accounting for the termination token).
156 | results = [res + [2] + [0] * (length - len(res)) for res in results]
157 |
158 | expressions = jnp.array(expressions, dtype=jnp.int32)
159 | results = jnp.array(results, dtype=jnp.int32)
160 |
161 | return {
162 | 'input': jnn.one_hot(expressions, self.input_size),
163 | 'output': jnn.one_hot(results, self.output_size),
164 | }
165 |
166 | @property
167 | def input_size(self) -> int:
168 | """Returns the input size for the models."""
169 | return 3
170 |
171 | @property
172 | def output_size(self) -> int:
173 | """Returns the output size for the models."""
174 | return 3
175 |
176 | def output_length(self, input_length: int) -> int:
177 | return input_length + 1
178 |
179 | def accuracy_mask(self, target: chex.Array) -> chex.Array:
180 | """Computes a mask that ignores everything after the termination token.
181 |
182 | Args:
183 | target: Target tokens of shape `(batch_size, output_length, output_size)`.
184 |
185 | Returns:
186 | The mask of shape `(batch_size, output_length)`.
187 | """
188 | batch_size, length, _ = target.shape
189 | termination_indices = jnp.argmax(
190 | jnp.argmax(target, axis=-1),
191 | axis=-1,
192 | keepdims=True,
193 | )
194 | indices = jnp.tile(jnp.arange(length), (batch_size, 1))
195 | return indices <= termination_indices
196 |
--------------------------------------------------------------------------------
/tasks/cs/binary_multiplication.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Multiply two binary numbers."""
17 |
18 | import random
19 | from typing import Sequence
20 |
21 | import chex
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition
28 |
29 |
30 | class BinaryMultiplication(task.GeneralizationTask):
31 | """A task with the goal of multiplying two numbers in binary (little-endian).
32 |
33 | The input is a string of the form `first_number£second_number` in
34 | (little-endian) binary notation (e.g., `01001*011`). The goal of the agent is
35 | to output the result, also in (little-endian) binary form (i.e., in the
36 | example `18 * 6 = 108 = 00110011`). The output is padded with 0s to match the
37 | input length, and the end of the product is denoted with a termination token
38 | (i.e., the output has values in `{0, 1, 2}`).
39 |
40 | Examples:
41 | 001 * 01101 = 000110120 (4 * 22 = 88)
42 | 1001 * 000001 = 00000100120 (9 * 32 = 288)
43 | """
44 |
45 | def _sample_expressions_and_results(
46 | self,
47 | batch_size: int,
48 | length: int,
49 | ) -> tuple[Sequence[list[int]], Sequence[list[int]]]:
50 | """Samples pairs of numbers and multiplies them in (little-endian) binary.
51 |
52 | We use Python's bignums, which can represent arbitrary-precision integers to
53 | perform multiplication of two potentially very large values (roughly of the
54 | size `2 ** (length // 2)`).
55 |
56 | Args:
57 | batch_size: The number of expressions and results to sample.
58 | length: The length of the input expression containing the two numbers and
59 | the separation token.
60 |
61 | Returns:
62 | The expression and the product of the two numbers. The expression has the
63 | format: `[first_number, 2, second_number]`, where the numbers are in
64 | (little-endian) binary notation. The product is also in (little-endian)
65 | binary notation, without leading (i.e., ending) zeros.
66 | """
67 | # If `length <= 2`, we just sample a binary sequence for the expression and
68 | # arbitrarily set the result to a fixed value (`[]` for `length == 1` and
69 | # `[0]` for `length == 2`) to maintain the invariant that the result has
70 | # length has most `length - 1`.
71 | if length <= 2:
72 | # Since `length <= 2`, we can use `np.random`` without overflow errors.
73 | numbers = np.random.randint(0, 2**length - 1, size=(batch_size))
74 | expressions = binary_addition.numbers_to_fixed_length_binary(
75 | numbers, length)
76 | return expressions, [[0] * (length - 1)] * batch_size
77 |
78 | # We only use `length - 1` tokens for the two values to account for the `*`.
79 | length_n = np.random.randint(1, length - 1, size=(batch_size,))
80 | length_m = length - 1 - length_n
81 |
82 | integer_n = [random.randint(1, 2**int(len_n) - 1) for len_n in length_n]
83 | integer_m = [random.randint(1, 2**int(len_m) - 1) for len_m in length_m]
84 |
85 | binary_n = binary_addition.numbers_to_variable_length_binary(
86 | integer_n, length_n)
87 | binary_m = binary_addition.numbers_to_variable_length_binary(
88 | integer_m, length_m)
89 |
90 | expressions = binary_addition.expression_from_numbers(binary_n, binary_m)
91 |
92 | integer_prod = [int_n * int_m for int_n, int_m in zip(integer_n, integer_m)]
93 | results = binary_addition.numbers_to_fixed_length_binary(
94 | integer_prod, length=0)
95 |
96 | return expressions, results
97 |
98 | def sample_batch(
99 | self,
100 | rng: chex.PRNGKey,
101 | batch_size: int,
102 | length: int,
103 | ) -> task.Batch:
104 | """Returns a batch of binary multiplications and their results."""
105 | del rng
106 |
107 | expressions, results = self._sample_expressions_and_results(
108 | batch_size=batch_size, length=length)
109 | # Append the termination token to the result and pad the result with zeros
110 | # to match the output length (accounting for the termination token). The
111 | # binary representation of the result will have at most length
112 | # `#(first_number) + #(second_number)`, where #() denotes the number of
113 | # digits of the binary notation. Since we use the token `2` to separate the
114 | # two numbers in the expression, the result will have length at most
115 | # `length - 1`, and thus by appending the termination token above it will
116 | # have length at most `length`, as desired.
117 | results = [res + [2] + [0] * (length - 1 - len(res)) for res in results]
118 |
119 | expressions = jnp.array(expressions, dtype=jnp.int32)
120 | results = jnp.array(results, dtype=jnp.int32)
121 |
122 | return {
123 | 'input': jnn.one_hot(expressions, self.input_size),
124 | 'output': jnn.one_hot(results, self.output_size),
125 | }
126 |
127 | @property
128 | def input_size(self) -> int:
129 | """Returns the input size for the models."""
130 | return 3
131 |
132 | @property
133 | def output_size(self) -> int:
134 | """Returns the output size for the models."""
135 | return 3
136 |
137 | def output_length(self, input_length: int) -> int:
138 | return input_length
139 |
140 | def accuracy_mask(self, target: chex.Array) -> chex.Array:
141 | """Computes a mask that ignores everything after the termination token.
142 |
143 | Args:
144 | target: Target tokens of shape `(batch_size, output_length, output_size)`.
145 |
146 | Returns:
147 | The mask of shape `(batch_size, output_length)`.
148 | """
149 | batch_size, length, _ = target.shape
150 | termination_indices = jnp.argmax(
151 | jnp.argmax(target, axis=-1),
152 | axis=-1,
153 | keepdims=True,
154 | )
155 | indices = jnp.tile(jnp.arange(length), (batch_size, 1))
156 | return indices <= termination_indices
157 |
--------------------------------------------------------------------------------
/tasks/cs/bucket_sort.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Sort tokens from a fixed alphabet (i.e., bucket sort)."""
17 |
18 | import functools
19 |
20 | import chex
21 | import jax
22 | from jax import nn as jnn
23 | from jax import numpy as jnp
24 | from jax import random as jrandom
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 |
28 |
29 | class BucketSort(task.GeneralizationTask):
30 | """A task with the goal of sorting tokens from a fixed alphabet.
31 |
32 | The input string is composed of tokens from a fixed-size alphabet, i.e.,
33 | `{0, 1, ..., vocab_size - 1}`, and the goal is to return the sorted string (in
34 | lexicographically increasing order).
35 |
36 | Examples:
37 | 10204112 -> 00111224 (with `vocab_size = 5`)
38 | 1110001 -> 0001111 (with `vocab_size = 2`)
39 | """
40 |
41 | def __init__(self, vocab_size: int = 5) -> None:
42 | """Initializes the task.
43 |
44 | Args:
45 | vocab_size: The size of the alphabet. We use 5 in the paper.
46 | """
47 | self._vocab_size = vocab_size
48 |
49 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
50 | def sample_batch(
51 | self,
52 | rng: chex.PRNGKey,
53 | batch_size: int,
54 | length: int,
55 | ) -> task.Batch:
56 | """Returns a batch of strings and tokens sorted by (inc.) occurrence."""
57 | strings = jrandom.randint(
58 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
59 | sorted_strings = jnp.sort(strings, axis=-1)
60 |
61 | return {
62 | 'input': jnn.one_hot(strings, num_classes=self.input_size),
63 | 'output': jnn.one_hot(sorted_strings, num_classes=self.output_size),
64 | }
65 |
66 | @property
67 | def input_size(self) -> int:
68 | """Returns the input size for the models."""
69 | return self._vocab_size
70 |
71 | @property
72 | def output_size(self) -> int:
73 | """Returns the output size for the models."""
74 | return self._vocab_size
75 |
76 | def output_length(self, input_length: int) -> int:
77 | """Returns the output length for a given input length."""
78 | return input_length
79 |
--------------------------------------------------------------------------------
/tasks/cs/compute_sqrt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute the floor of the square root of a binary number."""
17 |
18 | import math
19 | import random
20 |
21 | import chex
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 |
25 | from neural_networks_chomsky_hierarchy.tasks import task
26 | from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition
27 |
28 |
29 | class ComputeSqrt(task.GeneralizationTask):
30 | """A task with the goal of computing the square root of a binary number.
31 |
32 | The input is a number in binary (big-endian), and the output is the floor of
33 | the square root of this number, also in binary.
34 | Note the output length ie the length of the square root in binary is always
35 | ceil(input_length / 2) (because log(sqrt(x)) = 1/2 log(x)).
36 |
37 | Examples:
38 | 100101 = 37 -> square root is 6.08... -> floor(6.08) = 6 -> 101
39 | 111 = 7 -> square root is 2.64 -> floor(2.64) = 2 -> 10
40 | """
41 |
42 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int,
43 | length: int) -> task.Batch:
44 | """Returns a batch of binary numbers and their square roots, in binary."""
45 | del rng
46 | numbers = [random.randint(1, 2**length - 1) for _ in range(batch_size)]
47 | binary_numbers = binary_addition.numbers_to_fixed_length_binary(
48 | numbers, length=length, little_endian=False)
49 |
50 | sqrts = list(map(math.isqrt, numbers))
51 | binary_sqrts = binary_addition.numbers_to_fixed_length_binary(
52 | sqrts, length=self.output_length(length), little_endian=False)
53 |
54 | binary_numbers = jnp.array(binary_numbers, jnp.int32)
55 | binary_sqrts = jnp.array(binary_sqrts, jnp.int32)
56 |
57 | inputs = jnn.one_hot(binary_numbers, self.input_size)
58 | output = jnn.one_hot(binary_sqrts, self.output_size)
59 | return {'input': inputs, 'output': output}
60 |
61 | @property
62 | def input_size(self) -> int:
63 | """Returns the input size for the models."""
64 | return 2
65 |
66 | @property
67 | def output_size(self) -> int:
68 | """Returns the output size for the models."""
69 | return 2
70 |
71 | def output_length(self, input_length: int) -> int:
72 | return math.ceil(input_length / 2)
73 |
--------------------------------------------------------------------------------
/tasks/cs/duplicate_string.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Duplicate a string."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 | import jax.random as jrandom
24 |
25 | from neural_networks_chomsky_hierarchy.tasks import task
26 |
27 |
28 | class DuplicateString(task.GeneralizationTask):
29 | """A task with the goal of duplicating a string.
30 |
31 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The
32 | output is the same string outputted twice without any separator, ie:
33 | s_1 ... s_n s_1 ... s_n
34 |
35 | Examples:
36 | 101 -> 101 101
37 | 111111 -> 111111 111111
38 |
39 | In the paper, we use only binary strings (ie S = {0, 1}).
40 | Note that the sampling is jittable so this task is fast.
41 | """
42 |
43 | def __init__(self, vocab_size: int = 2) -> None:
44 | """Initializes the remember_string task.
45 |
46 | Args:
47 | vocab_size: The size of the alphabet. We use 2 in the paper.
48 | """
49 | self._vocab_size = vocab_size
50 |
51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
52 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
53 | length: int) -> task.Batch:
54 | """Returns a batch of strings and their copies."""
55 | strings = jrandom.randint(
56 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
57 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)
58 | output = jnp.concatenate([one_hot_strings, one_hot_strings], axis=1)
59 | return {"input": one_hot_strings, "output": output}
60 |
61 | @property
62 | def input_size(self) -> int:
63 | """Returns the input size for the models."""
64 | return self._vocab_size
65 |
66 | @property
67 | def output_size(self) -> int:
68 | """Returns the output size for the models."""
69 | return self._vocab_size
70 |
71 | def output_length(self, input_length: int) -> int:
72 | """Returns the output length for a given input length."""
73 | return 2 * input_length
74 |
--------------------------------------------------------------------------------
/tasks/cs/missing_duplicate_string.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Predict the missing symbol in a duplicated string."""
17 |
18 | import functools
19 |
20 | import chex
21 | import jax
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 | import jax.random as jrandom
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 |
28 |
29 | class MissingDuplicateString(task.GeneralizationTask):
30 | """A task with the goal of finding the missing symbol in a duplicated string.
31 |
32 | Given a binary string that is presented twice with exactly one element omitted
33 | (denoted by the placeholder token `2`), predict the value of that element.
34 | Thus, an agent trying to solve this task needs to recognize the underlying
35 | duplicated string to be able to produce the correct output.
36 | If the length is odd, the duplicated strings of length `length // 2` are
37 | padded with the empty token `3`.
38 |
39 | Examples
40 | 01100210 -> 1 (the substring is 0110, so the missing value is 1)
41 | 1011213 -> 0 (the subtring is 101, so the missing value is 0)
42 | """
43 |
44 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
45 | def sample_batch(
46 | self,
47 | rng: chex.PRNGKey,
48 | batch_size: int,
49 | length: int,
50 | ) -> task.Batch:
51 | """Returns a batch of strings and the expected class."""
52 | # For `length == 1`, we cannot meaningfully define substrings of length
53 | # `length // 2`, so we arbitrarily set the inputs and outputs to `1`.
54 | if length == 1:
55 | return {
56 | 'input':
57 | jnn.one_hot(
58 | jnp.ones((batch_size, length)), num_classes=self.input_size),
59 | 'output':
60 | jnn.one_hot(
61 | jnp.ones((batch_size,)), num_classes=self.output_size),
62 | }
63 |
64 | strings_rng, indices_rng = jrandom.split(rng)
65 | strings = jrandom.randint(
66 | strings_rng, shape=(batch_size, length // 2), minval=0, maxval=2)
67 | duplicated_strings = jnp.concatenate((strings, strings), axis=-1)
68 | indices = jrandom.randint(
69 | indices_rng,
70 | shape=(batch_size,),
71 | minval=0,
72 | maxval=duplicated_strings.shape[1])
73 | output = jax.vmap(lambda x, y: x[y])(duplicated_strings, indices)
74 | masked_strings = jax.vmap(lambda x, y: x.at[y].set(2))(duplicated_strings,
75 | indices)
76 |
77 | # If `length` is odd, we pad the strings with the empty token `3` at the end
78 | # to ensure that the final input length is equal to `length` given the two
79 | # substrings of length `length // 2`.
80 | padding = jnp.full((batch_size, length % 2), fill_value=3)
81 | padded_strings = jnp.concatenate((masked_strings, padding), axis=-1)
82 |
83 | return {
84 | 'input': jnn.one_hot(padded_strings, num_classes=self.input_size),
85 | 'output': jnn.one_hot(output, num_classes=self.output_size)
86 | }
87 |
88 | @property
89 | def input_size(self) -> int:
90 | """Returns the input size for the models."""
91 | return 4
92 |
93 | @property
94 | def output_size(self) -> int:
95 | """Returns the output size for the models."""
96 | return 2
97 |
--------------------------------------------------------------------------------
/tasks/cs/odds_first.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Sort a string by the parity of the indices (odd indices first)."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 | import jax.random as jrandom
24 |
25 | from neural_networks_chomsky_hierarchy.tasks import task
26 |
27 |
28 | class OddsFirst(task.GeneralizationTask):
29 | """A task with the goal of outputting a string's tokens at odd indices first.
30 |
31 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The
32 | output is the same string, but where the values at odd indexes have been put
33 | first: s_1 s_3 s_5 ... s_2 s_4 s_6 ...
34 |
35 | Examples:
36 | 00110101 -> 0100 0111
37 | 110 -> 10 1
38 |
39 | In the paper, we use only binary strings (ie S = {0, 1}).
40 | Note that the sampling is jittable so this task is fast.
41 | """
42 |
43 | def __init__(self, vocab_size: int = 2) -> None:
44 | """Initializes the odds_first task.
45 |
46 | Args:
47 | vocab_size: The size of the alphabet. We use 2 in the paper.
48 | """
49 | self._vocab_size = vocab_size
50 |
51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
52 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
53 | length: int) -> task.Batch:
54 | """Returns a batch of strings and their outputs."""
55 | strings = jrandom.randint(
56 | rng, shape=(batch_size, length), minval=0, maxval=self._vocab_size)
57 | one_hot_strings = jnn.one_hot(strings, num_classes=self._vocab_size)
58 | output = jnp.concatenate(
59 | [one_hot_strings[:, 1::2], one_hot_strings[:, ::2]], axis=1)
60 | return {"input": one_hot_strings, "output": output}
61 |
62 | @property
63 | def input_size(self) -> int:
64 | """Returns the input size for the model."""
65 | return self._vocab_size
66 |
67 | @property
68 | def output_size(self) -> int:
69 | """Returns the output size for the model."""
70 | return self._vocab_size
71 |
72 | def output_length(self, input_length: int) -> int:
73 | """Returns the output length for the model."""
74 | return input_length
75 |
--------------------------------------------------------------------------------
/tasks/dcf/modular_arithmetic_brackets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Modular arithmetic with brackets."""
17 |
18 | import collections
19 | from typing import Sequence
20 |
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 | import numpy as np
24 | import tqdm
25 | import tree
26 |
27 | from neural_networks_chomsky_hierarchy.tasks import task
28 |
29 |
30 | def generate_one_expression_and_result(
31 | modulus: int, length: int, mult: bool = False
32 | ) -> tuple[str, int]:
33 | """Returns a modular arithmetic expression with brackets, and its result.
34 |
35 | The values in the expression are in {0, 1, ..., modulus-1}. The allowed
36 | operations are either {+, -} (mult=False) or {+, -, *} (mult=True).
37 |
38 | Args:
39 | modulus: The modulus to use for the expression.
40 | length: The length of the expression.
41 | mult: Whether to include the multiplication operator in the expressions.
42 |
43 | Raises:
44 | ValueError if length < 1.
45 | """
46 |
47 | # Generates a terminal (digit).
48 | def gen_terminal():
49 | terminal = np.random.randint(low=0, high=modulus)
50 | return str(terminal), terminal
51 |
52 | # If length is less than 1, issue an error.
53 | if length < 1:
54 | raise ValueError(
55 | f'Can\'t generate expressions of length < 1. Got {length}.')
56 |
57 | # If length is less than 5, generate a digit d, -d, (d), or (-d).
58 | if length == 1:
59 | return gen_terminal()
60 | elif length == 2:
61 | term_str, term_val = gen_terminal()
62 | return f'-{term_str}', -term_val % modulus
63 | elif length == 3:
64 | term_str, term_val = gen_terminal()
65 | return f'({term_str})', term_val % modulus
66 | elif length == 4:
67 | term_str, term_val = gen_terminal()
68 | return f'(-{term_str})', -term_val % modulus
69 |
70 | # First split the length into a left and right part.
71 | left_length = np.random.randint(low=1, high=length - 3)
72 | right_length = length - (left_length + 3)
73 | left_str, left_val = generate_one_expression_and_result(
74 | modulus, left_length, mult=mult)
75 | right_str, right_val = generate_one_expression_and_result(
76 | modulus, right_length, mult=mult)
77 |
78 | # Now sample an operator and return.
79 | maxop = 3 if mult else 2
80 | op = np.random.randint(low=0, high=maxop)
81 | if op == 0:
82 | return '(' + left_str + '+' + right_str + ')', (left_val +
83 | right_val) % modulus
84 | elif op == 1:
85 | return '(' + left_str + '-' + right_str + ')', (left_val -
86 | right_val) % modulus
87 | else:
88 | return '(' + left_str + '*' + right_str + ')', (left_val *
89 | right_val) % modulus
90 |
91 |
92 | def generate_raw_dataset(
93 | n: int,
94 | lengths: Sequence[int],
95 | modulus: int,
96 | mult: bool = False,
97 | with_tqdm: bool = False,
98 | ) -> dict[int, dict[str, np.ndarray]]:
99 | """Generates a dataset of maths expressions with brackets, and their results.
100 |
101 | Args:
102 | n: The number of datapoints in the dataset.
103 | lengths: The lengths of the sequences to generate. n is evenly distributed
104 | over these lengths.
105 | modulus: Modulus used to compute the expressions.
106 | mult: Whether to include the multiplication operator in the expressions.
107 | with_tqdm: As the computation might be long, whether to add a tqdm progress
108 | bar or not.
109 |
110 | Returns:
111 | A dict which keys are the passed lengths, and the values are dicts with keys
112 | 'equations' and 'solutions', and values are the data numpy arrays.
113 | """
114 | alphabet_to_int = {
115 | '+': modulus,
116 | '-': modulus + 1,
117 | '*': modulus + 2,
118 | '(': modulus + 3,
119 | ')': modulus + 4,
120 | 'x': modulus + 5,
121 | '=': modulus + 6,
122 | }
123 | for x in range(modulus):
124 | alphabet_to_int[str(x)] = x
125 |
126 | make_default_dict = lambda: {'expressions': [], 'results': []}
127 | sequences = collections.defaultdict(make_default_dict)
128 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths
129 | for length in range_lengths:
130 | for _ in range(n // len(lengths)):
131 | seq, label = generate_one_expression_and_result(modulus, length, mult)
132 | seq = [alphabet_to_int[x] for x in seq]
133 | sequences[length]['expressions'].append(seq)
134 | sequences[length]['results'].append(label)
135 | sequences = tree.traverse(
136 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l,
137 | sequences,
138 | top_down=False,
139 | )
140 | return dict(sequences)
141 |
142 |
143 | class ModularArithmeticBrackets(task.GeneralizationTask):
144 | """A task with the goal of reducing an arithmetic expression with brackets."""
145 |
146 | def __init__(self, modulus: int = 5, mult: bool = False) -> None:
147 | """Initializes the modular arithmetic task.
148 |
149 | Args:
150 | modulus: The modulus used for the computation. We use 5 in the paper.
151 | mult: Whether to add multiplication or use only '+' and '-'.
152 | """
153 | self._modulus = modulus
154 | self._mult = mult
155 |
156 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
157 | length: int) -> task.Batch:
158 | """Returns a batch of inputs/outputs."""
159 | del rng
160 | batch = generate_raw_dataset(
161 | batch_size, lengths=[length], modulus=self._modulus,
162 | mult=self._mult)[length]
163 | inputs = jnn.one_hot(batch['expressions'], self.input_size)
164 | output = jnn.one_hot(batch['results'], self.output_size)
165 | return {'input': inputs, 'output': output}
166 |
167 | @property
168 | def input_size(self) -> int:
169 | """Returns the input size for the models."""
170 | return self._modulus + 6
171 |
172 | @property
173 | def output_size(self) -> int:
174 | """Returns the output size for the models."""
175 | return self._modulus
176 |
--------------------------------------------------------------------------------
/tasks/dcf/reverse_string.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute the reverse of an input string."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from neural_networks_chomsky_hierarchy.tasks import task
24 | from neural_networks_chomsky_hierarchy.tasks.cs import duplicate_string
25 |
26 |
27 | class ReverseString(duplicate_string.DuplicateString):
28 | """A task with the goal of reversing a given string.
29 |
30 | The input is a string s_1 ... s_n composed of symbols from a finite set S. The
31 | output is the string, reversed, ie s_n ... s_1.
32 |
33 | Examples:
34 | 011010 -> 010110
35 | 123021 -> 120321
36 |
37 | In the paper, we use only binary strings (ie S = {0, 1}).
38 | Note that the sampling is jittable so this task is fast.
39 | """
40 |
41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
43 | length: int) -> task.Batch:
44 | """Returns a batch of strings and their reversed version."""
45 | batch = super().sample_batch(rng, batch_size, length)
46 | batch['output'] = jnp.flip(batch['input'], axis=1)
47 | return batch
48 |
49 | def output_length(self, input_length: int) -> int:
50 | """Returns the output length for a given input length."""
51 | return input_length
52 |
--------------------------------------------------------------------------------
/tasks/dcf/solve_equation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Solve for the value of an unknown variable in an equation."""
17 |
18 | import collections
19 | from typing import Sequence
20 |
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 | import numpy as np
24 | import tqdm
25 | import tree
26 |
27 | from neural_networks_chomsky_hierarchy.tasks import task
28 | from neural_networks_chomsky_hierarchy.tasks.dcf import modular_arithmetic_brackets as mab
29 |
30 |
31 | def generate_equation_and_solution(
32 | modulus: int,
33 | length: int,
34 | ) -> tuple[str, int]:
35 | """Returns a modular arithmetic equation with brackets, and its solution.
36 |
37 | The values are in {0, 1, ..., modulus-1}, and the unknown
38 | value is x. The allowed operations are either {+, -} (mult=False) or
39 | {+, -, *} (mult=True).
40 | Warning: if mult=True, x might have multiple valid solutions.
41 |
42 | Args:
43 | modulus: The modulus to use for the expression.
44 | length: The length of the expression.
45 |
46 | Raises:
47 | ValueError if the length is < 3.
48 | """
49 |
50 | # Generate the expression.
51 | expr, val = mab.generate_one_expression_and_result(
52 | modulus,
53 | length - 2,
54 | # We use mult=False by default here, otherwise equations could have
55 | # multiple solutions if the variable 'x' or some expression containing the
56 | # variable is multiplied by 0.
57 | mult=False,
58 | )
59 |
60 | # Replace random digit with 'x'.
61 | idx = np.random.randint(low=0, high=len(expr))
62 | digits = [str(n) for n in range(modulus)]
63 | while expr[idx] not in digits:
64 | idx = (idx + 1) % (length - 2)
65 | solution = int(expr[idx])
66 | equation = f'{expr[:idx]}x{expr[idx + 1:]}={val}'
67 | return equation, solution
68 |
69 |
70 | def generate_raw_dataset(
71 | n: int,
72 | lengths: Sequence[int],
73 | modulus: int,
74 | with_tqdm: bool = False,
75 | ) -> dict[int, dict[str, np.ndarray]]:
76 | """Generates a dataset of equations and their solutions.
77 |
78 | Args:
79 | n: The number of datapoints in the dataset.
80 | lengths: The lengths of the sequences to generate. n is evenly distributed
81 | over these lengths.
82 | modulus: Modulus used to compute the expressions.
83 | with_tqdm: As the computation might be long, whether to add a tqdm progress
84 | bar or not.
85 |
86 | Returns:
87 | A dict which keys are the passed lengths, and the values are dicts with keys
88 | 'equations' and 'solutions', and values are the data numpy arrays.
89 | """
90 | alphabet_to_int = {
91 | '+': modulus,
92 | '-': modulus + 1,
93 | '(': modulus + 2,
94 | ')': modulus + 3,
95 | 'x': modulus + 4,
96 | '=': modulus + 5,
97 | }
98 | for x in range(modulus):
99 | alphabet_to_int[str(x)] = x
100 |
101 | sequences = collections.defaultdict(lambda: { # pylint: disable=g-long-lambda
102 | 'equations': [],
103 | 'solutions': []
104 | })
105 | range_lengths = tqdm.tqdm(lengths) if with_tqdm else lengths
106 | for length in range_lengths:
107 | for _ in range(n // len(lengths)):
108 | seq, label = generate_equation_and_solution(modulus, length)
109 | seq = [alphabet_to_int[x] for x in seq]
110 | sequences[length]['equations'].append(seq)
111 | sequences[length]['solutions'].append(label)
112 | # Convert the list of numbers we have to arrays at the leaves.
113 | sequences = tree.traverse(
114 | lambda l: np.array(l, dtype=np.int32) if isinstance(l, list) else l,
115 | sequences,
116 | top_down=False,
117 | )
118 | return dict(sequences)
119 |
120 |
121 | class SolveEquation(task.GeneralizationTask):
122 | """A task with the goal of solving an modular equation for an unknown.
123 |
124 | Note that the equations do not contain any multiplication as it could lead to
125 | multiple solutions (multiplication by zero).
126 | """
127 |
128 | def __init__(self, modulus: int = 5) -> None:
129 | """Initializes the modular arithmetic task.
130 |
131 | Args:
132 | modulus: The modulus used for the computation. We use 5 in the paper.
133 | """
134 | self._modulus = modulus
135 |
136 | def sample_batch(
137 | self,
138 | rng: jnp.ndarray,
139 | batch_size: int,
140 | length: int,
141 | ) -> task.Batch:
142 | """Returns a batch of inputs/outputs."""
143 | if length < 3:
144 | return {
145 | 'input':
146 | jnn.one_hot(
147 | jnp.zeros((batch_size, length)), num_classes=self.input_size),
148 | 'output':
149 | jnn.one_hot(
150 | jnp.zeros((batch_size,)), num_classes=self.output_size)
151 | }
152 | batch = generate_raw_dataset(
153 | batch_size,
154 | lengths=[length],
155 | modulus=self._modulus,
156 | )[length]
157 | inputs = jnn.one_hot(batch['equations'], self.input_size)
158 | output = jnn.one_hot(batch['solutions'], self.output_size)
159 | return {'input': inputs, 'output': output}
160 |
161 | @property
162 | def input_size(self) -> int:
163 | """Returns the input size for the models."""
164 | return self._modulus + 6
165 |
166 | @property
167 | def output_size(self) -> int:
168 | """Returns the output size for the models."""
169 | return self._modulus
170 |
--------------------------------------------------------------------------------
/tasks/dcf/stack_manipulation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Manipulate an input stack, using the input actions."""
17 |
18 | import chex
19 | import jax.nn as jnn
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 | from neural_networks_chomsky_hierarchy.tasks import task
24 |
25 |
26 | class StackManipulation(task.GeneralizationTask):
27 | """A task with the goal of following instructions and returning the end stack.
28 |
29 | The input is composed of a stack of 0s and 1s followed by a sequence of
30 | instructions POP/PUSH 0/PUSH 1 (represented by 2s/3s/4s). The input stack is
31 | given bottom-to-top, and the agent needs to execute the instructions given
32 | (left-to-rigth) and output the final stack top-to-bottom (i.e., as if it were
33 | popping the final stack). If a POP action is to be called on an empty stack,
34 | the action is ignored. The output is padded with 0s to match the input length
35 | + 1 (to accommodate for the termination token), and the end of the final stack
36 | is denoted with the termination symbol 2 (i.e., the output has values in {0,
37 | 1, 2}).
38 |
39 | Examples:
40 | 0 1 1 0 PUSH 1 POP POP
41 | initial 0 1 1 0 (the stack is received bottom-to-top)
42 | PUSH 1 0 1 1 0 1
43 | POP 0 1 1 0
44 | POP 0 1 1
45 | -> 1 1 0 2 0 0 0 0 (the stack is returned top-to-bottom)
46 |
47 | 1 1 0 POP POP POP
48 | initial 1 1 0
49 | POP 1 1
50 | POP 1
51 | POP
52 | -> 2 0 0 0 0 0 0 0 (the stack is empty and padded with zeros)
53 | """
54 |
55 | def _sample_expression_and_result(
56 | self, length: int
57 | ) -> tuple[np.ndarray, list[int]]:
58 | """Returns an expression with stack instructions, and the result stack."""
59 | if length == 1:
60 | value = np.random.randint(low=0, high=2, size=(1,))
61 | return value, list(value)
62 |
63 | # Initialize the stack content and the actions (POP/PUSH).
64 | stack_length = np.random.randint(low=1, high=length)
65 | stack = np.random.randint(low=0, high=2, size=(stack_length,))
66 | actions = np.random.randint(low=2, high=5, size=(length - stack_length,))
67 |
68 | # Apply the actions on the stack.
69 | current_stack = list(stack)
70 |
71 | for action in actions:
72 | if action == 2: # POP
73 | if current_stack:
74 | current_stack.pop()
75 | elif action in [3, 4]: # PUSH a 0 (case 3) or a 1 (case 4)
76 | current_stack.append(action - 3)
77 |
78 | return np.concatenate([stack, actions]), current_stack[::-1]
79 |
80 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int,
81 | length: int) -> task.Batch:
82 | """Returns a batch of strings and the expected class."""
83 | expressions, results = [], []
84 | for _ in range(batch_size):
85 | expression, result = self._sample_expression_and_result(length)
86 | expressions.append(expression)
87 | # Append the termination token to the result.
88 | result += [self.output_size - 1]
89 | # Pad the result with zeros to match the input length (accounting for the
90 | # termination token).
91 | result += [0] * (length + 1 - len(result))
92 | results.append(result)
93 | expressions = jnp.array(expressions)
94 | results = jnp.array(results)
95 |
96 | inputs = jnn.one_hot(expressions, self.input_size)
97 | output = jnn.one_hot(results, self.output_size)
98 | return {'input': inputs, 'output': output}
99 |
100 | @property
101 | def input_size(self) -> int:
102 | """Returns the input size for the models.
103 |
104 | The value is 5 because we have two possible tokens in the stack (0, 1), plus
105 | three tokens to describe the PUSH 0, PUSH 1, and POP actions.
106 | """
107 | return 5
108 |
109 | @property
110 | def output_size(self) -> int:
111 | """Returns the output size for the models."""
112 | return 3
113 |
114 | def output_length(self, input_length: int) -> int:
115 | """Returns the output length of the task."""
116 | return input_length + 1
117 |
118 | def accuracy_mask(self, target: chex.Array) -> chex.Array:
119 | """Computes mask that ignores everything after the termination tokens.
120 |
121 | Args:
122 | target: Target tokens of shape `(batch_size, output_length, output_size)`.
123 |
124 | Returns:
125 | The mask of shape `(batch_size, output_length)`.
126 | """
127 | batch_size, length, _ = target.shape
128 | termination_indices = jnp.argmax(
129 | jnp.argmax(target, axis=-1),
130 | axis=-1,
131 | keepdims=True,
132 | )
133 | indices = jnp.tile(jnp.arange(length), (batch_size, 1))
134 | return indices <= termination_indices
135 |
--------------------------------------------------------------------------------
/tasks/regular/cycle_navigation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute the final state after randomly walking on a circle."""
17 |
18 | import functools
19 |
20 | import chex
21 | import jax
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 | import jax.random as jrandom
25 |
26 | from neural_networks_chomsky_hierarchy.tasks import task
27 |
28 |
29 | class CycleNavigation(task.GeneralizationTask):
30 | """A task with the goal of computing the final state on a circle.
31 |
32 | The input is a string of actions, composed of 0s, 1s or -1s. The actions give
33 | directions to take on a finite length circle (0 is for stay, 1 is for right,
34 | -1 is for left). The goal is to give the final position on the circle after
35 | all the actions have been taken. The agent starts at position 0.
36 |
37 | By default, the length the circle is 5.
38 |
39 | Examples:
40 | 1 -1 0 -1 -1 -> -2 = class 3
41 | 1 1 1 -1 -> 2 = class 2
42 |
43 | Note that the sampling is jittable so it is fast.
44 | """
45 |
46 | @property
47 | def _cycle_length(self) -> int:
48 | """Returns the cycle length, number of possible states."""
49 | return 5
50 |
51 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
52 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int,
53 | length: int) -> task.Batch:
54 | """Returns a batch of strings and the expected class."""
55 | actions = jrandom.randint(
56 | rng, shape=(batch_size, length), minval=0, maxval=3)
57 | final_states = jnp.sum(actions - 1, axis=1) % self._cycle_length
58 | final_states = jnn.one_hot(final_states, num_classes=self.output_size)
59 | one_hot_strings = jnn.one_hot(actions, num_classes=self.input_size)
60 | return {"input": one_hot_strings, "output": final_states}
61 |
62 | @property
63 | def input_size(self) -> int:
64 | """Returns the input size for the models."""
65 | return 3
66 |
67 | @property
68 | def output_size(self) -> int:
69 | """Returns the output size for the models."""
70 | return self._cycle_length
71 |
--------------------------------------------------------------------------------
/tasks/regular/even_pairs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute whether the number of 01's and 10's is even."""
17 |
18 | import functools
19 |
20 | import jax
21 | from jax import nn as jnn
22 | from jax import numpy as jnp
23 | from jax import random as jrandom
24 |
25 | from neural_networks_chomsky_hierarchy.tasks import task
26 |
27 |
28 | class EvenPairs(task.GeneralizationTask):
29 | """A task with the goal of checking whether the number of 01s and 10s is even.
30 |
31 | The input is a binary string, composed of 0s and 1s. If the result is even,
32 | the class is 0, otherwise it's one.
33 |
34 | Examples:
35 | 001110 -> 1 '10' and 1 '01' -> class 0
36 | 0101001 -> 2 '10' and 3 '01' -> class 1
37 |
38 | Note the sampling is jittable so this task is fast.
39 | """
40 |
41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
43 | length: int) -> task.Batch:
44 | """Returns a batch of strings and the expected class."""
45 | strings = jrandom.randint(
46 | rng,
47 | shape=(batch_size, length),
48 | minval=0,
49 | maxval=2,
50 | )
51 | one_hot_strings = jnn.one_hot(strings, num_classes=2)
52 | unequal_pairs = jnp.logical_xor(strings[:, :-1], strings[:, 1:])
53 | odd_unequal_pairs = jnp.sum(unequal_pairs, axis=-1) % 2
54 | return {
55 | 'input': one_hot_strings,
56 | 'output': jnn.one_hot(odd_unequal_pairs, num_classes=self.output_size),
57 | }
58 |
59 | @property
60 | def input_size(self) -> int:
61 | """Returns the input size for the models."""
62 | return 2
63 |
64 | @property
65 | def output_size(self) -> int:
66 | """Returns the output size for the models."""
67 | return 2
68 |
--------------------------------------------------------------------------------
/tasks/regular/modular_arithmetic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Modular arithmetic without brackets.
17 |
18 | Note this allows to generate samples using a jittable function, and is therefore
19 | much faster than its 'brackets' counterpart, which requires to simulate the full
20 | CF grammar, non-jittable.
21 | """
22 |
23 | import functools
24 | from typing import Optional, Sequence
25 |
26 | import jax
27 | import jax.nn as jnn
28 | import jax.numpy as jnp
29 | import jax.random as jrandom
30 |
31 | from neural_networks_chomsky_hierarchy.tasks import task
32 |
33 | # Public as this may be used to encode/decode strings of numbers/symbols.
34 | OP_BY_CHARACTER = {'+': 0, '-': 1, '*': 2, '_': 3}
35 |
36 |
37 | def _replace_subtractions(expression: jnp.ndarray, modulus: int) -> jnp.ndarray:
38 | """Replaces subtractions in an expression by additions with the inverse.
39 |
40 | e.g. the expression [1, -, 3] results in [1, +, -3].
41 |
42 | Args:
43 | expression: Encoded expression (a 1D array of integers) in which to replace
44 | subtractions.
45 | modulus: The modulus to use for the modular arithmetic.
46 |
47 | Returns:
48 | The expression with all subtractions replaced by additions with the inverse.
49 | """
50 | if expression.size < 2:
51 | return expression
52 |
53 | mask = (expression == modulus + OP_BY_CHARACTER['-'])
54 | subtract_replaced = jnp.where(mask, modulus + OP_BY_CHARACTER['+'],
55 | expression)
56 | return subtract_replaced.at[2:].multiply(1 - 2 * mask[1:-1])
57 |
58 |
59 | def _perform_multiplications(expression: jnp.ndarray,
60 | modulus: int) -> jnp.ndarray:
61 | """Performs all multiplications in an expression containing only + and *.
62 |
63 | This is done at fixed length and the result is zero-padded to achieve this.
64 | Since the result of performing multiplications is an expression containing
65 | only + operators, the operators are dropped from the output. For example, the
66 | expression [1, +, 3, *, 4] results in [1, 12, 0].
67 |
68 | Args:
69 | expression: Encoded expression in which to perform multiplications.
70 | modulus: The modulus to use for the modular arithmetic.
71 |
72 | Returns:
73 | An array with the results of the multiplications (potentially zero-padded).
74 | """
75 | term_ids = jnp.cumsum(expression == modulus + OP_BY_CHARACTER['+'])[::2]
76 | # Segment_prod can only be jit-compiled with a fixed number of segments.
77 | # Therefore, we have to set to the maximum number of terms possible and
78 | # mask out superfluous segment results with zeros afterwards.
79 | maximum_term_number = expression.shape[0] // 2 + 1
80 | products = jax.ops.segment_prod(
81 | expression[::2],
82 | term_ids,
83 | num_segments=maximum_term_number,
84 | indices_are_sorted=True)
85 | valid_segment_mask = jnp.arange(maximum_term_number) <= term_ids[-1]
86 | return products * valid_segment_mask
87 |
88 |
89 | def _replace_blanks(expression: jnp.ndarray, modulus: int) -> jnp.ndarray:
90 | """Replaces blank symbols in expression with either `+` or `0`.
91 |
92 | Depending on whether the blank symbol is at the position of an operator or a
93 | residual, the blank symbol is replaced with a `+` operator or a `0`.
94 |
95 | Args:
96 | expression: Encoded expression in which to replace blank symbols.
97 | modulus: The modulus to use for the modular arithmetic.
98 |
99 | Returns:
100 | An array with blank symbols replaced by either `+` or `0`.
101 | """
102 | mask = (expression == OP_BY_CHARACTER['_'] + modulus)
103 | operator_mask = mask.at[::2].set(False)
104 | residual_mask = mask.at[1::2].set(False)
105 |
106 | blanks_replaced = jnp.where(operator_mask, OP_BY_CHARACTER['+'] + modulus,
107 | expression)
108 | blanks_replaced = jnp.where(residual_mask, 0, blanks_replaced)
109 | return blanks_replaced
110 |
111 |
112 | def _evaluate_expression(expression: jnp.ndarray, modulus: int) -> jnp.ndarray:
113 | """Returns the result of evaluating a modular arithmetic expression."""
114 | expression = _replace_blanks(expression, modulus)
115 | expression = _replace_subtractions(expression, modulus)
116 | additive_terms = _perform_multiplications(expression, modulus)
117 | return jnp.sum(additive_terms) % modulus
118 |
119 |
120 | class ModularArithmetic(task.GeneralizationTask):
121 | """A task with the goal of reducing a simple arithmetic expression.
122 |
123 | The input is a string, composed of numbers (in {0, ..., modulus-1}), and
124 | operators (in {+, -, *}). The output is the reduced value of this expression,
125 | which is also in {0, ..., modulus-1}.
126 |
127 | Examples (modulo 5):
128 | 1 + 2 * 3 = 2
129 | 1 - 1 - 1 = 4
130 | 0 * 1 + 4 * 3 - 2 = 0
131 |
132 | Note that the input strings are always of odd length.
133 | """
134 |
135 | def __init__(
136 | self,
137 | modulus: int = 5,
138 | operators: Optional[Sequence[str]] = None,
139 | ) -> None:
140 | """Initializes the modular arithmetic task.
141 |
142 | Args:
143 | modulus: The modulus used for the computation. We use 5 in the paper.
144 | operators: Operators to be used in the sequences. By default it's None,
145 | meaning all operators available are used.
146 | """
147 | self._modulus = modulus
148 | if operators is None:
149 | operators = ('+', '*', '-')
150 | self._operators = (OP_BY_CHARACTER[op] for op in operators)
151 |
152 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
153 | def sample_batch(
154 | self,
155 | rng: jnp.ndarray,
156 | batch_size: int,
157 | length: int,
158 | ) -> task.Batch:
159 | """Returns a batch of modular arithmetic expressions and their labels.
160 |
161 | Args:
162 | rng: The jax random number generator.
163 | batch_size: The size of the batch returned.
164 | length: The length of the sequence. As this length must be odd for the
165 | modular arithmetic dataset, if it's not, we force it to be by
166 | subtracting one to the length passed.
167 | """
168 | # Subtracting one to the length if it's not odd already.
169 | if length % 2 != 1:
170 | length -= 1
171 |
172 | batch = jnp.empty((batch_size, length), dtype=int)
173 | rng1, rng2 = jax.random.split(rng)
174 | remainders = jax.random.randint(rng1,
175 | (batch_size, length // 2 + 1), 0,
176 | self._modulus)
177 | ops = self._modulus + jnp.array(list(self._operators))
178 |
179 | operations = jrandom.choice(rng2, ops, (batch_size, length // 2))
180 | batch = batch.at[:, ::2].set(remainders)
181 | expressions = batch.at[:, 1::2].set(operations)
182 |
183 | evaluate = functools.partial(_evaluate_expression, modulus=self._modulus)
184 | labels = jax.vmap(evaluate)(expressions)
185 | labels = jnn.one_hot(labels, self._modulus)
186 | one_hot_expressions = jnn.one_hot(expressions,
187 | self._modulus + len(OP_BY_CHARACTER))
188 | return {'input': one_hot_expressions, 'output': labels}
189 |
190 | @property
191 | def input_size(self) -> int:
192 | """Returns the input size for the models."""
193 | return self._modulus + len(OP_BY_CHARACTER)
194 |
195 | @property
196 | def output_size(self) -> int:
197 | """Returns the output size for the models."""
198 | return self._modulus
199 |
--------------------------------------------------------------------------------
/tasks/regular/parity_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute whether the number of 1s in a string is even."""
17 |
18 | import functools
19 |
20 | import jax
21 | import jax.nn as jnn
22 | import jax.numpy as jnp
23 | import jax.random as jrandom
24 |
25 | from neural_networks_chomsky_hierarchy.tasks import task
26 |
27 |
28 | class ParityCheck(task.GeneralizationTask):
29 | """A task with the goal of counting the number of '1' in a string, modulo 2.
30 |
31 | The input is a string, composed of 0s and 1s. If the result is even, the class
32 | is 0, otherwise it's 1.
33 |
34 | Examples:
35 | 1010100 -> 3 1s (odd) -> class 1
36 | 01111 -> 4 1s (even) -> class 0
37 |
38 | Note that the sampling is jittable so this task is fast.
39 | """
40 |
41 | @functools.partial(jax.jit, static_argnums=(0, 2, 3))
42 | def sample_batch(self, rng: jnp.ndarray, batch_size: int,
43 | length: int) -> task.Batch:
44 | """Returns a batch of strings and the expected class."""
45 | strings = jrandom.randint(
46 | rng, shape=(batch_size, length), minval=0, maxval=2)
47 | n_b = jnp.sum(strings, axis=1) % 2
48 | n_b = jnn.one_hot(n_b, num_classes=2)
49 | one_hot_strings = jnn.one_hot(strings, num_classes=2)
50 | return {"input": one_hot_strings, "output": n_b}
51 |
52 | @property
53 | def input_size(self) -> int:
54 | """Returns the input size for the models."""
55 | return 2
56 |
57 | @property
58 | def output_size(self) -> int:
59 | """Returns the output size for the models."""
60 | return 2
61 |
--------------------------------------------------------------------------------
/tasks/task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Base class for length generalization tasks."""
17 |
18 | import abc
19 | from typing import TypedDict
20 |
21 | import chex
22 | import jax.nn as jnn
23 | import jax.numpy as jnp
24 |
25 | Batch = TypedDict('Batch', {'input': chex.Array, 'output': chex.Array})
26 |
27 |
28 | class GeneralizationTask(abc.ABC):
29 | """A task for the generalization project.
30 |
31 | Exposes a sample_batch method, and some details about input/output sizes,
32 | losses and accuracies.
33 | """
34 |
35 | @abc.abstractmethod
36 | def sample_batch(self, rng: chex.PRNGKey, batch_size: int,
37 | length: int) -> Batch:
38 | """Returns a batch of inputs/outputs."""
39 |
40 | def pointwise_loss_fn(self, output: chex.Array,
41 | target: chex.Array) -> chex.Array:
42 | """Returns the pointwise loss between an output and a target."""
43 | return -target * jnn.log_softmax(output)
44 |
45 | def accuracy_fn(self, output: chex.Array, target: chex.Array) -> chex.Array:
46 | """Returns the accuracy between an output and a target."""
47 | return (jnp.argmax(output,
48 | axis=-1) == jnp.argmax(target,
49 | axis=-1)).astype(jnp.float32)
50 |
51 | def accuracy_mask(self, target: chex.Array) -> chex.Array:
52 | """Returns a mask to compute the accuracies, to remove the superfluous ones."""
53 | # Target is a shape of shape (B, T, C) where C is the number of classes.
54 | # We want a mask per input (B, T), so we take this shape.
55 | return jnp.ones(target.shape[:-1])
56 |
57 | @property
58 | @abc.abstractmethod
59 | def input_size(self) -> int:
60 | """Returns the size of the input of the models trained on this task."""
61 |
62 | @property
63 | @abc.abstractmethod
64 | def output_size(self) -> int:
65 | """Returns the size of the output of the models trained on this task."""
66 |
67 | def output_length(self, input_length: int) -> int:
68 | """Returns the length of the output, given an input length."""
69 | del input_length
70 | return 1
71 |
--------------------------------------------------------------------------------